per_host.go 4.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153
  1. // Copyright 2011 The Go Authors. All rights reserved.
  2. // Use of this source code is governed by a BSD-style
  3. // license that can be found in the LICENSE file.
  4. package proxy
  5. import (
  6. "context"
  7. "net"
  8. "net/netip"
  9. "strings"
  10. )
  11. // A PerHost directs connections to a default Dialer unless the host name
  12. // requested matches one of a number of exceptions.
  13. type PerHost struct {
  14. def, bypass Dialer
  15. bypassNetworks []*net.IPNet
  16. bypassIPs []net.IP
  17. bypassZones []string
  18. bypassHosts []string
  19. }
  20. // NewPerHost returns a PerHost Dialer that directs connections to either
  21. // defaultDialer or bypass, depending on whether the connection matches one of
  22. // the configured rules.
  23. func NewPerHost(defaultDialer, bypass Dialer) *PerHost {
  24. return &PerHost{
  25. def: defaultDialer,
  26. bypass: bypass,
  27. }
  28. }
  29. // Dial connects to the address addr on the given network through either
  30. // defaultDialer or bypass.
  31. func (p *PerHost) Dial(network, addr string) (c net.Conn, err error) {
  32. host, _, err := net.SplitHostPort(addr)
  33. if err != nil {
  34. return nil, err
  35. }
  36. return p.dialerForRequest(host).Dial(network, addr)
  37. }
  38. // DialContext connects to the address addr on the given network through either
  39. // defaultDialer or bypass.
  40. func (p *PerHost) DialContext(ctx context.Context, network, addr string) (c net.Conn, err error) {
  41. host, _, err := net.SplitHostPort(addr)
  42. if err != nil {
  43. return nil, err
  44. }
  45. d := p.dialerForRequest(host)
  46. if x, ok := d.(ContextDialer); ok {
  47. return x.DialContext(ctx, network, addr)
  48. }
  49. return dialContext(ctx, d, network, addr)
  50. }
  51. func (p *PerHost) dialerForRequest(host string) Dialer {
  52. if nip, err := netip.ParseAddr(host); err == nil {
  53. ip := net.IP(nip.AsSlice())
  54. for _, net := range p.bypassNetworks {
  55. if net.Contains(ip) {
  56. return p.bypass
  57. }
  58. }
  59. for _, bypassIP := range p.bypassIPs {
  60. if bypassIP.Equal(ip) {
  61. return p.bypass
  62. }
  63. }
  64. return p.def
  65. }
  66. for _, zone := range p.bypassZones {
  67. if strings.HasSuffix(host, zone) {
  68. return p.bypass
  69. }
  70. if host == zone[1:] {
  71. // For a zone ".example.com", we match "example.com"
  72. // too.
  73. return p.bypass
  74. }
  75. }
  76. for _, bypassHost := range p.bypassHosts {
  77. if bypassHost == host {
  78. return p.bypass
  79. }
  80. }
  81. return p.def
  82. }
  83. // AddFromString parses a string that contains comma-separated values
  84. // specifying hosts that should use the bypass proxy. Each value is either an
  85. // IP address, a CIDR range, a zone (*.example.com) or a host name
  86. // (localhost). A best effort is made to parse the string and errors are
  87. // ignored.
  88. func (p *PerHost) AddFromString(s string) {
  89. hosts := strings.Split(s, ",")
  90. for _, host := range hosts {
  91. host = strings.TrimSpace(host)
  92. if len(host) == 0 {
  93. continue
  94. }
  95. if strings.Contains(host, "/") {
  96. // We assume that it's a CIDR address like 127.0.0.0/8
  97. if _, net, err := net.ParseCIDR(host); err == nil {
  98. p.AddNetwork(net)
  99. }
  100. continue
  101. }
  102. if nip, err := netip.ParseAddr(host); err == nil {
  103. p.AddIP(net.IP(nip.AsSlice()))
  104. continue
  105. }
  106. if strings.HasPrefix(host, "*.") {
  107. p.AddZone(host[1:])
  108. continue
  109. }
  110. p.AddHost(host)
  111. }
  112. }
  113. // AddIP specifies an IP address that will use the bypass proxy. Note that
  114. // this will only take effect if a literal IP address is dialed. A connection
  115. // to a named host will never match an IP.
  116. func (p *PerHost) AddIP(ip net.IP) {
  117. p.bypassIPs = append(p.bypassIPs, ip)
  118. }
  119. // AddNetwork specifies an IP range that will use the bypass proxy. Note that
  120. // this will only take effect if a literal IP address is dialed. A connection
  121. // to a named host will never match.
  122. func (p *PerHost) AddNetwork(net *net.IPNet) {
  123. p.bypassNetworks = append(p.bypassNetworks, net)
  124. }
  125. // AddZone specifies a DNS suffix that will use the bypass proxy. A zone of
  126. // "example.com" matches "example.com" and all of its subdomains.
  127. func (p *PerHost) AddZone(zone string) {
  128. zone = strings.TrimSuffix(zone, ".")
  129. if !strings.HasPrefix(zone, ".") {
  130. zone = "." + zone
  131. }
  132. p.bypassZones = append(p.bypassZones, zone)
  133. }
  134. // AddHost specifies a host name that will use the bypass proxy.
  135. func (p *PerHost) AddHost(host string) {
  136. host = strings.TrimSuffix(host, ".")
  137. p.bypassHosts = append(p.bypassHosts, host)
  138. }