session_srtcp.go 4.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183
  1. package srtp
  2. import (
  3. "net"
  4. "time"
  5. "github.com/pion/logging"
  6. "github.com/pion/rtcp"
  7. )
  8. const defaultSessionSRTCPReplayProtectionWindow = 64
  9. // SessionSRTCP implements io.ReadWriteCloser and provides a bi-directional SRTCP session
  10. // SRTCP itself does not have a design like this, but it is common in most applications
  11. // for local/remote to each have their own keying material. This provides those patterns
  12. // instead of making everyone re-implement
  13. type SessionSRTCP struct {
  14. session
  15. writeStream *WriteStreamSRTCP
  16. }
  17. // NewSessionSRTCP creates a SRTCP session using conn as the underlying transport.
  18. func NewSessionSRTCP(conn net.Conn, config *Config) (*SessionSRTCP, error) { //nolint:dupl
  19. if config == nil {
  20. return nil, errNoConfig
  21. } else if conn == nil {
  22. return nil, errNoConn
  23. }
  24. loggerFactory := config.LoggerFactory
  25. if loggerFactory == nil {
  26. loggerFactory = logging.NewDefaultLoggerFactory()
  27. }
  28. localOpts := append(
  29. []ContextOption{},
  30. config.LocalOptions...,
  31. )
  32. remoteOpts := append(
  33. []ContextOption{
  34. // Default options
  35. SRTCPReplayProtection(defaultSessionSRTCPReplayProtectionWindow),
  36. },
  37. config.RemoteOptions...,
  38. )
  39. s := &SessionSRTCP{
  40. session: session{
  41. nextConn: conn,
  42. localOptions: localOpts,
  43. remoteOptions: remoteOpts,
  44. readStreams: map[uint32]readStream{},
  45. newStream: make(chan readStream),
  46. started: make(chan interface{}),
  47. closed: make(chan interface{}),
  48. bufferFactory: config.BufferFactory,
  49. log: loggerFactory.NewLogger("srtp"),
  50. },
  51. }
  52. s.writeStream = &WriteStreamSRTCP{s}
  53. err := s.session.start(
  54. config.Keys.LocalMasterKey, config.Keys.LocalMasterSalt,
  55. config.Keys.RemoteMasterKey, config.Keys.RemoteMasterSalt,
  56. config.Profile,
  57. s,
  58. )
  59. if err != nil {
  60. return nil, err
  61. }
  62. return s, nil
  63. }
  64. // OpenWriteStream returns the global write stream for the Session
  65. func (s *SessionSRTCP) OpenWriteStream() (*WriteStreamSRTCP, error) {
  66. return s.writeStream, nil
  67. }
  68. // OpenReadStream opens a read stream for the given SSRC, it can be used
  69. // if you want a certain SSRC, but don't want to wait for AcceptStream
  70. func (s *SessionSRTCP) OpenReadStream(ssrc uint32) (*ReadStreamSRTCP, error) {
  71. r, _ := s.session.getOrCreateReadStream(ssrc, s, newReadStreamSRTCP)
  72. if readStream, ok := r.(*ReadStreamSRTCP); ok {
  73. return readStream, nil
  74. }
  75. return nil, errFailedTypeAssertion
  76. }
  77. // AcceptStream returns a stream to handle RTCP for a single SSRC
  78. func (s *SessionSRTCP) AcceptStream() (*ReadStreamSRTCP, uint32, error) {
  79. stream, ok := <-s.newStream
  80. if !ok {
  81. return nil, 0, errStreamAlreadyClosed
  82. }
  83. readStream, ok := stream.(*ReadStreamSRTCP)
  84. if !ok {
  85. return nil, 0, errFailedTypeAssertion
  86. }
  87. return readStream, stream.GetSSRC(), nil
  88. }
  89. // Close ends the session
  90. func (s *SessionSRTCP) Close() error {
  91. return s.session.close()
  92. }
  93. // Private
  94. func (s *SessionSRTCP) write(buf []byte) (int, error) {
  95. if _, ok := <-s.session.started; ok {
  96. return 0, errStartedChannelUsedIncorrectly
  97. }
  98. ibuf := bufferpool.Get()
  99. defer bufferpool.Put(ibuf)
  100. s.session.localContextMutex.Lock()
  101. encrypted, err := s.localContext.EncryptRTCP(ibuf.([]byte), buf, nil)
  102. s.session.localContextMutex.Unlock()
  103. if err != nil {
  104. return 0, err
  105. }
  106. return s.session.nextConn.Write(encrypted)
  107. }
  108. func (s *SessionSRTCP) setWriteDeadline(t time.Time) error {
  109. return s.session.nextConn.SetWriteDeadline(t)
  110. }
  111. // create a list of Destination SSRCs
  112. // that's a superset of all Destinations in the slice.
  113. func destinationSSRC(pkts []rtcp.Packet) []uint32 {
  114. ssrcSet := make(map[uint32]struct{})
  115. for _, p := range pkts {
  116. for _, ssrc := range p.DestinationSSRC() {
  117. ssrcSet[ssrc] = struct{}{}
  118. }
  119. }
  120. out := make([]uint32, 0, len(ssrcSet))
  121. for ssrc := range ssrcSet {
  122. out = append(out, ssrc)
  123. }
  124. return out
  125. }
  126. func (s *SessionSRTCP) decrypt(buf []byte) error {
  127. decrypted, err := s.remoteContext.DecryptRTCP(buf, buf, nil)
  128. if err != nil {
  129. return err
  130. }
  131. pkt, err := rtcp.Unmarshal(decrypted)
  132. if err != nil {
  133. return err
  134. }
  135. for _, ssrc := range destinationSSRC(pkt) {
  136. r, isNew := s.session.getOrCreateReadStream(ssrc, s, newReadStreamSRTCP)
  137. if r == nil {
  138. return nil // Session has been closed
  139. } else if isNew {
  140. s.session.newStream <- r // Notify AcceptStream
  141. }
  142. readStream, ok := r.(*ReadStreamSRTCP)
  143. if !ok {
  144. return errFailedTypeAssertion
  145. }
  146. _, err = readStream.write(decrypted)
  147. if err != nil {
  148. return err
  149. }
  150. }
  151. return nil
  152. }