mapiter.go 4.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196
  1. package mapiter
  2. import (
  3. "context"
  4. "reflect"
  5. "sync"
  6. "github.com/pkg/errors"
  7. )
  8. // Iterate creates an iterator from arbitrary map types. This is not
  9. // the most efficient tool, but it's the quickest way to create an
  10. // iterator for maps.
  11. // Also, note that you cannot make any assumptions on the order of
  12. // pairs being returned.
  13. func Iterate(ctx context.Context, m interface{}) (Iterator, error) {
  14. mrv := reflect.ValueOf(m)
  15. if mrv.Kind() != reflect.Map {
  16. return nil, errors.Errorf(`argument must be a map (%s)`, mrv.Type())
  17. }
  18. ch := make(chan *Pair)
  19. go func(ctx context.Context, ch chan *Pair, mrv reflect.Value) {
  20. defer close(ch)
  21. for _, key := range mrv.MapKeys() {
  22. value := mrv.MapIndex(key)
  23. pair := &Pair{
  24. Key: key.Interface(),
  25. Value: value.Interface(),
  26. }
  27. select {
  28. case <-ctx.Done():
  29. return
  30. case ch <- pair:
  31. }
  32. }
  33. }(ctx, ch, mrv)
  34. return New(ch), nil
  35. }
  36. // Source represents a map that knows how to create an iterator
  37. type Source interface {
  38. Iterate(context.Context) Iterator
  39. }
  40. // Pair represents a single pair of key and value from a map
  41. type Pair struct {
  42. Key interface{}
  43. Value interface{}
  44. }
  45. // Iterator iterates through keys and values of a map
  46. type Iterator interface {
  47. Next(context.Context) bool
  48. Pair() *Pair
  49. }
  50. type iter struct {
  51. ch chan *Pair
  52. mu sync.RWMutex
  53. next *Pair
  54. }
  55. // Visitor represents an object that handles each pair in a map
  56. type Visitor interface {
  57. Visit(interface{}, interface{}) error
  58. }
  59. // VisitorFunc is a type of Visitor based on a function
  60. type VisitorFunc func(interface{}, interface{}) error
  61. func (fn VisitorFunc) Visit(s interface{}, v interface{}) error {
  62. return fn(s, v)
  63. }
  64. func New(ch chan *Pair) Iterator {
  65. return &iter{
  66. ch: ch,
  67. }
  68. }
  69. // Next returns true if there are more items to read from the iterator
  70. func (i *iter) Next(ctx context.Context) bool {
  71. i.mu.RLock()
  72. if i.ch == nil {
  73. i.mu.RUnlock()
  74. return false
  75. }
  76. i.mu.RUnlock()
  77. i.mu.Lock()
  78. defer i.mu.Unlock()
  79. select {
  80. case <-ctx.Done():
  81. i.ch = nil
  82. return false
  83. case v, ok := <-i.ch:
  84. if !ok {
  85. i.ch = nil
  86. return false
  87. }
  88. i.next = v
  89. return true
  90. }
  91. //nolint:govet
  92. return false // never reached
  93. }
  94. // Pair returns the currently buffered Pair. Calling Next() will reset its value
  95. func (i *iter) Pair() *Pair {
  96. i.mu.RLock()
  97. defer i.mu.RUnlock()
  98. return i.next
  99. }
  100. // Walk walks through each element in the map
  101. func Walk(ctx context.Context, s Source, v Visitor) error {
  102. for i := s.Iterate(ctx); i.Next(ctx); {
  103. pair := i.Pair()
  104. if err := v.Visit(pair.Key, pair.Value); err != nil {
  105. return errors.Wrapf(err, `failed to visit key %s`, pair.Key)
  106. }
  107. }
  108. return nil
  109. }
  110. // AsMap returns the values obtained from the source as a map
  111. func AsMap(ctx context.Context, s interface{}, v interface{}) error {
  112. var iter Iterator
  113. switch reflect.ValueOf(s).Kind() {
  114. case reflect.Map:
  115. x, err := Iterate(ctx, s)
  116. if err != nil {
  117. return errors.Wrap(err, `failed to iterate over map type`)
  118. }
  119. iter = x
  120. default:
  121. ssrc, ok := s.(Source)
  122. if !ok {
  123. return errors.Errorf(`cannot iterate over %T: not a mapiter.Source type`, s)
  124. }
  125. iter = ssrc.Iterate(ctx)
  126. }
  127. dst := reflect.ValueOf(v)
  128. // dst MUST be a pointer to a map type
  129. if kind := dst.Kind(); kind != reflect.Ptr {
  130. return errors.Errorf(`dst must be a pointer to a map (%s)`, dst.Type())
  131. }
  132. dst = dst.Elem()
  133. if dst.Kind() != reflect.Map {
  134. return errors.Errorf(`dst must be a pointer to a map (%s)`, dst.Type())
  135. }
  136. if dst.IsNil() {
  137. dst.Set(reflect.MakeMap(dst.Type()))
  138. }
  139. // dst must be assignable
  140. if !dst.CanSet() {
  141. return errors.New(`dst is not writeable`)
  142. }
  143. keytyp := dst.Type().Key()
  144. valtyp := dst.Type().Elem()
  145. for iter.Next(ctx) {
  146. pair := iter.Pair()
  147. rvkey := reflect.ValueOf(pair.Key)
  148. rvvalue := reflect.ValueOf(pair.Value)
  149. if !rvkey.Type().AssignableTo(keytyp) {
  150. return errors.Errorf(`cannot assign key of type %s to map key of type %s`, rvkey.Type(), keytyp)
  151. }
  152. switch rvvalue.Kind() {
  153. // we can only check if we can assign to rvvalue to valtyp if it's non-nil
  154. case reflect.Invalid:
  155. rvvalue = reflect.New(valtyp).Elem()
  156. default:
  157. if !rvvalue.Type().AssignableTo(valtyp) {
  158. return errors.Errorf(`cannot assign value of type %s to map value of type %s`, rvvalue.Type(), valtyp)
  159. }
  160. }
  161. dst.SetMapIndex(rvkey, rvvalue)
  162. }
  163. return nil
  164. }