iptree.go 8.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341
  1. // Package iptree implements radix tree data structure for IPv4 and IPv6 networks.
  2. package iptree
  3. import (
  4. "fmt"
  5. "net"
  6. "github.com/infobloxopen/go-trees/numtree"
  7. )
  8. const (
  9. iPv4Bits = net.IPv4len * 8
  10. iPv6Bits = net.IPv6len * 8
  11. )
  12. var (
  13. iPv4MaxMask = net.CIDRMask(iPv4Bits, iPv4Bits)
  14. iPv6MaxMask = net.CIDRMask(iPv6Bits, iPv6Bits)
  15. )
  16. // Tree is a radix tree for IPv4 and IPv6 networks.
  17. type Tree struct {
  18. root32 *numtree.Node32
  19. root64 *numtree.Node64
  20. }
  21. // Pair represents a key-value pair returned by Enumerate method.
  22. type Pair struct {
  23. Key *net.IPNet
  24. Value interface{}
  25. }
  26. type subTree64 *numtree.Node64
  27. // NewTree creates empty tree.
  28. func NewTree() *Tree {
  29. return &Tree{}
  30. }
  31. // InsertNet inserts value using given network as a key. The method returns new tree (old one remains unaffected).
  32. func (t *Tree) InsertNet(n *net.IPNet, value interface{}) *Tree {
  33. if n == nil {
  34. return t
  35. }
  36. if key, bits := iPv4NetToUint32(n); bits >= 0 {
  37. var (
  38. r32 *numtree.Node32
  39. r64 *numtree.Node64
  40. )
  41. if t != nil {
  42. r32 = t.root32
  43. r64 = t.root64
  44. }
  45. return &Tree{root32: r32.Insert(key, bits, value), root64: r64}
  46. }
  47. if MSKey, MSBits, LSKey, LSBits := iPv6NetToUint64Pair(n); MSBits >= 0 {
  48. var (
  49. r32 *numtree.Node32
  50. r64 *numtree.Node64
  51. )
  52. if t != nil {
  53. r32 = t.root32
  54. r64 = t.root64
  55. }
  56. if MSBits < numtree.Key64BitSize {
  57. return &Tree{root32: r32, root64: r64.Insert(MSKey, MSBits, value)}
  58. }
  59. var r *numtree.Node64
  60. if v, ok := r64.ExactMatch(MSKey, MSBits); ok {
  61. s, ok := v.(subTree64)
  62. if !ok {
  63. err := fmt.Errorf("invalid IPv6 tree: expected subTree64 value at 0x%016x, %d but got %T (%#v)",
  64. MSKey, MSBits, v, v)
  65. panic(err)
  66. }
  67. r = (*numtree.Node64)(s)
  68. }
  69. r = r.Insert(LSKey, LSBits, value)
  70. return &Tree{root32: r32, root64: r64.Insert(MSKey, MSBits, subTree64(r))}
  71. }
  72. return t
  73. }
  74. // InplaceInsertNet inserts (or replaces) value using given network as a key in current tree.
  75. func (t *Tree) InplaceInsertNet(n *net.IPNet, value interface{}) {
  76. if n == nil {
  77. return
  78. }
  79. if key, bits := iPv4NetToUint32(n); bits >= 0 {
  80. t.root32 = t.root32.InplaceInsert(key, bits, value)
  81. } else if MSKey, MSBits, LSKey, LSBits := iPv6NetToUint64Pair(n); MSBits >= 0 {
  82. if MSBits < numtree.Key64BitSize {
  83. t.root64 = t.root64.InplaceInsert(MSKey, MSBits, value)
  84. } else {
  85. if v, ok := t.root64.ExactMatch(MSKey, MSBits); ok {
  86. s, ok := v.(subTree64)
  87. if !ok {
  88. err := fmt.Errorf("invalid IPv6 tree: expected subTree64 value at 0x%016x, %d but got %T (%#v)",
  89. MSKey, MSBits, v, v)
  90. panic(err)
  91. }
  92. r := (*numtree.Node64)(s)
  93. newR := r.InplaceInsert(LSKey, LSBits, value)
  94. if newR != r {
  95. t.root64 = t.root64.InplaceInsert(MSKey, MSBits, subTree64(newR))
  96. }
  97. } else {
  98. var r *numtree.Node64
  99. r = r.InplaceInsert(LSKey, LSBits, value)
  100. t.root64 = t.root64.InplaceInsert(MSKey, MSBits, subTree64(r))
  101. }
  102. }
  103. }
  104. }
  105. // InsertIP inserts value using given IP address as a key. The method returns new tree (old one remains unaffected).
  106. func (t *Tree) InsertIP(ip net.IP, value interface{}) *Tree {
  107. return t.InsertNet(newIPNetFromIP(ip), value)
  108. }
  109. // InplaceInsertIP inserts (or replaces) value using given IP address as a key in current tree.
  110. func (t *Tree) InplaceInsertIP(ip net.IP, value interface{}) {
  111. t.InplaceInsertNet(newIPNetFromIP(ip), value)
  112. }
  113. // Enumerate returns channel which is populated by key-value pairs of tree content.
  114. func (t *Tree) Enumerate() chan Pair {
  115. ch := make(chan Pair)
  116. go func() {
  117. defer close(ch)
  118. if t == nil {
  119. return
  120. }
  121. t.enumerate(ch)
  122. }()
  123. return ch
  124. }
  125. // GetByNet gets value for network which is equal to or contains given network.
  126. func (t *Tree) GetByNet(n *net.IPNet) (interface{}, bool) {
  127. if t == nil || n == nil {
  128. return nil, false
  129. }
  130. if key, bits := iPv4NetToUint32(n); bits >= 0 {
  131. return t.root32.Match(key, bits)
  132. }
  133. if MSKey, MSBits, LSKey, LSBits := iPv6NetToUint64Pair(n); MSBits >= 0 {
  134. v, ok := t.root64.Match(MSKey, MSBits)
  135. if !ok || MSBits < numtree.Key64BitSize {
  136. return v, ok
  137. }
  138. s, ok := v.(subTree64)
  139. if !ok {
  140. return v, true
  141. }
  142. v, ok = (*numtree.Node64)(s).Match(LSKey, LSBits)
  143. if ok {
  144. return v, ok
  145. }
  146. return t.root64.Match(MSKey, numtree.Key64BitSize-1)
  147. }
  148. return nil, false
  149. }
  150. // GetByIP gets value for network which is equal to or contains given IP address.
  151. func (t *Tree) GetByIP(ip net.IP) (interface{}, bool) {
  152. return t.GetByNet(newIPNetFromIP(ip))
  153. }
  154. // DeleteByNet removes subtree which is contained by given network. The method returns new tree (old one remains unaffected) and flag indicating if deletion happens indeed.
  155. func (t *Tree) DeleteByNet(n *net.IPNet) (*Tree, bool) {
  156. if t == nil || n == nil {
  157. return t, false
  158. }
  159. if key, bits := iPv4NetToUint32(n); bits >= 0 {
  160. r, ok := t.root32.Delete(key, bits)
  161. if ok {
  162. return &Tree{root32: r, root64: t.root64}, true
  163. }
  164. } else if MSKey, MSBits, LSKey, LSBits := iPv6NetToUint64Pair(n); MSBits >= 0 {
  165. r64 := t.root64
  166. if MSBits < numtree.Key64BitSize {
  167. r64, ok := r64.Delete(MSKey, MSBits)
  168. if ok {
  169. return &Tree{root32: t.root32, root64: r64}, true
  170. }
  171. } else if v, ok := r64.ExactMatch(MSKey, MSBits); ok {
  172. s, ok := v.(subTree64)
  173. if !ok {
  174. err := fmt.Errorf("invalid IPv6 tree: expected subTree64 value at 0x%016x, %d but got %T (%#v)",
  175. MSKey, MSBits, v, v)
  176. panic(err)
  177. }
  178. r, ok := (*numtree.Node64)(s).Delete(LSKey, LSBits)
  179. if ok {
  180. if r == nil {
  181. r64, _ = r64.Delete(MSKey, MSBits)
  182. } else {
  183. r64 = r64.Insert(MSKey, MSBits, subTree64(r))
  184. }
  185. return &Tree{root32: t.root32, root64: r64}, true
  186. }
  187. }
  188. }
  189. return t, false
  190. }
  191. // DeleteByIP removes node by given IP address. The method returns new tree (old one remains unaffected) and flag indicating if deletion happens indeed.
  192. func (t *Tree) DeleteByIP(ip net.IP) (*Tree, bool) {
  193. return t.DeleteByNet(newIPNetFromIP(ip))
  194. }
  195. func (t *Tree) enumerate(ch chan Pair) {
  196. for n := range t.root32.Enumerate() {
  197. mask := net.CIDRMask(int(n.Bits), iPv4Bits)
  198. ch <- Pair{
  199. Key: &net.IPNet{
  200. IP: unpackUint32ToIP(n.Key).Mask(mask),
  201. Mask: mask},
  202. Value: n.Value}
  203. }
  204. for n := range t.root64.Enumerate() {
  205. MSIP := append(unpackUint64ToIP(n.Key), make(net.IP, 8)...)
  206. if s, ok := n.Value.(subTree64); ok {
  207. for n := range (*numtree.Node64)(s).Enumerate() {
  208. LSIP := unpackUint64ToIP(n.Key)
  209. mask := net.CIDRMask(numtree.Key64BitSize+int(n.Bits), iPv6Bits)
  210. ch <- Pair{
  211. Key: &net.IPNet{
  212. IP: append(MSIP[0:8], LSIP...).Mask(mask),
  213. Mask: mask},
  214. Value: n.Value}
  215. }
  216. } else {
  217. mask := net.CIDRMask(int(n.Bits), iPv6Bits)
  218. ch <- Pair{
  219. Key: &net.IPNet{
  220. IP: MSIP.Mask(mask),
  221. Mask: mask},
  222. Value: n.Value}
  223. }
  224. }
  225. }
  226. func iPv4NetToUint32(n *net.IPNet) (uint32, int) {
  227. if len(n.IP) != net.IPv4len {
  228. return 0, -1
  229. }
  230. ones, bits := n.Mask.Size()
  231. if bits != iPv4Bits {
  232. return 0, -1
  233. }
  234. return packIPToUint32(n.IP), ones
  235. }
  236. func packIPToUint32(x net.IP) uint32 {
  237. return (uint32(x[0]) << 24) | (uint32(x[1]) << 16) | (uint32(x[2]) << 8) | uint32(x[3])
  238. }
  239. func unpackUint32ToIP(x uint32) net.IP {
  240. return net.IP{byte(x >> 24 & 0xff), byte(x >> 16 & 0xff), byte(x >> 8 & 0xff), byte(x & 0xff)}
  241. }
  242. func iPv6NetToUint64Pair(n *net.IPNet) (uint64, int, uint64, int) {
  243. if len(n.IP) != net.IPv6len {
  244. return 0, -1, 0, -1
  245. }
  246. ones, bits := n.Mask.Size()
  247. if bits != iPv6Bits {
  248. return 0, -1, 0, -1
  249. }
  250. MSBits := numtree.Key64BitSize
  251. LSBits := 0
  252. if ones > numtree.Key64BitSize {
  253. LSBits = ones - numtree.Key64BitSize
  254. } else {
  255. MSBits = ones
  256. }
  257. return packIPToUint64(n.IP), MSBits, packIPToUint64(n.IP[8:]), LSBits
  258. }
  259. func packIPToUint64(x net.IP) uint64 {
  260. return (uint64(x[0]) << 56) | (uint64(x[1]) << 48) | (uint64(x[2]) << 40) | (uint64(x[3]) << 32) |
  261. (uint64(x[4]) << 24) | (uint64(x[5]) << 16) | (uint64(x[6]) << 8) | uint64(x[7])
  262. }
  263. func unpackUint64ToIP(x uint64) net.IP {
  264. return net.IP{
  265. byte(x >> 56 & 0xff),
  266. byte(x >> 48 & 0xff),
  267. byte(x >> 40 & 0xff),
  268. byte(x >> 32 & 0xff),
  269. byte(x >> 24 & 0xff),
  270. byte(x >> 16 & 0xff),
  271. byte(x >> 8 & 0xff),
  272. byte(x & 0xff)}
  273. }
  274. func newIPNetFromIP(ip net.IP) *net.IPNet {
  275. if ip4 := ip.To4(); ip4 != nil {
  276. return &net.IPNet{IP: ip4, Mask: iPv4MaxMask}
  277. }
  278. if ip6 := ip.To16(); ip6 != nil {
  279. return &net.IPNet{IP: ip6, Mask: iPv6MaxMask}
  280. }
  281. return nil
  282. }