| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568 |
- // Copyright 2022 The Go Authors. All rights reserved.
- // Use of this source code is governed by a BSD-style
- // license that can be found in the LICENSE file.
- package json
- import (
- "fmt"
- "reflect"
- "regexp"
- "strconv"
- "sync"
- )
- // DiscriminatorToTypeFunc is used to get a reflect.Type from its
- // discriminator.
- type DiscriminatorToTypeFunc func(discriminator string) (reflect.Type, bool)
- // TypeToDiscriminatorFunc is used to get a discriminator string from a
- // reflect.Type. Empty return value suppresses discriminator rendering.
- type TypeToDiscriminatorFunc func(reflect.Type) (discriminator string)
- // DefaultDiscriminatorFunc is shorthand for the ShortName func and is used when
- // no other discriminator func is set explicitly
- var DefaultDiscriminatorFunc = ShortName
- // ShortName returns the type name in golang without the package name
- func ShortName(t reflect.Type) (discriminator string) {
- tn := t.Name()
- if tn == "" {
- return t.String()
- }
- return tn
- }
- // FullName return the name of the type prefixed with the package name as
- // appropriate
- func FullName(t reflect.Type) (discriminator string) {
- tn := t.Name()
- if tn == "" {
- return t.String()
- }
- if pp := t.PkgPath(); pp != "" {
- return fmt.Sprintf("%s.%s", pp, tn)
- }
- return tn
- }
- // DiscriminatorEncodeMode is a mask that describes the different encode
- // options.
- type DiscriminatorEncodeMode uint8
- const (
- // DiscriminatorEncodeTypeNameRootValue causes the type name to be encoded
- // for the root value.
- DiscriminatorEncodeTypeNameRootValue DiscriminatorEncodeMode = 1 << iota
- // DiscriminatorEncodeTypeNameAllObjects causes the type name to be encoded
- // for all struct and map values. Please note this specifically does not
- // apply to the root value.
- DiscriminatorEncodeTypeNameAllObjects
- // DiscriminatorEncodeTypeNameIfRequired is the default behavior when
- // the discriminator is set, and the type name is only encoded if required.
- DiscriminatorEncodeTypeNameIfRequired DiscriminatorEncodeMode = 0
- )
- func (m DiscriminatorEncodeMode) root() bool {
- return m&DiscriminatorEncodeTypeNameRootValue > 0
- }
- func (m DiscriminatorEncodeMode) all() bool {
- return m&DiscriminatorEncodeTypeNameAllObjects > 0
- }
- func (d *decodeState) isDiscriminatorSet() bool {
- return d.discriminatorTypeFieldName != "" &&
- d.discriminatorValueFieldName != ""
- }
- // discriminatorOpType describes the current operation related to
- // discriminators when reading a JSON object's fields.
- type discriminatorOpType uint8
- const (
- // discriminatorOpTypeNameField indicates the discriminator type name
- // field was discovered.
- discriminatorOpTypeNameField = iota + 1
- // discriminatorOpValueField indicates the discriminator value field
- // was discovered.
- discriminatorOpValueField
- )
- func (d *decodeState) discriminatorGetValue() (reflect.Value, error) {
- // Record the current offset so we know where the data starts.
- offset := d.readIndex()
- // Create a temporary decodeState used to inspect the current object
- // and determine its discriminator type and decode its value.
- dd := &decodeState{
- disallowUnknownFields: d.disallowUnknownFields,
- useNumber: d.useNumber,
- discriminatorToTypeFn: d.discriminatorToTypeFn,
- discriminatorTypeFieldName: d.discriminatorTypeFieldName,
- discriminatorValueFieldName: d.discriminatorValueFieldName,
- }
- dd.init(append([]byte{}, d.data[offset:]...))
- defer freeScanner(&dd.scan)
- dd.scan.reset()
- var (
- t reflect.Type // the instance of the type
- valueOff = -1 // the offset of a possible discriminator value
- )
- dd.scanWhile(scanSkipSpace)
- if dd.opcode != scanBeginObject {
- panic(phasePanicMsg)
- }
- for {
- dd.scanWhile(scanSkipSpace)
- if dd.opcode == scanEndObject {
- // closing } - can only happen on first iteration.
- break
- }
- if dd.opcode != scanBeginLiteral {
- panic(phasePanicMsg)
- }
- // Read key.
- start := dd.readIndex()
- dd.rescanLiteral()
- item := dd.data[start:dd.readIndex()]
- key, ok := unquote(item)
- if !ok {
- panic(phasePanicMsg)
- }
- // Check to see if the key is related to the discriminator.
- var discriminatorOp discriminatorOpType
- switch key {
- case d.discriminatorTypeFieldName:
- discriminatorOp = discriminatorOpTypeNameField
- case d.discriminatorValueFieldName:
- discriminatorOp = discriminatorOpValueField
- }
- // Read : before value.
- if dd.opcode == scanSkipSpace {
- dd.scanWhile(scanSkipSpace)
- }
- if dd.opcode != scanObjectKey {
- panic(phasePanicMsg)
- }
- dd.scanWhile(scanSkipSpace)
- // Read value.
- valOff := dd.readIndex()
- val := dd.valueInterface()
- switch discriminatorOp {
- case discriminatorOpTypeNameField:
- tn, ok := val.(string)
- if !ok {
- return reflect.Value{}, fmt.Errorf(
- "json: discriminator type at offset %d is not string",
- offset+valOff)
- }
- if tn == "" {
- return reflect.Value{}, fmt.Errorf(
- "json: discriminator type at offset %d is empty",
- offset+valOff)
- }
- // Parse the type name into a type instance.
- ti, err := discriminatorParseTypeName(tn, d.discriminatorToTypeFn)
- if err != nil {
- return reflect.Value{}, err
- }
- // Assign the type instance to the outer variable, t.
- t = ti
- // Primitive types and types with Unmarshaler are wrapped in a
- // structure with type and value fields. Structures and Maps not
- // implementing Unmarshaler use discriminator embedded within their
- // content.
- if useNestedDiscriminator(t) {
- // If the type is a map or a struct not implementing Unmarshaler
- // then it is not necessary to continue walking over the current
- // JSON object since it will be completely re-scanned to decode
- // its value into the discovered type.
- dd.opcode = scanEndObject
- } else {
- // Otherwise if the value offset has been discovered then it is
- // safe to stop walking over the current JSON object as well.
- if valueOff > -1 {
- dd.opcode = scanEndObject
- }
- }
- case discriminatorOpValueField:
- valueOff = valOff
- // If the type has been discovered then it is safe to stop walking
- // over the current JSON object.
- if t != nil {
- dd.opcode = scanEndObject
- }
- }
- // Next token must be , or }.
- if dd.opcode == scanSkipSpace {
- dd.scanWhile(scanSkipSpace)
- }
- if dd.opcode == scanEndObject {
- break
- }
- if dd.opcode != scanObjectValue {
- panic(phasePanicMsg)
- }
- }
- // If there is not a type discriminator then return early.
- if t == nil {
- return reflect.Value{}, fmt.Errorf("json: missing discriminator")
- }
- // Instantiate a new instance of the discriminated type.
- var v reflect.Value
- switch t.Kind() {
- case reflect.Slice:
- // MakeSlice returns a value that is not addressable.
- // Instead, use MakeSlice to get the type, then use
- // reflect.New to create an addressable value.
- v = reflect.New(reflect.MakeSlice(t, 0, 0).Type()).Elem()
- case reflect.Map:
- // MakeMap returns a value that is not addressable.
- // Instead, use MakeMap to get the type, then use
- // reflect.New to create an addressable value.
- v = reflect.New(reflect.MakeMap(t).Type()).Elem()
- case reflect.Complex64, reflect.Complex128:
- return reflect.Value{}, fmt.Errorf("json: unsupported discriminator type: %s", t.Kind())
- default:
- v = reflect.New(t)
- }
- // Reset the decode state to prepare for decoding the data.
- dd.scan.reset()
- if useNestedDiscriminator(t) {
- // Set the offset to zero since the entire object will be decoded
- // into v.
- dd.off = 0
- } else {
- // Set the offset to what it was before the discriminator value was
- // read so only the value field is decoded into v.
- dd.off = valueOff
- }
- // This will initialize the correct scan step and op code.
- dd.scanWhile(scanSkipSpace)
- // Decode the data into the value.
- if err := dd.value(v); err != nil {
- return reflect.Value{}, err
- }
- // Check the saved error as well since the decoder.value function does not
- // always return an error. If the reflected value is still zero, then it is
- // likely the decoder was unable to decode the value.
- if err := dd.savedError; err != nil {
- switch v.Kind() {
- case reflect.Ptr, reflect.Interface:
- v = v.Elem()
- }
- if v.IsZero() {
- return reflect.Value{}, err
- }
- }
- return v, nil
- }
- func (d *decodeState) discriminatorInterfaceDecode(t reflect.Type, v reflect.Value) error {
- defer func() {
- // Advance the decode state, throwing away the value.
- _ = d.objectInterface()
- }()
- dv, err := d.discriminatorGetValue()
- if err != nil {
- return err
- }
- switch dv.Kind() {
- case reflect.Map, reflect.Slice:
- if dv.Type().AssignableTo(t) {
- v.Set(dv)
- return nil
- }
- if pdv := dv.Addr(); pdv.Type().AssignableTo(t) {
- v.Set(pdv)
- return nil
- }
- case reflect.Ptr:
- if dve := dv.Elem(); dve.Type().AssignableTo(t) {
- v.Set(dve)
- return nil
- }
- if dv.Type().AssignableTo(t) {
- v.Set(dv)
- return nil
- }
- }
- return fmt.Errorf("json: unsupported discriminator kind: %s", dv.Kind())
- }
- func (o encOpts) isDiscriminatorSet() bool {
- return o.discriminatorTypeFieldName != "" &&
- o.discriminatorValueFieldName != ""
- }
- func discriminatorInterfaceEncode(e *encodeState, v reflect.Value, opts encOpts) {
- v = v.Elem()
- if v.Type().Implements(marshalerType) {
- discriminatorValue := opts.discriminatorValueFn(v.Type())
- if discriminatorValue == "" {
- marshalerEncoder(e, v, opts)
- }
- e.WriteString(`{"`)
- e.WriteString(opts.discriminatorTypeFieldName)
- e.WriteString(`":"`)
- e.WriteString(discriminatorValue)
- e.WriteString(`","`)
- e.WriteString(opts.discriminatorValueFieldName)
- e.WriteString(`":`)
- marshalerEncoder(e, v, opts)
- e.WriteByte('}')
- return
- }
- switch v.Kind() {
- case reflect.Chan, reflect.Func, reflect.Invalid:
- e.error(&UnsupportedValueError{v, fmt.Sprintf("invalid kind: %s", v.Kind())})
- case reflect.Map:
- e.discriminatorEncodeTypeName = true
- newMapEncoder(v.Type())(e, v, opts)
- case reflect.Struct:
- e.discriminatorEncodeTypeName = true
- newStructEncoder(v.Type())(e, v, opts)
- case reflect.Ptr:
- discriminatorInterfaceEncode(e, v, opts)
- default:
- discriminatorValue := opts.discriminatorValueFn(v.Type())
- if discriminatorValue == "" {
- e.reflectValue(v, opts)
- return
- }
- e.WriteString(`{"`)
- e.WriteString(opts.discriminatorTypeFieldName)
- e.WriteString(`":"`)
- e.WriteString(discriminatorValue)
- e.WriteString(`","`)
- e.WriteString(opts.discriminatorValueFieldName)
- e.WriteString(`":`)
- e.reflectValue(v, opts)
- e.WriteByte('}')
- }
- }
- func discriminatorMapEncode(e *encodeState, v reflect.Value, opts encOpts) {
- if !e.discriminatorEncodeTypeName && !opts.discriminatorEncodeMode.all() {
- return
- }
- discriminatorValue := opts.discriminatorValueFn(v.Type())
- if discriminatorValue == "" {
- return
- }
- e.WriteByte('"')
- e.WriteString(opts.discriminatorTypeFieldName)
- e.WriteString(`":"`)
- e.WriteString(discriminatorValue)
- e.WriteByte('"')
- if v.Len() > 0 {
- e.WriteByte(',')
- }
- e.discriminatorEncodeTypeName = false
- }
- func discriminatorStructEncode(e *encodeState, v reflect.Value, opts encOpts) byte {
- if !e.discriminatorEncodeTypeName && !opts.discriminatorEncodeMode.all() {
- return '{'
- }
- discriminatorValue := opts.discriminatorValueFn(v.Type())
- if discriminatorValue == "" {
- return '{'
- }
- e.WriteString(`{"`)
- e.WriteString(opts.discriminatorTypeFieldName)
- e.WriteString(`":"`)
- e.WriteString(discriminatorValue)
- e.WriteByte('"')
- e.discriminatorEncodeTypeName = false
- return ','
- }
- var unmarshalerType = reflect.TypeOf((*Unmarshaler)(nil)).Elem()
- // Discriminator is nested in map and struct unless they implement Unmarshaler.
- func useNestedDiscriminator(t reflect.Type) bool {
- if t.Implements(unmarshalerType) || reflect.PtrTo(t).Implements(unmarshalerType) {
- return false
- }
- kind := t.Kind()
- if kind == reflect.Struct || kind == reflect.Map {
- return true
- }
- return false
- }
- var discriminatorTypeRegistry = map[string]reflect.Type{
- "uint": reflect.TypeOf(uint(0)),
- "uint8": reflect.TypeOf(uint8(0)),
- "uint16": reflect.TypeOf(uint16(0)),
- "uint32": reflect.TypeOf(uint32(0)),
- "uint64": reflect.TypeOf(uint64(0)),
- "uintptr": reflect.TypeOf(uintptr(0)),
- "int": reflect.TypeOf(int(0)),
- "int8": reflect.TypeOf(int8(0)),
- "int16": reflect.TypeOf(int16(0)),
- "int32": reflect.TypeOf(int32(0)),
- "int64": reflect.TypeOf(int64(0)),
- "float32": reflect.TypeOf(float32(0)),
- "float64": reflect.TypeOf(float64(0)),
- "bool": reflect.TypeOf(true),
- "string": reflect.TypeOf(""),
- "any": reflect.TypeOf((*interface{})(nil)).Elem(),
- "interface{}": reflect.TypeOf((*interface{})(nil)).Elem(),
- "interface {}": reflect.TypeOf((*interface{})(nil)).Elem(),
- // Not supported, but here to prevent the decoder from panicing
- // if encountered.
- "complex64": reflect.TypeOf(complex64(0)),
- "complex128": reflect.TypeOf(complex128(0)),
- }
- // discriminatorPointerTypeCache caches the pointer type for another type.
- // For example, a key that was the int type would have a value that is the
- // *int type.
- var discriminatorPointerTypeCache sync.Map // map[reflect.Type]reflect.Type
- // cachedPointerType returns the pointer type for another and avoids repeated
- // work by using a cache.
- func cachedPointerType(t reflect.Type) reflect.Type {
- if value, ok := discriminatorPointerTypeCache.Load(t); ok {
- return value.(reflect.Type)
- }
- pt := reflect.New(t).Type()
- value, _ := discriminatorPointerTypeCache.LoadOrStore(t, pt)
- return value.(reflect.Type)
- }
- var (
- mapPatt = regexp.MustCompile(`^\*?map\[([^\]]+)\](.+)$`)
- arrayPatt = regexp.MustCompile(`^\*?\[(\d+)\](.+)$`)
- slicePatt = regexp.MustCompile(`^\*?\[\](.+)$`)
- )
- // discriminatorParseTypeName returns a reflect.Type for the given type name.
- func discriminatorParseTypeName(
- typeName string,
- typeFn DiscriminatorToTypeFunc) (reflect.Type, error) {
- // Check to see if the type is an array, map, or slice.
- var (
- aln = -1 // array length
- etn string // map or slice element type name
- ktn string // map key type name
- )
- if m := arrayPatt.FindStringSubmatch(typeName); len(m) > 0 {
- i, err := strconv.Atoi(m[1])
- if err != nil {
- return nil, err
- }
- aln = i
- etn = m[2]
- } else if m := slicePatt.FindStringSubmatch(typeName); len(m) > 0 {
- etn = m[1]
- } else if m := mapPatt.FindStringSubmatch(typeName); len(m) > 0 {
- ktn = m[1]
- etn = m[2]
- }
- // indirectTypeName checks to see if the type name begins with a
- // "*" characters. If it does, then the type name sans the "*"
- // character is returned along with a true value indicating the
- // type is a pointer. Otherwise the original type name is returned
- // along with a false value.
- indirectTypeName := func(tn string) (string, bool) {
- if len(tn) > 1 && tn[0] == '*' {
- return tn[1:], true
- }
- return tn, false
- }
- lookupType := func(tn string) (reflect.Type, bool) {
- // Get the actual type name and a flag indicating whether the
- // type is a pointer.
- n, p := indirectTypeName(tn)
- var t reflect.Type
- ok := false
- // look up the type in the external registry to allow name override.
- if typeFn != nil {
- t, ok = typeFn(n)
- }
- if !ok {
- // Use the built-in registry if the external registry fails
- if t, ok = discriminatorTypeRegistry[n]; !ok {
- return nil, false
- }
- }
- // If the type was a pointer then get the type's pointer type.
- if p {
- t = cachedPointerType(t)
- }
- return t, true
- }
- var t reflect.Type
- if ktn == "" && etn != "" {
- et, ok := lookupType(etn)
- if !ok {
- return nil, fmt.Errorf("json: invalid array/slice element type: %s", etn)
- }
- if aln > -1 {
- // Array
- t = reflect.ArrayOf(aln, et)
- } else {
- // Slice
- t = reflect.SliceOf(et)
- }
- } else if ktn != "" && etn != "" {
- // Map
- kt, ok := lookupType(ktn)
- if !ok {
- return nil, fmt.Errorf("json: invalid map key type: %s", ktn)
- }
- et, ok := lookupType(etn)
- if !ok {
- return nil, fmt.Errorf("json: invalid map element type: %s", etn)
- }
- t = reflect.MapOf(kt, et)
- } else {
- var ok bool
- if t, ok = lookupType(typeName); !ok {
- return nil, fmt.Errorf("json: invalid discriminator type: %s", typeName)
- }
- }
- return t, nil
- }
|