udp_muxed_conn.go 5.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246
  1. package ice
  2. import (
  3. "encoding/binary"
  4. "io"
  5. "net"
  6. "sync"
  7. "time"
  8. "github.com/pion/logging"
  9. "github.com/pion/transport/packetio"
  10. )
  11. type udpMuxedConnParams struct {
  12. Mux *UDPMuxDefault
  13. AddrPool *sync.Pool
  14. Key string
  15. LocalAddr net.Addr
  16. Logger logging.LeveledLogger
  17. }
  18. // udpMuxedConn represents a logical packet conn for a single remote as identified by ufrag
  19. type udpMuxedConn struct {
  20. params *udpMuxedConnParams
  21. // remote addresses that we have sent to on this conn
  22. addresses []string
  23. // channel holding incoming packets
  24. buffer *packetio.Buffer
  25. closedChan chan struct{}
  26. closeOnce sync.Once
  27. mu sync.Mutex
  28. }
  29. func newUDPMuxedConn(params *udpMuxedConnParams) *udpMuxedConn {
  30. p := &udpMuxedConn{
  31. params: params,
  32. buffer: packetio.NewBuffer(),
  33. closedChan: make(chan struct{}),
  34. }
  35. return p
  36. }
  37. func (c *udpMuxedConn) ReadFrom(b []byte) (n int, raddr net.Addr, err error) {
  38. buf := c.params.AddrPool.Get().(*bufferHolder)
  39. defer c.params.AddrPool.Put(buf)
  40. // read address
  41. total, err := c.buffer.Read(buf.buffer)
  42. if err != nil {
  43. return 0, nil, err
  44. }
  45. dataLen := int(binary.LittleEndian.Uint16(buf.buffer[:2]))
  46. if dataLen > total || dataLen > len(b) {
  47. return 0, nil, io.ErrShortBuffer
  48. }
  49. // read data and then address
  50. offset := 2
  51. copy(b, buf.buffer[offset:offset+dataLen])
  52. offset += dataLen
  53. // read address len & decode address
  54. addrLen := int(binary.LittleEndian.Uint16(buf.buffer[offset : offset+2]))
  55. offset += 2
  56. if raddr, err = decodeUDPAddr(buf.buffer[offset : offset+addrLen]); err != nil {
  57. return 0, nil, err
  58. }
  59. return dataLen, raddr, nil
  60. }
  61. func (c *udpMuxedConn) WriteTo(buf []byte, raddr net.Addr) (n int, err error) {
  62. if c.isClosed() {
  63. return 0, io.ErrClosedPipe
  64. }
  65. // each time we write to a new address, we'll register it with the mux
  66. addr := raddr.String()
  67. if !c.containsAddress(addr) {
  68. c.addAddress(addr)
  69. }
  70. return c.params.Mux.writeTo(buf, raddr)
  71. }
  72. func (c *udpMuxedConn) LocalAddr() net.Addr {
  73. return c.params.LocalAddr
  74. }
  75. func (c *udpMuxedConn) SetDeadline(tm time.Time) error {
  76. return nil
  77. }
  78. func (c *udpMuxedConn) SetReadDeadline(tm time.Time) error {
  79. return nil
  80. }
  81. func (c *udpMuxedConn) SetWriteDeadline(tm time.Time) error {
  82. return nil
  83. }
  84. func (c *udpMuxedConn) CloseChannel() <-chan struct{} {
  85. return c.closedChan
  86. }
  87. func (c *udpMuxedConn) Close() error {
  88. var err error
  89. c.closeOnce.Do(func() {
  90. err = c.buffer.Close()
  91. close(c.closedChan)
  92. })
  93. c.mu.Lock()
  94. defer c.mu.Unlock()
  95. c.addresses = nil
  96. return err
  97. }
  98. func (c *udpMuxedConn) isClosed() bool {
  99. select {
  100. case <-c.closedChan:
  101. return true
  102. default:
  103. return false
  104. }
  105. }
  106. func (c *udpMuxedConn) getAddresses() []string {
  107. c.mu.Lock()
  108. defer c.mu.Unlock()
  109. addresses := make([]string, len(c.addresses))
  110. copy(addresses, c.addresses)
  111. return addresses
  112. }
  113. func (c *udpMuxedConn) addAddress(addr string) {
  114. c.mu.Lock()
  115. c.addresses = append(c.addresses, addr)
  116. c.mu.Unlock()
  117. // map it on mux
  118. c.params.Mux.registerConnForAddress(c, addr)
  119. }
  120. func (c *udpMuxedConn) removeAddress(addr string) {
  121. c.mu.Lock()
  122. defer c.mu.Unlock()
  123. newAddresses := make([]string, 0, len(c.addresses))
  124. for _, a := range c.addresses {
  125. if a != addr {
  126. newAddresses = append(newAddresses, a)
  127. }
  128. }
  129. c.addresses = newAddresses
  130. }
  131. func (c *udpMuxedConn) containsAddress(addr string) bool {
  132. c.mu.Lock()
  133. defer c.mu.Unlock()
  134. for _, a := range c.addresses {
  135. if addr == a {
  136. return true
  137. }
  138. }
  139. return false
  140. }
  141. func (c *udpMuxedConn) writePacket(data []byte, addr *net.UDPAddr) error {
  142. // write two packets, address and data
  143. buf := c.params.AddrPool.Get().(*bufferHolder)
  144. defer c.params.AddrPool.Put(buf)
  145. // format of buffer | data len | data bytes | addr len | addr bytes |
  146. if len(buf.buffer) < len(data)+maxAddrSize {
  147. return io.ErrShortBuffer
  148. }
  149. // data len
  150. binary.LittleEndian.PutUint16(buf.buffer, uint16(len(data)))
  151. offset := 2
  152. // data
  153. copy(buf.buffer[offset:], data)
  154. offset += len(data)
  155. // write address first, leaving room for its length
  156. n, err := encodeUDPAddr(addr, buf.buffer[offset+2:])
  157. if err != nil {
  158. return nil
  159. }
  160. total := offset + n + 2
  161. // address len
  162. binary.LittleEndian.PutUint16(buf.buffer[offset:], uint16(n))
  163. if _, err := c.buffer.Write(buf.buffer[:total]); err != nil {
  164. return err
  165. }
  166. return nil
  167. }
  168. func encodeUDPAddr(addr *net.UDPAddr, buf []byte) (int, error) {
  169. ipdata, err := addr.IP.MarshalText()
  170. if err != nil {
  171. return 0, err
  172. }
  173. total := 2 + len(ipdata) + 2 + len(addr.Zone)
  174. if total > len(buf) {
  175. return 0, io.ErrShortBuffer
  176. }
  177. binary.LittleEndian.PutUint16(buf, uint16(len(ipdata)))
  178. offset := 2
  179. n := copy(buf[offset:], ipdata)
  180. offset += n
  181. binary.LittleEndian.PutUint16(buf[offset:], uint16(addr.Port))
  182. offset += 2
  183. copy(buf[offset:], addr.Zone)
  184. return total, nil
  185. }
  186. func decodeUDPAddr(buf []byte) (*net.UDPAddr, error) {
  187. addr := net.UDPAddr{}
  188. offset := 0
  189. ipLen := int(binary.LittleEndian.Uint16(buf[:2]))
  190. offset += 2
  191. // basic bounds checking
  192. if ipLen+offset > len(buf) {
  193. return nil, io.ErrShortBuffer
  194. }
  195. if err := addr.IP.UnmarshalText(buf[offset : offset+ipLen]); err != nil {
  196. return nil, err
  197. }
  198. offset += ipLen
  199. addr.Port = int(binary.LittleEndian.Uint16(buf[offset : offset+2]))
  200. offset += 2
  201. zone := make([]byte, len(buf[offset:]))
  202. copy(zone, buf[offset:])
  203. addr.Zone = string(zone)
  204. return &addr, nil
  205. }