secruleset.go 11 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482
  1. // Copyright 2019 Yunion
  2. //
  3. // Licensed under the Apache License, Version 2.0 (the "License");
  4. // you may not use this file except in compliance with the License.
  5. // You may obtain a copy of the License at
  6. //
  7. // http://www.apache.org/licenses/LICENSE-2.0
  8. //
  9. // Unless required by applicable law or agreed to in writing, software
  10. // distributed under the License is distributed on an "AS IS" BASIS,
  11. // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. // See the License for the specific language governing permissions and
  13. // limitations under the License.
  14. package secrules
  15. import (
  16. "bytes"
  17. "net"
  18. "sort"
  19. "yunion.io/x/log"
  20. "yunion.io/x/pkg/gotypes"
  21. "yunion.io/x/pkg/util/netutils"
  22. "yunion.io/x/pkg/util/regutils"
  23. "yunion.io/x/pkg/util/sortutils"
  24. )
  25. func isWildNet(ipnet *net.IPNet) bool {
  26. return gotypes.IsNil(ipnet)
  27. }
  28. func compareIPNet(ipnet1, ipnet2 *net.IPNet) sortutils.CompareResult {
  29. srsIPi := ipnet1.String()
  30. srsIPj := ipnet2.String()
  31. if !isWildNet(ipnet1) && !isWildNet(ipnet2) {
  32. if srsIPi != srsIPj {
  33. isIPv6i := regutils.MatchCIDR6(srsIPi)
  34. isIPv6j := regutils.MatchCIDR6(srsIPj)
  35. if isIPv6i && isIPv6j {
  36. // compare two ipv6
  37. v6Rangei := netutils.NewIPV6AddrRangeFromIPNet(ipnet1)
  38. v6Rangej := netutils.NewIPV6AddrRangeFromIPNet(ipnet2)
  39. return v6Rangei.Compare(v6Rangej)
  40. } else if !isIPv6i && !isIPv6j {
  41. // compare two ipv4
  42. v4Rangei := netutils.NewIPV4AddrRangeFromIPNet(ipnet1)
  43. v4Rangej := netutils.NewIPV4AddrRangeFromIPNet(ipnet2)
  44. return v4Rangei.Compare(v4Rangej)
  45. } else if isIPv6i && !isIPv6j {
  46. // v4 first
  47. return sortutils.More
  48. } else {
  49. // if !isIPv6i && isIPv6j {
  50. // v4 first
  51. return sortutils.Less
  52. }
  53. } else {
  54. return sortutils.Equal
  55. }
  56. } else if isWildNet(ipnet1) && !isWildNet(ipnet2) {
  57. return sortutils.Less
  58. } else if isWildNet(ipnet1) && !isWildNet(ipnet2) {
  59. return sortutils.More
  60. } else {
  61. // both wild net, go to next
  62. return sortutils.Equal
  63. }
  64. }
  65. func isWildProtocol(protocol string) bool {
  66. return len(protocol) == 0 || protocol == PROTO_ANY
  67. }
  68. func compareProtocol(protocol1, protocol2 string) sortutils.CompareResult {
  69. isWild1 := isWildProtocol(protocol1)
  70. isWild2 := isWildProtocol(protocol1)
  71. if isWild1 && isWild2 {
  72. return sortutils.Equal
  73. } else if isWild1 && !isWild2 {
  74. return sortutils.Less
  75. } else if !isWild1 && isWild2 {
  76. return sortutils.More
  77. } else {
  78. return sortutils.CompareString(protocol1, protocol2)
  79. }
  80. }
  81. type SecurityRuleSet []SecurityRule
  82. func (srs SecurityRuleSet) Len() int {
  83. return len(srs)
  84. }
  85. func (srs SecurityRuleSet) Swap(i, j int) {
  86. srs[i], srs[j] = srs[j], srs[i]
  87. }
  88. func (srs SecurityRuleSet) Less(i, j int) bool {
  89. if srs[i].Priority > srs[j].Priority {
  90. return true
  91. } else if srs[i].Priority < srs[j].Priority {
  92. return false
  93. }
  94. // priority equals, compare ipnet
  95. {
  96. result := compareIPNet(srs[i].IPNet, srs[j].IPNet)
  97. switch result {
  98. case sortutils.Less:
  99. return true
  100. case sortutils.More:
  101. return false
  102. }
  103. }
  104. // compare protocol
  105. {
  106. result := compareProtocol(srs[i].Protocol, srs[j].Protocol)
  107. switch result {
  108. case sortutils.Less:
  109. return true
  110. case sortutils.More:
  111. return false
  112. }
  113. }
  114. return srs[i].String() < srs[j].String()
  115. }
  116. func (srs SecurityRuleSet) stringList() []string {
  117. r := make([]string, len(srs))
  118. for i := range srs {
  119. r = append(r, srs[i].String())
  120. }
  121. return r
  122. }
  123. func (srs SecurityRuleSet) String() string {
  124. buf := bytes.Buffer{}
  125. for i := range srs {
  126. buf.WriteString(srs[i].String())
  127. buf.WriteString(";")
  128. }
  129. return buf.String()
  130. }
  131. func (srs SecurityRuleSet) Equals(srs1 SecurityRuleSet) bool {
  132. sort.Sort(srs)
  133. sort.Sort(srs1)
  134. return srs.equals(srs1)
  135. }
  136. func (srs SecurityRuleSet) equals(srs1 SecurityRuleSet) bool {
  137. if len(srs) != len(srs1) {
  138. return false
  139. }
  140. for i := range srs {
  141. if !srs[i].equals(&srs1[i]) {
  142. return false
  143. }
  144. }
  145. return true
  146. }
  147. // convert to pure allow list
  148. //
  149. // requirements on srs
  150. //
  151. // - ordered by priority
  152. // - same direction
  153. //
  154. /*func (srs SecurityRuleSet) AllowList() SecurityRuleSet {
  155. allowList := SecurityRuleSet{}
  156. denyList := SecurityRuleSet{}
  157. for i := range srs {
  158. if srs[i].Action == SecurityRuleAllow {
  159. allowList = append(allowList, srs[i])
  160. } else {
  161. denyList = append(denyList, srs[i])
  162. }
  163. }
  164. sort.Sort(allowList)
  165. allowList.uniq()
  166. if len(denyList) > 0 {
  167. sort.Sort(denyList)
  168. denyList.uniq()
  169. for i := range denyList {
  170. allowList = allowList.cutOut(denyList[i])
  171. }
  172. }
  173. allowList = allowList.collapse()
  174. return allowList
  175. }
  176. func (srs SecurityRuleSet) cutOut(r SecurityRule) SecurityRuleSet {
  177. cutRes := SecurityRuleSet{}
  178. for i := range srs {
  179. cutout := srs[i].cutOut(r)
  180. cutRes = append(cutRes, cutout...)
  181. }
  182. return cutRes
  183. }
  184. func (srs SecurityRuleSet) cutOutFirst() SecurityRuleSet {
  185. r := SecurityRuleSet{}
  186. if len(srs) == 0 {
  187. return r
  188. }
  189. sr := srs[0]
  190. srs_ := srs[1:]
  191. for _, sr_ := range srs_ {
  192. if sr.Action == sr_.Action {
  193. r = append(r, sr_)
  194. continue
  195. }
  196. cut := sr_.cutOut(sr)
  197. r = append(r, cut...)
  198. }
  199. return r
  200. }*/
  201. // remove duplicate rules
  202. func (srs SecurityRuleSet) uniq() SecurityRuleSet {
  203. for i := len(srs) - 1; i > 0; i-- {
  204. sr0 := &srs[i-1]
  205. sr1 := &srs[i]
  206. if sr0.String() != sr1.String() {
  207. continue
  208. }
  209. srs = append(srs[:i], srs[i+1:]...)
  210. }
  211. return srs
  212. }
  213. // collapse result of AllowList
  214. //
  215. // - same direction
  216. // - same action
  217. //
  218. // As they share the same action, priority's influence on order of rules can be ignored
  219. //
  220. func (srs SecurityRuleSet) collapse() SecurityRuleSet {
  221. srs1 := make(SecurityRuleSet, len(srs))
  222. copy(srs1, srs)
  223. for i := range srs1 {
  224. sr := &srs1[i]
  225. if len(sr.Ports) > 0 {
  226. sort.Slice(sr.Ports, func(i, j int) bool {
  227. return sr.Ports[i] < sr.Ports[j]
  228. })
  229. }
  230. }
  231. sort.Slice(srs1, func(i, j int) bool {
  232. sr0 := &srs1[i]
  233. sr1 := &srs1[j]
  234. {
  235. result := compareProtocol(sr0.Protocol, sr1.Protocol)
  236. switch result {
  237. case sortutils.Less:
  238. return true
  239. case sortutils.More:
  240. return false
  241. }
  242. }
  243. {
  244. result := compareIPNet(sr0.IPNet, sr1.IPNet)
  245. switch result {
  246. case sortutils.Less:
  247. return true
  248. case sortutils.More:
  249. return false
  250. }
  251. }
  252. if sr0.PortStart > 0 && sr0.PortEnd > 0 {
  253. if sr1.PortStart > 0 && sr1.PortEnd > 0 {
  254. return sr0.PortStart < sr1.PortStart
  255. }
  256. // port range comes first
  257. return true
  258. } else if len(sr0.Ports) > 0 {
  259. if sr1.PortStart > 0 && sr1.PortEnd > 0 {
  260. return false
  261. } else if len(sr1.Ports) > 0 {
  262. sr0l := len(sr0.Ports)
  263. sr1l := len(sr1.Ports)
  264. for i := 0; i < sr0l && i < sr1l; i++ {
  265. if sr0.Ports[i] != sr1.Ports[i] {
  266. return sr0.Ports[i] < sr1.Ports[i]
  267. }
  268. }
  269. return sr0l < sr1l
  270. }
  271. }
  272. return sr0.Priority < sr1.Priority
  273. })
  274. // merge ports
  275. for i := len(srs1) - 1; i > 0; i-- {
  276. sr0 := &srs1[i-1]
  277. sr1 := &srs1[i]
  278. if sr0.Protocol != sr1.Protocol {
  279. continue
  280. }
  281. if !sr0.netEquals(sr1) {
  282. continue
  283. }
  284. if (len(sr0.Ports) > 0 || (sr0.PortStart == sr0.PortEnd && sr0.PortStart > 0)) && (len(sr1.Ports) > 0 || (sr1.PortStart == sr1.PortEnd && sr1.PortStart > 0)) {
  285. ps := newPortsFromInts(sr0.Ports...)
  286. ps = append(ps, newPortsFromInts(sr1.Ports...)...)
  287. if sr0.PortStart == sr0.PortEnd && sr0.PortStart > 0 {
  288. ps = append(ps, uint16(sr0.PortStart))
  289. }
  290. if sr1.PortStart == sr1.PortEnd && sr1.PortStart > 0 {
  291. ps = append(ps, uint16(sr1.PortStart))
  292. }
  293. ps = ps.dedup()
  294. sr0.Ports = ps.IntSlice()
  295. sr0.PortStart, sr0.PortEnd = -1, -1
  296. srs1 = append(srs1[:i], srs1[i+1:]...)
  297. } else if sr0.PortStart > 0 && sr1.PortStart > 0 && sr0.PortEnd > 0 && sr1.PortEnd > 0 {
  298. if sr0.PortEnd == sr1.PortStart-1 {
  299. sr0.PortEnd = sr1.PortEnd
  300. srs1 = append(srs1[:i], srs1[i+1:]...)
  301. } else if sr0.PortStart-1 == sr1.PortEnd {
  302. sr0.PortStart = sr1.PortStart
  303. srs1 = append(srs1[:i], srs1[i+1:]...)
  304. } else if sr0.PortStart == sr1.PortStart && sr0.PortEnd == sr1.PortEnd {
  305. srs1 = append(srs1[:i], srs1[i+1:]...)
  306. }
  307. // save that contains, intersects
  308. }
  309. }
  310. for i := range srs1 {
  311. sr := &srs1[i]
  312. if sr.PortStart <= 1 && sr.PortEnd >= 65535 {
  313. sr.PortStart = -1
  314. sr.PortEnd = -1
  315. }
  316. }
  317. //merge cidr
  318. sort.Slice(srs1, func(i, j int) bool {
  319. sr0 := &srs1[i]
  320. sr1 := &srs1[j]
  321. {
  322. result := compareProtocol(sr0.Protocol, sr1.Protocol)
  323. switch result {
  324. case sortutils.Less:
  325. return true
  326. case sortutils.More:
  327. return false
  328. }
  329. }
  330. if sr0.GetPortsString() != sr1.GetPortsString() {
  331. return sr0.GetPortsString() < sr1.GetPortsString()
  332. }
  333. {
  334. result := compareIPNet(sr0.IPNet, sr1.IPNet)
  335. switch result {
  336. case sortutils.Less:
  337. return true
  338. case sortutils.More:
  339. return false
  340. }
  341. }
  342. return sr0.Priority < sr1.Priority
  343. })
  344. // 将端口和协议相同的规则归类
  345. needMerged := []SecurityRuleSet{}
  346. for i, j := 0, 0; i < len(srs1); i++ {
  347. if i == 0 {
  348. needMerged = append(needMerged, SecurityRuleSet{srs1[i]})
  349. continue
  350. }
  351. last := needMerged[j][len(needMerged[j])-1]
  352. if last.Protocol == srs1[i].Protocol && last.GetPortsString() == srs1[i].GetPortsString() {
  353. needMerged[j] = append(needMerged[j], srs1[i])
  354. continue
  355. }
  356. needMerged = append(needMerged, SecurityRuleSet{srs1[i]})
  357. j++
  358. }
  359. result := SecurityRuleSet{}
  360. for _, srs := range needMerged {
  361. result = append(result, srs.mergeNet()...)
  362. }
  363. result = result.uniq()
  364. for i := range result {
  365. sr := &result[i]
  366. sr.Priority = 1
  367. }
  368. return result
  369. }
  370. func (srs SecurityRuleSet) mergeNet() SecurityRuleSet {
  371. ranges4 := []netutils.IPV4AddrRange{}
  372. ranges6 := []netutils.IPV6AddrRange{}
  373. for i := 0; i < len(srs); i++ {
  374. if isWildNet(srs[i].IPNet) {
  375. // wild mark
  376. ranges4 = append(ranges4, netutils.AllIPV4AddrRange)
  377. ranges6 = append(ranges6, netutils.AllIPV6AddrRange)
  378. } else {
  379. cidr := srs[i].IPNet.String()
  380. if regutils.MatchCIDR6(cidr) {
  381. // ipv6
  382. ranges6 = append(ranges6, netutils.NewIPV6AddrRangeFromIPNet(srs[i].IPNet))
  383. } else {
  384. ranges4 = append(ranges4, netutils.NewIPV4AddrRangeFromIPNet(srs[i].IPNet))
  385. }
  386. }
  387. }
  388. ranges4 = netutils.IPV4AddrRangeList(ranges4).Merge()
  389. ranges6 = netutils.IPV6AddrRangeList(ranges6).Merge()
  390. nets := []*net.IPNet{}
  391. hasWildNet4 := false
  392. hasWildNet6 := false
  393. for i := range ranges4 {
  394. addr := ranges4[i]
  395. for _, n := range addr.ToIPNets() {
  396. if n.String() == "0.0.0.0/0" {
  397. hasWildNet4 = true
  398. } else {
  399. nets = append(nets, n)
  400. log.Debugf("merge v4 %s", n.String())
  401. }
  402. }
  403. }
  404. for i := range ranges6 {
  405. addr := ranges6[i]
  406. for _, n := range addr.ToIPNets() {
  407. if n.String() == "::/0" {
  408. hasWildNet6 = true
  409. } else {
  410. nets = append(nets, n)
  411. }
  412. }
  413. }
  414. result := SecurityRuleSet{}
  415. if hasWildNet4 && hasWildNet6 {
  416. val := srs[0]
  417. val.IPNet = nil
  418. result = append(result, val)
  419. } else if hasWildNet4 {
  420. val := srs[0]
  421. val.IPNet = &net.IPNet{
  422. IP: net.IPv4zero,
  423. Mask: net.CIDRMask(0, 32),
  424. }
  425. result = append(result, val)
  426. } else if hasWildNet6 {
  427. val := srs[0]
  428. val.IPNet = &net.IPNet{
  429. IP: net.IPv6zero,
  430. Mask: net.CIDRMask(0, 128),
  431. }
  432. result = append(result, val)
  433. }
  434. for _, net := range nets {
  435. val := srs[0]
  436. val.IPNet = net
  437. result = append(result, val)
  438. }
  439. return result
  440. }