datachannel.go 11 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391
  1. // Package datachannel implements WebRTC Data Channels
  2. package datachannel
  3. import (
  4. "fmt"
  5. "io"
  6. "sync"
  7. "sync/atomic"
  8. "github.com/pion/logging"
  9. "github.com/pion/sctp"
  10. )
  11. const receiveMTU = 8192
  12. // Reader is an extended io.Reader
  13. // that also returns if the message is text.
  14. type Reader interface {
  15. ReadDataChannel([]byte) (int, bool, error)
  16. }
  17. // Writer is an extended io.Writer
  18. // that also allows indicating if a message is text.
  19. type Writer interface {
  20. WriteDataChannel([]byte, bool) (int, error)
  21. }
  22. // ReadWriteCloser is an extended io.ReadWriteCloser
  23. // that also implements our Reader and Writer.
  24. type ReadWriteCloser interface {
  25. io.Reader
  26. io.Writer
  27. Reader
  28. Writer
  29. io.Closer
  30. }
  31. // DataChannel represents a data channel
  32. type DataChannel struct {
  33. Config
  34. // stats
  35. messagesSent uint32
  36. messagesReceived uint32
  37. bytesSent uint64
  38. bytesReceived uint64
  39. mu sync.Mutex
  40. onOpenCompleteHandler func()
  41. openCompleteHandlerOnce sync.Once
  42. stream *sctp.Stream
  43. log logging.LeveledLogger
  44. }
  45. // Config is used to configure the data channel.
  46. type Config struct {
  47. ChannelType ChannelType
  48. Negotiated bool
  49. Priority uint16
  50. ReliabilityParameter uint32
  51. Label string
  52. Protocol string
  53. LoggerFactory logging.LoggerFactory
  54. }
  55. func newDataChannel(stream *sctp.Stream, config *Config) (*DataChannel, error) {
  56. return &DataChannel{
  57. Config: *config,
  58. stream: stream,
  59. log: config.LoggerFactory.NewLogger("datachannel"),
  60. }, nil
  61. }
  62. // Dial opens a data channels over SCTP
  63. func Dial(a *sctp.Association, id uint16, config *Config) (*DataChannel, error) {
  64. stream, err := a.OpenStream(id, sctp.PayloadTypeWebRTCBinary)
  65. if err != nil {
  66. return nil, err
  67. }
  68. dc, err := Client(stream, config)
  69. if err != nil {
  70. return nil, err
  71. }
  72. return dc, nil
  73. }
  74. // Client opens a data channel over an SCTP stream
  75. func Client(stream *sctp.Stream, config *Config) (*DataChannel, error) {
  76. msg := &channelOpen{
  77. ChannelType: config.ChannelType,
  78. Priority: config.Priority,
  79. ReliabilityParameter: config.ReliabilityParameter,
  80. Label: []byte(config.Label),
  81. Protocol: []byte(config.Protocol),
  82. }
  83. if !config.Negotiated {
  84. rawMsg, err := msg.Marshal()
  85. if err != nil {
  86. return nil, fmt.Errorf("failed to marshal ChannelOpen %w", err)
  87. }
  88. if _, err = stream.WriteSCTP(rawMsg, sctp.PayloadTypeWebRTCDCEP); err != nil {
  89. return nil, fmt.Errorf("failed to send ChannelOpen %w", err)
  90. }
  91. }
  92. return newDataChannel(stream, config)
  93. }
  94. // Accept is used to accept incoming data channels over SCTP
  95. func Accept(a *sctp.Association, config *Config, existingChannels ...*DataChannel) (*DataChannel, error) {
  96. stream, err := a.AcceptStream()
  97. if err != nil {
  98. return nil, err
  99. }
  100. for _, ch := range existingChannels {
  101. if ch.StreamIdentifier() == stream.StreamIdentifier() {
  102. ch.stream.SetDefaultPayloadType(sctp.PayloadTypeWebRTCBinary)
  103. return ch, nil
  104. }
  105. }
  106. stream.SetDefaultPayloadType(sctp.PayloadTypeWebRTCBinary)
  107. dc, err := Server(stream, config)
  108. if err != nil {
  109. return nil, err
  110. }
  111. return dc, nil
  112. }
  113. // Server accepts a data channel over an SCTP stream
  114. func Server(stream *sctp.Stream, config *Config) (*DataChannel, error) {
  115. buffer := make([]byte, receiveMTU)
  116. n, ppi, err := stream.ReadSCTP(buffer)
  117. if err != nil {
  118. return nil, err
  119. }
  120. if ppi != sctp.PayloadTypeWebRTCDCEP {
  121. return nil, fmt.Errorf("%w %s", ErrInvalidPayloadProtocolIdentifier, ppi)
  122. }
  123. openMsg, err := parseExpectDataChannelOpen(buffer[:n])
  124. if err != nil {
  125. return nil, fmt.Errorf("failed to parse DataChannelOpen packet %w", err)
  126. }
  127. config.ChannelType = openMsg.ChannelType
  128. config.Priority = openMsg.Priority
  129. config.ReliabilityParameter = openMsg.ReliabilityParameter
  130. config.Label = string(openMsg.Label)
  131. config.Protocol = string(openMsg.Protocol)
  132. dataChannel, err := newDataChannel(stream, config)
  133. if err != nil {
  134. return nil, err
  135. }
  136. err = dataChannel.writeDataChannelAck()
  137. if err != nil {
  138. return nil, err
  139. }
  140. err = dataChannel.commitReliabilityParams()
  141. if err != nil {
  142. return nil, err
  143. }
  144. return dataChannel, nil
  145. }
  146. // Read reads a packet of len(p) bytes as binary data
  147. func (c *DataChannel) Read(p []byte) (int, error) {
  148. n, _, err := c.ReadDataChannel(p)
  149. return n, err
  150. }
  151. // ReadDataChannel reads a packet of len(p) bytes
  152. func (c *DataChannel) ReadDataChannel(p []byte) (int, bool, error) {
  153. for {
  154. n, ppi, err := c.stream.ReadSCTP(p)
  155. if err == io.EOF {
  156. // When the peer sees that an incoming stream was
  157. // reset, it also resets its corresponding outgoing stream.
  158. if closeErr := c.stream.Close(); closeErr != nil {
  159. return 0, false, closeErr
  160. }
  161. }
  162. if err != nil {
  163. return 0, false, err
  164. }
  165. if ppi == sctp.PayloadTypeWebRTCDCEP {
  166. if err = c.handleDCEP(p[:n]); err != nil {
  167. c.log.Errorf("Failed to handle DCEP: %s", err.Error())
  168. }
  169. continue
  170. } else if ppi == sctp.PayloadTypeWebRTCBinaryEmpty || ppi == sctp.PayloadTypeWebRTCStringEmpty {
  171. n = 0
  172. }
  173. atomic.AddUint32(&c.messagesReceived, 1)
  174. atomic.AddUint64(&c.bytesReceived, uint64(n))
  175. isString := ppi == sctp.PayloadTypeWebRTCString || ppi == sctp.PayloadTypeWebRTCStringEmpty
  176. return n, isString, err
  177. }
  178. }
  179. // MessagesSent returns the number of messages sent
  180. func (c *DataChannel) MessagesSent() uint32 {
  181. return atomic.LoadUint32(&c.messagesSent)
  182. }
  183. // MessagesReceived returns the number of messages received
  184. func (c *DataChannel) MessagesReceived() uint32 {
  185. return atomic.LoadUint32(&c.messagesReceived)
  186. }
  187. // OnOpen sets an event handler which is invoked when
  188. // a DATA_CHANNEL_ACK message is received.
  189. // The handler is called only on thefor the channel opened
  190. // https://datatracker.ietf.org/doc/html/draft-ietf-rtcweb-data-protocol-09#section-5.2
  191. func (c *DataChannel) OnOpen(f func()) {
  192. c.mu.Lock()
  193. c.openCompleteHandlerOnce = sync.Once{}
  194. c.onOpenCompleteHandler = f
  195. c.mu.Unlock()
  196. }
  197. func (c *DataChannel) onOpenComplete() {
  198. c.mu.Lock()
  199. hdlr := c.onOpenCompleteHandler
  200. c.mu.Unlock()
  201. if hdlr != nil {
  202. go c.openCompleteHandlerOnce.Do(func() {
  203. hdlr()
  204. })
  205. }
  206. }
  207. // BytesSent returns the number of bytes sent
  208. func (c *DataChannel) BytesSent() uint64 {
  209. return atomic.LoadUint64(&c.bytesSent)
  210. }
  211. // BytesReceived returns the number of bytes received
  212. func (c *DataChannel) BytesReceived() uint64 {
  213. return atomic.LoadUint64(&c.bytesReceived)
  214. }
  215. // StreamIdentifier returns the Stream identifier associated to the stream.
  216. func (c *DataChannel) StreamIdentifier() uint16 {
  217. return c.stream.StreamIdentifier()
  218. }
  219. func (c *DataChannel) handleDCEP(data []byte) error {
  220. msg, err := parse(data)
  221. if err != nil {
  222. return fmt.Errorf("failed to parse DataChannel packet %w", err)
  223. }
  224. switch msg := msg.(type) {
  225. case *channelAck:
  226. c.log.Debug("Received DATA_CHANNEL_ACK")
  227. if err = c.commitReliabilityParams(); err != nil {
  228. return err
  229. }
  230. c.onOpenComplete()
  231. default:
  232. return fmt.Errorf("%w %v", ErrInvalidMessageType, msg)
  233. }
  234. return nil
  235. }
  236. // Write writes len(p) bytes from p as binary data
  237. func (c *DataChannel) Write(p []byte) (n int, err error) {
  238. return c.WriteDataChannel(p, false)
  239. }
  240. // WriteDataChannel writes len(p) bytes from p
  241. func (c *DataChannel) WriteDataChannel(p []byte, isString bool) (n int, err error) {
  242. // https://tools.ietf.org/html/draft-ietf-rtcweb-data-channel-12#section-6.6
  243. // SCTP does not support the sending of empty user messages. Therefore,
  244. // if an empty message has to be sent, the appropriate PPID (WebRTC
  245. // String Empty or WebRTC Binary Empty) is used and the SCTP user
  246. // message of one zero byte is sent. When receiving an SCTP user
  247. // message with one of these PPIDs, the receiver MUST ignore the SCTP
  248. // user message and process it as an empty message.
  249. var ppi sctp.PayloadProtocolIdentifier
  250. switch {
  251. case !isString && len(p) > 0:
  252. ppi = sctp.PayloadTypeWebRTCBinary
  253. case !isString && len(p) == 0:
  254. ppi = sctp.PayloadTypeWebRTCBinaryEmpty
  255. case isString && len(p) > 0:
  256. ppi = sctp.PayloadTypeWebRTCString
  257. case isString && len(p) == 0:
  258. ppi = sctp.PayloadTypeWebRTCStringEmpty
  259. }
  260. atomic.AddUint32(&c.messagesSent, 1)
  261. atomic.AddUint64(&c.bytesSent, uint64(len(p)))
  262. if len(p) == 0 {
  263. _, err := c.stream.WriteSCTP([]byte{0}, ppi)
  264. return 0, err
  265. }
  266. return c.stream.WriteSCTP(p, ppi)
  267. }
  268. func (c *DataChannel) writeDataChannelAck() error {
  269. ack := channelAck{}
  270. ackMsg, err := ack.Marshal()
  271. if err != nil {
  272. return fmt.Errorf("failed to marshal ChannelOpen ACK: %w", err)
  273. }
  274. if _, err = c.stream.WriteSCTP(ackMsg, sctp.PayloadTypeWebRTCDCEP); err != nil {
  275. return fmt.Errorf("failed to send ChannelOpen ACK: %w", err)
  276. }
  277. return err
  278. }
  279. // Close closes the DataChannel and the underlying SCTP stream.
  280. func (c *DataChannel) Close() error {
  281. // https://tools.ietf.org/html/draft-ietf-rtcweb-data-channel-13#section-6.7
  282. // Closing of a data channel MUST be signaled by resetting the
  283. // corresponding outgoing streams [RFC6525]. This means that if one
  284. // side decides to close the data channel, it resets the corresponding
  285. // outgoing stream. When the peer sees that an incoming stream was
  286. // reset, it also resets its corresponding outgoing stream. Once this
  287. // is completed, the data channel is closed. Resetting a stream sets
  288. // the Stream Sequence Numbers (SSNs) of the stream back to 'zero' with
  289. // a corresponding notification to the application layer that the reset
  290. // has been performed. Streams are available for reuse after a reset
  291. // has been performed.
  292. return c.stream.Close()
  293. }
  294. // BufferedAmount returns the number of bytes of data currently queued to be
  295. // sent over this stream.
  296. func (c *DataChannel) BufferedAmount() uint64 {
  297. return c.stream.BufferedAmount()
  298. }
  299. // BufferedAmountLowThreshold returns the number of bytes of buffered outgoing
  300. // data that is considered "low." Defaults to 0.
  301. func (c *DataChannel) BufferedAmountLowThreshold() uint64 {
  302. return c.stream.BufferedAmountLowThreshold()
  303. }
  304. // SetBufferedAmountLowThreshold is used to update the threshold.
  305. // See BufferedAmountLowThreshold().
  306. func (c *DataChannel) SetBufferedAmountLowThreshold(th uint64) {
  307. c.stream.SetBufferedAmountLowThreshold(th)
  308. }
  309. // OnBufferedAmountLow sets the callback handler which would be called when the
  310. // number of bytes of outgoing data buffered is lower than the threshold.
  311. func (c *DataChannel) OnBufferedAmountLow(f func()) {
  312. c.stream.OnBufferedAmountLow(f)
  313. }
  314. func (c *DataChannel) commitReliabilityParams() error {
  315. switch c.Config.ChannelType {
  316. case ChannelTypeReliable:
  317. c.stream.SetReliabilityParams(false, sctp.ReliabilityTypeReliable, c.Config.ReliabilityParameter)
  318. case ChannelTypeReliableUnordered:
  319. c.stream.SetReliabilityParams(true, sctp.ReliabilityTypeReliable, c.Config.ReliabilityParameter)
  320. case ChannelTypePartialReliableRexmit:
  321. c.stream.SetReliabilityParams(false, sctp.ReliabilityTypeRexmit, c.Config.ReliabilityParameter)
  322. case ChannelTypePartialReliableRexmitUnordered:
  323. c.stream.SetReliabilityParams(true, sctp.ReliabilityTypeRexmit, c.Config.ReliabilityParameter)
  324. case ChannelTypePartialReliableTimed:
  325. c.stream.SetReliabilityParams(false, sctp.ReliabilityTypeTimed, c.Config.ReliabilityParameter)
  326. case ChannelTypePartialReliableTimedUnordered:
  327. c.stream.SetReliabilityParams(true, sctp.ReliabilityTypeTimed, c.Config.ReliabilityParameter)
  328. default:
  329. return fmt.Errorf("%w %v", ErrInvalidChannelType, c.Config.ChannelType)
  330. }
  331. return nil
  332. }