Source file
src/database/sql/fakedb_test.go
1
2
3
4
5 package sql
6
7 import (
8 "context"
9 "database/sql/driver"
10 "errors"
11 "fmt"
12 "io"
13 "reflect"
14 "slices"
15 "strconv"
16 "strings"
17 "sync"
18 "sync/atomic"
19 "testing"
20 "time"
21 )
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47 type fakeDriver struct {
48 mu sync.Mutex
49 openCount int
50 closeCount int
51 waitCh chan struct{}
52 waitingCh chan struct{}
53 dbs map[string]*fakeDB
54 }
55
56 type fakeConnector struct {
57 name string
58
59 waiter func(context.Context)
60 closed bool
61 }
62
63 func (c *fakeConnector) Connect(context.Context) (driver.Conn, error) {
64 conn, err := fdriver.Open(c.name)
65 conn.(*fakeConn).waiter = c.waiter
66 return conn, err
67 }
68
69 func (c *fakeConnector) Driver() driver.Driver {
70 return fdriver
71 }
72
73 func (c *fakeConnector) Close() error {
74 if c.closed {
75 return errors.New("fakedb: connector is closed")
76 }
77 c.closed = true
78 return nil
79 }
80
81 type fakeDriverCtx struct {
82 fakeDriver
83 }
84
85 var _ driver.DriverContext = &fakeDriverCtx{}
86
87 func (cc *fakeDriverCtx) OpenConnector(name string) (driver.Connector, error) {
88 return &fakeConnector{name: name}, nil
89 }
90
91 type fakeDB struct {
92 name string
93
94 useRawBytes atomic.Bool
95
96 mu sync.Mutex
97 tables map[string]*table
98 badConn bool
99 allowAny bool
100 }
101
102 type fakeError struct {
103 Message string
104 Wrapped error
105 }
106
107 func (err fakeError) Error() string {
108 return err.Message
109 }
110
111 func (err fakeError) Unwrap() error {
112 return err.Wrapped
113 }
114
115 type table struct {
116 mu sync.Mutex
117 colname []string
118 coltype []string
119 rows []*row
120 }
121
122 func (t *table) columnIndex(name string) int {
123 return slices.Index(t.colname, name)
124 }
125
126 type row struct {
127 cols []any
128 }
129
130 type memToucher interface {
131
132 touchMem()
133 }
134
135 type fakeConn struct {
136 db *fakeDB
137
138 currTx *fakeTx
139
140
141
142 line int64
143
144
145 mu sync.Mutex
146 stmtsMade int
147 stmtsClosed int
148 numPrepare int
149
150
151 bad bool
152 stickyBad bool
153
154 skipDirtySession bool
155
156
157
158 dirtySession bool
159
160
161
162 waiter func(context.Context)
163 }
164
165 func (c *fakeConn) touchMem() {
166 c.line++
167 }
168
169 func (c *fakeConn) incrStat(v *int) {
170 c.mu.Lock()
171 *v++
172 c.mu.Unlock()
173 }
174
175 type fakeTx struct {
176 c *fakeConn
177 }
178
179 type boundCol struct {
180 Column string
181 Placeholder string
182 Ordinal int
183 }
184
185 type fakeStmt struct {
186 memToucher
187 c *fakeConn
188 q string
189
190 cmd string
191 table string
192 panic string
193 wait time.Duration
194
195 next *fakeStmt
196
197 closed bool
198
199 colName []string
200 colType []string
201 colValue []any
202 placeholders int
203
204 whereCol []boundCol
205
206 placeholderConverter []driver.ValueConverter
207 }
208
209 var fdriver driver.Driver = &fakeDriver{}
210
211 func init() {
212 Register("test", fdriver)
213 }
214
215 type Dummy struct {
216 driver.Driver
217 }
218
219 func TestDrivers(t *testing.T) {
220 unregisterAllDrivers()
221 Register("test", fdriver)
222 Register("invalid", Dummy{})
223 all := Drivers()
224 if len(all) < 2 || !slices.IsSorted(all) || !slices.Contains(all, "test") || !slices.Contains(all, "invalid") {
225 t.Fatalf("Drivers = %v, want sorted list with at least [invalid, test]", all)
226 }
227 }
228
229
230 var hookOpenErr struct {
231 sync.Mutex
232 fn func() error
233 }
234
235 func setHookOpenErr(fn func() error) {
236 hookOpenErr.Lock()
237 defer hookOpenErr.Unlock()
238 hookOpenErr.fn = fn
239 }
240
241
242
243
244
245
246
247 func (d *fakeDriver) Open(dsn string) (driver.Conn, error) {
248 hookOpenErr.Lock()
249 fn := hookOpenErr.fn
250 hookOpenErr.Unlock()
251 if fn != nil {
252 if err := fn(); err != nil {
253 return nil, err
254 }
255 }
256 parts := strings.Split(dsn, ";")
257 if len(parts) < 1 {
258 return nil, errors.New("fakedb: no database name")
259 }
260 name := parts[0]
261
262 db := d.getDB(name)
263
264 d.mu.Lock()
265 d.openCount++
266 d.mu.Unlock()
267 conn := &fakeConn{db: db}
268
269 if len(parts) >= 2 && parts[1] == "badConn" {
270 conn.bad = true
271 }
272 if d.waitCh != nil {
273 d.waitingCh <- struct{}{}
274 <-d.waitCh
275 d.waitCh = nil
276 d.waitingCh = nil
277 }
278 return conn, nil
279 }
280
281 func (d *fakeDriver) getDB(name string) *fakeDB {
282 d.mu.Lock()
283 defer d.mu.Unlock()
284 if d.dbs == nil {
285 d.dbs = make(map[string]*fakeDB)
286 }
287 db, ok := d.dbs[name]
288 if !ok {
289 db = &fakeDB{name: name}
290 d.dbs[name] = db
291 }
292 return db
293 }
294
295 func (db *fakeDB) wipe() {
296 db.mu.Lock()
297 defer db.mu.Unlock()
298 db.tables = nil
299 }
300
301 func (db *fakeDB) createTable(name string, columnNames, columnTypes []string) error {
302 db.mu.Lock()
303 defer db.mu.Unlock()
304 if db.tables == nil {
305 db.tables = make(map[string]*table)
306 }
307 if _, exist := db.tables[name]; exist {
308 return fmt.Errorf("fakedb: table %q already exists", name)
309 }
310 if len(columnNames) != len(columnTypes) {
311 return fmt.Errorf("fakedb: create table of %q len(names) != len(types): %d vs %d",
312 name, len(columnNames), len(columnTypes))
313 }
314 db.tables[name] = &table{colname: columnNames, coltype: columnTypes}
315 return nil
316 }
317
318
319 func (db *fakeDB) table(table string) (*table, bool) {
320 if db.tables == nil {
321 return nil, false
322 }
323 t, ok := db.tables[table]
324 return t, ok
325 }
326
327 func (db *fakeDB) columnType(table, column string) (typ string, ok bool) {
328 db.mu.Lock()
329 defer db.mu.Unlock()
330 t, ok := db.table(table)
331 if !ok {
332 return
333 }
334 if i := slices.Index(t.colname, column); i != -1 {
335 return t.coltype[i], true
336 }
337 return "", false
338 }
339
340 func (c *fakeConn) isBad() bool {
341 if c.stickyBad {
342 return true
343 } else if c.bad {
344 if c.db == nil {
345 return false
346 }
347
348 c.db.badConn = !c.db.badConn
349 return c.db.badConn
350 } else {
351 return false
352 }
353 }
354
355 func (c *fakeConn) isDirtyAndMark() bool {
356 if c.skipDirtySession {
357 return false
358 }
359 if c.currTx != nil {
360 c.dirtySession = true
361 return false
362 }
363 if c.dirtySession {
364 return true
365 }
366 c.dirtySession = true
367 return false
368 }
369
370 func (c *fakeConn) Begin() (driver.Tx, error) {
371 if c.isBad() {
372 return nil, fakeError{Wrapped: driver.ErrBadConn}
373 }
374 if c.currTx != nil {
375 return nil, errors.New("fakedb: already in a transaction")
376 }
377 c.touchMem()
378 c.currTx = &fakeTx{c: c}
379 return c.currTx, nil
380 }
381
382 var hookPostCloseConn struct {
383 sync.Mutex
384 fn func(*fakeConn, error)
385 }
386
387 func setHookpostCloseConn(fn func(*fakeConn, error)) {
388 hookPostCloseConn.Lock()
389 defer hookPostCloseConn.Unlock()
390 hookPostCloseConn.fn = fn
391 }
392
393 var testStrictClose *testing.T
394
395
396
397 func setStrictFakeConnClose(t *testing.T) {
398 testStrictClose = t
399 }
400
401 func (c *fakeConn) ResetSession(ctx context.Context) error {
402 c.dirtySession = false
403 c.currTx = nil
404 if c.isBad() {
405 return fakeError{Message: "Reset Session: bad conn", Wrapped: driver.ErrBadConn}
406 }
407 return nil
408 }
409
410 var _ driver.Validator = (*fakeConn)(nil)
411
412 func (c *fakeConn) IsValid() bool {
413 return !c.isBad()
414 }
415
416 func (c *fakeConn) Close() (err error) {
417 drv := fdriver.(*fakeDriver)
418 defer func() {
419 if err != nil && testStrictClose != nil {
420 testStrictClose.Errorf("failed to close a test fakeConn: %v", err)
421 }
422 hookPostCloseConn.Lock()
423 fn := hookPostCloseConn.fn
424 hookPostCloseConn.Unlock()
425 if fn != nil {
426 fn(c, err)
427 }
428 if err == nil {
429 drv.mu.Lock()
430 drv.closeCount++
431 drv.mu.Unlock()
432 }
433 }()
434 c.touchMem()
435 if c.currTx != nil {
436 return errors.New("fakedb: can't close fakeConn; in a Transaction")
437 }
438 if c.db == nil {
439 return errors.New("fakedb: can't close fakeConn; already closed")
440 }
441 if c.stmtsMade > c.stmtsClosed {
442 return errors.New("fakedb: can't close; dangling statement(s)")
443 }
444 c.db = nil
445 return nil
446 }
447
448 func checkSubsetTypes(allowAny bool, args []driver.NamedValue) error {
449 for _, arg := range args {
450 switch arg.Value.(type) {
451 case int64, float64, bool, nil, []byte, string, time.Time:
452 default:
453 if !allowAny {
454 return fmt.Errorf("fakedb: invalid argument ordinal %[1]d: %[2]v, type %[2]T", arg.Ordinal, arg.Value)
455 }
456 }
457 }
458 return nil
459 }
460
461 func (c *fakeConn) Exec(query string, args []driver.Value) (driver.Result, error) {
462
463 panic("ExecContext was not called.")
464 }
465
466 func (c *fakeConn) ExecContext(ctx context.Context, query string, args []driver.NamedValue) (driver.Result, error) {
467
468
469
470
471 err := checkSubsetTypes(c.db.allowAny, args)
472 if err != nil {
473 return nil, err
474 }
475 return nil, driver.ErrSkip
476 }
477
478 func (c *fakeConn) Query(query string, args []driver.Value) (driver.Rows, error) {
479
480 panic("QueryContext was not called.")
481 }
482
483 func (c *fakeConn) QueryContext(ctx context.Context, query string, args []driver.NamedValue) (driver.Rows, error) {
484
485
486
487
488 err := checkSubsetTypes(c.db.allowAny, args)
489 if err != nil {
490 return nil, err
491 }
492 return nil, driver.ErrSkip
493 }
494
495 func errf(msg string, args ...any) error {
496 return errors.New("fakedb: " + fmt.Sprintf(msg, args...))
497 }
498
499
500
501
502 func (c *fakeConn) prepareSelect(stmt *fakeStmt, parts []string) (*fakeStmt, error) {
503 if len(parts) != 3 {
504 stmt.Close()
505 return nil, errf("invalid SELECT syntax with %d parts; want 3", len(parts))
506 }
507 stmt.table = parts[0]
508
509 stmt.colName = strings.Split(parts[1], ",")
510 for n, colspec := range strings.Split(parts[2], ",") {
511 if colspec == "" {
512 continue
513 }
514 nameVal := strings.Split(colspec, "=")
515 if len(nameVal) != 2 {
516 stmt.Close()
517 return nil, errf("SELECT on table %q has invalid column spec of %q (index %d)", stmt.table, colspec, n)
518 }
519 column, value := nameVal[0], nameVal[1]
520 _, ok := c.db.columnType(stmt.table, column)
521 if !ok {
522 stmt.Close()
523 return nil, errf("SELECT on table %q references non-existent column %q", stmt.table, column)
524 }
525 if !strings.HasPrefix(value, "?") {
526 stmt.Close()
527 return nil, errf("SELECT on table %q has pre-bound value for where column %q; need a question mark",
528 stmt.table, column)
529 }
530 stmt.placeholders++
531 stmt.whereCol = append(stmt.whereCol, boundCol{Column: column, Placeholder: value, Ordinal: stmt.placeholders})
532 }
533 return stmt, nil
534 }
535
536
537 func (c *fakeConn) prepareCreate(stmt *fakeStmt, parts []string) (*fakeStmt, error) {
538 if len(parts) != 2 {
539 stmt.Close()
540 return nil, errf("invalid CREATE syntax with %d parts; want 2", len(parts))
541 }
542 stmt.table = parts[0]
543 for n, colspec := range strings.Split(parts[1], ",") {
544 nameType := strings.Split(colspec, "=")
545 if len(nameType) != 2 {
546 stmt.Close()
547 return nil, errf("CREATE table %q has invalid column spec of %q (index %d)", stmt.table, colspec, n)
548 }
549 stmt.colName = append(stmt.colName, nameType[0])
550 stmt.colType = append(stmt.colType, nameType[1])
551 }
552 return stmt, nil
553 }
554
555
556 func (c *fakeConn) prepareInsert(ctx context.Context, stmt *fakeStmt, parts []string) (*fakeStmt, error) {
557 if len(parts) != 2 {
558 stmt.Close()
559 return nil, errf("invalid INSERT syntax with %d parts; want 2", len(parts))
560 }
561 stmt.table = parts[0]
562 for n, colspec := range strings.Split(parts[1], ",") {
563 nameVal := strings.Split(colspec, "=")
564 if len(nameVal) != 2 {
565 stmt.Close()
566 return nil, errf("INSERT table %q has invalid column spec of %q (index %d)", stmt.table, colspec, n)
567 }
568 column, value := nameVal[0], nameVal[1]
569 ctype, ok := c.db.columnType(stmt.table, column)
570 if !ok {
571 stmt.Close()
572 return nil, errf("INSERT table %q references non-existent column %q", stmt.table, column)
573 }
574 stmt.colName = append(stmt.colName, column)
575
576 if !strings.HasPrefix(value, "?") {
577 var subsetVal any
578
579 switch ctype {
580 case "string":
581 subsetVal = []byte(value)
582 case "blob":
583 subsetVal = []byte(value)
584 case "int32":
585 i, err := strconv.Atoi(value)
586 if err != nil {
587 stmt.Close()
588 return nil, errf("invalid conversion to int32 from %q", value)
589 }
590 subsetVal = int64(i)
591 case "table":
592 c.skipDirtySession = true
593 vparts := strings.Split(value, "!")
594
595 substmt, err := c.PrepareContext(ctx, fmt.Sprintf("SELECT|%s|%s|", vparts[0], strings.Join(vparts[1:], ",")))
596 if err != nil {
597 return nil, err
598 }
599 cursor, err := (substmt.(driver.StmtQueryContext)).QueryContext(ctx, []driver.NamedValue{})
600 substmt.Close()
601 if err != nil {
602 return nil, err
603 }
604 subsetVal = cursor
605 default:
606 stmt.Close()
607 return nil, errf("unsupported conversion for pre-bound parameter %q to type %q", value, ctype)
608 }
609 stmt.colValue = append(stmt.colValue, subsetVal)
610 } else {
611 stmt.placeholders++
612 stmt.placeholderConverter = append(stmt.placeholderConverter, converterForType(ctype))
613 stmt.colValue = append(stmt.colValue, value)
614 }
615 }
616 return stmt, nil
617 }
618
619
620 var hookPrepareBadConn func() bool
621
622 func (c *fakeConn) Prepare(query string) (driver.Stmt, error) {
623 panic("use PrepareContext")
624 }
625
626 func (c *fakeConn) PrepareContext(ctx context.Context, query string) (driver.Stmt, error) {
627 c.numPrepare++
628 if c.db == nil {
629 panic("nil c.db; conn = " + fmt.Sprintf("%#v", c))
630 }
631
632 if c.stickyBad || (hookPrepareBadConn != nil && hookPrepareBadConn()) {
633 return nil, fakeError{Message: "Prepare: Sticky Bad", Wrapped: driver.ErrBadConn}
634 }
635
636 c.touchMem()
637 var firstStmt, prev *fakeStmt
638 for _, query := range strings.Split(query, ";") {
639 parts := strings.Split(query, "|")
640 if len(parts) < 1 {
641 return nil, errf("empty query")
642 }
643 stmt := &fakeStmt{q: query, c: c, memToucher: c}
644 if firstStmt == nil {
645 firstStmt = stmt
646 }
647 if len(parts) >= 3 {
648 switch parts[0] {
649 case "PANIC":
650 stmt.panic = parts[1]
651 parts = parts[2:]
652 case "WAIT":
653 wait, err := time.ParseDuration(parts[1])
654 if err != nil {
655 return nil, errf("expected section after WAIT to be a duration, got %q %v", parts[1], err)
656 }
657 parts = parts[2:]
658 stmt.wait = wait
659 }
660 }
661 cmd := parts[0]
662 stmt.cmd = cmd
663 parts = parts[1:]
664
665 if c.waiter != nil {
666 c.waiter(ctx)
667 if err := ctx.Err(); err != nil {
668 return nil, err
669 }
670 }
671
672 if stmt.wait > 0 {
673 wait := time.NewTimer(stmt.wait)
674 select {
675 case <-wait.C:
676 case <-ctx.Done():
677 wait.Stop()
678 return nil, ctx.Err()
679 }
680 }
681
682 c.incrStat(&c.stmtsMade)
683 var err error
684 switch cmd {
685 case "WIPE":
686
687 case "USE_RAWBYTES":
688 c.db.useRawBytes.Store(true)
689 case "SELECT":
690 stmt, err = c.prepareSelect(stmt, parts)
691 case "CREATE":
692 stmt, err = c.prepareCreate(stmt, parts)
693 case "INSERT":
694 stmt, err = c.prepareInsert(ctx, stmt, parts)
695 case "NOSERT":
696
697
698 stmt, err = c.prepareInsert(ctx, stmt, parts)
699 default:
700 stmt.Close()
701 return nil, errf("unsupported command type %q", cmd)
702 }
703 if err != nil {
704 return nil, err
705 }
706 if prev != nil {
707 prev.next = stmt
708 }
709 prev = stmt
710 }
711 return firstStmt, nil
712 }
713
714 func (s *fakeStmt) ColumnConverter(idx int) driver.ValueConverter {
715 if s.panic == "ColumnConverter" {
716 panic(s.panic)
717 }
718 if len(s.placeholderConverter) == 0 {
719 return driver.DefaultParameterConverter
720 }
721 return s.placeholderConverter[idx]
722 }
723
724 func (s *fakeStmt) Close() error {
725 if s.panic == "Close" {
726 panic(s.panic)
727 }
728 if s.c == nil {
729 panic("nil conn in fakeStmt.Close")
730 }
731 if s.c.db == nil {
732 panic("in fakeStmt.Close, conn's db is nil (already closed)")
733 }
734 s.touchMem()
735 if !s.closed {
736 s.c.incrStat(&s.c.stmtsClosed)
737 s.closed = true
738 }
739 if s.next != nil {
740 s.next.Close()
741 }
742 return nil
743 }
744
745 var errClosed = errors.New("fakedb: statement has been closed")
746
747
748 var hookExecBadConn func() bool
749
750 func (s *fakeStmt) Exec(args []driver.Value) (driver.Result, error) {
751 panic("Using ExecContext")
752 }
753
754 var errFakeConnSessionDirty = errors.New("fakedb: session is dirty")
755
756 func (s *fakeStmt) ExecContext(ctx context.Context, args []driver.NamedValue) (driver.Result, error) {
757 if s.panic == "Exec" {
758 panic(s.panic)
759 }
760 if s.closed {
761 return nil, errClosed
762 }
763
764 if s.c.stickyBad || (hookExecBadConn != nil && hookExecBadConn()) {
765 return nil, fakeError{Message: "Exec: Sticky Bad", Wrapped: driver.ErrBadConn}
766 }
767 if s.c.isDirtyAndMark() {
768 return nil, errFakeConnSessionDirty
769 }
770
771 err := checkSubsetTypes(s.c.db.allowAny, args)
772 if err != nil {
773 return nil, err
774 }
775 s.touchMem()
776
777 if s.wait > 0 {
778 time.Sleep(s.wait)
779 }
780
781 select {
782 default:
783 case <-ctx.Done():
784 return nil, ctx.Err()
785 }
786
787 db := s.c.db
788 switch s.cmd {
789 case "WIPE":
790 db.wipe()
791 return driver.ResultNoRows, nil
792 case "USE_RAWBYTES":
793 s.c.db.useRawBytes.Store(true)
794 return driver.ResultNoRows, nil
795 case "CREATE":
796 if err := db.createTable(s.table, s.colName, s.colType); err != nil {
797 return nil, err
798 }
799 return driver.ResultNoRows, nil
800 case "INSERT":
801 return s.execInsert(args, true)
802 case "NOSERT":
803
804
805 return s.execInsert(args, false)
806 }
807 return nil, fmt.Errorf("fakedb: unimplemented statement Exec command type of %q", s.cmd)
808 }
809
810 func valueFromPlaceholderName(args []driver.NamedValue, name string) driver.Value {
811 for i := range args {
812 if args[i].Name == name {
813 return args[i].Value
814 }
815 }
816 return nil
817 }
818
819
820
821
822 func (s *fakeStmt) execInsert(args []driver.NamedValue, doInsert bool) (driver.Result, error) {
823 db := s.c.db
824 if len(args) != s.placeholders {
825 panic("error in pkg db; should only get here if size is correct")
826 }
827 db.mu.Lock()
828 t, ok := db.table(s.table)
829 db.mu.Unlock()
830 if !ok {
831 return nil, fmt.Errorf("fakedb: table %q doesn't exist", s.table)
832 }
833
834 t.mu.Lock()
835 defer t.mu.Unlock()
836
837 var cols []any
838 if doInsert {
839 cols = make([]any, len(t.colname))
840 }
841 argPos := 0
842 for n, colname := range s.colName {
843 colidx := t.columnIndex(colname)
844 if colidx == -1 {
845 return nil, fmt.Errorf("fakedb: column %q doesn't exist or dropped since prepared statement was created", colname)
846 }
847 var val any
848 if strvalue, ok := s.colValue[n].(string); ok && strings.HasPrefix(strvalue, "?") {
849 if strvalue == "?" {
850 val = args[argPos].Value
851 } else {
852
853 if v := valueFromPlaceholderName(args, strvalue[1:]); v != nil {
854 val = v
855 }
856 }
857 argPos++
858 } else {
859 val = s.colValue[n]
860 }
861 if doInsert {
862 cols[colidx] = val
863 }
864 }
865
866 if doInsert {
867 t.rows = append(t.rows, &row{cols: cols})
868 }
869 return driver.RowsAffected(1), nil
870 }
871
872
873 var hookQueryBadConn func() bool
874
875 func (s *fakeStmt) Query(args []driver.Value) (driver.Rows, error) {
876 panic("Use QueryContext")
877 }
878
879 func (s *fakeStmt) QueryContext(ctx context.Context, args []driver.NamedValue) (driver.Rows, error) {
880 if s.panic == "Query" {
881 panic(s.panic)
882 }
883 if s.closed {
884 return nil, errClosed
885 }
886
887 if s.c.stickyBad || (hookQueryBadConn != nil && hookQueryBadConn()) {
888 return nil, fakeError{Message: "Query: Sticky Bad", Wrapped: driver.ErrBadConn}
889 }
890 if s.c.isDirtyAndMark() {
891 return nil, errFakeConnSessionDirty
892 }
893
894 err := checkSubsetTypes(s.c.db.allowAny, args)
895 if err != nil {
896 return nil, err
897 }
898
899 s.touchMem()
900 db := s.c.db
901 if len(args) != s.placeholders {
902 panic("error in pkg db; should only get here if size is correct")
903 }
904
905 setMRows := make([][]*row, 0, 1)
906 setColumns := make([][]string, 0, 1)
907 setColType := make([][]string, 0, 1)
908
909 for {
910 db.mu.Lock()
911 t, ok := db.table(s.table)
912 db.mu.Unlock()
913 if !ok {
914 return nil, fmt.Errorf("fakedb: table %q doesn't exist", s.table)
915 }
916
917 if s.table == "magicquery" {
918 if len(s.whereCol) == 2 && s.whereCol[0].Column == "op" && s.whereCol[1].Column == "millis" {
919 if args[0].Value == "sleep" {
920 time.Sleep(time.Duration(args[1].Value.(int64)) * time.Millisecond)
921 }
922 }
923 }
924 if s.table == "tx_status" && s.colName[0] == "tx_status" {
925 txStatus := "autocommit"
926 if s.c.currTx != nil {
927 txStatus = "transaction"
928 }
929 cursor := &rowsCursor{
930 db: s.c.db,
931 parentMem: s.c,
932 posRow: -1,
933 rows: [][]*row{
934 {
935 {
936 cols: []any{
937 txStatus,
938 },
939 },
940 },
941 },
942 cols: [][]string{
943 {
944 "tx_status",
945 },
946 },
947 colType: [][]string{
948 {
949 "string",
950 },
951 },
952 errPos: -1,
953 }
954 return cursor, nil
955 }
956
957 t.mu.Lock()
958
959 colIdx := make(map[string]int)
960 for _, name := range s.colName {
961 idx := t.columnIndex(name)
962 if idx == -1 {
963 t.mu.Unlock()
964 return nil, fmt.Errorf("fakedb: unknown column name %q", name)
965 }
966 colIdx[name] = idx
967 }
968
969 mrows := []*row{}
970 rows:
971 for _, trow := range t.rows {
972
973
974
975 for _, wcol := range s.whereCol {
976 idx := t.columnIndex(wcol.Column)
977 if idx == -1 {
978 t.mu.Unlock()
979 return nil, fmt.Errorf("fakedb: invalid where clause column %q", wcol)
980 }
981 tcol := trow.cols[idx]
982 if bs, ok := tcol.([]byte); ok {
983
984 tcol = string(bs)
985 }
986 var argValue any
987 if wcol.Placeholder == "?" {
988 argValue = args[wcol.Ordinal-1].Value
989 } else {
990 if v := valueFromPlaceholderName(args, wcol.Placeholder[1:]); v != nil {
991 argValue = v
992 }
993 }
994 if fmt.Sprintf("%v", tcol) != fmt.Sprintf("%v", argValue) {
995 continue rows
996 }
997 }
998 mrow := &row{cols: make([]any, len(s.colName))}
999 for seli, name := range s.colName {
1000 mrow.cols[seli] = trow.cols[colIdx[name]]
1001 }
1002 mrows = append(mrows, mrow)
1003 }
1004
1005 var colType []string
1006 for _, column := range s.colName {
1007 colType = append(colType, t.coltype[t.columnIndex(column)])
1008 }
1009
1010 t.mu.Unlock()
1011
1012 setMRows = append(setMRows, mrows)
1013 setColumns = append(setColumns, s.colName)
1014 setColType = append(setColType, colType)
1015
1016 if s.next == nil {
1017 break
1018 }
1019 s = s.next
1020 }
1021
1022 cursor := &rowsCursor{
1023 db: s.c.db,
1024 parentMem: s.c,
1025 posRow: -1,
1026 rows: setMRows,
1027 cols: setColumns,
1028 colType: setColType,
1029 errPos: -1,
1030 }
1031 return cursor, nil
1032 }
1033
1034 func (s *fakeStmt) NumInput() int {
1035 if s.panic == "NumInput" {
1036 panic(s.panic)
1037 }
1038 return s.placeholders
1039 }
1040
1041
1042 var hookCommitBadConn func() bool
1043
1044 func (tx *fakeTx) Commit() error {
1045 tx.c.currTx = nil
1046 if hookCommitBadConn != nil && hookCommitBadConn() {
1047 return fakeError{Message: "Commit: Hook Bad Conn", Wrapped: driver.ErrBadConn}
1048 }
1049 tx.c.touchMem()
1050 return nil
1051 }
1052
1053
1054 var hookRollbackBadConn func() bool
1055
1056 func (tx *fakeTx) Rollback() error {
1057 tx.c.currTx = nil
1058 if hookRollbackBadConn != nil && hookRollbackBadConn() {
1059 return fakeError{Message: "Rollback: Hook Bad Conn", Wrapped: driver.ErrBadConn}
1060 }
1061 tx.c.touchMem()
1062 return nil
1063 }
1064
1065 type rowsCursor struct {
1066 db *fakeDB
1067 parentMem memToucher
1068 cols [][]string
1069 colType [][]string
1070 posSet int
1071 posRow int
1072 rows [][]*row
1073 closed bool
1074
1075
1076 errPos int
1077 err error
1078
1079
1080
1081
1082 bytesClone map[*byte][]byte
1083
1084
1085
1086
1087
1088 line int64
1089
1090
1091 closeErr error
1092 }
1093
1094 func (rc *rowsCursor) touchMem() {
1095 rc.parentMem.touchMem()
1096 rc.line++
1097 }
1098
1099 func (rc *rowsCursor) Close() error {
1100 rc.touchMem()
1101 rc.parentMem.touchMem()
1102 rc.closed = true
1103 return rc.closeErr
1104 }
1105
1106 func (rc *rowsCursor) Columns() []string {
1107 return rc.cols[rc.posSet]
1108 }
1109
1110 func (rc *rowsCursor) ColumnTypeScanType(index int) reflect.Type {
1111 return colTypeToReflectType(rc.colType[rc.posSet][index])
1112 }
1113
1114 var rowsCursorNextHook func(dest []driver.Value) error
1115
1116 func (rc *rowsCursor) Next(dest []driver.Value) error {
1117 if rowsCursorNextHook != nil {
1118 return rowsCursorNextHook(dest)
1119 }
1120
1121 if rc.closed {
1122 return errors.New("fakedb: cursor is closed")
1123 }
1124 rc.touchMem()
1125 rc.posRow++
1126 if rc.posRow == rc.errPos {
1127 return rc.err
1128 }
1129 if rc.posRow >= len(rc.rows[rc.posSet]) {
1130 return io.EOF
1131 }
1132 for i, v := range rc.rows[rc.posSet][rc.posRow].cols {
1133
1134
1135
1136
1137
1138
1139 dest[i] = v
1140
1141 if bs, ok := v.([]byte); ok && !rc.db.useRawBytes.Load() {
1142 if rc.bytesClone == nil {
1143 rc.bytesClone = make(map[*byte][]byte)
1144 }
1145 clone, ok := rc.bytesClone[&bs[0]]
1146 if !ok {
1147 clone = make([]byte, len(bs))
1148 copy(clone, bs)
1149 rc.bytesClone[&bs[0]] = clone
1150 }
1151 dest[i] = clone
1152 }
1153 }
1154 return nil
1155 }
1156
1157 func (rc *rowsCursor) HasNextResultSet() bool {
1158 rc.touchMem()
1159 return rc.posSet < len(rc.rows)-1
1160 }
1161
1162 func (rc *rowsCursor) NextResultSet() error {
1163 rc.touchMem()
1164 if rc.HasNextResultSet() {
1165 rc.posSet++
1166 rc.posRow = -1
1167 return nil
1168 }
1169 return io.EOF
1170 }
1171
1172
1173
1174
1175
1176
1177
1178 type fakeDriverString struct{}
1179
1180 func (fakeDriverString) ConvertValue(v any) (driver.Value, error) {
1181 switch c := v.(type) {
1182 case string, []byte:
1183 return v, nil
1184 case *string:
1185 if c == nil {
1186 return nil, nil
1187 }
1188 return *c, nil
1189 }
1190 return fmt.Sprintf("%v", v), nil
1191 }
1192
1193 type anyTypeConverter struct{}
1194
1195 func (anyTypeConverter) ConvertValue(v any) (driver.Value, error) {
1196 return v, nil
1197 }
1198
1199 func converterForType(typ string) driver.ValueConverter {
1200 switch typ {
1201 case "bool":
1202 return driver.Bool
1203 case "nullbool":
1204 return driver.Null{Converter: driver.Bool}
1205 case "byte", "int16":
1206 return driver.NotNull{Converter: driver.DefaultParameterConverter}
1207 case "int32":
1208 return driver.Int32
1209 case "nullbyte", "nullint32", "nullint16":
1210 return driver.Null{Converter: driver.DefaultParameterConverter}
1211 case "string":
1212 return driver.NotNull{Converter: fakeDriverString{}}
1213 case "nullstring":
1214 return driver.Null{Converter: fakeDriverString{}}
1215 case "int64":
1216
1217 return driver.NotNull{Converter: driver.DefaultParameterConverter}
1218 case "nullint64":
1219
1220 return driver.Null{Converter: driver.DefaultParameterConverter}
1221 case "float64":
1222
1223 return driver.NotNull{Converter: driver.DefaultParameterConverter}
1224 case "nullfloat64":
1225
1226 return driver.Null{Converter: driver.DefaultParameterConverter}
1227 case "datetime":
1228 return driver.NotNull{Converter: driver.DefaultParameterConverter}
1229 case "nulldatetime":
1230 return driver.Null{Converter: driver.DefaultParameterConverter}
1231 case "any":
1232 return anyTypeConverter{}
1233 }
1234 panic("invalid fakedb column type of " + typ)
1235 }
1236
1237 func colTypeToReflectType(typ string) reflect.Type {
1238 switch typ {
1239 case "bool":
1240 return reflect.TypeFor[bool]()
1241 case "nullbool":
1242 return reflect.TypeFor[NullBool]()
1243 case "int16":
1244 return reflect.TypeFor[int16]()
1245 case "nullint16":
1246 return reflect.TypeFor[NullInt16]()
1247 case "int32":
1248 return reflect.TypeFor[int32]()
1249 case "nullint32":
1250 return reflect.TypeFor[NullInt32]()
1251 case "string":
1252 return reflect.TypeFor[string]()
1253 case "nullstring":
1254 return reflect.TypeFor[NullString]()
1255 case "int64":
1256 return reflect.TypeFor[int64]()
1257 case "nullint64":
1258 return reflect.TypeFor[NullInt64]()
1259 case "float64":
1260 return reflect.TypeFor[float64]()
1261 case "nullfloat64":
1262 return reflect.TypeFor[NullFloat64]()
1263 case "datetime":
1264 return reflect.TypeFor[time.Time]()
1265 case "any":
1266 return reflect.TypeFor[any]()
1267 }
1268 panic("invalid fakedb column type of " + typ)
1269 }
1270
View as plain text