conn.go 7.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312
  1. // Package udp provides a connection-oriented listener over a UDP PacketConn
  2. package udp
  3. import (
  4. "context"
  5. "errors"
  6. "net"
  7. "sync"
  8. "sync/atomic"
  9. "time"
  10. "github.com/pion/transport/v2/deadline"
  11. "github.com/pion/transport/v2/packetio"
  12. pkgSync "github.com/pion/udp/pkg/sync"
  13. )
  14. const (
  15. receiveMTU = 8192
  16. defaultListenBacklog = 128 // same as Linux default
  17. )
  18. // Typed errors
  19. var (
  20. ErrClosedListener = errors.New("udp: listener closed")
  21. ErrListenQueueExceeded = errors.New("udp: listen queue exceeded")
  22. )
  23. // listener augments a connection-oriented Listener over a UDP PacketConn
  24. type listener struct {
  25. pConn *net.UDPConn
  26. accepting atomic.Value // bool
  27. acceptCh chan *Conn
  28. doneCh chan struct{}
  29. doneOnce sync.Once
  30. acceptFilter func([]byte) bool
  31. readBufferPool *sync.Pool
  32. connLock sync.Mutex
  33. conns map[string]*Conn
  34. connWG *pkgSync.WaitGroup
  35. readWG sync.WaitGroup
  36. errClose atomic.Value // error
  37. }
  38. // Accept waits for and returns the next connection to the listener.
  39. func (l *listener) Accept() (net.Conn, error) {
  40. select {
  41. case c := <-l.acceptCh:
  42. l.connWG.Add(1)
  43. return c, nil
  44. case <-l.doneCh:
  45. return nil, ErrClosedListener
  46. }
  47. }
  48. // Close closes the listener.
  49. // Any blocked Accept operations will be unblocked and return errors.
  50. func (l *listener) Close() error {
  51. var err error
  52. l.doneOnce.Do(func() {
  53. l.accepting.Store(false)
  54. close(l.doneCh)
  55. l.connLock.Lock()
  56. // Close unaccepted connections
  57. L_CLOSE:
  58. for {
  59. select {
  60. case c := <-l.acceptCh:
  61. close(c.doneCh)
  62. delete(l.conns, c.rAddr.String())
  63. default:
  64. break L_CLOSE
  65. }
  66. }
  67. nConns := len(l.conns)
  68. l.connLock.Unlock()
  69. l.connWG.Done()
  70. if nConns == 0 {
  71. // Wait if this is the final connection
  72. l.readWG.Wait()
  73. if errClose, ok := l.errClose.Load().(error); ok {
  74. err = errClose
  75. }
  76. } else {
  77. err = nil
  78. }
  79. })
  80. return err
  81. }
  82. // Addr returns the listener's network address.
  83. func (l *listener) Addr() net.Addr {
  84. return l.pConn.LocalAddr()
  85. }
  86. // ListenConfig stores options for listening to an address.
  87. type ListenConfig struct {
  88. // Backlog defines the maximum length of the queue of pending
  89. // connections. It is equivalent of the backlog argument of
  90. // POSIX listen function.
  91. // If a connection request arrives when the queue is full,
  92. // the request will be silently discarded, unlike TCP.
  93. // Set zero to use default value 128 which is same as Linux default.
  94. Backlog int
  95. // AcceptFilter determines whether the new conn should be made for
  96. // the incoming packet. If not set, any packet creates new conn.
  97. AcceptFilter func([]byte) bool
  98. }
  99. // Listen creates a new listener based on the ListenConfig.
  100. func (lc *ListenConfig) Listen(network string, laddr *net.UDPAddr) (net.Listener, error) {
  101. if lc.Backlog == 0 {
  102. lc.Backlog = defaultListenBacklog
  103. }
  104. conn, err := net.ListenUDP(network, laddr)
  105. if err != nil {
  106. return nil, err
  107. }
  108. l := &listener{
  109. pConn: conn,
  110. acceptCh: make(chan *Conn, lc.Backlog),
  111. conns: make(map[string]*Conn),
  112. doneCh: make(chan struct{}),
  113. acceptFilter: lc.AcceptFilter,
  114. readBufferPool: &sync.Pool{
  115. New: func() interface{} {
  116. buf := make([]byte, receiveMTU)
  117. return &buf
  118. },
  119. },
  120. connWG: pkgSync.NewWaitGroup(),
  121. }
  122. l.accepting.Store(true)
  123. l.connWG.Add(1)
  124. l.readWG.Add(2) // wait readLoop and Close execution routine
  125. go l.readLoop()
  126. go func() {
  127. l.connWG.Wait()
  128. if err := l.pConn.Close(); err != nil {
  129. l.errClose.Store(err)
  130. }
  131. l.readWG.Done()
  132. }()
  133. return l, nil
  134. }
  135. // Listen creates a new listener using default ListenConfig.
  136. func Listen(network string, laddr *net.UDPAddr) (net.Listener, error) {
  137. return (&ListenConfig{}).Listen(network, laddr)
  138. }
  139. // readLoop has to tasks:
  140. // 1. Dispatching incoming packets to the correct Conn.
  141. // It can therefore not be ended until all Conns are closed.
  142. // 2. Creating a new Conn when receiving from a new remote.
  143. func (l *listener) readLoop() {
  144. defer l.readWG.Done()
  145. for {
  146. buf, ok := l.readBufferPool.Get().(*[]byte)
  147. if !ok {
  148. return
  149. }
  150. n, raddr, err := l.pConn.ReadFrom(*buf)
  151. if err != nil {
  152. return
  153. }
  154. conn, ok, err := l.getConn(raddr, (*buf)[:n])
  155. if err != nil {
  156. continue
  157. }
  158. if ok {
  159. _, _ = conn.buffer.Write((*buf)[:n])
  160. }
  161. }
  162. }
  163. func (l *listener) getConn(raddr net.Addr, buf []byte) (*Conn, bool, error) {
  164. l.connLock.Lock()
  165. defer l.connLock.Unlock()
  166. conn, ok := l.conns[raddr.String()]
  167. if !ok {
  168. if isAccepting, ok := l.accepting.Load().(bool); !isAccepting || !ok {
  169. return nil, false, ErrClosedListener
  170. }
  171. if l.acceptFilter != nil {
  172. if !l.acceptFilter(buf) {
  173. return nil, false, nil
  174. }
  175. }
  176. conn = l.newConn(raddr)
  177. select {
  178. case l.acceptCh <- conn:
  179. l.conns[raddr.String()] = conn
  180. default:
  181. return nil, false, ErrListenQueueExceeded
  182. }
  183. }
  184. return conn, true, nil
  185. }
  186. // Conn augments a connection-oriented connection over a UDP PacketConn
  187. type Conn struct {
  188. listener *listener
  189. rAddr net.Addr
  190. buffer *packetio.Buffer
  191. doneCh chan struct{}
  192. doneOnce sync.Once
  193. writeDeadline *deadline.Deadline
  194. }
  195. func (l *listener) newConn(rAddr net.Addr) *Conn {
  196. return &Conn{
  197. listener: l,
  198. rAddr: rAddr,
  199. buffer: packetio.NewBuffer(),
  200. doneCh: make(chan struct{}),
  201. writeDeadline: deadline.New(),
  202. }
  203. }
  204. // Read reads from c into p
  205. func (c *Conn) Read(p []byte) (int, error) {
  206. return c.buffer.Read(p)
  207. }
  208. // Write writes len(p) bytes from p to the DTLS connection
  209. func (c *Conn) Write(p []byte) (n int, err error) {
  210. select {
  211. case <-c.writeDeadline.Done():
  212. return 0, context.DeadlineExceeded
  213. default:
  214. }
  215. return c.listener.pConn.WriteTo(p, c.rAddr)
  216. }
  217. // Close closes the conn and releases any Read calls
  218. func (c *Conn) Close() error {
  219. var err error
  220. c.doneOnce.Do(func() {
  221. c.listener.connWG.Done()
  222. close(c.doneCh)
  223. c.listener.connLock.Lock()
  224. delete(c.listener.conns, c.rAddr.String())
  225. nConns := len(c.listener.conns)
  226. c.listener.connLock.Unlock()
  227. if isAccepting, ok := c.listener.accepting.Load().(bool); nConns == 0 && !isAccepting && ok {
  228. // Wait if this is the final connection
  229. c.listener.readWG.Wait()
  230. if errClose, ok := c.listener.errClose.Load().(error); ok {
  231. err = errClose
  232. }
  233. } else {
  234. err = nil
  235. }
  236. if errBuf := c.buffer.Close(); errBuf != nil && err == nil {
  237. err = errBuf
  238. }
  239. })
  240. return err
  241. }
  242. // LocalAddr implements net.Conn.LocalAddr
  243. func (c *Conn) LocalAddr() net.Addr {
  244. return c.listener.pConn.LocalAddr()
  245. }
  246. // RemoteAddr implements net.Conn.RemoteAddr
  247. func (c *Conn) RemoteAddr() net.Addr {
  248. return c.rAddr
  249. }
  250. // SetDeadline implements net.Conn.SetDeadline
  251. func (c *Conn) SetDeadline(t time.Time) error {
  252. c.writeDeadline.Set(t)
  253. return c.SetReadDeadline(t)
  254. }
  255. // SetReadDeadline implements net.Conn.SetDeadline
  256. func (c *Conn) SetReadDeadline(t time.Time) error {
  257. return c.buffer.SetReadDeadline(t)
  258. }
  259. // SetWriteDeadline implements net.Conn.SetDeadline
  260. func (c *Conn) SetWriteDeadline(t time.Time) error {
  261. c.writeDeadline.Set(t)
  262. // Write deadline of underlying connection should not be changed
  263. // since the connection can be shared.
  264. return nil
  265. }