1
2
3
4
5 package fuzz
6
7 import (
8 "bytes"
9 "fmt"
10 "go/ast"
11 "go/parser"
12 "go/token"
13 "math"
14 "strconv"
15 "strings"
16 "unicode/utf8"
17 )
18
19
20 var encVersion1 = "go test fuzz v1"
21
22
23
24 func marshalCorpusFile(vals ...any) []byte {
25 if len(vals) == 0 {
26 panic("must have at least one value to marshal")
27 }
28 b := bytes.NewBuffer([]byte(encVersion1 + "\n"))
29
30
31 for _, val := range vals {
32 switch t := val.(type) {
33 case int, int8, int16, int64, uint, uint16, uint32, uint64, bool:
34 fmt.Fprintf(b, "%T(%v)\n", t, t)
35 case float32:
36 if math.IsNaN(float64(t)) && math.Float32bits(t) != math.Float32bits(float32(math.NaN())) {
37
38
39
40
41
42
43
44
45
46
47
48
49 fmt.Fprintf(b, "math.Float32frombits(0x%x)\n", math.Float32bits(t))
50 } else {
51
52
53
54
55
56
57
58
59
60 fmt.Fprintf(b, "%T(%v)\n", t, t)
61 }
62 case float64:
63 if math.IsNaN(t) && math.Float64bits(t) != math.Float64bits(math.NaN()) {
64 fmt.Fprintf(b, "math.Float64frombits(0x%x)\n", math.Float64bits(t))
65 } else {
66 fmt.Fprintf(b, "%T(%v)\n", t, t)
67 }
68 case string:
69 fmt.Fprintf(b, "string(%q)\n", t)
70 case rune:
71
72
73
74
75
76
77
78
79
80
81
82
83 if utf8.ValidRune(t) {
84 fmt.Fprintf(b, "rune(%q)\n", t)
85 } else {
86 fmt.Fprintf(b, "int32(%v)\n", t)
87 }
88 case byte:
89
90
91 fmt.Fprintf(b, "byte(%q)\n", t)
92 case []byte:
93 fmt.Fprintf(b, "[]byte(%q)\n", t)
94 default:
95 panic(fmt.Sprintf("unsupported type: %T", t))
96 }
97 }
98 return b.Bytes()
99 }
100
101
102 func unmarshalCorpusFile(b []byte) ([]any, error) {
103 if len(b) == 0 {
104 return nil, fmt.Errorf("cannot unmarshal empty string")
105 }
106 lines := bytes.Split(b, []byte("\n"))
107 if len(lines) < 2 {
108 return nil, fmt.Errorf("must include version and at least one value")
109 }
110 version := strings.TrimSuffix(string(lines[0]), "\r")
111 if version != encVersion1 {
112 return nil, fmt.Errorf("unknown encoding version: %s", version)
113 }
114 var vals []any
115 for _, line := range lines[1:] {
116 line = bytes.TrimSpace(line)
117 if len(line) == 0 {
118 continue
119 }
120 v, err := parseCorpusValue(line)
121 if err != nil {
122 return nil, fmt.Errorf("malformed line %q: %v", line, err)
123 }
124 vals = append(vals, v)
125 }
126 return vals, nil
127 }
128
129 func parseCorpusValue(line []byte) (any, error) {
130 fs := token.NewFileSet()
131 expr, err := parser.ParseExprFrom(fs, "(test)", line, 0)
132 if err != nil {
133 return nil, err
134 }
135 call, ok := expr.(*ast.CallExpr)
136 if !ok {
137 return nil, fmt.Errorf("expected call expression")
138 }
139 if len(call.Args) != 1 {
140 return nil, fmt.Errorf("expected call expression with 1 argument; got %d", len(call.Args))
141 }
142 arg := call.Args[0]
143
144 if arrayType, ok := call.Fun.(*ast.ArrayType); ok {
145 if arrayType.Len != nil {
146 return nil, fmt.Errorf("expected []byte or primitive type")
147 }
148 elt, ok := arrayType.Elt.(*ast.Ident)
149 if !ok || elt.Name != "byte" {
150 return nil, fmt.Errorf("expected []byte")
151 }
152 lit, ok := arg.(*ast.BasicLit)
153 if !ok || lit.Kind != token.STRING {
154 return nil, fmt.Errorf("string literal required for type []byte")
155 }
156 s, err := strconv.Unquote(lit.Value)
157 if err != nil {
158 return nil, err
159 }
160 return []byte(s), nil
161 }
162
163 var idType *ast.Ident
164 if selector, ok := call.Fun.(*ast.SelectorExpr); ok {
165 xIdent, ok := selector.X.(*ast.Ident)
166 if !ok || xIdent.Name != "math" {
167 return nil, fmt.Errorf("invalid selector type")
168 }
169 switch selector.Sel.Name {
170 case "Float64frombits":
171 idType = &ast.Ident{Name: "float64-bits"}
172 case "Float32frombits":
173 idType = &ast.Ident{Name: "float32-bits"}
174 default:
175 return nil, fmt.Errorf("invalid selector type")
176 }
177 } else {
178 idType, ok = call.Fun.(*ast.Ident)
179 if !ok {
180 return nil, fmt.Errorf("expected []byte or primitive type")
181 }
182 if idType.Name == "bool" {
183 id, ok := arg.(*ast.Ident)
184 if !ok {
185 return nil, fmt.Errorf("malformed bool")
186 }
187 if id.Name == "true" {
188 return true, nil
189 } else if id.Name == "false" {
190 return false, nil
191 } else {
192 return nil, fmt.Errorf("true or false required for type bool")
193 }
194 }
195 }
196
197 var (
198 val string
199 kind token.Token
200 )
201 if op, ok := arg.(*ast.UnaryExpr); ok {
202 switch lit := op.X.(type) {
203 case *ast.BasicLit:
204 if op.Op != token.SUB {
205 return nil, fmt.Errorf("unsupported operation on int/float: %v", op.Op)
206 }
207
208 val = op.Op.String() + lit.Value
209 kind = lit.Kind
210 case *ast.Ident:
211 if lit.Name != "Inf" {
212 return nil, fmt.Errorf("expected operation on int or float type")
213 }
214 if op.Op == token.SUB {
215 val = "-Inf"
216 } else {
217 val = "+Inf"
218 }
219 kind = token.FLOAT
220 default:
221 return nil, fmt.Errorf("expected operation on int or float type")
222 }
223 } else {
224 switch lit := arg.(type) {
225 case *ast.BasicLit:
226 val, kind = lit.Value, lit.Kind
227 case *ast.Ident:
228 if lit.Name != "NaN" {
229 return nil, fmt.Errorf("literal value required for primitive type")
230 }
231 val, kind = "NaN", token.FLOAT
232 default:
233 return nil, fmt.Errorf("literal value required for primitive type")
234 }
235 }
236
237 switch typ := idType.Name; typ {
238 case "string":
239 if kind != token.STRING {
240 return nil, fmt.Errorf("string literal value required for type string")
241 }
242 return strconv.Unquote(val)
243 case "byte", "rune":
244 if kind == token.INT {
245 switch typ {
246 case "rune":
247 return parseInt(val, typ)
248 case "byte":
249 return parseUint(val, typ)
250 }
251 }
252 if kind != token.CHAR {
253 return nil, fmt.Errorf("character literal required for byte/rune types")
254 }
255 n := len(val)
256 if n < 2 {
257 return nil, fmt.Errorf("malformed character literal, missing single quotes")
258 }
259 code, _, _, err := strconv.UnquoteChar(val[1:n-1], '\'')
260 if err != nil {
261 return nil, err
262 }
263 if typ == "rune" {
264 return code, nil
265 }
266 if code >= 256 {
267 return nil, fmt.Errorf("can only encode single byte to a byte type")
268 }
269 return byte(code), nil
270 case "int", "int8", "int16", "int32", "int64":
271 if kind != token.INT {
272 return nil, fmt.Errorf("integer literal required for int types")
273 }
274 return parseInt(val, typ)
275 case "uint", "uint8", "uint16", "uint32", "uint64":
276 if kind != token.INT {
277 return nil, fmt.Errorf("integer literal required for uint types")
278 }
279 return parseUint(val, typ)
280 case "float32":
281 if kind != token.FLOAT && kind != token.INT {
282 return nil, fmt.Errorf("float or integer literal required for float32 type")
283 }
284 v, err := strconv.ParseFloat(val, 32)
285 return float32(v), err
286 case "float64":
287 if kind != token.FLOAT && kind != token.INT {
288 return nil, fmt.Errorf("float or integer literal required for float64 type")
289 }
290 return strconv.ParseFloat(val, 64)
291 case "float32-bits":
292 if kind != token.INT {
293 return nil, fmt.Errorf("integer literal required for math.Float32frombits type")
294 }
295 bits, err := parseUint(val, "uint32")
296 if err != nil {
297 return nil, err
298 }
299 return math.Float32frombits(bits.(uint32)), nil
300 case "float64-bits":
301 if kind != token.FLOAT && kind != token.INT {
302 return nil, fmt.Errorf("integer literal required for math.Float64frombits type")
303 }
304 bits, err := parseUint(val, "uint64")
305 if err != nil {
306 return nil, err
307 }
308 return math.Float64frombits(bits.(uint64)), nil
309 default:
310 return nil, fmt.Errorf("expected []byte or primitive type")
311 }
312 }
313
314
315 func parseInt(val, typ string) (any, error) {
316 switch typ {
317 case "int":
318
319
320
321
322
323 i, err := strconv.ParseInt(val, 0, 64)
324 return int(i), err
325 case "int8":
326 i, err := strconv.ParseInt(val, 0, 8)
327 return int8(i), err
328 case "int16":
329 i, err := strconv.ParseInt(val, 0, 16)
330 return int16(i), err
331 case "int32", "rune":
332 i, err := strconv.ParseInt(val, 0, 32)
333 return int32(i), err
334 case "int64":
335 return strconv.ParseInt(val, 0, 64)
336 default:
337 panic("unreachable")
338 }
339 }
340
341
342 func parseUint(val, typ string) (any, error) {
343 switch typ {
344 case "uint":
345 i, err := strconv.ParseUint(val, 0, 64)
346 return uint(i), err
347 case "uint8", "byte":
348 i, err := strconv.ParseUint(val, 0, 8)
349 return uint8(i), err
350 case "uint16":
351 i, err := strconv.ParseUint(val, 0, 16)
352 return uint16(i), err
353 case "uint32":
354 i, err := strconv.ParseUint(val, 0, 32)
355 return uint32(i), err
356 case "uint64":
357 return strconv.ParseUint(val, 0, 64)
358 default:
359 panic("unreachable")
360 }
361 }
362
View as plain text