// Copyright 2025 The Go Authors. All rights reserved. // Use of this source code is governed by a BSD-style // license that can be found in the LICENSE file. // This program generates tests to verify that conditional comparisons // with constants are properly optimized by the compiler through constant folding. // The generated test should be compiled with a known working version of Go. // Run with `go run conditionalCmpConstGen.go` to generate a file called // conditionalCmpConst_test.go in the grandparent directory. package main import ( "bytes" "fmt" "go/format" "log" "os" "strings" ) // IntegerConstraint defines a type constraint for all integer types func writeIntegerConstraint(w *bytes.Buffer) { fmt.Fprintf(w, "type IntegerConstraint interface {\n") fmt.Fprintf(w, "\tint | uint | int8 | uint8 | int16 | ") fmt.Fprintf(w, "uint16 | int32 | uint32 | int64 | uint64\n") fmt.Fprintf(w, "}\n\n") } // TestCase describes a parameterized test case with comparison and logical operations func writeTestCaseStruct(w *bytes.Buffer) { fmt.Fprintf(w, "type TestCase[T IntegerConstraint] struct {\n") fmt.Fprintf(w, "\tcmp1, cmp2 func(a, b T) bool\n") fmt.Fprintf(w, "\tcombine func(x, y bool) bool\n") fmt.Fprintf(w, "\ttargetFunc func(a, b, c, d T) bool\n") fmt.Fprintf(w, "\tcmp1Expr, cmp2Expr, logicalExpr string // String representations for debugging\n") fmt.Fprintf(w, "}\n\n") } // BoundaryValues contains base value and its variations for edge case testing func writeBoundaryValuesStruct(w *bytes.Buffer) { fmt.Fprintf(w, "type BoundaryValues[T IntegerConstraint] struct {\n") fmt.Fprintf(w, "\tbase T\n") fmt.Fprintf(w, "\tvariants [3]T\n") fmt.Fprintf(w, "}\n\n") } // writeTypeDefinitions generates all necessary type declarations func writeTypeDefinitions(w *bytes.Buffer) { writeIntegerConstraint(w) writeTestCaseStruct(w) writeBoundaryValuesStruct(w) } // comparisonOperators contains format strings for comparison operators var comparisonOperators = []string{ "%s == %s", "%s <= %s", "%s < %s", "%s != %s", "%s >= %s", "%s > %s", } // logicalOperators contains format strings for logical combination of boolean expressions var logicalOperators = []string{ "(%s) && (%s)", "(%s) && !(%s)", "!(%s) && (%s)", "!(%s) && !(%s)", "(%s) || (%s)", "(%s) || !(%s)", "!(%s) || (%s)", "!(%s) || !(%s)", } // writeComparator generates a comparator function based on the comparison operator func writeComparator(w *bytes.Buffer, fieldName, operatorFormat string) { expression := fmt.Sprintf(operatorFormat, "a", "b") fmt.Fprintf(w, "\t\t\t%s: func(a, b T) bool { return %s },\n", fieldName, expression) } // writeLogicalCombiner generates a function to combine two boolean values func writeLogicalCombiner(w *bytes.Buffer, logicalOperator string) { expression := fmt.Sprintf(logicalOperator, "x", "y") fmt.Fprintf(w, "\t\t\tcombine: func(x, y bool) bool { return %s },\n", expression) } // writeTargetFunction generates the target function with conditional expression func writeTargetFunction(w *bytes.Buffer, cmp1, cmp2, logicalOp string) { leftExpr := fmt.Sprintf(cmp1, "a", "b") rightExpr := fmt.Sprintf(cmp2, "c", "d") condition := fmt.Sprintf(logicalOp, leftExpr, rightExpr) fmt.Fprintf(w, "\t\t\ttargetFunc: func(a, b, c, d T) bool {\n") fmt.Fprintf(w, "\t\t\t\tif %s {\n", condition) fmt.Fprintf(w, "\t\t\t\t\treturn true\n") fmt.Fprintf(w, "\t\t\t\t}\n") fmt.Fprintf(w, "\t\t\t\treturn false\n") fmt.Fprintf(w, "\t\t\t},\n") } // writeTestCase creates a single test case with given comparison and logical operators func writeTestCase(w *bytes.Buffer, cmp1, cmp2, logicalOp string) { fmt.Fprintf(w, "\t\t{\n") writeComparator(w, "cmp1", cmp1) writeComparator(w, "cmp2", cmp2) writeLogicalCombiner(w, logicalOp) writeTargetFunction(w, cmp1, cmp2, logicalOp) // Store string representations for debugging cmp1Expr := fmt.Sprintf(cmp1, "a", "b") cmp2Expr := fmt.Sprintf(cmp2, "c", "d") logicalExpr := fmt.Sprintf(logicalOp, cmp1Expr, cmp2Expr) fmt.Fprintf(w, "\t\t\tcmp1Expr: %q,\n", cmp1Expr) fmt.Fprintf(w, "\t\t\tcmp2Expr: %q,\n", cmp2Expr) fmt.Fprintf(w, "\t\t\tlogicalExpr: %q,\n", logicalExpr) fmt.Fprintf(w, "\t\t},\n") } // generateTestCases creates a slice of all possible test cases func generateTestCases(w *bytes.Buffer) { fmt.Fprintf(w, "func generateTestCases[T IntegerConstraint]() []TestCase[T] {\n") fmt.Fprintf(w, "\treturn []TestCase[T]{\n") for _, cmp1 := range comparisonOperators { for _, cmp2 := range comparisonOperators { for _, logicalOp := range logicalOperators { writeTestCase(w, cmp1, cmp2, logicalOp) } } } fmt.Fprintf(w, "\t}\n") fmt.Fprintf(w, "}\n\n") } // TypeConfig defines a type and its test base value type TypeConfig struct { typeName, baseValue string } // typeConfigs contains all integer types to test with their base values var typeConfigs = []TypeConfig{ {typeName: "int8", baseValue: "1 << 6"}, {typeName: "uint8", baseValue: "1 << 6"}, {typeName: "int16", baseValue: "1 << 14"}, {typeName: "uint16", baseValue: "1 << 14"}, {typeName: "int32", baseValue: "1 << 30"}, {typeName: "uint32", baseValue: "1 << 30"}, {typeName: "int", baseValue: "1 << 30"}, {typeName: "uint", baseValue: "1 << 30"}, {typeName: "int64", baseValue: "1 << 62"}, {typeName: "uint64", baseValue: "1 << 62"}, } // writeTypeSpecificTest generates test for a specific integer type func writeTypeSpecificTest(w *bytes.Buffer, typeName, baseValue string) { typeTitle := strings.Title(typeName) fmt.Fprintf(w, "func Test%sConditionalCmpConst(t *testing.T) {\n", typeTitle) fmt.Fprintf(w, "\ttestCases := generateTestCases[%s]()\n", typeName) fmt.Fprintf(w, "\tbase := %s(%s)\n", typeName, baseValue) fmt.Fprintf(w, "\tvalues := [3]%s{base - 1, base, base + 1}\n\n", typeName) fmt.Fprintf(w, "\tfor _, tc := range testCases {\n") fmt.Fprintf(w, "\t\ta, c := base, base\n") fmt.Fprintf(w, "\t\tfor _, b := range values {\n") fmt.Fprintf(w, "\t\t\tfor _, d := range values {\n") fmt.Fprintf(w, "\t\t\t\texpected := tc.combine(tc.cmp1(a, b), tc.cmp2(c, d))\n") fmt.Fprintf(w, "\t\t\t\tactual := tc.targetFunc(a, b, c, d)\n") fmt.Fprintf(w, "\t\t\t\tif actual != expected {\n") fmt.Fprintf(w, "\t\t\t\t\tt.Errorf(\"conditional comparison failed:\\n\"+\n") fmt.Fprintf(w, "\t\t\t\t\t\t\" type: %%T\\n\"+\n") fmt.Fprintf(w, "\t\t\t\t\t\t\" condition: %%s\\n\"+\n") fmt.Fprintf(w, "\t\t\t\t\t\t\" values: a=%%v, b=%%v, c=%%v, d=%%v\\n\"+\n") fmt.Fprintf(w, "\t\t\t\t\t\t\" cmp1(a,b)=%%v (%%s)\\n\"+\n") fmt.Fprintf(w, "\t\t\t\t\t\t\" cmp2(c,d)=%%v (%%s)\\n\"+\n") fmt.Fprintf(w, "\t\t\t\t\t\t\" expected: combine(%%v, %%v)=%%v\\n\"+\n") fmt.Fprintf(w, "\t\t\t\t\t\t\" actual: %%v\\n\"+\n") fmt.Fprintf(w, "\t\t\t\t\t\t\" logical expression: %%s\",\n") fmt.Fprintf(w, "\t\t\t\t\t\ta,\n") fmt.Fprintf(w, "\t\t\t\t\t\ttc.logicalExpr,\n") fmt.Fprintf(w, "\t\t\t\t\t\ta, b, c, d,\n") fmt.Fprintf(w, "\t\t\t\t\t\ttc.cmp1(a, b), tc.cmp1Expr,\n") fmt.Fprintf(w, "\t\t\t\t\t\ttc.cmp2(c, d), tc.cmp2Expr,\n") fmt.Fprintf(w, "\t\t\t\t\t\ttc.cmp1(a, b), tc.cmp2(c, d), expected,\n") fmt.Fprintf(w, "\t\t\t\t\t\tactual,\n") fmt.Fprintf(w, "\t\t\t\t\t\ttc.logicalExpr)\n") fmt.Fprintf(w, "\t\t\t\t}\n") fmt.Fprintf(w, "\t\t\t}\n") fmt.Fprintf(w, "\t\t}\n") fmt.Fprintf(w, "\t}\n") fmt.Fprintf(w, "}\n\n") } // writeAllTests generates tests for all supported integer types func writeAllTests(w *bytes.Buffer) { for _, config := range typeConfigs { writeTypeSpecificTest(w, config.typeName, config.baseValue) } } func main() { buffer := new(bytes.Buffer) // Header for generated file fmt.Fprintf(buffer, "// Code generated by conditionalCmpConstGen.go; DO NOT EDIT.\n\n") fmt.Fprintf(buffer, "package test\n\n") fmt.Fprintf(buffer, "import \"testing\"\n\n") // Generate type definitions writeTypeDefinitions(buffer) // Generate test cases generateTestCases(buffer) // Generate specific tests for each integer type writeAllTests(buffer) // Format generated source code rawSource := buffer.Bytes() formattedSource, err := format.Source(rawSource) if err != nil { // Output raw source for debugging if formatting fails fmt.Printf("%s\n", rawSource) log.Fatal("error formatting generated code: ", err) } // Write to output file outputPath := "../../conditionalCmpConst_test.go" if err := os.WriteFile(outputPath, formattedSource, 0666); err != nil { log.Fatal("failed to write output file: ", err) } log.Printf("Tests successfully generated to %s", outputPath) }