1
2
3
4
5
6 package base64
7
8 import (
9 "encoding/binary"
10 "io"
11 "slices"
12 "strconv"
13 )
14
15
18
19
20
21
22
23
24 type Encoding struct {
25 encode [64]byte
26 decodeMap [256]uint8
27 padChar rune
28 strict bool
29 }
30
31 const (
32 StdPadding rune = '='
33 NoPadding rune = -1
34 )
35
36 const (
37 decodeMapInitialize = "" +
38 "\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff" +
39 "\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff" +
40 "\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff" +
41 "\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff" +
42 "\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff" +
43 "\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff" +
44 "\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff" +
45 "\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff" +
46 "\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff" +
47 "\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff" +
48 "\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff" +
49 "\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff" +
50 "\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff" +
51 "\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff" +
52 "\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff" +
53 "\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff"
54 invalidIndex = '\xff'
55 )
56
57
58
59
60
61
62
63
64 func NewEncoding(encoder string) *Encoding {
65 if len(encoder) != 64 {
66 panic("encoding alphabet is not 64-bytes long")
67 }
68
69 e := new(Encoding)
70 e.padChar = StdPadding
71 copy(e.encode[:], encoder)
72 copy(e.decodeMap[:], decodeMapInitialize)
73
74 for i := 0; i < len(encoder); i++ {
75
76
77
78 switch {
79 case encoder[i] == '\n' || encoder[i] == '\r':
80 panic("encoding alphabet contains newline character")
81 case e.decodeMap[encoder[i]] != invalidIndex:
82 panic("encoding alphabet includes duplicate symbols")
83 }
84 e.decodeMap[encoder[i]] = uint8(i)
85 }
86 return e
87 }
88
89
90
91
92
93
94
95
96 func (enc Encoding) WithPadding(padding rune) *Encoding {
97 switch {
98 case padding < NoPadding || padding == '\r' || padding == '\n' || padding > 0xff:
99 panic("invalid padding")
100 case padding != NoPadding && enc.decodeMap[byte(padding)] != invalidIndex:
101 panic("padding contained in alphabet")
102 }
103 enc.padChar = padding
104 return &enc
105 }
106
107
108
109
110
111
112
113 func (enc Encoding) Strict() *Encoding {
114 enc.strict = true
115 return &enc
116 }
117
118
119 var StdEncoding = NewEncoding("ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789+/")
120
121
122
123 var URLEncoding = NewEncoding("ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789-_")
124
125
126
127
128 var RawStdEncoding = StdEncoding.WithPadding(NoPadding)
129
130
131
132
133 var RawURLEncoding = URLEncoding.WithPadding(NoPadding)
134
135
138
139
140
141
142
143
144
145 func (enc *Encoding) Encode(dst, src []byte) {
146 if len(src) == 0 {
147 return
148 }
149
150
151
152 _ = enc.encode
153
154 di, si := 0, 0
155 n := (len(src) / 3) * 3
156 for si < n {
157
158 val := uint(src[si+0])<<16 | uint(src[si+1])<<8 | uint(src[si+2])
159
160 dst[di+0] = enc.encode[val>>18&0x3F]
161 dst[di+1] = enc.encode[val>>12&0x3F]
162 dst[di+2] = enc.encode[val>>6&0x3F]
163 dst[di+3] = enc.encode[val&0x3F]
164
165 si += 3
166 di += 4
167 }
168
169 remain := len(src) - si
170 if remain == 0 {
171 return
172 }
173
174 val := uint(src[si+0]) << 16
175 if remain == 2 {
176 val |= uint(src[si+1]) << 8
177 }
178
179 dst[di+0] = enc.encode[val>>18&0x3F]
180 dst[di+1] = enc.encode[val>>12&0x3F]
181
182 switch remain {
183 case 2:
184 dst[di+2] = enc.encode[val>>6&0x3F]
185 if enc.padChar != NoPadding {
186 dst[di+3] = byte(enc.padChar)
187 }
188 case 1:
189 if enc.padChar != NoPadding {
190 dst[di+2] = byte(enc.padChar)
191 dst[di+3] = byte(enc.padChar)
192 }
193 }
194 }
195
196
197
198 func (enc *Encoding) AppendEncode(dst, src []byte) []byte {
199 n := enc.EncodedLen(len(src))
200 dst = slices.Grow(dst, n)
201 enc.Encode(dst[len(dst):][:n], src)
202 return dst[:len(dst)+n]
203 }
204
205
206 func (enc *Encoding) EncodeToString(src []byte) string {
207 buf := make([]byte, enc.EncodedLen(len(src)))
208 enc.Encode(buf, src)
209 return string(buf)
210 }
211
212 type encoder struct {
213 err error
214 enc *Encoding
215 w io.Writer
216 buf [3]byte
217 nbuf int
218 out [1024]byte
219 }
220
221 func (e *encoder) Write(p []byte) (n int, err error) {
222 if e.err != nil {
223 return 0, e.err
224 }
225
226
227 if e.nbuf > 0 {
228 var i int
229 for i = 0; i < len(p) && e.nbuf < 3; i++ {
230 e.buf[e.nbuf] = p[i]
231 e.nbuf++
232 }
233 n += i
234 p = p[i:]
235 if e.nbuf < 3 {
236 return
237 }
238 e.enc.Encode(e.out[:], e.buf[:])
239 if _, e.err = e.w.Write(e.out[:4]); e.err != nil {
240 return n, e.err
241 }
242 e.nbuf = 0
243 }
244
245
246 for len(p) >= 3 {
247 nn := len(e.out) / 4 * 3
248 if nn > len(p) {
249 nn = len(p)
250 nn -= nn % 3
251 }
252 e.enc.Encode(e.out[:], p[:nn])
253 if _, e.err = e.w.Write(e.out[0 : nn/3*4]); e.err != nil {
254 return n, e.err
255 }
256 n += nn
257 p = p[nn:]
258 }
259
260
261 copy(e.buf[:], p)
262 e.nbuf = len(p)
263 n += len(p)
264 return
265 }
266
267
268
269 func (e *encoder) Close() error {
270
271 if e.err == nil && e.nbuf > 0 {
272 e.enc.Encode(e.out[:], e.buf[:e.nbuf])
273 _, e.err = e.w.Write(e.out[:e.enc.EncodedLen(e.nbuf)])
274 e.nbuf = 0
275 }
276 return e.err
277 }
278
279
280
281
282
283
284 func NewEncoder(enc *Encoding, w io.Writer) io.WriteCloser {
285 return &encoder{enc: enc, w: w}
286 }
287
288
289
290 func (enc *Encoding) EncodedLen(n int) int {
291 if enc.padChar == NoPadding {
292 return n/3*4 + (n%3*8+5)/6
293 }
294 return (n + 2) / 3 * 4
295 }
296
297
300
301 type CorruptInputError int64
302
303 func (e CorruptInputError) Error() string {
304 return "illegal base64 data at input byte " + strconv.FormatInt(int64(e), 10)
305 }
306
307
308
309
310
311
312 func (enc *Encoding) decodeQuantum(dst, src []byte, si int) (nsi, n int, err error) {
313
314 var dbuf [4]byte
315 dlen := 4
316
317
318 _ = enc.decodeMap
319
320 for j := 0; j < len(dbuf); j++ {
321 if len(src) == si {
322 switch {
323 case j == 0:
324 return si, 0, nil
325 case j == 1, enc.padChar != NoPadding:
326 return si, 0, CorruptInputError(si - j)
327 }
328 dlen = j
329 break
330 }
331 in := src[si]
332 si++
333
334 out := enc.decodeMap[in]
335 if out != 0xff {
336 dbuf[j] = out
337 continue
338 }
339
340 if in == '\n' || in == '\r' {
341 j--
342 continue
343 }
344
345 if rune(in) != enc.padChar {
346 return si, 0, CorruptInputError(si - 1)
347 }
348
349
350 switch j {
351 case 0, 1:
352
353 return si, 0, CorruptInputError(si - 1)
354 case 2:
355
356
357 for si < len(src) && (src[si] == '\n' || src[si] == '\r') {
358 si++
359 }
360 if si == len(src) {
361
362 return si, 0, CorruptInputError(len(src))
363 }
364 if rune(src[si]) != enc.padChar {
365
366 return si, 0, CorruptInputError(si - 1)
367 }
368
369 si++
370 }
371
372
373 for si < len(src) && (src[si] == '\n' || src[si] == '\r') {
374 si++
375 }
376 if si < len(src) {
377
378 err = CorruptInputError(si)
379 }
380 dlen = j
381 break
382 }
383
384
385 val := uint(dbuf[0])<<18 | uint(dbuf[1])<<12 | uint(dbuf[2])<<6 | uint(dbuf[3])
386 dbuf[2], dbuf[1], dbuf[0] = byte(val>>0), byte(val>>8), byte(val>>16)
387 switch dlen {
388 case 4:
389 dst[2] = dbuf[2]
390 dbuf[2] = 0
391 fallthrough
392 case 3:
393 dst[1] = dbuf[1]
394 if enc.strict && dbuf[2] != 0 {
395 return si, 0, CorruptInputError(si - 1)
396 }
397 dbuf[1] = 0
398 fallthrough
399 case 2:
400 dst[0] = dbuf[0]
401 if enc.strict && (dbuf[1] != 0 || dbuf[2] != 0) {
402 return si, 0, CorruptInputError(si - 2)
403 }
404 }
405
406 return si, dlen - 1, err
407 }
408
409
410
411
412 func (enc *Encoding) AppendDecode(dst, src []byte) ([]byte, error) {
413
414 n := len(src)
415 for n > 0 && rune(src[n-1]) == enc.padChar {
416 n--
417 }
418 n = decodedLen(n, NoPadding)
419
420 dst = slices.Grow(dst, n)
421 n, err := enc.Decode(dst[len(dst):][:n], src)
422 return dst[:len(dst)+n], err
423 }
424
425
426 func (enc *Encoding) DecodeString(s string) ([]byte, error) {
427 dbuf := make([]byte, enc.DecodedLen(len(s)))
428 n, err := enc.Decode(dbuf, []byte(s))
429 return dbuf[:n], err
430 }
431
432 type decoder struct {
433 err error
434 readErr error
435 enc *Encoding
436 r io.Reader
437 buf [1024]byte
438 nbuf int
439 out []byte
440 outbuf [1024 / 4 * 3]byte
441 }
442
443 func (d *decoder) Read(p []byte) (n int, err error) {
444
445 if len(d.out) > 0 {
446 n = copy(p, d.out)
447 d.out = d.out[n:]
448 return n, nil
449 }
450
451 if d.err != nil {
452 return 0, d.err
453 }
454
455
456
457
458 for d.nbuf < 4 && d.readErr == nil {
459 nn := len(p) / 3 * 4
460 if nn < 4 {
461 nn = 4
462 }
463 if nn > len(d.buf) {
464 nn = len(d.buf)
465 }
466 nn, d.readErr = d.r.Read(d.buf[d.nbuf:nn])
467 d.nbuf += nn
468 }
469
470 if d.nbuf < 4 {
471 if d.enc.padChar == NoPadding && d.nbuf > 0 {
472
473 var nw int
474 nw, d.err = d.enc.Decode(d.outbuf[:], d.buf[:d.nbuf])
475 d.nbuf = 0
476 d.out = d.outbuf[:nw]
477 n = copy(p, d.out)
478 d.out = d.out[n:]
479 if n > 0 || len(p) == 0 && len(d.out) > 0 {
480 return n, nil
481 }
482 if d.err != nil {
483 return 0, d.err
484 }
485 }
486 d.err = d.readErr
487 if d.err == io.EOF && d.nbuf > 0 {
488 d.err = io.ErrUnexpectedEOF
489 }
490 return 0, d.err
491 }
492
493
494 nr := d.nbuf / 4 * 4
495 nw := d.nbuf / 4 * 3
496 if nw > len(p) {
497 nw, d.err = d.enc.Decode(d.outbuf[:], d.buf[:nr])
498 d.out = d.outbuf[:nw]
499 n = copy(p, d.out)
500 d.out = d.out[n:]
501 } else {
502 n, d.err = d.enc.Decode(p, d.buf[:nr])
503 }
504 d.nbuf -= nr
505 copy(d.buf[:d.nbuf], d.buf[nr:])
506 return n, d.err
507 }
508
509
510
511
512
513
514 func (enc *Encoding) Decode(dst, src []byte) (n int, err error) {
515 if len(src) == 0 {
516 return 0, nil
517 }
518
519
520
521
522 _ = enc.decodeMap
523
524 si := 0
525 for strconv.IntSize >= 64 && len(src)-si >= 8 && len(dst)-n >= 8 {
526 src2 := src[si : si+8]
527 if dn, ok := assemble64(
528 enc.decodeMap[src2[0]],
529 enc.decodeMap[src2[1]],
530 enc.decodeMap[src2[2]],
531 enc.decodeMap[src2[3]],
532 enc.decodeMap[src2[4]],
533 enc.decodeMap[src2[5]],
534 enc.decodeMap[src2[6]],
535 enc.decodeMap[src2[7]],
536 ); ok {
537 binary.BigEndian.PutUint64(dst[n:], dn)
538 n += 6
539 si += 8
540 } else {
541 var ninc int
542 si, ninc, err = enc.decodeQuantum(dst[n:], src, si)
543 n += ninc
544 if err != nil {
545 return n, err
546 }
547 }
548 }
549
550 for len(src)-si >= 4 && len(dst)-n >= 4 {
551 src2 := src[si : si+4]
552 if dn, ok := assemble32(
553 enc.decodeMap[src2[0]],
554 enc.decodeMap[src2[1]],
555 enc.decodeMap[src2[2]],
556 enc.decodeMap[src2[3]],
557 ); ok {
558 binary.BigEndian.PutUint32(dst[n:], dn)
559 n += 3
560 si += 4
561 } else {
562 var ninc int
563 si, ninc, err = enc.decodeQuantum(dst[n:], src, si)
564 n += ninc
565 if err != nil {
566 return n, err
567 }
568 }
569 }
570
571 for si < len(src) {
572 var ninc int
573 si, ninc, err = enc.decodeQuantum(dst[n:], src, si)
574 n += ninc
575 if err != nil {
576 return n, err
577 }
578 }
579 return n, err
580 }
581
582
583
584
585 func assemble32(n1, n2, n3, n4 byte) (dn uint32, ok bool) {
586
587
588 if n1|n2|n3|n4 == 0xff {
589 return 0, false
590 }
591 return uint32(n1)<<26 |
592 uint32(n2)<<20 |
593 uint32(n3)<<14 |
594 uint32(n4)<<8,
595 true
596 }
597
598
599
600
601 func assemble64(n1, n2, n3, n4, n5, n6, n7, n8 byte) (dn uint64, ok bool) {
602
603
604 if n1|n2|n3|n4|n5|n6|n7|n8 == 0xff {
605 return 0, false
606 }
607 return uint64(n1)<<58 |
608 uint64(n2)<<52 |
609 uint64(n3)<<46 |
610 uint64(n4)<<40 |
611 uint64(n5)<<34 |
612 uint64(n6)<<28 |
613 uint64(n7)<<22 |
614 uint64(n8)<<16,
615 true
616 }
617
618 type newlineFilteringReader struct {
619 wrapped io.Reader
620 }
621
622 func (r *newlineFilteringReader) Read(p []byte) (int, error) {
623 n, err := r.wrapped.Read(p)
624 for n > 0 {
625 offset := 0
626 for i, b := range p[:n] {
627 if b != '\r' && b != '\n' {
628 if i != offset {
629 p[offset] = b
630 }
631 offset++
632 }
633 }
634 if offset > 0 {
635 return offset, err
636 }
637
638 n, err = r.wrapped.Read(p)
639 }
640 return n, err
641 }
642
643
644 func NewDecoder(enc *Encoding, r io.Reader) io.Reader {
645 return &decoder{enc: enc, r: &newlineFilteringReader{r}}
646 }
647
648
649
650 func (enc *Encoding) DecodedLen(n int) int {
651 return decodedLen(n, enc.padChar)
652 }
653
654 func decodedLen(n int, padChar rune) int {
655 if padChar == NoPadding {
656
657 return n/4*3 + n%4*6/8
658 }
659
660 return n / 4 * 3
661 }
662
View as plain text