Source file src/cmd/compile/internal/test/testdata/gen/conditionalCmpConstGen.go

     1  // Copyright 2025 The Go Authors. All rights reserved.
     2  // Use of this source code is governed by a BSD-style
     3  // license that can be found in the LICENSE file.
     4  
     5  // This program generates tests to verify that conditional comparisons
     6  // with constants are properly optimized by the compiler through constant folding.
     7  // The generated test should be compiled with a known working version of Go.
     8  // Run with `go run conditionalCmpConstGen.go` to generate a file called
     9  // conditionalCmpConst_test.go in the grandparent directory.
    10  
    11  package main
    12  
    13  import (
    14  	"bytes"
    15  	"fmt"
    16  	"go/format"
    17  	"log"
    18  	"os"
    19  	"strings"
    20  )
    21  
    22  // IntegerConstraint defines a type constraint for all integer types
    23  func writeIntegerConstraint(w *bytes.Buffer) {
    24  	fmt.Fprintf(w, "type IntegerConstraint interface {\n")
    25  	fmt.Fprintf(w, "\tint | uint | int8 | uint8 | int16 | ")
    26  	fmt.Fprintf(w, "uint16 | int32 | uint32 | int64 | uint64\n")
    27  	fmt.Fprintf(w, "}\n\n")
    28  }
    29  
    30  // TestCase describes a parameterized test case with comparison and logical operations
    31  func writeTestCaseStruct(w *bytes.Buffer) {
    32  	fmt.Fprintf(w, "type TestCase[T IntegerConstraint] struct {\n")
    33  	fmt.Fprintf(w, "\tcmp1, cmp2 func(a, b T) bool\n")
    34  	fmt.Fprintf(w, "\tcombine func(x, y bool) bool\n")
    35  	fmt.Fprintf(w, "\ttargetFunc func(a, b, c, d T) bool\n")
    36  	fmt.Fprintf(w, "\tcmp1Expr, cmp2Expr, logicalExpr string // String representations for debugging\n")
    37  	fmt.Fprintf(w, "}\n\n")
    38  }
    39  
    40  // BoundaryValues contains base value and its variations for edge case testing
    41  func writeBoundaryValuesStruct(w *bytes.Buffer) {
    42  	fmt.Fprintf(w, "type BoundaryValues[T IntegerConstraint] struct {\n")
    43  	fmt.Fprintf(w, "\tbase T\n")
    44  	fmt.Fprintf(w, "\tvariants [3]T\n")
    45  	fmt.Fprintf(w, "}\n\n")
    46  }
    47  
    48  // writeTypeDefinitions generates all necessary type declarations
    49  func writeTypeDefinitions(w *bytes.Buffer) {
    50  	writeIntegerConstraint(w)
    51  	writeTestCaseStruct(w)
    52  	writeBoundaryValuesStruct(w)
    53  }
    54  
    55  // comparisonOperators contains format strings for comparison operators
    56  var comparisonOperators = []string{
    57  	"%s == %s", "%s <= %s", "%s < %s",
    58  	"%s != %s", "%s >= %s", "%s > %s",
    59  }
    60  
    61  // logicalOperators contains format strings for logical combination of boolean expressions
    62  var logicalOperators = []string{
    63  	"(%s) && (%s)", "(%s) && !(%s)", "!(%s) && (%s)", "!(%s) && !(%s)",
    64  	"(%s) || (%s)", "(%s) || !(%s)", "!(%s) || (%s)", "!(%s) || !(%s)",
    65  }
    66  
    67  // writeComparator generates a comparator function based on the comparison operator
    68  func writeComparator(w *bytes.Buffer, fieldName, operatorFormat string) {
    69  	expression := fmt.Sprintf(operatorFormat, "a", "b")
    70  	fmt.Fprintf(w, "\t\t\t%s: func(a, b T) bool { return %s },\n", fieldName, expression)
    71  }
    72  
    73  // writeLogicalCombiner generates a function to combine two boolean values
    74  func writeLogicalCombiner(w *bytes.Buffer, logicalOperator string) {
    75  	expression := fmt.Sprintf(logicalOperator, "x", "y")
    76  	fmt.Fprintf(w, "\t\t\tcombine: func(x, y bool) bool { return %s },\n", expression)
    77  }
    78  
    79  // writeTargetFunction generates the target function with conditional expression
    80  func writeTargetFunction(w *bytes.Buffer, cmp1, cmp2, logicalOp string) {
    81  	leftExpr := fmt.Sprintf(cmp1, "a", "b")
    82  	rightExpr := fmt.Sprintf(cmp2, "c", "d")
    83  	condition := fmt.Sprintf(logicalOp, leftExpr, rightExpr)
    84  
    85  	fmt.Fprintf(w, "\t\t\ttargetFunc: func(a, b, c, d T) bool {\n")
    86  	fmt.Fprintf(w, "\t\t\t\tif %s {\n", condition)
    87  	fmt.Fprintf(w, "\t\t\t\t\treturn true\n")
    88  	fmt.Fprintf(w, "\t\t\t\t}\n")
    89  	fmt.Fprintf(w, "\t\t\t\treturn false\n")
    90  	fmt.Fprintf(w, "\t\t\t},\n")
    91  }
    92  
    93  // writeTestCase creates a single test case with given comparison and logical operators
    94  func writeTestCase(w *bytes.Buffer, cmp1, cmp2, logicalOp string) {
    95  	fmt.Fprintf(w, "\t\t{\n")
    96  	writeComparator(w, "cmp1", cmp1)
    97  	writeComparator(w, "cmp2", cmp2)
    98  	writeLogicalCombiner(w, logicalOp)
    99  	writeTargetFunction(w, cmp1, cmp2, logicalOp)
   100  
   101  	// Store string representations for debugging
   102  	cmp1Expr := fmt.Sprintf(cmp1, "a", "b")
   103  	cmp2Expr := fmt.Sprintf(cmp2, "c", "d")
   104  	logicalExpr := fmt.Sprintf(logicalOp, cmp1Expr, cmp2Expr)
   105  
   106  	fmt.Fprintf(w, "\t\t\tcmp1Expr: %q,\n", cmp1Expr)
   107  	fmt.Fprintf(w, "\t\t\tcmp2Expr: %q,\n", cmp2Expr)
   108  	fmt.Fprintf(w, "\t\t\tlogicalExpr: %q,\n", logicalExpr)
   109  
   110  	fmt.Fprintf(w, "\t\t},\n")
   111  }
   112  
   113  // generateTestCases creates a slice of all possible test cases
   114  func generateTestCases(w *bytes.Buffer) {
   115  	fmt.Fprintf(w, "func generateTestCases[T IntegerConstraint]() []TestCase[T] {\n")
   116  	fmt.Fprintf(w, "\treturn []TestCase[T]{\n")
   117  
   118  	for _, cmp1 := range comparisonOperators {
   119  		for _, cmp2 := range comparisonOperators {
   120  			for _, logicalOp := range logicalOperators {
   121  				writeTestCase(w, cmp1, cmp2, logicalOp)
   122  			}
   123  		}
   124  	}
   125  
   126  	fmt.Fprintf(w, "\t}\n")
   127  	fmt.Fprintf(w, "}\n\n")
   128  }
   129  
   130  // TypeConfig defines a type and its test base value
   131  type TypeConfig struct {
   132  	typeName, baseValue string
   133  }
   134  
   135  // typeConfigs contains all integer types to test with their base values
   136  var typeConfigs = []TypeConfig{
   137  	{typeName: "int8", baseValue: "1 << 6"},
   138  	{typeName: "uint8", baseValue: "1 << 6"},
   139  	{typeName: "int16", baseValue: "1 << 14"},
   140  	{typeName: "uint16", baseValue: "1 << 14"},
   141  	{typeName: "int32", baseValue: "1 << 30"},
   142  	{typeName: "uint32", baseValue: "1 << 30"},
   143  	{typeName: "int", baseValue: "1 << 30"},
   144  	{typeName: "uint", baseValue: "1 << 30"},
   145  	{typeName: "int64", baseValue: "1 << 62"},
   146  	{typeName: "uint64", baseValue: "1 << 62"},
   147  }
   148  
   149  // writeTypeSpecificTest generates test for a specific integer type
   150  func writeTypeSpecificTest(w *bytes.Buffer, typeName, baseValue string) {
   151  	typeTitle := strings.Title(typeName)
   152  
   153  	fmt.Fprintf(w, "func Test%sConditionalCmpConst(t *testing.T) {\n", typeTitle)
   154  
   155  	fmt.Fprintf(w, "\ttestCases := generateTestCases[%s]()\n", typeName)
   156  	fmt.Fprintf(w, "\tbase := %s(%s)\n", typeName, baseValue)
   157  	fmt.Fprintf(w, "\tvalues := [3]%s{base - 1, base, base + 1}\n\n", typeName)
   158  
   159  	fmt.Fprintf(w, "\tfor _, tc := range testCases {\n")
   160  	fmt.Fprintf(w, "\t\ta, c := base, base\n")
   161  	fmt.Fprintf(w, "\t\tfor _, b := range values {\n")
   162  	fmt.Fprintf(w, "\t\t\tfor _, d := range values {\n")
   163  	fmt.Fprintf(w, "\t\t\t\texpected := tc.combine(tc.cmp1(a, b), tc.cmp2(c, d))\n")
   164  	fmt.Fprintf(w, "\t\t\t\tactual := tc.targetFunc(a, b, c, d)\n")
   165  	fmt.Fprintf(w, "\t\t\t\tif actual != expected {\n")
   166  	fmt.Fprintf(w, "\t\t\t\t\tt.Errorf(\"conditional comparison failed:\\n\"+\n")
   167  	fmt.Fprintf(w, "\t\t\t\t\t\t\"  type: %%T\\n\"+\n")
   168  	fmt.Fprintf(w, "\t\t\t\t\t\t\"  condition: %%s\\n\"+\n")
   169  	fmt.Fprintf(w, "\t\t\t\t\t\t\"  values: a=%%v, b=%%v, c=%%v, d=%%v\\n\"+\n")
   170  	fmt.Fprintf(w, "\t\t\t\t\t\t\"  cmp1(a,b)=%%v (%%s)\\n\"+\n")
   171  	fmt.Fprintf(w, "\t\t\t\t\t\t\"  cmp2(c,d)=%%v (%%s)\\n\"+\n")
   172  	fmt.Fprintf(w, "\t\t\t\t\t\t\"  expected: combine(%%v, %%v)=%%v\\n\"+\n")
   173  	fmt.Fprintf(w, "\t\t\t\t\t\t\"  actual: %%v\\n\"+\n")
   174  	fmt.Fprintf(w, "\t\t\t\t\t\t\"  logical expression: %%s\",\n")
   175  	fmt.Fprintf(w, "\t\t\t\t\t\ta,\n")
   176  	fmt.Fprintf(w, "\t\t\t\t\t\ttc.logicalExpr,\n")
   177  	fmt.Fprintf(w, "\t\t\t\t\t\ta, b, c, d,\n")
   178  	fmt.Fprintf(w, "\t\t\t\t\t\ttc.cmp1(a, b), tc.cmp1Expr,\n")
   179  	fmt.Fprintf(w, "\t\t\t\t\t\ttc.cmp2(c, d), tc.cmp2Expr,\n")
   180  	fmt.Fprintf(w, "\t\t\t\t\t\ttc.cmp1(a, b), tc.cmp2(c, d), expected,\n")
   181  	fmt.Fprintf(w, "\t\t\t\t\t\tactual,\n")
   182  	fmt.Fprintf(w, "\t\t\t\t\t\ttc.logicalExpr)\n")
   183  	fmt.Fprintf(w, "\t\t\t\t}\n")
   184  	fmt.Fprintf(w, "\t\t\t}\n")
   185  	fmt.Fprintf(w, "\t\t}\n")
   186  	fmt.Fprintf(w, "\t}\n")
   187  
   188  	fmt.Fprintf(w, "}\n\n")
   189  }
   190  
   191  // writeAllTests generates tests for all supported integer types
   192  func writeAllTests(w *bytes.Buffer) {
   193  	for _, config := range typeConfigs {
   194  		writeTypeSpecificTest(w, config.typeName, config.baseValue)
   195  	}
   196  }
   197  
   198  func main() {
   199  	buffer := new(bytes.Buffer)
   200  
   201  	// Header for generated file
   202  	fmt.Fprintf(buffer, "// Code generated by conditionalCmpConstGen.go; DO NOT EDIT.\n\n")
   203  	fmt.Fprintf(buffer, "package test\n\n")
   204  	fmt.Fprintf(buffer, "import \"testing\"\n\n")
   205  
   206  	// Generate type definitions
   207  	writeTypeDefinitions(buffer)
   208  
   209  	// Generate test cases
   210  	generateTestCases(buffer)
   211  
   212  	// Generate specific tests for each integer type
   213  	writeAllTests(buffer)
   214  
   215  	// Format generated source code
   216  	rawSource := buffer.Bytes()
   217  	formattedSource, err := format.Source(rawSource)
   218  	if err != nil {
   219  		// Output raw source for debugging if formatting fails
   220  		fmt.Printf("%s\n", rawSource)
   221  		log.Fatal("error formatting generated code: ", err)
   222  	}
   223  
   224  	// Write to output file
   225  	outputPath := "../../conditionalCmpConst_test.go"
   226  	if err := os.WriteFile(outputPath, formattedSource, 0666); err != nil {
   227  		log.Fatal("failed to write output file: ", err)
   228  	}
   229  
   230  	log.Printf("Tests successfully generated to %s", outputPath)
   231  }
   232  

View as plain text