1
2
3
4
5 package asn1
6
7 import (
8 "bytes"
9 "errors"
10 "fmt"
11 "math/big"
12 "reflect"
13 "slices"
14 "time"
15 "unicode/utf8"
16 )
17
18 var (
19 byte00Encoder encoder = byteEncoder(0x00)
20 byteFFEncoder encoder = byteEncoder(0xff)
21 )
22
23
24 type encoder interface {
25
26 Len() int
27
28 Encode(dst []byte)
29 }
30
31 type byteEncoder byte
32
33 func (c byteEncoder) Len() int {
34 return 1
35 }
36
37 func (c byteEncoder) Encode(dst []byte) {
38 dst[0] = byte(c)
39 }
40
41 type bytesEncoder []byte
42
43 func (b bytesEncoder) Len() int {
44 return len(b)
45 }
46
47 func (b bytesEncoder) Encode(dst []byte) {
48 if copy(dst, b) != len(b) {
49 panic("internal error")
50 }
51 }
52
53 type stringEncoder string
54
55 func (s stringEncoder) Len() int {
56 return len(s)
57 }
58
59 func (s stringEncoder) Encode(dst []byte) {
60 if copy(dst, s) != len(s) {
61 panic("internal error")
62 }
63 }
64
65 type multiEncoder []encoder
66
67 func (m multiEncoder) Len() int {
68 var size int
69 for _, e := range m {
70 size += e.Len()
71 }
72 return size
73 }
74
75 func (m multiEncoder) Encode(dst []byte) {
76 var off int
77 for _, e := range m {
78 e.Encode(dst[off:])
79 off += e.Len()
80 }
81 }
82
83 type setEncoder []encoder
84
85 func (s setEncoder) Len() int {
86 var size int
87 for _, e := range s {
88 size += e.Len()
89 }
90 return size
91 }
92
93 func (s setEncoder) Encode(dst []byte) {
94
95
96
97
98
99
100
101
102 l := make([][]byte, len(s))
103 for i, e := range s {
104 l[i] = make([]byte, e.Len())
105 e.Encode(l[i])
106 }
107
108
109
110
111
112
113
114 slices.SortFunc(l, bytes.Compare)
115
116 var off int
117 for _, b := range l {
118 copy(dst[off:], b)
119 off += len(b)
120 }
121 }
122
123 type taggedEncoder struct {
124
125
126 scratch [8]byte
127 tag encoder
128 body encoder
129 }
130
131 func (t *taggedEncoder) Len() int {
132 return t.tag.Len() + t.body.Len()
133 }
134
135 func (t *taggedEncoder) Encode(dst []byte) {
136 t.tag.Encode(dst)
137 t.body.Encode(dst[t.tag.Len():])
138 }
139
140 type int64Encoder int64
141
142 func (i int64Encoder) Len() int {
143 n := 1
144
145 for i > 127 {
146 n++
147 i >>= 8
148 }
149
150 for i < -128 {
151 n++
152 i >>= 8
153 }
154
155 return n
156 }
157
158 func (i int64Encoder) Encode(dst []byte) {
159 n := i.Len()
160
161 for j := 0; j < n; j++ {
162 dst[j] = byte(i >> uint((n-1-j)*8))
163 }
164 }
165
166 func base128IntLength(n int64) int {
167 if n == 0 {
168 return 1
169 }
170
171 l := 0
172 for i := n; i > 0; i >>= 7 {
173 l++
174 }
175
176 return l
177 }
178
179 func appendBase128Int(dst []byte, n int64) []byte {
180 l := base128IntLength(n)
181
182 for i := l - 1; i >= 0; i-- {
183 o := byte(n >> uint(i*7))
184 o &= 0x7f
185 if i != 0 {
186 o |= 0x80
187 }
188
189 dst = append(dst, o)
190 }
191
192 return dst
193 }
194
195 func makeBigInt(n *big.Int) (encoder, error) {
196 if n == nil {
197 return nil, StructuralError{"empty integer"}
198 }
199
200 if n.Sign() < 0 {
201
202
203
204
205 nMinus1 := new(big.Int).Neg(n)
206 nMinus1.Sub(nMinus1, bigOne)
207 bytes := nMinus1.Bytes()
208 for i := range bytes {
209 bytes[i] ^= 0xff
210 }
211 if len(bytes) == 0 || bytes[0]&0x80 == 0 {
212 return multiEncoder([]encoder{byteFFEncoder, bytesEncoder(bytes)}), nil
213 }
214 return bytesEncoder(bytes), nil
215 } else if n.Sign() == 0 {
216
217 return byte00Encoder, nil
218 } else {
219 bytes := n.Bytes()
220 if len(bytes) > 0 && bytes[0]&0x80 != 0 {
221
222
223 return multiEncoder([]encoder{byte00Encoder, bytesEncoder(bytes)}), nil
224 }
225 return bytesEncoder(bytes), nil
226 }
227 }
228
229 func appendLength(dst []byte, i int) []byte {
230 n := lengthLength(i)
231
232 for ; n > 0; n-- {
233 dst = append(dst, byte(i>>uint((n-1)*8)))
234 }
235
236 return dst
237 }
238
239 func lengthLength(i int) (numBytes int) {
240 numBytes = 1
241 for i > 255 {
242 numBytes++
243 i >>= 8
244 }
245 return
246 }
247
248 func appendTagAndLength(dst []byte, t tagAndLength) []byte {
249 b := uint8(t.class) << 6
250 if t.isCompound {
251 b |= 0x20
252 }
253 if t.tag >= 31 {
254 b |= 0x1f
255 dst = append(dst, b)
256 dst = appendBase128Int(dst, int64(t.tag))
257 } else {
258 b |= uint8(t.tag)
259 dst = append(dst, b)
260 }
261
262 if t.length >= 128 {
263 l := lengthLength(t.length)
264 dst = append(dst, 0x80|byte(l))
265 dst = appendLength(dst, t.length)
266 } else {
267 dst = append(dst, byte(t.length))
268 }
269
270 return dst
271 }
272
273 type bitStringEncoder BitString
274
275 func (b bitStringEncoder) Len() int {
276 return len(b.Bytes) + 1
277 }
278
279 func (b bitStringEncoder) Encode(dst []byte) {
280 dst[0] = byte((8 - b.BitLength%8) % 8)
281 if copy(dst[1:], b.Bytes) != len(b.Bytes) {
282 panic("internal error")
283 }
284 }
285
286 type oidEncoder []int
287
288 func (oid oidEncoder) Len() int {
289 l := base128IntLength(int64(oid[0]*40 + oid[1]))
290 for i := 2; i < len(oid); i++ {
291 l += base128IntLength(int64(oid[i]))
292 }
293 return l
294 }
295
296 func (oid oidEncoder) Encode(dst []byte) {
297 dst = appendBase128Int(dst[:0], int64(oid[0]*40+oid[1]))
298 for i := 2; i < len(oid); i++ {
299 dst = appendBase128Int(dst, int64(oid[i]))
300 }
301 }
302
303 func makeObjectIdentifier(oid []int) (e encoder, err error) {
304 if len(oid) < 2 || oid[0] > 2 || (oid[0] < 2 && oid[1] >= 40) {
305 return nil, StructuralError{"invalid object identifier"}
306 }
307
308 return oidEncoder(oid), nil
309 }
310
311 func makePrintableString(s string) (e encoder, err error) {
312 for i := 0; i < len(s); i++ {
313
314
315
316
317
318
319 if !isPrintable(s[i], allowAsterisk, rejectAmpersand) {
320 return nil, StructuralError{"PrintableString contains invalid character"}
321 }
322 }
323
324 return stringEncoder(s), nil
325 }
326
327 func makeIA5String(s string) (e encoder, err error) {
328 for i := 0; i < len(s); i++ {
329 if s[i] > 127 {
330 return nil, StructuralError{"IA5String contains invalid character"}
331 }
332 }
333
334 return stringEncoder(s), nil
335 }
336
337 func makeNumericString(s string) (e encoder, err error) {
338 for i := 0; i < len(s); i++ {
339 if !isNumeric(s[i]) {
340 return nil, StructuralError{"NumericString contains invalid character"}
341 }
342 }
343
344 return stringEncoder(s), nil
345 }
346
347 func makeUTF8String(s string) encoder {
348 return stringEncoder(s)
349 }
350
351 func appendTwoDigits(dst []byte, v int) []byte {
352 return append(dst, byte('0'+(v/10)%10), byte('0'+v%10))
353 }
354
355 func appendFourDigits(dst []byte, v int) []byte {
356 return append(dst,
357 byte('0'+(v/1000)%10),
358 byte('0'+(v/100)%10),
359 byte('0'+(v/10)%10),
360 byte('0'+v%10))
361 }
362
363 func outsideUTCRange(t time.Time) bool {
364 year := t.Year()
365 return year < 1950 || year >= 2050
366 }
367
368 func makeUTCTime(t time.Time) (e encoder, err error) {
369 dst := make([]byte, 0, 18)
370
371 dst, err = appendUTCTime(dst, t)
372 if err != nil {
373 return nil, err
374 }
375
376 return bytesEncoder(dst), nil
377 }
378
379 func makeGeneralizedTime(t time.Time) (e encoder, err error) {
380 dst := make([]byte, 0, 20)
381
382 dst, err = appendGeneralizedTime(dst, t)
383 if err != nil {
384 return nil, err
385 }
386
387 return bytesEncoder(dst), nil
388 }
389
390 func appendUTCTime(dst []byte, t time.Time) (ret []byte, err error) {
391 year := t.Year()
392
393 switch {
394 case 1950 <= year && year < 2000:
395 dst = appendTwoDigits(dst, year-1900)
396 case 2000 <= year && year < 2050:
397 dst = appendTwoDigits(dst, year-2000)
398 default:
399 return nil, StructuralError{"cannot represent time as UTCTime"}
400 }
401
402 return appendTimeCommon(dst, t), nil
403 }
404
405 func appendGeneralizedTime(dst []byte, t time.Time) (ret []byte, err error) {
406 year := t.Year()
407 if year < 0 || year > 9999 {
408 return nil, StructuralError{"cannot represent time as GeneralizedTime"}
409 }
410
411 dst = appendFourDigits(dst, year)
412
413 return appendTimeCommon(dst, t), nil
414 }
415
416 func appendTimeCommon(dst []byte, t time.Time) []byte {
417 _, month, day := t.Date()
418
419 dst = appendTwoDigits(dst, int(month))
420 dst = appendTwoDigits(dst, day)
421
422 hour, min, sec := t.Clock()
423
424 dst = appendTwoDigits(dst, hour)
425 dst = appendTwoDigits(dst, min)
426 dst = appendTwoDigits(dst, sec)
427
428 _, offset := t.Zone()
429
430 switch {
431 case offset/60 == 0:
432 return append(dst, 'Z')
433 case offset > 0:
434 dst = append(dst, '+')
435 case offset < 0:
436 dst = append(dst, '-')
437 }
438
439 offsetMinutes := offset / 60
440 if offsetMinutes < 0 {
441 offsetMinutes = -offsetMinutes
442 }
443
444 dst = appendTwoDigits(dst, offsetMinutes/60)
445 dst = appendTwoDigits(dst, offsetMinutes%60)
446
447 return dst
448 }
449
450 func stripTagAndLength(in []byte) []byte {
451 _, offset, err := parseTagAndLength(in, 0)
452 if err != nil {
453 return in
454 }
455 return in[offset:]
456 }
457
458 func makeBody(value reflect.Value, params fieldParameters) (e encoder, err error) {
459 switch value.Type() {
460 case flagType:
461 return bytesEncoder(nil), nil
462 case timeType:
463 t := value.Interface().(time.Time)
464 if params.timeType == TagGeneralizedTime || outsideUTCRange(t) {
465 return makeGeneralizedTime(t)
466 }
467 return makeUTCTime(t)
468 case bitStringType:
469 return bitStringEncoder(value.Interface().(BitString)), nil
470 case objectIdentifierType:
471 return makeObjectIdentifier(value.Interface().(ObjectIdentifier))
472 case bigIntType:
473 return makeBigInt(value.Interface().(*big.Int))
474 }
475
476 switch v := value; v.Kind() {
477 case reflect.Bool:
478 if v.Bool() {
479 return byteFFEncoder, nil
480 }
481 return byte00Encoder, nil
482 case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64:
483 return int64Encoder(v.Int()), nil
484 case reflect.Struct:
485 t := v.Type()
486
487 for i := 0; i < t.NumField(); i++ {
488 if !t.Field(i).IsExported() {
489 return nil, StructuralError{"struct contains unexported fields"}
490 }
491 }
492
493 startingField := 0
494
495 n := t.NumField()
496 if n == 0 {
497 return bytesEncoder(nil), nil
498 }
499
500
501
502 if t.Field(0).Type == rawContentsType {
503 s := v.Field(0)
504 if s.Len() > 0 {
505 bytes := s.Bytes()
506
510 return bytesEncoder(stripTagAndLength(bytes)), nil
511 }
512
513 startingField = 1
514 }
515
516 switch n1 := n - startingField; n1 {
517 case 0:
518 return bytesEncoder(nil), nil
519 case 1:
520 return makeField(v.Field(startingField), parseFieldParameters(t.Field(startingField).Tag.Get("asn1")))
521 default:
522 m := make([]encoder, n1)
523 for i := 0; i < n1; i++ {
524 m[i], err = makeField(v.Field(i+startingField), parseFieldParameters(t.Field(i+startingField).Tag.Get("asn1")))
525 if err != nil {
526 return nil, err
527 }
528 }
529
530 return multiEncoder(m), nil
531 }
532 case reflect.Slice:
533 sliceType := v.Type()
534 if sliceType.Elem().Kind() == reflect.Uint8 {
535 return bytesEncoder(v.Bytes()), nil
536 }
537
538 var fp fieldParameters
539
540 switch l := v.Len(); l {
541 case 0:
542 return bytesEncoder(nil), nil
543 case 1:
544 return makeField(v.Index(0), fp)
545 default:
546 m := make([]encoder, l)
547
548 for i := 0; i < l; i++ {
549 m[i], err = makeField(v.Index(i), fp)
550 if err != nil {
551 return nil, err
552 }
553 }
554
555 if params.set {
556 return setEncoder(m), nil
557 }
558 return multiEncoder(m), nil
559 }
560 case reflect.String:
561 switch params.stringType {
562 case TagIA5String:
563 return makeIA5String(v.String())
564 case TagPrintableString:
565 return makePrintableString(v.String())
566 case TagNumericString:
567 return makeNumericString(v.String())
568 default:
569 return makeUTF8String(v.String()), nil
570 }
571 }
572
573 return nil, StructuralError{"unknown Go type"}
574 }
575
576 func makeField(v reflect.Value, params fieldParameters) (e encoder, err error) {
577 if !v.IsValid() {
578 return nil, fmt.Errorf("asn1: cannot marshal nil value")
579 }
580
581 if v.Kind() == reflect.Interface && v.Type().NumMethod() == 0 {
582 return makeField(v.Elem(), params)
583 }
584
585 if v.Kind() == reflect.Slice && v.Len() == 0 && params.omitEmpty {
586 return bytesEncoder(nil), nil
587 }
588
589 if params.optional && params.defaultValue != nil && canHaveDefaultValue(v.Kind()) {
590 defaultValue := reflect.New(v.Type()).Elem()
591 defaultValue.SetInt(*params.defaultValue)
592
593 if reflect.DeepEqual(v.Interface(), defaultValue.Interface()) {
594 return bytesEncoder(nil), nil
595 }
596 }
597
598
599
600
601 if params.optional && params.defaultValue == nil {
602 if reflect.DeepEqual(v.Interface(), reflect.Zero(v.Type()).Interface()) {
603 return bytesEncoder(nil), nil
604 }
605 }
606
607 if v.Type() == rawValueType {
608 rv := v.Interface().(RawValue)
609 if len(rv.FullBytes) != 0 {
610 return bytesEncoder(rv.FullBytes), nil
611 }
612
613 t := new(taggedEncoder)
614
615 t.tag = bytesEncoder(appendTagAndLength(t.scratch[:0], tagAndLength{rv.Class, rv.Tag, len(rv.Bytes), rv.IsCompound}))
616 t.body = bytesEncoder(rv.Bytes)
617
618 return t, nil
619 }
620
621 matchAny, tag, isCompound, ok := getUniversalType(v.Type())
622 if !ok || matchAny {
623 return nil, StructuralError{fmt.Sprintf("unknown Go type: %v", v.Type())}
624 }
625
626 if params.timeType != 0 && tag != TagUTCTime {
627 return nil, StructuralError{"explicit time type given to non-time member"}
628 }
629
630 if params.stringType != 0 && tag != TagPrintableString {
631 return nil, StructuralError{"explicit string type given to non-string member"}
632 }
633
634 switch tag {
635 case TagPrintableString:
636 if params.stringType == 0 {
637
638
639
640 for _, r := range v.String() {
641 if r >= utf8.RuneSelf || !isPrintable(byte(r), rejectAsterisk, rejectAmpersand) {
642 if !utf8.ValidString(v.String()) {
643 return nil, errors.New("asn1: string not valid UTF-8")
644 }
645 tag = TagUTF8String
646 break
647 }
648 }
649 } else {
650 tag = params.stringType
651 }
652 case TagUTCTime:
653 if params.timeType == TagGeneralizedTime || outsideUTCRange(v.Interface().(time.Time)) {
654 tag = TagGeneralizedTime
655 }
656 }
657
658 if params.set {
659 if tag != TagSequence {
660 return nil, StructuralError{"non sequence tagged as set"}
661 }
662 tag = TagSet
663 }
664
665
666
667
668
669
670 if tag == TagSet && !params.set {
671 params.set = true
672 }
673
674 t := new(taggedEncoder)
675
676 t.body, err = makeBody(v, params)
677 if err != nil {
678 return nil, err
679 }
680
681 bodyLen := t.body.Len()
682
683 class := ClassUniversal
684 if params.tag != nil {
685 if params.application {
686 class = ClassApplication
687 } else if params.private {
688 class = ClassPrivate
689 } else {
690 class = ClassContextSpecific
691 }
692
693 if params.explicit {
694 t.tag = bytesEncoder(appendTagAndLength(t.scratch[:0], tagAndLength{ClassUniversal, tag, bodyLen, isCompound}))
695
696 tt := new(taggedEncoder)
697
698 tt.body = t
699
700 tt.tag = bytesEncoder(appendTagAndLength(tt.scratch[:0], tagAndLength{
701 class: class,
702 tag: *params.tag,
703 length: bodyLen + t.tag.Len(),
704 isCompound: true,
705 }))
706
707 return tt, nil
708 }
709
710
711 tag = *params.tag
712 }
713
714 t.tag = bytesEncoder(appendTagAndLength(t.scratch[:0], tagAndLength{class, tag, bodyLen, isCompound}))
715
716 return t, nil
717 }
718
719
720
721
722
723
724
725
726
727
728
729
730 func Marshal(val any) ([]byte, error) {
731 return MarshalWithParams(val, "")
732 }
733
734
735
736 func MarshalWithParams(val any, params string) ([]byte, error) {
737 e, err := makeField(reflect.ValueOf(val), parseFieldParameters(params))
738 if err != nil {
739 return nil, err
740 }
741 b := make([]byte, e.Len())
742 e.Encode(b)
743 return b, nil
744 }
745
View as plain text