discriminator.go 16 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568
  1. // Copyright 2022 The Go Authors. All rights reserved.
  2. // Use of this source code is governed by a BSD-style
  3. // license that can be found in the LICENSE file.
  4. package json
  5. import (
  6. "fmt"
  7. "reflect"
  8. "regexp"
  9. "strconv"
  10. "sync"
  11. )
  12. // DiscriminatorToTypeFunc is used to get a reflect.Type from its
  13. // discriminator.
  14. type DiscriminatorToTypeFunc func(discriminator string) (reflect.Type, bool)
  15. // TypeToDiscriminatorFunc is used to get a discriminator string from a
  16. // reflect.Type. Empty return value suppresses discriminator rendering.
  17. type TypeToDiscriminatorFunc func(reflect.Type) (discriminator string)
  18. // DefaultDiscriminatorFunc is shorthand for the ShortName func and is used when
  19. // no other discriminator func is set explicitly
  20. var DefaultDiscriminatorFunc = ShortName
  21. // ShortName returns the type name in golang without the package name
  22. func ShortName(t reflect.Type) (discriminator string) {
  23. tn := t.Name()
  24. if tn == "" {
  25. return t.String()
  26. }
  27. return tn
  28. }
  29. // FullName return the name of the type prefixed with the package name as
  30. // appropriate
  31. func FullName(t reflect.Type) (discriminator string) {
  32. tn := t.Name()
  33. if tn == "" {
  34. return t.String()
  35. }
  36. if pp := t.PkgPath(); pp != "" {
  37. return fmt.Sprintf("%s.%s", pp, tn)
  38. }
  39. return tn
  40. }
  41. // DiscriminatorEncodeMode is a mask that describes the different encode
  42. // options.
  43. type DiscriminatorEncodeMode uint8
  44. const (
  45. // DiscriminatorEncodeTypeNameRootValue causes the type name to be encoded
  46. // for the root value.
  47. DiscriminatorEncodeTypeNameRootValue DiscriminatorEncodeMode = 1 << iota
  48. // DiscriminatorEncodeTypeNameAllObjects causes the type name to be encoded
  49. // for all struct and map values. Please note this specifically does not
  50. // apply to the root value.
  51. DiscriminatorEncodeTypeNameAllObjects
  52. // DiscriminatorEncodeTypeNameIfRequired is the default behavior when
  53. // the discriminator is set, and the type name is only encoded if required.
  54. DiscriminatorEncodeTypeNameIfRequired DiscriminatorEncodeMode = 0
  55. )
  56. func (m DiscriminatorEncodeMode) root() bool {
  57. return m&DiscriminatorEncodeTypeNameRootValue > 0
  58. }
  59. func (m DiscriminatorEncodeMode) all() bool {
  60. return m&DiscriminatorEncodeTypeNameAllObjects > 0
  61. }
  62. func (d *decodeState) isDiscriminatorSet() bool {
  63. return d.discriminatorTypeFieldName != "" &&
  64. d.discriminatorValueFieldName != ""
  65. }
  66. // discriminatorOpType describes the current operation related to
  67. // discriminators when reading a JSON object's fields.
  68. type discriminatorOpType uint8
  69. const (
  70. // discriminatorOpTypeNameField indicates the discriminator type name
  71. // field was discovered.
  72. discriminatorOpTypeNameField = iota + 1
  73. // discriminatorOpValueField indicates the discriminator value field
  74. // was discovered.
  75. discriminatorOpValueField
  76. )
  77. func (d *decodeState) discriminatorGetValue() (reflect.Value, error) {
  78. // Record the current offset so we know where the data starts.
  79. offset := d.readIndex()
  80. // Create a temporary decodeState used to inspect the current object
  81. // and determine its discriminator type and decode its value.
  82. dd := &decodeState{
  83. disallowUnknownFields: d.disallowUnknownFields,
  84. useNumber: d.useNumber,
  85. discriminatorToTypeFn: d.discriminatorToTypeFn,
  86. discriminatorTypeFieldName: d.discriminatorTypeFieldName,
  87. discriminatorValueFieldName: d.discriminatorValueFieldName,
  88. }
  89. dd.init(append([]byte{}, d.data[offset:]...))
  90. defer freeScanner(&dd.scan)
  91. dd.scan.reset()
  92. var (
  93. t reflect.Type // the instance of the type
  94. valueOff = -1 // the offset of a possible discriminator value
  95. )
  96. dd.scanWhile(scanSkipSpace)
  97. if dd.opcode != scanBeginObject {
  98. panic(phasePanicMsg)
  99. }
  100. for {
  101. dd.scanWhile(scanSkipSpace)
  102. if dd.opcode == scanEndObject {
  103. // closing } - can only happen on first iteration.
  104. break
  105. }
  106. if dd.opcode != scanBeginLiteral {
  107. panic(phasePanicMsg)
  108. }
  109. // Read key.
  110. start := dd.readIndex()
  111. dd.rescanLiteral()
  112. item := dd.data[start:dd.readIndex()]
  113. key, ok := unquote(item)
  114. if !ok {
  115. panic(phasePanicMsg)
  116. }
  117. // Check to see if the key is related to the discriminator.
  118. var discriminatorOp discriminatorOpType
  119. switch key {
  120. case d.discriminatorTypeFieldName:
  121. discriminatorOp = discriminatorOpTypeNameField
  122. case d.discriminatorValueFieldName:
  123. discriminatorOp = discriminatorOpValueField
  124. }
  125. // Read : before value.
  126. if dd.opcode == scanSkipSpace {
  127. dd.scanWhile(scanSkipSpace)
  128. }
  129. if dd.opcode != scanObjectKey {
  130. panic(phasePanicMsg)
  131. }
  132. dd.scanWhile(scanSkipSpace)
  133. // Read value.
  134. valOff := dd.readIndex()
  135. val := dd.valueInterface()
  136. switch discriminatorOp {
  137. case discriminatorOpTypeNameField:
  138. tn, ok := val.(string)
  139. if !ok {
  140. return reflect.Value{}, fmt.Errorf(
  141. "json: discriminator type at offset %d is not string",
  142. offset+valOff)
  143. }
  144. if tn == "" {
  145. return reflect.Value{}, fmt.Errorf(
  146. "json: discriminator type at offset %d is empty",
  147. offset+valOff)
  148. }
  149. // Parse the type name into a type instance.
  150. ti, err := discriminatorParseTypeName(tn, d.discriminatorToTypeFn)
  151. if err != nil {
  152. return reflect.Value{}, err
  153. }
  154. // Assign the type instance to the outer variable, t.
  155. t = ti
  156. // Primitive types and types with Unmarshaler are wrapped in a
  157. // structure with type and value fields. Structures and Maps not
  158. // implementing Unmarshaler use discriminator embedded within their
  159. // content.
  160. if useNestedDiscriminator(t) {
  161. // If the type is a map or a struct not implementing Unmarshaler
  162. // then it is not necessary to continue walking over the current
  163. // JSON object since it will be completely re-scanned to decode
  164. // its value into the discovered type.
  165. dd.opcode = scanEndObject
  166. } else {
  167. // Otherwise if the value offset has been discovered then it is
  168. // safe to stop walking over the current JSON object as well.
  169. if valueOff > -1 {
  170. dd.opcode = scanEndObject
  171. }
  172. }
  173. case discriminatorOpValueField:
  174. valueOff = valOff
  175. // If the type has been discovered then it is safe to stop walking
  176. // over the current JSON object.
  177. if t != nil {
  178. dd.opcode = scanEndObject
  179. }
  180. }
  181. // Next token must be , or }.
  182. if dd.opcode == scanSkipSpace {
  183. dd.scanWhile(scanSkipSpace)
  184. }
  185. if dd.opcode == scanEndObject {
  186. break
  187. }
  188. if dd.opcode != scanObjectValue {
  189. panic(phasePanicMsg)
  190. }
  191. }
  192. // If there is not a type discriminator then return early.
  193. if t == nil {
  194. return reflect.Value{}, fmt.Errorf("json: missing discriminator")
  195. }
  196. // Instantiate a new instance of the discriminated type.
  197. var v reflect.Value
  198. switch t.Kind() {
  199. case reflect.Slice:
  200. // MakeSlice returns a value that is not addressable.
  201. // Instead, use MakeSlice to get the type, then use
  202. // reflect.New to create an addressable value.
  203. v = reflect.New(reflect.MakeSlice(t, 0, 0).Type()).Elem()
  204. case reflect.Map:
  205. // MakeMap returns a value that is not addressable.
  206. // Instead, use MakeMap to get the type, then use
  207. // reflect.New to create an addressable value.
  208. v = reflect.New(reflect.MakeMap(t).Type()).Elem()
  209. case reflect.Complex64, reflect.Complex128:
  210. return reflect.Value{}, fmt.Errorf("json: unsupported discriminator type: %s", t.Kind())
  211. default:
  212. v = reflect.New(t)
  213. }
  214. // Reset the decode state to prepare for decoding the data.
  215. dd.scan.reset()
  216. if useNestedDiscriminator(t) {
  217. // Set the offset to zero since the entire object will be decoded
  218. // into v.
  219. dd.off = 0
  220. } else {
  221. // Set the offset to what it was before the discriminator value was
  222. // read so only the value field is decoded into v.
  223. dd.off = valueOff
  224. }
  225. // This will initialize the correct scan step and op code.
  226. dd.scanWhile(scanSkipSpace)
  227. // Decode the data into the value.
  228. if err := dd.value(v); err != nil {
  229. return reflect.Value{}, err
  230. }
  231. // Check the saved error as well since the decoder.value function does not
  232. // always return an error. If the reflected value is still zero, then it is
  233. // likely the decoder was unable to decode the value.
  234. if err := dd.savedError; err != nil {
  235. switch v.Kind() {
  236. case reflect.Ptr, reflect.Interface:
  237. v = v.Elem()
  238. }
  239. if v.IsZero() {
  240. return reflect.Value{}, err
  241. }
  242. }
  243. return v, nil
  244. }
  245. func (d *decodeState) discriminatorInterfaceDecode(t reflect.Type, v reflect.Value) error {
  246. defer func() {
  247. // Advance the decode state, throwing away the value.
  248. _ = d.objectInterface()
  249. }()
  250. dv, err := d.discriminatorGetValue()
  251. if err != nil {
  252. return err
  253. }
  254. switch dv.Kind() {
  255. case reflect.Map, reflect.Slice:
  256. if dv.Type().AssignableTo(t) {
  257. v.Set(dv)
  258. return nil
  259. }
  260. if pdv := dv.Addr(); pdv.Type().AssignableTo(t) {
  261. v.Set(pdv)
  262. return nil
  263. }
  264. case reflect.Ptr:
  265. if dve := dv.Elem(); dve.Type().AssignableTo(t) {
  266. v.Set(dve)
  267. return nil
  268. }
  269. if dv.Type().AssignableTo(t) {
  270. v.Set(dv)
  271. return nil
  272. }
  273. }
  274. return fmt.Errorf("json: unsupported discriminator kind: %s", dv.Kind())
  275. }
  276. func (o encOpts) isDiscriminatorSet() bool {
  277. return o.discriminatorTypeFieldName != "" &&
  278. o.discriminatorValueFieldName != ""
  279. }
  280. func discriminatorInterfaceEncode(e *encodeState, v reflect.Value, opts encOpts) {
  281. v = v.Elem()
  282. if v.Type().Implements(marshalerType) {
  283. discriminatorValue := opts.discriminatorValueFn(v.Type())
  284. if discriminatorValue == "" {
  285. marshalerEncoder(e, v, opts)
  286. }
  287. e.WriteString(`{"`)
  288. e.WriteString(opts.discriminatorTypeFieldName)
  289. e.WriteString(`":"`)
  290. e.WriteString(discriminatorValue)
  291. e.WriteString(`","`)
  292. e.WriteString(opts.discriminatorValueFieldName)
  293. e.WriteString(`":`)
  294. marshalerEncoder(e, v, opts)
  295. e.WriteByte('}')
  296. return
  297. }
  298. switch v.Kind() {
  299. case reflect.Chan, reflect.Func, reflect.Invalid:
  300. e.error(&UnsupportedValueError{v, fmt.Sprintf("invalid kind: %s", v.Kind())})
  301. case reflect.Map:
  302. e.discriminatorEncodeTypeName = true
  303. newMapEncoder(v.Type())(e, v, opts)
  304. case reflect.Struct:
  305. e.discriminatorEncodeTypeName = true
  306. newStructEncoder(v.Type())(e, v, opts)
  307. case reflect.Ptr:
  308. discriminatorInterfaceEncode(e, v, opts)
  309. default:
  310. discriminatorValue := opts.discriminatorValueFn(v.Type())
  311. if discriminatorValue == "" {
  312. e.reflectValue(v, opts)
  313. return
  314. }
  315. e.WriteString(`{"`)
  316. e.WriteString(opts.discriminatorTypeFieldName)
  317. e.WriteString(`":"`)
  318. e.WriteString(discriminatorValue)
  319. e.WriteString(`","`)
  320. e.WriteString(opts.discriminatorValueFieldName)
  321. e.WriteString(`":`)
  322. e.reflectValue(v, opts)
  323. e.WriteByte('}')
  324. }
  325. }
  326. func discriminatorMapEncode(e *encodeState, v reflect.Value, opts encOpts) {
  327. if !e.discriminatorEncodeTypeName && !opts.discriminatorEncodeMode.all() {
  328. return
  329. }
  330. discriminatorValue := opts.discriminatorValueFn(v.Type())
  331. if discriminatorValue == "" {
  332. return
  333. }
  334. e.WriteByte('"')
  335. e.WriteString(opts.discriminatorTypeFieldName)
  336. e.WriteString(`":"`)
  337. e.WriteString(discriminatorValue)
  338. e.WriteByte('"')
  339. if v.Len() > 0 {
  340. e.WriteByte(',')
  341. }
  342. e.discriminatorEncodeTypeName = false
  343. }
  344. func discriminatorStructEncode(e *encodeState, v reflect.Value, opts encOpts) byte {
  345. if !e.discriminatorEncodeTypeName && !opts.discriminatorEncodeMode.all() {
  346. return '{'
  347. }
  348. discriminatorValue := opts.discriminatorValueFn(v.Type())
  349. if discriminatorValue == "" {
  350. return '{'
  351. }
  352. e.WriteString(`{"`)
  353. e.WriteString(opts.discriminatorTypeFieldName)
  354. e.WriteString(`":"`)
  355. e.WriteString(discriminatorValue)
  356. e.WriteByte('"')
  357. e.discriminatorEncodeTypeName = false
  358. return ','
  359. }
  360. var unmarshalerType = reflect.TypeOf((*Unmarshaler)(nil)).Elem()
  361. // Discriminator is nested in map and struct unless they implement Unmarshaler.
  362. func useNestedDiscriminator(t reflect.Type) bool {
  363. if t.Implements(unmarshalerType) || reflect.PtrTo(t).Implements(unmarshalerType) {
  364. return false
  365. }
  366. kind := t.Kind()
  367. if kind == reflect.Struct || kind == reflect.Map {
  368. return true
  369. }
  370. return false
  371. }
  372. var discriminatorTypeRegistry = map[string]reflect.Type{
  373. "uint": reflect.TypeOf(uint(0)),
  374. "uint8": reflect.TypeOf(uint8(0)),
  375. "uint16": reflect.TypeOf(uint16(0)),
  376. "uint32": reflect.TypeOf(uint32(0)),
  377. "uint64": reflect.TypeOf(uint64(0)),
  378. "uintptr": reflect.TypeOf(uintptr(0)),
  379. "int": reflect.TypeOf(int(0)),
  380. "int8": reflect.TypeOf(int8(0)),
  381. "int16": reflect.TypeOf(int16(0)),
  382. "int32": reflect.TypeOf(int32(0)),
  383. "int64": reflect.TypeOf(int64(0)),
  384. "float32": reflect.TypeOf(float32(0)),
  385. "float64": reflect.TypeOf(float64(0)),
  386. "bool": reflect.TypeOf(true),
  387. "string": reflect.TypeOf(""),
  388. "any": reflect.TypeOf((*interface{})(nil)).Elem(),
  389. "interface{}": reflect.TypeOf((*interface{})(nil)).Elem(),
  390. "interface {}": reflect.TypeOf((*interface{})(nil)).Elem(),
  391. // Not supported, but here to prevent the decoder from panicing
  392. // if encountered.
  393. "complex64": reflect.TypeOf(complex64(0)),
  394. "complex128": reflect.TypeOf(complex128(0)),
  395. }
  396. // discriminatorPointerTypeCache caches the pointer type for another type.
  397. // For example, a key that was the int type would have a value that is the
  398. // *int type.
  399. var discriminatorPointerTypeCache sync.Map // map[reflect.Type]reflect.Type
  400. // cachedPointerType returns the pointer type for another and avoids repeated
  401. // work by using a cache.
  402. func cachedPointerType(t reflect.Type) reflect.Type {
  403. if value, ok := discriminatorPointerTypeCache.Load(t); ok {
  404. return value.(reflect.Type)
  405. }
  406. pt := reflect.New(t).Type()
  407. value, _ := discriminatorPointerTypeCache.LoadOrStore(t, pt)
  408. return value.(reflect.Type)
  409. }
  410. var (
  411. mapPatt = regexp.MustCompile(`^\*?map\[([^\]]+)\](.+)$`)
  412. arrayPatt = regexp.MustCompile(`^\*?\[(\d+)\](.+)$`)
  413. slicePatt = regexp.MustCompile(`^\*?\[\](.+)$`)
  414. )
  415. // discriminatorParseTypeName returns a reflect.Type for the given type name.
  416. func discriminatorParseTypeName(
  417. typeName string,
  418. typeFn DiscriminatorToTypeFunc) (reflect.Type, error) {
  419. // Check to see if the type is an array, map, or slice.
  420. var (
  421. aln = -1 // array length
  422. etn string // map or slice element type name
  423. ktn string // map key type name
  424. )
  425. if m := arrayPatt.FindStringSubmatch(typeName); len(m) > 0 {
  426. i, err := strconv.Atoi(m[1])
  427. if err != nil {
  428. return nil, err
  429. }
  430. aln = i
  431. etn = m[2]
  432. } else if m := slicePatt.FindStringSubmatch(typeName); len(m) > 0 {
  433. etn = m[1]
  434. } else if m := mapPatt.FindStringSubmatch(typeName); len(m) > 0 {
  435. ktn = m[1]
  436. etn = m[2]
  437. }
  438. // indirectTypeName checks to see if the type name begins with a
  439. // "*" characters. If it does, then the type name sans the "*"
  440. // character is returned along with a true value indicating the
  441. // type is a pointer. Otherwise the original type name is returned
  442. // along with a false value.
  443. indirectTypeName := func(tn string) (string, bool) {
  444. if len(tn) > 1 && tn[0] == '*' {
  445. return tn[1:], true
  446. }
  447. return tn, false
  448. }
  449. lookupType := func(tn string) (reflect.Type, bool) {
  450. // Get the actual type name and a flag indicating whether the
  451. // type is a pointer.
  452. n, p := indirectTypeName(tn)
  453. var t reflect.Type
  454. ok := false
  455. // look up the type in the external registry to allow name override.
  456. if typeFn != nil {
  457. t, ok = typeFn(n)
  458. }
  459. if !ok {
  460. // Use the built-in registry if the external registry fails
  461. if t, ok = discriminatorTypeRegistry[n]; !ok {
  462. return nil, false
  463. }
  464. }
  465. // If the type was a pointer then get the type's pointer type.
  466. if p {
  467. t = cachedPointerType(t)
  468. }
  469. return t, true
  470. }
  471. var t reflect.Type
  472. if ktn == "" && etn != "" {
  473. et, ok := lookupType(etn)
  474. if !ok {
  475. return nil, fmt.Errorf("json: invalid array/slice element type: %s", etn)
  476. }
  477. if aln > -1 {
  478. // Array
  479. t = reflect.ArrayOf(aln, et)
  480. } else {
  481. // Slice
  482. t = reflect.SliceOf(et)
  483. }
  484. } else if ktn != "" && etn != "" {
  485. // Map
  486. kt, ok := lookupType(ktn)
  487. if !ok {
  488. return nil, fmt.Errorf("json: invalid map key type: %s", ktn)
  489. }
  490. et, ok := lookupType(etn)
  491. if !ok {
  492. return nil, fmt.Errorf("json: invalid map element type: %s", etn)
  493. }
  494. t = reflect.MapOf(kt, et)
  495. } else {
  496. var ok bool
  497. if t, ok = lookupType(typeName); !ok {
  498. return nil, fmt.Errorf("json: invalid discriminator type: %s", typeName)
  499. }
  500. }
  501. return t, nil
  502. }