sctptransport.go 9.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417
  1. //go:build !js
  2. // +build !js
  3. package webrtc
  4. import (
  5. "errors"
  6. "io"
  7. "math"
  8. "sync"
  9. "time"
  10. "github.com/pion/datachannel"
  11. "github.com/pion/logging"
  12. "github.com/pion/sctp"
  13. "github.com/pion/webrtc/v3/pkg/rtcerr"
  14. )
  15. const sctpMaxChannels = uint16(65535)
  16. // SCTPTransport provides details about the SCTP transport.
  17. type SCTPTransport struct {
  18. lock sync.RWMutex
  19. dtlsTransport *DTLSTransport
  20. // State represents the current state of the SCTP transport.
  21. state SCTPTransportState
  22. // SCTPTransportState doesn't have an enum to distinguish between New/Connecting
  23. // so we need a dedicated field
  24. isStarted bool
  25. // MaxMessageSize represents the maximum size of data that can be passed to
  26. // DataChannel's send() method.
  27. maxMessageSize float64
  28. // MaxChannels represents the maximum amount of DataChannel's that can
  29. // be used simultaneously.
  30. maxChannels *uint16
  31. // OnStateChange func()
  32. onErrorHandler func(error)
  33. sctpAssociation *sctp.Association
  34. onDataChannelHandler func(*DataChannel)
  35. onDataChannelOpenedHandler func(*DataChannel)
  36. // DataChannels
  37. dataChannels []*DataChannel
  38. dataChannelsOpened uint32
  39. dataChannelsRequested uint32
  40. dataChannelsAccepted uint32
  41. api *API
  42. log logging.LeveledLogger
  43. }
  44. // NewSCTPTransport creates a new SCTPTransport.
  45. // This constructor is part of the ORTC API. It is not
  46. // meant to be used together with the basic WebRTC API.
  47. func (api *API) NewSCTPTransport(dtls *DTLSTransport) *SCTPTransport {
  48. res := &SCTPTransport{
  49. dtlsTransport: dtls,
  50. state: SCTPTransportStateConnecting,
  51. api: api,
  52. log: api.settingEngine.LoggerFactory.NewLogger("ortc"),
  53. }
  54. res.updateMessageSize()
  55. res.updateMaxChannels()
  56. return res
  57. }
  58. // Transport returns the DTLSTransport instance the SCTPTransport is sending over.
  59. func (r *SCTPTransport) Transport() *DTLSTransport {
  60. r.lock.RLock()
  61. defer r.lock.RUnlock()
  62. return r.dtlsTransport
  63. }
  64. // GetCapabilities returns the SCTPCapabilities of the SCTPTransport.
  65. func (r *SCTPTransport) GetCapabilities() SCTPCapabilities {
  66. return SCTPCapabilities{
  67. MaxMessageSize: 0,
  68. }
  69. }
  70. // Start the SCTPTransport. Since both local and remote parties must mutually
  71. // create an SCTPTransport, SCTP SO (Simultaneous Open) is used to establish
  72. // a connection over SCTP.
  73. func (r *SCTPTransport) Start(remoteCaps SCTPCapabilities) error {
  74. if r.isStarted {
  75. return nil
  76. }
  77. r.isStarted = true
  78. dtlsTransport := r.Transport()
  79. if dtlsTransport == nil || dtlsTransport.conn == nil {
  80. return errSCTPTransportDTLS
  81. }
  82. sctpAssociation, err := sctp.Client(sctp.Config{
  83. NetConn: dtlsTransport.conn,
  84. MaxReceiveBufferSize: r.api.settingEngine.sctp.maxReceiveBufferSize,
  85. LoggerFactory: r.api.settingEngine.LoggerFactory,
  86. })
  87. if err != nil {
  88. return err
  89. }
  90. r.lock.Lock()
  91. r.sctpAssociation = sctpAssociation
  92. r.state = SCTPTransportStateConnected
  93. dataChannels := append([]*DataChannel{}, r.dataChannels...)
  94. r.lock.Unlock()
  95. var openedDCCount uint32
  96. for _, d := range dataChannels {
  97. if d.ReadyState() == DataChannelStateConnecting {
  98. err := d.open(r)
  99. if err != nil {
  100. r.log.Warnf("failed to open data channel: %s", err)
  101. continue
  102. }
  103. openedDCCount++
  104. }
  105. }
  106. r.lock.Lock()
  107. r.dataChannelsOpened += openedDCCount
  108. r.lock.Unlock()
  109. go r.acceptDataChannels(sctpAssociation)
  110. return nil
  111. }
  112. // Stop stops the SCTPTransport
  113. func (r *SCTPTransport) Stop() error {
  114. r.lock.Lock()
  115. defer r.lock.Unlock()
  116. if r.sctpAssociation == nil {
  117. return nil
  118. }
  119. err := r.sctpAssociation.Close()
  120. if err != nil {
  121. return err
  122. }
  123. r.sctpAssociation = nil
  124. r.state = SCTPTransportStateClosed
  125. return nil
  126. }
  127. func (r *SCTPTransport) acceptDataChannels(a *sctp.Association) {
  128. r.lock.RLock()
  129. dataChannels := make([]*datachannel.DataChannel, 0, len(r.dataChannels))
  130. for _, dc := range r.dataChannels {
  131. dc.mu.Lock()
  132. isNil := dc.dataChannel == nil
  133. dc.mu.Unlock()
  134. if isNil {
  135. continue
  136. }
  137. dataChannels = append(dataChannels, dc.dataChannel)
  138. }
  139. r.lock.RUnlock()
  140. ACCEPT:
  141. for {
  142. dc, err := datachannel.Accept(a, &datachannel.Config{
  143. LoggerFactory: r.api.settingEngine.LoggerFactory,
  144. }, dataChannels...)
  145. if err != nil {
  146. if !errors.Is(err, io.EOF) {
  147. r.log.Errorf("Failed to accept data channel: %v", err)
  148. r.onError(err)
  149. }
  150. return
  151. }
  152. for _, ch := range dataChannels {
  153. if ch.StreamIdentifier() == dc.StreamIdentifier() {
  154. continue ACCEPT
  155. }
  156. }
  157. var (
  158. maxRetransmits *uint16
  159. maxPacketLifeTime *uint16
  160. )
  161. val := uint16(dc.Config.ReliabilityParameter)
  162. ordered := true
  163. switch dc.Config.ChannelType {
  164. case datachannel.ChannelTypeReliable:
  165. ordered = true
  166. case datachannel.ChannelTypeReliableUnordered:
  167. ordered = false
  168. case datachannel.ChannelTypePartialReliableRexmit:
  169. ordered = true
  170. maxRetransmits = &val
  171. case datachannel.ChannelTypePartialReliableRexmitUnordered:
  172. ordered = false
  173. maxRetransmits = &val
  174. case datachannel.ChannelTypePartialReliableTimed:
  175. ordered = true
  176. maxPacketLifeTime = &val
  177. case datachannel.ChannelTypePartialReliableTimedUnordered:
  178. ordered = false
  179. maxPacketLifeTime = &val
  180. default:
  181. }
  182. sid := dc.StreamIdentifier()
  183. rtcDC, err := r.api.newDataChannel(&DataChannelParameters{
  184. ID: &sid,
  185. Label: dc.Config.Label,
  186. Protocol: dc.Config.Protocol,
  187. Negotiated: dc.Config.Negotiated,
  188. Ordered: ordered,
  189. MaxPacketLifeTime: maxPacketLifeTime,
  190. MaxRetransmits: maxRetransmits,
  191. }, r.api.settingEngine.LoggerFactory.NewLogger("ortc"))
  192. if err != nil {
  193. r.log.Errorf("Failed to accept data channel: %v", err)
  194. r.onError(err)
  195. return
  196. }
  197. <-r.onDataChannel(rtcDC)
  198. rtcDC.handleOpen(dc, true, dc.Config.Negotiated)
  199. r.lock.Lock()
  200. r.dataChannelsOpened++
  201. handler := r.onDataChannelOpenedHandler
  202. r.lock.Unlock()
  203. if handler != nil {
  204. handler(rtcDC)
  205. }
  206. }
  207. }
  208. // OnError sets an event handler which is invoked when
  209. // the SCTP connection error occurs.
  210. func (r *SCTPTransport) OnError(f func(err error)) {
  211. r.lock.Lock()
  212. defer r.lock.Unlock()
  213. r.onErrorHandler = f
  214. }
  215. func (r *SCTPTransport) onError(err error) {
  216. r.lock.RLock()
  217. handler := r.onErrorHandler
  218. r.lock.RUnlock()
  219. if handler != nil {
  220. go handler(err)
  221. }
  222. }
  223. // OnDataChannel sets an event handler which is invoked when a data
  224. // channel message arrives from a remote peer.
  225. func (r *SCTPTransport) OnDataChannel(f func(*DataChannel)) {
  226. r.lock.Lock()
  227. defer r.lock.Unlock()
  228. r.onDataChannelHandler = f
  229. }
  230. // OnDataChannelOpened sets an event handler which is invoked when a data
  231. // channel is opened
  232. func (r *SCTPTransport) OnDataChannelOpened(f func(*DataChannel)) {
  233. r.lock.Lock()
  234. defer r.lock.Unlock()
  235. r.onDataChannelOpenedHandler = f
  236. }
  237. func (r *SCTPTransport) onDataChannel(dc *DataChannel) (done chan struct{}) {
  238. r.lock.Lock()
  239. r.dataChannels = append(r.dataChannels, dc)
  240. r.dataChannelsAccepted++
  241. handler := r.onDataChannelHandler
  242. r.lock.Unlock()
  243. done = make(chan struct{})
  244. if handler == nil || dc == nil {
  245. close(done)
  246. return
  247. }
  248. // Run this synchronously to allow setup done in onDataChannelFn()
  249. // to complete before datachannel event handlers might be called.
  250. go func() {
  251. handler(dc)
  252. close(done)
  253. }()
  254. return
  255. }
  256. func (r *SCTPTransport) updateMessageSize() {
  257. r.lock.Lock()
  258. defer r.lock.Unlock()
  259. var remoteMaxMessageSize float64 = 65536 // pion/webrtc#758
  260. var canSendSize float64 = 65536 // pion/webrtc#758
  261. r.maxMessageSize = r.calcMessageSize(remoteMaxMessageSize, canSendSize)
  262. }
  263. func (r *SCTPTransport) calcMessageSize(remoteMaxMessageSize, canSendSize float64) float64 {
  264. switch {
  265. case remoteMaxMessageSize == 0 &&
  266. canSendSize == 0:
  267. return math.Inf(1)
  268. case remoteMaxMessageSize == 0:
  269. return canSendSize
  270. case canSendSize == 0:
  271. return remoteMaxMessageSize
  272. case canSendSize > remoteMaxMessageSize:
  273. return remoteMaxMessageSize
  274. default:
  275. return canSendSize
  276. }
  277. }
  278. func (r *SCTPTransport) updateMaxChannels() {
  279. val := sctpMaxChannels
  280. r.maxChannels = &val
  281. }
  282. // MaxChannels is the maximum number of RTCDataChannels that can be open simultaneously.
  283. func (r *SCTPTransport) MaxChannels() uint16 {
  284. r.lock.Lock()
  285. defer r.lock.Unlock()
  286. if r.maxChannels == nil {
  287. return sctpMaxChannels
  288. }
  289. return *r.maxChannels
  290. }
  291. // State returns the current state of the SCTPTransport
  292. func (r *SCTPTransport) State() SCTPTransportState {
  293. r.lock.RLock()
  294. defer r.lock.RUnlock()
  295. return r.state
  296. }
  297. func (r *SCTPTransport) collectStats(collector *statsReportCollector) {
  298. collector.Collecting()
  299. stats := TransportStats{
  300. Timestamp: statsTimestampFrom(time.Now()),
  301. Type: StatsTypeTransport,
  302. ID: "sctpTransport",
  303. }
  304. association := r.association()
  305. if association != nil {
  306. stats.BytesSent = association.BytesSent()
  307. stats.BytesReceived = association.BytesReceived()
  308. }
  309. collector.Collect(stats.ID, stats)
  310. }
  311. func (r *SCTPTransport) generateAndSetDataChannelID(dtlsRole DTLSRole, idOut **uint16) error {
  312. var id uint16
  313. if dtlsRole != DTLSRoleClient {
  314. id++
  315. }
  316. max := r.MaxChannels()
  317. r.lock.Lock()
  318. defer r.lock.Unlock()
  319. // Create map of ids so we can compare without double-looping each time.
  320. idsMap := make(map[uint16]struct{}, len(r.dataChannels))
  321. for _, dc := range r.dataChannels {
  322. if dc.ID() == nil {
  323. continue
  324. }
  325. idsMap[*dc.ID()] = struct{}{}
  326. }
  327. for ; id < max-1; id += 2 {
  328. if _, ok := idsMap[id]; ok {
  329. continue
  330. }
  331. *idOut = &id
  332. return nil
  333. }
  334. return &rtcerr.OperationError{Err: ErrMaxDataChannelID}
  335. }
  336. func (r *SCTPTransport) association() *sctp.Association {
  337. if r == nil {
  338. return nil
  339. }
  340. r.lock.RLock()
  341. association := r.sctpAssociation
  342. r.lock.RUnlock()
  343. return association
  344. }