1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16 package sql
17
18 import (
19 "context"
20 "database/sql/driver"
21 "errors"
22 "fmt"
23 "io"
24 "maps"
25 "math/rand/v2"
26 "reflect"
27 "runtime"
28 "slices"
29 "strconv"
30 "sync"
31 "sync/atomic"
32 "time"
33 _ "unsafe"
34 )
35
36 var driversMu sync.RWMutex
37
38
39
40
41
42
43
44
45
46
47
48 var drivers = make(map[string]driver.Driver)
49
50
51 var nowFunc = time.Now
52
53
54
55
56 func Register(name string, driver driver.Driver) {
57 driversMu.Lock()
58 defer driversMu.Unlock()
59 if driver == nil {
60 panic("sql: Register driver is nil")
61 }
62 if _, dup := drivers[name]; dup {
63 panic("sql: Register called twice for driver " + name)
64 }
65 drivers[name] = driver
66 }
67
68 func unregisterAllDrivers() {
69 driversMu.Lock()
70 defer driversMu.Unlock()
71
72 drivers = make(map[string]driver.Driver)
73 }
74
75
76 func Drivers() []string {
77 driversMu.RLock()
78 defer driversMu.RUnlock()
79 return slices.Sorted(maps.Keys(drivers))
80 }
81
82
83
84
85
86
87
88 type NamedArg struct {
89 _NamedFieldsRequired struct{}
90
91
92
93
94
95
96
97 Name string
98
99
100
101
102 Value any
103 }
104
105
106
107
108
109
110
111
112
113
114
115
116
117 func Named(name string, value any) NamedArg {
118
119
120
121
122 return NamedArg{Name: name, Value: value}
123 }
124
125
126 type IsolationLevel int
127
128
129
130
131
132 const (
133 LevelDefault IsolationLevel = iota
134 LevelReadUncommitted
135 LevelReadCommitted
136 LevelWriteCommitted
137 LevelRepeatableRead
138 LevelSnapshot
139 LevelSerializable
140 LevelLinearizable
141 )
142
143
144 func (i IsolationLevel) String() string {
145 switch i {
146 case LevelDefault:
147 return "Default"
148 case LevelReadUncommitted:
149 return "Read Uncommitted"
150 case LevelReadCommitted:
151 return "Read Committed"
152 case LevelWriteCommitted:
153 return "Write Committed"
154 case LevelRepeatableRead:
155 return "Repeatable Read"
156 case LevelSnapshot:
157 return "Snapshot"
158 case LevelSerializable:
159 return "Serializable"
160 case LevelLinearizable:
161 return "Linearizable"
162 default:
163 return "IsolationLevel(" + strconv.Itoa(int(i)) + ")"
164 }
165 }
166
167 var _ fmt.Stringer = LevelDefault
168
169
170 type TxOptions struct {
171
172
173 Isolation IsolationLevel
174 ReadOnly bool
175 }
176
177
178
179
180 type RawBytes []byte
181
182
183
184
185
186
187
188
189
190
191
192
193
194 type NullString struct {
195 String string
196 Valid bool
197 }
198
199
200 func (ns *NullString) Scan(value any) error {
201 if value == nil {
202 ns.String, ns.Valid = "", false
203 return nil
204 }
205 err := convertAssign(&ns.String, value)
206 ns.Valid = err == nil
207 return err
208 }
209
210
211 func (ns NullString) Value() (driver.Value, error) {
212 if !ns.Valid {
213 return nil, nil
214 }
215 return ns.String, nil
216 }
217
218
219
220
221 type NullInt64 struct {
222 Int64 int64
223 Valid bool
224 }
225
226
227 func (n *NullInt64) Scan(value any) error {
228 if value == nil {
229 n.Int64, n.Valid = 0, false
230 return nil
231 }
232 err := convertAssign(&n.Int64, value)
233 n.Valid = err == nil
234 return err
235 }
236
237
238 func (n NullInt64) Value() (driver.Value, error) {
239 if !n.Valid {
240 return nil, nil
241 }
242 return n.Int64, nil
243 }
244
245
246
247
248 type NullInt32 struct {
249 Int32 int32
250 Valid bool
251 }
252
253
254 func (n *NullInt32) Scan(value any) error {
255 if value == nil {
256 n.Int32, n.Valid = 0, false
257 return nil
258 }
259 err := convertAssign(&n.Int32, value)
260 n.Valid = err == nil
261 return err
262 }
263
264
265 func (n NullInt32) Value() (driver.Value, error) {
266 if !n.Valid {
267 return nil, nil
268 }
269 return int64(n.Int32), nil
270 }
271
272
273
274
275 type NullInt16 struct {
276 Int16 int16
277 Valid bool
278 }
279
280
281 func (n *NullInt16) Scan(value any) error {
282 if value == nil {
283 n.Int16, n.Valid = 0, false
284 return nil
285 }
286 err := convertAssign(&n.Int16, value)
287 n.Valid = err == nil
288 return err
289 }
290
291
292 func (n NullInt16) Value() (driver.Value, error) {
293 if !n.Valid {
294 return nil, nil
295 }
296 return int64(n.Int16), nil
297 }
298
299
300
301
302 type NullByte struct {
303 Byte byte
304 Valid bool
305 }
306
307
308 func (n *NullByte) Scan(value any) error {
309 if value == nil {
310 n.Byte, n.Valid = 0, false
311 return nil
312 }
313 err := convertAssign(&n.Byte, value)
314 n.Valid = err == nil
315 return err
316 }
317
318
319 func (n NullByte) Value() (driver.Value, error) {
320 if !n.Valid {
321 return nil, nil
322 }
323 return int64(n.Byte), nil
324 }
325
326
327
328
329 type NullFloat64 struct {
330 Float64 float64
331 Valid bool
332 }
333
334
335 func (n *NullFloat64) Scan(value any) error {
336 if value == nil {
337 n.Float64, n.Valid = 0, false
338 return nil
339 }
340 err := convertAssign(&n.Float64, value)
341 n.Valid = err == nil
342 return err
343 }
344
345
346 func (n NullFloat64) Value() (driver.Value, error) {
347 if !n.Valid {
348 return nil, nil
349 }
350 return n.Float64, nil
351 }
352
353
354
355
356 type NullBool struct {
357 Bool bool
358 Valid bool
359 }
360
361
362 func (n *NullBool) Scan(value any) error {
363 if value == nil {
364 n.Bool, n.Valid = false, false
365 return nil
366 }
367 err := convertAssign(&n.Bool, value)
368 n.Valid = err == nil
369 return err
370 }
371
372
373 func (n NullBool) Value() (driver.Value, error) {
374 if !n.Valid {
375 return nil, nil
376 }
377 return n.Bool, nil
378 }
379
380
381
382
383 type NullTime struct {
384 Time time.Time
385 Valid bool
386 }
387
388
389 func (n *NullTime) Scan(value any) error {
390 if value == nil {
391 n.Time, n.Valid = time.Time{}, false
392 return nil
393 }
394 err := convertAssign(&n.Time, value)
395 n.Valid = err == nil
396 return err
397 }
398
399
400 func (n NullTime) Value() (driver.Value, error) {
401 if !n.Valid {
402 return nil, nil
403 }
404 return n.Time, nil
405 }
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421 type Null[T any] struct {
422 V T
423 Valid bool
424 }
425
426 func (n *Null[T]) Scan(value any) error {
427 if value == nil {
428 n.V, n.Valid = *new(T), false
429 return nil
430 }
431 err := convertAssign(&n.V, value)
432 n.Valid = err == nil
433 return err
434 }
435
436 func (n Null[T]) Value() (driver.Value, error) {
437 if !n.Valid {
438 return nil, nil
439 }
440 v := any(n.V)
441
442 if valuer, ok := v.(driver.Valuer); ok {
443 val, err := callValuerValue(valuer)
444 if err != nil {
445 return val, err
446 }
447 v = val
448 }
449
450 return driver.DefaultParameterConverter.ConvertValue(v)
451 }
452
453
454 type Scanner interface {
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473 Scan(src any) error
474 }
475
476
477
478
479
480
481
482
483
484 type Out struct {
485 _NamedFieldsRequired struct{}
486
487
488
489 Dest any
490
491
492
493
494 In bool
495 }
496
497
498
499
500 var ErrNoRows = errors.New("sql: no rows in result set")
501
502
503
504
505
506
507
508
509
510
511
512
513
514 type DB struct {
515
516 waitDuration atomic.Int64
517
518 connector driver.Connector
519
520
521
522 numClosed atomic.Uint64
523
524 mu sync.Mutex
525 freeConn []*driverConn
526 connRequests connRequestSet
527 numOpen int
528
529
530
531
532
533 openerCh chan struct{}
534 closed bool
535 dep map[finalCloser]depSet
536 lastPut map[*driverConn]string
537 maxIdleCount int
538 maxOpen int
539 maxLifetime time.Duration
540 maxIdleTime time.Duration
541 cleanerCh chan struct{}
542 waitCount int64
543 maxIdleClosed int64
544 maxIdleTimeClosed int64
545 maxLifetimeClosed int64
546
547 stop func()
548 }
549
550
551 type connReuseStrategy uint8
552
553 const (
554
555 alwaysNewConn connReuseStrategy = iota
556
557
558
559 cachedOrNewConn
560 )
561
562
563
564
565
566 type driverConn struct {
567 db *DB
568 createdAt time.Time
569
570 sync.Mutex
571 ci driver.Conn
572 needReset bool
573 closed bool
574 finalClosed bool
575 openStmt map[*driverStmt]bool
576
577
578 inUse bool
579 dbmuClosed bool
580 returnedAt time.Time
581 onPut []func()
582 }
583
584 func (dc *driverConn) releaseConn(err error) {
585 dc.db.putConn(dc, err, true)
586 }
587
588 func (dc *driverConn) removeOpenStmt(ds *driverStmt) {
589 dc.Lock()
590 defer dc.Unlock()
591 delete(dc.openStmt, ds)
592 }
593
594 func (dc *driverConn) expired(timeout time.Duration) bool {
595 if timeout <= 0 {
596 return false
597 }
598 return dc.createdAt.Add(timeout).Before(nowFunc())
599 }
600
601
602
603 func (dc *driverConn) resetSession(ctx context.Context) error {
604 dc.Lock()
605 defer dc.Unlock()
606
607 if !dc.needReset {
608 return nil
609 }
610 if cr, ok := dc.ci.(driver.SessionResetter); ok {
611 return cr.ResetSession(ctx)
612 }
613 return nil
614 }
615
616
617
618 func (dc *driverConn) validateConnection(needsReset bool) bool {
619 dc.Lock()
620 defer dc.Unlock()
621
622 if needsReset {
623 dc.needReset = true
624 }
625 if cv, ok := dc.ci.(driver.Validator); ok {
626 return cv.IsValid()
627 }
628 return true
629 }
630
631
632
633 func (dc *driverConn) prepareLocked(ctx context.Context, cg stmtConnGrabber, query string) (*driverStmt, error) {
634 si, err := ctxDriverPrepare(ctx, dc.ci, query)
635 if err != nil {
636 return nil, err
637 }
638 ds := &driverStmt{Locker: dc, si: si}
639
640
641 if cg != nil {
642 return ds, nil
643 }
644
645
646
647
648
649 if dc.openStmt == nil {
650 dc.openStmt = make(map[*driverStmt]bool)
651 }
652 dc.openStmt[ds] = true
653 return ds, nil
654 }
655
656
657 func (dc *driverConn) closeDBLocked() func() error {
658 dc.Lock()
659 defer dc.Unlock()
660 if dc.closed {
661 return func() error { return errors.New("sql: duplicate driverConn close") }
662 }
663 dc.closed = true
664 return dc.db.removeDepLocked(dc, dc)
665 }
666
667 func (dc *driverConn) Close() error {
668 dc.Lock()
669 if dc.closed {
670 dc.Unlock()
671 return errors.New("sql: duplicate driverConn close")
672 }
673 dc.closed = true
674 dc.Unlock()
675
676
677 dc.db.mu.Lock()
678 dc.dbmuClosed = true
679 fn := dc.db.removeDepLocked(dc, dc)
680 dc.db.mu.Unlock()
681 return fn()
682 }
683
684 func (dc *driverConn) finalClose() error {
685 var err error
686
687
688
689 var openStmt []*driverStmt
690 withLock(dc, func() {
691 openStmt = make([]*driverStmt, 0, len(dc.openStmt))
692 for ds := range dc.openStmt {
693 openStmt = append(openStmt, ds)
694 }
695 dc.openStmt = nil
696 })
697 for _, ds := range openStmt {
698 ds.Close()
699 }
700 withLock(dc, func() {
701 dc.finalClosed = true
702 err = dc.ci.Close()
703 dc.ci = nil
704 })
705
706 dc.db.mu.Lock()
707 dc.db.numOpen--
708 dc.db.maybeOpenNewConnections()
709 dc.db.mu.Unlock()
710
711 dc.db.numClosed.Add(1)
712 return err
713 }
714
715
716
717
718 type driverStmt struct {
719 sync.Locker
720 si driver.Stmt
721 closed bool
722 closeErr error
723 }
724
725
726
727 func (ds *driverStmt) Close() error {
728 ds.Lock()
729 defer ds.Unlock()
730 if ds.closed {
731 return ds.closeErr
732 }
733 ds.closed = true
734 ds.closeErr = ds.si.Close()
735 return ds.closeErr
736 }
737
738
739 type depSet map[any]bool
740
741
742
743 type finalCloser interface {
744
745
746 finalClose() error
747 }
748
749
750
751 func (db *DB) addDep(x finalCloser, dep any) {
752 db.mu.Lock()
753 defer db.mu.Unlock()
754 db.addDepLocked(x, dep)
755 }
756
757 func (db *DB) addDepLocked(x finalCloser, dep any) {
758 if db.dep == nil {
759 db.dep = make(map[finalCloser]depSet)
760 }
761 xdep := db.dep[x]
762 if xdep == nil {
763 xdep = make(depSet)
764 db.dep[x] = xdep
765 }
766 xdep[dep] = true
767 }
768
769
770
771
772
773 func (db *DB) removeDep(x finalCloser, dep any) error {
774 db.mu.Lock()
775 fn := db.removeDepLocked(x, dep)
776 db.mu.Unlock()
777 return fn()
778 }
779
780 func (db *DB) removeDepLocked(x finalCloser, dep any) func() error {
781 xdep, ok := db.dep[x]
782 if !ok {
783 panic(fmt.Sprintf("unpaired removeDep: no deps for %T", x))
784 }
785
786 l0 := len(xdep)
787 delete(xdep, dep)
788
789 switch len(xdep) {
790 case l0:
791
792 panic(fmt.Sprintf("unpaired removeDep: no %T dep on %T", dep, x))
793 case 0:
794
795 delete(db.dep, x)
796 return x.finalClose
797 default:
798
799 return func() error { return nil }
800 }
801 }
802
803
804
805
806
807
808 var connectionRequestQueueSize = 1000000
809
810 type dsnConnector struct {
811 dsn string
812 driver driver.Driver
813 }
814
815 func (t dsnConnector) Connect(_ context.Context) (driver.Conn, error) {
816 return t.driver.Open(t.dsn)
817 }
818
819 func (t dsnConnector) Driver() driver.Driver {
820 return t.driver
821 }
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839 func OpenDB(c driver.Connector) *DB {
840 ctx, cancel := context.WithCancel(context.Background())
841 db := &DB{
842 connector: c,
843 openerCh: make(chan struct{}, connectionRequestQueueSize),
844 lastPut: make(map[*driverConn]string),
845 stop: cancel,
846 }
847
848 go db.connectionOpener(ctx)
849
850 return db
851 }
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870 func Open(driverName, dataSourceName string) (*DB, error) {
871 driversMu.RLock()
872 driveri, ok := drivers[driverName]
873 driversMu.RUnlock()
874 if !ok {
875 return nil, fmt.Errorf("sql: unknown driver %q (forgotten import?)", driverName)
876 }
877
878 if driverCtx, ok := driveri.(driver.DriverContext); ok {
879 connector, err := driverCtx.OpenConnector(dataSourceName)
880 if err != nil {
881 return nil, err
882 }
883 return OpenDB(connector), nil
884 }
885
886 return OpenDB(dsnConnector{dsn: dataSourceName, driver: driveri}), nil
887 }
888
889 func (db *DB) pingDC(ctx context.Context, dc *driverConn, release func(error)) error {
890 var err error
891 if pinger, ok := dc.ci.(driver.Pinger); ok {
892 withLock(dc, func() {
893 err = pinger.Ping(ctx)
894 })
895 }
896 release(err)
897 return err
898 }
899
900
901
902 func (db *DB) PingContext(ctx context.Context) error {
903 var dc *driverConn
904 var err error
905
906 err = db.retry(func(strategy connReuseStrategy) error {
907 dc, err = db.conn(ctx, strategy)
908 return err
909 })
910
911 if err != nil {
912 return err
913 }
914
915 return db.pingDC(ctx, dc, dc.releaseConn)
916 }
917
918
919
920
921
922
923 func (db *DB) Ping() error {
924 return db.PingContext(context.Background())
925 }
926
927
928
929
930
931
932
933 func (db *DB) Close() error {
934 db.mu.Lock()
935 if db.closed {
936 db.mu.Unlock()
937 return nil
938 }
939 if db.cleanerCh != nil {
940 close(db.cleanerCh)
941 }
942 var err error
943 fns := make([]func() error, 0, len(db.freeConn))
944 for _, dc := range db.freeConn {
945 fns = append(fns, dc.closeDBLocked())
946 }
947 db.freeConn = nil
948 db.closed = true
949 db.connRequests.CloseAndRemoveAll()
950 db.mu.Unlock()
951 for _, fn := range fns {
952 err1 := fn()
953 if err1 != nil {
954 err = err1
955 }
956 }
957 db.stop()
958 if c, ok := db.connector.(io.Closer); ok {
959 err1 := c.Close()
960 if err1 != nil {
961 err = err1
962 }
963 }
964 return err
965 }
966
967 const defaultMaxIdleConns = 2
968
969 func (db *DB) maxIdleConnsLocked() int {
970 n := db.maxIdleCount
971 switch {
972 case n == 0:
973
974 return defaultMaxIdleConns
975 case n < 0:
976 return 0
977 default:
978 return n
979 }
980 }
981
982 func (db *DB) shortestIdleTimeLocked() time.Duration {
983 if db.maxIdleTime <= 0 {
984 return db.maxLifetime
985 }
986 if db.maxLifetime <= 0 {
987 return db.maxIdleTime
988 }
989 return min(db.maxIdleTime, db.maxLifetime)
990 }
991
992
993
994
995
996
997
998
999
1000
1001
1002 func (db *DB) SetMaxIdleConns(n int) {
1003 db.mu.Lock()
1004 if n > 0 {
1005 db.maxIdleCount = n
1006 } else {
1007
1008 db.maxIdleCount = -1
1009 }
1010
1011 if db.maxOpen > 0 && db.maxIdleConnsLocked() > db.maxOpen {
1012 db.maxIdleCount = db.maxOpen
1013 }
1014 var closing []*driverConn
1015 idleCount := len(db.freeConn)
1016 maxIdle := db.maxIdleConnsLocked()
1017 if idleCount > maxIdle {
1018 closing = db.freeConn[maxIdle:]
1019 db.freeConn = db.freeConn[:maxIdle]
1020 }
1021 db.maxIdleClosed += int64(len(closing))
1022 db.mu.Unlock()
1023 for _, c := range closing {
1024 c.Close()
1025 }
1026 }
1027
1028
1029
1030
1031
1032
1033
1034
1035
1036 func (db *DB) SetMaxOpenConns(n int) {
1037 db.mu.Lock()
1038 db.maxOpen = n
1039 if n < 0 {
1040 db.maxOpen = 0
1041 }
1042 syncMaxIdle := db.maxOpen > 0 && db.maxIdleConnsLocked() > db.maxOpen
1043 db.mu.Unlock()
1044 if syncMaxIdle {
1045 db.SetMaxIdleConns(n)
1046 }
1047 }
1048
1049
1050
1051
1052
1053
1054 func (db *DB) SetConnMaxLifetime(d time.Duration) {
1055 if d < 0 {
1056 d = 0
1057 }
1058 db.mu.Lock()
1059
1060 if d > 0 && d < db.shortestIdleTimeLocked() && db.cleanerCh != nil {
1061 select {
1062 case db.cleanerCh <- struct{}{}:
1063 default:
1064 }
1065 }
1066 db.maxLifetime = d
1067 db.startCleanerLocked()
1068 db.mu.Unlock()
1069 }
1070
1071
1072
1073
1074
1075
1076 func (db *DB) SetConnMaxIdleTime(d time.Duration) {
1077 if d < 0 {
1078 d = 0
1079 }
1080 db.mu.Lock()
1081 defer db.mu.Unlock()
1082
1083
1084 if d > 0 && d < db.shortestIdleTimeLocked() && db.cleanerCh != nil {
1085 select {
1086 case db.cleanerCh <- struct{}{}:
1087 default:
1088 }
1089 }
1090 db.maxIdleTime = d
1091 db.startCleanerLocked()
1092 }
1093
1094
1095 func (db *DB) startCleanerLocked() {
1096 if (db.maxLifetime > 0 || db.maxIdleTime > 0) && db.numOpen > 0 && db.cleanerCh == nil {
1097 db.cleanerCh = make(chan struct{}, 1)
1098 go db.connectionCleaner(db.shortestIdleTimeLocked())
1099 }
1100 }
1101
1102 func (db *DB) connectionCleaner(d time.Duration) {
1103 const minInterval = time.Second
1104
1105 if d < minInterval {
1106 d = minInterval
1107 }
1108 t := time.NewTimer(d)
1109
1110 for {
1111 select {
1112 case <-t.C:
1113 case <-db.cleanerCh:
1114 }
1115
1116 db.mu.Lock()
1117
1118 d = db.shortestIdleTimeLocked()
1119 if db.closed || db.numOpen == 0 || d <= 0 {
1120 db.cleanerCh = nil
1121 db.mu.Unlock()
1122 return
1123 }
1124
1125 d, closing := db.connectionCleanerRunLocked(d)
1126 db.mu.Unlock()
1127 for _, c := range closing {
1128 c.Close()
1129 }
1130
1131 if d < minInterval {
1132 d = minInterval
1133 }
1134
1135 if !t.Stop() {
1136 select {
1137 case <-t.C:
1138 default:
1139 }
1140 }
1141 t.Reset(d)
1142 }
1143 }
1144
1145
1146
1147
1148 func (db *DB) connectionCleanerRunLocked(d time.Duration) (time.Duration, []*driverConn) {
1149 var idleClosing int64
1150 var closing []*driverConn
1151 if db.maxIdleTime > 0 {
1152
1153
1154 idleSince := nowFunc().Add(-db.maxIdleTime)
1155 last := len(db.freeConn) - 1
1156 for i := last; i >= 0; i-- {
1157 c := db.freeConn[i]
1158 if c.returnedAt.Before(idleSince) {
1159 i++
1160 closing = db.freeConn[:i:i]
1161 db.freeConn = db.freeConn[i:]
1162 idleClosing = int64(len(closing))
1163 db.maxIdleTimeClosed += idleClosing
1164 break
1165 }
1166 }
1167
1168 if len(db.freeConn) > 0 {
1169 c := db.freeConn[0]
1170 if d2 := c.returnedAt.Sub(idleSince); d2 < d {
1171
1172
1173 d = d2
1174 }
1175 }
1176 }
1177
1178 if db.maxLifetime > 0 {
1179 expiredSince := nowFunc().Add(-db.maxLifetime)
1180 for i := 0; i < len(db.freeConn); i++ {
1181 c := db.freeConn[i]
1182 if c.createdAt.Before(expiredSince) {
1183 closing = append(closing, c)
1184
1185 last := len(db.freeConn) - 1
1186
1187
1188 copy(db.freeConn[i:], db.freeConn[i+1:])
1189 db.freeConn[last] = nil
1190 db.freeConn = db.freeConn[:last]
1191 i--
1192 } else if d2 := c.createdAt.Sub(expiredSince); d2 < d {
1193
1194
1195 d = d2
1196 }
1197 }
1198 db.maxLifetimeClosed += int64(len(closing)) - idleClosing
1199 }
1200
1201 return d, closing
1202 }
1203
1204
1205 type DBStats struct {
1206 MaxOpenConnections int
1207
1208
1209 OpenConnections int
1210 InUse int
1211 Idle int
1212
1213
1214 WaitCount int64
1215 WaitDuration time.Duration
1216 MaxIdleClosed int64
1217 MaxIdleTimeClosed int64
1218 MaxLifetimeClosed int64
1219 }
1220
1221
1222 func (db *DB) Stats() DBStats {
1223 wait := db.waitDuration.Load()
1224
1225 db.mu.Lock()
1226 defer db.mu.Unlock()
1227
1228 stats := DBStats{
1229 MaxOpenConnections: db.maxOpen,
1230
1231 Idle: len(db.freeConn),
1232 OpenConnections: db.numOpen,
1233 InUse: db.numOpen - len(db.freeConn),
1234
1235 WaitCount: db.waitCount,
1236 WaitDuration: time.Duration(wait),
1237 MaxIdleClosed: db.maxIdleClosed,
1238 MaxIdleTimeClosed: db.maxIdleTimeClosed,
1239 MaxLifetimeClosed: db.maxLifetimeClosed,
1240 }
1241 return stats
1242 }
1243
1244
1245
1246
1247 func (db *DB) maybeOpenNewConnections() {
1248 numRequests := db.connRequests.Len()
1249 if db.maxOpen > 0 {
1250 numCanOpen := db.maxOpen - db.numOpen
1251 if numRequests > numCanOpen {
1252 numRequests = numCanOpen
1253 }
1254 }
1255 for numRequests > 0 {
1256 db.numOpen++
1257 numRequests--
1258 if db.closed {
1259 return
1260 }
1261 db.openerCh <- struct{}{}
1262 }
1263 }
1264
1265
1266 func (db *DB) connectionOpener(ctx context.Context) {
1267 for {
1268 select {
1269 case <-ctx.Done():
1270 return
1271 case <-db.openerCh:
1272 db.openNewConnection(ctx)
1273 }
1274 }
1275 }
1276
1277
1278 func (db *DB) openNewConnection(ctx context.Context) {
1279
1280
1281
1282 ci, err := db.connector.Connect(ctx)
1283 db.mu.Lock()
1284 defer db.mu.Unlock()
1285 if db.closed {
1286 if err == nil {
1287 ci.Close()
1288 }
1289 db.numOpen--
1290 return
1291 }
1292 if err != nil {
1293 db.numOpen--
1294 db.putConnDBLocked(nil, err)
1295 db.maybeOpenNewConnections()
1296 return
1297 }
1298 dc := &driverConn{
1299 db: db,
1300 createdAt: nowFunc(),
1301 returnedAt: nowFunc(),
1302 ci: ci,
1303 }
1304 if db.putConnDBLocked(dc, err) {
1305 db.addDepLocked(dc, dc)
1306 } else {
1307 db.numOpen--
1308 ci.Close()
1309 }
1310 }
1311
1312
1313
1314
1315 type connRequest struct {
1316 conn *driverConn
1317 err error
1318 }
1319
1320 var errDBClosed = errors.New("sql: database is closed")
1321
1322
1323 func (db *DB) conn(ctx context.Context, strategy connReuseStrategy) (*driverConn, error) {
1324 db.mu.Lock()
1325 if db.closed {
1326 db.mu.Unlock()
1327 return nil, errDBClosed
1328 }
1329
1330 select {
1331 default:
1332 case <-ctx.Done():
1333 db.mu.Unlock()
1334 return nil, ctx.Err()
1335 }
1336 lifetime := db.maxLifetime
1337
1338
1339 last := len(db.freeConn) - 1
1340 if strategy == cachedOrNewConn && last >= 0 {
1341
1342
1343 conn := db.freeConn[last]
1344 db.freeConn = db.freeConn[:last]
1345 conn.inUse = true
1346 if conn.expired(lifetime) {
1347 db.maxLifetimeClosed++
1348 db.mu.Unlock()
1349 conn.Close()
1350 return nil, driver.ErrBadConn
1351 }
1352 db.mu.Unlock()
1353
1354
1355 if err := conn.resetSession(ctx); errors.Is(err, driver.ErrBadConn) {
1356 conn.Close()
1357 return nil, err
1358 }
1359
1360 return conn, nil
1361 }
1362
1363
1364
1365 if db.maxOpen > 0 && db.numOpen >= db.maxOpen {
1366
1367
1368 req := make(chan connRequest, 1)
1369 delHandle := db.connRequests.Add(req)
1370 db.waitCount++
1371 db.mu.Unlock()
1372
1373 waitStart := nowFunc()
1374
1375
1376 select {
1377 case <-ctx.Done():
1378
1379
1380 db.mu.Lock()
1381 deleted := db.connRequests.Delete(delHandle)
1382 db.mu.Unlock()
1383
1384 db.waitDuration.Add(int64(time.Since(waitStart)))
1385
1386
1387
1388 if !deleted {
1389
1390
1391
1392
1393
1394
1395 select {
1396 default:
1397 case ret, ok := <-req:
1398 if ok && ret.conn != nil {
1399 db.putConn(ret.conn, ret.err, false)
1400 }
1401 }
1402 }
1403 return nil, ctx.Err()
1404 case ret, ok := <-req:
1405 db.waitDuration.Add(int64(time.Since(waitStart)))
1406
1407 if !ok {
1408 return nil, errDBClosed
1409 }
1410
1411
1412
1413
1414
1415
1416 if strategy == cachedOrNewConn && ret.err == nil && ret.conn.expired(lifetime) {
1417 db.mu.Lock()
1418 db.maxLifetimeClosed++
1419 db.mu.Unlock()
1420 ret.conn.Close()
1421 return nil, driver.ErrBadConn
1422 }
1423 if ret.conn == nil {
1424 return nil, ret.err
1425 }
1426
1427
1428 if err := ret.conn.resetSession(ctx); errors.Is(err, driver.ErrBadConn) {
1429 ret.conn.Close()
1430 return nil, err
1431 }
1432 return ret.conn, ret.err
1433 }
1434 }
1435
1436 db.numOpen++
1437 db.mu.Unlock()
1438 ci, err := db.connector.Connect(ctx)
1439 if err != nil {
1440 db.mu.Lock()
1441 db.numOpen--
1442 db.maybeOpenNewConnections()
1443 db.mu.Unlock()
1444 return nil, err
1445 }
1446 db.mu.Lock()
1447 dc := &driverConn{
1448 db: db,
1449 createdAt: nowFunc(),
1450 returnedAt: nowFunc(),
1451 ci: ci,
1452 inUse: true,
1453 }
1454 db.addDepLocked(dc, dc)
1455 db.mu.Unlock()
1456 return dc, nil
1457 }
1458
1459
1460 var putConnHook func(*DB, *driverConn)
1461
1462
1463
1464
1465 func (db *DB) noteUnusedDriverStatement(c *driverConn, ds *driverStmt) {
1466 db.mu.Lock()
1467 defer db.mu.Unlock()
1468 if c.inUse {
1469 c.onPut = append(c.onPut, func() {
1470 ds.Close()
1471 })
1472 } else {
1473 c.Lock()
1474 fc := c.finalClosed
1475 c.Unlock()
1476 if !fc {
1477 ds.Close()
1478 }
1479 }
1480 }
1481
1482
1483
1484 const debugGetPut = false
1485
1486
1487
1488 func (db *DB) putConn(dc *driverConn, err error, resetSession bool) {
1489 if !errors.Is(err, driver.ErrBadConn) {
1490 if !dc.validateConnection(resetSession) {
1491 err = driver.ErrBadConn
1492 }
1493 }
1494 db.mu.Lock()
1495 if !dc.inUse {
1496 db.mu.Unlock()
1497 if debugGetPut {
1498 fmt.Printf("putConn(%v) DUPLICATE was: %s\n\nPREVIOUS was: %s", dc, stack(), db.lastPut[dc])
1499 }
1500 panic("sql: connection returned that was never out")
1501 }
1502
1503 if !errors.Is(err, driver.ErrBadConn) && dc.expired(db.maxLifetime) {
1504 db.maxLifetimeClosed++
1505 err = driver.ErrBadConn
1506 }
1507 if debugGetPut {
1508 db.lastPut[dc] = stack()
1509 }
1510 dc.inUse = false
1511 dc.returnedAt = nowFunc()
1512
1513 for _, fn := range dc.onPut {
1514 fn()
1515 }
1516 dc.onPut = nil
1517
1518 if errors.Is(err, driver.ErrBadConn) {
1519
1520
1521
1522
1523 db.maybeOpenNewConnections()
1524 db.mu.Unlock()
1525 dc.Close()
1526 return
1527 }
1528 if putConnHook != nil {
1529 putConnHook(db, dc)
1530 }
1531 added := db.putConnDBLocked(dc, nil)
1532 db.mu.Unlock()
1533
1534 if !added {
1535 dc.Close()
1536 return
1537 }
1538 }
1539
1540
1541
1542
1543
1544
1545
1546
1547
1548
1549 func (db *DB) putConnDBLocked(dc *driverConn, err error) bool {
1550 if db.closed {
1551 return false
1552 }
1553 if db.maxOpen > 0 && db.numOpen > db.maxOpen {
1554 return false
1555 }
1556 if req, ok := db.connRequests.TakeRandom(); ok {
1557 if err == nil {
1558 dc.inUse = true
1559 }
1560 req <- connRequest{
1561 conn: dc,
1562 err: err,
1563 }
1564 return true
1565 } else if err == nil && !db.closed {
1566 if db.maxIdleConnsLocked() > len(db.freeConn) {
1567 db.freeConn = append(db.freeConn, dc)
1568 db.startCleanerLocked()
1569 return true
1570 }
1571 db.maxIdleClosed++
1572 }
1573 return false
1574 }
1575
1576
1577
1578
1579 const maxBadConnRetries = 2
1580
1581 func (db *DB) retry(fn func(strategy connReuseStrategy) error) error {
1582 for i := int64(0); i < maxBadConnRetries; i++ {
1583 err := fn(cachedOrNewConn)
1584
1585 if err == nil || !errors.Is(err, driver.ErrBadConn) {
1586 return err
1587 }
1588 }
1589
1590 return fn(alwaysNewConn)
1591 }
1592
1593
1594
1595
1596
1597
1598
1599
1600
1601 func (db *DB) PrepareContext(ctx context.Context, query string) (*Stmt, error) {
1602 var stmt *Stmt
1603 var err error
1604
1605 err = db.retry(func(strategy connReuseStrategy) error {
1606 stmt, err = db.prepare(ctx, query, strategy)
1607 return err
1608 })
1609
1610 return stmt, err
1611 }
1612
1613
1614
1615
1616
1617
1618
1619
1620
1621 func (db *DB) Prepare(query string) (*Stmt, error) {
1622 return db.PrepareContext(context.Background(), query)
1623 }
1624
1625 func (db *DB) prepare(ctx context.Context, query string, strategy connReuseStrategy) (*Stmt, error) {
1626
1627
1628
1629
1630
1631
1632 dc, err := db.conn(ctx, strategy)
1633 if err != nil {
1634 return nil, err
1635 }
1636 return db.prepareDC(ctx, dc, dc.releaseConn, nil, query)
1637 }
1638
1639
1640
1641
1642 func (db *DB) prepareDC(ctx context.Context, dc *driverConn, release func(error), cg stmtConnGrabber, query string) (*Stmt, error) {
1643 var ds *driverStmt
1644 var err error
1645 defer func() {
1646 release(err)
1647 }()
1648 withLock(dc, func() {
1649 ds, err = dc.prepareLocked(ctx, cg, query)
1650 })
1651 if err != nil {
1652 return nil, err
1653 }
1654 stmt := &Stmt{
1655 db: db,
1656 query: query,
1657 cg: cg,
1658 cgds: ds,
1659 }
1660
1661
1662
1663
1664 if cg == nil {
1665 stmt.css = []connStmt{{dc, ds}}
1666 stmt.lastNumClosed = db.numClosed.Load()
1667 db.addDep(stmt, stmt)
1668 }
1669 return stmt, nil
1670 }
1671
1672
1673
1674 func (db *DB) ExecContext(ctx context.Context, query string, args ...any) (Result, error) {
1675 var res Result
1676 var err error
1677
1678 err = db.retry(func(strategy connReuseStrategy) error {
1679 res, err = db.exec(ctx, query, args, strategy)
1680 return err
1681 })
1682
1683 return res, err
1684 }
1685
1686
1687
1688
1689
1690
1691 func (db *DB) Exec(query string, args ...any) (Result, error) {
1692 return db.ExecContext(context.Background(), query, args...)
1693 }
1694
1695 func (db *DB) exec(ctx context.Context, query string, args []any, strategy connReuseStrategy) (Result, error) {
1696 dc, err := db.conn(ctx, strategy)
1697 if err != nil {
1698 return nil, err
1699 }
1700 return db.execDC(ctx, dc, dc.releaseConn, query, args)
1701 }
1702
1703 func (db *DB) execDC(ctx context.Context, dc *driverConn, release func(error), query string, args []any) (res Result, err error) {
1704 defer func() {
1705 release(err)
1706 }()
1707 execerCtx, ok := dc.ci.(driver.ExecerContext)
1708 var execer driver.Execer
1709 if !ok {
1710 execer, ok = dc.ci.(driver.Execer)
1711 }
1712 if ok {
1713 var nvdargs []driver.NamedValue
1714 var resi driver.Result
1715 withLock(dc, func() {
1716 nvdargs, err = driverArgsConnLocked(dc.ci, nil, args)
1717 if err != nil {
1718 return
1719 }
1720 resi, err = ctxDriverExec(ctx, execerCtx, execer, query, nvdargs)
1721 })
1722 if err != driver.ErrSkip {
1723 if err != nil {
1724 return nil, err
1725 }
1726 return driverResult{dc, resi}, nil
1727 }
1728 }
1729
1730 var si driver.Stmt
1731 withLock(dc, func() {
1732 si, err = ctxDriverPrepare(ctx, dc.ci, query)
1733 })
1734 if err != nil {
1735 return nil, err
1736 }
1737 ds := &driverStmt{Locker: dc, si: si}
1738 defer ds.Close()
1739 return resultFromStatement(ctx, dc.ci, ds, args...)
1740 }
1741
1742
1743
1744 func (db *DB) QueryContext(ctx context.Context, query string, args ...any) (*Rows, error) {
1745 var rows *Rows
1746 var err error
1747
1748 err = db.retry(func(strategy connReuseStrategy) error {
1749 rows, err = db.query(ctx, query, args, strategy)
1750 return err
1751 })
1752
1753 return rows, err
1754 }
1755
1756
1757
1758
1759
1760
1761 func (db *DB) Query(query string, args ...any) (*Rows, error) {
1762 return db.QueryContext(context.Background(), query, args...)
1763 }
1764
1765 func (db *DB) query(ctx context.Context, query string, args []any, strategy connReuseStrategy) (*Rows, error) {
1766 dc, err := db.conn(ctx, strategy)
1767 if err != nil {
1768 return nil, err
1769 }
1770
1771 return db.queryDC(ctx, nil, dc, dc.releaseConn, query, args)
1772 }
1773
1774
1775
1776
1777
1778 func (db *DB) queryDC(ctx, txctx context.Context, dc *driverConn, releaseConn func(error), query string, args []any) (*Rows, error) {
1779 queryerCtx, ok := dc.ci.(driver.QueryerContext)
1780 var queryer driver.Queryer
1781 if !ok {
1782 queryer, ok = dc.ci.(driver.Queryer)
1783 }
1784 if ok {
1785 var nvdargs []driver.NamedValue
1786 var rowsi driver.Rows
1787 var err error
1788 withLock(dc, func() {
1789 nvdargs, err = driverArgsConnLocked(dc.ci, nil, args)
1790 if err != nil {
1791 return
1792 }
1793 rowsi, err = ctxDriverQuery(ctx, queryerCtx, queryer, query, nvdargs)
1794 })
1795 if err != driver.ErrSkip {
1796 if err != nil {
1797 releaseConn(err)
1798 return nil, err
1799 }
1800
1801
1802 rows := &Rows{
1803 dc: dc,
1804 releaseConn: releaseConn,
1805 rowsi: rowsi,
1806 }
1807 rows.initContextClose(ctx, txctx)
1808 return rows, nil
1809 }
1810 }
1811
1812 var si driver.Stmt
1813 var err error
1814 withLock(dc, func() {
1815 si, err = ctxDriverPrepare(ctx, dc.ci, query)
1816 })
1817 if err != nil {
1818 releaseConn(err)
1819 return nil, err
1820 }
1821
1822 ds := &driverStmt{Locker: dc, si: si}
1823 rowsi, err := rowsiFromStatement(ctx, dc.ci, ds, args...)
1824 if err != nil {
1825 ds.Close()
1826 releaseConn(err)
1827 return nil, err
1828 }
1829
1830
1831
1832 rows := &Rows{
1833 dc: dc,
1834 releaseConn: releaseConn,
1835 rowsi: rowsi,
1836 closeStmt: ds,
1837 }
1838 rows.initContextClose(ctx, txctx)
1839 return rows, nil
1840 }
1841
1842
1843
1844
1845
1846
1847
1848 func (db *DB) QueryRowContext(ctx context.Context, query string, args ...any) *Row {
1849 rows, err := db.QueryContext(ctx, query, args...)
1850 return &Row{rows: rows, err: err}
1851 }
1852
1853
1854
1855
1856
1857
1858
1859
1860
1861
1862 func (db *DB) QueryRow(query string, args ...any) *Row {
1863 return db.QueryRowContext(context.Background(), query, args...)
1864 }
1865
1866
1867
1868
1869
1870
1871
1872
1873
1874
1875
1876 func (db *DB) BeginTx(ctx context.Context, opts *TxOptions) (*Tx, error) {
1877 var tx *Tx
1878 var err error
1879
1880 err = db.retry(func(strategy connReuseStrategy) error {
1881 tx, err = db.begin(ctx, opts, strategy)
1882 return err
1883 })
1884
1885 return tx, err
1886 }
1887
1888
1889
1890
1891
1892
1893 func (db *DB) Begin() (*Tx, error) {
1894 return db.BeginTx(context.Background(), nil)
1895 }
1896
1897 func (db *DB) begin(ctx context.Context, opts *TxOptions, strategy connReuseStrategy) (tx *Tx, err error) {
1898 dc, err := db.conn(ctx, strategy)
1899 if err != nil {
1900 return nil, err
1901 }
1902 return db.beginDC(ctx, dc, dc.releaseConn, opts)
1903 }
1904
1905
1906 func (db *DB) beginDC(ctx context.Context, dc *driverConn, release func(error), opts *TxOptions) (tx *Tx, err error) {
1907 var txi driver.Tx
1908 keepConnOnRollback := false
1909 withLock(dc, func() {
1910 _, hasSessionResetter := dc.ci.(driver.SessionResetter)
1911 _, hasConnectionValidator := dc.ci.(driver.Validator)
1912 keepConnOnRollback = hasSessionResetter && hasConnectionValidator
1913 txi, err = ctxDriverBegin(ctx, opts, dc.ci)
1914 })
1915 if err != nil {
1916 release(err)
1917 return nil, err
1918 }
1919
1920
1921
1922 ctx, cancel := context.WithCancel(ctx)
1923 tx = &Tx{
1924 db: db,
1925 dc: dc,
1926 releaseConn: release,
1927 txi: txi,
1928 cancel: cancel,
1929 keepConnOnRollback: keepConnOnRollback,
1930 ctx: ctx,
1931 }
1932 go tx.awaitDone()
1933 return tx, nil
1934 }
1935
1936
1937 func (db *DB) Driver() driver.Driver {
1938 return db.connector.Driver()
1939 }
1940
1941
1942
1943 var ErrConnDone = errors.New("sql: connection is already closed")
1944
1945
1946
1947
1948
1949
1950
1951
1952 func (db *DB) Conn(ctx context.Context) (*Conn, error) {
1953 var dc *driverConn
1954 var err error
1955
1956 err = db.retry(func(strategy connReuseStrategy) error {
1957 dc, err = db.conn(ctx, strategy)
1958 return err
1959 })
1960
1961 if err != nil {
1962 return nil, err
1963 }
1964
1965 conn := &Conn{
1966 db: db,
1967 dc: dc,
1968 }
1969 return conn, nil
1970 }
1971
1972 type releaseConn func(error)
1973
1974
1975
1976
1977
1978
1979
1980
1981
1982
1983 type Conn struct {
1984 db *DB
1985
1986
1987
1988
1989 closemu sync.RWMutex
1990
1991
1992
1993 dc *driverConn
1994
1995
1996
1997 done atomic.Bool
1998
1999 releaseConnOnce sync.Once
2000
2001
2002 releaseConnCache releaseConn
2003 }
2004
2005
2006
2007 func (c *Conn) grabConn(context.Context) (*driverConn, releaseConn, error) {
2008 if c.done.Load() {
2009 return nil, nil, ErrConnDone
2010 }
2011 c.releaseConnOnce.Do(func() {
2012 c.releaseConnCache = c.closemuRUnlockCondReleaseConn
2013 })
2014 c.closemu.RLock()
2015 return c.dc, c.releaseConnCache, nil
2016 }
2017
2018
2019 func (c *Conn) PingContext(ctx context.Context) error {
2020 dc, release, err := c.grabConn(ctx)
2021 if err != nil {
2022 return err
2023 }
2024 return c.db.pingDC(ctx, dc, release)
2025 }
2026
2027
2028
2029 func (c *Conn) ExecContext(ctx context.Context, query string, args ...any) (Result, error) {
2030 dc, release, err := c.grabConn(ctx)
2031 if err != nil {
2032 return nil, err
2033 }
2034 return c.db.execDC(ctx, dc, release, query, args)
2035 }
2036
2037
2038
2039 func (c *Conn) QueryContext(ctx context.Context, query string, args ...any) (*Rows, error) {
2040 dc, release, err := c.grabConn(ctx)
2041 if err != nil {
2042 return nil, err
2043 }
2044 return c.db.queryDC(ctx, nil, dc, release, query, args)
2045 }
2046
2047
2048
2049
2050
2051
2052
2053 func (c *Conn) QueryRowContext(ctx context.Context, query string, args ...any) *Row {
2054 rows, err := c.QueryContext(ctx, query, args...)
2055 return &Row{rows: rows, err: err}
2056 }
2057
2058
2059
2060
2061
2062
2063
2064
2065
2066 func (c *Conn) PrepareContext(ctx context.Context, query string) (*Stmt, error) {
2067 dc, release, err := c.grabConn(ctx)
2068 if err != nil {
2069 return nil, err
2070 }
2071 return c.db.prepareDC(ctx, dc, release, c, query)
2072 }
2073
2074
2075
2076
2077
2078
2079 func (c *Conn) Raw(f func(driverConn any) error) (err error) {
2080 var dc *driverConn
2081 var release releaseConn
2082
2083
2084 dc, release, err = c.grabConn(nil)
2085 if err != nil {
2086 return
2087 }
2088 fPanic := true
2089 dc.Mutex.Lock()
2090 defer func() {
2091 dc.Mutex.Unlock()
2092
2093
2094
2095
2096 if fPanic {
2097 err = driver.ErrBadConn
2098 }
2099 release(err)
2100 }()
2101 err = f(dc.ci)
2102 fPanic = false
2103
2104 return
2105 }
2106
2107
2108
2109
2110
2111
2112
2113
2114
2115
2116
2117 func (c *Conn) BeginTx(ctx context.Context, opts *TxOptions) (*Tx, error) {
2118 dc, release, err := c.grabConn(ctx)
2119 if err != nil {
2120 return nil, err
2121 }
2122 return c.db.beginDC(ctx, dc, release, opts)
2123 }
2124
2125
2126
2127 func (c *Conn) closemuRUnlockCondReleaseConn(err error) {
2128 c.closemu.RUnlock()
2129 if errors.Is(err, driver.ErrBadConn) {
2130 c.close(err)
2131 }
2132 }
2133
2134 func (c *Conn) txCtx() context.Context {
2135 return nil
2136 }
2137
2138 func (c *Conn) close(err error) error {
2139 if !c.done.CompareAndSwap(false, true) {
2140 return ErrConnDone
2141 }
2142
2143
2144
2145 c.closemu.Lock()
2146 defer c.closemu.Unlock()
2147
2148 c.dc.releaseConn(err)
2149 c.dc = nil
2150 c.db = nil
2151 return err
2152 }
2153
2154
2155
2156
2157
2158
2159 func (c *Conn) Close() error {
2160 return c.close(nil)
2161 }
2162
2163
2164
2165
2166
2167
2168
2169
2170
2171
2172
2173 type Tx struct {
2174 db *DB
2175
2176
2177
2178
2179 closemu sync.RWMutex
2180
2181
2182
2183 dc *driverConn
2184 txi driver.Tx
2185
2186
2187
2188 releaseConn func(error)
2189
2190
2191
2192
2193 done atomic.Bool
2194
2195
2196
2197
2198 keepConnOnRollback bool
2199
2200
2201
2202 stmts struct {
2203 sync.Mutex
2204 v []*Stmt
2205 }
2206
2207
2208 cancel func()
2209
2210
2211 ctx context.Context
2212 }
2213
2214
2215
2216 func (tx *Tx) awaitDone() {
2217
2218
2219 <-tx.ctx.Done()
2220
2221
2222
2223
2224
2225
2226
2227 discardConnection := !tx.keepConnOnRollback
2228 tx.rollback(discardConnection)
2229 }
2230
2231 func (tx *Tx) isDone() bool {
2232 return tx.done.Load()
2233 }
2234
2235
2236
2237 var ErrTxDone = errors.New("sql: transaction has already been committed or rolled back")
2238
2239
2240
2241
2242 func (tx *Tx) close(err error) {
2243 tx.releaseConn(err)
2244 tx.dc = nil
2245 tx.txi = nil
2246 }
2247
2248
2249
2250 var hookTxGrabConn func()
2251
2252 func (tx *Tx) grabConn(ctx context.Context) (*driverConn, releaseConn, error) {
2253 select {
2254 default:
2255 case <-ctx.Done():
2256 return nil, nil, ctx.Err()
2257 }
2258
2259
2260
2261 tx.closemu.RLock()
2262 if tx.isDone() {
2263 tx.closemu.RUnlock()
2264 return nil, nil, ErrTxDone
2265 }
2266 if hookTxGrabConn != nil {
2267 hookTxGrabConn()
2268 }
2269 return tx.dc, tx.closemuRUnlockRelease, nil
2270 }
2271
2272 func (tx *Tx) txCtx() context.Context {
2273 return tx.ctx
2274 }
2275
2276
2277
2278
2279
2280 func (tx *Tx) closemuRUnlockRelease(error) {
2281 tx.closemu.RUnlock()
2282 }
2283
2284
2285 func (tx *Tx) closePrepared() {
2286 tx.stmts.Lock()
2287 defer tx.stmts.Unlock()
2288 for _, stmt := range tx.stmts.v {
2289 stmt.Close()
2290 }
2291 }
2292
2293
2294 func (tx *Tx) Commit() error {
2295
2296
2297
2298 select {
2299 default:
2300 case <-tx.ctx.Done():
2301 if tx.done.Load() {
2302 return ErrTxDone
2303 }
2304 return tx.ctx.Err()
2305 }
2306 if !tx.done.CompareAndSwap(false, true) {
2307 return ErrTxDone
2308 }
2309
2310
2311
2312
2313
2314 tx.cancel()
2315 tx.closemu.Lock()
2316 tx.closemu.Unlock()
2317
2318 var err error
2319 withLock(tx.dc, func() {
2320 err = tx.txi.Commit()
2321 })
2322 if !errors.Is(err, driver.ErrBadConn) {
2323 tx.closePrepared()
2324 }
2325 tx.close(err)
2326 return err
2327 }
2328
2329 var rollbackHook func()
2330
2331
2332
2333 func (tx *Tx) rollback(discardConn bool) error {
2334 if !tx.done.CompareAndSwap(false, true) {
2335 return ErrTxDone
2336 }
2337
2338 if rollbackHook != nil {
2339 rollbackHook()
2340 }
2341
2342
2343
2344
2345
2346 tx.cancel()
2347 tx.closemu.Lock()
2348 tx.closemu.Unlock()
2349
2350 var err error
2351 withLock(tx.dc, func() {
2352 err = tx.txi.Rollback()
2353 })
2354 if !errors.Is(err, driver.ErrBadConn) {
2355 tx.closePrepared()
2356 }
2357 if discardConn {
2358 err = driver.ErrBadConn
2359 }
2360 tx.close(err)
2361 return err
2362 }
2363
2364
2365 func (tx *Tx) Rollback() error {
2366 return tx.rollback(false)
2367 }
2368
2369
2370
2371
2372
2373
2374
2375
2376
2377
2378
2379 func (tx *Tx) PrepareContext(ctx context.Context, query string) (*Stmt, error) {
2380 dc, release, err := tx.grabConn(ctx)
2381 if err != nil {
2382 return nil, err
2383 }
2384
2385 stmt, err := tx.db.prepareDC(ctx, dc, release, tx, query)
2386 if err != nil {
2387 return nil, err
2388 }
2389 tx.stmts.Lock()
2390 tx.stmts.v = append(tx.stmts.v, stmt)
2391 tx.stmts.Unlock()
2392 return stmt, nil
2393 }
2394
2395
2396
2397
2398
2399
2400
2401
2402
2403
2404 func (tx *Tx) Prepare(query string) (*Stmt, error) {
2405 return tx.PrepareContext(context.Background(), query)
2406 }
2407
2408
2409
2410
2411
2412
2413
2414
2415
2416
2417
2418
2419
2420
2421
2422
2423
2424 func (tx *Tx) StmtContext(ctx context.Context, stmt *Stmt) *Stmt {
2425 dc, release, err := tx.grabConn(ctx)
2426 if err != nil {
2427 return &Stmt{stickyErr: err}
2428 }
2429 defer release(nil)
2430
2431 if tx.db != stmt.db {
2432 return &Stmt{stickyErr: errors.New("sql: Tx.Stmt: statement from different database used")}
2433 }
2434 var si driver.Stmt
2435 var parentStmt *Stmt
2436 stmt.mu.Lock()
2437 if stmt.closed || stmt.cg != nil {
2438
2439
2440
2441
2442
2443
2444 stmt.mu.Unlock()
2445 withLock(dc, func() {
2446 si, err = ctxDriverPrepare(ctx, dc.ci, stmt.query)
2447 })
2448 if err != nil {
2449 return &Stmt{stickyErr: err}
2450 }
2451 } else {
2452 stmt.removeClosedStmtLocked()
2453
2454
2455 for _, v := range stmt.css {
2456 if v.dc == dc {
2457 si = v.ds.si
2458 break
2459 }
2460 }
2461
2462 stmt.mu.Unlock()
2463
2464 if si == nil {
2465 var ds *driverStmt
2466 withLock(dc, func() {
2467 ds, err = stmt.prepareOnConnLocked(ctx, dc)
2468 })
2469 if err != nil {
2470 return &Stmt{stickyErr: err}
2471 }
2472 si = ds.si
2473 }
2474 parentStmt = stmt
2475 }
2476
2477 txs := &Stmt{
2478 db: tx.db,
2479 cg: tx,
2480 cgds: &driverStmt{
2481 Locker: dc,
2482 si: si,
2483 },
2484 parentStmt: parentStmt,
2485 query: stmt.query,
2486 }
2487 if parentStmt != nil {
2488 tx.db.addDep(parentStmt, txs)
2489 }
2490 tx.stmts.Lock()
2491 tx.stmts.v = append(tx.stmts.v, txs)
2492 tx.stmts.Unlock()
2493 return txs
2494 }
2495
2496
2497
2498
2499
2500
2501
2502
2503
2504
2505
2506
2507
2508
2509
2510
2511
2512 func (tx *Tx) Stmt(stmt *Stmt) *Stmt {
2513 return tx.StmtContext(context.Background(), stmt)
2514 }
2515
2516
2517
2518 func (tx *Tx) ExecContext(ctx context.Context, query string, args ...any) (Result, error) {
2519 dc, release, err := tx.grabConn(ctx)
2520 if err != nil {
2521 return nil, err
2522 }
2523 return tx.db.execDC(ctx, dc, release, query, args)
2524 }
2525
2526
2527
2528
2529
2530
2531 func (tx *Tx) Exec(query string, args ...any) (Result, error) {
2532 return tx.ExecContext(context.Background(), query, args...)
2533 }
2534
2535
2536 func (tx *Tx) QueryContext(ctx context.Context, query string, args ...any) (*Rows, error) {
2537 dc, release, err := tx.grabConn(ctx)
2538 if err != nil {
2539 return nil, err
2540 }
2541
2542 return tx.db.queryDC(ctx, tx.ctx, dc, release, query, args)
2543 }
2544
2545
2546
2547
2548
2549 func (tx *Tx) Query(query string, args ...any) (*Rows, error) {
2550 return tx.QueryContext(context.Background(), query, args...)
2551 }
2552
2553
2554
2555
2556
2557
2558
2559 func (tx *Tx) QueryRowContext(ctx context.Context, query string, args ...any) *Row {
2560 rows, err := tx.QueryContext(ctx, query, args...)
2561 return &Row{rows: rows, err: err}
2562 }
2563
2564
2565
2566
2567
2568
2569
2570
2571
2572
2573 func (tx *Tx) QueryRow(query string, args ...any) *Row {
2574 return tx.QueryRowContext(context.Background(), query, args...)
2575 }
2576
2577
2578 type connStmt struct {
2579 dc *driverConn
2580 ds *driverStmt
2581 }
2582
2583
2584
2585 type stmtConnGrabber interface {
2586
2587
2588 grabConn(context.Context) (*driverConn, releaseConn, error)
2589
2590
2591
2592
2593 txCtx() context.Context
2594 }
2595
2596 var (
2597 _ stmtConnGrabber = &Tx{}
2598 _ stmtConnGrabber = &Conn{}
2599 )
2600
2601
2602
2603
2604
2605
2606
2607
2608
2609
2610 type Stmt struct {
2611
2612 db *DB
2613 query string
2614 stickyErr error
2615
2616 closemu sync.RWMutex
2617
2618
2619
2620
2621
2622
2623 cg stmtConnGrabber
2624 cgds *driverStmt
2625
2626
2627
2628
2629
2630
2631
2632 parentStmt *Stmt
2633
2634 mu sync.Mutex
2635 closed bool
2636
2637
2638
2639
2640
2641 css []connStmt
2642
2643
2644
2645 lastNumClosed uint64
2646 }
2647
2648
2649
2650 func (s *Stmt) ExecContext(ctx context.Context, args ...any) (Result, error) {
2651 s.closemu.RLock()
2652 defer s.closemu.RUnlock()
2653
2654 var res Result
2655 err := s.db.retry(func(strategy connReuseStrategy) error {
2656 dc, releaseConn, ds, err := s.connStmt(ctx, strategy)
2657 if err != nil {
2658 return err
2659 }
2660
2661 res, err = resultFromStatement(ctx, dc.ci, ds, args...)
2662 releaseConn(err)
2663 return err
2664 })
2665
2666 return res, err
2667 }
2668
2669
2670
2671
2672
2673
2674 func (s *Stmt) Exec(args ...any) (Result, error) {
2675 return s.ExecContext(context.Background(), args...)
2676 }
2677
2678 func resultFromStatement(ctx context.Context, ci driver.Conn, ds *driverStmt, args ...any) (Result, error) {
2679 ds.Lock()
2680 defer ds.Unlock()
2681
2682 dargs, err := driverArgsConnLocked(ci, ds, args)
2683 if err != nil {
2684 return nil, err
2685 }
2686
2687 resi, err := ctxDriverStmtExec(ctx, ds.si, dargs)
2688 if err != nil {
2689 return nil, err
2690 }
2691 return driverResult{ds.Locker, resi}, nil
2692 }
2693
2694
2695
2696
2697
2698 func (s *Stmt) removeClosedStmtLocked() {
2699 t := len(s.css)/2 + 1
2700 if t > 10 {
2701 t = 10
2702 }
2703 dbClosed := s.db.numClosed.Load()
2704 if dbClosed-s.lastNumClosed < uint64(t) {
2705 return
2706 }
2707
2708 s.db.mu.Lock()
2709 for i := 0; i < len(s.css); i++ {
2710 if s.css[i].dc.dbmuClosed {
2711 s.css[i] = s.css[len(s.css)-1]
2712
2713 s.css[len(s.css)-1] = connStmt{}
2714 s.css = s.css[:len(s.css)-1]
2715 i--
2716 }
2717 }
2718 s.db.mu.Unlock()
2719 s.lastNumClosed = dbClosed
2720 }
2721
2722
2723
2724
2725 func (s *Stmt) connStmt(ctx context.Context, strategy connReuseStrategy) (dc *driverConn, releaseConn func(error), ds *driverStmt, err error) {
2726 if err = s.stickyErr; err != nil {
2727 return
2728 }
2729 s.mu.Lock()
2730 if s.closed {
2731 s.mu.Unlock()
2732 err = errors.New("sql: statement is closed")
2733 return
2734 }
2735
2736
2737
2738 if s.cg != nil {
2739 s.mu.Unlock()
2740 dc, releaseConn, err = s.cg.grabConn(ctx)
2741 if err != nil {
2742 return
2743 }
2744 return dc, releaseConn, s.cgds, nil
2745 }
2746
2747 s.removeClosedStmtLocked()
2748 s.mu.Unlock()
2749
2750 dc, err = s.db.conn(ctx, strategy)
2751 if err != nil {
2752 return nil, nil, nil, err
2753 }
2754
2755 s.mu.Lock()
2756 for _, v := range s.css {
2757 if v.dc == dc {
2758 s.mu.Unlock()
2759 return dc, dc.releaseConn, v.ds, nil
2760 }
2761 }
2762 s.mu.Unlock()
2763
2764
2765 withLock(dc, func() {
2766 ds, err = s.prepareOnConnLocked(ctx, dc)
2767 })
2768 if err != nil {
2769 dc.releaseConn(err)
2770 return nil, nil, nil, err
2771 }
2772
2773 return dc, dc.releaseConn, ds, nil
2774 }
2775
2776
2777
2778 func (s *Stmt) prepareOnConnLocked(ctx context.Context, dc *driverConn) (*driverStmt, error) {
2779 si, err := dc.prepareLocked(ctx, s.cg, s.query)
2780 if err != nil {
2781 return nil, err
2782 }
2783 cs := connStmt{dc, si}
2784 s.mu.Lock()
2785 s.css = append(s.css, cs)
2786 s.mu.Unlock()
2787 return cs.ds, nil
2788 }
2789
2790
2791
2792 func (s *Stmt) QueryContext(ctx context.Context, args ...any) (*Rows, error) {
2793 s.closemu.RLock()
2794 defer s.closemu.RUnlock()
2795
2796 var rowsi driver.Rows
2797 var rows *Rows
2798
2799 err := s.db.retry(func(strategy connReuseStrategy) error {
2800 dc, releaseConn, ds, err := s.connStmt(ctx, strategy)
2801 if err != nil {
2802 return err
2803 }
2804
2805 rowsi, err = rowsiFromStatement(ctx, dc.ci, ds, args...)
2806 if err == nil {
2807
2808
2809 rows = &Rows{
2810 dc: dc,
2811 rowsi: rowsi,
2812
2813 }
2814
2815
2816 s.db.addDep(s, rows)
2817
2818
2819
2820 rows.releaseConn = func(err error) {
2821 releaseConn(err)
2822 s.db.removeDep(s, rows)
2823 }
2824 var txctx context.Context
2825 if s.cg != nil {
2826 txctx = s.cg.txCtx()
2827 }
2828 rows.initContextClose(ctx, txctx)
2829 return nil
2830 }
2831
2832 releaseConn(err)
2833 return err
2834 })
2835
2836 return rows, err
2837 }
2838
2839
2840
2841
2842
2843
2844 func (s *Stmt) Query(args ...any) (*Rows, error) {
2845 return s.QueryContext(context.Background(), args...)
2846 }
2847
2848 func rowsiFromStatement(ctx context.Context, ci driver.Conn, ds *driverStmt, args ...any) (driver.Rows, error) {
2849 ds.Lock()
2850 defer ds.Unlock()
2851 dargs, err := driverArgsConnLocked(ci, ds, args)
2852 if err != nil {
2853 return nil, err
2854 }
2855 return ctxDriverStmtQuery(ctx, ds.si, dargs)
2856 }
2857
2858
2859
2860
2861
2862
2863
2864 func (s *Stmt) QueryRowContext(ctx context.Context, args ...any) *Row {
2865 rows, err := s.QueryContext(ctx, args...)
2866 if err != nil {
2867 return &Row{err: err}
2868 }
2869 return &Row{rows: rows}
2870 }
2871
2872
2873
2874
2875
2876
2877
2878
2879
2880
2881
2882
2883
2884
2885
2886 func (s *Stmt) QueryRow(args ...any) *Row {
2887 return s.QueryRowContext(context.Background(), args...)
2888 }
2889
2890
2891 func (s *Stmt) Close() error {
2892 s.closemu.Lock()
2893 defer s.closemu.Unlock()
2894
2895 if s.stickyErr != nil {
2896 return s.stickyErr
2897 }
2898 s.mu.Lock()
2899 if s.closed {
2900 s.mu.Unlock()
2901 return nil
2902 }
2903 s.closed = true
2904 txds := s.cgds
2905 s.cgds = nil
2906
2907 s.mu.Unlock()
2908
2909 if s.cg == nil {
2910 return s.db.removeDep(s, s)
2911 }
2912
2913 if s.parentStmt != nil {
2914
2915
2916 return s.db.removeDep(s.parentStmt, s)
2917 }
2918 return txds.Close()
2919 }
2920
2921 func (s *Stmt) finalClose() error {
2922 s.mu.Lock()
2923 defer s.mu.Unlock()
2924 if s.css != nil {
2925 for _, v := range s.css {
2926 s.db.noteUnusedDriverStatement(v.dc, v.ds)
2927 v.dc.removeOpenStmt(v.ds)
2928 }
2929 s.css = nil
2930 }
2931 return nil
2932 }
2933
2934
2935
2936 type Rows struct {
2937 dc *driverConn
2938 releaseConn func(error)
2939 rowsi driver.Rows
2940 cancel func()
2941 closeStmt *driverStmt
2942
2943 contextDone atomic.Pointer[error]
2944
2945
2946
2947
2948
2949
2950 closemu sync.RWMutex
2951 lasterr error
2952 closed bool
2953
2954
2955
2956
2957
2958
2959
2960
2961 closemuScanHold bool
2962
2963
2964
2965
2966
2967 hitEOF bool
2968
2969
2970
2971 lastcols []driver.Value
2972
2973
2974
2975
2976
2977
2978 raw []byte
2979 }
2980
2981
2982
2983 func (rs *Rows) lasterrOrErrLocked(err error) error {
2984 if rs.lasterr != nil && rs.lasterr != io.EOF {
2985 return rs.lasterr
2986 }
2987 return err
2988 }
2989
2990
2991
2992 var bypassRowsAwaitDone = false
2993
2994 func (rs *Rows) initContextClose(ctx, txctx context.Context) {
2995 if ctx.Done() == nil && (txctx == nil || txctx.Done() == nil) {
2996 return
2997 }
2998 if bypassRowsAwaitDone {
2999 return
3000 }
3001 closectx, cancel := context.WithCancel(ctx)
3002 rs.cancel = cancel
3003 go rs.awaitDone(ctx, txctx, closectx)
3004 }
3005
3006
3007
3008
3009
3010
3011 func (rs *Rows) awaitDone(ctx, txctx, closectx context.Context) {
3012 var txctxDone <-chan struct{}
3013 if txctx != nil {
3014 txctxDone = txctx.Done()
3015 }
3016 select {
3017 case <-ctx.Done():
3018 err := ctx.Err()
3019 rs.contextDone.Store(&err)
3020 case <-txctxDone:
3021 err := txctx.Err()
3022 rs.contextDone.Store(&err)
3023 case <-closectx.Done():
3024
3025
3026 }
3027 rs.close(ctx.Err())
3028 }
3029
3030
3031
3032
3033
3034
3035
3036 func (rs *Rows) Next() bool {
3037
3038
3039
3040 rs.closemuRUnlockIfHeldByScan()
3041
3042 if rs.contextDone.Load() != nil {
3043 return false
3044 }
3045
3046 var doClose, ok bool
3047 withLock(rs.closemu.RLocker(), func() {
3048 doClose, ok = rs.nextLocked()
3049 })
3050 if doClose {
3051 rs.Close()
3052 }
3053 if doClose && !ok {
3054 rs.hitEOF = true
3055 }
3056 return ok
3057 }
3058
3059 func (rs *Rows) nextLocked() (doClose, ok bool) {
3060 if rs.closed {
3061 return false, false
3062 }
3063
3064
3065
3066 rs.dc.Lock()
3067 defer rs.dc.Unlock()
3068
3069 if rs.lastcols == nil {
3070 rs.lastcols = make([]driver.Value, len(rs.rowsi.Columns()))
3071 }
3072
3073 rs.lasterr = rs.rowsi.Next(rs.lastcols)
3074 if rs.lasterr != nil {
3075
3076 if rs.lasterr != io.EOF {
3077 return true, false
3078 }
3079 nextResultSet, ok := rs.rowsi.(driver.RowsNextResultSet)
3080 if !ok {
3081 return true, false
3082 }
3083
3084
3085
3086 if !nextResultSet.HasNextResultSet() {
3087 doClose = true
3088 }
3089 return doClose, false
3090 }
3091 return false, true
3092 }
3093
3094
3095
3096
3097
3098
3099
3100
3101
3102 func (rs *Rows) NextResultSet() bool {
3103
3104
3105
3106 rs.closemuRUnlockIfHeldByScan()
3107
3108 var doClose bool
3109 defer func() {
3110 if doClose {
3111 rs.Close()
3112 }
3113 }()
3114 rs.closemu.RLock()
3115 defer rs.closemu.RUnlock()
3116
3117 if rs.closed {
3118 return false
3119 }
3120
3121 rs.lastcols = nil
3122 nextResultSet, ok := rs.rowsi.(driver.RowsNextResultSet)
3123 if !ok {
3124 doClose = true
3125 return false
3126 }
3127
3128
3129
3130 rs.dc.Lock()
3131 defer rs.dc.Unlock()
3132
3133 rs.lasterr = nextResultSet.NextResultSet()
3134 if rs.lasterr != nil {
3135 doClose = true
3136 return false
3137 }
3138 return true
3139 }
3140
3141
3142
3143 func (rs *Rows) Err() error {
3144
3145
3146
3147
3148 if !rs.hitEOF {
3149 if errp := rs.contextDone.Load(); errp != nil {
3150 return *errp
3151 }
3152 }
3153
3154 rs.closemu.RLock()
3155 defer rs.closemu.RUnlock()
3156 return rs.lasterrOrErrLocked(nil)
3157 }
3158
3159
3160
3161
3162
3163
3164
3165 func (rs *Rows) rawbuf() []byte {
3166 if rs == nil {
3167
3168 return nil
3169 }
3170 return rs.raw
3171 }
3172
3173
3174
3175 func (rs *Rows) setrawbuf(b []byte) RawBytes {
3176 if rs == nil {
3177
3178 return RawBytes(b)
3179 }
3180 off := len(rs.raw)
3181 rs.raw = b
3182 return RawBytes(rs.raw[off:])
3183 }
3184
3185 var errRowsClosed = errors.New("sql: Rows are closed")
3186 var errNoRows = errors.New("sql: no Rows available")
3187
3188
3189
3190 func (rs *Rows) Columns() ([]string, error) {
3191 rs.closemu.RLock()
3192 defer rs.closemu.RUnlock()
3193 if rs.closed {
3194 return nil, rs.lasterrOrErrLocked(errRowsClosed)
3195 }
3196 if rs.rowsi == nil {
3197 return nil, rs.lasterrOrErrLocked(errNoRows)
3198 }
3199 rs.dc.Lock()
3200 defer rs.dc.Unlock()
3201
3202 return rs.rowsi.Columns(), nil
3203 }
3204
3205
3206
3207 func (rs *Rows) ColumnTypes() ([]*ColumnType, error) {
3208 rs.closemu.RLock()
3209 defer rs.closemu.RUnlock()
3210 if rs.closed {
3211 return nil, rs.lasterrOrErrLocked(errRowsClosed)
3212 }
3213 if rs.rowsi == nil {
3214 return nil, rs.lasterrOrErrLocked(errNoRows)
3215 }
3216 rs.dc.Lock()
3217 defer rs.dc.Unlock()
3218
3219 return rowsColumnInfoSetupConnLocked(rs.rowsi), nil
3220 }
3221
3222
3223 type ColumnType struct {
3224 name string
3225
3226 hasNullable bool
3227 hasLength bool
3228 hasPrecisionScale bool
3229
3230 nullable bool
3231 length int64
3232 databaseType string
3233 precision int64
3234 scale int64
3235 scanType reflect.Type
3236 }
3237
3238
3239 func (ci *ColumnType) Name() string {
3240 return ci.name
3241 }
3242
3243
3244
3245
3246
3247
3248 func (ci *ColumnType) Length() (length int64, ok bool) {
3249 return ci.length, ci.hasLength
3250 }
3251
3252
3253
3254 func (ci *ColumnType) DecimalSize() (precision, scale int64, ok bool) {
3255 return ci.precision, ci.scale, ci.hasPrecisionScale
3256 }
3257
3258
3259
3260
3261 func (ci *ColumnType) ScanType() reflect.Type {
3262 return ci.scanType
3263 }
3264
3265
3266
3267 func (ci *ColumnType) Nullable() (nullable, ok bool) {
3268 return ci.nullable, ci.hasNullable
3269 }
3270
3271
3272
3273
3274
3275
3276
3277 func (ci *ColumnType) DatabaseTypeName() string {
3278 return ci.databaseType
3279 }
3280
3281 func rowsColumnInfoSetupConnLocked(rowsi driver.Rows) []*ColumnType {
3282 names := rowsi.Columns()
3283
3284 list := make([]*ColumnType, len(names))
3285 for i := range list {
3286 ci := &ColumnType{
3287 name: names[i],
3288 }
3289 list[i] = ci
3290
3291 if prop, ok := rowsi.(driver.RowsColumnTypeScanType); ok {
3292 ci.scanType = prop.ColumnTypeScanType(i)
3293 } else {
3294 ci.scanType = reflect.TypeFor[any]()
3295 }
3296 if prop, ok := rowsi.(driver.RowsColumnTypeDatabaseTypeName); ok {
3297 ci.databaseType = prop.ColumnTypeDatabaseTypeName(i)
3298 }
3299 if prop, ok := rowsi.(driver.RowsColumnTypeLength); ok {
3300 ci.length, ci.hasLength = prop.ColumnTypeLength(i)
3301 }
3302 if prop, ok := rowsi.(driver.RowsColumnTypeNullable); ok {
3303 ci.nullable, ci.hasNullable = prop.ColumnTypeNullable(i)
3304 }
3305 if prop, ok := rowsi.(driver.RowsColumnTypePrecisionScale); ok {
3306 ci.precision, ci.scale, ci.hasPrecisionScale = prop.ColumnTypePrecisionScale(i)
3307 }
3308 }
3309 return list
3310 }
3311
3312
3313
3314
3315
3316
3317
3318
3319
3320
3321
3322
3323
3324
3325
3326
3327
3328
3329
3330
3331
3332
3333
3334
3335
3336
3337
3338
3339
3340
3341
3342
3343
3344
3345
3346
3347
3348
3349
3350
3351
3352
3353
3354
3355
3356
3357
3358
3359
3360
3361
3362
3363
3364
3365
3366
3367
3368
3369
3370
3371
3372 func (rs *Rows) Scan(dest ...any) error {
3373 if rs.closemuScanHold {
3374
3375
3376 return fmt.Errorf("sql: Scan called without calling Next (closemuScanHold)")
3377 }
3378
3379 rs.closemu.RLock()
3380 rs.raw = rs.raw[:0]
3381 err := rs.scanLocked(dest...)
3382 if err == nil && scanArgsContainRawBytes(dest) {
3383 rs.closemuScanHold = true
3384 } else {
3385 rs.closemu.RUnlock()
3386 }
3387 return err
3388 }
3389
3390 func (rs *Rows) scanLocked(dest ...any) error {
3391 if rs.lasterr != nil && rs.lasterr != io.EOF {
3392 return rs.lasterr
3393 }
3394 if rs.closed {
3395 return rs.lasterrOrErrLocked(errRowsClosed)
3396 }
3397
3398 if rs.lastcols == nil {
3399 return errors.New("sql: Scan called without calling Next")
3400 }
3401 if len(dest) != len(rs.lastcols) {
3402 return fmt.Errorf("sql: expected %d destination arguments in Scan, not %d", len(rs.lastcols), len(dest))
3403 }
3404
3405 for i, sv := range rs.lastcols {
3406 err := convertAssignRows(dest[i], sv, rs)
3407 if err != nil {
3408 return fmt.Errorf(`sql: Scan error on column index %d, name %q: %w`, i, rs.rowsi.Columns()[i], err)
3409 }
3410 }
3411 return nil
3412 }
3413
3414
3415
3416 func (rs *Rows) closemuRUnlockIfHeldByScan() {
3417 if rs.closemuScanHold {
3418 rs.closemuScanHold = false
3419 rs.closemu.RUnlock()
3420 }
3421 }
3422
3423 func scanArgsContainRawBytes(args []any) bool {
3424 for _, a := range args {
3425 if _, ok := a.(*RawBytes); ok {
3426 return true
3427 }
3428 }
3429 return false
3430 }
3431
3432
3433
3434 var rowsCloseHook = func() func(*Rows, *error) { return nil }
3435
3436
3437
3438
3439
3440 func (rs *Rows) Close() error {
3441
3442
3443
3444 rs.closemuRUnlockIfHeldByScan()
3445
3446 return rs.close(nil)
3447 }
3448
3449 func (rs *Rows) close(err error) error {
3450 rs.closemu.Lock()
3451 defer rs.closemu.Unlock()
3452
3453 if rs.closed {
3454 return nil
3455 }
3456 rs.closed = true
3457
3458 if rs.lasterr == nil {
3459 rs.lasterr = err
3460 }
3461
3462 withLock(rs.dc, func() {
3463 err = rs.rowsi.Close()
3464 })
3465 if fn := rowsCloseHook(); fn != nil {
3466 fn(rs, &err)
3467 }
3468 if rs.cancel != nil {
3469 rs.cancel()
3470 }
3471
3472 if rs.closeStmt != nil {
3473 rs.closeStmt.Close()
3474 }
3475 rs.releaseConn(err)
3476
3477 rs.lasterr = rs.lasterrOrErrLocked(err)
3478 return err
3479 }
3480
3481
3482 type Row struct {
3483
3484 err error
3485 rows *Rows
3486 }
3487
3488
3489
3490
3491
3492
3493 func (r *Row) Scan(dest ...any) error {
3494 if r.err != nil {
3495 return r.err
3496 }
3497
3498
3499
3500
3501
3502
3503
3504
3505
3506
3507
3508
3509
3510
3511 defer r.rows.Close()
3512 if scanArgsContainRawBytes(dest) {
3513 return errors.New("sql: RawBytes isn't allowed on Row.Scan")
3514 }
3515
3516 if !r.rows.Next() {
3517 if err := r.rows.Err(); err != nil {
3518 return err
3519 }
3520 return ErrNoRows
3521 }
3522 err := r.rows.Scan(dest...)
3523 if err != nil {
3524 return err
3525 }
3526
3527 return r.rows.Close()
3528 }
3529
3530
3531
3532
3533
3534 func (r *Row) Err() error {
3535 return r.err
3536 }
3537
3538
3539 type Result interface {
3540
3541
3542
3543
3544
3545 LastInsertId() (int64, error)
3546
3547
3548
3549
3550 RowsAffected() (int64, error)
3551 }
3552
3553 type driverResult struct {
3554 sync.Locker
3555 resi driver.Result
3556 }
3557
3558 func (dr driverResult) LastInsertId() (int64, error) {
3559 dr.Lock()
3560 defer dr.Unlock()
3561 return dr.resi.LastInsertId()
3562 }
3563
3564 func (dr driverResult) RowsAffected() (int64, error) {
3565 dr.Lock()
3566 defer dr.Unlock()
3567 return dr.resi.RowsAffected()
3568 }
3569
3570 func stack() string {
3571 var buf [2 << 10]byte
3572 return string(buf[:runtime.Stack(buf[:], false)])
3573 }
3574
3575
3576 func withLock(lk sync.Locker, fn func()) {
3577 lk.Lock()
3578 defer lk.Unlock()
3579 fn()
3580 }
3581
3582
3583
3584
3585
3586
3587
3588
3589
3590
3591
3592 type connRequestSet struct {
3593
3594 s []connRequestAndIndex
3595 }
3596
3597 type connRequestAndIndex struct {
3598
3599 req chan connRequest
3600
3601
3602
3603 curIdx *int
3604 }
3605
3606
3607
3608 func (s *connRequestSet) CloseAndRemoveAll() {
3609 for _, v := range s.s {
3610 *v.curIdx = -1
3611 close(v.req)
3612 }
3613 s.s = nil
3614 }
3615
3616
3617 func (s *connRequestSet) Len() int { return len(s.s) }
3618
3619
3620
3621 type connRequestDelHandle struct {
3622 idx *int
3623 }
3624
3625
3626
3627
3628 func (s *connRequestSet) Add(v chan connRequest) connRequestDelHandle {
3629 idx := len(s.s)
3630
3631
3632
3633
3634
3635
3636
3637
3638 idxPtr := &idx
3639 s.s = append(s.s, connRequestAndIndex{v, idxPtr})
3640 return connRequestDelHandle{idxPtr}
3641 }
3642
3643
3644
3645
3646
3647 func (s *connRequestSet) Delete(h connRequestDelHandle) bool {
3648 idx := *h.idx
3649 if idx < 0 {
3650 return false
3651 }
3652 s.deleteIndex(idx)
3653 return true
3654 }
3655
3656 func (s *connRequestSet) deleteIndex(idx int) {
3657
3658 *(s.s[idx].curIdx) = -1
3659
3660
3661 if idx < len(s.s)-1 {
3662 last := s.s[len(s.s)-1]
3663 *last.curIdx = idx
3664 s.s[idx] = last
3665 }
3666
3667 s.s[len(s.s)-1] = connRequestAndIndex{}
3668 s.s = s.s[:len(s.s)-1]
3669 }
3670
3671
3672
3673
3674 func (s *connRequestSet) TakeRandom() (v chan connRequest, ok bool) {
3675 if len(s.s) == 0 {
3676 return nil, false
3677 }
3678 pick := rand.IntN(len(s.s))
3679 e := s.s[pick]
3680 s.deleteIndex(pick)
3681 return e.req, true
3682 }
3683
View as plain text