1
2
3
4
5
6
7
8
9
10
11 package main
12
13 import (
14 "bytes"
15 "fmt"
16 "go/format"
17 "log"
18 "math/big"
19 "sort"
20 )
21
22 const (
23 maxU64 = (1 << 64) - 1
24 maxU32 = (1 << 32) - 1
25 maxU16 = (1 << 16) - 1
26 maxU8 = (1 << 8) - 1
27
28 maxI64 = (1 << 63) - 1
29 maxI32 = (1 << 31) - 1
30 maxI16 = (1 << 15) - 1
31 maxI8 = (1 << 7) - 1
32
33 minI64 = -(1 << 63)
34 minI32 = -(1 << 31)
35 minI16 = -(1 << 15)
36 minI8 = -(1 << 7)
37 )
38
39 func cmp(left *big.Int, op string, right *big.Int) bool {
40 switch left.Cmp(right) {
41 case -1:
42 return op == "<" || op == "<=" || op == "!="
43 case 0:
44 return op == "==" || op == "<=" || op == ">="
45 case 1:
46 return op == ">" || op == ">=" || op == "!="
47 }
48 panic("unexpected comparison value")
49 }
50
51 func inRange(typ string, val *big.Int) bool {
52 min, max := &big.Int{}, &big.Int{}
53 switch typ {
54 case "uint64":
55 max = max.SetUint64(maxU64)
56 case "uint32":
57 max = max.SetUint64(maxU32)
58 case "uint16":
59 max = max.SetUint64(maxU16)
60 case "uint8":
61 max = max.SetUint64(maxU8)
62 case "int64":
63 min = min.SetInt64(minI64)
64 max = max.SetInt64(maxI64)
65 case "int32":
66 min = min.SetInt64(minI32)
67 max = max.SetInt64(maxI32)
68 case "int16":
69 min = min.SetInt64(minI16)
70 max = max.SetInt64(maxI16)
71 case "int8":
72 min = min.SetInt64(minI8)
73 max = max.SetInt64(maxI8)
74 default:
75 panic("unexpected type")
76 }
77 return cmp(min, "<=", val) && cmp(val, "<=", max)
78 }
79
80 func getValues(typ string) []*big.Int {
81 Uint := func(v uint64) *big.Int { return big.NewInt(0).SetUint64(v) }
82 Int := func(v int64) *big.Int { return big.NewInt(0).SetInt64(v) }
83 values := []*big.Int{
84
85 Uint(maxU64),
86 Uint(maxU64 - 1),
87 Uint(maxI64 + 1),
88 Uint(maxI64),
89 Uint(maxI64 - 1),
90 Uint(maxU32 + 1),
91 Uint(maxU32),
92 Uint(maxU32 - 1),
93 Uint(maxI32 + 1),
94 Uint(maxI32),
95 Uint(maxI32 - 1),
96 Uint(maxU16 + 1),
97 Uint(maxU16),
98 Uint(maxU16 - 1),
99 Uint(maxI16 + 1),
100 Uint(maxI16),
101 Uint(maxI16 - 1),
102 Uint(maxU8 + 1),
103 Uint(maxU8),
104 Uint(maxU8 - 1),
105 Uint(maxI8 + 1),
106 Uint(maxI8),
107 Uint(maxI8 - 1),
108 Uint(0),
109 Int(minI8 + 1),
110 Int(minI8),
111 Int(minI8 - 1),
112 Int(minI16 + 1),
113 Int(minI16),
114 Int(minI16 - 1),
115 Int(minI32 + 1),
116 Int(minI32),
117 Int(minI32 - 1),
118 Int(minI64 + 1),
119 Int(minI64),
120
121
122 Uint(1),
123 Int(-1),
124 Uint(0xff << 56),
125 Uint(0xff << 32),
126 Uint(0xff << 24),
127 }
128 sort.Slice(values, func(i, j int) bool { return values[i].Cmp(values[j]) == -1 })
129 var ret []*big.Int
130 for _, val := range values {
131 if !inRange(typ, val) {
132 continue
133 }
134 ret = append(ret, val)
135 }
136 return ret
137 }
138
139 func sigString(v *big.Int) string {
140 var t big.Int
141 t.Abs(v)
142 if v.Sign() == -1 {
143 return "neg" + t.String()
144 }
145 return t.String()
146 }
147
148 func main() {
149 types := []string{
150 "uint64", "uint32", "uint16", "uint8",
151 "int64", "int32", "int16", "int8",
152 }
153
154 w := new(bytes.Buffer)
155 fmt.Fprintf(w, "// Code generated by gen/cmpConstGen.go. DO NOT EDIT.\n\n")
156 fmt.Fprintf(w, "package main;\n")
157 fmt.Fprintf(w, "import (\"testing\"; \"reflect\"; \"runtime\";)\n")
158 fmt.Fprintf(w, "// results show the expected result for the elements left of, equal to and right of the index.\n")
159 fmt.Fprintf(w, "type result struct{l, e, r bool}\n")
160 fmt.Fprintf(w, "var (\n")
161 fmt.Fprintf(w, " eq = result{l: false, e: true, r: false}\n")
162 fmt.Fprintf(w, " ne = result{l: true, e: false, r: true}\n")
163 fmt.Fprintf(w, " lt = result{l: true, e: false, r: false}\n")
164 fmt.Fprintf(w, " le = result{l: true, e: true, r: false}\n")
165 fmt.Fprintf(w, " gt = result{l: false, e: false, r: true}\n")
166 fmt.Fprintf(w, " ge = result{l: false, e: true, r: true}\n")
167 fmt.Fprintf(w, ")\n")
168
169 operators := []struct{ op, name string }{
170 {"<", "lt"},
171 {"<=", "le"},
172 {">", "gt"},
173 {">=", "ge"},
174 {"==", "eq"},
175 {"!=", "ne"},
176 }
177
178 for _, typ := range types {
179
180 fmt.Fprintf(w, "\n// %v tests\n", typ)
181 values := getValues(typ)
182 fmt.Fprintf(w, "var %v_vals = []%v{\n", typ, typ)
183 for _, val := range values {
184 fmt.Fprintf(w, "%v,\n", val.String())
185 }
186 fmt.Fprintf(w, "}\n")
187
188
189 for _, r := range values {
190
191 sig := sigString(r)
192 for _, op := range operators {
193
194 fmt.Fprintf(w, "func %v_%v_%v(x %v) bool { return x %v %v; }\n", op.name, sig, typ, typ, op.op, r.String())
195 }
196 }
197
198
199 fmt.Fprintf(w, "var %v_tests = []struct{\n", typ)
200 fmt.Fprintf(w, " idx int // index of the constant used\n")
201 fmt.Fprintf(w, " exp result // expected results\n")
202 fmt.Fprintf(w, " fn func(%v) bool\n", typ)
203 fmt.Fprintf(w, "}{\n")
204 for i, r := range values {
205 sig := sigString(r)
206 for _, op := range operators {
207 fmt.Fprintf(w, "{idx: %v,", i)
208 fmt.Fprintf(w, "exp: %v,", op.name)
209 fmt.Fprintf(w, "fn: %v_%v_%v},\n", op.name, sig, typ)
210 }
211 }
212 fmt.Fprintf(w, "}\n")
213 }
214
215
216 fmt.Fprintf(w, "// TestComparisonsConst tests results for comparison operations against constants.\n")
217 fmt.Fprintf(w, "func TestComparisonsConst(t *testing.T) {\n")
218 for _, typ := range types {
219 fmt.Fprintf(w, "for i, test := range %v_tests {\n", typ)
220 fmt.Fprintf(w, " for j, x := range %v_vals {\n", typ)
221 fmt.Fprintf(w, " want := test.exp.l\n")
222 fmt.Fprintf(w, " if j == test.idx {\nwant = test.exp.e\n}")
223 fmt.Fprintf(w, " else if j > test.idx {\nwant = test.exp.r\n}\n")
224 fmt.Fprintf(w, " if test.fn(x) != want {\n")
225 fmt.Fprintf(w, " fn := runtime.FuncForPC(reflect.ValueOf(test.fn).Pointer()).Name()\n")
226 fmt.Fprintf(w, " t.Errorf(\"test failed: %%v(%%v) != %%v [type=%v i=%%v j=%%v idx=%%v]\", fn, x, want, i, j, test.idx)\n", typ)
227 fmt.Fprintf(w, " }\n")
228 fmt.Fprintf(w, " }\n")
229 fmt.Fprintf(w, "}\n")
230 }
231 fmt.Fprintf(w, "}\n")
232
233
234 b := w.Bytes()
235 src, err := format.Source(b)
236 if err != nil {
237 fmt.Printf("%s\n", b)
238 panic(err)
239 }
240
241
242 err = os.WriteFile("../cmpConst_test.go", src, 0666)
243 if err != nil {
244 log.Fatalf("can't write output: %v\n", err)
245 }
246 }
247
View as plain text