stream.go 9.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367
  1. package sctp
  2. import (
  3. "errors"
  4. "fmt"
  5. "io"
  6. "math"
  7. "sync"
  8. "sync/atomic"
  9. "github.com/pion/logging"
  10. )
  11. const (
  12. // ReliabilityTypeReliable is used for reliable transmission
  13. ReliabilityTypeReliable byte = 0
  14. // ReliabilityTypeRexmit is used for partial reliability by retransmission count
  15. ReliabilityTypeRexmit byte = 1
  16. // ReliabilityTypeTimed is used for partial reliability by retransmission duration
  17. ReliabilityTypeTimed byte = 2
  18. )
  19. var (
  20. errOutboundPacketTooLarge = errors.New("outbound packet larger than maximum message size")
  21. errStreamClosed = errors.New("Stream closed")
  22. )
  23. // Stream represents an SCTP stream
  24. type Stream struct {
  25. association *Association
  26. lock sync.RWMutex
  27. streamIdentifier uint16
  28. defaultPayloadType PayloadProtocolIdentifier
  29. reassemblyQueue *reassemblyQueue
  30. sequenceNumber uint16
  31. readNotifier *sync.Cond
  32. readErr error
  33. writeErr error
  34. unordered bool
  35. reliabilityType byte
  36. reliabilityValue uint32
  37. bufferedAmount uint64
  38. bufferedAmountLow uint64
  39. onBufferedAmountLow func()
  40. log logging.LeveledLogger
  41. name string
  42. }
  43. // StreamIdentifier returns the Stream identifier associated to the stream.
  44. func (s *Stream) StreamIdentifier() uint16 {
  45. s.lock.RLock()
  46. defer s.lock.RUnlock()
  47. return s.streamIdentifier
  48. }
  49. // SetDefaultPayloadType sets the default payload type used by Write.
  50. func (s *Stream) SetDefaultPayloadType(defaultPayloadType PayloadProtocolIdentifier) {
  51. atomic.StoreUint32((*uint32)(&s.defaultPayloadType), uint32(defaultPayloadType))
  52. }
  53. // SetReliabilityParams sets reliability parameters for this stream.
  54. func (s *Stream) SetReliabilityParams(unordered bool, relType byte, relVal uint32) {
  55. s.lock.Lock()
  56. defer s.lock.Unlock()
  57. s.setReliabilityParams(unordered, relType, relVal)
  58. }
  59. // setReliabilityParams sets reliability parameters for this stream.
  60. // The caller should hold the lock.
  61. func (s *Stream) setReliabilityParams(unordered bool, relType byte, relVal uint32) {
  62. s.log.Debugf("[%s] reliability params: ordered=%v type=%d value=%d",
  63. s.name, !unordered, relType, relVal)
  64. s.unordered = unordered
  65. s.reliabilityType = relType
  66. s.reliabilityValue = relVal
  67. }
  68. // Read reads a packet of len(p) bytes, dropping the Payload Protocol Identifier.
  69. // Returns EOF when the stream is reset or an error if the stream is closed
  70. // otherwise.
  71. func (s *Stream) Read(p []byte) (int, error) {
  72. n, _, err := s.ReadSCTP(p)
  73. return n, err
  74. }
  75. // ReadSCTP reads a packet of len(p) bytes and returns the associated Payload
  76. // Protocol Identifier.
  77. // Returns EOF when the stream is reset or an error if the stream is closed
  78. // otherwise.
  79. func (s *Stream) ReadSCTP(p []byte) (int, PayloadProtocolIdentifier, error) {
  80. s.lock.Lock()
  81. defer s.lock.Unlock()
  82. for {
  83. n, ppi, err := s.reassemblyQueue.read(p)
  84. if err == nil {
  85. return n, ppi, nil
  86. } else if errors.Is(err, io.ErrShortBuffer) {
  87. return 0, PayloadProtocolIdentifier(0), err
  88. }
  89. err = s.readErr
  90. if err != nil {
  91. return 0, PayloadProtocolIdentifier(0), err
  92. }
  93. s.readNotifier.Wait()
  94. }
  95. }
  96. func (s *Stream) handleData(pd *chunkPayloadData) {
  97. s.lock.Lock()
  98. defer s.lock.Unlock()
  99. var readable bool
  100. if s.reassemblyQueue.push(pd) {
  101. readable = s.reassemblyQueue.isReadable()
  102. s.log.Debugf("[%s] reassemblyQueue readable=%v", s.name, readable)
  103. if readable {
  104. s.log.Debugf("[%s] readNotifier.signal()", s.name)
  105. s.readNotifier.Signal()
  106. s.log.Debugf("[%s] readNotifier.signal() done", s.name)
  107. }
  108. }
  109. }
  110. func (s *Stream) handleForwardTSNForOrdered(ssn uint16) {
  111. var readable bool
  112. func() {
  113. s.lock.Lock()
  114. defer s.lock.Unlock()
  115. if s.unordered {
  116. return // unordered chunks are handled by handleForwardUnordered method
  117. }
  118. // Remove all chunks older than or equal to the new TSN from
  119. // the reassemblyQueue.
  120. s.reassemblyQueue.forwardTSNForOrdered(ssn)
  121. readable = s.reassemblyQueue.isReadable()
  122. }()
  123. // Notify the reader asynchronously if there's a data chunk to read.
  124. if readable {
  125. s.readNotifier.Signal()
  126. }
  127. }
  128. func (s *Stream) handleForwardTSNForUnordered(newCumulativeTSN uint32) {
  129. var readable bool
  130. func() {
  131. s.lock.Lock()
  132. defer s.lock.Unlock()
  133. if !s.unordered {
  134. return // ordered chunks are handled by handleForwardTSNOrdered method
  135. }
  136. // Remove all chunks older than or equal to the new TSN from
  137. // the reassemblyQueue.
  138. s.reassemblyQueue.forwardTSNForUnordered(newCumulativeTSN)
  139. readable = s.reassemblyQueue.isReadable()
  140. }()
  141. // Notify the reader asynchronously if there's a data chunk to read.
  142. if readable {
  143. s.readNotifier.Signal()
  144. }
  145. }
  146. // Write writes len(p) bytes from p with the default Payload Protocol Identifier
  147. func (s *Stream) Write(p []byte) (n int, err error) {
  148. ppi := PayloadProtocolIdentifier(atomic.LoadUint32((*uint32)(&s.defaultPayloadType)))
  149. return s.WriteSCTP(p, ppi)
  150. }
  151. // WriteSCTP writes len(p) bytes from p to the DTLS connection
  152. func (s *Stream) WriteSCTP(p []byte, ppi PayloadProtocolIdentifier) (n int, err error) {
  153. maxMessageSize := s.association.MaxMessageSize()
  154. if len(p) > int(maxMessageSize) {
  155. return 0, fmt.Errorf("%w: %v", errOutboundPacketTooLarge, math.MaxUint16)
  156. }
  157. switch s.association.getState() {
  158. case shutdownSent, shutdownAckSent, shutdownPending, shutdownReceived:
  159. s.lock.Lock()
  160. if s.writeErr == nil {
  161. s.writeErr = errStreamClosed
  162. }
  163. s.lock.Unlock()
  164. default:
  165. }
  166. s.lock.RLock()
  167. err = s.writeErr
  168. s.lock.RUnlock()
  169. if err != nil {
  170. return 0, err
  171. }
  172. chunks := s.packetize(p, ppi)
  173. return len(p), s.association.sendPayloadData(chunks)
  174. }
  175. func (s *Stream) packetize(raw []byte, ppi PayloadProtocolIdentifier) []*chunkPayloadData {
  176. s.lock.Lock()
  177. defer s.lock.Unlock()
  178. i := uint32(0)
  179. remaining := uint32(len(raw))
  180. // From draft-ietf-rtcweb-data-protocol-09, section 6:
  181. // All Data Channel Establishment Protocol messages MUST be sent using
  182. // ordered delivery and reliable transmission.
  183. unordered := ppi != PayloadTypeWebRTCDCEP && s.unordered
  184. var chunks []*chunkPayloadData
  185. var head *chunkPayloadData
  186. for remaining != 0 {
  187. fragmentSize := min32(s.association.maxPayloadSize, remaining)
  188. // Copy the userdata since we'll have to store it until acked
  189. // and the caller may re-use the buffer in the mean time
  190. userData := make([]byte, fragmentSize)
  191. copy(userData, raw[i:i+fragmentSize])
  192. chunk := &chunkPayloadData{
  193. streamIdentifier: s.streamIdentifier,
  194. userData: userData,
  195. unordered: unordered,
  196. beginningFragment: i == 0,
  197. endingFragment: remaining-fragmentSize == 0,
  198. immediateSack: false,
  199. payloadType: ppi,
  200. streamSequenceNumber: s.sequenceNumber,
  201. head: head,
  202. }
  203. if head == nil {
  204. head = chunk
  205. }
  206. chunks = append(chunks, chunk)
  207. remaining -= fragmentSize
  208. i += fragmentSize
  209. }
  210. // RFC 4960 Sec 6.6
  211. // Note: When transmitting ordered and unordered data, an endpoint does
  212. // not increment its Stream Sequence Number when transmitting a DATA
  213. // chunk with U flag set to 1.
  214. if !unordered {
  215. s.sequenceNumber++
  216. }
  217. s.bufferedAmount += uint64(len(raw))
  218. s.log.Tracef("[%s] bufferedAmount = %d", s.name, s.bufferedAmount)
  219. return chunks
  220. }
  221. // Close closes the write-direction of the stream.
  222. // Future calls to Write are not permitted after calling Close.
  223. func (s *Stream) Close() error {
  224. if sid, isOpen := func() (uint16, bool) {
  225. s.lock.Lock()
  226. defer s.lock.Unlock()
  227. isOpen := true
  228. if s.writeErr == nil {
  229. s.writeErr = errStreamClosed
  230. } else {
  231. isOpen = false
  232. }
  233. if s.readErr == nil {
  234. s.readErr = io.EOF
  235. } else {
  236. isOpen = false
  237. }
  238. s.readNotifier.Broadcast() // broadcast regardless
  239. return s.streamIdentifier, isOpen
  240. }(); isOpen {
  241. // Reset the outgoing stream
  242. // https://tools.ietf.org/html/rfc6525
  243. return s.association.sendResetRequest(sid)
  244. }
  245. return nil
  246. }
  247. // BufferedAmount returns the number of bytes of data currently queued to be sent over this stream.
  248. func (s *Stream) BufferedAmount() uint64 {
  249. s.lock.RLock()
  250. defer s.lock.RUnlock()
  251. return s.bufferedAmount
  252. }
  253. // BufferedAmountLowThreshold returns the number of bytes of buffered outgoing data that is
  254. // considered "low." Defaults to 0.
  255. func (s *Stream) BufferedAmountLowThreshold() uint64 {
  256. s.lock.RLock()
  257. defer s.lock.RUnlock()
  258. return s.bufferedAmountLow
  259. }
  260. // SetBufferedAmountLowThreshold is used to update the threshold.
  261. // See BufferedAmountLowThreshold().
  262. func (s *Stream) SetBufferedAmountLowThreshold(th uint64) {
  263. s.lock.Lock()
  264. defer s.lock.Unlock()
  265. s.bufferedAmountLow = th
  266. }
  267. // OnBufferedAmountLow sets the callback handler which would be called when the number of
  268. // bytes of outgoing data buffered is lower than the threshold.
  269. func (s *Stream) OnBufferedAmountLow(f func()) {
  270. s.lock.Lock()
  271. defer s.lock.Unlock()
  272. s.onBufferedAmountLow = f
  273. }
  274. // This method is called by association's readLoop (go-)routine to notify this stream
  275. // of the specified amount of outgoing data has been delivered to the peer.
  276. func (s *Stream) onBufferReleased(nBytesReleased int) {
  277. if nBytesReleased <= 0 {
  278. return
  279. }
  280. s.lock.Lock()
  281. fromAmount := s.bufferedAmount
  282. if s.bufferedAmount < uint64(nBytesReleased) {
  283. s.bufferedAmount = 0
  284. s.log.Errorf("[%s] released buffer size %d should be <= %d",
  285. s.name, nBytesReleased, s.bufferedAmount)
  286. } else {
  287. s.bufferedAmount -= uint64(nBytesReleased)
  288. }
  289. s.log.Tracef("[%s] bufferedAmount = %d", s.name, s.bufferedAmount)
  290. if s.onBufferedAmountLow != nil && fromAmount > s.bufferedAmountLow && s.bufferedAmount <= s.bufferedAmountLow {
  291. f := s.onBufferedAmountLow
  292. s.lock.Unlock()
  293. f()
  294. return
  295. }
  296. s.lock.Unlock()
  297. }
  298. func (s *Stream) getNumBytesInReassemblyQueue() int {
  299. // No lock is required as it reads the size with atomic load function.
  300. return s.reassemblyQueue.getNumBytes()
  301. }