Source file src/text/template/funcs.go

     1  // Copyright 2011 The Go Authors. All rights reserved.
     2  // Use of this source code is governed by a BSD-style
     3  // license that can be found in the LICENSE file.
     4  
     5  package template
     6  
     7  import (
     8  	"errors"
     9  	"fmt"
    10  	"io"
    11  	"net/url"
    12  	"reflect"
    13  	"strings"
    14  	"sync"
    15  	"unicode"
    16  	"unicode/utf8"
    17  )
    18  
    19  // FuncMap is the type of the map defining the mapping from names to functions.
    20  // Each function must have either a single return value, or two return values of
    21  // which the second has type error. In that case, if the second (error)
    22  // return value evaluates to non-nil during execution, execution terminates and
    23  // Execute returns that error.
    24  //
    25  // Errors returned by Execute wrap the underlying error; call [errors.As] to
    26  // unwrap them.
    27  //
    28  // When template execution invokes a function with an argument list, that list
    29  // must be assignable to the function's parameter types. Functions meant to
    30  // apply to arguments of arbitrary type can use parameters of type interface{} or
    31  // of type [reflect.Value]. Similarly, functions meant to return a result of arbitrary
    32  // type can return interface{} or [reflect.Value].
    33  type FuncMap map[string]any
    34  
    35  // builtins returns the FuncMap.
    36  // It is not a global variable so the linker can dead code eliminate
    37  // more when this isn't called. See golang.org/issue/36021.
    38  // TODO: revert this back to a global map once golang.org/issue/2559 is fixed.
    39  func builtins() FuncMap {
    40  	return FuncMap{
    41  		"and":      and,
    42  		"call":     emptyCall,
    43  		"html":     HTMLEscaper,
    44  		"index":    index,
    45  		"slice":    slice,
    46  		"js":       JSEscaper,
    47  		"len":      length,
    48  		"not":      not,
    49  		"or":       or,
    50  		"print":    fmt.Sprint,
    51  		"printf":   fmt.Sprintf,
    52  		"println":  fmt.Sprintln,
    53  		"urlquery": URLQueryEscaper,
    54  
    55  		// Comparisons
    56  		"eq": eq, // ==
    57  		"ge": ge, // >=
    58  		"gt": gt, // >
    59  		"le": le, // <=
    60  		"lt": lt, // <
    61  		"ne": ne, // !=
    62  	}
    63  }
    64  
    65  var builtinFuncsOnce struct {
    66  	sync.Once
    67  	v map[string]reflect.Value
    68  }
    69  
    70  // builtinFuncsOnce lazily computes & caches the builtinFuncs map.
    71  // TODO: revert this back to a global map once golang.org/issue/2559 is fixed.
    72  func builtinFuncs() map[string]reflect.Value {
    73  	builtinFuncsOnce.Do(func() {
    74  		builtinFuncsOnce.v = createValueFuncs(builtins())
    75  	})
    76  	return builtinFuncsOnce.v
    77  }
    78  
    79  // createValueFuncs turns a FuncMap into a map[string]reflect.Value
    80  func createValueFuncs(funcMap FuncMap) map[string]reflect.Value {
    81  	m := make(map[string]reflect.Value)
    82  	addValueFuncs(m, funcMap)
    83  	return m
    84  }
    85  
    86  // addValueFuncs adds to values the functions in funcs, converting them to reflect.Values.
    87  func addValueFuncs(out map[string]reflect.Value, in FuncMap) {
    88  	for name, fn := range in {
    89  		if !goodName(name) {
    90  			panic(fmt.Errorf("function name %q is not a valid identifier", name))
    91  		}
    92  		v := reflect.ValueOf(fn)
    93  		if v.Kind() != reflect.Func {
    94  			panic("value for " + name + " not a function")
    95  		}
    96  		if err := goodFunc(name, v.Type()); err != nil {
    97  			panic(err)
    98  		}
    99  		out[name] = v
   100  	}
   101  }
   102  
   103  // addFuncs adds to values the functions in funcs. It does no checking of the input -
   104  // call addValueFuncs first.
   105  func addFuncs(out, in FuncMap) {
   106  	for name, fn := range in {
   107  		out[name] = fn
   108  	}
   109  }
   110  
   111  // goodFunc reports whether the function or method has the right result signature.
   112  func goodFunc(name string, typ reflect.Type) error {
   113  	// We allow functions with 1 result or 2 results where the second is an error.
   114  	switch numOut := typ.NumOut(); {
   115  	case numOut == 1:
   116  		return nil
   117  	case numOut == 2 && typ.Out(1) == errorType:
   118  		return nil
   119  	case numOut == 2:
   120  		return fmt.Errorf("invalid function signature for %s: second return value should be error; is %s", name, typ.Out(1))
   121  	default:
   122  		return fmt.Errorf("function %s has %d return values; should be 1 or 2", name, typ.NumOut())
   123  	}
   124  }
   125  
   126  // goodName reports whether the function name is a valid identifier.
   127  func goodName(name string) bool {
   128  	if name == "" {
   129  		return false
   130  	}
   131  	for i, r := range name {
   132  		switch {
   133  		case r == '_':
   134  		case i == 0 && !unicode.IsLetter(r):
   135  			return false
   136  		case !unicode.IsLetter(r) && !unicode.IsDigit(r):
   137  			return false
   138  		}
   139  	}
   140  	return true
   141  }
   142  
   143  // findFunction looks for a function in the template, and global map.
   144  func findFunction(name string, tmpl *Template) (v reflect.Value, isBuiltin, ok bool) {
   145  	if tmpl != nil && tmpl.common != nil {
   146  		tmpl.muFuncs.RLock()
   147  		defer tmpl.muFuncs.RUnlock()
   148  		if fn := tmpl.execFuncs[name]; fn.IsValid() {
   149  			return fn, false, true
   150  		}
   151  	}
   152  	if fn := builtinFuncs()[name]; fn.IsValid() {
   153  		return fn, true, true
   154  	}
   155  	return reflect.Value{}, false, false
   156  }
   157  
   158  // prepareArg checks if value can be used as an argument of type argType, and
   159  // converts an invalid value to appropriate zero if possible.
   160  func prepareArg(value reflect.Value, argType reflect.Type) (reflect.Value, error) {
   161  	if !value.IsValid() {
   162  		if !canBeNil(argType) {
   163  			return reflect.Value{}, fmt.Errorf("value is nil; should be of type %s", argType)
   164  		}
   165  		value = reflect.Zero(argType)
   166  	}
   167  	if value.Type().AssignableTo(argType) {
   168  		return value, nil
   169  	}
   170  	if intLike(value.Kind()) && intLike(argType.Kind()) && value.Type().ConvertibleTo(argType) {
   171  		value = value.Convert(argType)
   172  		return value, nil
   173  	}
   174  	return reflect.Value{}, fmt.Errorf("value has type %s; should be %s", value.Type(), argType)
   175  }
   176  
   177  func intLike(typ reflect.Kind) bool {
   178  	switch typ {
   179  	case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64:
   180  		return true
   181  	case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64, reflect.Uintptr:
   182  		return true
   183  	}
   184  	return false
   185  }
   186  
   187  // indexArg checks if a reflect.Value can be used as an index, and converts it to int if possible.
   188  func indexArg(index reflect.Value, cap int) (int, error) {
   189  	var x int64
   190  	switch index.Kind() {
   191  	case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64:
   192  		x = index.Int()
   193  	case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64, reflect.Uintptr:
   194  		x = int64(index.Uint())
   195  	case reflect.Invalid:
   196  		return 0, fmt.Errorf("cannot index slice/array with nil")
   197  	default:
   198  		return 0, fmt.Errorf("cannot index slice/array with type %s", index.Type())
   199  	}
   200  	if x < 0 || int(x) < 0 || int(x) > cap {
   201  		return 0, fmt.Errorf("index out of range: %d", x)
   202  	}
   203  	return int(x), nil
   204  }
   205  
   206  // Indexing.
   207  
   208  // index returns the result of indexing its first argument by the following
   209  // arguments. Thus "index x 1 2 3" is, in Go syntax, x[1][2][3]. Each
   210  // indexed item must be a map, slice, or array.
   211  func index(item reflect.Value, indexes ...reflect.Value) (reflect.Value, error) {
   212  	item = indirectInterface(item)
   213  	if !item.IsValid() {
   214  		return reflect.Value{}, fmt.Errorf("index of untyped nil")
   215  	}
   216  	for _, index := range indexes {
   217  		index = indirectInterface(index)
   218  		var isNil bool
   219  		if item, isNil = indirect(item); isNil {
   220  			return reflect.Value{}, fmt.Errorf("index of nil pointer")
   221  		}
   222  		switch item.Kind() {
   223  		case reflect.Array, reflect.Slice, reflect.String:
   224  			x, err := indexArg(index, item.Len())
   225  			if err != nil {
   226  				return reflect.Value{}, err
   227  			}
   228  			item = item.Index(x)
   229  		case reflect.Map:
   230  			index, err := prepareArg(index, item.Type().Key())
   231  			if err != nil {
   232  				return reflect.Value{}, err
   233  			}
   234  			if x := item.MapIndex(index); x.IsValid() {
   235  				item = x
   236  			} else {
   237  				item = reflect.Zero(item.Type().Elem())
   238  			}
   239  		case reflect.Invalid:
   240  			// the loop holds invariant: item.IsValid()
   241  			panic("unreachable")
   242  		default:
   243  			return reflect.Value{}, fmt.Errorf("can't index item of type %s", item.Type())
   244  		}
   245  	}
   246  	return item, nil
   247  }
   248  
   249  // Slicing.
   250  
   251  // slice returns the result of slicing its first argument by the remaining
   252  // arguments. Thus "slice x 1 2" is, in Go syntax, x[1:2], while "slice x"
   253  // is x[:], "slice x 1" is x[1:], and "slice x 1 2 3" is x[1:2:3]. The first
   254  // argument must be a string, slice, or array.
   255  func slice(item reflect.Value, indexes ...reflect.Value) (reflect.Value, error) {
   256  	item = indirectInterface(item)
   257  	if !item.IsValid() {
   258  		return reflect.Value{}, fmt.Errorf("slice of untyped nil")
   259  	}
   260  	if len(indexes) > 3 {
   261  		return reflect.Value{}, fmt.Errorf("too many slice indexes: %d", len(indexes))
   262  	}
   263  	var cap int
   264  	switch item.Kind() {
   265  	case reflect.String:
   266  		if len(indexes) == 3 {
   267  			return reflect.Value{}, fmt.Errorf("cannot 3-index slice a string")
   268  		}
   269  		cap = item.Len()
   270  	case reflect.Array, reflect.Slice:
   271  		cap = item.Cap()
   272  	default:
   273  		return reflect.Value{}, fmt.Errorf("can't slice item of type %s", item.Type())
   274  	}
   275  	// set default values for cases item[:], item[i:].
   276  	idx := [3]int{0, item.Len()}
   277  	for i, index := range indexes {
   278  		x, err := indexArg(index, cap)
   279  		if err != nil {
   280  			return reflect.Value{}, err
   281  		}
   282  		idx[i] = x
   283  	}
   284  	// given item[i:j], make sure i <= j.
   285  	if idx[0] > idx[1] {
   286  		return reflect.Value{}, fmt.Errorf("invalid slice index: %d > %d", idx[0], idx[1])
   287  	}
   288  	if len(indexes) < 3 {
   289  		return item.Slice(idx[0], idx[1]), nil
   290  	}
   291  	// given item[i:j:k], make sure i <= j <= k.
   292  	if idx[1] > idx[2] {
   293  		return reflect.Value{}, fmt.Errorf("invalid slice index: %d > %d", idx[1], idx[2])
   294  	}
   295  	return item.Slice3(idx[0], idx[1], idx[2]), nil
   296  }
   297  
   298  // Length
   299  
   300  // length returns the length of the item, with an error if it has no defined length.
   301  func length(item reflect.Value) (int, error) {
   302  	item, isNil := indirect(item)
   303  	if isNil {
   304  		return 0, fmt.Errorf("len of nil pointer")
   305  	}
   306  	switch item.Kind() {
   307  	case reflect.Array, reflect.Chan, reflect.Map, reflect.Slice, reflect.String:
   308  		return item.Len(), nil
   309  	}
   310  	return 0, fmt.Errorf("len of type %s", item.Type())
   311  }
   312  
   313  // Function invocation
   314  
   315  func emptyCall(fn reflect.Value, args ...reflect.Value) reflect.Value {
   316  	panic("unreachable") // implemented as a special case in evalCall
   317  }
   318  
   319  // call returns the result of evaluating the first argument as a function.
   320  // The function must return 1 result, or 2 results, the second of which is an error.
   321  func call(name string, fn reflect.Value, args ...reflect.Value) (reflect.Value, error) {
   322  	fn = indirectInterface(fn)
   323  	if !fn.IsValid() {
   324  		return reflect.Value{}, fmt.Errorf("call of nil")
   325  	}
   326  	typ := fn.Type()
   327  	if typ.Kind() != reflect.Func {
   328  		return reflect.Value{}, fmt.Errorf("non-function %s of type %s", name, typ)
   329  	}
   330  
   331  	if err := goodFunc(name, typ); err != nil {
   332  		return reflect.Value{}, err
   333  	}
   334  	numIn := typ.NumIn()
   335  	var dddType reflect.Type
   336  	if typ.IsVariadic() {
   337  		if len(args) < numIn-1 {
   338  			return reflect.Value{}, fmt.Errorf("wrong number of args for %s: got %d want at least %d", name, len(args), numIn-1)
   339  		}
   340  		dddType = typ.In(numIn - 1).Elem()
   341  	} else {
   342  		if len(args) != numIn {
   343  			return reflect.Value{}, fmt.Errorf("wrong number of args for %s: got %d want %d", name, len(args), numIn)
   344  		}
   345  	}
   346  	argv := make([]reflect.Value, len(args))
   347  	for i, arg := range args {
   348  		arg = indirectInterface(arg)
   349  		// Compute the expected type. Clumsy because of variadics.
   350  		argType := dddType
   351  		if !typ.IsVariadic() || i < numIn-1 {
   352  			argType = typ.In(i)
   353  		}
   354  
   355  		var err error
   356  		if argv[i], err = prepareArg(arg, argType); err != nil {
   357  			return reflect.Value{}, fmt.Errorf("arg %d: %w", i, err)
   358  		}
   359  	}
   360  	return safeCall(fn, argv)
   361  }
   362  
   363  // safeCall runs fun.Call(args), and returns the resulting value and error, if
   364  // any. If the call panics, the panic value is returned as an error.
   365  func safeCall(fun reflect.Value, args []reflect.Value) (val reflect.Value, err error) {
   366  	defer func() {
   367  		if r := recover(); r != nil {
   368  			if e, ok := r.(error); ok {
   369  				err = e
   370  			} else {
   371  				err = fmt.Errorf("%v", r)
   372  			}
   373  		}
   374  	}()
   375  	ret := fun.Call(args)
   376  	if len(ret) == 2 && !ret[1].IsNil() {
   377  		return ret[0], ret[1].Interface().(error)
   378  	}
   379  	return ret[0], nil
   380  }
   381  
   382  // Boolean logic.
   383  
   384  func truth(arg reflect.Value) bool {
   385  	t, _ := isTrue(indirectInterface(arg))
   386  	return t
   387  }
   388  
   389  // and computes the Boolean AND of its arguments, returning
   390  // the first false argument it encounters, or the last argument.
   391  func and(arg0 reflect.Value, args ...reflect.Value) reflect.Value {
   392  	panic("unreachable") // implemented as a special case in evalCall
   393  }
   394  
   395  // or computes the Boolean OR of its arguments, returning
   396  // the first true argument it encounters, or the last argument.
   397  func or(arg0 reflect.Value, args ...reflect.Value) reflect.Value {
   398  	panic("unreachable") // implemented as a special case in evalCall
   399  }
   400  
   401  // not returns the Boolean negation of its argument.
   402  func not(arg reflect.Value) bool {
   403  	return !truth(arg)
   404  }
   405  
   406  // Comparison.
   407  
   408  // TODO: Perhaps allow comparison between signed and unsigned integers.
   409  
   410  var (
   411  	errBadComparisonType = errors.New("invalid type for comparison")
   412  	errBadComparison     = errors.New("incompatible types for comparison")
   413  	errNoComparison      = errors.New("missing argument for comparison")
   414  )
   415  
   416  type kind int
   417  
   418  const (
   419  	invalidKind kind = iota
   420  	boolKind
   421  	complexKind
   422  	intKind
   423  	floatKind
   424  	stringKind
   425  	uintKind
   426  )
   427  
   428  func basicKind(v reflect.Value) (kind, error) {
   429  	switch v.Kind() {
   430  	case reflect.Bool:
   431  		return boolKind, nil
   432  	case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64:
   433  		return intKind, nil
   434  	case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64, reflect.Uintptr:
   435  		return uintKind, nil
   436  	case reflect.Float32, reflect.Float64:
   437  		return floatKind, nil
   438  	case reflect.Complex64, reflect.Complex128:
   439  		return complexKind, nil
   440  	case reflect.String:
   441  		return stringKind, nil
   442  	}
   443  	return invalidKind, errBadComparisonType
   444  }
   445  
   446  // isNil returns true if v is the zero reflect.Value, or nil of its type.
   447  func isNil(v reflect.Value) bool {
   448  	if !v.IsValid() {
   449  		return true
   450  	}
   451  	switch v.Kind() {
   452  	case reflect.Chan, reflect.Func, reflect.Interface, reflect.Map, reflect.Pointer, reflect.Slice:
   453  		return v.IsNil()
   454  	}
   455  	return false
   456  }
   457  
   458  // canCompare reports whether v1 and v2 are both the same kind, or one is nil.
   459  // Called only when dealing with nillable types, or there's about to be an error.
   460  func canCompare(v1, v2 reflect.Value) bool {
   461  	k1 := v1.Kind()
   462  	k2 := v2.Kind()
   463  	if k1 == k2 {
   464  		return true
   465  	}
   466  	// We know the type can be compared to nil.
   467  	return k1 == reflect.Invalid || k2 == reflect.Invalid
   468  }
   469  
   470  // eq evaluates the comparison a == b || a == c || ...
   471  func eq(arg1 reflect.Value, arg2 ...reflect.Value) (bool, error) {
   472  	arg1 = indirectInterface(arg1)
   473  	if len(arg2) == 0 {
   474  		return false, errNoComparison
   475  	}
   476  	k1, _ := basicKind(arg1)
   477  	for _, arg := range arg2 {
   478  		arg = indirectInterface(arg)
   479  		k2, _ := basicKind(arg)
   480  		truth := false
   481  		if k1 != k2 {
   482  			// Special case: Can compare integer values regardless of type's sign.
   483  			switch {
   484  			case k1 == intKind && k2 == uintKind:
   485  				truth = arg1.Int() >= 0 && uint64(arg1.Int()) == arg.Uint()
   486  			case k1 == uintKind && k2 == intKind:
   487  				truth = arg.Int() >= 0 && arg1.Uint() == uint64(arg.Int())
   488  			default:
   489  				if arg1.IsValid() && arg.IsValid() {
   490  					return false, errBadComparison
   491  				}
   492  			}
   493  		} else {
   494  			switch k1 {
   495  			case boolKind:
   496  				truth = arg1.Bool() == arg.Bool()
   497  			case complexKind:
   498  				truth = arg1.Complex() == arg.Complex()
   499  			case floatKind:
   500  				truth = arg1.Float() == arg.Float()
   501  			case intKind:
   502  				truth = arg1.Int() == arg.Int()
   503  			case stringKind:
   504  				truth = arg1.String() == arg.String()
   505  			case uintKind:
   506  				truth = arg1.Uint() == arg.Uint()
   507  			default:
   508  				if !canCompare(arg1, arg) {
   509  					return false, fmt.Errorf("non-comparable types %s: %v, %s: %v", arg1, arg1.Type(), arg.Type(), arg)
   510  				}
   511  				if isNil(arg1) || isNil(arg) {
   512  					truth = isNil(arg) == isNil(arg1)
   513  				} else {
   514  					if !arg.Type().Comparable() {
   515  						return false, fmt.Errorf("non-comparable type %s: %v", arg, arg.Type())
   516  					}
   517  					truth = arg1.Interface() == arg.Interface()
   518  				}
   519  			}
   520  		}
   521  		if truth {
   522  			return true, nil
   523  		}
   524  	}
   525  	return false, nil
   526  }
   527  
   528  // ne evaluates the comparison a != b.
   529  func ne(arg1, arg2 reflect.Value) (bool, error) {
   530  	// != is the inverse of ==.
   531  	equal, err := eq(arg1, arg2)
   532  	return !equal, err
   533  }
   534  
   535  // lt evaluates the comparison a < b.
   536  func lt(arg1, arg2 reflect.Value) (bool, error) {
   537  	arg1 = indirectInterface(arg1)
   538  	k1, err := basicKind(arg1)
   539  	if err != nil {
   540  		return false, err
   541  	}
   542  	arg2 = indirectInterface(arg2)
   543  	k2, err := basicKind(arg2)
   544  	if err != nil {
   545  		return false, err
   546  	}
   547  	truth := false
   548  	if k1 != k2 {
   549  		// Special case: Can compare integer values regardless of type's sign.
   550  		switch {
   551  		case k1 == intKind && k2 == uintKind:
   552  			truth = arg1.Int() < 0 || uint64(arg1.Int()) < arg2.Uint()
   553  		case k1 == uintKind && k2 == intKind:
   554  			truth = arg2.Int() >= 0 && arg1.Uint() < uint64(arg2.Int())
   555  		default:
   556  			return false, errBadComparison
   557  		}
   558  	} else {
   559  		switch k1 {
   560  		case boolKind, complexKind:
   561  			return false, errBadComparisonType
   562  		case floatKind:
   563  			truth = arg1.Float() < arg2.Float()
   564  		case intKind:
   565  			truth = arg1.Int() < arg2.Int()
   566  		case stringKind:
   567  			truth = arg1.String() < arg2.String()
   568  		case uintKind:
   569  			truth = arg1.Uint() < arg2.Uint()
   570  		default:
   571  			panic("invalid kind")
   572  		}
   573  	}
   574  	return truth, nil
   575  }
   576  
   577  // le evaluates the comparison <= b.
   578  func le(arg1, arg2 reflect.Value) (bool, error) {
   579  	// <= is < or ==.
   580  	lessThan, err := lt(arg1, arg2)
   581  	if lessThan || err != nil {
   582  		return lessThan, err
   583  	}
   584  	return eq(arg1, arg2)
   585  }
   586  
   587  // gt evaluates the comparison a > b.
   588  func gt(arg1, arg2 reflect.Value) (bool, error) {
   589  	// > is the inverse of <=.
   590  	lessOrEqual, err := le(arg1, arg2)
   591  	if err != nil {
   592  		return false, err
   593  	}
   594  	return !lessOrEqual, nil
   595  }
   596  
   597  // ge evaluates the comparison a >= b.
   598  func ge(arg1, arg2 reflect.Value) (bool, error) {
   599  	// >= is the inverse of <.
   600  	lessThan, err := lt(arg1, arg2)
   601  	if err != nil {
   602  		return false, err
   603  	}
   604  	return !lessThan, nil
   605  }
   606  
   607  // HTML escaping.
   608  
   609  var (
   610  	htmlQuot = []byte("&#34;") // shorter than "&quot;"
   611  	htmlApos = []byte("&#39;") // shorter than "&apos;" and apos was not in HTML until HTML5
   612  	htmlAmp  = []byte("&amp;")
   613  	htmlLt   = []byte("&lt;")
   614  	htmlGt   = []byte("&gt;")
   615  	htmlNull = []byte("\uFFFD")
   616  )
   617  
   618  // HTMLEscape writes to w the escaped HTML equivalent of the plain text data b.
   619  func HTMLEscape(w io.Writer, b []byte) {
   620  	last := 0
   621  	for i, c := range b {
   622  		var html []byte
   623  		switch c {
   624  		case '\000':
   625  			html = htmlNull
   626  		case '"':
   627  			html = htmlQuot
   628  		case '\'':
   629  			html = htmlApos
   630  		case '&':
   631  			html = htmlAmp
   632  		case '<':
   633  			html = htmlLt
   634  		case '>':
   635  			html = htmlGt
   636  		default:
   637  			continue
   638  		}
   639  		w.Write(b[last:i])
   640  		w.Write(html)
   641  		last = i + 1
   642  	}
   643  	w.Write(b[last:])
   644  }
   645  
   646  // HTMLEscapeString returns the escaped HTML equivalent of the plain text data s.
   647  func HTMLEscapeString(s string) string {
   648  	// Avoid allocation if we can.
   649  	if !strings.ContainsAny(s, "'\"&<>\000") {
   650  		return s
   651  	}
   652  	var b strings.Builder
   653  	HTMLEscape(&b, []byte(s))
   654  	return b.String()
   655  }
   656  
   657  // HTMLEscaper returns the escaped HTML equivalent of the textual
   658  // representation of its arguments.
   659  func HTMLEscaper(args ...any) string {
   660  	return HTMLEscapeString(evalArgs(args))
   661  }
   662  
   663  // JavaScript escaping.
   664  
   665  var (
   666  	jsLowUni = []byte(`\u00`)
   667  	hex      = []byte("0123456789ABCDEF")
   668  
   669  	jsBackslash = []byte(`\\`)
   670  	jsApos      = []byte(`\'`)
   671  	jsQuot      = []byte(`\"`)
   672  	jsLt        = []byte(`\u003C`)
   673  	jsGt        = []byte(`\u003E`)
   674  	jsAmp       = []byte(`\u0026`)
   675  	jsEq        = []byte(`\u003D`)
   676  )
   677  
   678  // JSEscape writes to w the escaped JavaScript equivalent of the plain text data b.
   679  func JSEscape(w io.Writer, b []byte) {
   680  	last := 0
   681  	for i := 0; i < len(b); i++ {
   682  		c := b[i]
   683  
   684  		if !jsIsSpecial(rune(c)) {
   685  			// fast path: nothing to do
   686  			continue
   687  		}
   688  		w.Write(b[last:i])
   689  
   690  		if c < utf8.RuneSelf {
   691  			// Quotes, slashes and angle brackets get quoted.
   692  			// Control characters get written as \u00XX.
   693  			switch c {
   694  			case '\\':
   695  				w.Write(jsBackslash)
   696  			case '\'':
   697  				w.Write(jsApos)
   698  			case '"':
   699  				w.Write(jsQuot)
   700  			case '<':
   701  				w.Write(jsLt)
   702  			case '>':
   703  				w.Write(jsGt)
   704  			case '&':
   705  				w.Write(jsAmp)
   706  			case '=':
   707  				w.Write(jsEq)
   708  			default:
   709  				w.Write(jsLowUni)
   710  				t, b := c>>4, c&0x0f
   711  				w.Write(hex[t : t+1])
   712  				w.Write(hex[b : b+1])
   713  			}
   714  		} else {
   715  			// Unicode rune.
   716  			r, size := utf8.DecodeRune(b[i:])
   717  			if unicode.IsPrint(r) {
   718  				w.Write(b[i : i+size])
   719  			} else {
   720  				fmt.Fprintf(w, "\\u%04X", r)
   721  			}
   722  			i += size - 1
   723  		}
   724  		last = i + 1
   725  	}
   726  	w.Write(b[last:])
   727  }
   728  
   729  // JSEscapeString returns the escaped JavaScript equivalent of the plain text data s.
   730  func JSEscapeString(s string) string {
   731  	// Avoid allocation if we can.
   732  	if strings.IndexFunc(s, jsIsSpecial) < 0 {
   733  		return s
   734  	}
   735  	var b strings.Builder
   736  	JSEscape(&b, []byte(s))
   737  	return b.String()
   738  }
   739  
   740  func jsIsSpecial(r rune) bool {
   741  	switch r {
   742  	case '\\', '\'', '"', '<', '>', '&', '=':
   743  		return true
   744  	}
   745  	return r < ' ' || utf8.RuneSelf <= r
   746  }
   747  
   748  // JSEscaper returns the escaped JavaScript equivalent of the textual
   749  // representation of its arguments.
   750  func JSEscaper(args ...any) string {
   751  	return JSEscapeString(evalArgs(args))
   752  }
   753  
   754  // URLQueryEscaper returns the escaped value of the textual representation of
   755  // its arguments in a form suitable for embedding in a URL query.
   756  func URLQueryEscaper(args ...any) string {
   757  	return url.QueryEscape(evalArgs(args))
   758  }
   759  
   760  // evalArgs formats the list of arguments into a string. It is therefore equivalent to
   761  //
   762  //	fmt.Sprint(args...)
   763  //
   764  // except that each argument is indirected (if a pointer), as required,
   765  // using the same rules as the default string evaluation during template
   766  // execution.
   767  func evalArgs(args []any) string {
   768  	ok := false
   769  	var s string
   770  	// Fast path for simple common case.
   771  	if len(args) == 1 {
   772  		s, ok = args[0].(string)
   773  	}
   774  	if !ok {
   775  		for i, arg := range args {
   776  			a, ok := printableValue(reflect.ValueOf(arg))
   777  			if ok {
   778  				args[i] = a
   779  			} // else let fmt do its thing
   780  		}
   781  		s = fmt.Sprint(args...)
   782  	}
   783  	return s
   784  }
   785  

View as plain text