1
2
3
4
5 package sumdb
6
7 import (
8 "bytes"
9 "errors"
10 "fmt"
11 "strings"
12 "sync"
13 "sync/atomic"
14
15 "golang.org/x/mod/module"
16 "golang.org/x/mod/sumdb/note"
17 "golang.org/x/mod/sumdb/tlog"
18 )
19
20
21
22
23 type ClientOps interface {
24
25
26
27
28
29
30 ReadRemote(path string) ([]byte, error)
31
32
33
34
35
36
37
38
39
40
41 ReadConfig(file string) ([]byte, error)
42
43
44
45
46
47
48
49 WriteConfig(file string, old, new []byte) error
50
51
52
53
54
55
56 ReadCache(file string) ([]byte, error)
57
58
59 WriteCache(file string, data []byte)
60
61
62 Log(msg string)
63
64
65
66
67
68 SecurityError(msg string)
69 }
70
71
72 var ErrWriteConflict = errors.New("write conflict")
73
74
75 var ErrSecurity = errors.New("security error: misbehaving server")
76
77
78
79 type Client struct {
80 ops ClientOps
81
82 didLookup uint32
83
84
85 initOnce sync.Once
86 initErr error
87 name string
88 verifiers note.Verifiers
89 tileReader tileReader
90 tileHeight int
91 nosumdb string
92
93 record parCache
94 tileCache parCache
95
96 latestMu sync.Mutex
97 latest tlog.Tree
98 latestMsg []byte
99
100 tileSavedMu sync.Mutex
101 tileSaved map[tlog.Tile]bool
102 }
103
104
105 func NewClient(ops ClientOps) *Client {
106 return &Client{
107 ops: ops,
108 }
109 }
110
111
112
113 func (c *Client) init() error {
114 c.initOnce.Do(c.initWork)
115 return c.initErr
116 }
117
118
119 func (c *Client) initWork() {
120 defer func() {
121 if c.initErr != nil {
122 c.initErr = fmt.Errorf("initializing sumdb.Client: %v", c.initErr)
123 }
124 }()
125
126 c.tileReader.c = c
127 if c.tileHeight == 0 {
128 c.tileHeight = 8
129 }
130 c.tileSaved = make(map[tlog.Tile]bool)
131
132 vkey, err := c.ops.ReadConfig("key")
133 if err != nil {
134 c.initErr = err
135 return
136 }
137 verifier, err := note.NewVerifier(strings.TrimSpace(string(vkey)))
138 if err != nil {
139 c.initErr = err
140 return
141 }
142 c.verifiers = note.VerifierList(verifier)
143 c.name = verifier.Name()
144
145 if c.latest.N == 0 {
146 c.latest.Hash, err = tlog.TreeHash(0, nil)
147 if err != nil {
148 c.initErr = err
149 return
150 }
151 }
152
153 data, err := c.ops.ReadConfig(c.name + "/latest")
154 if err != nil {
155 c.initErr = err
156 return
157 }
158 if err := c.mergeLatest(data); err != nil {
159 c.initErr = err
160 return
161 }
162 }
163
164
165
166
167
168
169 func (c *Client) SetTileHeight(height int) {
170 if atomic.LoadUint32(&c.didLookup) != 0 {
171 panic("SetTileHeight used after Lookup")
172 }
173 if height <= 0 {
174 panic("invalid call to SetTileHeight")
175 }
176 if c.tileHeight != 0 {
177 panic("multiple calls to SetTileHeight")
178 }
179 c.tileHeight = height
180 }
181
182
183
184
185
186
187 func (c *Client) SetGONOSUMDB(list string) {
188 if atomic.LoadUint32(&c.didLookup) != 0 {
189 panic("SetGONOSUMDB used after Lookup")
190 }
191 if c.nosumdb != "" {
192 panic("multiple calls to SetGONOSUMDB")
193 }
194 c.nosumdb = list
195 }
196
197
198
199
200 var ErrGONOSUMDB = errors.New("skipped (listed in GONOSUMDB)")
201
202 func (c *Client) skip(target string) bool {
203 return module.MatchPrefixPatterns(c.nosumdb, target)
204 }
205
206
207
208
209 func (c *Client) Lookup(path, vers string) (lines []string, err error) {
210 atomic.StoreUint32(&c.didLookup, 1)
211
212 if c.skip(path) {
213 return nil, ErrGONOSUMDB
214 }
215
216 defer func() {
217 if err != nil {
218 err = fmt.Errorf("%s@%s: %v", path, vers, err)
219 }
220 }()
221
222 if err := c.init(); err != nil {
223 return nil, err
224 }
225
226
227 epath, err := module.EscapePath(path)
228 if err != nil {
229 return nil, err
230 }
231 evers, err := module.EscapeVersion(strings.TrimSuffix(vers, "/go.mod"))
232 if err != nil {
233 return nil, err
234 }
235 remotePath := "/lookup/" + epath + "@" + evers
236 file := c.name + remotePath
237
238
239
240
241
242
243 type cached struct {
244 data []byte
245 err error
246 }
247 result := c.record.Do(file, func() interface{} {
248
249 writeCache := false
250 data, err := c.ops.ReadCache(file)
251 if err != nil {
252 data, err = c.ops.ReadRemote(remotePath)
253 if err != nil {
254 return cached{nil, err}
255 }
256 writeCache = true
257 }
258
259
260 id, text, treeMsg, err := tlog.ParseRecord(data)
261 if err != nil {
262 return cached{nil, err}
263 }
264 if err := c.mergeLatest(treeMsg); err != nil {
265 return cached{nil, err}
266 }
267 if err := c.checkRecord(id, text); err != nil {
268 return cached{nil, err}
269 }
270
271
272
273 if writeCache {
274 c.ops.WriteCache(file, data)
275 }
276
277 return cached{data, nil}
278 }).(cached)
279 if result.err != nil {
280 return nil, result.err
281 }
282
283
284
285 prefix := path + " " + vers + " "
286 var hashes []string
287 for _, line := range strings.Split(string(result.data), "\n") {
288 if strings.HasPrefix(line, prefix) {
289 hashes = append(hashes, line)
290 }
291 }
292 return hashes, nil
293 }
294
295
296
297
298
299
300
301
302
303
304 func (c *Client) mergeLatest(msg []byte) error {
305
306 when, err := c.mergeLatestMem(msg)
307 if err != nil {
308 return err
309 }
310 if when != msgFuture {
311
312
313 return nil
314 }
315
316
317
318
319
320 for {
321 msg, err := c.ops.ReadConfig(c.name + "/latest")
322 if err != nil {
323 return err
324 }
325 when, err := c.mergeLatestMem(msg)
326 if err != nil {
327 return err
328 }
329 if when != msgPast {
330
331
332 return nil
333 }
334
335
336 c.latestMu.Lock()
337 latestMsg := c.latestMsg
338 c.latestMu.Unlock()
339 if err := c.ops.WriteConfig(c.name+"/latest", msg, latestMsg); err != ErrWriteConflict {
340
341 return err
342 }
343 }
344 }
345
346 const (
347 msgPast = 1 + iota
348 msgNow
349 msgFuture
350 )
351
352
353
354
355
356
357
358
359
360 func (c *Client) mergeLatestMem(msg []byte) (when int, err error) {
361 if len(msg) == 0 {
362
363 c.latestMu.Lock()
364 latest := c.latest
365 c.latestMu.Unlock()
366 if latest.N == 0 {
367 return msgNow, nil
368 }
369 return msgPast, nil
370 }
371
372 note, err := note.Open(msg, c.verifiers)
373 if err != nil {
374 return 0, fmt.Errorf("reading tree note: %v\nnote:\n%s", err, msg)
375 }
376 tree, err := tlog.ParseTree([]byte(note.Text))
377 if err != nil {
378 return 0, fmt.Errorf("reading tree: %v\ntree:\n%s", err, note.Text)
379 }
380
381
382
383
384 c.latestMu.Lock()
385 latest := c.latest
386 latestMsg := c.latestMsg
387 c.latestMu.Unlock()
388
389 for {
390
391 if tree.N <= latest.N {
392 if err := c.checkTrees(tree, msg, latest, latestMsg); err != nil {
393 return 0, err
394 }
395 if tree.N < latest.N {
396 return msgPast, nil
397 }
398 return msgNow, nil
399 }
400
401
402 if err := c.checkTrees(latest, latestMsg, tree, msg); err != nil {
403 return 0, err
404 }
405
406
407
408 c.latestMu.Lock()
409 installed := false
410 if c.latest == latest {
411 installed = true
412 c.latest = tree
413 c.latestMsg = msg
414 } else {
415 latest = c.latest
416 latestMsg = c.latestMsg
417 }
418 c.latestMu.Unlock()
419
420 if installed {
421 return msgFuture, nil
422 }
423 }
424 }
425
426
427
428
429
430 func (c *Client) checkTrees(older tlog.Tree, olderNote []byte, newer tlog.Tree, newerNote []byte) error {
431 thr := tlog.TileHashReader(newer, &c.tileReader)
432 h, err := tlog.TreeHash(older.N, thr)
433 if err != nil {
434 if older.N == newer.N {
435 return fmt.Errorf("checking tree#%d: %v", older.N, err)
436 }
437 return fmt.Errorf("checking tree#%d against tree#%d: %v", older.N, newer.N, err)
438 }
439 if h == older.Hash {
440 return nil
441 }
442
443
444
445 var buf bytes.Buffer
446 fmt.Fprintf(&buf, "SECURITY ERROR\n")
447 fmt.Fprintf(&buf, "go.sum database server misbehavior detected!\n\n")
448 indent := func(b []byte) []byte {
449 return bytes.Replace(b, []byte("\n"), []byte("\n\t"), -1)
450 }
451 fmt.Fprintf(&buf, "old database:\n\t%s\n", indent(olderNote))
452 fmt.Fprintf(&buf, "new database:\n\t%s\n", indent(newerNote))
453
454
455
456
457
458
459
460
461
462
463
464 fmt.Fprintf(&buf, "proof of misbehavior:\n\t%v", h)
465 if p, err := tlog.ProveTree(newer.N, older.N, thr); err != nil {
466 fmt.Fprintf(&buf, "\tinternal error: %v\n", err)
467 } else if err := tlog.CheckTree(p, newer.N, newer.Hash, older.N, h); err != nil {
468 fmt.Fprintf(&buf, "\tinternal error: generated inconsistent proof\n")
469 } else {
470 for _, h := range p {
471 fmt.Fprintf(&buf, "\n\t%v", h)
472 }
473 }
474 c.ops.SecurityError(buf.String())
475 return ErrSecurity
476 }
477
478
479 func (c *Client) checkRecord(id int64, data []byte) error {
480 c.latestMu.Lock()
481 latest := c.latest
482 c.latestMu.Unlock()
483
484 if id >= latest.N {
485 return fmt.Errorf("cannot validate record %d in tree of size %d", id, latest.N)
486 }
487 hashes, err := tlog.TileHashReader(latest, &c.tileReader).ReadHashes([]int64{tlog.StoredHashIndex(0, id)})
488 if err != nil {
489 return err
490 }
491 if hashes[0] == tlog.RecordHash(data) {
492 return nil
493 }
494 return fmt.Errorf("cannot authenticate record data in server response")
495 }
496
497
498
499
500 type tileReader struct {
501 c *Client
502 }
503
504 func (r *tileReader) Height() int {
505 return r.c.tileHeight
506 }
507
508
509
510 func (r *tileReader) ReadTiles(tiles []tlog.Tile) ([][]byte, error) {
511
512 data := make([][]byte, len(tiles))
513 errs := make([]error, len(tiles))
514 var wg sync.WaitGroup
515 for i, tile := range tiles {
516 wg.Add(1)
517 go func(i int, tile tlog.Tile) {
518 defer wg.Done()
519 defer func() {
520 if e := recover(); e != nil {
521 errs[i] = fmt.Errorf("panic: %v", e)
522 }
523 }()
524 data[i], errs[i] = r.c.readTile(tile)
525 }(i, tile)
526 }
527 wg.Wait()
528
529 for _, err := range errs {
530 if err != nil {
531 return nil, err
532 }
533 }
534
535 return data, nil
536 }
537
538
539 func (c *Client) tileCacheKey(tile tlog.Tile) string {
540 return c.name + "/" + tile.Path()
541 }
542
543
544 func (c *Client) tileRemotePath(tile tlog.Tile) string {
545 return "/" + tile.Path()
546 }
547
548
549 func (c *Client) readTile(tile tlog.Tile) ([]byte, error) {
550 type cached struct {
551 data []byte
552 err error
553 }
554
555 result := c.tileCache.Do(tile, func() interface{} {
556
557 data, err := c.ops.ReadCache(c.tileCacheKey(tile))
558 if err == nil {
559 c.markTileSaved(tile)
560 return cached{data, nil}
561 }
562
563
564
565
566 full := tile
567 full.W = 1 << uint(tile.H)
568 if tile != full {
569 data, err := c.ops.ReadCache(c.tileCacheKey(full))
570 if err == nil {
571 c.markTileSaved(tile)
572 return cached{data[:len(data)/full.W*tile.W], nil}
573 }
574 }
575
576
577 data, err = c.ops.ReadRemote(c.tileRemotePath(tile))
578 if err == nil {
579 return cached{data, nil}
580 }
581
582
583
584
585
586 if tile != full {
587 data, err := c.ops.ReadRemote(c.tileRemotePath(full))
588 if err == nil {
589
590
591
592
593
594 return cached{data[:len(data)/full.W*tile.W], nil}
595 }
596 }
597
598
599
600 return cached{nil, err}
601 }).(cached)
602
603 return result.data, result.err
604 }
605
606
607
608 func (c *Client) markTileSaved(tile tlog.Tile) {
609 c.tileSavedMu.Lock()
610 c.tileSaved[tile] = true
611 c.tileSavedMu.Unlock()
612 }
613
614
615 func (r *tileReader) SaveTiles(tiles []tlog.Tile, data [][]byte) {
616 c := r.c
617
618
619
620 save := make([]bool, len(tiles))
621 c.tileSavedMu.Lock()
622 for i, tile := range tiles {
623 if !c.tileSaved[tile] {
624 save[i] = true
625 c.tileSaved[tile] = true
626 }
627 }
628 c.tileSavedMu.Unlock()
629
630 for i, tile := range tiles {
631 if save[i] {
632
633
634
635
636 c.ops.WriteCache(c.name+"/"+tile.Path(), data[i])
637 }
638 }
639 }
640
View as plain text