1
2
3
4
5
6
7
8
9
10
11 package tlog
12
13 import (
14 "crypto/sha256"
15 "encoding/base64"
16 "errors"
17 "fmt"
18 "math/bits"
19 )
20
21
22 type Hash [HashSize]byte
23
24
25 const HashSize = 32
26
27
28 func (h Hash) String() string {
29 return base64.StdEncoding.EncodeToString(h[:])
30 }
31
32
33 func (h Hash) MarshalJSON() ([]byte, error) {
34 return []byte(`"` + h.String() + `"`), nil
35 }
36
37
38 func (h *Hash) UnmarshalJSON(data []byte) error {
39 if len(data) != 1+44+1 || data[0] != '"' || data[len(data)-2] != '=' || data[len(data)-1] != '"' {
40 return errors.New("cannot decode hash")
41 }
42
43
44
45
46
47
48
49
50 var tmp Hash
51 n, err := base64.RawStdEncoding.Decode(tmp[:], data[1:len(data)-2])
52 if err != nil || n != HashSize {
53 return errors.New("cannot decode hash")
54 }
55 *h = tmp
56 return nil
57 }
58
59
60 func ParseHash(s string) (Hash, error) {
61 data, err := base64.StdEncoding.DecodeString(s)
62 if err != nil || len(data) != HashSize {
63 return Hash{}, fmt.Errorf("malformed hash")
64 }
65 var h Hash
66 copy(h[:], data)
67 return h, nil
68 }
69
70
71
72 func maxpow2(n int64) (k int64, l int) {
73 l = 0
74 for 1<<uint(l+1) < n {
75 l++
76 }
77 return 1 << uint(l), l
78 }
79
80 var zeroPrefix = []byte{0x00}
81
82
83 func RecordHash(data []byte) Hash {
84
85
86 h := sha256.New()
87 h.Write(zeroPrefix)
88 h.Write(data)
89 var h1 Hash
90 h.Sum(h1[:0])
91 return h1
92 }
93
94
95 func NodeHash(left, right Hash) Hash {
96
97
98
99
100 var buf [1 + HashSize + HashSize]byte
101 buf[0] = 0x01
102 copy(buf[1:], left[:])
103 copy(buf[1+HashSize:], right[:])
104 return sha256.Sum256(buf[:])
105 }
106
107
108
109
110
111
112
113
114
115
116
117 func StoredHashIndex(level int, n int64) int64 {
118
119
120
121 for l := level; l > 0; l-- {
122 n = 2*n + 1
123 }
124
125
126 i := int64(0)
127 for ; n > 0; n >>= 1 {
128 i += n
129 }
130
131 return i + int64(level)
132 }
133
134
135
136 func SplitStoredHashIndex(index int64) (level int, n int64) {
137
138
139
140 n = index / 2
141 indexN := StoredHashIndex(0, n)
142 if indexN > index {
143 panic("bad math")
144 }
145 for {
146
147 x := indexN + 1 + int64(bits.TrailingZeros64(uint64(n+1)))
148 if x > index {
149 break
150 }
151 n++
152 indexN = x
153 }
154
155
156 level = int(index - indexN)
157 return level, n >> uint(level)
158 }
159
160
161
162 func StoredHashCount(n int64) int64 {
163 if n == 0 {
164 return 0
165 }
166
167 numHash := StoredHashIndex(0, n-1) + 1
168
169 for i := uint64(n - 1); i&1 != 0; i >>= 1 {
170 numHash++
171 }
172 return numHash
173 }
174
175
176
177
178
179
180
181
182 func StoredHashes(n int64, data []byte, r HashReader) ([]Hash, error) {
183 return StoredHashesForRecordHash(n, RecordHash(data), r)
184 }
185
186
187
188 func StoredHashesForRecordHash(n int64, h Hash, r HashReader) ([]Hash, error) {
189
190 hashes := []Hash{h}
191
192
193
194
195 m := int(bits.TrailingZeros64(uint64(n + 1)))
196 indexes := make([]int64, m)
197 for i := 0; i < m; i++ {
198
199
200 indexes[m-1-i] = StoredHashIndex(i, n>>uint(i)-1)
201 }
202
203
204 old, err := r.ReadHashes(indexes)
205 if err != nil {
206 return nil, err
207 }
208 if len(old) != len(indexes) {
209 return nil, fmt.Errorf("tlog: ReadHashes(%d indexes) = %d hashes", len(indexes), len(old))
210 }
211
212
213 for i := 0; i < m; i++ {
214 h = NodeHash(old[m-1-i], h)
215 hashes = append(hashes, h)
216 }
217 return hashes, nil
218 }
219
220
221 type HashReader interface {
222
223
224
225
226
227 ReadHashes(indexes []int64) ([]Hash, error)
228 }
229
230
231 type HashReaderFunc func([]int64) ([]Hash, error)
232
233 func (f HashReaderFunc) ReadHashes(indexes []int64) ([]Hash, error) {
234 return f(indexes)
235 }
236
237
238
239 var emptyHash = Hash{
240 0xe3, 0xb0, 0xc4, 0x42, 0x98, 0xfc, 0x1c, 0x14,
241 0x9a, 0xfb, 0xf4, 0xc8, 0x99, 0x6f, 0xb9, 0x24,
242 0x27, 0xae, 0x41, 0xe4, 0x64, 0x9b, 0x93, 0x4c,
243 0xa4, 0x95, 0x99, 0x1b, 0x78, 0x52, 0xb8, 0x55,
244 }
245
246
247
248
249
250 func TreeHash(n int64, r HashReader) (Hash, error) {
251 if n == 0 {
252 return emptyHash, nil
253 }
254 indexes := subTreeIndex(0, n, nil)
255 hashes, err := r.ReadHashes(indexes)
256 if err != nil {
257 return Hash{}, err
258 }
259 if len(hashes) != len(indexes) {
260 return Hash{}, fmt.Errorf("tlog: ReadHashes(%d indexes) = %d hashes", len(indexes), len(hashes))
261 }
262 hash, hashes := subTreeHash(0, n, hashes)
263 if len(hashes) != 0 {
264 panic("tlog: bad index math in TreeHash")
265 }
266 return hash, nil
267 }
268
269
270
271
272
273 func subTreeIndex(lo, hi int64, need []int64) []int64 {
274
275 for lo < hi {
276 k, level := maxpow2(hi - lo + 1)
277 if lo&(k-1) != 0 {
278 panic("tlog: bad math in subTreeIndex")
279 }
280 need = append(need, StoredHashIndex(level, lo>>uint(level)))
281 lo += k
282 }
283 return need
284 }
285
286
287
288
289
290 func subTreeHash(lo, hi int64, hashes []Hash) (Hash, []Hash) {
291
292
293
294
295 numTree := 0
296 for lo < hi {
297 k, _ := maxpow2(hi - lo + 1)
298 if lo&(k-1) != 0 || lo >= hi {
299 panic("tlog: bad math in subTreeHash")
300 }
301 numTree++
302 lo += k
303 }
304
305 if len(hashes) < numTree {
306 panic("tlog: bad index math in subTreeHash")
307 }
308
309
310 h := hashes[numTree-1]
311 for i := numTree - 2; i >= 0; i-- {
312 h = NodeHash(hashes[i], h)
313 }
314 return h, hashes[numTree:]
315 }
316
317
318
319 type RecordProof []Hash
320
321
322 func ProveRecord(t, n int64, r HashReader) (RecordProof, error) {
323 if t < 0 || n < 0 || n >= t {
324 return nil, fmt.Errorf("tlog: invalid inputs in ProveRecord")
325 }
326 indexes := leafProofIndex(0, t, n, nil)
327 if len(indexes) == 0 {
328 return RecordProof{}, nil
329 }
330 hashes, err := r.ReadHashes(indexes)
331 if err != nil {
332 return nil, err
333 }
334 if len(hashes) != len(indexes) {
335 return nil, fmt.Errorf("tlog: ReadHashes(%d indexes) = %d hashes", len(indexes), len(hashes))
336 }
337
338 p, hashes := leafProof(0, t, n, hashes)
339 if len(hashes) != 0 {
340 panic("tlog: bad index math in ProveRecord")
341 }
342 return p, nil
343 }
344
345
346
347
348
349 func leafProofIndex(lo, hi, n int64, need []int64) []int64 {
350
351 if !(lo <= n && n < hi) {
352 panic("tlog: bad math in leafProofIndex")
353 }
354 if lo+1 == hi {
355 return need
356 }
357 if k, _ := maxpow2(hi - lo); n < lo+k {
358 need = leafProofIndex(lo, lo+k, n, need)
359 need = subTreeIndex(lo+k, hi, need)
360 } else {
361 need = subTreeIndex(lo, lo+k, need)
362 need = leafProofIndex(lo+k, hi, n, need)
363 }
364 return need
365 }
366
367
368
369
370 func leafProof(lo, hi, n int64, hashes []Hash) (RecordProof, []Hash) {
371
372 if !(lo <= n && n < hi) {
373 panic("tlog: bad math in leafProof")
374 }
375
376 if lo+1 == hi {
377
378
379 return RecordProof{}, hashes
380 }
381
382
383
384 var p RecordProof
385 var th Hash
386 if k, _ := maxpow2(hi - lo); n < lo+k {
387
388 p, hashes = leafProof(lo, lo+k, n, hashes)
389 th, hashes = subTreeHash(lo+k, hi, hashes)
390 } else {
391
392 th, hashes = subTreeHash(lo, lo+k, hashes)
393 p, hashes = leafProof(lo+k, hi, n, hashes)
394 }
395 return append(p, th), hashes
396 }
397
398 var errProofFailed = errors.New("invalid transparency proof")
399
400
401
402 func CheckRecord(p RecordProof, t int64, th Hash, n int64, h Hash) error {
403 if t < 0 || n < 0 || n >= t {
404 return fmt.Errorf("tlog: invalid inputs in CheckRecord")
405 }
406 th2, err := runRecordProof(p, 0, t, n, h)
407 if err != nil {
408 return err
409 }
410 if th2 == th {
411 return nil
412 }
413 return errProofFailed
414 }
415
416
417
418
419 func runRecordProof(p RecordProof, lo, hi, n int64, leafHash Hash) (Hash, error) {
420
421 if !(lo <= n && n < hi) {
422 panic("tlog: bad math in runRecordProof")
423 }
424
425 if lo+1 == hi {
426
427
428 if len(p) != 0 {
429 return Hash{}, errProofFailed
430 }
431 return leafHash, nil
432 }
433
434 if len(p) == 0 {
435 return Hash{}, errProofFailed
436 }
437
438 k, _ := maxpow2(hi - lo)
439 if n < lo+k {
440 th, err := runRecordProof(p[:len(p)-1], lo, lo+k, n, leafHash)
441 if err != nil {
442 return Hash{}, err
443 }
444 return NodeHash(th, p[len(p)-1]), nil
445 } else {
446 th, err := runRecordProof(p[:len(p)-1], lo+k, hi, n, leafHash)
447 if err != nil {
448 return Hash{}, err
449 }
450 return NodeHash(p[len(p)-1], th), nil
451 }
452 }
453
454
455
456
457 type TreeProof []Hash
458
459
460
461 func ProveTree(t, n int64, h HashReader) (TreeProof, error) {
462 if t < 1 || n < 1 || n > t {
463 return nil, fmt.Errorf("tlog: invalid inputs in ProveTree")
464 }
465 indexes := treeProofIndex(0, t, n, nil)
466 if len(indexes) == 0 {
467 return TreeProof{}, nil
468 }
469 hashes, err := h.ReadHashes(indexes)
470 if err != nil {
471 return nil, err
472 }
473 if len(hashes) != len(indexes) {
474 return nil, fmt.Errorf("tlog: ReadHashes(%d indexes) = %d hashes", len(indexes), len(hashes))
475 }
476
477 p, hashes := treeProof(0, t, n, hashes)
478 if len(hashes) != 0 {
479 panic("tlog: bad index math in ProveTree")
480 }
481 return p, nil
482 }
483
484
485
486
487 func treeProofIndex(lo, hi, n int64, need []int64) []int64 {
488
489 if !(lo < n && n <= hi) {
490 panic("tlog: bad math in treeProofIndex")
491 }
492
493 if n == hi {
494 if lo == 0 {
495 return need
496 }
497 return subTreeIndex(lo, hi, need)
498 }
499
500 if k, _ := maxpow2(hi - lo); n <= lo+k {
501 need = treeProofIndex(lo, lo+k, n, need)
502 need = subTreeIndex(lo+k, hi, need)
503 } else {
504 need = subTreeIndex(lo, lo+k, need)
505 need = treeProofIndex(lo+k, hi, n, need)
506 }
507 return need
508 }
509
510
511
512
513 func treeProof(lo, hi, n int64, hashes []Hash) (TreeProof, []Hash) {
514
515 if !(lo < n && n <= hi) {
516 panic("tlog: bad math in treeProof")
517 }
518
519
520 if n == hi {
521 if lo == 0 {
522
523
524 return TreeProof{}, hashes
525 }
526 th, hashes := subTreeHash(lo, hi, hashes)
527 return TreeProof{th}, hashes
528 }
529
530
531
532 var p TreeProof
533 var th Hash
534 if k, _ := maxpow2(hi - lo); n <= lo+k {
535
536 p, hashes = treeProof(lo, lo+k, n, hashes)
537 th, hashes = subTreeHash(lo+k, hi, hashes)
538 } else {
539
540 th, hashes = subTreeHash(lo, lo+k, hashes)
541 p, hashes = treeProof(lo+k, hi, n, hashes)
542 }
543 return append(p, th), hashes
544 }
545
546
547
548 func CheckTree(p TreeProof, t int64, th Hash, n int64, h Hash) error {
549 if t < 1 || n < 1 || n > t {
550 return fmt.Errorf("tlog: invalid inputs in CheckTree")
551 }
552 h2, th2, err := runTreeProof(p, 0, t, n, h)
553 if err != nil {
554 return err
555 }
556 if th2 == th && h2 == h {
557 return nil
558 }
559 return errProofFailed
560 }
561
562
563
564
565
566 func runTreeProof(p TreeProof, lo, hi, n int64, old Hash) (Hash, Hash, error) {
567
568 if !(lo < n && n <= hi) {
569 panic("tlog: bad math in runTreeProof")
570 }
571
572
573 if n == hi {
574 if lo == 0 {
575 if len(p) != 0 {
576 return Hash{}, Hash{}, errProofFailed
577 }
578 return old, old, nil
579 }
580 if len(p) != 1 {
581 return Hash{}, Hash{}, errProofFailed
582 }
583 return p[0], p[0], nil
584 }
585
586 if len(p) == 0 {
587 return Hash{}, Hash{}, errProofFailed
588 }
589
590
591 k, _ := maxpow2(hi - lo)
592 if n <= lo+k {
593 oh, th, err := runTreeProof(p[:len(p)-1], lo, lo+k, n, old)
594 if err != nil {
595 return Hash{}, Hash{}, err
596 }
597 return oh, NodeHash(th, p[len(p)-1]), nil
598 } else {
599 oh, th, err := runTreeProof(p[:len(p)-1], lo+k, hi, n, old)
600 if err != nil {
601 return Hash{}, Hash{}, err
602 }
603 return NodeHash(p[len(p)-1], oh), NodeHash(p[len(p)-1], th), nil
604 }
605 }
606
View as plain text