1
2
3
4
5 package template
6
7 import (
8 "errors"
9 "fmt"
10 "io"
11 "net/url"
12 "reflect"
13 "strings"
14 "sync"
15 "unicode"
16 "unicode/utf8"
17 )
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33 type FuncMap map[string]any
34
35
36
37
38
39 func builtins() FuncMap {
40 return FuncMap{
41 "and": and,
42 "call": emptyCall,
43 "html": HTMLEscaper,
44 "index": index,
45 "slice": slice,
46 "js": JSEscaper,
47 "len": length,
48 "not": not,
49 "or": or,
50 "print": fmt.Sprint,
51 "printf": fmt.Sprintf,
52 "println": fmt.Sprintln,
53 "urlquery": URLQueryEscaper,
54
55
56 "eq": eq,
57 "ge": ge,
58 "gt": gt,
59 "le": le,
60 "lt": lt,
61 "ne": ne,
62 }
63 }
64
65 var builtinFuncsOnce struct {
66 sync.Once
67 v map[string]reflect.Value
68 }
69
70
71
72 func builtinFuncs() map[string]reflect.Value {
73 builtinFuncsOnce.Do(func() {
74 builtinFuncsOnce.v = createValueFuncs(builtins())
75 })
76 return builtinFuncsOnce.v
77 }
78
79
80 func createValueFuncs(funcMap FuncMap) map[string]reflect.Value {
81 m := make(map[string]reflect.Value)
82 addValueFuncs(m, funcMap)
83 return m
84 }
85
86
87 func addValueFuncs(out map[string]reflect.Value, in FuncMap) {
88 for name, fn := range in {
89 if !goodName(name) {
90 panic(fmt.Errorf("function name %q is not a valid identifier", name))
91 }
92 v := reflect.ValueOf(fn)
93 if v.Kind() != reflect.Func {
94 panic("value for " + name + " not a function")
95 }
96 if err := goodFunc(name, v.Type()); err != nil {
97 panic(err)
98 }
99 out[name] = v
100 }
101 }
102
103
104
105 func addFuncs(out, in FuncMap) {
106 for name, fn := range in {
107 out[name] = fn
108 }
109 }
110
111
112 func goodFunc(name string, typ reflect.Type) error {
113
114 switch numOut := typ.NumOut(); {
115 case numOut == 1:
116 return nil
117 case numOut == 2 && typ.Out(1) == errorType:
118 return nil
119 case numOut == 2:
120 return fmt.Errorf("invalid function signature for %s: second return value should be error; is %s", name, typ.Out(1))
121 default:
122 return fmt.Errorf("function %s has %d return values; should be 1 or 2", name, typ.NumOut())
123 }
124 }
125
126
127 func goodName(name string) bool {
128 if name == "" {
129 return false
130 }
131 for i, r := range name {
132 switch {
133 case r == '_':
134 case i == 0 && !unicode.IsLetter(r):
135 return false
136 case !unicode.IsLetter(r) && !unicode.IsDigit(r):
137 return false
138 }
139 }
140 return true
141 }
142
143
144 func findFunction(name string, tmpl *Template) (v reflect.Value, isBuiltin, ok bool) {
145 if tmpl != nil && tmpl.common != nil {
146 tmpl.muFuncs.RLock()
147 defer tmpl.muFuncs.RUnlock()
148 if fn := tmpl.execFuncs[name]; fn.IsValid() {
149 return fn, false, true
150 }
151 }
152 if fn := builtinFuncs()[name]; fn.IsValid() {
153 return fn, true, true
154 }
155 return reflect.Value{}, false, false
156 }
157
158
159
160 func prepareArg(value reflect.Value, argType reflect.Type) (reflect.Value, error) {
161 if !value.IsValid() {
162 if !canBeNil(argType) {
163 return reflect.Value{}, fmt.Errorf("value is nil; should be of type %s", argType)
164 }
165 value = reflect.Zero(argType)
166 }
167 if value.Type().AssignableTo(argType) {
168 return value, nil
169 }
170 if intLike(value.Kind()) && intLike(argType.Kind()) && value.Type().ConvertibleTo(argType) {
171 value = value.Convert(argType)
172 return value, nil
173 }
174 return reflect.Value{}, fmt.Errorf("value has type %s; should be %s", value.Type(), argType)
175 }
176
177 func intLike(typ reflect.Kind) bool {
178 switch typ {
179 case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64:
180 return true
181 case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64, reflect.Uintptr:
182 return true
183 }
184 return false
185 }
186
187
188 func indexArg(index reflect.Value, cap int) (int, error) {
189 var x int64
190 switch index.Kind() {
191 case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64:
192 x = index.Int()
193 case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64, reflect.Uintptr:
194 x = int64(index.Uint())
195 case reflect.Invalid:
196 return 0, fmt.Errorf("cannot index slice/array with nil")
197 default:
198 return 0, fmt.Errorf("cannot index slice/array with type %s", index.Type())
199 }
200 if x < 0 || int(x) < 0 || int(x) > cap {
201 return 0, fmt.Errorf("index out of range: %d", x)
202 }
203 return int(x), nil
204 }
205
206
207
208
209
210
211 func index(item reflect.Value, indexes ...reflect.Value) (reflect.Value, error) {
212 item = indirectInterface(item)
213 if !item.IsValid() {
214 return reflect.Value{}, fmt.Errorf("index of untyped nil")
215 }
216 for _, index := range indexes {
217 index = indirectInterface(index)
218 var isNil bool
219 if item, isNil = indirect(item); isNil {
220 return reflect.Value{}, fmt.Errorf("index of nil pointer")
221 }
222 switch item.Kind() {
223 case reflect.Array, reflect.Slice, reflect.String:
224 x, err := indexArg(index, item.Len())
225 if err != nil {
226 return reflect.Value{}, err
227 }
228 item = item.Index(x)
229 case reflect.Map:
230 index, err := prepareArg(index, item.Type().Key())
231 if err != nil {
232 return reflect.Value{}, err
233 }
234 if x := item.MapIndex(index); x.IsValid() {
235 item = x
236 } else {
237 item = reflect.Zero(item.Type().Elem())
238 }
239 case reflect.Invalid:
240
241 panic("unreachable")
242 default:
243 return reflect.Value{}, fmt.Errorf("can't index item of type %s", item.Type())
244 }
245 }
246 return item, nil
247 }
248
249
250
251
252
253
254
255 func slice(item reflect.Value, indexes ...reflect.Value) (reflect.Value, error) {
256 item = indirectInterface(item)
257 if !item.IsValid() {
258 return reflect.Value{}, fmt.Errorf("slice of untyped nil")
259 }
260 if len(indexes) > 3 {
261 return reflect.Value{}, fmt.Errorf("too many slice indexes: %d", len(indexes))
262 }
263 var cap int
264 switch item.Kind() {
265 case reflect.String:
266 if len(indexes) == 3 {
267 return reflect.Value{}, fmt.Errorf("cannot 3-index slice a string")
268 }
269 cap = item.Len()
270 case reflect.Array, reflect.Slice:
271 cap = item.Cap()
272 default:
273 return reflect.Value{}, fmt.Errorf("can't slice item of type %s", item.Type())
274 }
275
276 idx := [3]int{0, item.Len()}
277 for i, index := range indexes {
278 x, err := indexArg(index, cap)
279 if err != nil {
280 return reflect.Value{}, err
281 }
282 idx[i] = x
283 }
284
285 if idx[0] > idx[1] {
286 return reflect.Value{}, fmt.Errorf("invalid slice index: %d > %d", idx[0], idx[1])
287 }
288 if len(indexes) < 3 {
289 return item.Slice(idx[0], idx[1]), nil
290 }
291
292 if idx[1] > idx[2] {
293 return reflect.Value{}, fmt.Errorf("invalid slice index: %d > %d", idx[1], idx[2])
294 }
295 return item.Slice3(idx[0], idx[1], idx[2]), nil
296 }
297
298
299
300
301 func length(item reflect.Value) (int, error) {
302 item, isNil := indirect(item)
303 if isNil {
304 return 0, fmt.Errorf("len of nil pointer")
305 }
306 switch item.Kind() {
307 case reflect.Array, reflect.Chan, reflect.Map, reflect.Slice, reflect.String:
308 return item.Len(), nil
309 }
310 return 0, fmt.Errorf("len of type %s", item.Type())
311 }
312
313
314
315 func emptyCall(fn reflect.Value, args ...reflect.Value) reflect.Value {
316 panic("unreachable")
317 }
318
319
320
321 func call(name string, fn reflect.Value, args ...reflect.Value) (reflect.Value, error) {
322 fn = indirectInterface(fn)
323 if !fn.IsValid() {
324 return reflect.Value{}, fmt.Errorf("call of nil")
325 }
326 typ := fn.Type()
327 if typ.Kind() != reflect.Func {
328 return reflect.Value{}, fmt.Errorf("non-function %s of type %s", name, typ)
329 }
330
331 if err := goodFunc(name, typ); err != nil {
332 return reflect.Value{}, err
333 }
334 numIn := typ.NumIn()
335 var dddType reflect.Type
336 if typ.IsVariadic() {
337 if len(args) < numIn-1 {
338 return reflect.Value{}, fmt.Errorf("wrong number of args for %s: got %d want at least %d", name, len(args), numIn-1)
339 }
340 dddType = typ.In(numIn - 1).Elem()
341 } else {
342 if len(args) != numIn {
343 return reflect.Value{}, fmt.Errorf("wrong number of args for %s: got %d want %d", name, len(args), numIn)
344 }
345 }
346 argv := make([]reflect.Value, len(args))
347 for i, arg := range args {
348 arg = indirectInterface(arg)
349
350 argType := dddType
351 if !typ.IsVariadic() || i < numIn-1 {
352 argType = typ.In(i)
353 }
354
355 var err error
356 if argv[i], err = prepareArg(arg, argType); err != nil {
357 return reflect.Value{}, fmt.Errorf("arg %d: %w", i, err)
358 }
359 }
360 return safeCall(fn, argv)
361 }
362
363
364
365 func safeCall(fun reflect.Value, args []reflect.Value) (val reflect.Value, err error) {
366 defer func() {
367 if r := recover(); r != nil {
368 if e, ok := r.(error); ok {
369 err = e
370 } else {
371 err = fmt.Errorf("%v", r)
372 }
373 }
374 }()
375 ret := fun.Call(args)
376 if len(ret) == 2 && !ret[1].IsNil() {
377 return ret[0], ret[1].Interface().(error)
378 }
379 return ret[0], nil
380 }
381
382
383
384 func truth(arg reflect.Value) bool {
385 t, _ := isTrue(indirectInterface(arg))
386 return t
387 }
388
389
390
391 func and(arg0 reflect.Value, args ...reflect.Value) reflect.Value {
392 panic("unreachable")
393 }
394
395
396
397 func or(arg0 reflect.Value, args ...reflect.Value) reflect.Value {
398 panic("unreachable")
399 }
400
401
402 func not(arg reflect.Value) bool {
403 return !truth(arg)
404 }
405
406
407
408
409
410 var (
411 errBadComparisonType = errors.New("invalid type for comparison")
412 errBadComparison = errors.New("incompatible types for comparison")
413 errNoComparison = errors.New("missing argument for comparison")
414 )
415
416 type kind int
417
418 const (
419 invalidKind kind = iota
420 boolKind
421 complexKind
422 intKind
423 floatKind
424 stringKind
425 uintKind
426 )
427
428 func basicKind(v reflect.Value) (kind, error) {
429 switch v.Kind() {
430 case reflect.Bool:
431 return boolKind, nil
432 case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64:
433 return intKind, nil
434 case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64, reflect.Uintptr:
435 return uintKind, nil
436 case reflect.Float32, reflect.Float64:
437 return floatKind, nil
438 case reflect.Complex64, reflect.Complex128:
439 return complexKind, nil
440 case reflect.String:
441 return stringKind, nil
442 }
443 return invalidKind, errBadComparisonType
444 }
445
446
447 func isNil(v reflect.Value) bool {
448 if !v.IsValid() {
449 return true
450 }
451 switch v.Kind() {
452 case reflect.Chan, reflect.Func, reflect.Interface, reflect.Map, reflect.Pointer, reflect.Slice:
453 return v.IsNil()
454 }
455 return false
456 }
457
458
459
460 func canCompare(v1, v2 reflect.Value) bool {
461 k1 := v1.Kind()
462 k2 := v2.Kind()
463 if k1 == k2 {
464 return true
465 }
466
467 return k1 == reflect.Invalid || k2 == reflect.Invalid
468 }
469
470
471 func eq(arg1 reflect.Value, arg2 ...reflect.Value) (bool, error) {
472 arg1 = indirectInterface(arg1)
473 if len(arg2) == 0 {
474 return false, errNoComparison
475 }
476 k1, _ := basicKind(arg1)
477 for _, arg := range arg2 {
478 arg = indirectInterface(arg)
479 k2, _ := basicKind(arg)
480 truth := false
481 if k1 != k2 {
482
483 switch {
484 case k1 == intKind && k2 == uintKind:
485 truth = arg1.Int() >= 0 && uint64(arg1.Int()) == arg.Uint()
486 case k1 == uintKind && k2 == intKind:
487 truth = arg.Int() >= 0 && arg1.Uint() == uint64(arg.Int())
488 default:
489 if arg1.IsValid() && arg.IsValid() {
490 return false, errBadComparison
491 }
492 }
493 } else {
494 switch k1 {
495 case boolKind:
496 truth = arg1.Bool() == arg.Bool()
497 case complexKind:
498 truth = arg1.Complex() == arg.Complex()
499 case floatKind:
500 truth = arg1.Float() == arg.Float()
501 case intKind:
502 truth = arg1.Int() == arg.Int()
503 case stringKind:
504 truth = arg1.String() == arg.String()
505 case uintKind:
506 truth = arg1.Uint() == arg.Uint()
507 default:
508 if !canCompare(arg1, arg) {
509 return false, fmt.Errorf("non-comparable types %s: %v, %s: %v", arg1, arg1.Type(), arg.Type(), arg)
510 }
511 if isNil(arg1) || isNil(arg) {
512 truth = isNil(arg) == isNil(arg1)
513 } else {
514 if !arg.Type().Comparable() {
515 return false, fmt.Errorf("non-comparable type %s: %v", arg, arg.Type())
516 }
517 truth = arg1.Interface() == arg.Interface()
518 }
519 }
520 }
521 if truth {
522 return true, nil
523 }
524 }
525 return false, nil
526 }
527
528
529 func ne(arg1, arg2 reflect.Value) (bool, error) {
530
531 equal, err := eq(arg1, arg2)
532 return !equal, err
533 }
534
535
536 func lt(arg1, arg2 reflect.Value) (bool, error) {
537 arg1 = indirectInterface(arg1)
538 k1, err := basicKind(arg1)
539 if err != nil {
540 return false, err
541 }
542 arg2 = indirectInterface(arg2)
543 k2, err := basicKind(arg2)
544 if err != nil {
545 return false, err
546 }
547 truth := false
548 if k1 != k2 {
549
550 switch {
551 case k1 == intKind && k2 == uintKind:
552 truth = arg1.Int() < 0 || uint64(arg1.Int()) < arg2.Uint()
553 case k1 == uintKind && k2 == intKind:
554 truth = arg2.Int() >= 0 && arg1.Uint() < uint64(arg2.Int())
555 default:
556 return false, errBadComparison
557 }
558 } else {
559 switch k1 {
560 case boolKind, complexKind:
561 return false, errBadComparisonType
562 case floatKind:
563 truth = arg1.Float() < arg2.Float()
564 case intKind:
565 truth = arg1.Int() < arg2.Int()
566 case stringKind:
567 truth = arg1.String() < arg2.String()
568 case uintKind:
569 truth = arg1.Uint() < arg2.Uint()
570 default:
571 panic("invalid kind")
572 }
573 }
574 return truth, nil
575 }
576
577
578 func le(arg1, arg2 reflect.Value) (bool, error) {
579
580 lessThan, err := lt(arg1, arg2)
581 if lessThan || err != nil {
582 return lessThan, err
583 }
584 return eq(arg1, arg2)
585 }
586
587
588 func gt(arg1, arg2 reflect.Value) (bool, error) {
589
590 lessOrEqual, err := le(arg1, arg2)
591 if err != nil {
592 return false, err
593 }
594 return !lessOrEqual, nil
595 }
596
597
598 func ge(arg1, arg2 reflect.Value) (bool, error) {
599
600 lessThan, err := lt(arg1, arg2)
601 if err != nil {
602 return false, err
603 }
604 return !lessThan, nil
605 }
606
607
608
609 var (
610 htmlQuot = []byte(""")
611 htmlApos = []byte("'")
612 htmlAmp = []byte("&")
613 htmlLt = []byte("<")
614 htmlGt = []byte(">")
615 htmlNull = []byte("\uFFFD")
616 )
617
618
619 func HTMLEscape(w io.Writer, b []byte) {
620 last := 0
621 for i, c := range b {
622 var html []byte
623 switch c {
624 case '\000':
625 html = htmlNull
626 case '"':
627 html = htmlQuot
628 case '\'':
629 html = htmlApos
630 case '&':
631 html = htmlAmp
632 case '<':
633 html = htmlLt
634 case '>':
635 html = htmlGt
636 default:
637 continue
638 }
639 w.Write(b[last:i])
640 w.Write(html)
641 last = i + 1
642 }
643 w.Write(b[last:])
644 }
645
646
647 func HTMLEscapeString(s string) string {
648
649 if !strings.ContainsAny(s, "'\"&<>\000") {
650 return s
651 }
652 var b strings.Builder
653 HTMLEscape(&b, []byte(s))
654 return b.String()
655 }
656
657
658
659 func HTMLEscaper(args ...any) string {
660 return HTMLEscapeString(evalArgs(args))
661 }
662
663
664
665 var (
666 jsLowUni = []byte(`\u00`)
667 hex = []byte("0123456789ABCDEF")
668
669 jsBackslash = []byte(`\\`)
670 jsApos = []byte(`\'`)
671 jsQuot = []byte(`\"`)
672 jsLt = []byte(`\u003C`)
673 jsGt = []byte(`\u003E`)
674 jsAmp = []byte(`\u0026`)
675 jsEq = []byte(`\u003D`)
676 )
677
678
679 func JSEscape(w io.Writer, b []byte) {
680 last := 0
681 for i := 0; i < len(b); i++ {
682 c := b[i]
683
684 if !jsIsSpecial(rune(c)) {
685
686 continue
687 }
688 w.Write(b[last:i])
689
690 if c < utf8.RuneSelf {
691
692
693 switch c {
694 case '\\':
695 w.Write(jsBackslash)
696 case '\'':
697 w.Write(jsApos)
698 case '"':
699 w.Write(jsQuot)
700 case '<':
701 w.Write(jsLt)
702 case '>':
703 w.Write(jsGt)
704 case '&':
705 w.Write(jsAmp)
706 case '=':
707 w.Write(jsEq)
708 default:
709 w.Write(jsLowUni)
710 t, b := c>>4, c&0x0f
711 w.Write(hex[t : t+1])
712 w.Write(hex[b : b+1])
713 }
714 } else {
715
716 r, size := utf8.DecodeRune(b[i:])
717 if unicode.IsPrint(r) {
718 w.Write(b[i : i+size])
719 } else {
720 fmt.Fprintf(w, "\\u%04X", r)
721 }
722 i += size - 1
723 }
724 last = i + 1
725 }
726 w.Write(b[last:])
727 }
728
729
730 func JSEscapeString(s string) string {
731
732 if strings.IndexFunc(s, jsIsSpecial) < 0 {
733 return s
734 }
735 var b strings.Builder
736 JSEscape(&b, []byte(s))
737 return b.String()
738 }
739
740 func jsIsSpecial(r rune) bool {
741 switch r {
742 case '\\', '\'', '"', '<', '>', '&', '=':
743 return true
744 }
745 return r < ' ' || utf8.RuneSelf <= r
746 }
747
748
749
750 func JSEscaper(args ...any) string {
751 return JSEscapeString(evalArgs(args))
752 }
753
754
755
756 func URLQueryEscaper(args ...any) string {
757 return url.QueryEscape(evalArgs(args))
758 }
759
760
761
762
763
764
765
766
767 func evalArgs(args []any) string {
768 ok := false
769 var s string
770
771 if len(args) == 1 {
772 s, ok = args[0].(string)
773 }
774 if !ok {
775 for i, arg := range args {
776 a, ok := printableValue(reflect.ValueOf(arg))
777 if ok {
778 args[i] = a
779 }
780 }
781 s = fmt.Sprint(args...)
782 }
783 return s
784 }
785
View as plain text