Source file src/cmd/compile/internal/inline/inlheur/texpr_classify_test.go

     1  // Copyright 2023 The Go Authors. All rights reserved.
     2  // Use of this source code is governed by a BSD-style
     3  // license that can be found in the LICENSE file.
     4  
     5  package inlheur
     6  
     7  import (
     8  	"cmd/compile/internal/base"
     9  	"cmd/compile/internal/ir"
    10  	"cmd/compile/internal/typecheck"
    11  	"cmd/compile/internal/types"
    12  	"cmd/internal/obj"
    13  	"cmd/internal/src"
    14  	"cmd/internal/sys"
    15  	"go/constant"
    16  	"testing"
    17  )
    18  
    19  var pos src.XPos
    20  var local *types.Pkg
    21  var f *ir.Func
    22  
    23  func init() {
    24  	types.PtrSize = 8
    25  	types.RegSize = 8
    26  	types.MaxWidth = 1 << 50
    27  	base.Ctxt = &obj.Link{Arch: &obj.LinkArch{Arch: &sys.Arch{Alignment: 1, CanMergeLoads: true}}}
    28  
    29  	typecheck.InitUniverse()
    30  	local = types.NewPkg("", "")
    31  	fsym := &types.Sym{
    32  		Pkg:  types.NewPkg("my/import/path", "path"),
    33  		Name: "function",
    34  	}
    35  	f = ir.NewFunc(src.NoXPos, src.NoXPos, fsym, nil)
    36  }
    37  
    38  type state struct {
    39  	ntab map[string]*ir.Name
    40  }
    41  
    42  func mkstate() *state {
    43  	return &state{
    44  		ntab: make(map[string]*ir.Name),
    45  	}
    46  }
    47  
    48  func bin(x ir.Node, op ir.Op, y ir.Node) ir.Node {
    49  	return ir.NewBinaryExpr(pos, op, x, y)
    50  }
    51  
    52  func conv(x ir.Node, t *types.Type) ir.Node {
    53  	return ir.NewConvExpr(pos, ir.OCONV, t, x)
    54  }
    55  
    56  func logical(x ir.Node, op ir.Op, y ir.Node) ir.Node {
    57  	return ir.NewLogicalExpr(pos, op, x, y)
    58  }
    59  
    60  func un(op ir.Op, x ir.Node) ir.Node {
    61  	return ir.NewUnaryExpr(pos, op, x)
    62  }
    63  
    64  func liti(i int64) ir.Node {
    65  	return ir.NewBasicLit(pos, types.Types[types.TINT64], constant.MakeInt64(i))
    66  }
    67  
    68  func lits(s string) ir.Node {
    69  	return ir.NewBasicLit(pos, types.Types[types.TSTRING], constant.MakeString(s))
    70  }
    71  
    72  func (s *state) nm(name string, t *types.Type) *ir.Name {
    73  	if n, ok := s.ntab[name]; ok {
    74  		if n.Type() != t {
    75  			panic("bad")
    76  		}
    77  		return n
    78  	}
    79  	sym := local.Lookup(name)
    80  	nn := ir.NewNameAt(pos, sym, t)
    81  	s.ntab[name] = nn
    82  	return nn
    83  }
    84  
    85  func (s *state) nmi64(name string) *ir.Name {
    86  	return s.nm(name, types.Types[types.TINT64])
    87  }
    88  
    89  func (s *state) nms(name string) *ir.Name {
    90  	return s.nm(name, types.Types[types.TSTRING])
    91  }
    92  
    93  func TestClassifyIntegerCompare(t *testing.T) {
    94  
    95  	// (n < 10 || n > 100) && (n >= 12 || n <= 99 || n != 101)
    96  	s := mkstate()
    97  	nn := s.nmi64("n")
    98  	nlt10 := bin(nn, ir.OLT, liti(10))         // n < 10
    99  	ngt100 := bin(nn, ir.OGT, liti(100))       // n > 100
   100  	nge12 := bin(nn, ir.OGE, liti(12))         // n >= 12
   101  	nle99 := bin(nn, ir.OLE, liti(99))         // n < 10
   102  	nne101 := bin(nn, ir.ONE, liti(101))       // n != 101
   103  	noror1 := logical(nlt10, ir.OOROR, ngt100) // n < 10 || n > 100
   104  	noror2 := logical(nge12, ir.OOROR, nle99)  // n >= 12 || n <= 99
   105  	noror3 := logical(noror2, ir.OOROR, nne101)
   106  	nandand := typecheck.Expr(logical(noror1, ir.OANDAND, noror3))
   107  
   108  	wantv := true
   109  	v := ShouldFoldIfNameConstant(nandand, []*ir.Name{nn})
   110  	if v != wantv {
   111  		t.Errorf("wanted shouldfold(%v) %v, got %v", nandand, wantv, v)
   112  	}
   113  }
   114  
   115  func TestClassifyStringCompare(t *testing.T) {
   116  
   117  	// s != "foo" && s < "ooblek" && s > "plarkish"
   118  	s := mkstate()
   119  	nn := s.nms("s")
   120  	snefoo := bin(nn, ir.ONE, lits("foo"))     // s != "foo"
   121  	sltoob := bin(nn, ir.OLT, lits("ooblek"))  // s < "ooblek"
   122  	sgtpk := bin(nn, ir.OGT, lits("plarkish")) // s > "plarkish"
   123  	nandand := logical(snefoo, ir.OANDAND, sltoob)
   124  	top := typecheck.Expr(logical(nandand, ir.OANDAND, sgtpk))
   125  
   126  	wantv := true
   127  	v := ShouldFoldIfNameConstant(top, []*ir.Name{nn})
   128  	if v != wantv {
   129  		t.Errorf("wanted shouldfold(%v) %v, got %v", top, wantv, v)
   130  	}
   131  }
   132  
   133  func TestClassifyIntegerArith(t *testing.T) {
   134  	// n+1 ^ n-3 * n/2 + n<<9 + n>>2 - n&^7
   135  
   136  	s := mkstate()
   137  	nn := s.nmi64("n")
   138  	np1 := bin(nn, ir.OADD, liti(1))     // n+1
   139  	nm3 := bin(nn, ir.OSUB, liti(3))     // n-3
   140  	nd2 := bin(nn, ir.ODIV, liti(2))     // n/2
   141  	nls9 := bin(nn, ir.OLSH, liti(9))    // n<<9
   142  	nrs2 := bin(nn, ir.ORSH, liti(2))    // n>>2
   143  	nan7 := bin(nn, ir.OANDNOT, liti(7)) // n&^7
   144  	c1xor := bin(np1, ir.OXOR, nm3)
   145  	c2mul := bin(c1xor, ir.OMUL, nd2)
   146  	c3add := bin(c2mul, ir.OADD, nls9)
   147  	c4add := bin(c3add, ir.OADD, nrs2)
   148  	c5sub := bin(c4add, ir.OSUB, nan7)
   149  	top := typecheck.Expr(c5sub)
   150  
   151  	wantv := true
   152  	v := ShouldFoldIfNameConstant(top, []*ir.Name{nn})
   153  	if v != wantv {
   154  		t.Errorf("wanted shouldfold(%v) %v, got %v", top, wantv, v)
   155  	}
   156  }
   157  
   158  func TestClassifyAssortedShifts(t *testing.T) {
   159  
   160  	s := mkstate()
   161  	nn := s.nmi64("n")
   162  	badcases := []ir.Node{
   163  		bin(liti(3), ir.OLSH, nn), // 3<<n
   164  		bin(liti(7), ir.ORSH, nn), // 7>>n
   165  	}
   166  	for _, bc := range badcases {
   167  		wantv := false
   168  		v := ShouldFoldIfNameConstant(typecheck.Expr(bc), []*ir.Name{nn})
   169  		if v != wantv {
   170  			t.Errorf("wanted shouldfold(%v) %v, got %v", bc, wantv, v)
   171  		}
   172  	}
   173  }
   174  
   175  func TestClassifyFloat(t *testing.T) {
   176  	// float32(n) + float32(10)
   177  	s := mkstate()
   178  	nn := s.nm("n", types.Types[types.TUINT32])
   179  	f1 := conv(nn, types.Types[types.TFLOAT32])
   180  	f2 := conv(liti(10), types.Types[types.TFLOAT32])
   181  	add := bin(f1, ir.OADD, f2)
   182  
   183  	wantv := false
   184  	v := ShouldFoldIfNameConstant(typecheck.Expr(add), []*ir.Name{nn})
   185  	if v != wantv {
   186  		t.Errorf("wanted shouldfold(%v) %v, got %v", add, wantv, v)
   187  	}
   188  }
   189  
   190  func TestMultipleNamesAllUsed(t *testing.T) {
   191  	// n != 101 && m < 2
   192  	s := mkstate()
   193  	nn := s.nmi64("n")
   194  	nm := s.nmi64("m")
   195  	nne101 := bin(nn, ir.ONE, liti(101)) // n != 101
   196  	mlt2 := bin(nm, ir.OLT, liti(2))     // m < 2
   197  	nandand := typecheck.Expr(logical(nne101, ir.OANDAND, mlt2))
   198  
   199  	// all names used
   200  	wantv := true
   201  	v := ShouldFoldIfNameConstant(nandand, []*ir.Name{nn, nm})
   202  	if v != wantv {
   203  		t.Errorf("wanted shouldfold(%v) %v, got %v", nandand, wantv, v)
   204  	}
   205  
   206  	// not all names used
   207  	wantv = false
   208  	v = ShouldFoldIfNameConstant(nne101, []*ir.Name{nn, nm})
   209  	if v != wantv {
   210  		t.Errorf("wanted shouldfold(%v) %v, got %v", nne101, wantv, v)
   211  	}
   212  
   213  	// other names used.
   214  	np := s.nmi64("p")
   215  	pne0 := bin(np, ir.ONE, liti(101)) // p != 0
   216  	noror := logical(nandand, ir.OOROR, pne0)
   217  	wantv = false
   218  	v = ShouldFoldIfNameConstant(noror, []*ir.Name{nn, nm})
   219  	if v != wantv {
   220  		t.Errorf("wanted shouldfold(%v) %v, got %v", noror, wantv, v)
   221  	}
   222  }
   223  

View as plain text