| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482 |
- // Copyright 2019 Yunion
- //
- // Licensed under the Apache License, Version 2.0 (the "License");
- // you may not use this file except in compliance with the License.
- // You may obtain a copy of the License at
- //
- // http://www.apache.org/licenses/LICENSE-2.0
- //
- // Unless required by applicable law or agreed to in writing, software
- // distributed under the License is distributed on an "AS IS" BASIS,
- // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- // See the License for the specific language governing permissions and
- // limitations under the License.
- package secrules
- import (
- "bytes"
- "net"
- "sort"
- "yunion.io/x/log"
- "yunion.io/x/pkg/gotypes"
- "yunion.io/x/pkg/util/netutils"
- "yunion.io/x/pkg/util/regutils"
- "yunion.io/x/pkg/util/sortutils"
- )
- func isWildNet(ipnet *net.IPNet) bool {
- return gotypes.IsNil(ipnet)
- }
- func compareIPNet(ipnet1, ipnet2 *net.IPNet) sortutils.CompareResult {
- srsIPi := ipnet1.String()
- srsIPj := ipnet2.String()
- if !isWildNet(ipnet1) && !isWildNet(ipnet2) {
- if srsIPi != srsIPj {
- isIPv6i := regutils.MatchCIDR6(srsIPi)
- isIPv6j := regutils.MatchCIDR6(srsIPj)
- if isIPv6i && isIPv6j {
- // compare two ipv6
- v6Rangei := netutils.NewIPV6AddrRangeFromIPNet(ipnet1)
- v6Rangej := netutils.NewIPV6AddrRangeFromIPNet(ipnet2)
- return v6Rangei.Compare(v6Rangej)
- } else if !isIPv6i && !isIPv6j {
- // compare two ipv4
- v4Rangei := netutils.NewIPV4AddrRangeFromIPNet(ipnet1)
- v4Rangej := netutils.NewIPV4AddrRangeFromIPNet(ipnet2)
- return v4Rangei.Compare(v4Rangej)
- } else if isIPv6i && !isIPv6j {
- // v4 first
- return sortutils.More
- } else {
- // if !isIPv6i && isIPv6j {
- // v4 first
- return sortutils.Less
- }
- } else {
- return sortutils.Equal
- }
- } else if isWildNet(ipnet1) && !isWildNet(ipnet2) {
- return sortutils.Less
- } else if isWildNet(ipnet1) && !isWildNet(ipnet2) {
- return sortutils.More
- } else {
- // both wild net, go to next
- return sortutils.Equal
- }
- }
- func isWildProtocol(protocol string) bool {
- return len(protocol) == 0 || protocol == PROTO_ANY
- }
- func compareProtocol(protocol1, protocol2 string) sortutils.CompareResult {
- isWild1 := isWildProtocol(protocol1)
- isWild2 := isWildProtocol(protocol1)
- if isWild1 && isWild2 {
- return sortutils.Equal
- } else if isWild1 && !isWild2 {
- return sortutils.Less
- } else if !isWild1 && isWild2 {
- return sortutils.More
- } else {
- return sortutils.CompareString(protocol1, protocol2)
- }
- }
- type SecurityRuleSet []SecurityRule
- func (srs SecurityRuleSet) Len() int {
- return len(srs)
- }
- func (srs SecurityRuleSet) Swap(i, j int) {
- srs[i], srs[j] = srs[j], srs[i]
- }
- func (srs SecurityRuleSet) Less(i, j int) bool {
- if srs[i].Priority > srs[j].Priority {
- return true
- } else if srs[i].Priority < srs[j].Priority {
- return false
- }
- // priority equals, compare ipnet
- {
- result := compareIPNet(srs[i].IPNet, srs[j].IPNet)
- switch result {
- case sortutils.Less:
- return true
- case sortutils.More:
- return false
- }
- }
- // compare protocol
- {
- result := compareProtocol(srs[i].Protocol, srs[j].Protocol)
- switch result {
- case sortutils.Less:
- return true
- case sortutils.More:
- return false
- }
- }
- return srs[i].String() < srs[j].String()
- }
- func (srs SecurityRuleSet) stringList() []string {
- r := make([]string, len(srs))
- for i := range srs {
- r = append(r, srs[i].String())
- }
- return r
- }
- func (srs SecurityRuleSet) String() string {
- buf := bytes.Buffer{}
- for i := range srs {
- buf.WriteString(srs[i].String())
- buf.WriteString(";")
- }
- return buf.String()
- }
- func (srs SecurityRuleSet) Equals(srs1 SecurityRuleSet) bool {
- sort.Sort(srs)
- sort.Sort(srs1)
- return srs.equals(srs1)
- }
- func (srs SecurityRuleSet) equals(srs1 SecurityRuleSet) bool {
- if len(srs) != len(srs1) {
- return false
- }
- for i := range srs {
- if !srs[i].equals(&srs1[i]) {
- return false
- }
- }
- return true
- }
- // convert to pure allow list
- //
- // requirements on srs
- //
- // - ordered by priority
- // - same direction
- //
- /*func (srs SecurityRuleSet) AllowList() SecurityRuleSet {
- allowList := SecurityRuleSet{}
- denyList := SecurityRuleSet{}
- for i := range srs {
- if srs[i].Action == SecurityRuleAllow {
- allowList = append(allowList, srs[i])
- } else {
- denyList = append(denyList, srs[i])
- }
- }
- sort.Sort(allowList)
- allowList.uniq()
- if len(denyList) > 0 {
- sort.Sort(denyList)
- denyList.uniq()
- for i := range denyList {
- allowList = allowList.cutOut(denyList[i])
- }
- }
- allowList = allowList.collapse()
- return allowList
- }
- func (srs SecurityRuleSet) cutOut(r SecurityRule) SecurityRuleSet {
- cutRes := SecurityRuleSet{}
- for i := range srs {
- cutout := srs[i].cutOut(r)
- cutRes = append(cutRes, cutout...)
- }
- return cutRes
- }
- func (srs SecurityRuleSet) cutOutFirst() SecurityRuleSet {
- r := SecurityRuleSet{}
- if len(srs) == 0 {
- return r
- }
- sr := srs[0]
- srs_ := srs[1:]
- for _, sr_ := range srs_ {
- if sr.Action == sr_.Action {
- r = append(r, sr_)
- continue
- }
- cut := sr_.cutOut(sr)
- r = append(r, cut...)
- }
- return r
- }*/
- // remove duplicate rules
- func (srs SecurityRuleSet) uniq() SecurityRuleSet {
- for i := len(srs) - 1; i > 0; i-- {
- sr0 := &srs[i-1]
- sr1 := &srs[i]
- if sr0.String() != sr1.String() {
- continue
- }
- srs = append(srs[:i], srs[i+1:]...)
- }
- return srs
- }
- // collapse result of AllowList
- //
- // - same direction
- // - same action
- //
- // As they share the same action, priority's influence on order of rules can be ignored
- //
- func (srs SecurityRuleSet) collapse() SecurityRuleSet {
- srs1 := make(SecurityRuleSet, len(srs))
- copy(srs1, srs)
- for i := range srs1 {
- sr := &srs1[i]
- if len(sr.Ports) > 0 {
- sort.Slice(sr.Ports, func(i, j int) bool {
- return sr.Ports[i] < sr.Ports[j]
- })
- }
- }
- sort.Slice(srs1, func(i, j int) bool {
- sr0 := &srs1[i]
- sr1 := &srs1[j]
- {
- result := compareProtocol(sr0.Protocol, sr1.Protocol)
- switch result {
- case sortutils.Less:
- return true
- case sortutils.More:
- return false
- }
- }
- {
- result := compareIPNet(sr0.IPNet, sr1.IPNet)
- switch result {
- case sortutils.Less:
- return true
- case sortutils.More:
- return false
- }
- }
- if sr0.PortStart > 0 && sr0.PortEnd > 0 {
- if sr1.PortStart > 0 && sr1.PortEnd > 0 {
- return sr0.PortStart < sr1.PortStart
- }
- // port range comes first
- return true
- } else if len(sr0.Ports) > 0 {
- if sr1.PortStart > 0 && sr1.PortEnd > 0 {
- return false
- } else if len(sr1.Ports) > 0 {
- sr0l := len(sr0.Ports)
- sr1l := len(sr1.Ports)
- for i := 0; i < sr0l && i < sr1l; i++ {
- if sr0.Ports[i] != sr1.Ports[i] {
- return sr0.Ports[i] < sr1.Ports[i]
- }
- }
- return sr0l < sr1l
- }
- }
- return sr0.Priority < sr1.Priority
- })
- // merge ports
- for i := len(srs1) - 1; i > 0; i-- {
- sr0 := &srs1[i-1]
- sr1 := &srs1[i]
- if sr0.Protocol != sr1.Protocol {
- continue
- }
- if !sr0.netEquals(sr1) {
- continue
- }
- if (len(sr0.Ports) > 0 || (sr0.PortStart == sr0.PortEnd && sr0.PortStart > 0)) && (len(sr1.Ports) > 0 || (sr1.PortStart == sr1.PortEnd && sr1.PortStart > 0)) {
- ps := newPortsFromInts(sr0.Ports...)
- ps = append(ps, newPortsFromInts(sr1.Ports...)...)
- if sr0.PortStart == sr0.PortEnd && sr0.PortStart > 0 {
- ps = append(ps, uint16(sr0.PortStart))
- }
- if sr1.PortStart == sr1.PortEnd && sr1.PortStart > 0 {
- ps = append(ps, uint16(sr1.PortStart))
- }
- ps = ps.dedup()
- sr0.Ports = ps.IntSlice()
- sr0.PortStart, sr0.PortEnd = -1, -1
- srs1 = append(srs1[:i], srs1[i+1:]...)
- } else if sr0.PortStart > 0 && sr1.PortStart > 0 && sr0.PortEnd > 0 && sr1.PortEnd > 0 {
- if sr0.PortEnd == sr1.PortStart-1 {
- sr0.PortEnd = sr1.PortEnd
- srs1 = append(srs1[:i], srs1[i+1:]...)
- } else if sr0.PortStart-1 == sr1.PortEnd {
- sr0.PortStart = sr1.PortStart
- srs1 = append(srs1[:i], srs1[i+1:]...)
- } else if sr0.PortStart == sr1.PortStart && sr0.PortEnd == sr1.PortEnd {
- srs1 = append(srs1[:i], srs1[i+1:]...)
- }
- // save that contains, intersects
- }
- }
- for i := range srs1 {
- sr := &srs1[i]
- if sr.PortStart <= 1 && sr.PortEnd >= 65535 {
- sr.PortStart = -1
- sr.PortEnd = -1
- }
- }
- //merge cidr
- sort.Slice(srs1, func(i, j int) bool {
- sr0 := &srs1[i]
- sr1 := &srs1[j]
- {
- result := compareProtocol(sr0.Protocol, sr1.Protocol)
- switch result {
- case sortutils.Less:
- return true
- case sortutils.More:
- return false
- }
- }
- if sr0.GetPortsString() != sr1.GetPortsString() {
- return sr0.GetPortsString() < sr1.GetPortsString()
- }
- {
- result := compareIPNet(sr0.IPNet, sr1.IPNet)
- switch result {
- case sortutils.Less:
- return true
- case sortutils.More:
- return false
- }
- }
- return sr0.Priority < sr1.Priority
- })
- // 将端口和协议相同的规则归类
- needMerged := []SecurityRuleSet{}
- for i, j := 0, 0; i < len(srs1); i++ {
- if i == 0 {
- needMerged = append(needMerged, SecurityRuleSet{srs1[i]})
- continue
- }
- last := needMerged[j][len(needMerged[j])-1]
- if last.Protocol == srs1[i].Protocol && last.GetPortsString() == srs1[i].GetPortsString() {
- needMerged[j] = append(needMerged[j], srs1[i])
- continue
- }
- needMerged = append(needMerged, SecurityRuleSet{srs1[i]})
- j++
- }
- result := SecurityRuleSet{}
- for _, srs := range needMerged {
- result = append(result, srs.mergeNet()...)
- }
- result = result.uniq()
- for i := range result {
- sr := &result[i]
- sr.Priority = 1
- }
- return result
- }
- func (srs SecurityRuleSet) mergeNet() SecurityRuleSet {
- ranges4 := []netutils.IPV4AddrRange{}
- ranges6 := []netutils.IPV6AddrRange{}
- for i := 0; i < len(srs); i++ {
- if isWildNet(srs[i].IPNet) {
- // wild mark
- ranges4 = append(ranges4, netutils.AllIPV4AddrRange)
- ranges6 = append(ranges6, netutils.AllIPV6AddrRange)
- } else {
- cidr := srs[i].IPNet.String()
- if regutils.MatchCIDR6(cidr) {
- // ipv6
- ranges6 = append(ranges6, netutils.NewIPV6AddrRangeFromIPNet(srs[i].IPNet))
- } else {
- ranges4 = append(ranges4, netutils.NewIPV4AddrRangeFromIPNet(srs[i].IPNet))
- }
- }
- }
- ranges4 = netutils.IPV4AddrRangeList(ranges4).Merge()
- ranges6 = netutils.IPV6AddrRangeList(ranges6).Merge()
- nets := []*net.IPNet{}
- hasWildNet4 := false
- hasWildNet6 := false
- for i := range ranges4 {
- addr := ranges4[i]
- for _, n := range addr.ToIPNets() {
- if n.String() == "0.0.0.0/0" {
- hasWildNet4 = true
- } else {
- nets = append(nets, n)
- log.Debugf("merge v4 %s", n.String())
- }
- }
- }
- for i := range ranges6 {
- addr := ranges6[i]
- for _, n := range addr.ToIPNets() {
- if n.String() == "::/0" {
- hasWildNet6 = true
- } else {
- nets = append(nets, n)
- }
- }
- }
- result := SecurityRuleSet{}
- if hasWildNet4 && hasWildNet6 {
- val := srs[0]
- val.IPNet = nil
- result = append(result, val)
- } else if hasWildNet4 {
- val := srs[0]
- val.IPNet = &net.IPNet{
- IP: net.IPv4zero,
- Mask: net.CIDRMask(0, 32),
- }
- result = append(result, val)
- } else if hasWildNet6 {
- val := srs[0]
- val.IPNet = &net.IPNet{
- IP: net.IPv6zero,
- Mask: net.CIDRMask(0, 128),
- }
- result = append(result, val)
- }
- for _, net := range nets {
- val := srs[0]
- val.IPNet = net
- result = append(result, val)
- }
- return result
- }
|