Source file src/database/sql/sql.go

     1  // Copyright 2011 The Go Authors. All rights reserved.
     2  // Use of this source code is governed by a BSD-style
     3  // license that can be found in the LICENSE file.
     4  
     5  // Package sql provides a generic interface around SQL (or SQL-like)
     6  // databases.
     7  //
     8  // The sql package must be used in conjunction with a database driver.
     9  // See https://golang.org/s/sqldrivers for a list of drivers.
    10  //
    11  // Drivers that do not support context cancellation will not return until
    12  // after the query is completed.
    13  //
    14  // For usage examples, see the wiki page at
    15  // https://golang.org/s/sqlwiki.
    16  package sql
    17  
    18  import (
    19  	"context"
    20  	"database/sql/driver"
    21  	"errors"
    22  	"fmt"
    23  	"io"
    24  	"maps"
    25  	"math/rand/v2"
    26  	"reflect"
    27  	"runtime"
    28  	"slices"
    29  	"strconv"
    30  	"sync"
    31  	"sync/atomic"
    32  	"time"
    33  	_ "unsafe"
    34  )
    35  
    36  var driversMu sync.RWMutex
    37  
    38  // drivers should be an internal detail,
    39  // but widely used packages access it using linkname.
    40  // (It is extra wrong that they linkname drivers but not driversMu.)
    41  // Notable members of the hall of shame include:
    42  //   - github.com/instana/go-sensor
    43  //
    44  // Do not remove or change the type signature.
    45  // See go.dev/issue/67401.
    46  //
    47  //go:linkname drivers
    48  var drivers = make(map[string]driver.Driver)
    49  
    50  // nowFunc returns the current time; it's overridden in tests.
    51  var nowFunc = time.Now
    52  
    53  // Register makes a database driver available by the provided name.
    54  // If Register is called twice with the same name or if driver is nil,
    55  // it panics.
    56  func Register(name string, driver driver.Driver) {
    57  	driversMu.Lock()
    58  	defer driversMu.Unlock()
    59  	if driver == nil {
    60  		panic("sql: Register driver is nil")
    61  	}
    62  	if _, dup := drivers[name]; dup {
    63  		panic("sql: Register called twice for driver " + name)
    64  	}
    65  	drivers[name] = driver
    66  }
    67  
    68  func unregisterAllDrivers() {
    69  	driversMu.Lock()
    70  	defer driversMu.Unlock()
    71  	// For tests.
    72  	drivers = make(map[string]driver.Driver)
    73  }
    74  
    75  // Drivers returns a sorted list of the names of the registered drivers.
    76  func Drivers() []string {
    77  	driversMu.RLock()
    78  	defer driversMu.RUnlock()
    79  	return slices.Sorted(maps.Keys(drivers))
    80  }
    81  
    82  // A NamedArg is a named argument. NamedArg values may be used as
    83  // arguments to [DB.Query] or [DB.Exec] and bind to the corresponding named
    84  // parameter in the SQL statement.
    85  //
    86  // For a more concise way to create NamedArg values, see
    87  // the [Named] function.
    88  type NamedArg struct {
    89  	_NamedFieldsRequired struct{}
    90  
    91  	// Name is the name of the parameter placeholder.
    92  	//
    93  	// If empty, the ordinal position in the argument list will be
    94  	// used.
    95  	//
    96  	// Name must omit any symbol prefix.
    97  	Name string
    98  
    99  	// Value is the value of the parameter.
   100  	// It may be assigned the same value types as the query
   101  	// arguments.
   102  	Value any
   103  }
   104  
   105  // Named provides a more concise way to create [NamedArg] values.
   106  //
   107  // Example usage:
   108  //
   109  //	db.ExecContext(ctx, `
   110  //	    delete from Invoice
   111  //	    where
   112  //	        TimeCreated < @end
   113  //	        and TimeCreated >= @start;`,
   114  //	    sql.Named("start", startTime),
   115  //	    sql.Named("end", endTime),
   116  //	)
   117  func Named(name string, value any) NamedArg {
   118  	// This method exists because the go1compat promise
   119  	// doesn't guarantee that structs don't grow more fields,
   120  	// so unkeyed struct literals are a vet error. Thus, we don't
   121  	// want to allow sql.NamedArg{name, value}.
   122  	return NamedArg{Name: name, Value: value}
   123  }
   124  
   125  // IsolationLevel is the transaction isolation level used in [TxOptions].
   126  type IsolationLevel int
   127  
   128  // Various isolation levels that drivers may support in [DB.BeginTx].
   129  // If a driver does not support a given isolation level an error may be returned.
   130  //
   131  // See https://en.wikipedia.org/wiki/Isolation_(database_systems)#Isolation_levels.
   132  const (
   133  	LevelDefault IsolationLevel = iota
   134  	LevelReadUncommitted
   135  	LevelReadCommitted
   136  	LevelWriteCommitted
   137  	LevelRepeatableRead
   138  	LevelSnapshot
   139  	LevelSerializable
   140  	LevelLinearizable
   141  )
   142  
   143  // String returns the name of the transaction isolation level.
   144  func (i IsolationLevel) String() string {
   145  	switch i {
   146  	case LevelDefault:
   147  		return "Default"
   148  	case LevelReadUncommitted:
   149  		return "Read Uncommitted"
   150  	case LevelReadCommitted:
   151  		return "Read Committed"
   152  	case LevelWriteCommitted:
   153  		return "Write Committed"
   154  	case LevelRepeatableRead:
   155  		return "Repeatable Read"
   156  	case LevelSnapshot:
   157  		return "Snapshot"
   158  	case LevelSerializable:
   159  		return "Serializable"
   160  	case LevelLinearizable:
   161  		return "Linearizable"
   162  	default:
   163  		return "IsolationLevel(" + strconv.Itoa(int(i)) + ")"
   164  	}
   165  }
   166  
   167  var _ fmt.Stringer = LevelDefault
   168  
   169  // TxOptions holds the transaction options to be used in [DB.BeginTx].
   170  type TxOptions struct {
   171  	// Isolation is the transaction isolation level.
   172  	// If zero, the driver or database's default level is used.
   173  	Isolation IsolationLevel
   174  	ReadOnly  bool
   175  }
   176  
   177  // RawBytes is a byte slice that holds a reference to memory owned by
   178  // the database itself. After a [Rows.Scan] into a RawBytes, the slice is only
   179  // valid until the next call to [Rows.Next], [Rows.Scan], or [Rows.Close].
   180  type RawBytes []byte
   181  
   182  // NullString represents a string that may be null.
   183  // NullString implements the [Scanner] interface so
   184  // it can be used as a scan destination:
   185  //
   186  //	var s NullString
   187  //	err := db.QueryRow("SELECT name FROM foo WHERE id=?", id).Scan(&s)
   188  //	...
   189  //	if s.Valid {
   190  //	   // use s.String
   191  //	} else {
   192  //	   // NULL value
   193  //	}
   194  type NullString struct {
   195  	String string
   196  	Valid  bool // Valid is true if String is not NULL
   197  }
   198  
   199  // Scan implements the [Scanner] interface.
   200  func (ns *NullString) Scan(value any) error {
   201  	if value == nil {
   202  		ns.String, ns.Valid = "", false
   203  		return nil
   204  	}
   205  	ns.Valid = true
   206  	return convertAssign(&ns.String, value)
   207  }
   208  
   209  // Value implements the [driver.Valuer] interface.
   210  func (ns NullString) Value() (driver.Value, error) {
   211  	if !ns.Valid {
   212  		return nil, nil
   213  	}
   214  	return ns.String, nil
   215  }
   216  
   217  // NullInt64 represents an int64 that may be null.
   218  // NullInt64 implements the [Scanner] interface so
   219  // it can be used as a scan destination, similar to [NullString].
   220  type NullInt64 struct {
   221  	Int64 int64
   222  	Valid bool // Valid is true if Int64 is not NULL
   223  }
   224  
   225  // Scan implements the [Scanner] interface.
   226  func (n *NullInt64) Scan(value any) error {
   227  	if value == nil {
   228  		n.Int64, n.Valid = 0, false
   229  		return nil
   230  	}
   231  	n.Valid = true
   232  	return convertAssign(&n.Int64, value)
   233  }
   234  
   235  // Value implements the [driver.Valuer] interface.
   236  func (n NullInt64) Value() (driver.Value, error) {
   237  	if !n.Valid {
   238  		return nil, nil
   239  	}
   240  	return n.Int64, nil
   241  }
   242  
   243  // NullInt32 represents an int32 that may be null.
   244  // NullInt32 implements the [Scanner] interface so
   245  // it can be used as a scan destination, similar to [NullString].
   246  type NullInt32 struct {
   247  	Int32 int32
   248  	Valid bool // Valid is true if Int32 is not NULL
   249  }
   250  
   251  // Scan implements the [Scanner] interface.
   252  func (n *NullInt32) Scan(value any) error {
   253  	if value == nil {
   254  		n.Int32, n.Valid = 0, false
   255  		return nil
   256  	}
   257  	n.Valid = true
   258  	return convertAssign(&n.Int32, value)
   259  }
   260  
   261  // Value implements the [driver.Valuer] interface.
   262  func (n NullInt32) Value() (driver.Value, error) {
   263  	if !n.Valid {
   264  		return nil, nil
   265  	}
   266  	return int64(n.Int32), nil
   267  }
   268  
   269  // NullInt16 represents an int16 that may be null.
   270  // NullInt16 implements the [Scanner] interface so
   271  // it can be used as a scan destination, similar to [NullString].
   272  type NullInt16 struct {
   273  	Int16 int16
   274  	Valid bool // Valid is true if Int16 is not NULL
   275  }
   276  
   277  // Scan implements the [Scanner] interface.
   278  func (n *NullInt16) Scan(value any) error {
   279  	if value == nil {
   280  		n.Int16, n.Valid = 0, false
   281  		return nil
   282  	}
   283  	err := convertAssign(&n.Int16, value)
   284  	n.Valid = err == nil
   285  	return err
   286  }
   287  
   288  // Value implements the [driver.Valuer] interface.
   289  func (n NullInt16) Value() (driver.Value, error) {
   290  	if !n.Valid {
   291  		return nil, nil
   292  	}
   293  	return int64(n.Int16), nil
   294  }
   295  
   296  // NullByte represents a byte that may be null.
   297  // NullByte implements the [Scanner] interface so
   298  // it can be used as a scan destination, similar to [NullString].
   299  type NullByte struct {
   300  	Byte  byte
   301  	Valid bool // Valid is true if Byte is not NULL
   302  }
   303  
   304  // Scan implements the [Scanner] interface.
   305  func (n *NullByte) Scan(value any) error {
   306  	if value == nil {
   307  		n.Byte, n.Valid = 0, false
   308  		return nil
   309  	}
   310  	err := convertAssign(&n.Byte, value)
   311  	n.Valid = err == nil
   312  	return err
   313  }
   314  
   315  // Value implements the [driver.Valuer] interface.
   316  func (n NullByte) Value() (driver.Value, error) {
   317  	if !n.Valid {
   318  		return nil, nil
   319  	}
   320  	return int64(n.Byte), nil
   321  }
   322  
   323  // NullFloat64 represents a float64 that may be null.
   324  // NullFloat64 implements the [Scanner] interface so
   325  // it can be used as a scan destination, similar to [NullString].
   326  type NullFloat64 struct {
   327  	Float64 float64
   328  	Valid   bool // Valid is true if Float64 is not NULL
   329  }
   330  
   331  // Scan implements the [Scanner] interface.
   332  func (n *NullFloat64) Scan(value any) error {
   333  	if value == nil {
   334  		n.Float64, n.Valid = 0, false
   335  		return nil
   336  	}
   337  	n.Valid = true
   338  	return convertAssign(&n.Float64, value)
   339  }
   340  
   341  // Value implements the [driver.Valuer] interface.
   342  func (n NullFloat64) Value() (driver.Value, error) {
   343  	if !n.Valid {
   344  		return nil, nil
   345  	}
   346  	return n.Float64, nil
   347  }
   348  
   349  // NullBool represents a bool that may be null.
   350  // NullBool implements the [Scanner] interface so
   351  // it can be used as a scan destination, similar to [NullString].
   352  type NullBool struct {
   353  	Bool  bool
   354  	Valid bool // Valid is true if Bool is not NULL
   355  }
   356  
   357  // Scan implements the [Scanner] interface.
   358  func (n *NullBool) Scan(value any) error {
   359  	if value == nil {
   360  		n.Bool, n.Valid = false, false
   361  		return nil
   362  	}
   363  	n.Valid = true
   364  	return convertAssign(&n.Bool, value)
   365  }
   366  
   367  // Value implements the [driver.Valuer] interface.
   368  func (n NullBool) Value() (driver.Value, error) {
   369  	if !n.Valid {
   370  		return nil, nil
   371  	}
   372  	return n.Bool, nil
   373  }
   374  
   375  // NullTime represents a [time.Time] that may be null.
   376  // NullTime implements the [Scanner] interface so
   377  // it can be used as a scan destination, similar to [NullString].
   378  type NullTime struct {
   379  	Time  time.Time
   380  	Valid bool // Valid is true if Time is not NULL
   381  }
   382  
   383  // Scan implements the [Scanner] interface.
   384  func (n *NullTime) Scan(value any) error {
   385  	if value == nil {
   386  		n.Time, n.Valid = time.Time{}, false
   387  		return nil
   388  	}
   389  	n.Valid = true
   390  	return convertAssign(&n.Time, value)
   391  }
   392  
   393  // Value implements the [driver.Valuer] interface.
   394  func (n NullTime) Value() (driver.Value, error) {
   395  	if !n.Valid {
   396  		return nil, nil
   397  	}
   398  	return n.Time, nil
   399  }
   400  
   401  // Null represents a value that may be null.
   402  // Null implements the [Scanner] interface so
   403  // it can be used as a scan destination:
   404  //
   405  //	var s Null[string]
   406  //	err := db.QueryRow("SELECT name FROM foo WHERE id=?", id).Scan(&s)
   407  //	...
   408  //	if s.Valid {
   409  //	   // use s.V
   410  //	} else {
   411  //	   // NULL value
   412  //	}
   413  //
   414  // T should be one of the types accepted by [driver.Value].
   415  type Null[T any] struct {
   416  	V     T
   417  	Valid bool
   418  }
   419  
   420  func (n *Null[T]) Scan(value any) error {
   421  	if value == nil {
   422  		n.V, n.Valid = *new(T), false
   423  		return nil
   424  	}
   425  	n.Valid = true
   426  	return convertAssign(&n.V, value)
   427  }
   428  
   429  func (n Null[T]) Value() (driver.Value, error) {
   430  	if !n.Valid {
   431  		return nil, nil
   432  	}
   433  	v := any(n.V)
   434  	// See issue 69728.
   435  	if valuer, ok := v.(driver.Valuer); ok {
   436  		val, err := callValuerValue(valuer)
   437  		if err != nil {
   438  			return val, err
   439  		}
   440  		v = val
   441  	}
   442  	// See issue 69837.
   443  	return driver.DefaultParameterConverter.ConvertValue(v)
   444  }
   445  
   446  // Scanner is an interface used by [Rows.Scan].
   447  type Scanner interface {
   448  	// Scan assigns a value from a database driver.
   449  	//
   450  	// The src value will be of one of the following types:
   451  	//
   452  	//    int64
   453  	//    float64
   454  	//    bool
   455  	//    []byte
   456  	//    string
   457  	//    time.Time
   458  	//    nil - for NULL values
   459  	//
   460  	// An error should be returned if the value cannot be stored
   461  	// without loss of information.
   462  	//
   463  	// Reference types such as []byte are only valid until the next call to Scan
   464  	// and should not be retained. Their underlying memory is owned by the driver.
   465  	// If retention is necessary, copy their values before the next call to Scan.
   466  	Scan(src any) error
   467  }
   468  
   469  // Out may be used to retrieve OUTPUT value parameters from stored procedures.
   470  //
   471  // Not all drivers and databases support OUTPUT value parameters.
   472  //
   473  // Example usage:
   474  //
   475  //	var outArg string
   476  //	_, err := db.ExecContext(ctx, "ProcName", sql.Named("Arg1", sql.Out{Dest: &outArg}))
   477  type Out struct {
   478  	_NamedFieldsRequired struct{}
   479  
   480  	// Dest is a pointer to the value that will be set to the result of the
   481  	// stored procedure's OUTPUT parameter.
   482  	Dest any
   483  
   484  	// In is whether the parameter is an INOUT parameter. If so, the input value to the stored
   485  	// procedure is the dereferenced value of Dest's pointer, which is then replaced with
   486  	// the output value.
   487  	In bool
   488  }
   489  
   490  // ErrNoRows is returned by [Row.Scan] when [DB.QueryRow] doesn't return a
   491  // row. In such a case, QueryRow returns a placeholder [*Row] value that
   492  // defers this error until a Scan.
   493  var ErrNoRows = errors.New("sql: no rows in result set")
   494  
   495  // DB is a database handle representing a pool of zero or more
   496  // underlying connections. It's safe for concurrent use by multiple
   497  // goroutines.
   498  //
   499  // The sql package creates and frees connections automatically; it
   500  // also maintains a free pool of idle connections. If the database has
   501  // a concept of per-connection state, such state can be reliably observed
   502  // within a transaction ([Tx]) or connection ([Conn]). Once [DB.Begin] is called, the
   503  // returned [Tx] is bound to a single connection. Once [Tx.Commit] or
   504  // [Tx.Rollback] is called on the transaction, that transaction's
   505  // connection is returned to [DB]'s idle connection pool. The pool size
   506  // can be controlled with [DB.SetMaxIdleConns].
   507  type DB struct {
   508  	// Total time waited for new connections.
   509  	waitDuration atomic.Int64
   510  
   511  	connector driver.Connector
   512  	// numClosed is an atomic counter which represents a total number of
   513  	// closed connections. Stmt.openStmt checks it before cleaning closed
   514  	// connections in Stmt.css.
   515  	numClosed atomic.Uint64
   516  
   517  	mu           sync.Mutex    // protects following fields
   518  	freeConn     []*driverConn // free connections ordered by returnedAt oldest to newest
   519  	connRequests connRequestSet
   520  	numOpen      int // number of opened and pending open connections
   521  	// Used to signal the need for new connections
   522  	// a goroutine running connectionOpener() reads on this chan and
   523  	// maybeOpenNewConnections sends on the chan (one send per needed connection)
   524  	// It is closed during db.Close(). The close tells the connectionOpener
   525  	// goroutine to exit.
   526  	openerCh          chan struct{}
   527  	closed            bool
   528  	dep               map[finalCloser]depSet
   529  	lastPut           map[*driverConn]string // stacktrace of last conn's put; debug only
   530  	maxIdleCount      int                    // zero means defaultMaxIdleConns; negative means 0
   531  	maxOpen           int                    // <= 0 means unlimited
   532  	maxLifetime       time.Duration          // maximum amount of time a connection may be reused
   533  	maxIdleTime       time.Duration          // maximum amount of time a connection may be idle before being closed
   534  	cleanerCh         chan struct{}
   535  	waitCount         int64 // Total number of connections waited for.
   536  	maxIdleClosed     int64 // Total number of connections closed due to idle count.
   537  	maxIdleTimeClosed int64 // Total number of connections closed due to idle time.
   538  	maxLifetimeClosed int64 // Total number of connections closed due to max connection lifetime limit.
   539  
   540  	stop func() // stop cancels the connection opener.
   541  }
   542  
   543  // connReuseStrategy determines how (*DB).conn returns database connections.
   544  type connReuseStrategy uint8
   545  
   546  const (
   547  	// alwaysNewConn forces a new connection to the database.
   548  	alwaysNewConn connReuseStrategy = iota
   549  	// cachedOrNewConn returns a cached connection, if available, else waits
   550  	// for one to become available (if MaxOpenConns has been reached) or
   551  	// creates a new database connection.
   552  	cachedOrNewConn
   553  )
   554  
   555  // driverConn wraps a driver.Conn with a mutex, to
   556  // be held during all calls into the Conn. (including any calls onto
   557  // interfaces returned via that Conn, such as calls on Tx, Stmt,
   558  // Result, Rows)
   559  type driverConn struct {
   560  	db        *DB
   561  	createdAt time.Time
   562  
   563  	sync.Mutex  // guards following
   564  	ci          driver.Conn
   565  	needReset   bool // The connection session should be reset before use if true.
   566  	closed      bool
   567  	finalClosed bool // ci.Close has been called
   568  	openStmt    map[*driverStmt]bool
   569  
   570  	// guarded by db.mu
   571  	inUse      bool
   572  	dbmuClosed bool      // same as closed, but guarded by db.mu, for removeClosedStmtLocked
   573  	returnedAt time.Time // Time the connection was created or returned.
   574  	onPut      []func()  // code (with db.mu held) run when conn is next returned
   575  }
   576  
   577  func (dc *driverConn) releaseConn(err error) {
   578  	dc.db.putConn(dc, err, true)
   579  }
   580  
   581  func (dc *driverConn) removeOpenStmt(ds *driverStmt) {
   582  	dc.Lock()
   583  	defer dc.Unlock()
   584  	delete(dc.openStmt, ds)
   585  }
   586  
   587  func (dc *driverConn) expired(timeout time.Duration) bool {
   588  	if timeout <= 0 {
   589  		return false
   590  	}
   591  	return dc.createdAt.Add(timeout).Before(nowFunc())
   592  }
   593  
   594  // resetSession checks if the driver connection needs the
   595  // session to be reset and if required, resets it.
   596  func (dc *driverConn) resetSession(ctx context.Context) error {
   597  	dc.Lock()
   598  	defer dc.Unlock()
   599  
   600  	if !dc.needReset {
   601  		return nil
   602  	}
   603  	if cr, ok := dc.ci.(driver.SessionResetter); ok {
   604  		return cr.ResetSession(ctx)
   605  	}
   606  	return nil
   607  }
   608  
   609  // validateConnection checks if the connection is valid and can
   610  // still be used. It also marks the session for reset if required.
   611  func (dc *driverConn) validateConnection(needsReset bool) bool {
   612  	dc.Lock()
   613  	defer dc.Unlock()
   614  
   615  	if needsReset {
   616  		dc.needReset = true
   617  	}
   618  	if cv, ok := dc.ci.(driver.Validator); ok {
   619  		return cv.IsValid()
   620  	}
   621  	return true
   622  }
   623  
   624  // prepareLocked prepares the query on dc. When cg == nil the dc must keep track of
   625  // the prepared statements in a pool.
   626  func (dc *driverConn) prepareLocked(ctx context.Context, cg stmtConnGrabber, query string) (*driverStmt, error) {
   627  	si, err := ctxDriverPrepare(ctx, dc.ci, query)
   628  	if err != nil {
   629  		return nil, err
   630  	}
   631  	ds := &driverStmt{Locker: dc, si: si}
   632  
   633  	// No need to manage open statements if there is a single connection grabber.
   634  	if cg != nil {
   635  		return ds, nil
   636  	}
   637  
   638  	// Track each driverConn's open statements, so we can close them
   639  	// before closing the conn.
   640  	//
   641  	// Wrap all driver.Stmt is *driverStmt to ensure they are only closed once.
   642  	if dc.openStmt == nil {
   643  		dc.openStmt = make(map[*driverStmt]bool)
   644  	}
   645  	dc.openStmt[ds] = true
   646  	return ds, nil
   647  }
   648  
   649  // the dc.db's Mutex is held.
   650  func (dc *driverConn) closeDBLocked() func() error {
   651  	dc.Lock()
   652  	defer dc.Unlock()
   653  	if dc.closed {
   654  		return func() error { return errors.New("sql: duplicate driverConn close") }
   655  	}
   656  	dc.closed = true
   657  	return dc.db.removeDepLocked(dc, dc)
   658  }
   659  
   660  func (dc *driverConn) Close() error {
   661  	dc.Lock()
   662  	if dc.closed {
   663  		dc.Unlock()
   664  		return errors.New("sql: duplicate driverConn close")
   665  	}
   666  	dc.closed = true
   667  	dc.Unlock() // not defer; removeDep finalClose calls may need to lock
   668  
   669  	// And now updates that require holding dc.mu.Lock.
   670  	dc.db.mu.Lock()
   671  	dc.dbmuClosed = true
   672  	fn := dc.db.removeDepLocked(dc, dc)
   673  	dc.db.mu.Unlock()
   674  	return fn()
   675  }
   676  
   677  func (dc *driverConn) finalClose() error {
   678  	var err error
   679  
   680  	// Each *driverStmt has a lock to the dc. Copy the list out of the dc
   681  	// before calling close on each stmt.
   682  	var openStmt []*driverStmt
   683  	withLock(dc, func() {
   684  		openStmt = make([]*driverStmt, 0, len(dc.openStmt))
   685  		for ds := range dc.openStmt {
   686  			openStmt = append(openStmt, ds)
   687  		}
   688  		dc.openStmt = nil
   689  	})
   690  	for _, ds := range openStmt {
   691  		ds.Close()
   692  	}
   693  	withLock(dc, func() {
   694  		dc.finalClosed = true
   695  		err = dc.ci.Close()
   696  		dc.ci = nil
   697  	})
   698  
   699  	dc.db.mu.Lock()
   700  	dc.db.numOpen--
   701  	dc.db.maybeOpenNewConnections()
   702  	dc.db.mu.Unlock()
   703  
   704  	dc.db.numClosed.Add(1)
   705  	return err
   706  }
   707  
   708  // driverStmt associates a driver.Stmt with the
   709  // *driverConn from which it came, so the driverConn's lock can be
   710  // held during calls.
   711  type driverStmt struct {
   712  	sync.Locker // the *driverConn
   713  	si          driver.Stmt
   714  	closed      bool
   715  	closeErr    error // return value of previous Close call
   716  }
   717  
   718  // Close ensures driver.Stmt is only closed once and always returns the same
   719  // result.
   720  func (ds *driverStmt) Close() error {
   721  	ds.Lock()
   722  	defer ds.Unlock()
   723  	if ds.closed {
   724  		return ds.closeErr
   725  	}
   726  	ds.closed = true
   727  	ds.closeErr = ds.si.Close()
   728  	return ds.closeErr
   729  }
   730  
   731  // depSet is a finalCloser's outstanding dependencies
   732  type depSet map[any]bool // set of true bools
   733  
   734  // The finalCloser interface is used by (*DB).addDep and related
   735  // dependency reference counting.
   736  type finalCloser interface {
   737  	// finalClose is called when the reference count of an object
   738  	// goes to zero. (*DB).mu is not held while calling it.
   739  	finalClose() error
   740  }
   741  
   742  // addDep notes that x now depends on dep, and x's finalClose won't be
   743  // called until all of x's dependencies are removed with removeDep.
   744  func (db *DB) addDep(x finalCloser, dep any) {
   745  	db.mu.Lock()
   746  	defer db.mu.Unlock()
   747  	db.addDepLocked(x, dep)
   748  }
   749  
   750  func (db *DB) addDepLocked(x finalCloser, dep any) {
   751  	if db.dep == nil {
   752  		db.dep = make(map[finalCloser]depSet)
   753  	}
   754  	xdep := db.dep[x]
   755  	if xdep == nil {
   756  		xdep = make(depSet)
   757  		db.dep[x] = xdep
   758  	}
   759  	xdep[dep] = true
   760  }
   761  
   762  // removeDep notes that x no longer depends on dep.
   763  // If x still has dependencies, nil is returned.
   764  // If x no longer has any dependencies, its finalClose method will be
   765  // called and its error value will be returned.
   766  func (db *DB) removeDep(x finalCloser, dep any) error {
   767  	db.mu.Lock()
   768  	fn := db.removeDepLocked(x, dep)
   769  	db.mu.Unlock()
   770  	return fn()
   771  }
   772  
   773  func (db *DB) removeDepLocked(x finalCloser, dep any) func() error {
   774  	xdep, ok := db.dep[x]
   775  	if !ok {
   776  		panic(fmt.Sprintf("unpaired removeDep: no deps for %T", x))
   777  	}
   778  
   779  	l0 := len(xdep)
   780  	delete(xdep, dep)
   781  
   782  	switch len(xdep) {
   783  	case l0:
   784  		// Nothing removed. Shouldn't happen.
   785  		panic(fmt.Sprintf("unpaired removeDep: no %T dep on %T", dep, x))
   786  	case 0:
   787  		// No more dependencies.
   788  		delete(db.dep, x)
   789  		return x.finalClose
   790  	default:
   791  		// Dependencies remain.
   792  		return func() error { return nil }
   793  	}
   794  }
   795  
   796  // This is the size of the connectionOpener request chan (DB.openerCh).
   797  // This value should be larger than the maximum typical value
   798  // used for DB.maxOpen. If maxOpen is significantly larger than
   799  // connectionRequestQueueSize then it is possible for ALL calls into the *DB
   800  // to block until the connectionOpener can satisfy the backlog of requests.
   801  var connectionRequestQueueSize = 1000000
   802  
   803  type dsnConnector struct {
   804  	dsn    string
   805  	driver driver.Driver
   806  }
   807  
   808  func (t dsnConnector) Connect(_ context.Context) (driver.Conn, error) {
   809  	return t.driver.Open(t.dsn)
   810  }
   811  
   812  func (t dsnConnector) Driver() driver.Driver {
   813  	return t.driver
   814  }
   815  
   816  // OpenDB opens a database using a [driver.Connector], allowing drivers to
   817  // bypass a string based data source name.
   818  //
   819  // Most users will open a database via a driver-specific connection
   820  // helper function that returns a [*DB]. No database drivers are included
   821  // in the Go standard library. See https://golang.org/s/sqldrivers for
   822  // a list of third-party drivers.
   823  //
   824  // OpenDB may just validate its arguments without creating a connection
   825  // to the database. To verify that the data source name is valid, call
   826  // [DB.Ping].
   827  //
   828  // The returned [DB] is safe for concurrent use by multiple goroutines
   829  // and maintains its own pool of idle connections. Thus, the OpenDB
   830  // function should be called just once. It is rarely necessary to
   831  // close a [DB].
   832  func OpenDB(c driver.Connector) *DB {
   833  	ctx, cancel := context.WithCancel(context.Background())
   834  	db := &DB{
   835  		connector: c,
   836  		openerCh:  make(chan struct{}, connectionRequestQueueSize),
   837  		lastPut:   make(map[*driverConn]string),
   838  		stop:      cancel,
   839  	}
   840  
   841  	go db.connectionOpener(ctx)
   842  
   843  	return db
   844  }
   845  
   846  // Open opens a database specified by its database driver name and a
   847  // driver-specific data source name, usually consisting of at least a
   848  // database name and connection information.
   849  //
   850  // Most users will open a database via a driver-specific connection
   851  // helper function that returns a [*DB]. No database drivers are included
   852  // in the Go standard library. See https://golang.org/s/sqldrivers for
   853  // a list of third-party drivers.
   854  //
   855  // Open may just validate its arguments without creating a connection
   856  // to the database. To verify that the data source name is valid, call
   857  // [DB.Ping].
   858  //
   859  // The returned [DB] is safe for concurrent use by multiple goroutines
   860  // and maintains its own pool of idle connections. Thus, the Open
   861  // function should be called just once. It is rarely necessary to
   862  // close a [DB].
   863  func Open(driverName, dataSourceName string) (*DB, error) {
   864  	driversMu.RLock()
   865  	driveri, ok := drivers[driverName]
   866  	driversMu.RUnlock()
   867  	if !ok {
   868  		return nil, fmt.Errorf("sql: unknown driver %q (forgotten import?)", driverName)
   869  	}
   870  
   871  	if driverCtx, ok := driveri.(driver.DriverContext); ok {
   872  		connector, err := driverCtx.OpenConnector(dataSourceName)
   873  		if err != nil {
   874  			return nil, err
   875  		}
   876  		return OpenDB(connector), nil
   877  	}
   878  
   879  	return OpenDB(dsnConnector{dsn: dataSourceName, driver: driveri}), nil
   880  }
   881  
   882  func (db *DB) pingDC(ctx context.Context, dc *driverConn, release func(error)) error {
   883  	var err error
   884  	if pinger, ok := dc.ci.(driver.Pinger); ok {
   885  		withLock(dc, func() {
   886  			err = pinger.Ping(ctx)
   887  		})
   888  	}
   889  	release(err)
   890  	return err
   891  }
   892  
   893  // PingContext verifies a connection to the database is still alive,
   894  // establishing a connection if necessary.
   895  func (db *DB) PingContext(ctx context.Context) error {
   896  	var dc *driverConn
   897  	var err error
   898  
   899  	err = db.retry(func(strategy connReuseStrategy) error {
   900  		dc, err = db.conn(ctx, strategy)
   901  		return err
   902  	})
   903  
   904  	if err != nil {
   905  		return err
   906  	}
   907  
   908  	return db.pingDC(ctx, dc, dc.releaseConn)
   909  }
   910  
   911  // Ping verifies a connection to the database is still alive,
   912  // establishing a connection if necessary.
   913  //
   914  // Ping uses [context.Background] internally; to specify the context, use
   915  // [DB.PingContext].
   916  func (db *DB) Ping() error {
   917  	return db.PingContext(context.Background())
   918  }
   919  
   920  // Close closes the database and prevents new queries from starting.
   921  // Close then waits for all queries that have started processing on the server
   922  // to finish.
   923  //
   924  // It is rare to Close a [DB], as the [DB] handle is meant to be
   925  // long-lived and shared between many goroutines.
   926  func (db *DB) Close() error {
   927  	db.mu.Lock()
   928  	if db.closed { // Make DB.Close idempotent
   929  		db.mu.Unlock()
   930  		return nil
   931  	}
   932  	if db.cleanerCh != nil {
   933  		close(db.cleanerCh)
   934  	}
   935  	var err error
   936  	fns := make([]func() error, 0, len(db.freeConn))
   937  	for _, dc := range db.freeConn {
   938  		fns = append(fns, dc.closeDBLocked())
   939  	}
   940  	db.freeConn = nil
   941  	db.closed = true
   942  	db.connRequests.CloseAndRemoveAll()
   943  	db.mu.Unlock()
   944  	for _, fn := range fns {
   945  		err1 := fn()
   946  		if err1 != nil {
   947  			err = err1
   948  		}
   949  	}
   950  	db.stop()
   951  	if c, ok := db.connector.(io.Closer); ok {
   952  		err1 := c.Close()
   953  		if err1 != nil {
   954  			err = err1
   955  		}
   956  	}
   957  	return err
   958  }
   959  
   960  const defaultMaxIdleConns = 2
   961  
   962  func (db *DB) maxIdleConnsLocked() int {
   963  	n := db.maxIdleCount
   964  	switch {
   965  	case n == 0:
   966  		// TODO(bradfitz): ask driver, if supported, for its default preference
   967  		return defaultMaxIdleConns
   968  	case n < 0:
   969  		return 0
   970  	default:
   971  		return n
   972  	}
   973  }
   974  
   975  func (db *DB) shortestIdleTimeLocked() time.Duration {
   976  	if db.maxIdleTime <= 0 {
   977  		return db.maxLifetime
   978  	}
   979  	if db.maxLifetime <= 0 {
   980  		return db.maxIdleTime
   981  	}
   982  	return min(db.maxIdleTime, db.maxLifetime)
   983  }
   984  
   985  // SetMaxIdleConns sets the maximum number of connections in the idle
   986  // connection pool.
   987  //
   988  // If MaxOpenConns is greater than 0 but less than the new MaxIdleConns,
   989  // then the new MaxIdleConns will be reduced to match the MaxOpenConns limit.
   990  //
   991  // If n <= 0, no idle connections are retained.
   992  //
   993  // The default max idle connections is currently 2. This may change in
   994  // a future release.
   995  func (db *DB) SetMaxIdleConns(n int) {
   996  	db.mu.Lock()
   997  	if n > 0 {
   998  		db.maxIdleCount = n
   999  	} else {
  1000  		// No idle connections.
  1001  		db.maxIdleCount = -1
  1002  	}
  1003  	// Make sure maxIdle doesn't exceed maxOpen
  1004  	if db.maxOpen > 0 && db.maxIdleConnsLocked() > db.maxOpen {
  1005  		db.maxIdleCount = db.maxOpen
  1006  	}
  1007  	var closing []*driverConn
  1008  	idleCount := len(db.freeConn)
  1009  	maxIdle := db.maxIdleConnsLocked()
  1010  	if idleCount > maxIdle {
  1011  		closing = db.freeConn[maxIdle:]
  1012  		db.freeConn = db.freeConn[:maxIdle]
  1013  	}
  1014  	db.maxIdleClosed += int64(len(closing))
  1015  	db.mu.Unlock()
  1016  	for _, c := range closing {
  1017  		c.Close()
  1018  	}
  1019  }
  1020  
  1021  // SetMaxOpenConns sets the maximum number of open connections to the database.
  1022  //
  1023  // If MaxIdleConns is greater than 0 and the new MaxOpenConns is less than
  1024  // MaxIdleConns, then MaxIdleConns will be reduced to match the new
  1025  // MaxOpenConns limit.
  1026  //
  1027  // If n <= 0, then there is no limit on the number of open connections.
  1028  // The default is 0 (unlimited).
  1029  func (db *DB) SetMaxOpenConns(n int) {
  1030  	db.mu.Lock()
  1031  	db.maxOpen = n
  1032  	if n < 0 {
  1033  		db.maxOpen = 0
  1034  	}
  1035  	syncMaxIdle := db.maxOpen > 0 && db.maxIdleConnsLocked() > db.maxOpen
  1036  	db.mu.Unlock()
  1037  	if syncMaxIdle {
  1038  		db.SetMaxIdleConns(n)
  1039  	}
  1040  }
  1041  
  1042  // SetConnMaxLifetime sets the maximum amount of time a connection may be reused.
  1043  //
  1044  // Expired connections may be closed lazily before reuse.
  1045  //
  1046  // If d <= 0, connections are not closed due to a connection's age.
  1047  func (db *DB) SetConnMaxLifetime(d time.Duration) {
  1048  	if d < 0 {
  1049  		d = 0
  1050  	}
  1051  	db.mu.Lock()
  1052  	// Wake cleaner up when lifetime is shortened.
  1053  	if d > 0 && d < db.maxLifetime && db.cleanerCh != nil {
  1054  		select {
  1055  		case db.cleanerCh <- struct{}{}:
  1056  		default:
  1057  		}
  1058  	}
  1059  	db.maxLifetime = d
  1060  	db.startCleanerLocked()
  1061  	db.mu.Unlock()
  1062  }
  1063  
  1064  // SetConnMaxIdleTime sets the maximum amount of time a connection may be idle.
  1065  //
  1066  // Expired connections may be closed lazily before reuse.
  1067  //
  1068  // If d <= 0, connections are not closed due to a connection's idle time.
  1069  func (db *DB) SetConnMaxIdleTime(d time.Duration) {
  1070  	if d < 0 {
  1071  		d = 0
  1072  	}
  1073  	db.mu.Lock()
  1074  	defer db.mu.Unlock()
  1075  
  1076  	// Wake cleaner up when idle time is shortened.
  1077  	if d > 0 && d < db.maxIdleTime && db.cleanerCh != nil {
  1078  		select {
  1079  		case db.cleanerCh <- struct{}{}:
  1080  		default:
  1081  		}
  1082  	}
  1083  	db.maxIdleTime = d
  1084  	db.startCleanerLocked()
  1085  }
  1086  
  1087  // startCleanerLocked starts connectionCleaner if needed.
  1088  func (db *DB) startCleanerLocked() {
  1089  	if (db.maxLifetime > 0 || db.maxIdleTime > 0) && db.numOpen > 0 && db.cleanerCh == nil {
  1090  		db.cleanerCh = make(chan struct{}, 1)
  1091  		go db.connectionCleaner(db.shortestIdleTimeLocked())
  1092  	}
  1093  }
  1094  
  1095  func (db *DB) connectionCleaner(d time.Duration) {
  1096  	const minInterval = time.Second
  1097  
  1098  	if d < minInterval {
  1099  		d = minInterval
  1100  	}
  1101  	t := time.NewTimer(d)
  1102  
  1103  	for {
  1104  		select {
  1105  		case <-t.C:
  1106  		case <-db.cleanerCh: // maxLifetime was changed or db was closed.
  1107  		}
  1108  
  1109  		db.mu.Lock()
  1110  
  1111  		d = db.shortestIdleTimeLocked()
  1112  		if db.closed || db.numOpen == 0 || d <= 0 {
  1113  			db.cleanerCh = nil
  1114  			db.mu.Unlock()
  1115  			return
  1116  		}
  1117  
  1118  		d, closing := db.connectionCleanerRunLocked(d)
  1119  		db.mu.Unlock()
  1120  		for _, c := range closing {
  1121  			c.Close()
  1122  		}
  1123  
  1124  		if d < minInterval {
  1125  			d = minInterval
  1126  		}
  1127  
  1128  		if !t.Stop() {
  1129  			select {
  1130  			case <-t.C:
  1131  			default:
  1132  			}
  1133  		}
  1134  		t.Reset(d)
  1135  	}
  1136  }
  1137  
  1138  // connectionCleanerRunLocked removes connections that should be closed from
  1139  // freeConn and returns them along side an updated duration to the next check
  1140  // if a quicker check is required to ensure connections are checked appropriately.
  1141  func (db *DB) connectionCleanerRunLocked(d time.Duration) (time.Duration, []*driverConn) {
  1142  	var idleClosing int64
  1143  	var closing []*driverConn
  1144  	if db.maxIdleTime > 0 {
  1145  		// As freeConn is ordered by returnedAt process
  1146  		// in reverse order to minimise the work needed.
  1147  		idleSince := nowFunc().Add(-db.maxIdleTime)
  1148  		last := len(db.freeConn) - 1
  1149  		for i := last; i >= 0; i-- {
  1150  			c := db.freeConn[i]
  1151  			if c.returnedAt.Before(idleSince) {
  1152  				i++
  1153  				closing = db.freeConn[:i:i]
  1154  				db.freeConn = db.freeConn[i:]
  1155  				idleClosing = int64(len(closing))
  1156  				db.maxIdleTimeClosed += idleClosing
  1157  				break
  1158  			}
  1159  		}
  1160  
  1161  		if len(db.freeConn) > 0 {
  1162  			c := db.freeConn[0]
  1163  			if d2 := c.returnedAt.Sub(idleSince); d2 < d {
  1164  				// Ensure idle connections are cleaned up as soon as
  1165  				// possible.
  1166  				d = d2
  1167  			}
  1168  		}
  1169  	}
  1170  
  1171  	if db.maxLifetime > 0 {
  1172  		expiredSince := nowFunc().Add(-db.maxLifetime)
  1173  		for i := 0; i < len(db.freeConn); i++ {
  1174  			c := db.freeConn[i]
  1175  			if c.createdAt.Before(expiredSince) {
  1176  				closing = append(closing, c)
  1177  
  1178  				last := len(db.freeConn) - 1
  1179  				// Use slow delete as order is required to ensure
  1180  				// connections are reused least idle time first.
  1181  				copy(db.freeConn[i:], db.freeConn[i+1:])
  1182  				db.freeConn[last] = nil
  1183  				db.freeConn = db.freeConn[:last]
  1184  				i--
  1185  			} else if d2 := c.createdAt.Sub(expiredSince); d2 < d {
  1186  				// Prevent connections sitting the freeConn when they
  1187  				// have expired by updating our next deadline d.
  1188  				d = d2
  1189  			}
  1190  		}
  1191  		db.maxLifetimeClosed += int64(len(closing)) - idleClosing
  1192  	}
  1193  
  1194  	return d, closing
  1195  }
  1196  
  1197  // DBStats contains database statistics.
  1198  type DBStats struct {
  1199  	MaxOpenConnections int // Maximum number of open connections to the database.
  1200  
  1201  	// Pool Status
  1202  	OpenConnections int // The number of established connections both in use and idle.
  1203  	InUse           int // The number of connections currently in use.
  1204  	Idle            int // The number of idle connections.
  1205  
  1206  	// Counters
  1207  	WaitCount         int64         // The total number of connections waited for.
  1208  	WaitDuration      time.Duration // The total time blocked waiting for a new connection.
  1209  	MaxIdleClosed     int64         // The total number of connections closed due to SetMaxIdleConns.
  1210  	MaxIdleTimeClosed int64         // The total number of connections closed due to SetConnMaxIdleTime.
  1211  	MaxLifetimeClosed int64         // The total number of connections closed due to SetConnMaxLifetime.
  1212  }
  1213  
  1214  // Stats returns database statistics.
  1215  func (db *DB) Stats() DBStats {
  1216  	wait := db.waitDuration.Load()
  1217  
  1218  	db.mu.Lock()
  1219  	defer db.mu.Unlock()
  1220  
  1221  	stats := DBStats{
  1222  		MaxOpenConnections: db.maxOpen,
  1223  
  1224  		Idle:            len(db.freeConn),
  1225  		OpenConnections: db.numOpen,
  1226  		InUse:           db.numOpen - len(db.freeConn),
  1227  
  1228  		WaitCount:         db.waitCount,
  1229  		WaitDuration:      time.Duration(wait),
  1230  		MaxIdleClosed:     db.maxIdleClosed,
  1231  		MaxIdleTimeClosed: db.maxIdleTimeClosed,
  1232  		MaxLifetimeClosed: db.maxLifetimeClosed,
  1233  	}
  1234  	return stats
  1235  }
  1236  
  1237  // Assumes db.mu is locked.
  1238  // If there are connRequests and the connection limit hasn't been reached,
  1239  // then tell the connectionOpener to open new connections.
  1240  func (db *DB) maybeOpenNewConnections() {
  1241  	numRequests := db.connRequests.Len()
  1242  	if db.maxOpen > 0 {
  1243  		numCanOpen := db.maxOpen - db.numOpen
  1244  		if numRequests > numCanOpen {
  1245  			numRequests = numCanOpen
  1246  		}
  1247  	}
  1248  	for numRequests > 0 {
  1249  		db.numOpen++ // optimistically
  1250  		numRequests--
  1251  		if db.closed {
  1252  			return
  1253  		}
  1254  		db.openerCh <- struct{}{}
  1255  	}
  1256  }
  1257  
  1258  // Runs in a separate goroutine, opens new connections when requested.
  1259  func (db *DB) connectionOpener(ctx context.Context) {
  1260  	for {
  1261  		select {
  1262  		case <-ctx.Done():
  1263  			return
  1264  		case <-db.openerCh:
  1265  			db.openNewConnection(ctx)
  1266  		}
  1267  	}
  1268  }
  1269  
  1270  // Open one new connection
  1271  func (db *DB) openNewConnection(ctx context.Context) {
  1272  	// maybeOpenNewConnections has already executed db.numOpen++ before it sent
  1273  	// on db.openerCh. This function must execute db.numOpen-- if the
  1274  	// connection fails or is closed before returning.
  1275  	ci, err := db.connector.Connect(ctx)
  1276  	db.mu.Lock()
  1277  	defer db.mu.Unlock()
  1278  	if db.closed {
  1279  		if err == nil {
  1280  			ci.Close()
  1281  		}
  1282  		db.numOpen--
  1283  		return
  1284  	}
  1285  	if err != nil {
  1286  		db.numOpen--
  1287  		db.putConnDBLocked(nil, err)
  1288  		db.maybeOpenNewConnections()
  1289  		return
  1290  	}
  1291  	dc := &driverConn{
  1292  		db:         db,
  1293  		createdAt:  nowFunc(),
  1294  		returnedAt: nowFunc(),
  1295  		ci:         ci,
  1296  	}
  1297  	if db.putConnDBLocked(dc, err) {
  1298  		db.addDepLocked(dc, dc)
  1299  	} else {
  1300  		db.numOpen--
  1301  		ci.Close()
  1302  	}
  1303  }
  1304  
  1305  // connRequest represents one request for a new connection
  1306  // When there are no idle connections available, DB.conn will create
  1307  // a new connRequest and put it on the db.connRequests list.
  1308  type connRequest struct {
  1309  	conn *driverConn
  1310  	err  error
  1311  }
  1312  
  1313  var errDBClosed = errors.New("sql: database is closed")
  1314  
  1315  // conn returns a newly-opened or cached *driverConn.
  1316  func (db *DB) conn(ctx context.Context, strategy connReuseStrategy) (*driverConn, error) {
  1317  	db.mu.Lock()
  1318  	if db.closed {
  1319  		db.mu.Unlock()
  1320  		return nil, errDBClosed
  1321  	}
  1322  	// Check if the context is expired.
  1323  	select {
  1324  	default:
  1325  	case <-ctx.Done():
  1326  		db.mu.Unlock()
  1327  		return nil, ctx.Err()
  1328  	}
  1329  	lifetime := db.maxLifetime
  1330  
  1331  	// Prefer a free connection, if possible.
  1332  	last := len(db.freeConn) - 1
  1333  	if strategy == cachedOrNewConn && last >= 0 {
  1334  		// Reuse the lowest idle time connection so we can close
  1335  		// connections which remain idle as soon as possible.
  1336  		conn := db.freeConn[last]
  1337  		db.freeConn = db.freeConn[:last]
  1338  		conn.inUse = true
  1339  		if conn.expired(lifetime) {
  1340  			db.maxLifetimeClosed++
  1341  			db.mu.Unlock()
  1342  			conn.Close()
  1343  			return nil, driver.ErrBadConn
  1344  		}
  1345  		db.mu.Unlock()
  1346  
  1347  		// Reset the session if required.
  1348  		if err := conn.resetSession(ctx); errors.Is(err, driver.ErrBadConn) {
  1349  			conn.Close()
  1350  			return nil, err
  1351  		}
  1352  
  1353  		return conn, nil
  1354  	}
  1355  
  1356  	// Out of free connections or we were asked not to use one. If we're not
  1357  	// allowed to open any more connections, make a request and wait.
  1358  	if db.maxOpen > 0 && db.numOpen >= db.maxOpen {
  1359  		// Make the connRequest channel. It's buffered so that the
  1360  		// connectionOpener doesn't block while waiting for the req to be read.
  1361  		req := make(chan connRequest, 1)
  1362  		delHandle := db.connRequests.Add(req)
  1363  		db.waitCount++
  1364  		db.mu.Unlock()
  1365  
  1366  		waitStart := nowFunc()
  1367  
  1368  		// Timeout the connection request with the context.
  1369  		select {
  1370  		case <-ctx.Done():
  1371  			// Remove the connection request and ensure no value has been sent
  1372  			// on it after removing.
  1373  			db.mu.Lock()
  1374  			deleted := db.connRequests.Delete(delHandle)
  1375  			db.mu.Unlock()
  1376  
  1377  			db.waitDuration.Add(int64(time.Since(waitStart)))
  1378  
  1379  			// If we failed to delete it, that means either the DB was closed or
  1380  			// something else grabbed it and is about to send on it.
  1381  			if !deleted {
  1382  				// TODO(bradfitz): rather than this best effort select, we
  1383  				// should probably start a goroutine to read from req. This best
  1384  				// effort select existed before the change to check 'deleted'.
  1385  				// But if we know for sure it wasn't deleted and a sender is
  1386  				// outstanding, we should probably block on req (in a new
  1387  				// goroutine) to get the connection back.
  1388  				select {
  1389  				default:
  1390  				case ret, ok := <-req:
  1391  					if ok && ret.conn != nil {
  1392  						db.putConn(ret.conn, ret.err, false)
  1393  					}
  1394  				}
  1395  			}
  1396  			return nil, ctx.Err()
  1397  		case ret, ok := <-req:
  1398  			db.waitDuration.Add(int64(time.Since(waitStart)))
  1399  
  1400  			if !ok {
  1401  				return nil, errDBClosed
  1402  			}
  1403  			// Only check if the connection is expired if the strategy is cachedOrNewConns.
  1404  			// If we require a new connection, just re-use the connection without looking
  1405  			// at the expiry time. If it is expired, it will be checked when it is placed
  1406  			// back into the connection pool.
  1407  			// This prioritizes giving a valid connection to a client over the exact connection
  1408  			// lifetime, which could expire exactly after this point anyway.
  1409  			if strategy == cachedOrNewConn && ret.err == nil && ret.conn.expired(lifetime) {
  1410  				db.mu.Lock()
  1411  				db.maxLifetimeClosed++
  1412  				db.mu.Unlock()
  1413  				ret.conn.Close()
  1414  				return nil, driver.ErrBadConn
  1415  			}
  1416  			if ret.conn == nil {
  1417  				return nil, ret.err
  1418  			}
  1419  
  1420  			// Reset the session if required.
  1421  			if err := ret.conn.resetSession(ctx); errors.Is(err, driver.ErrBadConn) {
  1422  				ret.conn.Close()
  1423  				return nil, err
  1424  			}
  1425  			return ret.conn, ret.err
  1426  		}
  1427  	}
  1428  
  1429  	db.numOpen++ // optimistically
  1430  	db.mu.Unlock()
  1431  	ci, err := db.connector.Connect(ctx)
  1432  	if err != nil {
  1433  		db.mu.Lock()
  1434  		db.numOpen-- // correct for earlier optimism
  1435  		db.maybeOpenNewConnections()
  1436  		db.mu.Unlock()
  1437  		return nil, err
  1438  	}
  1439  	db.mu.Lock()
  1440  	dc := &driverConn{
  1441  		db:         db,
  1442  		createdAt:  nowFunc(),
  1443  		returnedAt: nowFunc(),
  1444  		ci:         ci,
  1445  		inUse:      true,
  1446  	}
  1447  	db.addDepLocked(dc, dc)
  1448  	db.mu.Unlock()
  1449  	return dc, nil
  1450  }
  1451  
  1452  // putConnHook is a hook for testing.
  1453  var putConnHook func(*DB, *driverConn)
  1454  
  1455  // noteUnusedDriverStatement notes that ds is no longer used and should
  1456  // be closed whenever possible (when c is next not in use), unless c is
  1457  // already closed.
  1458  func (db *DB) noteUnusedDriverStatement(c *driverConn, ds *driverStmt) {
  1459  	db.mu.Lock()
  1460  	defer db.mu.Unlock()
  1461  	if c.inUse {
  1462  		c.onPut = append(c.onPut, func() {
  1463  			ds.Close()
  1464  		})
  1465  	} else {
  1466  		c.Lock()
  1467  		fc := c.finalClosed
  1468  		c.Unlock()
  1469  		if !fc {
  1470  			ds.Close()
  1471  		}
  1472  	}
  1473  }
  1474  
  1475  // debugGetPut determines whether getConn & putConn calls' stack traces
  1476  // are returned for more verbose crashes.
  1477  const debugGetPut = false
  1478  
  1479  // putConn adds a connection to the db's free pool.
  1480  // err is optionally the last error that occurred on this connection.
  1481  func (db *DB) putConn(dc *driverConn, err error, resetSession bool) {
  1482  	if !errors.Is(err, driver.ErrBadConn) {
  1483  		if !dc.validateConnection(resetSession) {
  1484  			err = driver.ErrBadConn
  1485  		}
  1486  	}
  1487  	db.mu.Lock()
  1488  	if !dc.inUse {
  1489  		db.mu.Unlock()
  1490  		if debugGetPut {
  1491  			fmt.Printf("putConn(%v) DUPLICATE was: %s\n\nPREVIOUS was: %s", dc, stack(), db.lastPut[dc])
  1492  		}
  1493  		panic("sql: connection returned that was never out")
  1494  	}
  1495  
  1496  	if !errors.Is(err, driver.ErrBadConn) && dc.expired(db.maxLifetime) {
  1497  		db.maxLifetimeClosed++
  1498  		err = driver.ErrBadConn
  1499  	}
  1500  	if debugGetPut {
  1501  		db.lastPut[dc] = stack()
  1502  	}
  1503  	dc.inUse = false
  1504  	dc.returnedAt = nowFunc()
  1505  
  1506  	for _, fn := range dc.onPut {
  1507  		fn()
  1508  	}
  1509  	dc.onPut = nil
  1510  
  1511  	if errors.Is(err, driver.ErrBadConn) {
  1512  		// Don't reuse bad connections.
  1513  		// Since the conn is considered bad and is being discarded, treat it
  1514  		// as closed. Don't decrement the open count here, finalClose will
  1515  		// take care of that.
  1516  		db.maybeOpenNewConnections()
  1517  		db.mu.Unlock()
  1518  		dc.Close()
  1519  		return
  1520  	}
  1521  	if putConnHook != nil {
  1522  		putConnHook(db, dc)
  1523  	}
  1524  	added := db.putConnDBLocked(dc, nil)
  1525  	db.mu.Unlock()
  1526  
  1527  	if !added {
  1528  		dc.Close()
  1529  		return
  1530  	}
  1531  }
  1532  
  1533  // Satisfy a connRequest or put the driverConn in the idle pool and return true
  1534  // or return false.
  1535  // putConnDBLocked will satisfy a connRequest if there is one, or it will
  1536  // return the *driverConn to the freeConn list if err == nil and the idle
  1537  // connection limit will not be exceeded.
  1538  // If err != nil, the value of dc is ignored.
  1539  // If err == nil, then dc must not equal nil.
  1540  // If a connRequest was fulfilled or the *driverConn was placed in the
  1541  // freeConn list, then true is returned, otherwise false is returned.
  1542  func (db *DB) putConnDBLocked(dc *driverConn, err error) bool {
  1543  	if db.closed {
  1544  		return false
  1545  	}
  1546  	if db.maxOpen > 0 && db.numOpen > db.maxOpen {
  1547  		return false
  1548  	}
  1549  	if req, ok := db.connRequests.TakeRandom(); ok {
  1550  		if err == nil {
  1551  			dc.inUse = true
  1552  		}
  1553  		req <- connRequest{
  1554  			conn: dc,
  1555  			err:  err,
  1556  		}
  1557  		return true
  1558  	} else if err == nil && !db.closed {
  1559  		if db.maxIdleConnsLocked() > len(db.freeConn) {
  1560  			db.freeConn = append(db.freeConn, dc)
  1561  			db.startCleanerLocked()
  1562  			return true
  1563  		}
  1564  		db.maxIdleClosed++
  1565  	}
  1566  	return false
  1567  }
  1568  
  1569  // maxBadConnRetries is the number of maximum retries if the driver returns
  1570  // driver.ErrBadConn to signal a broken connection before forcing a new
  1571  // connection to be opened.
  1572  const maxBadConnRetries = 2
  1573  
  1574  func (db *DB) retry(fn func(strategy connReuseStrategy) error) error {
  1575  	for i := int64(0); i < maxBadConnRetries; i++ {
  1576  		err := fn(cachedOrNewConn)
  1577  		// retry if err is driver.ErrBadConn
  1578  		if err == nil || !errors.Is(err, driver.ErrBadConn) {
  1579  			return err
  1580  		}
  1581  	}
  1582  
  1583  	return fn(alwaysNewConn)
  1584  }
  1585  
  1586  // PrepareContext creates a prepared statement for later queries or executions.
  1587  // Multiple queries or executions may be run concurrently from the
  1588  // returned statement.
  1589  // The caller must call the statement's [*Stmt.Close] method
  1590  // when the statement is no longer needed.
  1591  //
  1592  // The provided context is used for the preparation of the statement, not for the
  1593  // execution of the statement.
  1594  func (db *DB) PrepareContext(ctx context.Context, query string) (*Stmt, error) {
  1595  	var stmt *Stmt
  1596  	var err error
  1597  
  1598  	err = db.retry(func(strategy connReuseStrategy) error {
  1599  		stmt, err = db.prepare(ctx, query, strategy)
  1600  		return err
  1601  	})
  1602  
  1603  	return stmt, err
  1604  }
  1605  
  1606  // Prepare creates a prepared statement for later queries or executions.
  1607  // Multiple queries or executions may be run concurrently from the
  1608  // returned statement.
  1609  // The caller must call the statement's [*Stmt.Close] method
  1610  // when the statement is no longer needed.
  1611  //
  1612  // Prepare uses [context.Background] internally; to specify the context, use
  1613  // [DB.PrepareContext].
  1614  func (db *DB) Prepare(query string) (*Stmt, error) {
  1615  	return db.PrepareContext(context.Background(), query)
  1616  }
  1617  
  1618  func (db *DB) prepare(ctx context.Context, query string, strategy connReuseStrategy) (*Stmt, error) {
  1619  	// TODO: check if db.driver supports an optional
  1620  	// driver.Preparer interface and call that instead, if so,
  1621  	// otherwise we make a prepared statement that's bound
  1622  	// to a connection, and to execute this prepared statement
  1623  	// we either need to use this connection (if it's free), else
  1624  	// get a new connection + re-prepare + execute on that one.
  1625  	dc, err := db.conn(ctx, strategy)
  1626  	if err != nil {
  1627  		return nil, err
  1628  	}
  1629  	return db.prepareDC(ctx, dc, dc.releaseConn, nil, query)
  1630  }
  1631  
  1632  // prepareDC prepares a query on the driverConn and calls release before
  1633  // returning. When cg == nil it implies that a connection pool is used, and
  1634  // when cg != nil only a single driver connection is used.
  1635  func (db *DB) prepareDC(ctx context.Context, dc *driverConn, release func(error), cg stmtConnGrabber, query string) (*Stmt, error) {
  1636  	var ds *driverStmt
  1637  	var err error
  1638  	defer func() {
  1639  		release(err)
  1640  	}()
  1641  	withLock(dc, func() {
  1642  		ds, err = dc.prepareLocked(ctx, cg, query)
  1643  	})
  1644  	if err != nil {
  1645  		return nil, err
  1646  	}
  1647  	stmt := &Stmt{
  1648  		db:    db,
  1649  		query: query,
  1650  		cg:    cg,
  1651  		cgds:  ds,
  1652  	}
  1653  
  1654  	// When cg == nil this statement will need to keep track of various
  1655  	// connections they are prepared on and record the stmt dependency on
  1656  	// the DB.
  1657  	if cg == nil {
  1658  		stmt.css = []connStmt{{dc, ds}}
  1659  		stmt.lastNumClosed = db.numClosed.Load()
  1660  		db.addDep(stmt, stmt)
  1661  	}
  1662  	return stmt, nil
  1663  }
  1664  
  1665  // ExecContext executes a query without returning any rows.
  1666  // The args are for any placeholder parameters in the query.
  1667  func (db *DB) ExecContext(ctx context.Context, query string, args ...any) (Result, error) {
  1668  	var res Result
  1669  	var err error
  1670  
  1671  	err = db.retry(func(strategy connReuseStrategy) error {
  1672  		res, err = db.exec(ctx, query, args, strategy)
  1673  		return err
  1674  	})
  1675  
  1676  	return res, err
  1677  }
  1678  
  1679  // Exec executes a query without returning any rows.
  1680  // The args are for any placeholder parameters in the query.
  1681  //
  1682  // Exec uses [context.Background] internally; to specify the context, use
  1683  // [DB.ExecContext].
  1684  func (db *DB) Exec(query string, args ...any) (Result, error) {
  1685  	return db.ExecContext(context.Background(), query, args...)
  1686  }
  1687  
  1688  func (db *DB) exec(ctx context.Context, query string, args []any, strategy connReuseStrategy) (Result, error) {
  1689  	dc, err := db.conn(ctx, strategy)
  1690  	if err != nil {
  1691  		return nil, err
  1692  	}
  1693  	return db.execDC(ctx, dc, dc.releaseConn, query, args)
  1694  }
  1695  
  1696  func (db *DB) execDC(ctx context.Context, dc *driverConn, release func(error), query string, args []any) (res Result, err error) {
  1697  	defer func() {
  1698  		release(err)
  1699  	}()
  1700  	execerCtx, ok := dc.ci.(driver.ExecerContext)
  1701  	var execer driver.Execer
  1702  	if !ok {
  1703  		execer, ok = dc.ci.(driver.Execer)
  1704  	}
  1705  	if ok {
  1706  		var nvdargs []driver.NamedValue
  1707  		var resi driver.Result
  1708  		withLock(dc, func() {
  1709  			nvdargs, err = driverArgsConnLocked(dc.ci, nil, args)
  1710  			if err != nil {
  1711  				return
  1712  			}
  1713  			resi, err = ctxDriverExec(ctx, execerCtx, execer, query, nvdargs)
  1714  		})
  1715  		if err != driver.ErrSkip {
  1716  			if err != nil {
  1717  				return nil, err
  1718  			}
  1719  			return driverResult{dc, resi}, nil
  1720  		}
  1721  	}
  1722  
  1723  	var si driver.Stmt
  1724  	withLock(dc, func() {
  1725  		si, err = ctxDriverPrepare(ctx, dc.ci, query)
  1726  	})
  1727  	if err != nil {
  1728  		return nil, err
  1729  	}
  1730  	ds := &driverStmt{Locker: dc, si: si}
  1731  	defer ds.Close()
  1732  	return resultFromStatement(ctx, dc.ci, ds, args...)
  1733  }
  1734  
  1735  // QueryContext executes a query that returns rows, typically a SELECT.
  1736  // The args are for any placeholder parameters in the query.
  1737  func (db *DB) QueryContext(ctx context.Context, query string, args ...any) (*Rows, error) {
  1738  	var rows *Rows
  1739  	var err error
  1740  
  1741  	err = db.retry(func(strategy connReuseStrategy) error {
  1742  		rows, err = db.query(ctx, query, args, strategy)
  1743  		return err
  1744  	})
  1745  
  1746  	return rows, err
  1747  }
  1748  
  1749  // Query executes a query that returns rows, typically a SELECT.
  1750  // The args are for any placeholder parameters in the query.
  1751  //
  1752  // Query uses [context.Background] internally; to specify the context, use
  1753  // [DB.QueryContext].
  1754  func (db *DB) Query(query string, args ...any) (*Rows, error) {
  1755  	return db.QueryContext(context.Background(), query, args...)
  1756  }
  1757  
  1758  func (db *DB) query(ctx context.Context, query string, args []any, strategy connReuseStrategy) (*Rows, error) {
  1759  	dc, err := db.conn(ctx, strategy)
  1760  	if err != nil {
  1761  		return nil, err
  1762  	}
  1763  
  1764  	return db.queryDC(ctx, nil, dc, dc.releaseConn, query, args)
  1765  }
  1766  
  1767  // queryDC executes a query on the given connection.
  1768  // The connection gets released by the releaseConn function.
  1769  // The ctx context is from a query method and the txctx context is from an
  1770  // optional transaction context.
  1771  func (db *DB) queryDC(ctx, txctx context.Context, dc *driverConn, releaseConn func(error), query string, args []any) (*Rows, error) {
  1772  	queryerCtx, ok := dc.ci.(driver.QueryerContext)
  1773  	var queryer driver.Queryer
  1774  	if !ok {
  1775  		queryer, ok = dc.ci.(driver.Queryer)
  1776  	}
  1777  	if ok {
  1778  		var nvdargs []driver.NamedValue
  1779  		var rowsi driver.Rows
  1780  		var err error
  1781  		withLock(dc, func() {
  1782  			nvdargs, err = driverArgsConnLocked(dc.ci, nil, args)
  1783  			if err != nil {
  1784  				return
  1785  			}
  1786  			rowsi, err = ctxDriverQuery(ctx, queryerCtx, queryer, query, nvdargs)
  1787  		})
  1788  		if err != driver.ErrSkip {
  1789  			if err != nil {
  1790  				releaseConn(err)
  1791  				return nil, err
  1792  			}
  1793  			// Note: ownership of dc passes to the *Rows, to be freed
  1794  			// with releaseConn.
  1795  			rows := &Rows{
  1796  				dc:          dc,
  1797  				releaseConn: releaseConn,
  1798  				rowsi:       rowsi,
  1799  			}
  1800  			rows.initContextClose(ctx, txctx)
  1801  			return rows, nil
  1802  		}
  1803  	}
  1804  
  1805  	var si driver.Stmt
  1806  	var err error
  1807  	withLock(dc, func() {
  1808  		si, err = ctxDriverPrepare(ctx, dc.ci, query)
  1809  	})
  1810  	if err != nil {
  1811  		releaseConn(err)
  1812  		return nil, err
  1813  	}
  1814  
  1815  	ds := &driverStmt{Locker: dc, si: si}
  1816  	rowsi, err := rowsiFromStatement(ctx, dc.ci, ds, args...)
  1817  	if err != nil {
  1818  		ds.Close()
  1819  		releaseConn(err)
  1820  		return nil, err
  1821  	}
  1822  
  1823  	// Note: ownership of ci passes to the *Rows, to be freed
  1824  	// with releaseConn.
  1825  	rows := &Rows{
  1826  		dc:          dc,
  1827  		releaseConn: releaseConn,
  1828  		rowsi:       rowsi,
  1829  		closeStmt:   ds,
  1830  	}
  1831  	rows.initContextClose(ctx, txctx)
  1832  	return rows, nil
  1833  }
  1834  
  1835  // QueryRowContext executes a query that is expected to return at most one row.
  1836  // QueryRowContext always returns a non-nil value. Errors are deferred until
  1837  // [Row]'s Scan method is called.
  1838  // If the query selects no rows, the [*Row.Scan] will return [ErrNoRows].
  1839  // Otherwise, [*Row.Scan] scans the first selected row and discards
  1840  // the rest.
  1841  func (db *DB) QueryRowContext(ctx context.Context, query string, args ...any) *Row {
  1842  	rows, err := db.QueryContext(ctx, query, args...)
  1843  	return &Row{rows: rows, err: err}
  1844  }
  1845  
  1846  // QueryRow executes a query that is expected to return at most one row.
  1847  // QueryRow always returns a non-nil value. Errors are deferred until
  1848  // [Row]'s Scan method is called.
  1849  // If the query selects no rows, the [*Row.Scan] will return [ErrNoRows].
  1850  // Otherwise, [*Row.Scan] scans the first selected row and discards
  1851  // the rest.
  1852  //
  1853  // QueryRow uses [context.Background] internally; to specify the context, use
  1854  // [DB.QueryRowContext].
  1855  func (db *DB) QueryRow(query string, args ...any) *Row {
  1856  	return db.QueryRowContext(context.Background(), query, args...)
  1857  }
  1858  
  1859  // BeginTx starts a transaction.
  1860  //
  1861  // The provided context is used until the transaction is committed or rolled back.
  1862  // If the context is canceled, the sql package will roll back
  1863  // the transaction. [Tx.Commit] will return an error if the context provided to
  1864  // BeginTx is canceled.
  1865  //
  1866  // The provided [TxOptions] is optional and may be nil if defaults should be used.
  1867  // If a non-default isolation level is used that the driver doesn't support,
  1868  // an error will be returned.
  1869  func (db *DB) BeginTx(ctx context.Context, opts *TxOptions) (*Tx, error) {
  1870  	var tx *Tx
  1871  	var err error
  1872  
  1873  	err = db.retry(func(strategy connReuseStrategy) error {
  1874  		tx, err = db.begin(ctx, opts, strategy)
  1875  		return err
  1876  	})
  1877  
  1878  	return tx, err
  1879  }
  1880  
  1881  // Begin starts a transaction. The default isolation level is dependent on
  1882  // the driver.
  1883  //
  1884  // Begin uses [context.Background] internally; to specify the context, use
  1885  // [DB.BeginTx].
  1886  func (db *DB) Begin() (*Tx, error) {
  1887  	return db.BeginTx(context.Background(), nil)
  1888  }
  1889  
  1890  func (db *DB) begin(ctx context.Context, opts *TxOptions, strategy connReuseStrategy) (tx *Tx, err error) {
  1891  	dc, err := db.conn(ctx, strategy)
  1892  	if err != nil {
  1893  		return nil, err
  1894  	}
  1895  	return db.beginDC(ctx, dc, dc.releaseConn, opts)
  1896  }
  1897  
  1898  // beginDC starts a transaction. The provided dc must be valid and ready to use.
  1899  func (db *DB) beginDC(ctx context.Context, dc *driverConn, release func(error), opts *TxOptions) (tx *Tx, err error) {
  1900  	var txi driver.Tx
  1901  	keepConnOnRollback := false
  1902  	withLock(dc, func() {
  1903  		_, hasSessionResetter := dc.ci.(driver.SessionResetter)
  1904  		_, hasConnectionValidator := dc.ci.(driver.Validator)
  1905  		keepConnOnRollback = hasSessionResetter && hasConnectionValidator
  1906  		txi, err = ctxDriverBegin(ctx, opts, dc.ci)
  1907  	})
  1908  	if err != nil {
  1909  		release(err)
  1910  		return nil, err
  1911  	}
  1912  
  1913  	// Schedule the transaction to rollback when the context is canceled.
  1914  	// The cancel function in Tx will be called after done is set to true.
  1915  	ctx, cancel := context.WithCancel(ctx)
  1916  	tx = &Tx{
  1917  		db:                 db,
  1918  		dc:                 dc,
  1919  		releaseConn:        release,
  1920  		txi:                txi,
  1921  		cancel:             cancel,
  1922  		keepConnOnRollback: keepConnOnRollback,
  1923  		ctx:                ctx,
  1924  	}
  1925  	go tx.awaitDone()
  1926  	return tx, nil
  1927  }
  1928  
  1929  // Driver returns the database's underlying driver.
  1930  func (db *DB) Driver() driver.Driver {
  1931  	return db.connector.Driver()
  1932  }
  1933  
  1934  // ErrConnDone is returned by any operation that is performed on a connection
  1935  // that has already been returned to the connection pool.
  1936  var ErrConnDone = errors.New("sql: connection is already closed")
  1937  
  1938  // Conn returns a single connection by either opening a new connection
  1939  // or returning an existing connection from the connection pool. Conn will
  1940  // block until either a connection is returned or ctx is canceled.
  1941  // Queries run on the same Conn will be run in the same database session.
  1942  //
  1943  // Every Conn must be returned to the database pool after use by
  1944  // calling [Conn.Close].
  1945  func (db *DB) Conn(ctx context.Context) (*Conn, error) {
  1946  	var dc *driverConn
  1947  	var err error
  1948  
  1949  	err = db.retry(func(strategy connReuseStrategy) error {
  1950  		dc, err = db.conn(ctx, strategy)
  1951  		return err
  1952  	})
  1953  
  1954  	if err != nil {
  1955  		return nil, err
  1956  	}
  1957  
  1958  	conn := &Conn{
  1959  		db: db,
  1960  		dc: dc,
  1961  	}
  1962  	return conn, nil
  1963  }
  1964  
  1965  type releaseConn func(error)
  1966  
  1967  // Conn represents a single database connection rather than a pool of database
  1968  // connections. Prefer running queries from [DB] unless there is a specific
  1969  // need for a continuous single database connection.
  1970  //
  1971  // A Conn must call [Conn.Close] to return the connection to the database pool
  1972  // and may do so concurrently with a running query.
  1973  //
  1974  // After a call to [Conn.Close], all operations on the
  1975  // connection fail with [ErrConnDone].
  1976  type Conn struct {
  1977  	db *DB
  1978  
  1979  	// closemu prevents the connection from closing while there
  1980  	// is an active query. It is held for read during queries
  1981  	// and exclusively during close.
  1982  	closemu sync.RWMutex
  1983  
  1984  	// dc is owned until close, at which point
  1985  	// it's returned to the connection pool.
  1986  	dc *driverConn
  1987  
  1988  	// done transitions from false to true exactly once, on close.
  1989  	// Once done, all operations fail with ErrConnDone.
  1990  	done atomic.Bool
  1991  
  1992  	releaseConnOnce sync.Once
  1993  	// releaseConnCache is a cache of c.closemuRUnlockCondReleaseConn
  1994  	// to save allocations in a call to grabConn.
  1995  	releaseConnCache releaseConn
  1996  }
  1997  
  1998  // grabConn takes a context to implement stmtConnGrabber
  1999  // but the context is not used.
  2000  func (c *Conn) grabConn(context.Context) (*driverConn, releaseConn, error) {
  2001  	if c.done.Load() {
  2002  		return nil, nil, ErrConnDone
  2003  	}
  2004  	c.releaseConnOnce.Do(func() {
  2005  		c.releaseConnCache = c.closemuRUnlockCondReleaseConn
  2006  	})
  2007  	c.closemu.RLock()
  2008  	return c.dc, c.releaseConnCache, nil
  2009  }
  2010  
  2011  // PingContext verifies the connection to the database is still alive.
  2012  func (c *Conn) PingContext(ctx context.Context) error {
  2013  	dc, release, err := c.grabConn(ctx)
  2014  	if err != nil {
  2015  		return err
  2016  	}
  2017  	return c.db.pingDC(ctx, dc, release)
  2018  }
  2019  
  2020  // ExecContext executes a query without returning any rows.
  2021  // The args are for any placeholder parameters in the query.
  2022  func (c *Conn) ExecContext(ctx context.Context, query string, args ...any) (Result, error) {
  2023  	dc, release, err := c.grabConn(ctx)
  2024  	if err != nil {
  2025  		return nil, err
  2026  	}
  2027  	return c.db.execDC(ctx, dc, release, query, args)
  2028  }
  2029  
  2030  // QueryContext executes a query that returns rows, typically a SELECT.
  2031  // The args are for any placeholder parameters in the query.
  2032  func (c *Conn) QueryContext(ctx context.Context, query string, args ...any) (*Rows, error) {
  2033  	dc, release, err := c.grabConn(ctx)
  2034  	if err != nil {
  2035  		return nil, err
  2036  	}
  2037  	return c.db.queryDC(ctx, nil, dc, release, query, args)
  2038  }
  2039  
  2040  // QueryRowContext executes a query that is expected to return at most one row.
  2041  // QueryRowContext always returns a non-nil value. Errors are deferred until
  2042  // the [*Row.Scan] method is called.
  2043  // If the query selects no rows, the [*Row.Scan] will return [ErrNoRows].
  2044  // Otherwise, the [*Row.Scan] scans the first selected row and discards
  2045  // the rest.
  2046  func (c *Conn) QueryRowContext(ctx context.Context, query string, args ...any) *Row {
  2047  	rows, err := c.QueryContext(ctx, query, args...)
  2048  	return &Row{rows: rows, err: err}
  2049  }
  2050  
  2051  // PrepareContext creates a prepared statement for later queries or executions.
  2052  // Multiple queries or executions may be run concurrently from the
  2053  // returned statement.
  2054  // The caller must call the statement's [*Stmt.Close] method
  2055  // when the statement is no longer needed.
  2056  //
  2057  // The provided context is used for the preparation of the statement, not for the
  2058  // execution of the statement.
  2059  func (c *Conn) PrepareContext(ctx context.Context, query string) (*Stmt, error) {
  2060  	dc, release, err := c.grabConn(ctx)
  2061  	if err != nil {
  2062  		return nil, err
  2063  	}
  2064  	return c.db.prepareDC(ctx, dc, release, c, query)
  2065  }
  2066  
  2067  // Raw executes f exposing the underlying driver connection for the
  2068  // duration of f. The driverConn must not be used outside of f.
  2069  //
  2070  // Once f returns and err is not [driver.ErrBadConn], the [Conn] will continue to be usable
  2071  // until [Conn.Close] is called.
  2072  func (c *Conn) Raw(f func(driverConn any) error) (err error) {
  2073  	var dc *driverConn
  2074  	var release releaseConn
  2075  
  2076  	// grabConn takes a context to implement stmtConnGrabber, but the context is not used.
  2077  	dc, release, err = c.grabConn(nil)
  2078  	if err != nil {
  2079  		return
  2080  	}
  2081  	fPanic := true
  2082  	dc.Mutex.Lock()
  2083  	defer func() {
  2084  		dc.Mutex.Unlock()
  2085  
  2086  		// If f panics fPanic will remain true.
  2087  		// Ensure an error is passed to release so the connection
  2088  		// may be discarded.
  2089  		if fPanic {
  2090  			err = driver.ErrBadConn
  2091  		}
  2092  		release(err)
  2093  	}()
  2094  	err = f(dc.ci)
  2095  	fPanic = false
  2096  
  2097  	return
  2098  }
  2099  
  2100  // BeginTx starts a transaction.
  2101  //
  2102  // The provided context is used until the transaction is committed or rolled back.
  2103  // If the context is canceled, the sql package will roll back
  2104  // the transaction. [Tx.Commit] will return an error if the context provided to
  2105  // BeginTx is canceled.
  2106  //
  2107  // The provided [TxOptions] is optional and may be nil if defaults should be used.
  2108  // If a non-default isolation level is used that the driver doesn't support,
  2109  // an error will be returned.
  2110  func (c *Conn) BeginTx(ctx context.Context, opts *TxOptions) (*Tx, error) {
  2111  	dc, release, err := c.grabConn(ctx)
  2112  	if err != nil {
  2113  		return nil, err
  2114  	}
  2115  	return c.db.beginDC(ctx, dc, release, opts)
  2116  }
  2117  
  2118  // closemuRUnlockCondReleaseConn read unlocks closemu
  2119  // as the sql operation is done with the dc.
  2120  func (c *Conn) closemuRUnlockCondReleaseConn(err error) {
  2121  	c.closemu.RUnlock()
  2122  	if errors.Is(err, driver.ErrBadConn) {
  2123  		c.close(err)
  2124  	}
  2125  }
  2126  
  2127  func (c *Conn) txCtx() context.Context {
  2128  	return nil
  2129  }
  2130  
  2131  func (c *Conn) close(err error) error {
  2132  	if !c.done.CompareAndSwap(false, true) {
  2133  		return ErrConnDone
  2134  	}
  2135  
  2136  	// Lock around releasing the driver connection
  2137  	// to ensure all queries have been stopped before doing so.
  2138  	c.closemu.Lock()
  2139  	defer c.closemu.Unlock()
  2140  
  2141  	c.dc.releaseConn(err)
  2142  	c.dc = nil
  2143  	c.db = nil
  2144  	return err
  2145  }
  2146  
  2147  // Close returns the connection to the connection pool.
  2148  // All operations after a Close will return with [ErrConnDone].
  2149  // Close is safe to call concurrently with other operations and will
  2150  // block until all other operations finish. It may be useful to first
  2151  // cancel any used context and then call close directly after.
  2152  func (c *Conn) Close() error {
  2153  	return c.close(nil)
  2154  }
  2155  
  2156  // Tx is an in-progress database transaction.
  2157  //
  2158  // A transaction must end with a call to [Tx.Commit] or [Tx.Rollback].
  2159  //
  2160  // After a call to [Tx.Commit] or [Tx.Rollback], all operations on the
  2161  // transaction fail with [ErrTxDone].
  2162  //
  2163  // The statements prepared for a transaction by calling
  2164  // the transaction's [Tx.Prepare] or [Tx.Stmt] methods are closed
  2165  // by the call to [Tx.Commit] or [Tx.Rollback].
  2166  type Tx struct {
  2167  	db *DB
  2168  
  2169  	// closemu prevents the transaction from closing while there
  2170  	// is an active query. It is held for read during queries
  2171  	// and exclusively during close.
  2172  	closemu sync.RWMutex
  2173  
  2174  	// dc is owned exclusively until Commit or Rollback, at which point
  2175  	// it's returned with putConn.
  2176  	dc  *driverConn
  2177  	txi driver.Tx
  2178  
  2179  	// releaseConn is called once the Tx is closed to release
  2180  	// any held driverConn back to the pool.
  2181  	releaseConn func(error)
  2182  
  2183  	// done transitions from false to true exactly once, on Commit
  2184  	// or Rollback. once done, all operations fail with
  2185  	// ErrTxDone.
  2186  	done atomic.Bool
  2187  
  2188  	// keepConnOnRollback is true if the driver knows
  2189  	// how to reset the connection's session and if need be discard
  2190  	// the connection.
  2191  	keepConnOnRollback bool
  2192  
  2193  	// All Stmts prepared for this transaction. These will be closed after the
  2194  	// transaction has been committed or rolled back.
  2195  	stmts struct {
  2196  		sync.Mutex
  2197  		v []*Stmt
  2198  	}
  2199  
  2200  	// cancel is called after done transitions from 0 to 1.
  2201  	cancel func()
  2202  
  2203  	// ctx lives for the life of the transaction.
  2204  	ctx context.Context
  2205  }
  2206  
  2207  // awaitDone blocks until the context in Tx is canceled and rolls back
  2208  // the transaction if it's not already done.
  2209  func (tx *Tx) awaitDone() {
  2210  	// Wait for either the transaction to be committed or rolled
  2211  	// back, or for the associated context to be closed.
  2212  	<-tx.ctx.Done()
  2213  
  2214  	// Discard and close the connection used to ensure the
  2215  	// transaction is closed and the resources are released.  This
  2216  	// rollback does nothing if the transaction has already been
  2217  	// committed or rolled back.
  2218  	// Do not discard the connection if the connection knows
  2219  	// how to reset the session.
  2220  	discardConnection := !tx.keepConnOnRollback
  2221  	tx.rollback(discardConnection)
  2222  }
  2223  
  2224  func (tx *Tx) isDone() bool {
  2225  	return tx.done.Load()
  2226  }
  2227  
  2228  // ErrTxDone is returned by any operation that is performed on a transaction
  2229  // that has already been committed or rolled back.
  2230  var ErrTxDone = errors.New("sql: transaction has already been committed or rolled back")
  2231  
  2232  // close returns the connection to the pool and
  2233  // must only be called by Tx.rollback or Tx.Commit while
  2234  // tx is already canceled and won't be executed concurrently.
  2235  func (tx *Tx) close(err error) {
  2236  	tx.releaseConn(err)
  2237  	tx.dc = nil
  2238  	tx.txi = nil
  2239  }
  2240  
  2241  // hookTxGrabConn specifies an optional hook to be called on
  2242  // a successful call to (*Tx).grabConn. For tests.
  2243  var hookTxGrabConn func()
  2244  
  2245  func (tx *Tx) grabConn(ctx context.Context) (*driverConn, releaseConn, error) {
  2246  	select {
  2247  	default:
  2248  	case <-ctx.Done():
  2249  		return nil, nil, ctx.Err()
  2250  	}
  2251  
  2252  	// closemu.RLock must come before the check for isDone to prevent the Tx from
  2253  	// closing while a query is executing.
  2254  	tx.closemu.RLock()
  2255  	if tx.isDone() {
  2256  		tx.closemu.RUnlock()
  2257  		return nil, nil, ErrTxDone
  2258  	}
  2259  	if hookTxGrabConn != nil { // test hook
  2260  		hookTxGrabConn()
  2261  	}
  2262  	return tx.dc, tx.closemuRUnlockRelease, nil
  2263  }
  2264  
  2265  func (tx *Tx) txCtx() context.Context {
  2266  	return tx.ctx
  2267  }
  2268  
  2269  // closemuRUnlockRelease is used as a func(error) method value in
  2270  // [DB.ExecContext] and [DB.QueryContext]. Unlocking in the releaseConn keeps
  2271  // the driver conn from being returned to the connection pool until
  2272  // the Rows has been closed.
  2273  func (tx *Tx) closemuRUnlockRelease(error) {
  2274  	tx.closemu.RUnlock()
  2275  }
  2276  
  2277  // Closes all Stmts prepared for this transaction.
  2278  func (tx *Tx) closePrepared() {
  2279  	tx.stmts.Lock()
  2280  	defer tx.stmts.Unlock()
  2281  	for _, stmt := range tx.stmts.v {
  2282  		stmt.Close()
  2283  	}
  2284  }
  2285  
  2286  // Commit commits the transaction.
  2287  func (tx *Tx) Commit() error {
  2288  	// Check context first to avoid transaction leak.
  2289  	// If put it behind tx.done CompareAndSwap statement, we can't ensure
  2290  	// the consistency between tx.done and the real COMMIT operation.
  2291  	select {
  2292  	default:
  2293  	case <-tx.ctx.Done():
  2294  		if tx.done.Load() {
  2295  			return ErrTxDone
  2296  		}
  2297  		return tx.ctx.Err()
  2298  	}
  2299  	if !tx.done.CompareAndSwap(false, true) {
  2300  		return ErrTxDone
  2301  	}
  2302  
  2303  	// Cancel the Tx to release any active R-closemu locks.
  2304  	// This is safe to do because tx.done has already transitioned
  2305  	// from 0 to 1. Hold the W-closemu lock prior to rollback
  2306  	// to ensure no other connection has an active query.
  2307  	tx.cancel()
  2308  	tx.closemu.Lock()
  2309  	tx.closemu.Unlock()
  2310  
  2311  	var err error
  2312  	withLock(tx.dc, func() {
  2313  		err = tx.txi.Commit()
  2314  	})
  2315  	if !errors.Is(err, driver.ErrBadConn) {
  2316  		tx.closePrepared()
  2317  	}
  2318  	tx.close(err)
  2319  	return err
  2320  }
  2321  
  2322  var rollbackHook func()
  2323  
  2324  // rollback aborts the transaction and optionally forces the pool to discard
  2325  // the connection.
  2326  func (tx *Tx) rollback(discardConn bool) error {
  2327  	if !tx.done.CompareAndSwap(false, true) {
  2328  		return ErrTxDone
  2329  	}
  2330  
  2331  	if rollbackHook != nil {
  2332  		rollbackHook()
  2333  	}
  2334  
  2335  	// Cancel the Tx to release any active R-closemu locks.
  2336  	// This is safe to do because tx.done has already transitioned
  2337  	// from 0 to 1. Hold the W-closemu lock prior to rollback
  2338  	// to ensure no other connection has an active query.
  2339  	tx.cancel()
  2340  	tx.closemu.Lock()
  2341  	tx.closemu.Unlock()
  2342  
  2343  	var err error
  2344  	withLock(tx.dc, func() {
  2345  		err = tx.txi.Rollback()
  2346  	})
  2347  	if !errors.Is(err, driver.ErrBadConn) {
  2348  		tx.closePrepared()
  2349  	}
  2350  	if discardConn {
  2351  		err = driver.ErrBadConn
  2352  	}
  2353  	tx.close(err)
  2354  	return err
  2355  }
  2356  
  2357  // Rollback aborts the transaction.
  2358  func (tx *Tx) Rollback() error {
  2359  	return tx.rollback(false)
  2360  }
  2361  
  2362  // PrepareContext creates a prepared statement for use within a transaction.
  2363  //
  2364  // The returned statement operates within the transaction and will be closed
  2365  // when the transaction has been committed or rolled back.
  2366  //
  2367  // To use an existing prepared statement on this transaction, see [Tx.Stmt].
  2368  //
  2369  // The provided context will be used for the preparation of the context, not
  2370  // for the execution of the returned statement. The returned statement
  2371  // will run in the transaction context.
  2372  func (tx *Tx) PrepareContext(ctx context.Context, query string) (*Stmt, error) {
  2373  	dc, release, err := tx.grabConn(ctx)
  2374  	if err != nil {
  2375  		return nil, err
  2376  	}
  2377  
  2378  	stmt, err := tx.db.prepareDC(ctx, dc, release, tx, query)
  2379  	if err != nil {
  2380  		return nil, err
  2381  	}
  2382  	tx.stmts.Lock()
  2383  	tx.stmts.v = append(tx.stmts.v, stmt)
  2384  	tx.stmts.Unlock()
  2385  	return stmt, nil
  2386  }
  2387  
  2388  // Prepare creates a prepared statement for use within a transaction.
  2389  //
  2390  // The returned statement operates within the transaction and will be closed
  2391  // when the transaction has been committed or rolled back.
  2392  //
  2393  // To use an existing prepared statement on this transaction, see [Tx.Stmt].
  2394  //
  2395  // Prepare uses [context.Background] internally; to specify the context, use
  2396  // [Tx.PrepareContext].
  2397  func (tx *Tx) Prepare(query string) (*Stmt, error) {
  2398  	return tx.PrepareContext(context.Background(), query)
  2399  }
  2400  
  2401  // StmtContext returns a transaction-specific prepared statement from
  2402  // an existing statement.
  2403  //
  2404  // Example:
  2405  //
  2406  //	updateMoney, err := db.Prepare("UPDATE balance SET money=money+? WHERE id=?")
  2407  //	...
  2408  //	tx, err := db.Begin()
  2409  //	...
  2410  //	res, err := tx.StmtContext(ctx, updateMoney).Exec(123.45, 98293203)
  2411  //
  2412  // The provided context is used for the preparation of the statement, not for the
  2413  // execution of the statement.
  2414  //
  2415  // The returned statement operates within the transaction and will be closed
  2416  // when the transaction has been committed or rolled back.
  2417  func (tx *Tx) StmtContext(ctx context.Context, stmt *Stmt) *Stmt {
  2418  	dc, release, err := tx.grabConn(ctx)
  2419  	if err != nil {
  2420  		return &Stmt{stickyErr: err}
  2421  	}
  2422  	defer release(nil)
  2423  
  2424  	if tx.db != stmt.db {
  2425  		return &Stmt{stickyErr: errors.New("sql: Tx.Stmt: statement from different database used")}
  2426  	}
  2427  	var si driver.Stmt
  2428  	var parentStmt *Stmt
  2429  	stmt.mu.Lock()
  2430  	if stmt.closed || stmt.cg != nil {
  2431  		// If the statement has been closed or already belongs to a
  2432  		// transaction, we can't reuse it in this connection.
  2433  		// Since tx.StmtContext should never need to be called with a
  2434  		// Stmt already belonging to tx, we ignore this edge case and
  2435  		// re-prepare the statement in this case. No need to add
  2436  		// code-complexity for this.
  2437  		stmt.mu.Unlock()
  2438  		withLock(dc, func() {
  2439  			si, err = ctxDriverPrepare(ctx, dc.ci, stmt.query)
  2440  		})
  2441  		if err != nil {
  2442  			return &Stmt{stickyErr: err}
  2443  		}
  2444  	} else {
  2445  		stmt.removeClosedStmtLocked()
  2446  		// See if the statement has already been prepared on this connection,
  2447  		// and reuse it if possible.
  2448  		for _, v := range stmt.css {
  2449  			if v.dc == dc {
  2450  				si = v.ds.si
  2451  				break
  2452  			}
  2453  		}
  2454  
  2455  		stmt.mu.Unlock()
  2456  
  2457  		if si == nil {
  2458  			var ds *driverStmt
  2459  			withLock(dc, func() {
  2460  				ds, err = stmt.prepareOnConnLocked(ctx, dc)
  2461  			})
  2462  			if err != nil {
  2463  				return &Stmt{stickyErr: err}
  2464  			}
  2465  			si = ds.si
  2466  		}
  2467  		parentStmt = stmt
  2468  	}
  2469  
  2470  	txs := &Stmt{
  2471  		db: tx.db,
  2472  		cg: tx,
  2473  		cgds: &driverStmt{
  2474  			Locker: dc,
  2475  			si:     si,
  2476  		},
  2477  		parentStmt: parentStmt,
  2478  		query:      stmt.query,
  2479  	}
  2480  	if parentStmt != nil {
  2481  		tx.db.addDep(parentStmt, txs)
  2482  	}
  2483  	tx.stmts.Lock()
  2484  	tx.stmts.v = append(tx.stmts.v, txs)
  2485  	tx.stmts.Unlock()
  2486  	return txs
  2487  }
  2488  
  2489  // Stmt returns a transaction-specific prepared statement from
  2490  // an existing statement.
  2491  //
  2492  // Example:
  2493  //
  2494  //	updateMoney, err := db.Prepare("UPDATE balance SET money=money+? WHERE id=?")
  2495  //	...
  2496  //	tx, err := db.Begin()
  2497  //	...
  2498  //	res, err := tx.Stmt(updateMoney).Exec(123.45, 98293203)
  2499  //
  2500  // The returned statement operates within the transaction and will be closed
  2501  // when the transaction has been committed or rolled back.
  2502  //
  2503  // Stmt uses [context.Background] internally; to specify the context, use
  2504  // [Tx.StmtContext].
  2505  func (tx *Tx) Stmt(stmt *Stmt) *Stmt {
  2506  	return tx.StmtContext(context.Background(), stmt)
  2507  }
  2508  
  2509  // ExecContext executes a query that doesn't return rows.
  2510  // For example: an INSERT and UPDATE.
  2511  func (tx *Tx) ExecContext(ctx context.Context, query string, args ...any) (Result, error) {
  2512  	dc, release, err := tx.grabConn(ctx)
  2513  	if err != nil {
  2514  		return nil, err
  2515  	}
  2516  	return tx.db.execDC(ctx, dc, release, query, args)
  2517  }
  2518  
  2519  // Exec executes a query that doesn't return rows.
  2520  // For example: an INSERT and UPDATE.
  2521  //
  2522  // Exec uses [context.Background] internally; to specify the context, use
  2523  // [Tx.ExecContext].
  2524  func (tx *Tx) Exec(query string, args ...any) (Result, error) {
  2525  	return tx.ExecContext(context.Background(), query, args...)
  2526  }
  2527  
  2528  // QueryContext executes a query that returns rows, typically a SELECT.
  2529  func (tx *Tx) QueryContext(ctx context.Context, query string, args ...any) (*Rows, error) {
  2530  	dc, release, err := tx.grabConn(ctx)
  2531  	if err != nil {
  2532  		return nil, err
  2533  	}
  2534  
  2535  	return tx.db.queryDC(ctx, tx.ctx, dc, release, query, args)
  2536  }
  2537  
  2538  // Query executes a query that returns rows, typically a SELECT.
  2539  //
  2540  // Query uses [context.Background] internally; to specify the context, use
  2541  // [Tx.QueryContext].
  2542  func (tx *Tx) Query(query string, args ...any) (*Rows, error) {
  2543  	return tx.QueryContext(context.Background(), query, args...)
  2544  }
  2545  
  2546  // QueryRowContext executes a query that is expected to return at most one row.
  2547  // QueryRowContext always returns a non-nil value. Errors are deferred until
  2548  // [Row]'s Scan method is called.
  2549  // If the query selects no rows, the [*Row.Scan] will return [ErrNoRows].
  2550  // Otherwise, the [*Row.Scan] scans the first selected row and discards
  2551  // the rest.
  2552  func (tx *Tx) QueryRowContext(ctx context.Context, query string, args ...any) *Row {
  2553  	rows, err := tx.QueryContext(ctx, query, args...)
  2554  	return &Row{rows: rows, err: err}
  2555  }
  2556  
  2557  // QueryRow executes a query that is expected to return at most one row.
  2558  // QueryRow always returns a non-nil value. Errors are deferred until
  2559  // [Row]'s Scan method is called.
  2560  // If the query selects no rows, the [*Row.Scan] will return [ErrNoRows].
  2561  // Otherwise, the [*Row.Scan] scans the first selected row and discards
  2562  // the rest.
  2563  //
  2564  // QueryRow uses [context.Background] internally; to specify the context, use
  2565  // [Tx.QueryRowContext].
  2566  func (tx *Tx) QueryRow(query string, args ...any) *Row {
  2567  	return tx.QueryRowContext(context.Background(), query, args...)
  2568  }
  2569  
  2570  // connStmt is a prepared statement on a particular connection.
  2571  type connStmt struct {
  2572  	dc *driverConn
  2573  	ds *driverStmt
  2574  }
  2575  
  2576  // stmtConnGrabber represents a Tx or Conn that will return the underlying
  2577  // driverConn and release function.
  2578  type stmtConnGrabber interface {
  2579  	// grabConn returns the driverConn and the associated release function
  2580  	// that must be called when the operation completes.
  2581  	grabConn(context.Context) (*driverConn, releaseConn, error)
  2582  
  2583  	// txCtx returns the transaction context if available.
  2584  	// The returned context should be selected on along with
  2585  	// any query context when awaiting a cancel.
  2586  	txCtx() context.Context
  2587  }
  2588  
  2589  var (
  2590  	_ stmtConnGrabber = &Tx{}
  2591  	_ stmtConnGrabber = &Conn{}
  2592  )
  2593  
  2594  // Stmt is a prepared statement.
  2595  // A Stmt is safe for concurrent use by multiple goroutines.
  2596  //
  2597  // If a Stmt is prepared on a [Tx] or [Conn], it will be bound to a single
  2598  // underlying connection forever. If the [Tx] or [Conn] closes, the Stmt will
  2599  // become unusable and all operations will return an error.
  2600  // If a Stmt is prepared on a [DB], it will remain usable for the lifetime of the
  2601  // [DB]. When the Stmt needs to execute on a new underlying connection, it will
  2602  // prepare itself on the new connection automatically.
  2603  type Stmt struct {
  2604  	// Immutable:
  2605  	db        *DB    // where we came from
  2606  	query     string // that created the Stmt
  2607  	stickyErr error  // if non-nil, this error is returned for all operations
  2608  
  2609  	closemu sync.RWMutex // held exclusively during close, for read otherwise.
  2610  
  2611  	// If Stmt is prepared on a Tx or Conn then cg is present and will
  2612  	// only ever grab a connection from cg.
  2613  	// If cg is nil then the Stmt must grab an arbitrary connection
  2614  	// from db and determine if it must prepare the stmt again by
  2615  	// inspecting css.
  2616  	cg   stmtConnGrabber
  2617  	cgds *driverStmt
  2618  
  2619  	// parentStmt is set when a transaction-specific statement
  2620  	// is requested from an identical statement prepared on the same
  2621  	// conn. parentStmt is used to track the dependency of this statement
  2622  	// on its originating ("parent") statement so that parentStmt may
  2623  	// be closed by the user without them having to know whether or not
  2624  	// any transactions are still using it.
  2625  	parentStmt *Stmt
  2626  
  2627  	mu     sync.Mutex // protects the rest of the fields
  2628  	closed bool
  2629  
  2630  	// css is a list of underlying driver statement interfaces
  2631  	// that are valid on particular connections. This is only
  2632  	// used if cg == nil and one is found that has idle
  2633  	// connections. If cg != nil, cgds is always used.
  2634  	css []connStmt
  2635  
  2636  	// lastNumClosed is copied from db.numClosed when Stmt is created
  2637  	// without tx and closed connections in css are removed.
  2638  	lastNumClosed uint64
  2639  }
  2640  
  2641  // ExecContext executes a prepared statement with the given arguments and
  2642  // returns a [Result] summarizing the effect of the statement.
  2643  func (s *Stmt) ExecContext(ctx context.Context, args ...any) (Result, error) {
  2644  	s.closemu.RLock()
  2645  	defer s.closemu.RUnlock()
  2646  
  2647  	var res Result
  2648  	err := s.db.retry(func(strategy connReuseStrategy) error {
  2649  		dc, releaseConn, ds, err := s.connStmt(ctx, strategy)
  2650  		if err != nil {
  2651  			return err
  2652  		}
  2653  
  2654  		res, err = resultFromStatement(ctx, dc.ci, ds, args...)
  2655  		releaseConn(err)
  2656  		return err
  2657  	})
  2658  
  2659  	return res, err
  2660  }
  2661  
  2662  // Exec executes a prepared statement with the given arguments and
  2663  // returns a [Result] summarizing the effect of the statement.
  2664  //
  2665  // Exec uses [context.Background] internally; to specify the context, use
  2666  // [Stmt.ExecContext].
  2667  func (s *Stmt) Exec(args ...any) (Result, error) {
  2668  	return s.ExecContext(context.Background(), args...)
  2669  }
  2670  
  2671  func resultFromStatement(ctx context.Context, ci driver.Conn, ds *driverStmt, args ...any) (Result, error) {
  2672  	ds.Lock()
  2673  	defer ds.Unlock()
  2674  
  2675  	dargs, err := driverArgsConnLocked(ci, ds, args)
  2676  	if err != nil {
  2677  		return nil, err
  2678  	}
  2679  
  2680  	resi, err := ctxDriverStmtExec(ctx, ds.si, dargs)
  2681  	if err != nil {
  2682  		return nil, err
  2683  	}
  2684  	return driverResult{ds.Locker, resi}, nil
  2685  }
  2686  
  2687  // removeClosedStmtLocked removes closed conns in s.css.
  2688  //
  2689  // To avoid lock contention on DB.mu, we do it only when
  2690  // s.db.numClosed - s.lastNum is large enough.
  2691  func (s *Stmt) removeClosedStmtLocked() {
  2692  	t := len(s.css)/2 + 1
  2693  	if t > 10 {
  2694  		t = 10
  2695  	}
  2696  	dbClosed := s.db.numClosed.Load()
  2697  	if dbClosed-s.lastNumClosed < uint64(t) {
  2698  		return
  2699  	}
  2700  
  2701  	s.db.mu.Lock()
  2702  	for i := 0; i < len(s.css); i++ {
  2703  		if s.css[i].dc.dbmuClosed {
  2704  			s.css[i] = s.css[len(s.css)-1]
  2705  			// Zero out the last element (for GC) before shrinking the slice.
  2706  			s.css[len(s.css)-1] = connStmt{}
  2707  			s.css = s.css[:len(s.css)-1]
  2708  			i--
  2709  		}
  2710  	}
  2711  	s.db.mu.Unlock()
  2712  	s.lastNumClosed = dbClosed
  2713  }
  2714  
  2715  // connStmt returns a free driver connection on which to execute the
  2716  // statement, a function to call to release the connection, and a
  2717  // statement bound to that connection.
  2718  func (s *Stmt) connStmt(ctx context.Context, strategy connReuseStrategy) (dc *driverConn, releaseConn func(error), ds *driverStmt, err error) {
  2719  	if err = s.stickyErr; err != nil {
  2720  		return
  2721  	}
  2722  	s.mu.Lock()
  2723  	if s.closed {
  2724  		s.mu.Unlock()
  2725  		err = errors.New("sql: statement is closed")
  2726  		return
  2727  	}
  2728  
  2729  	// In a transaction or connection, we always use the connection that the
  2730  	// stmt was created on.
  2731  	if s.cg != nil {
  2732  		s.mu.Unlock()
  2733  		dc, releaseConn, err = s.cg.grabConn(ctx) // blocks, waiting for the connection.
  2734  		if err != nil {
  2735  			return
  2736  		}
  2737  		return dc, releaseConn, s.cgds, nil
  2738  	}
  2739  
  2740  	s.removeClosedStmtLocked()
  2741  	s.mu.Unlock()
  2742  
  2743  	dc, err = s.db.conn(ctx, strategy)
  2744  	if err != nil {
  2745  		return nil, nil, nil, err
  2746  	}
  2747  
  2748  	s.mu.Lock()
  2749  	for _, v := range s.css {
  2750  		if v.dc == dc {
  2751  			s.mu.Unlock()
  2752  			return dc, dc.releaseConn, v.ds, nil
  2753  		}
  2754  	}
  2755  	s.mu.Unlock()
  2756  
  2757  	// No luck; we need to prepare the statement on this connection
  2758  	withLock(dc, func() {
  2759  		ds, err = s.prepareOnConnLocked(ctx, dc)
  2760  	})
  2761  	if err != nil {
  2762  		dc.releaseConn(err)
  2763  		return nil, nil, nil, err
  2764  	}
  2765  
  2766  	return dc, dc.releaseConn, ds, nil
  2767  }
  2768  
  2769  // prepareOnConnLocked prepares the query in Stmt s on dc and adds it to the list of
  2770  // open connStmt on the statement. It assumes the caller is holding the lock on dc.
  2771  func (s *Stmt) prepareOnConnLocked(ctx context.Context, dc *driverConn) (*driverStmt, error) {
  2772  	si, err := dc.prepareLocked(ctx, s.cg, s.query)
  2773  	if err != nil {
  2774  		return nil, err
  2775  	}
  2776  	cs := connStmt{dc, si}
  2777  	s.mu.Lock()
  2778  	s.css = append(s.css, cs)
  2779  	s.mu.Unlock()
  2780  	return cs.ds, nil
  2781  }
  2782  
  2783  // QueryContext executes a prepared query statement with the given arguments
  2784  // and returns the query results as a [*Rows].
  2785  func (s *Stmt) QueryContext(ctx context.Context, args ...any) (*Rows, error) {
  2786  	s.closemu.RLock()
  2787  	defer s.closemu.RUnlock()
  2788  
  2789  	var rowsi driver.Rows
  2790  	var rows *Rows
  2791  
  2792  	err := s.db.retry(func(strategy connReuseStrategy) error {
  2793  		dc, releaseConn, ds, err := s.connStmt(ctx, strategy)
  2794  		if err != nil {
  2795  			return err
  2796  		}
  2797  
  2798  		rowsi, err = rowsiFromStatement(ctx, dc.ci, ds, args...)
  2799  		if err == nil {
  2800  			// Note: ownership of ci passes to the *Rows, to be freed
  2801  			// with releaseConn.
  2802  			rows = &Rows{
  2803  				dc:    dc,
  2804  				rowsi: rowsi,
  2805  				// releaseConn set below
  2806  			}
  2807  			// addDep must be added before initContextClose or it could attempt
  2808  			// to removeDep before it has been added.
  2809  			s.db.addDep(s, rows)
  2810  
  2811  			// releaseConn must be set before initContextClose or it could
  2812  			// release the connection before it is set.
  2813  			rows.releaseConn = func(err error) {
  2814  				releaseConn(err)
  2815  				s.db.removeDep(s, rows)
  2816  			}
  2817  			var txctx context.Context
  2818  			if s.cg != nil {
  2819  				txctx = s.cg.txCtx()
  2820  			}
  2821  			rows.initContextClose(ctx, txctx)
  2822  			return nil
  2823  		}
  2824  
  2825  		releaseConn(err)
  2826  		return err
  2827  	})
  2828  
  2829  	return rows, err
  2830  }
  2831  
  2832  // Query executes a prepared query statement with the given arguments
  2833  // and returns the query results as a *Rows.
  2834  //
  2835  // Query uses [context.Background] internally; to specify the context, use
  2836  // [Stmt.QueryContext].
  2837  func (s *Stmt) Query(args ...any) (*Rows, error) {
  2838  	return s.QueryContext(context.Background(), args...)
  2839  }
  2840  
  2841  func rowsiFromStatement(ctx context.Context, ci driver.Conn, ds *driverStmt, args ...any) (driver.Rows, error) {
  2842  	ds.Lock()
  2843  	defer ds.Unlock()
  2844  	dargs, err := driverArgsConnLocked(ci, ds, args)
  2845  	if err != nil {
  2846  		return nil, err
  2847  	}
  2848  	return ctxDriverStmtQuery(ctx, ds.si, dargs)
  2849  }
  2850  
  2851  // QueryRowContext executes a prepared query statement with the given arguments.
  2852  // If an error occurs during the execution of the statement, that error will
  2853  // be returned by a call to Scan on the returned [*Row], which is always non-nil.
  2854  // If the query selects no rows, the [*Row.Scan] will return [ErrNoRows].
  2855  // Otherwise, the [*Row.Scan] scans the first selected row and discards
  2856  // the rest.
  2857  func (s *Stmt) QueryRowContext(ctx context.Context, args ...any) *Row {
  2858  	rows, err := s.QueryContext(ctx, args...)
  2859  	if err != nil {
  2860  		return &Row{err: err}
  2861  	}
  2862  	return &Row{rows: rows}
  2863  }
  2864  
  2865  // QueryRow executes a prepared query statement with the given arguments.
  2866  // If an error occurs during the execution of the statement, that error will
  2867  // be returned by a call to Scan on the returned [*Row], which is always non-nil.
  2868  // If the query selects no rows, the [*Row.Scan] will return [ErrNoRows].
  2869  // Otherwise, the [*Row.Scan] scans the first selected row and discards
  2870  // the rest.
  2871  //
  2872  // Example usage:
  2873  //
  2874  //	var name string
  2875  //	err := nameByUseridStmt.QueryRow(id).Scan(&name)
  2876  //
  2877  // QueryRow uses [context.Background] internally; to specify the context, use
  2878  // [Stmt.QueryRowContext].
  2879  func (s *Stmt) QueryRow(args ...any) *Row {
  2880  	return s.QueryRowContext(context.Background(), args...)
  2881  }
  2882  
  2883  // Close closes the statement.
  2884  func (s *Stmt) Close() error {
  2885  	s.closemu.Lock()
  2886  	defer s.closemu.Unlock()
  2887  
  2888  	if s.stickyErr != nil {
  2889  		return s.stickyErr
  2890  	}
  2891  	s.mu.Lock()
  2892  	if s.closed {
  2893  		s.mu.Unlock()
  2894  		return nil
  2895  	}
  2896  	s.closed = true
  2897  	txds := s.cgds
  2898  	s.cgds = nil
  2899  
  2900  	s.mu.Unlock()
  2901  
  2902  	if s.cg == nil {
  2903  		return s.db.removeDep(s, s)
  2904  	}
  2905  
  2906  	if s.parentStmt != nil {
  2907  		// If parentStmt is set, we must not close s.txds since it's stored
  2908  		// in the css array of the parentStmt.
  2909  		return s.db.removeDep(s.parentStmt, s)
  2910  	}
  2911  	return txds.Close()
  2912  }
  2913  
  2914  func (s *Stmt) finalClose() error {
  2915  	s.mu.Lock()
  2916  	defer s.mu.Unlock()
  2917  	if s.css != nil {
  2918  		for _, v := range s.css {
  2919  			s.db.noteUnusedDriverStatement(v.dc, v.ds)
  2920  			v.dc.removeOpenStmt(v.ds)
  2921  		}
  2922  		s.css = nil
  2923  	}
  2924  	return nil
  2925  }
  2926  
  2927  // Rows is the result of a query. Its cursor starts before the first row
  2928  // of the result set. Use [Rows.Next] to advance from row to row.
  2929  type Rows struct {
  2930  	dc          *driverConn // owned; must call releaseConn when closed to release
  2931  	releaseConn func(error)
  2932  	rowsi       driver.Rows
  2933  	cancel      func()      // called when Rows is closed, may be nil.
  2934  	closeStmt   *driverStmt // if non-nil, statement to Close on close
  2935  
  2936  	contextDone atomic.Pointer[error] // error that awaitDone saw; set before close attempt
  2937  
  2938  	// closemu prevents Rows from closing while there
  2939  	// is an active streaming result. It is held for read during non-close operations
  2940  	// and exclusively during close.
  2941  	//
  2942  	// closemu guards lasterr and closed.
  2943  	closemu sync.RWMutex
  2944  	lasterr error // non-nil only if closed is true
  2945  	closed  bool
  2946  
  2947  	// closemuScanHold is whether the previous call to Scan kept closemu RLock'ed
  2948  	// without unlocking it. It does that when the user passes a *RawBytes scan
  2949  	// target. In that case, we need to prevent awaitDone from closing the Rows
  2950  	// while the user's still using the memory. See go.dev/issue/60304.
  2951  	//
  2952  	// It is only used by Scan, Next, and NextResultSet which are expected
  2953  	// not to be called concurrently.
  2954  	closemuScanHold bool
  2955  
  2956  	// hitEOF is whether Next hit the end of the rows without
  2957  	// encountering an error. It's set in Next before
  2958  	// returning. It's only used by Next and Err which are
  2959  	// expected not to be called concurrently.
  2960  	hitEOF bool
  2961  
  2962  	// lastcols is only used in Scan, Next, and NextResultSet which are expected
  2963  	// not to be called concurrently.
  2964  	lastcols []driver.Value
  2965  
  2966  	// raw is a buffer for RawBytes that persists between Scan calls.
  2967  	// This is used when the driver returns a mismatched type that requires
  2968  	// a cloning allocation. For example, if the driver returns a *string and
  2969  	// the user is scanning into a *RawBytes, we need to copy the string.
  2970  	// The raw buffer here lets us reuse the memory for that copy across Scan calls.
  2971  	raw []byte
  2972  }
  2973  
  2974  // lasterrOrErrLocked returns either lasterr or the provided err.
  2975  // rs.closemu must be read-locked.
  2976  func (rs *Rows) lasterrOrErrLocked(err error) error {
  2977  	if rs.lasterr != nil && rs.lasterr != io.EOF {
  2978  		return rs.lasterr
  2979  	}
  2980  	return err
  2981  }
  2982  
  2983  // bypassRowsAwaitDone is only used for testing.
  2984  // If true, it will not close the Rows automatically from the context.
  2985  var bypassRowsAwaitDone = false
  2986  
  2987  func (rs *Rows) initContextClose(ctx, txctx context.Context) {
  2988  	if ctx.Done() == nil && (txctx == nil || txctx.Done() == nil) {
  2989  		return
  2990  	}
  2991  	if bypassRowsAwaitDone {
  2992  		return
  2993  	}
  2994  	closectx, cancel := context.WithCancel(ctx)
  2995  	rs.cancel = cancel
  2996  	go rs.awaitDone(ctx, txctx, closectx)
  2997  }
  2998  
  2999  // awaitDone blocks until ctx, txctx, or closectx is canceled.
  3000  // The ctx is provided from the query context.
  3001  // If the query was issued in a transaction, the transaction's context
  3002  // is also provided in txctx, to ensure Rows is closed if the Tx is closed.
  3003  // The closectx is closed by an explicit call to rs.Close.
  3004  func (rs *Rows) awaitDone(ctx, txctx, closectx context.Context) {
  3005  	var txctxDone <-chan struct{}
  3006  	if txctx != nil {
  3007  		txctxDone = txctx.Done()
  3008  	}
  3009  	select {
  3010  	case <-ctx.Done():
  3011  		err := ctx.Err()
  3012  		rs.contextDone.Store(&err)
  3013  	case <-txctxDone:
  3014  		err := txctx.Err()
  3015  		rs.contextDone.Store(&err)
  3016  	case <-closectx.Done():
  3017  		// rs.cancel was called via Close(); don't store this into contextDone
  3018  		// to ensure Err() is unaffected.
  3019  	}
  3020  	rs.close(ctx.Err())
  3021  }
  3022  
  3023  // Next prepares the next result row for reading with the [Rows.Scan] method. It
  3024  // returns true on success, or false if there is no next result row or an error
  3025  // happened while preparing it. [Rows.Err] should be consulted to distinguish between
  3026  // the two cases.
  3027  //
  3028  // Every call to [Rows.Scan], even the first one, must be preceded by a call to [Rows.Next].
  3029  func (rs *Rows) Next() bool {
  3030  	// If the user's calling Next, they're done with their previous row's Scan
  3031  	// results (any RawBytes memory), so we can release the read lock that would
  3032  	// be preventing awaitDone from calling close.
  3033  	rs.closemuRUnlockIfHeldByScan()
  3034  
  3035  	if rs.contextDone.Load() != nil {
  3036  		return false
  3037  	}
  3038  
  3039  	var doClose, ok bool
  3040  	withLock(rs.closemu.RLocker(), func() {
  3041  		doClose, ok = rs.nextLocked()
  3042  	})
  3043  	if doClose {
  3044  		rs.Close()
  3045  	}
  3046  	if doClose && !ok {
  3047  		rs.hitEOF = true
  3048  	}
  3049  	return ok
  3050  }
  3051  
  3052  func (rs *Rows) nextLocked() (doClose, ok bool) {
  3053  	if rs.closed {
  3054  		return false, false
  3055  	}
  3056  
  3057  	// Lock the driver connection before calling the driver interface
  3058  	// rowsi to prevent a Tx from rolling back the connection at the same time.
  3059  	rs.dc.Lock()
  3060  	defer rs.dc.Unlock()
  3061  
  3062  	if rs.lastcols == nil {
  3063  		rs.lastcols = make([]driver.Value, len(rs.rowsi.Columns()))
  3064  	}
  3065  
  3066  	rs.lasterr = rs.rowsi.Next(rs.lastcols)
  3067  	if rs.lasterr != nil {
  3068  		// Close the connection if there is a driver error.
  3069  		if rs.lasterr != io.EOF {
  3070  			return true, false
  3071  		}
  3072  		nextResultSet, ok := rs.rowsi.(driver.RowsNextResultSet)
  3073  		if !ok {
  3074  			return true, false
  3075  		}
  3076  		// The driver is at the end of the current result set.
  3077  		// Test to see if there is another result set after the current one.
  3078  		// Only close Rows if there is no further result sets to read.
  3079  		if !nextResultSet.HasNextResultSet() {
  3080  			doClose = true
  3081  		}
  3082  		return doClose, false
  3083  	}
  3084  	return false, true
  3085  }
  3086  
  3087  // NextResultSet prepares the next result set for reading. It reports whether
  3088  // there is further result sets, or false if there is no further result set
  3089  // or if there is an error advancing to it. The [Rows.Err] method should be consulted
  3090  // to distinguish between the two cases.
  3091  //
  3092  // After calling NextResultSet, the [Rows.Next] method should always be called before
  3093  // scanning. If there are further result sets they may not have rows in the result
  3094  // set.
  3095  func (rs *Rows) NextResultSet() bool {
  3096  	// If the user's calling NextResultSet, they're done with their previous
  3097  	// row's Scan results (any RawBytes memory), so we can release the read lock
  3098  	// that would be preventing awaitDone from calling close.
  3099  	rs.closemuRUnlockIfHeldByScan()
  3100  
  3101  	var doClose bool
  3102  	defer func() {
  3103  		if doClose {
  3104  			rs.Close()
  3105  		}
  3106  	}()
  3107  	rs.closemu.RLock()
  3108  	defer rs.closemu.RUnlock()
  3109  
  3110  	if rs.closed {
  3111  		return false
  3112  	}
  3113  
  3114  	rs.lastcols = nil
  3115  	nextResultSet, ok := rs.rowsi.(driver.RowsNextResultSet)
  3116  	if !ok {
  3117  		doClose = true
  3118  		return false
  3119  	}
  3120  
  3121  	// Lock the driver connection before calling the driver interface
  3122  	// rowsi to prevent a Tx from rolling back the connection at the same time.
  3123  	rs.dc.Lock()
  3124  	defer rs.dc.Unlock()
  3125  
  3126  	rs.lasterr = nextResultSet.NextResultSet()
  3127  	if rs.lasterr != nil {
  3128  		doClose = true
  3129  		return false
  3130  	}
  3131  	return true
  3132  }
  3133  
  3134  // Err returns the error, if any, that was encountered during iteration.
  3135  // Err may be called after an explicit or implicit [Rows.Close].
  3136  func (rs *Rows) Err() error {
  3137  	// Return any context error that might've happened during row iteration,
  3138  	// but only if we haven't reported the final Next() = false after rows
  3139  	// are done, in which case the user might've canceled their own context
  3140  	// before calling Rows.Err.
  3141  	if !rs.hitEOF {
  3142  		if errp := rs.contextDone.Load(); errp != nil {
  3143  			return *errp
  3144  		}
  3145  	}
  3146  
  3147  	rs.closemu.RLock()
  3148  	defer rs.closemu.RUnlock()
  3149  	return rs.lasterrOrErrLocked(nil)
  3150  }
  3151  
  3152  // rawbuf returns the buffer to append RawBytes values to.
  3153  // This buffer is reused across calls to Rows.Scan.
  3154  //
  3155  // Usage:
  3156  //
  3157  //	rawBytes = rows.setrawbuf(append(rows.rawbuf(), value...))
  3158  func (rs *Rows) rawbuf() []byte {
  3159  	if rs == nil {
  3160  		// convertAssignRows can take a nil *Rows; for simplicity handle it here
  3161  		return nil
  3162  	}
  3163  	return rs.raw
  3164  }
  3165  
  3166  // setrawbuf updates the RawBytes buffer with the result of appending a new value to it.
  3167  // It returns the new value.
  3168  func (rs *Rows) setrawbuf(b []byte) RawBytes {
  3169  	if rs == nil {
  3170  		// convertAssignRows can take a nil *Rows; for simplicity handle it here
  3171  		return RawBytes(b)
  3172  	}
  3173  	off := len(rs.raw)
  3174  	rs.raw = b
  3175  	return RawBytes(rs.raw[off:])
  3176  }
  3177  
  3178  var errRowsClosed = errors.New("sql: Rows are closed")
  3179  var errNoRows = errors.New("sql: no Rows available")
  3180  
  3181  // Columns returns the column names.
  3182  // Columns returns an error if the rows are closed.
  3183  func (rs *Rows) Columns() ([]string, error) {
  3184  	rs.closemu.RLock()
  3185  	defer rs.closemu.RUnlock()
  3186  	if rs.closed {
  3187  		return nil, rs.lasterrOrErrLocked(errRowsClosed)
  3188  	}
  3189  	if rs.rowsi == nil {
  3190  		return nil, rs.lasterrOrErrLocked(errNoRows)
  3191  	}
  3192  	rs.dc.Lock()
  3193  	defer rs.dc.Unlock()
  3194  
  3195  	return rs.rowsi.Columns(), nil
  3196  }
  3197  
  3198  // ColumnTypes returns column information such as column type, length,
  3199  // and nullable. Some information may not be available from some drivers.
  3200  func (rs *Rows) ColumnTypes() ([]*ColumnType, error) {
  3201  	rs.closemu.RLock()
  3202  	defer rs.closemu.RUnlock()
  3203  	if rs.closed {
  3204  		return nil, rs.lasterrOrErrLocked(errRowsClosed)
  3205  	}
  3206  	if rs.rowsi == nil {
  3207  		return nil, rs.lasterrOrErrLocked(errNoRows)
  3208  	}
  3209  	rs.dc.Lock()
  3210  	defer rs.dc.Unlock()
  3211  
  3212  	return rowsColumnInfoSetupConnLocked(rs.rowsi), nil
  3213  }
  3214  
  3215  // ColumnType contains the name and type of a column.
  3216  type ColumnType struct {
  3217  	name string
  3218  
  3219  	hasNullable       bool
  3220  	hasLength         bool
  3221  	hasPrecisionScale bool
  3222  
  3223  	nullable     bool
  3224  	length       int64
  3225  	databaseType string
  3226  	precision    int64
  3227  	scale        int64
  3228  	scanType     reflect.Type
  3229  }
  3230  
  3231  // Name returns the name or alias of the column.
  3232  func (ci *ColumnType) Name() string {
  3233  	return ci.name
  3234  }
  3235  
  3236  // Length returns the column type length for variable length column types such
  3237  // as text and binary field types. If the type length is unbounded the value will
  3238  // be [math.MaxInt64] (any database limits will still apply).
  3239  // If the column type is not variable length, such as an int, or if not supported
  3240  // by the driver ok is false.
  3241  func (ci *ColumnType) Length() (length int64, ok bool) {
  3242  	return ci.length, ci.hasLength
  3243  }
  3244  
  3245  // DecimalSize returns the scale and precision of a decimal type.
  3246  // If not applicable or if not supported ok is false.
  3247  func (ci *ColumnType) DecimalSize() (precision, scale int64, ok bool) {
  3248  	return ci.precision, ci.scale, ci.hasPrecisionScale
  3249  }
  3250  
  3251  // ScanType returns a Go type suitable for scanning into using [Rows.Scan].
  3252  // If a driver does not support this property ScanType will return
  3253  // the type of an empty interface.
  3254  func (ci *ColumnType) ScanType() reflect.Type {
  3255  	return ci.scanType
  3256  }
  3257  
  3258  // Nullable reports whether the column may be null.
  3259  // If a driver does not support this property ok will be false.
  3260  func (ci *ColumnType) Nullable() (nullable, ok bool) {
  3261  	return ci.nullable, ci.hasNullable
  3262  }
  3263  
  3264  // DatabaseTypeName returns the database system name of the column type. If an empty
  3265  // string is returned, then the driver type name is not supported.
  3266  // Consult your driver documentation for a list of driver data types. [ColumnType.Length] specifiers
  3267  // are not included.
  3268  // Common type names include "VARCHAR", "TEXT", "NVARCHAR", "DECIMAL", "BOOL",
  3269  // "INT", and "BIGINT".
  3270  func (ci *ColumnType) DatabaseTypeName() string {
  3271  	return ci.databaseType
  3272  }
  3273  
  3274  func rowsColumnInfoSetupConnLocked(rowsi driver.Rows) []*ColumnType {
  3275  	names := rowsi.Columns()
  3276  
  3277  	list := make([]*ColumnType, len(names))
  3278  	for i := range list {
  3279  		ci := &ColumnType{
  3280  			name: names[i],
  3281  		}
  3282  		list[i] = ci
  3283  
  3284  		if prop, ok := rowsi.(driver.RowsColumnTypeScanType); ok {
  3285  			ci.scanType = prop.ColumnTypeScanType(i)
  3286  		} else {
  3287  			ci.scanType = reflect.TypeFor[any]()
  3288  		}
  3289  		if prop, ok := rowsi.(driver.RowsColumnTypeDatabaseTypeName); ok {
  3290  			ci.databaseType = prop.ColumnTypeDatabaseTypeName(i)
  3291  		}
  3292  		if prop, ok := rowsi.(driver.RowsColumnTypeLength); ok {
  3293  			ci.length, ci.hasLength = prop.ColumnTypeLength(i)
  3294  		}
  3295  		if prop, ok := rowsi.(driver.RowsColumnTypeNullable); ok {
  3296  			ci.nullable, ci.hasNullable = prop.ColumnTypeNullable(i)
  3297  		}
  3298  		if prop, ok := rowsi.(driver.RowsColumnTypePrecisionScale); ok {
  3299  			ci.precision, ci.scale, ci.hasPrecisionScale = prop.ColumnTypePrecisionScale(i)
  3300  		}
  3301  	}
  3302  	return list
  3303  }
  3304  
  3305  // Scan copies the columns in the current row into the values pointed
  3306  // at by dest. The number of values in dest must be the same as the
  3307  // number of columns in [Rows].
  3308  //
  3309  // Scan converts columns read from the database into the following
  3310  // common Go types and special types provided by the sql package:
  3311  //
  3312  //	*string
  3313  //	*[]byte
  3314  //	*int, *int8, *int16, *int32, *int64
  3315  //	*uint, *uint8, *uint16, *uint32, *uint64
  3316  //	*bool
  3317  //	*float32, *float64
  3318  //	*interface{}
  3319  //	*RawBytes
  3320  //	*Rows (cursor value)
  3321  //	any type implementing Scanner (see Scanner docs)
  3322  //
  3323  // In the most simple case, if the type of the value from the source
  3324  // column is an integer, bool or string type T and dest is of type *T,
  3325  // Scan simply assigns the value through the pointer.
  3326  //
  3327  // Scan also converts between string and numeric types, as long as no
  3328  // information would be lost. While Scan stringifies all numbers
  3329  // scanned from numeric database columns into *string, scans into
  3330  // numeric types are checked for overflow. For example, a float64 with
  3331  // value 300 or a string with value "300" can scan into a uint16, but
  3332  // not into a uint8, though float64(255) or "255" can scan into a
  3333  // uint8. One exception is that scans of some float64 numbers to
  3334  // strings may lose information when stringifying. In general, scan
  3335  // floating point columns into *float64.
  3336  //
  3337  // If a dest argument has type *[]byte, Scan saves in that argument a
  3338  // copy of the corresponding data. The copy is owned by the caller and
  3339  // can be modified and held indefinitely. The copy can be avoided by
  3340  // using an argument of type [*RawBytes] instead; see the documentation
  3341  // for [RawBytes] for restrictions on its use.
  3342  //
  3343  // If an argument has type *interface{}, Scan copies the value
  3344  // provided by the underlying driver without conversion. When scanning
  3345  // from a source value of type []byte to *interface{}, a copy of the
  3346  // slice is made and the caller owns the result.
  3347  //
  3348  // Source values of type [time.Time] may be scanned into values of type
  3349  // *time.Time, *interface{}, *string, or *[]byte. When converting to
  3350  // the latter two, [time.RFC3339Nano] is used.
  3351  //
  3352  // Source values of type bool may be scanned into types *bool,
  3353  // *interface{}, *string, *[]byte, or [*RawBytes].
  3354  //
  3355  // For scanning into *bool, the source may be true, false, 1, 0, or
  3356  // string inputs parseable by [strconv.ParseBool].
  3357  //
  3358  // Scan can also convert a cursor returned from a query, such as
  3359  // "select cursor(select * from my_table) from dual", into a
  3360  // [*Rows] value that can itself be scanned from. The parent
  3361  // select query will close any cursor [*Rows] if the parent [*Rows] is closed.
  3362  //
  3363  // If any of the first arguments implementing [Scanner] returns an error,
  3364  // that error will be wrapped in the returned error.
  3365  func (rs *Rows) Scan(dest ...any) error {
  3366  	if rs.closemuScanHold {
  3367  		// This should only be possible if the user calls Scan twice in a row
  3368  		// without calling Next.
  3369  		return fmt.Errorf("sql: Scan called without calling Next (closemuScanHold)")
  3370  	}
  3371  	rs.closemu.RLock()
  3372  
  3373  	if rs.lasterr != nil && rs.lasterr != io.EOF {
  3374  		rs.closemu.RUnlock()
  3375  		return rs.lasterr
  3376  	}
  3377  	if rs.closed {
  3378  		err := rs.lasterrOrErrLocked(errRowsClosed)
  3379  		rs.closemu.RUnlock()
  3380  		return err
  3381  	}
  3382  
  3383  	if scanArgsContainRawBytes(dest) {
  3384  		rs.closemuScanHold = true
  3385  		rs.raw = rs.raw[:0]
  3386  	} else {
  3387  		rs.closemu.RUnlock()
  3388  	}
  3389  
  3390  	if rs.lastcols == nil {
  3391  		rs.closemuRUnlockIfHeldByScan()
  3392  		return errors.New("sql: Scan called without calling Next")
  3393  	}
  3394  	if len(dest) != len(rs.lastcols) {
  3395  		rs.closemuRUnlockIfHeldByScan()
  3396  		return fmt.Errorf("sql: expected %d destination arguments in Scan, not %d", len(rs.lastcols), len(dest))
  3397  	}
  3398  
  3399  	for i, sv := range rs.lastcols {
  3400  		err := convertAssignRows(dest[i], sv, rs)
  3401  		if err != nil {
  3402  			rs.closemuRUnlockIfHeldByScan()
  3403  			return fmt.Errorf(`sql: Scan error on column index %d, name %q: %w`, i, rs.rowsi.Columns()[i], err)
  3404  		}
  3405  	}
  3406  	return nil
  3407  }
  3408  
  3409  // closemuRUnlockIfHeldByScan releases any closemu.RLock held open by a previous
  3410  // call to Scan with *RawBytes.
  3411  func (rs *Rows) closemuRUnlockIfHeldByScan() {
  3412  	if rs.closemuScanHold {
  3413  		rs.closemuScanHold = false
  3414  		rs.closemu.RUnlock()
  3415  	}
  3416  }
  3417  
  3418  func scanArgsContainRawBytes(args []any) bool {
  3419  	for _, a := range args {
  3420  		if _, ok := a.(*RawBytes); ok {
  3421  			return true
  3422  		}
  3423  	}
  3424  	return false
  3425  }
  3426  
  3427  // rowsCloseHook returns a function so tests may install the
  3428  // hook through a test only mutex.
  3429  var rowsCloseHook = func() func(*Rows, *error) { return nil }
  3430  
  3431  // Close closes the [Rows], preventing further enumeration. If [Rows.Next] is called
  3432  // and returns false and there are no further result sets,
  3433  // the [Rows] are closed automatically and it will suffice to check the
  3434  // result of [Rows.Err]. Close is idempotent and does not affect the result of [Rows.Err].
  3435  func (rs *Rows) Close() error {
  3436  	// If the user's calling Close, they're done with their previous row's Scan
  3437  	// results (any RawBytes memory), so we can release the read lock that would
  3438  	// be preventing awaitDone from calling the unexported close before we do so.
  3439  	rs.closemuRUnlockIfHeldByScan()
  3440  
  3441  	return rs.close(nil)
  3442  }
  3443  
  3444  func (rs *Rows) close(err error) error {
  3445  	rs.closemu.Lock()
  3446  	defer rs.closemu.Unlock()
  3447  
  3448  	if rs.closed {
  3449  		return nil
  3450  	}
  3451  	rs.closed = true
  3452  
  3453  	if rs.lasterr == nil {
  3454  		rs.lasterr = err
  3455  	}
  3456  
  3457  	withLock(rs.dc, func() {
  3458  		err = rs.rowsi.Close()
  3459  	})
  3460  	if fn := rowsCloseHook(); fn != nil {
  3461  		fn(rs, &err)
  3462  	}
  3463  	if rs.cancel != nil {
  3464  		rs.cancel()
  3465  	}
  3466  
  3467  	if rs.closeStmt != nil {
  3468  		rs.closeStmt.Close()
  3469  	}
  3470  	rs.releaseConn(err)
  3471  
  3472  	rs.lasterr = rs.lasterrOrErrLocked(err)
  3473  	return err
  3474  }
  3475  
  3476  // Row is the result of calling [DB.QueryRow] to select a single row.
  3477  type Row struct {
  3478  	// One of these two will be non-nil:
  3479  	err  error // deferred error for easy chaining
  3480  	rows *Rows
  3481  }
  3482  
  3483  // Scan copies the columns from the matched row into the values
  3484  // pointed at by dest. See the documentation on [Rows.Scan] for details.
  3485  // If more than one row matches the query,
  3486  // Scan uses the first row and discards the rest. If no row matches
  3487  // the query, Scan returns [ErrNoRows].
  3488  func (r *Row) Scan(dest ...any) error {
  3489  	if r.err != nil {
  3490  		return r.err
  3491  	}
  3492  
  3493  	// TODO(bradfitz): for now we need to defensively clone all
  3494  	// []byte that the driver returned (not permitting
  3495  	// *RawBytes in Rows.Scan), since we're about to close
  3496  	// the Rows in our defer, when we return from this function.
  3497  	// the contract with the driver.Next(...) interface is that it
  3498  	// can return slices into read-only temporary memory that's
  3499  	// only valid until the next Scan/Close. But the TODO is that
  3500  	// for a lot of drivers, this copy will be unnecessary. We
  3501  	// should provide an optional interface for drivers to
  3502  	// implement to say, "don't worry, the []bytes that I return
  3503  	// from Next will not be modified again." (for instance, if
  3504  	// they were obtained from the network anyway) But for now we
  3505  	// don't care.
  3506  	defer r.rows.Close()
  3507  	if scanArgsContainRawBytes(dest) {
  3508  		return errors.New("sql: RawBytes isn't allowed on Row.Scan")
  3509  	}
  3510  
  3511  	if !r.rows.Next() {
  3512  		if err := r.rows.Err(); err != nil {
  3513  			return err
  3514  		}
  3515  		return ErrNoRows
  3516  	}
  3517  	err := r.rows.Scan(dest...)
  3518  	if err != nil {
  3519  		return err
  3520  	}
  3521  	// Make sure the query can be processed to completion with no errors.
  3522  	return r.rows.Close()
  3523  }
  3524  
  3525  // Err provides a way for wrapping packages to check for
  3526  // query errors without calling [Row.Scan].
  3527  // Err returns the error, if any, that was encountered while running the query.
  3528  // If this error is not nil, this error will also be returned from [Row.Scan].
  3529  func (r *Row) Err() error {
  3530  	return r.err
  3531  }
  3532  
  3533  // A Result summarizes an executed SQL command.
  3534  type Result interface {
  3535  	// LastInsertId returns the integer generated by the database
  3536  	// in response to a command. Typically this will be from an
  3537  	// "auto increment" column when inserting a new row. Not all
  3538  	// databases support this feature, and the syntax of such
  3539  	// statements varies.
  3540  	LastInsertId() (int64, error)
  3541  
  3542  	// RowsAffected returns the number of rows affected by an
  3543  	// update, insert, or delete. Not every database or database
  3544  	// driver may support this.
  3545  	RowsAffected() (int64, error)
  3546  }
  3547  
  3548  type driverResult struct {
  3549  	sync.Locker // the *driverConn
  3550  	resi        driver.Result
  3551  }
  3552  
  3553  func (dr driverResult) LastInsertId() (int64, error) {
  3554  	dr.Lock()
  3555  	defer dr.Unlock()
  3556  	return dr.resi.LastInsertId()
  3557  }
  3558  
  3559  func (dr driverResult) RowsAffected() (int64, error) {
  3560  	dr.Lock()
  3561  	defer dr.Unlock()
  3562  	return dr.resi.RowsAffected()
  3563  }
  3564  
  3565  func stack() string {
  3566  	var buf [2 << 10]byte
  3567  	return string(buf[:runtime.Stack(buf[:], false)])
  3568  }
  3569  
  3570  // withLock runs while holding lk.
  3571  func withLock(lk sync.Locker, fn func()) {
  3572  	lk.Lock()
  3573  	defer lk.Unlock() // in case fn panics
  3574  	fn()
  3575  }
  3576  
  3577  // connRequestSet is a set of chan connRequest that's
  3578  // optimized for:
  3579  //
  3580  //   - adding an element
  3581  //   - removing an element (only by the caller who added it)
  3582  //   - taking (get + delete) a random element
  3583  //
  3584  // We previously used a map for this but the take of a random element
  3585  // was expensive, making mapiters. This type avoids a map entirely
  3586  // and just uses a slice.
  3587  type connRequestSet struct {
  3588  	// s are the elements in the set.
  3589  	s []connRequestAndIndex
  3590  }
  3591  
  3592  type connRequestAndIndex struct {
  3593  	// req is the element in the set.
  3594  	req chan connRequest
  3595  
  3596  	// curIdx points to the current location of this element in
  3597  	// connRequestSet.s. It gets set to -1 upon removal.
  3598  	curIdx *int
  3599  }
  3600  
  3601  // CloseAndRemoveAll closes all channels in the set
  3602  // and clears the set.
  3603  func (s *connRequestSet) CloseAndRemoveAll() {
  3604  	for _, v := range s.s {
  3605  		*v.curIdx = -1
  3606  		close(v.req)
  3607  	}
  3608  	s.s = nil
  3609  }
  3610  
  3611  // Len returns the length of the set.
  3612  func (s *connRequestSet) Len() int { return len(s.s) }
  3613  
  3614  // connRequestDelHandle is an opaque handle to delete an
  3615  // item from calling Add.
  3616  type connRequestDelHandle struct {
  3617  	idx *int // pointer to index; or -1 if not in slice
  3618  }
  3619  
  3620  // Add adds v to the set of waiting requests.
  3621  // The returned connRequestDelHandle can be used to remove the item from
  3622  // the set.
  3623  func (s *connRequestSet) Add(v chan connRequest) connRequestDelHandle {
  3624  	idx := len(s.s)
  3625  	// TODO(bradfitz): for simplicity, this always allocates a new int-sized
  3626  	// allocation to store the index. But generally the set will be small and
  3627  	// under a scannable-threshold. As an optimization, we could permit the *int
  3628  	// to be nil when the set is small and should be scanned. This works even if
  3629  	// the set grows over the threshold with delete handles outstanding because
  3630  	// an element can only move to a lower index. So if it starts with a nil
  3631  	// position, it'll always be in a low index and thus scannable. But that
  3632  	// can be done in a follow-up change.
  3633  	idxPtr := &idx
  3634  	s.s = append(s.s, connRequestAndIndex{v, idxPtr})
  3635  	return connRequestDelHandle{idxPtr}
  3636  }
  3637  
  3638  // Delete removes an element from the set.
  3639  //
  3640  // It reports whether the element was deleted. (It can return false if a caller
  3641  // of TakeRandom took it meanwhile, or upon the second call to Delete)
  3642  func (s *connRequestSet) Delete(h connRequestDelHandle) bool {
  3643  	idx := *h.idx
  3644  	if idx < 0 {
  3645  		return false
  3646  	}
  3647  	s.deleteIndex(idx)
  3648  	return true
  3649  }
  3650  
  3651  func (s *connRequestSet) deleteIndex(idx int) {
  3652  	// Mark item as deleted.
  3653  	*(s.s[idx].curIdx) = -1
  3654  	// Copy last element, updating its position
  3655  	// to its new home.
  3656  	if idx < len(s.s)-1 {
  3657  		last := s.s[len(s.s)-1]
  3658  		*last.curIdx = idx
  3659  		s.s[idx] = last
  3660  	}
  3661  	// Zero out last element (for GC) before shrinking the slice.
  3662  	s.s[len(s.s)-1] = connRequestAndIndex{}
  3663  	s.s = s.s[:len(s.s)-1]
  3664  }
  3665  
  3666  // TakeRandom returns and removes a random element from s
  3667  // and reports whether there was one to take. (It returns ok=false
  3668  // if the set is empty.)
  3669  func (s *connRequestSet) TakeRandom() (v chan connRequest, ok bool) {
  3670  	if len(s.s) == 0 {
  3671  		return nil, false
  3672  	}
  3673  	pick := rand.IntN(len(s.s))
  3674  	e := s.s[pick]
  3675  	s.deleteIndex(pick)
  3676  	return e.req, true
  3677  }
  3678  

View as plain text