tcp_mux.go 7.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323
  1. package ice
  2. import (
  3. "encoding/binary"
  4. "io"
  5. "net"
  6. "strings"
  7. "sync"
  8. "github.com/pion/logging"
  9. "github.com/pion/stun"
  10. )
  11. // TCPMux is allows grouping multiple TCP net.Conns and using them like UDP
  12. // net.PacketConns. The main implementation of this is TCPMuxDefault, and this
  13. // interface exists to:
  14. // 1. prevent SEGV panics when TCPMuxDefault is not initialized by using the
  15. // invalidTCPMux implementation, and
  16. // 2. allow mocking in tests.
  17. type TCPMux interface {
  18. io.Closer
  19. GetConnByUfrag(ufrag string, isIPv6 bool) (net.PacketConn, error)
  20. RemoveConnByUfrag(ufrag string)
  21. }
  22. // invalidTCPMux is an implementation of TCPMux that always returns ErrTCPMuxNotInitialized.
  23. type invalidTCPMux struct{}
  24. func newInvalidTCPMux() *invalidTCPMux {
  25. return &invalidTCPMux{}
  26. }
  27. // Close implements TCPMux interface.
  28. func (m *invalidTCPMux) Close() error {
  29. return ErrTCPMuxNotInitialized
  30. }
  31. // GetConnByUfrag implements TCPMux interface.
  32. func (m *invalidTCPMux) GetConnByUfrag(ufrag string, isIPv6 bool) (net.PacketConn, error) {
  33. return nil, ErrTCPMuxNotInitialized
  34. }
  35. // RemoveConnByUfrag implements TCPMux interface.
  36. func (m *invalidTCPMux) RemoveConnByUfrag(ufrag string) {}
  37. // TCPMuxDefault muxes TCP net.Conns into net.PacketConns and groups them by
  38. // Ufrag. It is a default implementation of TCPMux interface.
  39. type TCPMuxDefault struct {
  40. params *TCPMuxParams
  41. closed bool
  42. // connsIPv4 and connsIPv6 are maps of all tcpPacketConns indexed by ufrag
  43. connsIPv4, connsIPv6 map[string]*tcpPacketConn
  44. mu sync.Mutex
  45. wg sync.WaitGroup
  46. }
  47. // TCPMuxParams are parameters for TCPMux.
  48. type TCPMuxParams struct {
  49. Listener net.Listener
  50. Logger logging.LeveledLogger
  51. ReadBufferSize int
  52. }
  53. // NewTCPMuxDefault creates a new instance of TCPMuxDefault.
  54. func NewTCPMuxDefault(params TCPMuxParams) *TCPMuxDefault {
  55. if params.Logger == nil {
  56. params.Logger = logging.NewDefaultLoggerFactory().NewLogger("ice")
  57. }
  58. m := &TCPMuxDefault{
  59. params: &params,
  60. connsIPv4: map[string]*tcpPacketConn{},
  61. connsIPv6: map[string]*tcpPacketConn{},
  62. }
  63. m.wg.Add(1)
  64. go func() {
  65. defer m.wg.Done()
  66. m.start()
  67. }()
  68. return m
  69. }
  70. func (m *TCPMuxDefault) start() {
  71. m.params.Logger.Infof("Listening TCP on %s\n", m.params.Listener.Addr())
  72. for {
  73. conn, err := m.params.Listener.Accept()
  74. if err != nil {
  75. m.params.Logger.Infof("Error accepting connection: %s\n", err)
  76. return
  77. }
  78. m.params.Logger.Debugf("Accepted connection from: %s to %s", conn.RemoteAddr(), conn.LocalAddr())
  79. m.wg.Add(1)
  80. go func() {
  81. defer m.wg.Done()
  82. m.handleConn(conn)
  83. }()
  84. }
  85. }
  86. // LocalAddr returns the listening address of this TCPMuxDefault.
  87. func (m *TCPMuxDefault) LocalAddr() net.Addr {
  88. return m.params.Listener.Addr()
  89. }
  90. // GetConnByUfrag retrieves an existing or creates a new net.PacketConn.
  91. func (m *TCPMuxDefault) GetConnByUfrag(ufrag string, isIPv6 bool) (net.PacketConn, error) {
  92. m.mu.Lock()
  93. defer m.mu.Unlock()
  94. if m.closed {
  95. return nil, io.ErrClosedPipe
  96. }
  97. if conn, ok := m.getConn(ufrag, isIPv6); ok {
  98. return conn, nil
  99. }
  100. return m.createConn(ufrag, m.LocalAddr(), isIPv6), nil
  101. }
  102. func (m *TCPMuxDefault) createConn(ufrag string, localAddr net.Addr, isIPv6 bool) *tcpPacketConn {
  103. conn := newTCPPacketConn(tcpPacketParams{
  104. ReadBuffer: m.params.ReadBufferSize,
  105. LocalAddr: localAddr,
  106. Logger: m.params.Logger,
  107. })
  108. if isIPv6 {
  109. m.connsIPv6[ufrag] = conn
  110. } else {
  111. m.connsIPv4[ufrag] = conn
  112. }
  113. m.wg.Add(1)
  114. go func() {
  115. defer m.wg.Done()
  116. <-conn.CloseChannel()
  117. m.RemoveConnByUfrag(ufrag)
  118. }()
  119. return conn
  120. }
  121. func (m *TCPMuxDefault) closeAndLogError(closer io.Closer) {
  122. err := closer.Close()
  123. if err != nil {
  124. m.params.Logger.Warnf("Error closing connection: %s", err)
  125. }
  126. }
  127. func (m *TCPMuxDefault) handleConn(conn net.Conn) {
  128. buf := make([]byte, receiveMTU)
  129. n, err := readStreamingPacket(conn, buf)
  130. if err != nil {
  131. m.params.Logger.Warnf("Error reading first packet from %s: %s", conn.RemoteAddr().String(), err)
  132. return
  133. }
  134. buf = buf[:n]
  135. msg := &stun.Message{
  136. Raw: make([]byte, len(buf)),
  137. }
  138. // Explicitly copy raw buffer so Message can own the memory.
  139. copy(msg.Raw, buf)
  140. if err = msg.Decode(); err != nil {
  141. m.closeAndLogError(conn)
  142. m.params.Logger.Warnf("Failed to handle decode ICE from %s to %s: %v\n", conn.RemoteAddr(), conn.LocalAddr(), err)
  143. return
  144. }
  145. if m == nil || msg.Type.Method != stun.MethodBinding { // not a stun
  146. m.closeAndLogError(conn)
  147. m.params.Logger.Warnf("Not a STUN message from %s to %s\n", conn.RemoteAddr(), conn.LocalAddr())
  148. return
  149. }
  150. for _, attr := range msg.Attributes {
  151. m.params.Logger.Debugf("msg attr: %s\n", attr.String())
  152. }
  153. attr, err := msg.Get(stun.AttrUsername)
  154. if err != nil {
  155. m.closeAndLogError(conn)
  156. m.params.Logger.Warnf("No Username attribute in STUN message from %s to %s\n", conn.RemoteAddr(), conn.LocalAddr())
  157. return
  158. }
  159. ufrag := strings.Split(string(attr), ":")[0]
  160. m.params.Logger.Debugf("Ufrag: %s\n", ufrag)
  161. m.mu.Lock()
  162. defer m.mu.Unlock()
  163. host, _, err := net.SplitHostPort(conn.RemoteAddr().String())
  164. if err != nil {
  165. m.closeAndLogError(conn)
  166. m.params.Logger.Warnf("Failed to get host in STUN message from %s to %s\n", conn.RemoteAddr(), conn.LocalAddr())
  167. return
  168. }
  169. isIPv6 := net.ParseIP(host).To4() == nil
  170. packetConn, ok := m.getConn(ufrag, isIPv6)
  171. if !ok {
  172. packetConn = m.createConn(ufrag, conn.LocalAddr(), isIPv6)
  173. }
  174. if err := packetConn.AddConn(conn, buf); err != nil {
  175. m.closeAndLogError(conn)
  176. m.params.Logger.Warnf("Error adding conn to tcpPacketConn from %s to %s: %s\n", conn.RemoteAddr(), conn.LocalAddr(), err)
  177. return
  178. }
  179. }
  180. // Close closes the listener and waits for all goroutines to exit.
  181. func (m *TCPMuxDefault) Close() error {
  182. m.mu.Lock()
  183. m.closed = true
  184. for _, conn := range m.connsIPv4 {
  185. m.closeAndLogError(conn)
  186. }
  187. for _, conn := range m.connsIPv6 {
  188. m.closeAndLogError(conn)
  189. }
  190. m.connsIPv4 = map[string]*tcpPacketConn{}
  191. m.connsIPv6 = map[string]*tcpPacketConn{}
  192. err := m.params.Listener.Close()
  193. m.mu.Unlock()
  194. m.wg.Wait()
  195. return err
  196. }
  197. // RemoveConnByUfrag closes and removes a net.PacketConn by Ufrag.
  198. func (m *TCPMuxDefault) RemoveConnByUfrag(ufrag string) {
  199. m.mu.Lock()
  200. defer m.mu.Unlock()
  201. if conn, ok := m.connsIPv4[ufrag]; ok {
  202. m.closeAndLogError(conn)
  203. delete(m.connsIPv4, ufrag)
  204. }
  205. if conn, ok := m.connsIPv6[ufrag]; ok {
  206. m.closeAndLogError(conn)
  207. delete(m.connsIPv6, ufrag)
  208. }
  209. }
  210. func (m *TCPMuxDefault) getConn(ufrag string, isIPv6 bool) (val *tcpPacketConn, ok bool) {
  211. if isIPv6 {
  212. val, ok = m.connsIPv6[ufrag]
  213. } else {
  214. val, ok = m.connsIPv4[ufrag]
  215. }
  216. return
  217. }
  218. const streamingPacketHeaderLen = 2
  219. // readStreamingPacket reads 1 packet from stream
  220. // read packet bytes https://tools.ietf.org/html/rfc4571#section-2
  221. // 2-byte length header prepends each packet:
  222. // 0 1 2 3
  223. // 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1
  224. // -----------------------------------------------------------------
  225. // | LENGTH | RTP or RTCP packet ... |
  226. // -----------------------------------------------------------------
  227. func readStreamingPacket(conn net.Conn, buf []byte) (int, error) {
  228. header := make([]byte, streamingPacketHeaderLen)
  229. var bytesRead, n int
  230. var err error
  231. for bytesRead < streamingPacketHeaderLen {
  232. if n, err = conn.Read(header[bytesRead:streamingPacketHeaderLen]); err != nil {
  233. return 0, err
  234. }
  235. bytesRead += n
  236. }
  237. length := int(binary.BigEndian.Uint16(header))
  238. if length > cap(buf) {
  239. return length, io.ErrShortBuffer
  240. }
  241. bytesRead = 0
  242. for bytesRead < length {
  243. if n, err = conn.Read(buf[bytesRead:length]); err != nil {
  244. return 0, err
  245. }
  246. bytesRead += n
  247. }
  248. return bytesRead, nil
  249. }
  250. func writeStreamingPacket(conn net.Conn, buf []byte) (int, error) {
  251. bufferCopy := make([]byte, streamingPacketHeaderLen+len(buf))
  252. binary.BigEndian.PutUint16(bufferCopy, uint16(len(buf)))
  253. copy(bufferCopy[2:], buf)
  254. n, err := conn.Write(bufferCopy)
  255. if err != nil {
  256. return 0, err
  257. }
  258. return n - streamingPacketHeaderLen, nil
  259. }