1
2
3
4
5 package cryptotest
6
7 import (
8 "bytes"
9 "crypto/cipher"
10 "testing"
11 )
12
13 type MakeBlock func(key []byte) (cipher.Block, error)
14
15
16
17 func TestBlock(t *testing.T, keySize int, mb MakeBlock) {
18
19 key := make([]byte, keySize)
20 newRandReader(t).Read(key)
21 t.Logf("Cipher key: 0x%x", key)
22
23 block, err := mb(key)
24 if err != nil {
25 t.Fatal(err)
26 }
27
28 blockSize := block.BlockSize()
29
30 t.Run("Encryption", func(t *testing.T) {
31 testCipher(t, block.Encrypt, blockSize)
32 })
33
34 t.Run("Decryption", func(t *testing.T) {
35 testCipher(t, block.Decrypt, blockSize)
36 })
37
38
39
40
41 t.Run("Roundtrip", func(t *testing.T) {
42 rng := newRandReader(t)
43
44
45 before, ciphertext, after := make([]byte, blockSize), make([]byte, blockSize), make([]byte, blockSize)
46
47 rng.Read(before)
48
49 block.Encrypt(ciphertext, before)
50 block.Decrypt(after, ciphertext)
51
52 if !bytes.Equal(after, before) {
53 t.Errorf("plaintext is different after an encrypt/decrypt cycle; got %x, want %x", after, before)
54 }
55
56
57 before, plaintext, after := make([]byte, blockSize), make([]byte, blockSize), make([]byte, blockSize)
58
59 rng.Read(before)
60
61 block.Decrypt(plaintext, before)
62 block.Encrypt(after, plaintext)
63
64 if !bytes.Equal(after, before) {
65 t.Errorf("ciphertext is different after a decrypt/encrypt cycle; got %x, want %x", after, before)
66 }
67 })
68
69 }
70
71 func testCipher(t *testing.T, cipher func(dst, src []byte), blockSize int) {
72 t.Run("AlterInput", func(t *testing.T) {
73 rng := newRandReader(t)
74
75
76
77 src, before := make([]byte, blockSize*2), make([]byte, blockSize*2)
78 rng.Read(src)
79 copy(before, src)
80
81 dst := make([]byte, blockSize)
82
83 cipher(dst, src)
84 if !bytes.Equal(src, before) {
85 t.Errorf("block cipher modified src; got %x, want %x", src, before)
86 }
87 })
88
89 t.Run("Aliasing", func(t *testing.T) {
90 rng := newRandReader(t)
91
92 buff, expectedOutput := make([]byte, blockSize), make([]byte, blockSize)
93
94
95 rng.Read(buff)
96 cipher(expectedOutput, buff)
97
98
99
100 cipher(buff, buff)
101 if !bytes.Equal(buff, expectedOutput) {
102 t.Errorf("block cipher produced different output when dst = src; got %x, want %x", buff, expectedOutput)
103 }
104 })
105
106 t.Run("OutOfBoundsWrite", func(t *testing.T) {
107 rng := newRandReader(t)
108
109 src := make([]byte, blockSize)
110 rng.Read(src)
111
112
113 buff := make([]byte, blockSize*3)
114 endOfPrefix, startOfSuffix := blockSize, blockSize*2
115 rng.Read(buff[:endOfPrefix])
116 rng.Read(buff[startOfSuffix:])
117 dst := buff[endOfPrefix:startOfSuffix]
118
119
120 initPrefix, initSuffix := make([]byte, blockSize), make([]byte, blockSize)
121 copy(initPrefix, buff[:endOfPrefix])
122 copy(initSuffix, buff[startOfSuffix:])
123
124
125
126 cipher(dst, src)
127 if !bytes.Equal(buff[startOfSuffix:], initSuffix) {
128 t.Errorf("block cipher did out of bounds write after end of dst slice; got %x, want %x", buff[startOfSuffix:], initSuffix)
129 }
130 if !bytes.Equal(buff[:endOfPrefix], initPrefix) {
131 t.Errorf("block cipher did out of bounds write before beginning of dst slice; got %x, want %x", buff[:endOfPrefix], initPrefix)
132 }
133
134
135
136 dst = buff[endOfPrefix:]
137 cipher(dst, src)
138 if !bytes.Equal(buff[startOfSuffix:], initSuffix) {
139 t.Errorf("block cipher modified dst past BlockSize bytes; got %x, want %x", buff[startOfSuffix:], initSuffix)
140 }
141 })
142
143
144
145
146 t.Run("OutOfBoundsRead", func(t *testing.T) {
147 rng := newRandReader(t)
148
149 src := make([]byte, blockSize)
150 rng.Read(src)
151 expectedDst := make([]byte, blockSize)
152 cipher(expectedDst, src)
153
154
155 buff := make([]byte, blockSize*3)
156 endOfPrefix, startOfSuffix := blockSize, blockSize*2
157
158 copy(buff[endOfPrefix:startOfSuffix], src)
159 rng.Read(buff[:endOfPrefix])
160 rng.Read(buff[startOfSuffix:])
161
162 testDst := make([]byte, blockSize)
163 cipher(testDst, buff[endOfPrefix:startOfSuffix])
164 if !bytes.Equal(testDst, expectedDst) {
165 t.Errorf("block cipher affected by data outside of src slice bounds; got %x, want %x", testDst, expectedDst)
166 }
167
168
169
170 cipher(testDst, buff[endOfPrefix:])
171 if !bytes.Equal(testDst, expectedDst) {
172 t.Errorf("block cipher affected by src data beyond BlockSize bytes; got %x, want %x", buff[startOfSuffix:], expectedDst)
173 }
174 })
175
176 t.Run("NonZeroDst", func(t *testing.T) {
177 rng := newRandReader(t)
178
179
180 src := make([]byte, blockSize)
181 rng.Read(src)
182 expectedDst := make([]byte, blockSize)
183
184 cipher(expectedDst, src)
185
186
187 dst := make([]byte, blockSize*2)
188 rng.Read(dst)
189
190
191 expectedDst = append(expectedDst, dst[blockSize:]...)
192
193 cipher(dst, src)
194 if !bytes.Equal(dst, expectedDst) {
195 t.Errorf("block cipher behavior differs when given non-zero dst; got %x, want %x", dst, expectedDst)
196 }
197 })
198
199 t.Run("BufferOverlap", func(t *testing.T) {
200 rng := newRandReader(t)
201
202 buff := make([]byte, blockSize*2)
203 rng.Read((buff))
204
205
206 src := buff[:blockSize]
207 dst := buff[1 : blockSize+1]
208 mustPanic(t, "invalid buffer overlap", func() { cipher(dst, src) })
209
210
211 src = buff[:blockSize]
212 dst = buff[blockSize-1 : 2*blockSize-1]
213 mustPanic(t, "invalid buffer overlap", func() { cipher(dst, src) })
214
215
216 src = buff[blockSize-1 : 2*blockSize-1]
217 dst = buff[:blockSize]
218 mustPanic(t, "invalid buffer overlap", func() { cipher(dst, src) })
219 })
220
221
222
223
224 t.Run("ShortBlock", func(t *testing.T) {
225
226
227
228 byteSlice := func(n int) []byte { return make([]byte, n+1)[0:n] }
229
230
231 mustPanic(t, "input not full block", func() { cipher(byteSlice(blockSize), byteSlice(blockSize-1)) })
232 mustPanic(t, "output not full block", func() { cipher(byteSlice(blockSize-1), byteSlice(blockSize)) })
233
234
235 mustPanic(t, "input not full block", func() { cipher(byteSlice(1), byteSlice(1)) })
236 mustPanic(t, "input not full block", func() { cipher(byteSlice(100), byteSlice(1)) })
237 mustPanic(t, "output not full block", func() { cipher(byteSlice(1), byteSlice(100)) })
238 })
239 }
240
241 func mustPanic(t *testing.T, msg string, f func()) {
242 t.Helper()
243
244 defer func() {
245 t.Helper()
246
247 err := recover()
248
249 if err == nil {
250 t.Errorf("function did not panic for %q", msg)
251 }
252 }()
253 f()
254 }
255
View as plain text