Source file src/cmd/compile/internal/walk/switch.go

     1  // Copyright 2009 The Go Authors. All rights reserved.
     2  // Use of this source code is governed by a BSD-style
     3  // license that can be found in the LICENSE file.
     4  
     5  package walk
     6  
     7  import (
     8  	"cmp"
     9  	"fmt"
    10  	"go/constant"
    11  	"go/token"
    12  	"math/bits"
    13  	"slices"
    14  	"sort"
    15  	"strings"
    16  
    17  	"cmd/compile/internal/base"
    18  	"cmd/compile/internal/ir"
    19  	"cmd/compile/internal/objw"
    20  	"cmd/compile/internal/reflectdata"
    21  	"cmd/compile/internal/rttype"
    22  	"cmd/compile/internal/ssagen"
    23  	"cmd/compile/internal/typecheck"
    24  	"cmd/compile/internal/types"
    25  	"cmd/internal/obj"
    26  	"cmd/internal/src"
    27  )
    28  
    29  // walkSwitch walks a switch statement.
    30  func walkSwitch(sw *ir.SwitchStmt) {
    31  	// Guard against double walk, see #25776.
    32  	if sw.Walked() {
    33  		return // Was fatal, but eliminating every possible source of double-walking is hard
    34  	}
    35  	sw.SetWalked(true)
    36  
    37  	if sw.Tag != nil && sw.Tag.Op() == ir.OTYPESW {
    38  		walkSwitchType(sw)
    39  	} else {
    40  		walkSwitchExpr(sw)
    41  	}
    42  }
    43  
    44  // walkSwitchExpr generates an AST implementing sw.  sw is an
    45  // expression switch.
    46  func walkSwitchExpr(sw *ir.SwitchStmt) {
    47  	lno := ir.SetPos(sw)
    48  
    49  	cond := sw.Tag
    50  	sw.Tag = nil
    51  
    52  	// convert switch {...} to switch true {...}
    53  	if cond == nil {
    54  		cond = ir.NewBool(base.Pos, true)
    55  		cond = typecheck.Expr(cond)
    56  		cond = typecheck.DefaultLit(cond, nil)
    57  	}
    58  
    59  	// Given "switch string(byteslice)",
    60  	// with all cases being side-effect free,
    61  	// use a zero-cost alias of the byte slice.
    62  	// Do this before calling walkExpr on cond,
    63  	// because walkExpr will lower the string
    64  	// conversion into a runtime call.
    65  	// See issue 24937 for more discussion.
    66  	if cond.Op() == ir.OBYTES2STR && allCaseExprsAreSideEffectFree(sw) {
    67  		cond := cond.(*ir.ConvExpr)
    68  		cond.SetOp(ir.OBYTES2STRTMP)
    69  	}
    70  
    71  	cond = walkExpr(cond, sw.PtrInit())
    72  	if cond.Op() != ir.OLITERAL && cond.Op() != ir.ONIL {
    73  		cond = copyExpr(cond, cond.Type(), &sw.Compiled)
    74  	}
    75  
    76  	base.Pos = lno
    77  
    78  	s := exprSwitch{
    79  		pos:      lno,
    80  		exprname: cond,
    81  	}
    82  
    83  	var defaultGoto ir.Node
    84  	var body ir.Nodes
    85  	for _, ncase := range sw.Cases {
    86  		label := typecheck.AutoLabel(".s")
    87  		jmp := ir.NewBranchStmt(ncase.Pos(), ir.OGOTO, label)
    88  
    89  		// Process case dispatch.
    90  		if len(ncase.List) == 0 {
    91  			if defaultGoto != nil {
    92  				base.Fatalf("duplicate default case not detected during typechecking")
    93  			}
    94  			defaultGoto = jmp
    95  		}
    96  
    97  		for i, n1 := range ncase.List {
    98  			var rtype ir.Node
    99  			if i < len(ncase.RTypes) {
   100  				rtype = ncase.RTypes[i]
   101  			}
   102  			s.Add(ncase.Pos(), n1, rtype, jmp)
   103  		}
   104  
   105  		// Process body.
   106  		body.Append(ir.NewLabelStmt(ncase.Pos(), label))
   107  		body.Append(ncase.Body...)
   108  		if fall, pos := endsInFallthrough(ncase.Body); !fall {
   109  			br := ir.NewBranchStmt(base.Pos, ir.OBREAK, nil)
   110  			br.SetPos(pos)
   111  			body.Append(br)
   112  		}
   113  	}
   114  	sw.Cases = nil
   115  
   116  	if defaultGoto == nil {
   117  		br := ir.NewBranchStmt(base.Pos, ir.OBREAK, nil)
   118  		br.SetPos(br.Pos().WithNotStmt())
   119  		defaultGoto = br
   120  	}
   121  
   122  	s.Emit(&sw.Compiled)
   123  	sw.Compiled.Append(defaultGoto)
   124  	sw.Compiled.Append(body.Take()...)
   125  	walkStmtList(sw.Compiled)
   126  }
   127  
   128  // An exprSwitch walks an expression switch.
   129  type exprSwitch struct {
   130  	pos      src.XPos
   131  	exprname ir.Node // value being switched on
   132  
   133  	done    ir.Nodes
   134  	clauses []exprClause
   135  }
   136  
   137  type exprClause struct {
   138  	pos    src.XPos
   139  	lo, hi ir.Node
   140  	rtype  ir.Node // *runtime._type for OEQ node
   141  	jmp    ir.Node
   142  }
   143  
   144  func (s *exprSwitch) Add(pos src.XPos, expr, rtype, jmp ir.Node) {
   145  	c := exprClause{pos: pos, lo: expr, hi: expr, rtype: rtype, jmp: jmp}
   146  	if types.IsOrdered[s.exprname.Type().Kind()] && expr.Op() == ir.OLITERAL {
   147  		s.clauses = append(s.clauses, c)
   148  		return
   149  	}
   150  
   151  	s.flush()
   152  	s.clauses = append(s.clauses, c)
   153  	s.flush()
   154  }
   155  
   156  func (s *exprSwitch) Emit(out *ir.Nodes) {
   157  	s.flush()
   158  	out.Append(s.done.Take()...)
   159  }
   160  
   161  func (s *exprSwitch) flush() {
   162  	cc := s.clauses
   163  	s.clauses = nil
   164  	if len(cc) == 0 {
   165  		return
   166  	}
   167  
   168  	// Caution: If len(cc) == 1, then cc[0] might not an OLITERAL.
   169  	// The code below is structured to implicitly handle this case
   170  	// (e.g., sort.Slice doesn't need to invoke the less function
   171  	// when there's only a single slice element).
   172  
   173  	if s.exprname.Type().IsString() && len(cc) >= 2 {
   174  		// Sort strings by length and then by value. It is
   175  		// much cheaper to compare lengths than values, and
   176  		// all we need here is consistency. We respect this
   177  		// sorting below.
   178  		slices.SortFunc(cc, func(a, b exprClause) int {
   179  			si := ir.StringVal(a.lo)
   180  			sj := ir.StringVal(b.lo)
   181  			if len(si) != len(sj) {
   182  				return cmp.Compare(len(si), len(sj))
   183  			}
   184  			return strings.Compare(si, sj)
   185  		})
   186  
   187  		// runLen returns the string length associated with a
   188  		// particular run of exprClauses.
   189  		runLen := func(run []exprClause) int64 { return int64(len(ir.StringVal(run[0].lo))) }
   190  
   191  		// Collapse runs of consecutive strings with the same length.
   192  		var runs [][]exprClause
   193  		start := 0
   194  		for i := 1; i < len(cc); i++ {
   195  			if runLen(cc[start:]) != runLen(cc[i:]) {
   196  				runs = append(runs, cc[start:i])
   197  				start = i
   198  			}
   199  		}
   200  		runs = append(runs, cc[start:])
   201  
   202  		// We have strings of more than one length. Generate an
   203  		// outer switch which switches on the length of the string
   204  		// and an inner switch in each case which resolves all the
   205  		// strings of the same length. The code looks something like this:
   206  
   207  		// goto outerLabel
   208  		// len5:
   209  		//   ... search among length 5 strings ...
   210  		//   goto endLabel
   211  		// len8:
   212  		//   ... search among length 8 strings ...
   213  		//   goto endLabel
   214  		// ... other lengths ...
   215  		// outerLabel:
   216  		// switch len(s) {
   217  		//   case 5: goto len5
   218  		//   case 8: goto len8
   219  		//   ... other lengths ...
   220  		// }
   221  		// endLabel:
   222  
   223  		outerLabel := typecheck.AutoLabel(".s")
   224  		endLabel := typecheck.AutoLabel(".s")
   225  
   226  		// Jump around all the individual switches for each length.
   227  		s.done.Append(ir.NewBranchStmt(s.pos, ir.OGOTO, outerLabel))
   228  
   229  		var outer exprSwitch
   230  		outer.exprname = ir.NewUnaryExpr(s.pos, ir.OLEN, s.exprname)
   231  		outer.exprname.SetType(types.Types[types.TINT])
   232  
   233  		for _, run := range runs {
   234  			// Target label to jump to when we match this length.
   235  			label := typecheck.AutoLabel(".s")
   236  
   237  			// Search within this run of same-length strings.
   238  			pos := run[0].pos
   239  			s.done.Append(ir.NewLabelStmt(pos, label))
   240  			stringSearch(s.exprname, run, &s.done)
   241  			s.done.Append(ir.NewBranchStmt(pos, ir.OGOTO, endLabel))
   242  
   243  			// Add length case to outer switch.
   244  			cas := ir.NewInt(pos, runLen(run))
   245  			jmp := ir.NewBranchStmt(pos, ir.OGOTO, label)
   246  			outer.Add(pos, cas, nil, jmp)
   247  		}
   248  		s.done.Append(ir.NewLabelStmt(s.pos, outerLabel))
   249  		outer.Emit(&s.done)
   250  		s.done.Append(ir.NewLabelStmt(s.pos, endLabel))
   251  		return
   252  	}
   253  
   254  	sort.Slice(cc, func(i, j int) bool {
   255  		return constant.Compare(cc[i].lo.Val(), token.LSS, cc[j].lo.Val())
   256  	})
   257  
   258  	// Merge consecutive integer cases.
   259  	if s.exprname.Type().IsInteger() {
   260  		consecutive := func(last, next constant.Value) bool {
   261  			delta := constant.BinaryOp(next, token.SUB, last)
   262  			return constant.Compare(delta, token.EQL, constant.MakeInt64(1))
   263  		}
   264  
   265  		merged := cc[:1]
   266  		for _, c := range cc[1:] {
   267  			last := &merged[len(merged)-1]
   268  			if last.jmp == c.jmp && consecutive(last.hi.Val(), c.lo.Val()) {
   269  				last.hi = c.lo
   270  			} else {
   271  				merged = append(merged, c)
   272  			}
   273  		}
   274  		cc = merged
   275  	}
   276  
   277  	s.search(cc, &s.done)
   278  }
   279  
   280  func (s *exprSwitch) search(cc []exprClause, out *ir.Nodes) {
   281  	if s.tryJumpTable(cc, out) {
   282  		return
   283  	}
   284  	binarySearch(len(cc), out,
   285  		func(i int) ir.Node {
   286  			return ir.NewBinaryExpr(base.Pos, ir.OLE, s.exprname, cc[i-1].hi)
   287  		},
   288  		func(i int, nif *ir.IfStmt) {
   289  			c := &cc[i]
   290  			nif.Cond = c.test(s.exprname)
   291  			nif.Body = []ir.Node{c.jmp}
   292  		},
   293  	)
   294  }
   295  
   296  // Try to implement the clauses with a jump table. Returns true if successful.
   297  func (s *exprSwitch) tryJumpTable(cc []exprClause, out *ir.Nodes) bool {
   298  	const minCases = 8   // have at least minCases cases in the switch
   299  	const minDensity = 4 // use at least 1 out of every minDensity entries
   300  
   301  	if base.Flag.N != 0 || !ssagen.Arch.LinkArch.CanJumpTable || base.Ctxt.Retpoline {
   302  		return false
   303  	}
   304  	if len(cc) < minCases {
   305  		return false // not enough cases for it to be worth it
   306  	}
   307  	if cc[0].lo.Val().Kind() != constant.Int {
   308  		return false // e.g. float
   309  	}
   310  	if s.exprname.Type().Size() > int64(types.PtrSize) {
   311  		return false // 64-bit switches on 32-bit archs
   312  	}
   313  	min := cc[0].lo.Val()
   314  	max := cc[len(cc)-1].hi.Val()
   315  	width := constant.BinaryOp(constant.BinaryOp(max, token.SUB, min), token.ADD, constant.MakeInt64(1))
   316  	limit := constant.MakeInt64(int64(len(cc)) * minDensity)
   317  	if constant.Compare(width, token.GTR, limit) {
   318  		// We disable jump tables if we use less than a minimum fraction of the entries.
   319  		// i.e. for switch x {case 0: case 1000: case 2000:} we don't want to use a jump table.
   320  		return false
   321  	}
   322  	jt := ir.NewJumpTableStmt(base.Pos, s.exprname)
   323  	for _, c := range cc {
   324  		jmp := c.jmp.(*ir.BranchStmt)
   325  		if jmp.Op() != ir.OGOTO || jmp.Label == nil {
   326  			panic("bad switch case body")
   327  		}
   328  		for i := c.lo.Val(); constant.Compare(i, token.LEQ, c.hi.Val()); i = constant.BinaryOp(i, token.ADD, constant.MakeInt64(1)) {
   329  			jt.Cases = append(jt.Cases, i)
   330  			jt.Targets = append(jt.Targets, jmp.Label)
   331  		}
   332  	}
   333  	out.Append(jt)
   334  	return true
   335  }
   336  
   337  func (c *exprClause) test(exprname ir.Node) ir.Node {
   338  	// Integer range.
   339  	if c.hi != c.lo {
   340  		low := ir.NewBinaryExpr(c.pos, ir.OGE, exprname, c.lo)
   341  		high := ir.NewBinaryExpr(c.pos, ir.OLE, exprname, c.hi)
   342  		return ir.NewLogicalExpr(c.pos, ir.OANDAND, low, high)
   343  	}
   344  
   345  	// Optimize "switch true { ...}" and "switch false { ... }".
   346  	if ir.IsConst(exprname, constant.Bool) && !c.lo.Type().IsInterface() {
   347  		if ir.BoolVal(exprname) {
   348  			return c.lo
   349  		} else {
   350  			return ir.NewUnaryExpr(c.pos, ir.ONOT, c.lo)
   351  		}
   352  	}
   353  
   354  	n := ir.NewBinaryExpr(c.pos, ir.OEQ, exprname, c.lo)
   355  	n.RType = c.rtype
   356  	return n
   357  }
   358  
   359  func allCaseExprsAreSideEffectFree(sw *ir.SwitchStmt) bool {
   360  	// In theory, we could be more aggressive, allowing any
   361  	// side-effect-free expressions in cases, but it's a bit
   362  	// tricky because some of that information is unavailable due
   363  	// to the introduction of temporaries during order.
   364  	// Restricting to constants is simple and probably powerful
   365  	// enough.
   366  
   367  	for _, ncase := range sw.Cases {
   368  		for _, v := range ncase.List {
   369  			if v.Op() != ir.OLITERAL {
   370  				return false
   371  			}
   372  		}
   373  	}
   374  	return true
   375  }
   376  
   377  // endsInFallthrough reports whether stmts ends with a "fallthrough" statement.
   378  func endsInFallthrough(stmts []ir.Node) (bool, src.XPos) {
   379  	if len(stmts) == 0 {
   380  		return false, src.NoXPos
   381  	}
   382  	i := len(stmts) - 1
   383  	return stmts[i].Op() == ir.OFALL, stmts[i].Pos()
   384  }
   385  
   386  // walkSwitchType generates an AST that implements sw, where sw is a
   387  // type switch.
   388  func walkSwitchType(sw *ir.SwitchStmt) {
   389  	var s typeSwitch
   390  	s.srcName = sw.Tag.(*ir.TypeSwitchGuard).X
   391  	s.srcName = walkExpr(s.srcName, sw.PtrInit())
   392  	s.srcName = copyExpr(s.srcName, s.srcName.Type(), &sw.Compiled)
   393  	s.okName = typecheck.TempAt(base.Pos, ir.CurFunc, types.Types[types.TBOOL])
   394  	s.itabName = typecheck.TempAt(base.Pos, ir.CurFunc, types.Types[types.TUINT8].PtrTo())
   395  
   396  	// Get interface descriptor word.
   397  	// For empty interfaces this will be the type.
   398  	// For non-empty interfaces this will be the itab.
   399  	srcItab := ir.NewUnaryExpr(base.Pos, ir.OITAB, s.srcName)
   400  	srcData := ir.NewUnaryExpr(base.Pos, ir.OIDATA, s.srcName)
   401  	srcData.SetType(types.Types[types.TUINT8].PtrTo())
   402  	srcData.SetTypecheck(1)
   403  
   404  	// For empty interfaces, do:
   405  	//     if e._type == nil {
   406  	//         do nil case if it exists, otherwise default
   407  	//     }
   408  	//     h := e._type.hash
   409  	// Use a similar strategy for non-empty interfaces.
   410  	ifNil := ir.NewIfStmt(base.Pos, nil, nil, nil)
   411  	ifNil.Cond = ir.NewBinaryExpr(base.Pos, ir.OEQ, srcItab, typecheck.NodNil())
   412  	base.Pos = base.Pos.WithNotStmt() // disable statement marks after the first check.
   413  	ifNil.Cond = typecheck.Expr(ifNil.Cond)
   414  	ifNil.Cond = typecheck.DefaultLit(ifNil.Cond, nil)
   415  	// ifNil.Nbody assigned later.
   416  	sw.Compiled.Append(ifNil)
   417  
   418  	// Load hash from type or itab.
   419  	dotHash := typeHashFieldOf(base.Pos, srcItab)
   420  	s.hashName = copyExpr(dotHash, dotHash.Type(), &sw.Compiled)
   421  
   422  	// Make a label for each case body.
   423  	labels := make([]*types.Sym, len(sw.Cases))
   424  	for i := range sw.Cases {
   425  		labels[i] = typecheck.AutoLabel(".s")
   426  	}
   427  
   428  	// "jump" to execute if no case matches.
   429  	br := ir.NewBranchStmt(base.Pos, ir.OBREAK, nil)
   430  
   431  	// Assemble a list of all the types we're looking for.
   432  	// This pass flattens the case lists, as well as handles
   433  	// some unusual cases, like default and nil cases.
   434  	type oneCase struct {
   435  		pos src.XPos
   436  		jmp ir.Node // jump to body of selected case
   437  
   438  		// The case we're matching. Normally the type we're looking for
   439  		// is typ.Type(), but when typ is ODYNAMICTYPE the actual type
   440  		// we're looking for is not a compile-time constant (typ.Type()
   441  		// will be its shape).
   442  		typ ir.Node
   443  
   444  		// For a single runtime known type with a case var, create a
   445  		// temporary variable to hold the value returned by the dynamic
   446  		// type assert expr, so that we do not need one more dynamic
   447  		// type assert expr later.
   448  		val ir.Node
   449  		idx int // index of the single runtime known type in sw.Cases
   450  	}
   451  	var cases []oneCase
   452  	var defaultGoto, nilGoto ir.Node
   453  	for i, ncase := range sw.Cases {
   454  		jmp := ir.NewBranchStmt(ncase.Pos(), ir.OGOTO, labels[i])
   455  		if len(ncase.List) == 0 { // default:
   456  			if defaultGoto != nil {
   457  				base.Fatalf("duplicate default case not detected during typechecking")
   458  			}
   459  			defaultGoto = jmp
   460  		}
   461  		for _, n1 := range ncase.List {
   462  			if ir.IsNil(n1) { // case nil:
   463  				if nilGoto != nil {
   464  					base.Fatalf("duplicate nil case not detected during typechecking")
   465  				}
   466  				nilGoto = jmp
   467  				continue
   468  			}
   469  			idx := -1
   470  			var val ir.Node
   471  			// for a single runtime known type with a case var, create the tmpVar
   472  			if len(ncase.List) == 1 && ncase.List[0].Op() == ir.ODYNAMICTYPE && ncase.Var != nil {
   473  				val = typecheck.TempAt(ncase.Pos(), ir.CurFunc, ncase.Var.Type())
   474  				idx = i
   475  			}
   476  			cases = append(cases, oneCase{
   477  				pos: ncase.Pos(),
   478  				typ: n1,
   479  				jmp: jmp,
   480  				val: val,
   481  				idx: idx,
   482  			})
   483  		}
   484  	}
   485  	if defaultGoto == nil {
   486  		defaultGoto = br
   487  	}
   488  	if nilGoto == nil {
   489  		nilGoto = defaultGoto
   490  	}
   491  	ifNil.Body = []ir.Node{nilGoto}
   492  
   493  	// Now go through the list of cases, processing groups as we find them.
   494  	var concreteCases []oneCase
   495  	var interfaceCases []oneCase
   496  	flush := func() {
   497  		// Process all the concrete types first. Because we handle shadowing
   498  		// below, it is correct to do all the concrete types before all of
   499  		// the interface types.
   500  		// The concrete cases can all be handled without a runtime call.
   501  		if len(concreteCases) > 0 {
   502  			var clauses []typeClause
   503  			for _, c := range concreteCases {
   504  				as := ir.NewAssignListStmt(c.pos, ir.OAS2,
   505  					[]ir.Node{ir.BlankNode, s.okName},                               // _, ok =
   506  					[]ir.Node{ir.NewTypeAssertExpr(c.pos, s.srcName, c.typ.Type())}) // iface.(type)
   507  				nif := ir.NewIfStmt(c.pos, s.okName, []ir.Node{c.jmp}, nil)
   508  				clauses = append(clauses, typeClause{
   509  					hash: types.TypeHash(c.typ.Type()),
   510  					body: []ir.Node{typecheck.Stmt(as), typecheck.Stmt(nif)},
   511  				})
   512  			}
   513  			s.flush(clauses, &sw.Compiled)
   514  			concreteCases = concreteCases[:0]
   515  		}
   516  
   517  		// The "any" case, if it exists, must be the last interface case, because
   518  		// it would shadow all subsequent cases. Strip it off here so the runtime
   519  		// call only needs to handle non-empty interfaces.
   520  		var anyGoto ir.Node
   521  		if len(interfaceCases) > 0 && interfaceCases[len(interfaceCases)-1].typ.Type().IsEmptyInterface() {
   522  			anyGoto = interfaceCases[len(interfaceCases)-1].jmp
   523  			interfaceCases = interfaceCases[:len(interfaceCases)-1]
   524  		}
   525  
   526  		// Next, process all the interface types with a single call to the runtime.
   527  		if len(interfaceCases) > 0 {
   528  
   529  			// Build an internal/abi.InterfaceSwitch descriptor to pass to the runtime.
   530  			lsym := types.LocalPkg.Lookup(fmt.Sprintf(".interfaceSwitch.%d", interfaceSwitchGen)).LinksymABI(obj.ABI0)
   531  			interfaceSwitchGen++
   532  			c := rttype.NewCursor(lsym, 0, rttype.InterfaceSwitch)
   533  			c.Field("Cache").WritePtr(typecheck.LookupRuntimeVar("emptyInterfaceSwitchCache"))
   534  			c.Field("NCases").WriteInt(int64(len(interfaceCases)))
   535  			array, sizeDelta := c.Field("Cases").ModifyArray(len(interfaceCases))
   536  			for i, c := range interfaceCases {
   537  				array.Elem(i).WritePtr(reflectdata.TypeLinksym(c.typ.Type()))
   538  			}
   539  			objw.Global(lsym, int32(rttype.InterfaceSwitch.Size()+sizeDelta), obj.LOCAL)
   540  			// The GC only needs to see the first pointer in the structure (all the others
   541  			// are to static locations). So the InterfaceSwitch type itself is fine, even
   542  			// though it might not cover the whole array we wrote above.
   543  			lsym.Gotype = reflectdata.TypeLinksym(rttype.InterfaceSwitch)
   544  
   545  			// Call runtime to do switch
   546  			// case, itab = runtime.interfaceSwitch(&descriptor, typeof(arg))
   547  			var typeArg ir.Node
   548  			if s.srcName.Type().IsEmptyInterface() {
   549  				typeArg = ir.NewConvExpr(base.Pos, ir.OCONVNOP, types.Types[types.TUINT8].PtrTo(), srcItab)
   550  			} else {
   551  				typeArg = itabType(srcItab)
   552  			}
   553  			caseVar := typecheck.TempAt(base.Pos, ir.CurFunc, types.Types[types.TINT])
   554  			isw := ir.NewInterfaceSwitchStmt(base.Pos, caseVar, s.itabName, typeArg, dotHash, lsym)
   555  			sw.Compiled.Append(isw)
   556  
   557  			// Switch on the result of the call (or cache lookup).
   558  			var newCases []*ir.CaseClause
   559  			for i, c := range interfaceCases {
   560  				newCases = append(newCases, &ir.CaseClause{
   561  					List: []ir.Node{ir.NewInt(base.Pos, int64(i))},
   562  					Body: []ir.Node{c.jmp},
   563  				})
   564  			}
   565  			// TODO: add len(newCases) case, mark switch as bounded
   566  			sw2 := ir.NewSwitchStmt(base.Pos, caseVar, newCases)
   567  			sw.Compiled.Append(typecheck.Stmt(sw2))
   568  			interfaceCases = interfaceCases[:0]
   569  		}
   570  
   571  		if anyGoto != nil {
   572  			// We've already handled the nil case, so everything
   573  			// that reaches here matches the "any" case.
   574  			sw.Compiled.Append(anyGoto)
   575  		}
   576  	}
   577  caseLoop:
   578  	for _, c := range cases {
   579  		if c.typ.Op() == ir.ODYNAMICTYPE {
   580  			flush() // process all previous cases
   581  			dt := c.typ.(*ir.DynamicType)
   582  			dot := ir.NewDynamicTypeAssertExpr(c.pos, ir.ODYNAMICDOTTYPE, s.srcName, dt.RType)
   583  			dot.ITab = dt.ITab
   584  			dot.SetType(c.typ.Type())
   585  			dot.SetTypecheck(1)
   586  
   587  			as := ir.NewAssignListStmt(c.pos, ir.OAS2, nil, nil)
   588  			as.Lhs = []ir.Node{ir.BlankNode, s.okName} // _, ok =
   589  			if c.val != nil {
   590  				as.Lhs[0] = c.val // tmpVar, ok =
   591  			}
   592  			as.Rhs = []ir.Node{dot}
   593  			typecheck.Stmt(as)
   594  
   595  			nif := ir.NewIfStmt(c.pos, s.okName, []ir.Node{c.jmp}, nil)
   596  			sw.Compiled.Append(as, nif)
   597  			continue
   598  		}
   599  
   600  		// Check for shadowing (a case that will never fire because
   601  		// a previous case would have always fired first). This check
   602  		// allows us to reorder concrete and interface cases.
   603  		// (TODO: these should be vet failures, maybe?)
   604  		for _, ic := range interfaceCases {
   605  			// An interface type case will shadow all
   606  			// subsequent types that implement that interface.
   607  			if typecheck.Implements(c.typ.Type(), ic.typ.Type()) {
   608  				continue caseLoop
   609  			}
   610  			// Note that we don't need to worry about:
   611  			// 1. Two concrete types shadowing each other. That's
   612  			//    disallowed by the spec.
   613  			// 2. A concrete type shadowing an interface type.
   614  			//    That can never happen, as interface types can
   615  			//    be satisfied by an infinite set of concrete types.
   616  			// The correctness of this step also depends on handling
   617  			// the dynamic type cases separately, as we do above.
   618  		}
   619  
   620  		if c.typ.Type().IsInterface() {
   621  			interfaceCases = append(interfaceCases, c)
   622  		} else {
   623  			concreteCases = append(concreteCases, c)
   624  		}
   625  	}
   626  	flush()
   627  
   628  	sw.Compiled.Append(defaultGoto) // if none of the cases matched
   629  
   630  	// Now generate all the case bodies
   631  	for i, ncase := range sw.Cases {
   632  		sw.Compiled.Append(ir.NewLabelStmt(ncase.Pos(), labels[i]))
   633  		if caseVar := ncase.Var; caseVar != nil {
   634  			val := s.srcName
   635  			if len(ncase.List) == 1 {
   636  				// single type. We have to downcast the input value to the target type.
   637  				if ncase.List[0].Op() == ir.OTYPE { // single compile-time known type
   638  					t := ncase.List[0].Type()
   639  					if t.IsInterface() {
   640  						// This case is an interface. Build case value from input interface.
   641  						// The data word will always be the same, but the itab/type changes.
   642  						if t.IsEmptyInterface() {
   643  							var typ ir.Node
   644  							if s.srcName.Type().IsEmptyInterface() {
   645  								// E->E, nothing to do, type is already correct.
   646  								typ = srcItab
   647  							} else {
   648  								// I->E, load type out of itab
   649  								typ = itabType(srcItab)
   650  								typ.SetPos(ncase.Pos())
   651  							}
   652  							val = ir.NewBinaryExpr(ncase.Pos(), ir.OMAKEFACE, typ, srcData)
   653  						} else {
   654  							// The itab we need was returned by a runtime.interfaceSwitch call.
   655  							val = ir.NewBinaryExpr(ncase.Pos(), ir.OMAKEFACE, s.itabName, srcData)
   656  						}
   657  					} else {
   658  						// This case is a concrete type, just read its value out of the interface.
   659  						val = ifaceData(ncase.Pos(), s.srcName, t)
   660  					}
   661  				} else if ncase.List[0].Op() == ir.ODYNAMICTYPE { // single runtime known type
   662  					var found bool
   663  					for _, c := range cases {
   664  						if c.idx == i {
   665  							val = c.val
   666  							found = val != nil
   667  							break
   668  						}
   669  					}
   670  					// the tmpVar must always be found
   671  					if !found {
   672  						base.Fatalf("an error occurred when processing type switch case %v", ncase.List[0])
   673  					}
   674  				} else if ir.IsNil(ncase.List[0]) {
   675  				} else {
   676  					base.Fatalf("unhandled type switch case %v", ncase.List[0])
   677  				}
   678  				val.SetType(caseVar.Type())
   679  				val.SetTypecheck(1)
   680  			}
   681  			l := []ir.Node{
   682  				ir.NewDecl(ncase.Pos(), ir.ODCL, caseVar),
   683  				ir.NewAssignStmt(ncase.Pos(), caseVar, val),
   684  			}
   685  			typecheck.Stmts(l)
   686  			sw.Compiled.Append(l...)
   687  		}
   688  		sw.Compiled.Append(ncase.Body...)
   689  		sw.Compiled.Append(br)
   690  	}
   691  
   692  	walkStmtList(sw.Compiled)
   693  	sw.Tag = nil
   694  	sw.Cases = nil
   695  }
   696  
   697  var interfaceSwitchGen int
   698  
   699  // typeHashFieldOf returns an expression to select the type hash field
   700  // from an interface's descriptor word (whether a *runtime._type or
   701  // *runtime.itab pointer).
   702  func typeHashFieldOf(pos src.XPos, itab *ir.UnaryExpr) *ir.SelectorExpr {
   703  	if itab.Op() != ir.OITAB {
   704  		base.Fatalf("expected OITAB, got %v", itab.Op())
   705  	}
   706  	var hashField *types.Field
   707  	if itab.X.Type().IsEmptyInterface() {
   708  		// runtime._type's hash field
   709  		if rtypeHashField == nil {
   710  			rtypeHashField = runtimeField("hash", rttype.Type.OffsetOf("Hash"), types.Types[types.TUINT32])
   711  		}
   712  		hashField = rtypeHashField
   713  	} else {
   714  		// runtime.itab's hash field
   715  		if itabHashField == nil {
   716  			itabHashField = runtimeField("hash", rttype.ITab.OffsetOf("Hash"), types.Types[types.TUINT32])
   717  		}
   718  		hashField = itabHashField
   719  	}
   720  	return boundedDotPtr(pos, itab, hashField)
   721  }
   722  
   723  var rtypeHashField, itabHashField *types.Field
   724  
   725  // A typeSwitch walks a type switch.
   726  type typeSwitch struct {
   727  	// Temporary variables (i.e., ONAMEs) used by type switch dispatch logic:
   728  	srcName  ir.Node // value being type-switched on
   729  	hashName ir.Node // type hash of the value being type-switched on
   730  	okName   ir.Node // boolean used for comma-ok type assertions
   731  	itabName ir.Node // itab value to use for first word of non-empty interface
   732  }
   733  
   734  type typeClause struct {
   735  	hash uint32
   736  	body ir.Nodes
   737  }
   738  
   739  func (s *typeSwitch) flush(cc []typeClause, compiled *ir.Nodes) {
   740  	if len(cc) == 0 {
   741  		return
   742  	}
   743  
   744  	slices.SortFunc(cc, func(a, b typeClause) int { return cmp.Compare(a.hash, b.hash) })
   745  
   746  	// Combine adjacent cases with the same hash.
   747  	merged := cc[:1]
   748  	for _, c := range cc[1:] {
   749  		last := &merged[len(merged)-1]
   750  		if last.hash == c.hash {
   751  			last.body.Append(c.body.Take()...)
   752  		} else {
   753  			merged = append(merged, c)
   754  		}
   755  	}
   756  	cc = merged
   757  
   758  	if s.tryJumpTable(cc, compiled) {
   759  		return
   760  	}
   761  	binarySearch(len(cc), compiled,
   762  		func(i int) ir.Node {
   763  			return ir.NewBinaryExpr(base.Pos, ir.OLE, s.hashName, ir.NewInt(base.Pos, int64(cc[i-1].hash)))
   764  		},
   765  		func(i int, nif *ir.IfStmt) {
   766  			// TODO(mdempsky): Omit hash equality check if
   767  			// there's only one type.
   768  			c := cc[i]
   769  			nif.Cond = ir.NewBinaryExpr(base.Pos, ir.OEQ, s.hashName, ir.NewInt(base.Pos, int64(c.hash)))
   770  			nif.Body.Append(c.body.Take()...)
   771  		},
   772  	)
   773  }
   774  
   775  // Try to implement the clauses with a jump table. Returns true if successful.
   776  func (s *typeSwitch) tryJumpTable(cc []typeClause, out *ir.Nodes) bool {
   777  	const minCases = 5 // have at least minCases cases in the switch
   778  	if base.Flag.N != 0 || !ssagen.Arch.LinkArch.CanJumpTable || base.Ctxt.Retpoline {
   779  		return false
   780  	}
   781  	if len(cc) < minCases {
   782  		return false // not enough cases for it to be worth it
   783  	}
   784  	hashes := make([]uint32, len(cc))
   785  	// b = # of bits to use. Start with the minimum number of
   786  	// bits possible, but try a few larger sizes if needed.
   787  	b0 := bits.Len(uint(len(cc) - 1))
   788  	for b := b0; b < b0+3; b++ {
   789  	pickI:
   790  		for i := 0; i <= 32-b; i++ { // starting bit position
   791  			// Compute the hash we'd get from all the cases,
   792  			// selecting b bits starting at bit i.
   793  			hashes = hashes[:0]
   794  			for _, c := range cc {
   795  				h := c.hash >> i & (1<<b - 1)
   796  				hashes = append(hashes, h)
   797  			}
   798  			// Order by increasing hash.
   799  			slices.Sort(hashes)
   800  			for j := 1; j < len(hashes); j++ {
   801  				if hashes[j] == hashes[j-1] {
   802  					// There is a duplicate hash; try a different b/i pair.
   803  					continue pickI
   804  				}
   805  			}
   806  
   807  			// All hashes are distinct. Use these values of b and i.
   808  			h := s.hashName
   809  			if i != 0 {
   810  				h = ir.NewBinaryExpr(base.Pos, ir.ORSH, h, ir.NewInt(base.Pos, int64(i)))
   811  			}
   812  			h = ir.NewBinaryExpr(base.Pos, ir.OAND, h, ir.NewInt(base.Pos, int64(1<<b-1)))
   813  			h = typecheck.Expr(h)
   814  
   815  			// Build jump table.
   816  			jt := ir.NewJumpTableStmt(base.Pos, h)
   817  			jt.Cases = make([]constant.Value, 1<<b)
   818  			jt.Targets = make([]*types.Sym, 1<<b)
   819  			out.Append(jt)
   820  
   821  			// Start with all hashes going to the didn't-match target.
   822  			noMatch := typecheck.AutoLabel(".s")
   823  			for j := 0; j < 1<<b; j++ {
   824  				jt.Cases[j] = constant.MakeInt64(int64(j))
   825  				jt.Targets[j] = noMatch
   826  			}
   827  			// This statement is not reachable, but it will make it obvious that we don't
   828  			// fall through to the first case.
   829  			out.Append(ir.NewBranchStmt(base.Pos, ir.OGOTO, noMatch))
   830  
   831  			// Emit each of the actual cases.
   832  			for _, c := range cc {
   833  				h := c.hash >> i & (1<<b - 1)
   834  				label := typecheck.AutoLabel(".s")
   835  				jt.Targets[h] = label
   836  				out.Append(ir.NewLabelStmt(base.Pos, label))
   837  				out.Append(c.body...)
   838  				// We reach here if the hash matches but the type equality test fails.
   839  				out.Append(ir.NewBranchStmt(base.Pos, ir.OGOTO, noMatch))
   840  			}
   841  			// Emit point to go to if type doesn't match any case.
   842  			out.Append(ir.NewLabelStmt(base.Pos, noMatch))
   843  			return true
   844  		}
   845  	}
   846  	// Couldn't find a perfect hash. Fall back to binary search.
   847  	return false
   848  }
   849  
   850  // binarySearch constructs a binary search tree for handling n cases,
   851  // and appends it to out. It's used for efficiently implementing
   852  // switch statements.
   853  //
   854  // less(i) should return a boolean expression. If it evaluates true,
   855  // then cases before i will be tested; otherwise, cases i and later.
   856  //
   857  // leaf(i, nif) should setup nif (an OIF node) to test case i. In
   858  // particular, it should set nif.Cond and nif.Body.
   859  func binarySearch(n int, out *ir.Nodes, less func(i int) ir.Node, leaf func(i int, nif *ir.IfStmt)) {
   860  	const binarySearchMin = 4 // minimum number of cases for binary search
   861  
   862  	var do func(lo, hi int, out *ir.Nodes)
   863  	do = func(lo, hi int, out *ir.Nodes) {
   864  		n := hi - lo
   865  		if n < binarySearchMin {
   866  			for i := lo; i < hi; i++ {
   867  				nif := ir.NewIfStmt(base.Pos, nil, nil, nil)
   868  				leaf(i, nif)
   869  				base.Pos = base.Pos.WithNotStmt()
   870  				nif.Cond = typecheck.Expr(nif.Cond)
   871  				nif.Cond = typecheck.DefaultLit(nif.Cond, nil)
   872  				out.Append(nif)
   873  				out = &nif.Else
   874  			}
   875  			return
   876  		}
   877  
   878  		half := lo + n/2
   879  		nif := ir.NewIfStmt(base.Pos, nil, nil, nil)
   880  		nif.Cond = less(half)
   881  		base.Pos = base.Pos.WithNotStmt()
   882  		nif.Cond = typecheck.Expr(nif.Cond)
   883  		nif.Cond = typecheck.DefaultLit(nif.Cond, nil)
   884  		do(lo, half, &nif.Body)
   885  		do(half, hi, &nif.Else)
   886  		out.Append(nif)
   887  	}
   888  
   889  	do(0, n, out)
   890  }
   891  
   892  func stringSearch(expr ir.Node, cc []exprClause, out *ir.Nodes) {
   893  	if len(cc) < 4 {
   894  		// Short list, just do brute force equality checks.
   895  		for _, c := range cc {
   896  			nif := ir.NewIfStmt(base.Pos.WithNotStmt(), typecheck.DefaultLit(typecheck.Expr(c.test(expr)), nil), []ir.Node{c.jmp}, nil)
   897  			out.Append(nif)
   898  			out = &nif.Else
   899  		}
   900  		return
   901  	}
   902  
   903  	// The strategy here is to find a simple test to divide the set of possible strings
   904  	// that might match expr approximately in half.
   905  	// The test we're going to use is to do an ordered comparison of a single byte
   906  	// of expr to a constant. We will pick the index of that byte and the value we're
   907  	// comparing against to make the split as even as possible.
   908  	//   if expr[3] <= 'd' { ... search strings with expr[3] at 'd' or lower  ... }
   909  	//   else              { ... search strings with expr[3] at 'e' or higher ... }
   910  	//
   911  	// To add complication, we will do the ordered comparison in the signed domain.
   912  	// The reason for this is to prevent CSE from merging the load used for the
   913  	// ordered comparison with the load used for the later equality check.
   914  	//   if expr[3] <= 'd' { ... if expr[0] == 'f' && expr[1] == 'o' && expr[2] == 'o' && expr[3] == 'd' { ... } }
   915  	// If we did both expr[3] loads in the unsigned domain, they would be CSEd, and that
   916  	// would in turn defeat the combining of expr[0]...expr[3] into a single 4-byte load.
   917  	// See issue 48222.
   918  	// By using signed loads for the ordered comparison and unsigned loads for the
   919  	// equality comparison, they don't get CSEd and the equality comparisons will be
   920  	// done using wider loads.
   921  
   922  	n := len(ir.StringVal(cc[0].lo)) // Length of the constant strings.
   923  	bestScore := int64(0)            // measure of how good the split is.
   924  	bestIdx := 0                     // split using expr[bestIdx]
   925  	bestByte := int8(0)              // compare expr[bestIdx] against bestByte
   926  	for idx := 0; idx < n; idx++ {
   927  		for b := int8(-128); b < 127; b++ {
   928  			le := 0
   929  			for _, c := range cc {
   930  				s := ir.StringVal(c.lo)
   931  				if int8(s[idx]) <= b {
   932  					le++
   933  				}
   934  			}
   935  			score := int64(le) * int64(len(cc)-le)
   936  			if score > bestScore {
   937  				bestScore = score
   938  				bestIdx = idx
   939  				bestByte = b
   940  			}
   941  		}
   942  	}
   943  
   944  	// The split must be at least 1:n-1 because we have at least 2 distinct strings; they
   945  	// have to be different somewhere.
   946  	// TODO: what if the best split is still pretty bad?
   947  	if bestScore == 0 {
   948  		base.Fatalf("unable to split string set")
   949  	}
   950  
   951  	// Convert expr to a []int8
   952  	slice := ir.NewConvExpr(base.Pos, ir.OSTR2BYTESTMP, types.NewSlice(types.Types[types.TINT8]), expr)
   953  	slice.SetTypecheck(1) // legacy typechecker doesn't handle this op
   954  	slice.MarkNonNil()
   955  	// Load the byte we're splitting on.
   956  	load := ir.NewIndexExpr(base.Pos, slice, ir.NewInt(base.Pos, int64(bestIdx)))
   957  	// Compare with the value we're splitting on.
   958  	cmp := ir.Node(ir.NewBinaryExpr(base.Pos, ir.OLE, load, ir.NewInt(base.Pos, int64(bestByte))))
   959  	cmp = typecheck.DefaultLit(typecheck.Expr(cmp), nil)
   960  	nif := ir.NewIfStmt(base.Pos, cmp, nil, nil)
   961  
   962  	var le []exprClause
   963  	var gt []exprClause
   964  	for _, c := range cc {
   965  		s := ir.StringVal(c.lo)
   966  		if int8(s[bestIdx]) <= bestByte {
   967  			le = append(le, c)
   968  		} else {
   969  			gt = append(gt, c)
   970  		}
   971  	}
   972  	stringSearch(expr, le, &nif.Body)
   973  	stringSearch(expr, gt, &nif.Else)
   974  	out.Append(nif)
   975  
   976  	// TODO: if expr[bestIdx] has enough different possible values, use a jump table.
   977  }
   978  

View as plain text