conn.go 6.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317
  1. package mdns
  2. import (
  3. "context"
  4. "errors"
  5. "math/big"
  6. "net"
  7. "sync"
  8. "time"
  9. "github.com/pion/logging"
  10. "golang.org/x/net/dns/dnsmessage"
  11. "golang.org/x/net/ipv4"
  12. )
  13. // Conn represents a mDNS Server
  14. type Conn struct {
  15. mu sync.RWMutex
  16. log logging.LeveledLogger
  17. socket *ipv4.PacketConn
  18. dstAddr *net.UDPAddr
  19. queryInterval time.Duration
  20. localNames []string
  21. queries []query
  22. closed chan interface{}
  23. }
  24. type query struct {
  25. nameWithSuffix string
  26. queryResultChan chan queryResult
  27. }
  28. type queryResult struct {
  29. answer dnsmessage.ResourceHeader
  30. addr net.Addr
  31. }
  32. const (
  33. inboundBufferSize = 512
  34. defaultQueryInterval = time.Second
  35. destinationAddress = "224.0.0.251:5353"
  36. maxMessageRecords = 3
  37. responseTTL = 120
  38. )
  39. // Server establishes a mDNS connection over an existing conn
  40. func Server(conn *ipv4.PacketConn, config *Config) (*Conn, error) {
  41. if config == nil {
  42. return nil, errNilConfig
  43. }
  44. ifaces, err := net.Interfaces()
  45. if err != nil {
  46. return nil, err
  47. }
  48. joinErrCount := 0
  49. for i := range ifaces {
  50. if err = conn.JoinGroup(&ifaces[i], &net.UDPAddr{IP: net.IPv4(224, 0, 0, 251)}); err != nil {
  51. joinErrCount++
  52. }
  53. }
  54. if joinErrCount >= len(ifaces) {
  55. return nil, errJoiningMulticastGroup
  56. }
  57. dstAddr, err := net.ResolveUDPAddr("udp", destinationAddress)
  58. if err != nil {
  59. return nil, err
  60. }
  61. loggerFactory := config.LoggerFactory
  62. if loggerFactory == nil {
  63. loggerFactory = logging.NewDefaultLoggerFactory()
  64. }
  65. localNames := []string{}
  66. for _, l := range config.LocalNames {
  67. localNames = append(localNames, l+".")
  68. }
  69. c := &Conn{
  70. queryInterval: defaultQueryInterval,
  71. queries: []query{},
  72. socket: conn,
  73. dstAddr: dstAddr,
  74. localNames: localNames,
  75. log: loggerFactory.NewLogger("mdns"),
  76. closed: make(chan interface{}),
  77. }
  78. if config.QueryInterval != 0 {
  79. c.queryInterval = config.QueryInterval
  80. }
  81. go c.start()
  82. return c, nil
  83. }
  84. // Close closes the mDNS Conn
  85. func (c *Conn) Close() error {
  86. select {
  87. case <-c.closed:
  88. return nil
  89. default:
  90. }
  91. if err := c.socket.Close(); err != nil {
  92. return err
  93. }
  94. <-c.closed
  95. return nil
  96. }
  97. // Query sends mDNS Queries for the following name until
  98. // either the Context is canceled/expires or we get a result
  99. func (c *Conn) Query(ctx context.Context, name string) (dnsmessage.ResourceHeader, net.Addr, error) {
  100. select {
  101. case <-c.closed:
  102. return dnsmessage.ResourceHeader{}, nil, errConnectionClosed
  103. default:
  104. }
  105. nameWithSuffix := name + "."
  106. queryChan := make(chan queryResult, 1)
  107. c.mu.Lock()
  108. c.queries = append(c.queries, query{nameWithSuffix, queryChan})
  109. ticker := time.NewTicker(c.queryInterval)
  110. c.mu.Unlock()
  111. defer ticker.Stop()
  112. c.sendQuestion(nameWithSuffix)
  113. for {
  114. select {
  115. case <-ticker.C:
  116. c.sendQuestion(nameWithSuffix)
  117. case <-c.closed:
  118. return dnsmessage.ResourceHeader{}, nil, errConnectionClosed
  119. case res := <-queryChan:
  120. return res.answer, res.addr, nil
  121. case <-ctx.Done():
  122. return dnsmessage.ResourceHeader{}, nil, errContextElapsed
  123. }
  124. }
  125. }
  126. func ipToBytes(ip net.IP) (out [4]byte) {
  127. rawIP := ip.To4()
  128. if rawIP == nil {
  129. return
  130. }
  131. ipInt := big.NewInt(0)
  132. ipInt.SetBytes(rawIP)
  133. copy(out[:], ipInt.Bytes())
  134. return
  135. }
  136. func interfaceForRemote(remote string) (net.IP, error) {
  137. conn, err := net.Dial("udp", remote)
  138. if err != nil {
  139. return nil, err
  140. }
  141. localAddr := conn.LocalAddr().(*net.UDPAddr)
  142. if err := conn.Close(); err != nil {
  143. return nil, err
  144. }
  145. return localAddr.IP, nil
  146. }
  147. func (c *Conn) sendQuestion(name string) {
  148. packedName, err := dnsmessage.NewName(name)
  149. if err != nil {
  150. c.log.Warnf("Failed to construct mDNS packet %v", err)
  151. return
  152. }
  153. msg := dnsmessage.Message{
  154. Header: dnsmessage.Header{},
  155. Questions: []dnsmessage.Question{
  156. {
  157. Type: dnsmessage.TypeA,
  158. Class: dnsmessage.ClassINET,
  159. Name: packedName,
  160. },
  161. },
  162. }
  163. rawQuery, err := msg.Pack()
  164. if err != nil {
  165. c.log.Warnf("Failed to construct mDNS packet %v", err)
  166. return
  167. }
  168. if _, err := c.socket.WriteTo(rawQuery, nil, c.dstAddr); err != nil {
  169. c.log.Warnf("Failed to send mDNS packet %v", err)
  170. return
  171. }
  172. }
  173. func (c *Conn) sendAnswer(name string, dst net.IP) {
  174. packedName, err := dnsmessage.NewName(name)
  175. if err != nil {
  176. c.log.Warnf("Failed to construct mDNS packet %v", err)
  177. return
  178. }
  179. msg := dnsmessage.Message{
  180. Header: dnsmessage.Header{
  181. Response: true,
  182. Authoritative: true,
  183. },
  184. Answers: []dnsmessage.Resource{
  185. {
  186. Header: dnsmessage.ResourceHeader{
  187. Type: dnsmessage.TypeA,
  188. Class: dnsmessage.ClassINET,
  189. Name: packedName,
  190. TTL: responseTTL,
  191. },
  192. Body: &dnsmessage.AResource{
  193. A: ipToBytes(dst),
  194. },
  195. },
  196. },
  197. }
  198. rawAnswer, err := msg.Pack()
  199. if err != nil {
  200. c.log.Warnf("Failed to construct mDNS packet %v", err)
  201. return
  202. }
  203. if _, err := c.socket.WriteTo(rawAnswer, nil, c.dstAddr); err != nil {
  204. c.log.Warnf("Failed to send mDNS packet %v", err)
  205. return
  206. }
  207. }
  208. func (c *Conn) start() { //nolint gocognit
  209. defer func() {
  210. c.mu.Lock()
  211. defer c.mu.Unlock()
  212. close(c.closed)
  213. }()
  214. b := make([]byte, inboundBufferSize)
  215. p := dnsmessage.Parser{}
  216. for {
  217. n, _, src, err := c.socket.ReadFrom(b)
  218. if err != nil {
  219. return
  220. }
  221. func() {
  222. c.mu.RLock()
  223. defer c.mu.RUnlock()
  224. if _, err := p.Start(b[:n]); err != nil {
  225. c.log.Warnf("Failed to parse mDNS packet %v", err)
  226. return
  227. }
  228. for i := 0; i <= maxMessageRecords; i++ {
  229. q, err := p.Question()
  230. if errors.Is(err, dnsmessage.ErrSectionDone) {
  231. break
  232. } else if err != nil {
  233. c.log.Warnf("Failed to parse mDNS packet %v", err)
  234. return
  235. }
  236. for _, localName := range c.localNames {
  237. if localName == q.Name.String() {
  238. localAddress, err := interfaceForRemote(src.String())
  239. if err != nil {
  240. c.log.Warnf("Failed to get local interface to communicate with %s: %v", src.String(), err)
  241. continue
  242. }
  243. c.sendAnswer(q.Name.String(), localAddress)
  244. }
  245. }
  246. }
  247. for i := 0; i <= maxMessageRecords; i++ {
  248. a, err := p.AnswerHeader()
  249. if errors.Is(err, dnsmessage.ErrSectionDone) {
  250. return
  251. }
  252. if err != nil {
  253. c.log.Warnf("Failed to parse mDNS packet %v", err)
  254. return
  255. }
  256. if a.Type != dnsmessage.TypeA && a.Type != dnsmessage.TypeAAAA {
  257. continue
  258. }
  259. for i := len(c.queries) - 1; i >= 0; i-- {
  260. if c.queries[i].nameWithSuffix == a.Name.String() {
  261. c.queries[i].queryResultChan <- queryResult{a, src}
  262. c.queries = append(c.queries[:i], c.queries[i+1:]...)
  263. }
  264. }
  265. }
  266. }()
  267. }
  268. }