tcp_packet_conn.go 4.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240
  1. package ice
  2. import (
  3. "fmt"
  4. "io"
  5. "net"
  6. "sync"
  7. "time"
  8. "github.com/pion/logging"
  9. )
  10. type tcpPacketConn struct {
  11. params *tcpPacketParams
  12. // conns is a map of net.Conns indexed by remote net.Addr.String()
  13. conns map[string]net.Conn
  14. recvChan chan streamingPacket
  15. mu sync.Mutex
  16. wg sync.WaitGroup
  17. closedChan chan struct{}
  18. closeOnce sync.Once
  19. }
  20. type streamingPacket struct {
  21. Data []byte
  22. RAddr net.Addr
  23. Err error
  24. }
  25. type tcpPacketParams struct {
  26. ReadBuffer int
  27. LocalAddr net.Addr
  28. Logger logging.LeveledLogger
  29. }
  30. func newTCPPacketConn(params tcpPacketParams) *tcpPacketConn {
  31. p := &tcpPacketConn{
  32. params: &params,
  33. conns: map[string]net.Conn{},
  34. recvChan: make(chan streamingPacket, params.ReadBuffer),
  35. closedChan: make(chan struct{}),
  36. }
  37. return p
  38. }
  39. func (t *tcpPacketConn) AddConn(conn net.Conn, firstPacketData []byte) error {
  40. t.params.Logger.Infof("AddConn: %s %s", conn.RemoteAddr().Network(), conn.RemoteAddr())
  41. t.mu.Lock()
  42. defer t.mu.Unlock()
  43. select {
  44. case <-t.closedChan:
  45. return io.ErrClosedPipe
  46. default:
  47. }
  48. if _, ok := t.conns[conn.RemoteAddr().String()]; ok {
  49. return fmt.Errorf("%w: %s", errConnectionAddrAlreadyExist, conn.RemoteAddr().String())
  50. }
  51. t.conns[conn.RemoteAddr().String()] = conn
  52. t.wg.Add(1)
  53. go func() {
  54. if firstPacketData != nil {
  55. t.recvChan <- streamingPacket{firstPacketData, conn.RemoteAddr(), nil}
  56. }
  57. defer t.wg.Done()
  58. t.startReading(conn)
  59. }()
  60. return nil
  61. }
  62. func (t *tcpPacketConn) startReading(conn net.Conn) {
  63. buf := make([]byte, receiveMTU)
  64. for {
  65. n, err := readStreamingPacket(conn, buf)
  66. // t.params.Logger.Infof("readStreamingPacket read %d bytes", n)
  67. if err != nil {
  68. t.params.Logger.Infof("%w: %s\n", errReadingStreamingPacket, err)
  69. t.handleRecv(streamingPacket{nil, conn.RemoteAddr(), err})
  70. t.removeConn(conn)
  71. return
  72. }
  73. data := make([]byte, n)
  74. copy(data, buf[:n])
  75. // t.params.Logger.Infof("Writing read streaming packet to recvChan: %d bytes", len(data))
  76. t.handleRecv(streamingPacket{data, conn.RemoteAddr(), nil})
  77. }
  78. }
  79. func (t *tcpPacketConn) handleRecv(pkt streamingPacket) {
  80. t.mu.Lock()
  81. recvChan := t.recvChan
  82. if t.isClosed() {
  83. recvChan = nil
  84. }
  85. t.mu.Unlock()
  86. select {
  87. case recvChan <- pkt:
  88. case <-t.closedChan:
  89. }
  90. }
  91. func (t *tcpPacketConn) isClosed() bool {
  92. select {
  93. case <-t.closedChan:
  94. return true
  95. default:
  96. return false
  97. }
  98. }
  99. // WriteTo is for passive and s-o candidates.
  100. func (t *tcpPacketConn) ReadFrom(b []byte) (n int, raddr net.Addr, err error) {
  101. pkt, ok := <-t.recvChan
  102. if !ok {
  103. return 0, nil, io.ErrClosedPipe
  104. }
  105. if pkt.Err != nil {
  106. return 0, pkt.RAddr, pkt.Err
  107. }
  108. if cap(b) < len(pkt.Data) {
  109. return 0, pkt.RAddr, io.ErrShortBuffer
  110. }
  111. n = len(pkt.Data)
  112. copy(b, pkt.Data[:n])
  113. return n, pkt.RAddr, err
  114. }
  115. // WriteTo is for active and s-o candidates.
  116. func (t *tcpPacketConn) WriteTo(buf []byte, raddr net.Addr) (n int, err error) {
  117. t.mu.Lock()
  118. conn, ok := t.conns[raddr.String()]
  119. t.mu.Unlock()
  120. if !ok {
  121. return 0, io.ErrClosedPipe
  122. // conn, err := net.DialTCP(tcp, nil, raddr.(*net.TCPAddr))
  123. // if err != nil {
  124. // t.params.Logger.Tracef("DialTCP error: %s", err)
  125. // return 0, err
  126. // }
  127. // go t.startReading(conn)
  128. // t.conns[raddr.String()] = conn
  129. }
  130. n, err = writeStreamingPacket(conn, buf)
  131. if err != nil {
  132. t.params.Logger.Tracef("%w %s\n", errWriting, raddr)
  133. return n, err
  134. }
  135. return n, err
  136. }
  137. func (t *tcpPacketConn) closeAndLogError(closer io.Closer) {
  138. err := closer.Close()
  139. if err != nil {
  140. t.params.Logger.Warnf("%w: %s", errClosingConnection, err)
  141. }
  142. }
  143. func (t *tcpPacketConn) removeConn(conn net.Conn) {
  144. t.mu.Lock()
  145. defer t.mu.Unlock()
  146. t.closeAndLogError(conn)
  147. delete(t.conns, conn.RemoteAddr().String())
  148. }
  149. func (t *tcpPacketConn) Close() error {
  150. t.mu.Lock()
  151. var shouldCloseRecvChan bool
  152. t.closeOnce.Do(func() {
  153. close(t.closedChan)
  154. shouldCloseRecvChan = true
  155. })
  156. for _, conn := range t.conns {
  157. t.closeAndLogError(conn)
  158. delete(t.conns, conn.RemoteAddr().String())
  159. }
  160. t.mu.Unlock()
  161. t.wg.Wait()
  162. if shouldCloseRecvChan {
  163. close(t.recvChan)
  164. }
  165. return nil
  166. }
  167. func (t *tcpPacketConn) LocalAddr() net.Addr {
  168. return t.params.LocalAddr
  169. }
  170. func (t *tcpPacketConn) SetDeadline(tm time.Time) error {
  171. return nil
  172. }
  173. func (t *tcpPacketConn) SetReadDeadline(tm time.Time) error {
  174. return nil
  175. }
  176. func (t *tcpPacketConn) SetWriteDeadline(tm time.Time) error {
  177. return nil
  178. }
  179. func (t *tcpPacketConn) CloseChannel() <-chan struct{} {
  180. return t.closedChan
  181. }
  182. func (t *tcpPacketConn) String() string {
  183. return fmt.Sprintf("tcpPacketConn{LocalAddr: %s}", t.params.LocalAddr)
  184. }