1
2
3
4
5
6
7 package chacha20
8
9 import (
10 "crypto/cipher"
11 "encoding/binary"
12 "errors"
13 "math/bits"
14
15 "golang.org/x/crypto/internal/alias"
16 )
17
18 const (
19
20 KeySize = 32
21
22
23
24
25
26
27 NonceSize = 12
28
29
30
31 NonceSizeX = 24
32 )
33
34
35
36 type Cipher struct {
37
38
39 key [8]uint32
40 counter uint32
41 nonce [3]uint32
42
43
44
45
46 buf [bufSize]byte
47 len int
48
49
50
51 overflow bool
52
53
54
55 precompDone bool
56 p1, p5, p9, p13 uint32
57 p2, p6, p10, p14 uint32
58 p3, p7, p11, p15 uint32
59 }
60
61 var _ cipher.Stream = (*Cipher)(nil)
62
63
64
65
66
67
68
69
70
71
72 func NewUnauthenticatedCipher(key, nonce []byte) (*Cipher, error) {
73
74
75
76 c := &Cipher{}
77 return newUnauthenticatedCipher(c, key, nonce)
78 }
79
80 func newUnauthenticatedCipher(c *Cipher, key, nonce []byte) (*Cipher, error) {
81 if len(key) != KeySize {
82 return nil, errors.New("chacha20: wrong key size")
83 }
84 if len(nonce) == NonceSizeX {
85
86
87
88 key, _ = HChaCha20(key, nonce[0:16])
89 cNonce := make([]byte, NonceSize)
90 copy(cNonce[4:12], nonce[16:24])
91 nonce = cNonce
92 } else if len(nonce) != NonceSize {
93 return nil, errors.New("chacha20: wrong nonce size")
94 }
95
96 key, nonce = key[:KeySize], nonce[:NonceSize]
97 c.key = [8]uint32{
98 binary.LittleEndian.Uint32(key[0:4]),
99 binary.LittleEndian.Uint32(key[4:8]),
100 binary.LittleEndian.Uint32(key[8:12]),
101 binary.LittleEndian.Uint32(key[12:16]),
102 binary.LittleEndian.Uint32(key[16:20]),
103 binary.LittleEndian.Uint32(key[20:24]),
104 binary.LittleEndian.Uint32(key[24:28]),
105 binary.LittleEndian.Uint32(key[28:32]),
106 }
107 c.nonce = [3]uint32{
108 binary.LittleEndian.Uint32(nonce[0:4]),
109 binary.LittleEndian.Uint32(nonce[4:8]),
110 binary.LittleEndian.Uint32(nonce[8:12]),
111 }
112 return c, nil
113 }
114
115
116 const (
117 j0 uint32 = 0x61707865
118 j1 uint32 = 0x3320646e
119 j2 uint32 = 0x79622d32
120 j3 uint32 = 0x6b206574
121 )
122
123 const blockSize = 64
124
125
126
127
128 func quarterRound(a, b, c, d uint32) (uint32, uint32, uint32, uint32) {
129 a += b
130 d ^= a
131 d = bits.RotateLeft32(d, 16)
132 c += d
133 b ^= c
134 b = bits.RotateLeft32(b, 12)
135 a += b
136 d ^= a
137 d = bits.RotateLeft32(d, 8)
138 c += d
139 b ^= c
140 b = bits.RotateLeft32(b, 7)
141 return a, b, c, d
142 }
143
144
145
146
147
148
149
150
151
152 func (s *Cipher) SetCounter(counter uint32) {
153
154
155
156
157 outputCounter := s.counter - uint32(s.len)/blockSize
158 if s.overflow || counter < outputCounter {
159 panic("chacha20: SetCounter attempted to rollback counter")
160 }
161
162
163
164
165
166 if counter < s.counter {
167 s.len = int(s.counter-counter) * blockSize
168 } else {
169 s.counter = counter
170 s.len = 0
171 }
172 }
173
174
175
176
177
178
179
180
181
182
183
184 func (s *Cipher) XORKeyStream(dst, src []byte) {
185 if len(src) == 0 {
186 return
187 }
188 if len(dst) < len(src) {
189 panic("chacha20: output smaller than input")
190 }
191 dst = dst[:len(src)]
192 if alias.InexactOverlap(dst, src) {
193 panic("chacha20: invalid buffer overlap")
194 }
195
196
197 if s.len != 0 {
198 keyStream := s.buf[bufSize-s.len:]
199 if len(src) < len(keyStream) {
200 keyStream = keyStream[:len(src)]
201 }
202 _ = src[len(keyStream)-1]
203 for i, b := range keyStream {
204 dst[i] = src[i] ^ b
205 }
206 s.len -= len(keyStream)
207 dst, src = dst[len(keyStream):], src[len(keyStream):]
208 }
209 if len(src) == 0 {
210 return
211 }
212
213
214
215
216 numBlocks := (uint64(len(src)) + blockSize - 1) / blockSize
217 if s.overflow || uint64(s.counter)+numBlocks > 1<<32 {
218 panic("chacha20: counter overflow")
219 } else if uint64(s.counter)+numBlocks == 1<<32 {
220 s.overflow = true
221 }
222
223
224
225
226
227 full := len(src) - len(src)%bufSize
228 if full > 0 {
229 s.xorKeyStreamBlocks(dst[:full], src[:full])
230 }
231 dst, src = dst[full:], src[full:]
232
233
234
235 const blocksPerBuf = bufSize / blockSize
236 if uint64(s.counter)+blocksPerBuf > 1<<32 {
237 s.buf = [bufSize]byte{}
238 numBlocks := (len(src) + blockSize - 1) / blockSize
239 buf := s.buf[bufSize-numBlocks*blockSize:]
240 copy(buf, src)
241 s.xorKeyStreamBlocksGeneric(buf, buf)
242 s.len = len(buf) - copy(dst, buf)
243 return
244 }
245
246
247
248 if len(src) > 0 {
249 s.buf = [bufSize]byte{}
250 copy(s.buf[:], src)
251 s.xorKeyStreamBlocks(s.buf[:], s.buf[:])
252 s.len = bufSize - copy(dst, s.buf[:])
253 }
254 }
255
256 func (s *Cipher) xorKeyStreamBlocksGeneric(dst, src []byte) {
257 if len(dst) != len(src) || len(dst)%blockSize != 0 {
258 panic("chacha20: internal error: wrong dst and/or src length")
259 }
260
261
262
263
264
265
266
267
268
269
270
271
272 var (
273 c0, c1, c2, c3 = j0, j1, j2, j3
274 c4, c5, c6, c7 = s.key[0], s.key[1], s.key[2], s.key[3]
275 c8, c9, c10, c11 = s.key[4], s.key[5], s.key[6], s.key[7]
276 _, c13, c14, c15 = s.counter, s.nonce[0], s.nonce[1], s.nonce[2]
277 )
278
279
280
281
282 if !s.precompDone {
283 s.p1, s.p5, s.p9, s.p13 = quarterRound(c1, c5, c9, c13)
284 s.p2, s.p6, s.p10, s.p14 = quarterRound(c2, c6, c10, c14)
285 s.p3, s.p7, s.p11, s.p15 = quarterRound(c3, c7, c11, c15)
286 s.precompDone = true
287 }
288
289
290
291 for len(src) >= 64 && len(dst) >= 64 {
292
293 fcr0, fcr4, fcr8, fcr12 := quarterRound(c0, c4, c8, s.counter)
294
295
296 x0, x5, x10, x15 := quarterRound(fcr0, s.p5, s.p10, s.p15)
297 x1, x6, x11, x12 := quarterRound(s.p1, s.p6, s.p11, fcr12)
298 x2, x7, x8, x13 := quarterRound(s.p2, s.p7, fcr8, s.p13)
299 x3, x4, x9, x14 := quarterRound(s.p3, fcr4, s.p9, s.p14)
300
301
302 for i := 0; i < 9; i++ {
303
304 x0, x4, x8, x12 = quarterRound(x0, x4, x8, x12)
305 x1, x5, x9, x13 = quarterRound(x1, x5, x9, x13)
306 x2, x6, x10, x14 = quarterRound(x2, x6, x10, x14)
307 x3, x7, x11, x15 = quarterRound(x3, x7, x11, x15)
308
309
310 x0, x5, x10, x15 = quarterRound(x0, x5, x10, x15)
311 x1, x6, x11, x12 = quarterRound(x1, x6, x11, x12)
312 x2, x7, x8, x13 = quarterRound(x2, x7, x8, x13)
313 x3, x4, x9, x14 = quarterRound(x3, x4, x9, x14)
314 }
315
316
317
318 addXor(dst[0:4], src[0:4], x0, c0)
319 addXor(dst[4:8], src[4:8], x1, c1)
320 addXor(dst[8:12], src[8:12], x2, c2)
321 addXor(dst[12:16], src[12:16], x3, c3)
322 addXor(dst[16:20], src[16:20], x4, c4)
323 addXor(dst[20:24], src[20:24], x5, c5)
324 addXor(dst[24:28], src[24:28], x6, c6)
325 addXor(dst[28:32], src[28:32], x7, c7)
326 addXor(dst[32:36], src[32:36], x8, c8)
327 addXor(dst[36:40], src[36:40], x9, c9)
328 addXor(dst[40:44], src[40:44], x10, c10)
329 addXor(dst[44:48], src[44:48], x11, c11)
330 addXor(dst[48:52], src[48:52], x12, s.counter)
331 addXor(dst[52:56], src[52:56], x13, c13)
332 addXor(dst[56:60], src[56:60], x14, c14)
333 addXor(dst[60:64], src[60:64], x15, c15)
334
335 s.counter += 1
336
337 src, dst = src[blockSize:], dst[blockSize:]
338 }
339 }
340
341
342
343
344 func HChaCha20(key, nonce []byte) ([]byte, error) {
345
346
347
348 out := make([]byte, 32)
349 return hChaCha20(out, key, nonce)
350 }
351
352 func hChaCha20(out, key, nonce []byte) ([]byte, error) {
353 if len(key) != KeySize {
354 return nil, errors.New("chacha20: wrong HChaCha20 key size")
355 }
356 if len(nonce) != 16 {
357 return nil, errors.New("chacha20: wrong HChaCha20 nonce size")
358 }
359
360 x0, x1, x2, x3 := j0, j1, j2, j3
361 x4 := binary.LittleEndian.Uint32(key[0:4])
362 x5 := binary.LittleEndian.Uint32(key[4:8])
363 x6 := binary.LittleEndian.Uint32(key[8:12])
364 x7 := binary.LittleEndian.Uint32(key[12:16])
365 x8 := binary.LittleEndian.Uint32(key[16:20])
366 x9 := binary.LittleEndian.Uint32(key[20:24])
367 x10 := binary.LittleEndian.Uint32(key[24:28])
368 x11 := binary.LittleEndian.Uint32(key[28:32])
369 x12 := binary.LittleEndian.Uint32(nonce[0:4])
370 x13 := binary.LittleEndian.Uint32(nonce[4:8])
371 x14 := binary.LittleEndian.Uint32(nonce[8:12])
372 x15 := binary.LittleEndian.Uint32(nonce[12:16])
373
374 for i := 0; i < 10; i++ {
375
376 x0, x4, x8, x12 = quarterRound(x0, x4, x8, x12)
377 x1, x5, x9, x13 = quarterRound(x1, x5, x9, x13)
378 x2, x6, x10, x14 = quarterRound(x2, x6, x10, x14)
379 x3, x7, x11, x15 = quarterRound(x3, x7, x11, x15)
380
381
382 x0, x5, x10, x15 = quarterRound(x0, x5, x10, x15)
383 x1, x6, x11, x12 = quarterRound(x1, x6, x11, x12)
384 x2, x7, x8, x13 = quarterRound(x2, x7, x8, x13)
385 x3, x4, x9, x14 = quarterRound(x3, x4, x9, x14)
386 }
387
388 _ = out[31]
389 binary.LittleEndian.PutUint32(out[0:4], x0)
390 binary.LittleEndian.PutUint32(out[4:8], x1)
391 binary.LittleEndian.PutUint32(out[8:12], x2)
392 binary.LittleEndian.PutUint32(out[12:16], x3)
393 binary.LittleEndian.PutUint32(out[16:20], x12)
394 binary.LittleEndian.PutUint32(out[20:24], x13)
395 binary.LittleEndian.PutUint32(out[24:28], x14)
396 binary.LittleEndian.PutUint32(out[28:32], x15)
397 return out, nil
398 }
399
View as plain text