server.go 4.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187
  1. // Package turn contains the public API for pion/turn, a toolkit for building TURN clients and servers
  2. package turn
  3. import (
  4. "fmt"
  5. "net"
  6. "sync"
  7. "time"
  8. "github.com/pion/logging"
  9. "github.com/pion/turn/v2/internal/allocation"
  10. "github.com/pion/turn/v2/internal/proto"
  11. "github.com/pion/turn/v2/internal/server"
  12. )
  13. const (
  14. defaultInboundMTU = 1600
  15. )
  16. // Server is an instance of the Pion TURN Server
  17. type Server struct {
  18. log logging.LeveledLogger
  19. authHandler AuthHandler
  20. realm string
  21. channelBindTimeout time.Duration
  22. nonces *sync.Map
  23. packetConnConfigs []PacketConnConfig
  24. listenerConfigs []ListenerConfig
  25. allocationManagers []*allocation.Manager
  26. inboundMTU int
  27. }
  28. // NewServer creates the Pion TURN server
  29. func NewServer(config ServerConfig) (*Server, error) {
  30. if err := config.validate(); err != nil {
  31. return nil, err
  32. }
  33. loggerFactory := config.LoggerFactory
  34. if loggerFactory == nil {
  35. loggerFactory = logging.NewDefaultLoggerFactory()
  36. }
  37. mtu := defaultInboundMTU
  38. if config.InboundMTU != 0 {
  39. mtu = config.InboundMTU
  40. }
  41. s := &Server{
  42. log: loggerFactory.NewLogger("turn"),
  43. authHandler: config.AuthHandler,
  44. realm: config.Realm,
  45. channelBindTimeout: config.ChannelBindTimeout,
  46. packetConnConfigs: config.PacketConnConfigs,
  47. listenerConfigs: config.ListenerConfigs,
  48. allocationManagers: make([]*allocation.Manager, len(config.PacketConnConfigs)+len(config.ListenerConfigs)),
  49. nonces: &sync.Map{},
  50. inboundMTU: mtu,
  51. }
  52. if s.channelBindTimeout == 0 {
  53. s.channelBindTimeout = proto.DefaultLifetime
  54. }
  55. for i := range s.packetConnConfigs {
  56. go func(i int, p PacketConnConfig) {
  57. allocationManager, err := allocation.NewManager(allocation.ManagerConfig{
  58. AllocatePacketConn: p.RelayAddressGenerator.AllocatePacketConn,
  59. AllocateConn: p.RelayAddressGenerator.AllocateConn,
  60. LeveledLogger: s.log,
  61. })
  62. if err != nil {
  63. s.log.Errorf("exit read loop on error: %s", err.Error())
  64. return
  65. }
  66. s.allocationManagers[i] = allocationManager
  67. defer func() {
  68. if err := allocationManager.Close(); err != nil {
  69. s.log.Errorf("Failed to close AllocationManager: %s", err.Error())
  70. }
  71. }()
  72. s.readLoop(p.PacketConn, allocationManager)
  73. }(i, s.packetConnConfigs[i])
  74. }
  75. for i, listener := range s.listenerConfigs {
  76. go func(i int, l ListenerConfig) {
  77. allocationManager, err := allocation.NewManager(allocation.ManagerConfig{
  78. AllocatePacketConn: l.RelayAddressGenerator.AllocatePacketConn,
  79. AllocateConn: l.RelayAddressGenerator.AllocateConn,
  80. LeveledLogger: s.log,
  81. })
  82. if err != nil {
  83. s.log.Errorf("exit read loop on error: %s", err.Error())
  84. return
  85. }
  86. s.allocationManagers[i] = allocationManager
  87. defer func() {
  88. if err := allocationManager.Close(); err != nil {
  89. s.log.Errorf("Failed to close AllocationManager: %s", err.Error())
  90. }
  91. }()
  92. for {
  93. conn, err := l.Listener.Accept()
  94. if err != nil {
  95. s.log.Debugf("exit accept loop on error: %s", err.Error())
  96. return
  97. }
  98. go s.readLoop(NewSTUNConn(conn), allocationManager)
  99. }
  100. }(i+len(s.packetConnConfigs), listener)
  101. }
  102. return s, nil
  103. }
  104. // AllocationCount returns the number of active allocations. It can be used to drain the server before closing
  105. func (s *Server) AllocationCount() int {
  106. allocations := 0
  107. for _, manager := range s.allocationManagers {
  108. if manager != nil {
  109. allocations += manager.AllocationCount()
  110. }
  111. }
  112. return allocations
  113. }
  114. // Close stops the TURN Server. It cleans up any associated state and closes all connections it is managing
  115. func (s *Server) Close() error {
  116. var errors []error
  117. for _, p := range s.packetConnConfigs {
  118. if err := p.PacketConn.Close(); err != nil {
  119. errors = append(errors, err)
  120. }
  121. }
  122. for _, l := range s.listenerConfigs {
  123. if err := l.Listener.Close(); err != nil {
  124. errors = append(errors, err)
  125. }
  126. }
  127. if len(errors) == 0 {
  128. return nil
  129. }
  130. err := errFailedToClose
  131. for _, e := range errors {
  132. err = fmt.Errorf("%s; Close error (%v) ", err.Error(), e) //nolint:goerr113
  133. }
  134. return err
  135. }
  136. func (s *Server) readLoop(p net.PacketConn, allocationManager *allocation.Manager) {
  137. buf := make([]byte, s.inboundMTU)
  138. for {
  139. n, addr, err := p.ReadFrom(buf)
  140. switch {
  141. case err != nil:
  142. s.log.Debugf("exit read loop on error: %s", err.Error())
  143. return
  144. case n >= s.inboundMTU:
  145. s.log.Debugf("Read bytes exceeded MTU, packet is possibly truncated")
  146. }
  147. if err := server.HandleRequest(server.Request{
  148. Conn: p,
  149. SrcAddr: addr,
  150. Buff: buf[:n],
  151. Log: s.log,
  152. AuthHandler: s.authHandler,
  153. Realm: s.realm,
  154. AllocationManager: allocationManager,
  155. ChannelBindTimeout: s.channelBindTimeout,
  156. Nonces: s.nonces,
  157. }); err != nil {
  158. s.log.Errorf("error when handling datagram: %v", err)
  159. }
  160. }
  161. }