| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341 |
- // Package iptree implements radix tree data structure for IPv4 and IPv6 networks.
- package iptree
- import (
- "fmt"
- "net"
- "github.com/infobloxopen/go-trees/numtree"
- )
- const (
- iPv4Bits = net.IPv4len * 8
- iPv6Bits = net.IPv6len * 8
- )
- var (
- iPv4MaxMask = net.CIDRMask(iPv4Bits, iPv4Bits)
- iPv6MaxMask = net.CIDRMask(iPv6Bits, iPv6Bits)
- )
- // Tree is a radix tree for IPv4 and IPv6 networks.
- type Tree struct {
- root32 *numtree.Node32
- root64 *numtree.Node64
- }
- // Pair represents a key-value pair returned by Enumerate method.
- type Pair struct {
- Key *net.IPNet
- Value interface{}
- }
- type subTree64 *numtree.Node64
- // NewTree creates empty tree.
- func NewTree() *Tree {
- return &Tree{}
- }
- // InsertNet inserts value using given network as a key. The method returns new tree (old one remains unaffected).
- func (t *Tree) InsertNet(n *net.IPNet, value interface{}) *Tree {
- if n == nil {
- return t
- }
- if key, bits := iPv4NetToUint32(n); bits >= 0 {
- var (
- r32 *numtree.Node32
- r64 *numtree.Node64
- )
- if t != nil {
- r32 = t.root32
- r64 = t.root64
- }
- return &Tree{root32: r32.Insert(key, bits, value), root64: r64}
- }
- if MSKey, MSBits, LSKey, LSBits := iPv6NetToUint64Pair(n); MSBits >= 0 {
- var (
- r32 *numtree.Node32
- r64 *numtree.Node64
- )
- if t != nil {
- r32 = t.root32
- r64 = t.root64
- }
- if MSBits < numtree.Key64BitSize {
- return &Tree{root32: r32, root64: r64.Insert(MSKey, MSBits, value)}
- }
- var r *numtree.Node64
- if v, ok := r64.ExactMatch(MSKey, MSBits); ok {
- s, ok := v.(subTree64)
- if !ok {
- err := fmt.Errorf("invalid IPv6 tree: expected subTree64 value at 0x%016x, %d but got %T (%#v)",
- MSKey, MSBits, v, v)
- panic(err)
- }
- r = (*numtree.Node64)(s)
- }
- r = r.Insert(LSKey, LSBits, value)
- return &Tree{root32: r32, root64: r64.Insert(MSKey, MSBits, subTree64(r))}
- }
- return t
- }
- // InplaceInsertNet inserts (or replaces) value using given network as a key in current tree.
- func (t *Tree) InplaceInsertNet(n *net.IPNet, value interface{}) {
- if n == nil {
- return
- }
- if key, bits := iPv4NetToUint32(n); bits >= 0 {
- t.root32 = t.root32.InplaceInsert(key, bits, value)
- } else if MSKey, MSBits, LSKey, LSBits := iPv6NetToUint64Pair(n); MSBits >= 0 {
- if MSBits < numtree.Key64BitSize {
- t.root64 = t.root64.InplaceInsert(MSKey, MSBits, value)
- } else {
- if v, ok := t.root64.ExactMatch(MSKey, MSBits); ok {
- s, ok := v.(subTree64)
- if !ok {
- err := fmt.Errorf("invalid IPv6 tree: expected subTree64 value at 0x%016x, %d but got %T (%#v)",
- MSKey, MSBits, v, v)
- panic(err)
- }
- r := (*numtree.Node64)(s)
- newR := r.InplaceInsert(LSKey, LSBits, value)
- if newR != r {
- t.root64 = t.root64.InplaceInsert(MSKey, MSBits, subTree64(newR))
- }
- } else {
- var r *numtree.Node64
- r = r.InplaceInsert(LSKey, LSBits, value)
- t.root64 = t.root64.InplaceInsert(MSKey, MSBits, subTree64(r))
- }
- }
- }
- }
- // InsertIP inserts value using given IP address as a key. The method returns new tree (old one remains unaffected).
- func (t *Tree) InsertIP(ip net.IP, value interface{}) *Tree {
- return t.InsertNet(newIPNetFromIP(ip), value)
- }
- // InplaceInsertIP inserts (or replaces) value using given IP address as a key in current tree.
- func (t *Tree) InplaceInsertIP(ip net.IP, value interface{}) {
- t.InplaceInsertNet(newIPNetFromIP(ip), value)
- }
- // Enumerate returns channel which is populated by key-value pairs of tree content.
- func (t *Tree) Enumerate() chan Pair {
- ch := make(chan Pair)
- go func() {
- defer close(ch)
- if t == nil {
- return
- }
- t.enumerate(ch)
- }()
- return ch
- }
- // GetByNet gets value for network which is equal to or contains given network.
- func (t *Tree) GetByNet(n *net.IPNet) (interface{}, bool) {
- if t == nil || n == nil {
- return nil, false
- }
- if key, bits := iPv4NetToUint32(n); bits >= 0 {
- return t.root32.Match(key, bits)
- }
- if MSKey, MSBits, LSKey, LSBits := iPv6NetToUint64Pair(n); MSBits >= 0 {
- v, ok := t.root64.Match(MSKey, MSBits)
- if !ok || MSBits < numtree.Key64BitSize {
- return v, ok
- }
- s, ok := v.(subTree64)
- if !ok {
- return v, true
- }
- v, ok = (*numtree.Node64)(s).Match(LSKey, LSBits)
- if ok {
- return v, ok
- }
- return t.root64.Match(MSKey, numtree.Key64BitSize-1)
- }
- return nil, false
- }
- // GetByIP gets value for network which is equal to or contains given IP address.
- func (t *Tree) GetByIP(ip net.IP) (interface{}, bool) {
- return t.GetByNet(newIPNetFromIP(ip))
- }
- // DeleteByNet removes subtree which is contained by given network. The method returns new tree (old one remains unaffected) and flag indicating if deletion happens indeed.
- func (t *Tree) DeleteByNet(n *net.IPNet) (*Tree, bool) {
- if t == nil || n == nil {
- return t, false
- }
- if key, bits := iPv4NetToUint32(n); bits >= 0 {
- r, ok := t.root32.Delete(key, bits)
- if ok {
- return &Tree{root32: r, root64: t.root64}, true
- }
- } else if MSKey, MSBits, LSKey, LSBits := iPv6NetToUint64Pair(n); MSBits >= 0 {
- r64 := t.root64
- if MSBits < numtree.Key64BitSize {
- r64, ok := r64.Delete(MSKey, MSBits)
- if ok {
- return &Tree{root32: t.root32, root64: r64}, true
- }
- } else if v, ok := r64.ExactMatch(MSKey, MSBits); ok {
- s, ok := v.(subTree64)
- if !ok {
- err := fmt.Errorf("invalid IPv6 tree: expected subTree64 value at 0x%016x, %d but got %T (%#v)",
- MSKey, MSBits, v, v)
- panic(err)
- }
- r, ok := (*numtree.Node64)(s).Delete(LSKey, LSBits)
- if ok {
- if r == nil {
- r64, _ = r64.Delete(MSKey, MSBits)
- } else {
- r64 = r64.Insert(MSKey, MSBits, subTree64(r))
- }
- return &Tree{root32: t.root32, root64: r64}, true
- }
- }
- }
- return t, false
- }
- // DeleteByIP removes node by given IP address. The method returns new tree (old one remains unaffected) and flag indicating if deletion happens indeed.
- func (t *Tree) DeleteByIP(ip net.IP) (*Tree, bool) {
- return t.DeleteByNet(newIPNetFromIP(ip))
- }
- func (t *Tree) enumerate(ch chan Pair) {
- for n := range t.root32.Enumerate() {
- mask := net.CIDRMask(int(n.Bits), iPv4Bits)
- ch <- Pair{
- Key: &net.IPNet{
- IP: unpackUint32ToIP(n.Key).Mask(mask),
- Mask: mask},
- Value: n.Value}
- }
- for n := range t.root64.Enumerate() {
- MSIP := append(unpackUint64ToIP(n.Key), make(net.IP, 8)...)
- if s, ok := n.Value.(subTree64); ok {
- for n := range (*numtree.Node64)(s).Enumerate() {
- LSIP := unpackUint64ToIP(n.Key)
- mask := net.CIDRMask(numtree.Key64BitSize+int(n.Bits), iPv6Bits)
- ch <- Pair{
- Key: &net.IPNet{
- IP: append(MSIP[0:8], LSIP...).Mask(mask),
- Mask: mask},
- Value: n.Value}
- }
- } else {
- mask := net.CIDRMask(int(n.Bits), iPv6Bits)
- ch <- Pair{
- Key: &net.IPNet{
- IP: MSIP.Mask(mask),
- Mask: mask},
- Value: n.Value}
- }
- }
- }
- func iPv4NetToUint32(n *net.IPNet) (uint32, int) {
- if len(n.IP) != net.IPv4len {
- return 0, -1
- }
- ones, bits := n.Mask.Size()
- if bits != iPv4Bits {
- return 0, -1
- }
- return packIPToUint32(n.IP), ones
- }
- func packIPToUint32(x net.IP) uint32 {
- return (uint32(x[0]) << 24) | (uint32(x[1]) << 16) | (uint32(x[2]) << 8) | uint32(x[3])
- }
- func unpackUint32ToIP(x uint32) net.IP {
- return net.IP{byte(x >> 24 & 0xff), byte(x >> 16 & 0xff), byte(x >> 8 & 0xff), byte(x & 0xff)}
- }
- func iPv6NetToUint64Pair(n *net.IPNet) (uint64, int, uint64, int) {
- if len(n.IP) != net.IPv6len {
- return 0, -1, 0, -1
- }
- ones, bits := n.Mask.Size()
- if bits != iPv6Bits {
- return 0, -1, 0, -1
- }
- MSBits := numtree.Key64BitSize
- LSBits := 0
- if ones > numtree.Key64BitSize {
- LSBits = ones - numtree.Key64BitSize
- } else {
- MSBits = ones
- }
- return packIPToUint64(n.IP), MSBits, packIPToUint64(n.IP[8:]), LSBits
- }
- func packIPToUint64(x net.IP) uint64 {
- return (uint64(x[0]) << 56) | (uint64(x[1]) << 48) | (uint64(x[2]) << 40) | (uint64(x[3]) << 32) |
- (uint64(x[4]) << 24) | (uint64(x[5]) << 16) | (uint64(x[6]) << 8) | uint64(x[7])
- }
- func unpackUint64ToIP(x uint64) net.IP {
- return net.IP{
- byte(x >> 56 & 0xff),
- byte(x >> 48 & 0xff),
- byte(x >> 40 & 0xff),
- byte(x >> 32 & 0xff),
- byte(x >> 24 & 0xff),
- byte(x >> 16 & 0xff),
- byte(x >> 8 & 0xff),
- byte(x & 0xff)}
- }
- func newIPNetFromIP(ip net.IP) *net.IPNet {
- if ip4 := ip.To4(); ip4 != nil {
- return &net.IPNet{IP: ip4, Mask: iPv4MaxMask}
- }
- if ip6 := ip.To16(); ip6 != nil {
- return &net.IPNet{IP: ip6, Mask: iPv6MaxMask}
- }
- return nil
- }
|