1
2
3
4
5
6
7
8
9
10
11 package main
12
13 import (
14 "bytes"
15 "fmt"
16 "go/format"
17 "log"
18 "os"
19 "strings"
20 )
21
22
23 func writeIntegerConstraint(w *bytes.Buffer) {
24 fmt.Fprintf(w, "type IntegerConstraint interface {\n")
25 fmt.Fprintf(w, "\tint | uint | int8 | uint8 | int16 | ")
26 fmt.Fprintf(w, "uint16 | int32 | uint32 | int64 | uint64\n")
27 fmt.Fprintf(w, "}\n\n")
28 }
29
30
31 func writeTestCaseStruct(w *bytes.Buffer) {
32 fmt.Fprintf(w, "type TestCase[T IntegerConstraint] struct {\n")
33 fmt.Fprintf(w, "\tcmp1, cmp2 func(a, b T) bool\n")
34 fmt.Fprintf(w, "\tcombine func(x, y bool) bool\n")
35 fmt.Fprintf(w, "\ttargetFunc func(a, b, c, d T) bool\n")
36 fmt.Fprintf(w, "\tcmp1Expr, cmp2Expr, logicalExpr string // String representations for debugging\n")
37 fmt.Fprintf(w, "}\n\n")
38 }
39
40
41 func writeBoundaryValuesStruct(w *bytes.Buffer) {
42 fmt.Fprintf(w, "type BoundaryValues[T IntegerConstraint] struct {\n")
43 fmt.Fprintf(w, "\tbase T\n")
44 fmt.Fprintf(w, "\tvariants [3]T\n")
45 fmt.Fprintf(w, "}\n\n")
46 }
47
48
49 func writeTypeDefinitions(w *bytes.Buffer) {
50 writeIntegerConstraint(w)
51 writeTestCaseStruct(w)
52 writeBoundaryValuesStruct(w)
53 }
54
55
56 var comparisonOperators = []string{
57 "%s == %s", "%s <= %s", "%s < %s",
58 "%s != %s", "%s >= %s", "%s > %s",
59 }
60
61
62 var logicalOperators = []string{
63 "(%s) && (%s)", "(%s) && !(%s)", "!(%s) && (%s)", "!(%s) && !(%s)",
64 "(%s) || (%s)", "(%s) || !(%s)", "!(%s) || (%s)", "!(%s) || !(%s)",
65 }
66
67
68 func writeComparator(w *bytes.Buffer, fieldName, operatorFormat string) {
69 expression := fmt.Sprintf(operatorFormat, "a", "b")
70 fmt.Fprintf(w, "\t\t\t%s: func(a, b T) bool { return %s },\n", fieldName, expression)
71 }
72
73
74 func writeLogicalCombiner(w *bytes.Buffer, logicalOperator string) {
75 expression := fmt.Sprintf(logicalOperator, "x", "y")
76 fmt.Fprintf(w, "\t\t\tcombine: func(x, y bool) bool { return %s },\n", expression)
77 }
78
79
80 func writeTargetFunction(w *bytes.Buffer, cmp1, cmp2, logicalOp string) {
81 leftExpr := fmt.Sprintf(cmp1, "a", "b")
82 rightExpr := fmt.Sprintf(cmp2, "c", "d")
83 condition := fmt.Sprintf(logicalOp, leftExpr, rightExpr)
84
85 fmt.Fprintf(w, "\t\t\ttargetFunc: func(a, b, c, d T) bool {\n")
86 fmt.Fprintf(w, "\t\t\t\tif %s {\n", condition)
87 fmt.Fprintf(w, "\t\t\t\t\treturn true\n")
88 fmt.Fprintf(w, "\t\t\t\t}\n")
89 fmt.Fprintf(w, "\t\t\t\treturn false\n")
90 fmt.Fprintf(w, "\t\t\t},\n")
91 }
92
93
94 func writeTestCase(w *bytes.Buffer, cmp1, cmp2, logicalOp string) {
95 fmt.Fprintf(w, "\t\t{\n")
96 writeComparator(w, "cmp1", cmp1)
97 writeComparator(w, "cmp2", cmp2)
98 writeLogicalCombiner(w, logicalOp)
99 writeTargetFunction(w, cmp1, cmp2, logicalOp)
100
101
102 cmp1Expr := fmt.Sprintf(cmp1, "a", "b")
103 cmp2Expr := fmt.Sprintf(cmp2, "c", "d")
104 logicalExpr := fmt.Sprintf(logicalOp, cmp1Expr, cmp2Expr)
105
106 fmt.Fprintf(w, "\t\t\tcmp1Expr: %q,\n", cmp1Expr)
107 fmt.Fprintf(w, "\t\t\tcmp2Expr: %q,\n", cmp2Expr)
108 fmt.Fprintf(w, "\t\t\tlogicalExpr: %q,\n", logicalExpr)
109
110 fmt.Fprintf(w, "\t\t},\n")
111 }
112
113
114 func generateTestCases(w *bytes.Buffer) {
115 fmt.Fprintf(w, "func generateTestCases[T IntegerConstraint]() []TestCase[T] {\n")
116 fmt.Fprintf(w, "\treturn []TestCase[T]{\n")
117
118 for _, cmp1 := range comparisonOperators {
119 for _, cmp2 := range comparisonOperators {
120 for _, logicalOp := range logicalOperators {
121 writeTestCase(w, cmp1, cmp2, logicalOp)
122 }
123 }
124 }
125
126 fmt.Fprintf(w, "\t}\n")
127 fmt.Fprintf(w, "}\n\n")
128 }
129
130
131 type TypeConfig struct {
132 typeName, baseValue string
133 }
134
135
136 var typeConfigs = []TypeConfig{
137 {typeName: "int8", baseValue: "1 << 6"},
138 {typeName: "uint8", baseValue: "1 << 6"},
139 {typeName: "int16", baseValue: "1 << 14"},
140 {typeName: "uint16", baseValue: "1 << 14"},
141 {typeName: "int32", baseValue: "1 << 30"},
142 {typeName: "uint32", baseValue: "1 << 30"},
143 {typeName: "int", baseValue: "1 << 30"},
144 {typeName: "uint", baseValue: "1 << 30"},
145 {typeName: "int64", baseValue: "1 << 62"},
146 {typeName: "uint64", baseValue: "1 << 62"},
147 }
148
149
150 func writeTypeSpecificTest(w *bytes.Buffer, typeName, baseValue string) {
151 typeTitle := strings.Title(typeName)
152
153 fmt.Fprintf(w, "func Test%sConditionalCmpConst(t *testing.T) {\n", typeTitle)
154
155 fmt.Fprintf(w, "\ttestCases := generateTestCases[%s]()\n", typeName)
156 fmt.Fprintf(w, "\tbase := %s(%s)\n", typeName, baseValue)
157 fmt.Fprintf(w, "\tvalues := [3]%s{base - 1, base, base + 1}\n\n", typeName)
158
159 fmt.Fprintf(w, "\tfor _, tc := range testCases {\n")
160 fmt.Fprintf(w, "\t\ta, c := base, base\n")
161 fmt.Fprintf(w, "\t\tfor _, b := range values {\n")
162 fmt.Fprintf(w, "\t\t\tfor _, d := range values {\n")
163 fmt.Fprintf(w, "\t\t\t\texpected := tc.combine(tc.cmp1(a, b), tc.cmp2(c, d))\n")
164 fmt.Fprintf(w, "\t\t\t\tactual := tc.targetFunc(a, b, c, d)\n")
165 fmt.Fprintf(w, "\t\t\t\tif actual != expected {\n")
166 fmt.Fprintf(w, "\t\t\t\t\tt.Errorf(\"conditional comparison failed:\\n\"+\n")
167 fmt.Fprintf(w, "\t\t\t\t\t\t\" type: %%T\\n\"+\n")
168 fmt.Fprintf(w, "\t\t\t\t\t\t\" condition: %%s\\n\"+\n")
169 fmt.Fprintf(w, "\t\t\t\t\t\t\" values: a=%%v, b=%%v, c=%%v, d=%%v\\n\"+\n")
170 fmt.Fprintf(w, "\t\t\t\t\t\t\" cmp1(a,b)=%%v (%%s)\\n\"+\n")
171 fmt.Fprintf(w, "\t\t\t\t\t\t\" cmp2(c,d)=%%v (%%s)\\n\"+\n")
172 fmt.Fprintf(w, "\t\t\t\t\t\t\" expected: combine(%%v, %%v)=%%v\\n\"+\n")
173 fmt.Fprintf(w, "\t\t\t\t\t\t\" actual: %%v\\n\"+\n")
174 fmt.Fprintf(w, "\t\t\t\t\t\t\" logical expression: %%s\",\n")
175 fmt.Fprintf(w, "\t\t\t\t\t\ta,\n")
176 fmt.Fprintf(w, "\t\t\t\t\t\ttc.logicalExpr,\n")
177 fmt.Fprintf(w, "\t\t\t\t\t\ta, b, c, d,\n")
178 fmt.Fprintf(w, "\t\t\t\t\t\ttc.cmp1(a, b), tc.cmp1Expr,\n")
179 fmt.Fprintf(w, "\t\t\t\t\t\ttc.cmp2(c, d), tc.cmp2Expr,\n")
180 fmt.Fprintf(w, "\t\t\t\t\t\ttc.cmp1(a, b), tc.cmp2(c, d), expected,\n")
181 fmt.Fprintf(w, "\t\t\t\t\t\tactual,\n")
182 fmt.Fprintf(w, "\t\t\t\t\t\ttc.logicalExpr)\n")
183 fmt.Fprintf(w, "\t\t\t\t}\n")
184 fmt.Fprintf(w, "\t\t\t}\n")
185 fmt.Fprintf(w, "\t\t}\n")
186 fmt.Fprintf(w, "\t}\n")
187
188 fmt.Fprintf(w, "}\n\n")
189 }
190
191
192 func writeAllTests(w *bytes.Buffer) {
193 for _, config := range typeConfigs {
194 writeTypeSpecificTest(w, config.typeName, config.baseValue)
195 }
196 }
197
198 func main() {
199 buffer := new(bytes.Buffer)
200
201
202 fmt.Fprintf(buffer, "// Code generated by conditionalCmpConstGen.go; DO NOT EDIT.\n\n")
203 fmt.Fprintf(buffer, "package test\n\n")
204 fmt.Fprintf(buffer, "import \"testing\"\n\n")
205
206
207 writeTypeDefinitions(buffer)
208
209
210 generateTestCases(buffer)
211
212
213 writeAllTests(buffer)
214
215
216 rawSource := buffer.Bytes()
217 formattedSource, err := format.Source(rawSource)
218 if err != nil {
219
220 fmt.Printf("%s\n", rawSource)
221 log.Fatal("error formatting generated code: ", err)
222 }
223
224
225 outputPath := "../../conditionalCmpConst_test.go"
226 if err := os.WriteFile(outputPath, formattedSource, 0666); err != nil {
227 log.Fatal("failed to write output file: ", err)
228 }
229
230 log.Printf("Tests successfully generated to %s", outputPath)
231 }
232
View as plain text