1
2
3
4
5 package xml
6
7 import (
8 "bytes"
9 "encoding"
10 "errors"
11 "fmt"
12 "reflect"
13 "runtime"
14 "strconv"
15 "strings"
16 )
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133 func Unmarshal(data []byte, v any) error {
134 return NewDecoder(bytes.NewReader(data)).Decode(v)
135 }
136
137
138
139 func (d *Decoder) Decode(v any) error {
140 return d.DecodeElement(v, nil)
141 }
142
143
144
145
146
147 func (d *Decoder) DecodeElement(v any, start *StartElement) error {
148 val := reflect.ValueOf(v)
149 if val.Kind() != reflect.Pointer {
150 return errors.New("non-pointer passed to Unmarshal")
151 }
152
153 if val.IsNil() {
154 return errors.New("nil pointer passed to Unmarshal")
155 }
156 return d.unmarshal(val.Elem(), start, 0)
157 }
158
159
160 type UnmarshalError string
161
162 func (e UnmarshalError) Error() string { return string(e) }
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179 type Unmarshaler interface {
180 UnmarshalXML(d *Decoder, start StartElement) error
181 }
182
183
184
185
186
187
188
189
190
191 type UnmarshalerAttr interface {
192 UnmarshalXMLAttr(attr Attr) error
193 }
194
195
196 func receiverType(val any) string {
197 t := reflect.TypeOf(val)
198 if t.Name() != "" {
199 return t.String()
200 }
201 return "(" + t.String() + ")"
202 }
203
204
205
206 func (d *Decoder) unmarshalInterface(val Unmarshaler, start *StartElement) error {
207
208 d.pushEOF()
209
210 d.unmarshalDepth++
211 err := val.UnmarshalXML(d, *start)
212 d.unmarshalDepth--
213 if err != nil {
214 d.popEOF()
215 return err
216 }
217
218 if !d.popEOF() {
219 return fmt.Errorf("xml: %s.UnmarshalXML did not consume entire <%s> element", receiverType(val), start.Name.Local)
220 }
221
222 return nil
223 }
224
225
226
227
228 func (d *Decoder) unmarshalTextInterface(val encoding.TextUnmarshaler) error {
229 var buf []byte
230 depth := 1
231 for depth > 0 {
232 t, err := d.Token()
233 if err != nil {
234 return err
235 }
236 switch t := t.(type) {
237 case CharData:
238 if depth == 1 {
239 buf = append(buf, t...)
240 }
241 case StartElement:
242 depth++
243 case EndElement:
244 depth--
245 }
246 }
247 return val.UnmarshalText(buf)
248 }
249
250
251 func (d *Decoder) unmarshalAttr(val reflect.Value, attr Attr) error {
252 if val.Kind() == reflect.Pointer {
253 if val.IsNil() {
254 val.Set(reflect.New(val.Type().Elem()))
255 }
256 val = val.Elem()
257 }
258 if val.CanInterface() && val.Type().Implements(unmarshalerAttrType) {
259
260
261 return val.Interface().(UnmarshalerAttr).UnmarshalXMLAttr(attr)
262 }
263 if val.CanAddr() {
264 pv := val.Addr()
265 if pv.CanInterface() && pv.Type().Implements(unmarshalerAttrType) {
266 return pv.Interface().(UnmarshalerAttr).UnmarshalXMLAttr(attr)
267 }
268 }
269
270
271 if val.CanInterface() && val.Type().Implements(textUnmarshalerType) {
272
273
274 return val.Interface().(encoding.TextUnmarshaler).UnmarshalText([]byte(attr.Value))
275 }
276 if val.CanAddr() {
277 pv := val.Addr()
278 if pv.CanInterface() && pv.Type().Implements(textUnmarshalerType) {
279 return pv.Interface().(encoding.TextUnmarshaler).UnmarshalText([]byte(attr.Value))
280 }
281 }
282
283 if val.Type().Kind() == reflect.Slice && val.Type().Elem().Kind() != reflect.Uint8 {
284
285
286 n := val.Len()
287 val.Grow(1)
288 val.SetLen(n + 1)
289
290
291 if err := d.unmarshalAttr(val.Index(n), attr); err != nil {
292 val.SetLen(n)
293 return err
294 }
295 return nil
296 }
297
298 if val.Type() == attrType {
299 val.Set(reflect.ValueOf(attr))
300 return nil
301 }
302
303 return copyValue(val, []byte(attr.Value))
304 }
305
306 var (
307 attrType = reflect.TypeFor[Attr]()
308 unmarshalerType = reflect.TypeFor[Unmarshaler]()
309 unmarshalerAttrType = reflect.TypeFor[UnmarshalerAttr]()
310 textUnmarshalerType = reflect.TypeFor[encoding.TextUnmarshaler]()
311 )
312
313 const (
314 maxUnmarshalDepth = 10000
315 maxUnmarshalDepthWasm = 5000
316 )
317
318 var errUnmarshalDepth = errors.New("exceeded max depth")
319
320
321 func (d *Decoder) unmarshal(val reflect.Value, start *StartElement, depth int) error {
322 if depth >= maxUnmarshalDepth || runtime.GOARCH == "wasm" && depth >= maxUnmarshalDepthWasm {
323 return errUnmarshalDepth
324 }
325
326 if start == nil {
327 for {
328 tok, err := d.Token()
329 if err != nil {
330 return err
331 }
332 if t, ok := tok.(StartElement); ok {
333 start = &t
334 break
335 }
336 }
337 }
338
339
340
341 if val.Kind() == reflect.Interface && !val.IsNil() {
342 e := val.Elem()
343 if e.Kind() == reflect.Pointer && !e.IsNil() {
344 val = e
345 }
346 }
347
348 if val.Kind() == reflect.Pointer {
349 if val.IsNil() {
350 val.Set(reflect.New(val.Type().Elem()))
351 }
352 val = val.Elem()
353 }
354
355 if val.CanInterface() && val.Type().Implements(unmarshalerType) {
356
357
358 return d.unmarshalInterface(val.Interface().(Unmarshaler), start)
359 }
360
361 if val.CanAddr() {
362 pv := val.Addr()
363 if pv.CanInterface() && pv.Type().Implements(unmarshalerType) {
364 return d.unmarshalInterface(pv.Interface().(Unmarshaler), start)
365 }
366 }
367
368 if val.CanInterface() && val.Type().Implements(textUnmarshalerType) {
369 return d.unmarshalTextInterface(val.Interface().(encoding.TextUnmarshaler))
370 }
371
372 if val.CanAddr() {
373 pv := val.Addr()
374 if pv.CanInterface() && pv.Type().Implements(textUnmarshalerType) {
375 return d.unmarshalTextInterface(pv.Interface().(encoding.TextUnmarshaler))
376 }
377 }
378
379 var (
380 data []byte
381 saveData reflect.Value
382 comment []byte
383 saveComment reflect.Value
384 saveXML reflect.Value
385 saveXMLIndex int
386 saveXMLData []byte
387 saveAny reflect.Value
388 sv reflect.Value
389 tinfo *typeInfo
390 err error
391 )
392
393 switch v := val; v.Kind() {
394 default:
395 return errors.New("unknown type " + v.Type().String())
396
397 case reflect.Interface:
398
399
400
401 return d.Skip()
402
403 case reflect.Slice:
404 typ := v.Type()
405 if typ.Elem().Kind() == reflect.Uint8 {
406
407 saveData = v
408 break
409 }
410
411
412
413 n := v.Len()
414 v.Grow(1)
415 v.SetLen(n + 1)
416
417
418 if err := d.unmarshal(v.Index(n), start, depth+1); err != nil {
419 v.SetLen(n)
420 return err
421 }
422 return nil
423
424 case reflect.Bool, reflect.Float32, reflect.Float64, reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64, reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64, reflect.Uintptr, reflect.String:
425 saveData = v
426
427 case reflect.Struct:
428 typ := v.Type()
429 if typ == nameType {
430 v.Set(reflect.ValueOf(start.Name))
431 break
432 }
433
434 sv = v
435 tinfo, err = getTypeInfo(typ)
436 if err != nil {
437 return err
438 }
439
440
441 if tinfo.xmlname != nil {
442 finfo := tinfo.xmlname
443 if finfo.name != "" && finfo.name != start.Name.Local {
444 return UnmarshalError("expected element type <" + finfo.name + "> but have <" + start.Name.Local + ">")
445 }
446 if finfo.xmlns != "" && finfo.xmlns != start.Name.Space {
447 e := "expected element <" + finfo.name + "> in name space " + finfo.xmlns + " but have "
448 if start.Name.Space == "" {
449 e += "no name space"
450 } else {
451 e += start.Name.Space
452 }
453 return UnmarshalError(e)
454 }
455 fv := finfo.value(sv, initNilPointers)
456 if _, ok := fv.Interface().(Name); ok {
457 fv.Set(reflect.ValueOf(start.Name))
458 }
459 }
460
461
462 for _, a := range start.Attr {
463 handled := false
464 any := -1
465 for i := range tinfo.fields {
466 finfo := &tinfo.fields[i]
467 switch finfo.flags & fMode {
468 case fAttr:
469 strv := finfo.value(sv, initNilPointers)
470 if a.Name.Local == finfo.name && (finfo.xmlns == "" || finfo.xmlns == a.Name.Space) {
471 if err := d.unmarshalAttr(strv, a); err != nil {
472 return err
473 }
474 handled = true
475 }
476
477 case fAny | fAttr:
478 if any == -1 {
479 any = i
480 }
481 }
482 }
483 if !handled && any >= 0 {
484 finfo := &tinfo.fields[any]
485 strv := finfo.value(sv, initNilPointers)
486 if err := d.unmarshalAttr(strv, a); err != nil {
487 return err
488 }
489 }
490 }
491
492
493 for i := range tinfo.fields {
494 finfo := &tinfo.fields[i]
495 switch finfo.flags & fMode {
496 case fCDATA, fCharData:
497 if !saveData.IsValid() {
498 saveData = finfo.value(sv, initNilPointers)
499 }
500
501 case fComment:
502 if !saveComment.IsValid() {
503 saveComment = finfo.value(sv, initNilPointers)
504 }
505
506 case fAny, fAny | fElement:
507 if !saveAny.IsValid() {
508 saveAny = finfo.value(sv, initNilPointers)
509 }
510
511 case fInnerXML:
512 if !saveXML.IsValid() {
513 saveXML = finfo.value(sv, initNilPointers)
514 if d.saved == nil {
515 saveXMLIndex = 0
516 d.saved = new(bytes.Buffer)
517 } else {
518 saveXMLIndex = d.savedOffset()
519 }
520 }
521 }
522 }
523 }
524
525
526
527 Loop:
528 for {
529 var savedOffset int
530 if saveXML.IsValid() {
531 savedOffset = d.savedOffset()
532 }
533 tok, err := d.Token()
534 if err != nil {
535 return err
536 }
537 switch t := tok.(type) {
538 case StartElement:
539 consumed := false
540 if sv.IsValid() {
541
542
543 consumed, err = d.unmarshalPath(tinfo, sv, nil, &t, depth)
544 if err != nil {
545 return err
546 }
547 if !consumed && saveAny.IsValid() {
548 consumed = true
549 if err := d.unmarshal(saveAny, &t, depth+1); err != nil {
550 return err
551 }
552 }
553 }
554 if !consumed {
555 if err := d.Skip(); err != nil {
556 return err
557 }
558 }
559
560 case EndElement:
561 if saveXML.IsValid() {
562 saveXMLData = d.saved.Bytes()[saveXMLIndex:savedOffset]
563 if saveXMLIndex == 0 {
564 d.saved = nil
565 }
566 }
567 break Loop
568
569 case CharData:
570 if saveData.IsValid() {
571 data = append(data, t...)
572 }
573
574 case Comment:
575 if saveComment.IsValid() {
576 comment = append(comment, t...)
577 }
578 }
579 }
580
581 if saveData.IsValid() && saveData.CanInterface() && saveData.Type().Implements(textUnmarshalerType) {
582 if err := saveData.Interface().(encoding.TextUnmarshaler).UnmarshalText(data); err != nil {
583 return err
584 }
585 saveData = reflect.Value{}
586 }
587
588 if saveData.IsValid() && saveData.CanAddr() {
589 pv := saveData.Addr()
590 if pv.CanInterface() && pv.Type().Implements(textUnmarshalerType) {
591 if err := pv.Interface().(encoding.TextUnmarshaler).UnmarshalText(data); err != nil {
592 return err
593 }
594 saveData = reflect.Value{}
595 }
596 }
597
598 if err := copyValue(saveData, data); err != nil {
599 return err
600 }
601
602 switch t := saveComment; t.Kind() {
603 case reflect.String:
604 t.SetString(string(comment))
605 case reflect.Slice:
606 t.Set(reflect.ValueOf(comment))
607 }
608
609 switch t := saveXML; t.Kind() {
610 case reflect.String:
611 t.SetString(string(saveXMLData))
612 case reflect.Slice:
613 if t.Type().Elem().Kind() == reflect.Uint8 {
614 t.Set(reflect.ValueOf(saveXMLData))
615 }
616 }
617
618 return nil
619 }
620
621 func copyValue(dst reflect.Value, src []byte) (err error) {
622 dst0 := dst
623
624 if dst.Kind() == reflect.Pointer {
625 if dst.IsNil() {
626 dst.Set(reflect.New(dst.Type().Elem()))
627 }
628 dst = dst.Elem()
629 }
630
631
632 switch dst.Kind() {
633 case reflect.Invalid:
634
635 default:
636 return errors.New("cannot unmarshal into " + dst0.Type().String())
637 case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64:
638 if len(src) == 0 {
639 dst.SetInt(0)
640 return nil
641 }
642 itmp, err := strconv.ParseInt(strings.TrimSpace(string(src)), 10, dst.Type().Bits())
643 if err != nil {
644 return err
645 }
646 dst.SetInt(itmp)
647 case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64, reflect.Uintptr:
648 if len(src) == 0 {
649 dst.SetUint(0)
650 return nil
651 }
652 utmp, err := strconv.ParseUint(strings.TrimSpace(string(src)), 10, dst.Type().Bits())
653 if err != nil {
654 return err
655 }
656 dst.SetUint(utmp)
657 case reflect.Float32, reflect.Float64:
658 if len(src) == 0 {
659 dst.SetFloat(0)
660 return nil
661 }
662 ftmp, err := strconv.ParseFloat(strings.TrimSpace(string(src)), dst.Type().Bits())
663 if err != nil {
664 return err
665 }
666 dst.SetFloat(ftmp)
667 case reflect.Bool:
668 if len(src) == 0 {
669 dst.SetBool(false)
670 return nil
671 }
672 value, err := strconv.ParseBool(strings.TrimSpace(string(src)))
673 if err != nil {
674 return err
675 }
676 dst.SetBool(value)
677 case reflect.String:
678 dst.SetString(string(src))
679 case reflect.Slice:
680 if len(src) == 0 {
681
682 src = []byte{}
683 }
684 dst.SetBytes(src)
685 }
686 return nil
687 }
688
689
690
691
692
693
694 func (d *Decoder) unmarshalPath(tinfo *typeInfo, sv reflect.Value, parents []string, start *StartElement, depth int) (consumed bool, err error) {
695 recurse := false
696 Loop:
697 for i := range tinfo.fields {
698 finfo := &tinfo.fields[i]
699 if finfo.flags&fElement == 0 || len(finfo.parents) < len(parents) || finfo.xmlns != "" && finfo.xmlns != start.Name.Space {
700 continue
701 }
702 for j := range parents {
703 if parents[j] != finfo.parents[j] {
704 continue Loop
705 }
706 }
707 if len(finfo.parents) == len(parents) && finfo.name == start.Name.Local {
708
709 return true, d.unmarshal(finfo.value(sv, initNilPointers), start, depth+1)
710 }
711 if len(finfo.parents) > len(parents) && finfo.parents[len(parents)] == start.Name.Local {
712
713
714
715 recurse = true
716
717
718
719 parents = finfo.parents[:len(parents)+1]
720 break
721 }
722 }
723 if !recurse {
724
725 return false, nil
726 }
727
728
729
730 for {
731 var tok Token
732 tok, err = d.Token()
733 if err != nil {
734 return true, err
735 }
736 switch t := tok.(type) {
737 case StartElement:
738
739
740 consumed2, err := d.unmarshalPath(tinfo, sv, parents, &t, depth)
741 if err != nil {
742 return true, err
743 }
744 if !consumed2 {
745 if err := d.Skip(); err != nil {
746 return true, err
747 }
748 }
749 case EndElement:
750 return true, nil
751 }
752 }
753 }
754
755
756
757
758
759
760 func (d *Decoder) Skip() error {
761 var depth int64
762 for {
763 tok, err := d.Token()
764 if err != nil {
765 return err
766 }
767 switch tok.(type) {
768 case StartElement:
769 depth++
770 case EndElement:
771 if depth == 0 {
772 return nil
773 }
774 depth--
775 }
776 }
777 }
778
View as plain text