1
2
3
4
5
6
7 package sql
8
9 import (
10 "bytes"
11 "database/sql/driver"
12 "errors"
13 "fmt"
14 "reflect"
15 "strconv"
16 "time"
17 "unicode"
18 "unicode/utf8"
19 _ "unsafe"
20 "uuid"
21 )
22
23 var errNilPtr = errors.New("destination pointer is nil")
24
25 func describeNamedValue(nv *driver.NamedValue) string {
26 if len(nv.Name) == 0 {
27 return fmt.Sprintf("$%d", nv.Ordinal)
28 }
29 return fmt.Sprintf("with name %q", nv.Name)
30 }
31
32 func validateNamedValueName(name string) error {
33 if len(name) == 0 {
34 return nil
35 }
36 r, _ := utf8.DecodeRuneInString(name)
37 if unicode.IsLetter(r) {
38 return nil
39 }
40 return fmt.Errorf("name %q does not begin with a letter", name)
41 }
42
43
44
45
46 type ccChecker struct {
47 cci driver.ColumnConverter
48 want int
49 }
50
51 func (c ccChecker) CheckNamedValue(nv *driver.NamedValue) error {
52 if c.cci == nil {
53 return driver.ErrSkip
54 }
55
56
57
58 index := nv.Ordinal - 1
59 if c.want >= 0 && c.want <= index {
60 return nil
61 }
62
63
64
65
66 if vr, ok := nv.Value.(driver.Valuer); ok {
67 sv, err := callValuerValue(vr)
68 if err != nil {
69 return err
70 }
71 if !driver.IsValue(sv) {
72 return fmt.Errorf("non-subset type %T returned from Value", sv)
73 }
74 nv.Value = sv
75 }
76
77
78
79
80
81
82
83
84 var err error
85 arg := nv.Value
86 nv.Value, err = c.cci.ColumnConverter(index).ConvertValue(arg)
87 if err != nil {
88 return err
89 }
90 if !driver.IsValue(nv.Value) {
91 return fmt.Errorf("driver ColumnConverter error converted %T to unsupported type %T", arg, nv.Value)
92 }
93 return nil
94 }
95
96
97
98
99 func defaultCheckNamedValue(nv *driver.NamedValue) (err error) {
100 nv.Value, err = driver.DefaultParameterConverter.ConvertValue(nv.Value)
101 return err
102 }
103
104
105
106
107
108
109
110 func driverArgsConnLocked(ci driver.Conn, ds *driverStmt, args []any) ([]driver.NamedValue, error) {
111 nvargs := make([]driver.NamedValue, len(args))
112
113
114
115
116 want := -1
117
118 var si driver.Stmt
119 var cc ccChecker
120 if ds != nil {
121 si = ds.si
122 want = ds.si.NumInput()
123 cc.want = want
124 }
125
126
127
128
129
130 nvc, ok := si.(driver.NamedValueChecker)
131 if !ok {
132 nvc, _ = ci.(driver.NamedValueChecker)
133 }
134 cci, ok := si.(driver.ColumnConverter)
135 if ok {
136 cc.cci = cci
137 }
138
139
140
141
142
143
144 var err error
145 var n int
146 for _, arg := range args {
147 nv := &nvargs[n]
148 if np, ok := arg.(NamedArg); ok {
149 if err = validateNamedValueName(np.Name); err != nil {
150 return nil, err
151 }
152 arg = np.Value
153 nv.Name = np.Name
154 }
155 nv.Ordinal = n + 1
156 nv.Value = arg
157
158
159
160
161
162
163
164
165
166
167
168
169 checker := defaultCheckNamedValue
170 nextCC := false
171 switch {
172 case nvc != nil:
173 nextCC = cci != nil
174 checker = nvc.CheckNamedValue
175 case cci != nil:
176 checker = cc.CheckNamedValue
177 }
178
179 nextCheck:
180 err = checker(nv)
181 switch err {
182 case nil:
183 n++
184 continue
185 case driver.ErrRemoveArgument:
186 nvargs = nvargs[:len(nvargs)-1]
187 continue
188 case driver.ErrSkip:
189 if nextCC {
190 nextCC = false
191 checker = cc.CheckNamedValue
192 } else {
193 checker = defaultCheckNamedValue
194 }
195 goto nextCheck
196 default:
197 return nil, fmt.Errorf("sql: converting argument %s type: %w", describeNamedValue(nv), err)
198 }
199 }
200
201
202
203 if want != -1 && len(nvargs) != want {
204 return nil, fmt.Errorf("sql: expected %d arguments, got %d", want, len(nvargs))
205 }
206
207 return nvargs, nil
208 }
209
210
211
212
213
214
215
216
217
218
219
220
221
222 func convertAssign(dest, src any) error {
223 return convertAssignRows(dest, src, nil)
224 }
225
226
227
228
229
230
231 func convertAssignRows(dest, src any, rows *Rows) error {
232
233 switch s := src.(type) {
234 case string:
235 switch d := dest.(type) {
236 case *string:
237 if d == nil {
238 return errNilPtr
239 }
240 *d = s
241 return nil
242 case *[]byte:
243 if d == nil {
244 return errNilPtr
245 }
246 *d = []byte(s)
247 return nil
248 case *RawBytes:
249 if d == nil {
250 return errNilPtr
251 }
252 *d = rows.setrawbuf(append(rows.rawbuf(), s...))
253 return nil
254 case *uuid.UUID:
255 if d == nil {
256 return errNilPtr
257 }
258 u, err := uuid.Parse(s)
259 if err != nil {
260 return fmt.Errorf("converting driver.Value type string (%q) to a UUID: %v", s, err)
261 }
262 *d = u
263 return nil
264 }
265 case []byte:
266 switch d := dest.(type) {
267 case *string:
268 if d == nil {
269 return errNilPtr
270 }
271 *d = string(s)
272 return nil
273 case *any:
274 if d == nil {
275 return errNilPtr
276 }
277 *d = bytes.Clone(s)
278 return nil
279 case *[]byte:
280 if d == nil {
281 return errNilPtr
282 }
283 *d = bytes.Clone(s)
284 return nil
285 case *RawBytes:
286 if d == nil {
287 return errNilPtr
288 }
289 *d = s
290 return nil
291 case *uuid.UUID:
292 if d == nil {
293 return errNilPtr
294 }
295 if len(s) == len(*d) {
296 copy((*d)[:], s)
297 return nil
298 }
299 var u uuid.UUID
300 err := u.UnmarshalText(s)
301 if err != nil {
302 return fmt.Errorf("converting driver.Value type []byte (%q) to a UUID: %v", s, err)
303 }
304 *d = u
305 return nil
306 }
307 case time.Time:
308 switch d := dest.(type) {
309 case *time.Time:
310 *d = s
311 return nil
312 case *string:
313 *d = s.Format(time.RFC3339Nano)
314 return nil
315 case *[]byte:
316 if d == nil {
317 return errNilPtr
318 }
319 *d = s.AppendFormat(make([]byte, 0, len(time.RFC3339Nano)), time.RFC3339Nano)
320 return nil
321 case *RawBytes:
322 if d == nil {
323 return errNilPtr
324 }
325 *d = rows.setrawbuf(s.AppendFormat(rows.rawbuf(), time.RFC3339Nano))
326 return nil
327 }
328 case decimalDecompose:
329 switch d := dest.(type) {
330 case decimalCompose:
331 return d.Compose(s.Decompose(nil))
332 }
333 case nil:
334 switch d := dest.(type) {
335 case *any:
336 if d == nil {
337 return errNilPtr
338 }
339 *d = nil
340 return nil
341 case *[]byte:
342 if d == nil {
343 return errNilPtr
344 }
345 *d = nil
346 return nil
347 case *RawBytes:
348 if d == nil {
349 return errNilPtr
350 }
351 *d = nil
352 return nil
353 }
354
355 case driver.Rows:
356 switch d := dest.(type) {
357 case *Rows:
358 if d == nil {
359 return errNilPtr
360 }
361 if rows == nil {
362 return errors.New("invalid context to convert cursor rows, missing parent *Rows")
363 }
364 *d = Rows{
365 dc: rows.dc,
366 releaseConn: func(error) {},
367 rowsi: s,
368 }
369
370 parentCancel := rows.cancel
371 rows.cancel = func() {
372
373
374 d.close(rows.lasterr)
375 if parentCancel != nil {
376 parentCancel()
377 }
378 }
379 return nil
380 }
381 }
382
383 var sv reflect.Value
384
385 switch d := dest.(type) {
386 case *string:
387 sv = reflect.ValueOf(src)
388 switch sv.Kind() {
389 case reflect.Bool,
390 reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64,
391 reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64,
392 reflect.Float32, reflect.Float64:
393 *d = asString(src)
394 return nil
395 }
396 case *[]byte:
397 sv = reflect.ValueOf(src)
398 if b, ok := asBytes(nil, sv); ok {
399 *d = b
400 return nil
401 }
402 case *RawBytes:
403 sv = reflect.ValueOf(src)
404 if b, ok := asBytes(rows.rawbuf(), sv); ok {
405 *d = rows.setrawbuf(b)
406 return nil
407 }
408 case *bool:
409 bv, err := driver.Bool.ConvertValue(src)
410 if err == nil {
411 *d = bv.(bool)
412 }
413 return err
414 case *any:
415 *d = src
416 return nil
417 }
418
419 if scanner, ok := dest.(Scanner); ok {
420 return scanner.Scan(src)
421 }
422
423 dpv := reflect.ValueOf(dest)
424 if dpv.Kind() != reflect.Pointer {
425 return errors.New("destination not a pointer")
426 }
427 if dpv.IsNil() {
428 return errNilPtr
429 }
430
431 if !sv.IsValid() {
432 sv = reflect.ValueOf(src)
433 }
434
435 dv := reflect.Indirect(dpv)
436 if sv.IsValid() && sv.Type().AssignableTo(dv.Type()) {
437 switch b := src.(type) {
438 case []byte:
439 dv.Set(reflect.ValueOf(bytes.Clone(b)))
440 default:
441 dv.Set(sv)
442 }
443 return nil
444 }
445
446 if dv.Kind() == sv.Kind() && sv.Type().ConvertibleTo(dv.Type()) {
447 dv.Set(sv.Convert(dv.Type()))
448 return nil
449 }
450
451
452
453
454
455
456 switch dv.Kind() {
457 case reflect.Pointer:
458 if src == nil {
459 dv.SetZero()
460 return nil
461 }
462 dv.Set(reflect.New(dv.Type().Elem()))
463 return convertAssignRows(dv.Interface(), src, rows)
464 case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64:
465 if src == nil {
466 return fmt.Errorf("converting NULL to %s is unsupported", dv.Kind())
467 }
468 s := asString(src)
469 i64, err := strconv.ParseInt(s, 10, dv.Type().Bits())
470 if err != nil {
471 err = strconvErr(err)
472 return fmt.Errorf("converting driver.Value type %T (%q) to a %s: %v", src, s, dv.Kind(), err)
473 }
474 dv.SetInt(i64)
475 return nil
476 case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64:
477 if src == nil {
478 return fmt.Errorf("converting NULL to %s is unsupported", dv.Kind())
479 }
480 s := asString(src)
481 u64, err := strconv.ParseUint(s, 10, dv.Type().Bits())
482 if err != nil {
483 err = strconvErr(err)
484 return fmt.Errorf("converting driver.Value type %T (%q) to a %s: %v", src, s, dv.Kind(), err)
485 }
486 dv.SetUint(u64)
487 return nil
488 case reflect.Float32, reflect.Float64:
489 if src == nil {
490 return fmt.Errorf("converting NULL to %s is unsupported", dv.Kind())
491 }
492 s := asString(src)
493 f64, err := strconv.ParseFloat(s, dv.Type().Bits())
494 if err != nil {
495 err = strconvErr(err)
496 return fmt.Errorf("converting driver.Value type %T (%q) to a %s: %v", src, s, dv.Kind(), err)
497 }
498 dv.SetFloat(f64)
499 return nil
500 case reflect.String:
501 if src == nil {
502 return fmt.Errorf("converting NULL to %s is unsupported", dv.Kind())
503 }
504 switch v := src.(type) {
505 case string:
506 dv.SetString(v)
507 return nil
508 case []byte:
509 dv.SetString(string(v))
510 return nil
511 }
512 }
513
514 return fmt.Errorf("unsupported Scan, storing driver.Value type %T into type %T", src, dest)
515 }
516
517 func strconvErr(err error) error {
518 if ne, ok := err.(*strconv.NumError); ok {
519 return ne.Err
520 }
521 return err
522 }
523
524 func asString(src any) string {
525 switch v := src.(type) {
526 case string:
527 return v
528 case []byte:
529 return string(v)
530 }
531 rv := reflect.ValueOf(src)
532 switch rv.Kind() {
533 case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64:
534 return strconv.FormatInt(rv.Int(), 10)
535 case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64:
536 return strconv.FormatUint(rv.Uint(), 10)
537 case reflect.Float64:
538 return strconv.FormatFloat(rv.Float(), 'g', -1, 64)
539 case reflect.Float32:
540 return strconv.FormatFloat(rv.Float(), 'g', -1, 32)
541 case reflect.Bool:
542 return strconv.FormatBool(rv.Bool())
543 }
544 return fmt.Sprintf("%v", src)
545 }
546
547 func asBytes(buf []byte, rv reflect.Value) (b []byte, ok bool) {
548 switch rv.Kind() {
549 case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64:
550 return strconv.AppendInt(buf, rv.Int(), 10), true
551 case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64:
552 return strconv.AppendUint(buf, rv.Uint(), 10), true
553 case reflect.Float32:
554 return strconv.AppendFloat(buf, rv.Float(), 'g', -1, 32), true
555 case reflect.Float64:
556 return strconv.AppendFloat(buf, rv.Float(), 'g', -1, 64), true
557 case reflect.Bool:
558 return strconv.AppendBool(buf, rv.Bool()), true
559 case reflect.String:
560 s := rv.String()
561 return append(buf, s...), true
562 }
563 return
564 }
565
566 var valuerReflectType = reflect.TypeFor[driver.Valuer]()
567
568
569
570
571
572
573
574
575
576
577
578
579 func callValuerValue(vr driver.Valuer) (v driver.Value, err error) {
580 if rv := reflect.ValueOf(vr); rv.Kind() == reflect.Pointer &&
581 rv.IsNil() &&
582 rv.Type().Elem().Implements(valuerReflectType) {
583 return nil, nil
584 }
585 return vr.Value()
586 }
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609 type decimal interface {
610 decimalDecompose
611 decimalCompose
612 }
613
614 type decimalDecompose interface {
615
616
617
618 Decompose(buf []byte) (form byte, negative bool, coefficient []byte, exponent int32)
619 }
620
621 type decimalCompose interface {
622
623
624 Compose(form byte, negative bool, coefficient []byte, exponent int32) error
625 }
626
View as plain text