session_srtp.go 4.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193
  1. package srtp
  2. import (
  3. "net"
  4. "sync"
  5. "time"
  6. "github.com/pion/logging"
  7. "github.com/pion/rtp"
  8. )
  9. const defaultSessionSRTPReplayProtectionWindow = 64
  10. // SessionSRTP implements io.ReadWriteCloser and provides a bi-directional SRTP session
  11. // SRTP itself does not have a design like this, but it is common in most applications
  12. // for local/remote to each have their own keying material. This provides those patterns
  13. // instead of making everyone re-implement
  14. type SessionSRTP struct {
  15. session
  16. writeStream *WriteStreamSRTP
  17. }
  18. // NewSessionSRTP creates a SRTP session using conn as the underlying transport.
  19. func NewSessionSRTP(conn net.Conn, config *Config) (*SessionSRTP, error) { //nolint:dupl
  20. if config == nil {
  21. return nil, errNoConfig
  22. } else if conn == nil {
  23. return nil, errNoConn
  24. }
  25. loggerFactory := config.LoggerFactory
  26. if loggerFactory == nil {
  27. loggerFactory = logging.NewDefaultLoggerFactory()
  28. }
  29. localOpts := append(
  30. []ContextOption{},
  31. config.LocalOptions...,
  32. )
  33. remoteOpts := append(
  34. []ContextOption{
  35. // Default options
  36. SRTPReplayProtection(defaultSessionSRTPReplayProtectionWindow),
  37. },
  38. config.RemoteOptions...,
  39. )
  40. s := &SessionSRTP{
  41. session: session{
  42. nextConn: conn,
  43. localOptions: localOpts,
  44. remoteOptions: remoteOpts,
  45. readStreams: map[uint32]readStream{},
  46. newStream: make(chan readStream),
  47. started: make(chan interface{}),
  48. closed: make(chan interface{}),
  49. bufferFactory: config.BufferFactory,
  50. log: loggerFactory.NewLogger("srtp"),
  51. },
  52. }
  53. s.writeStream = &WriteStreamSRTP{s}
  54. err := s.session.start(
  55. config.Keys.LocalMasterKey, config.Keys.LocalMasterSalt,
  56. config.Keys.RemoteMasterKey, config.Keys.RemoteMasterSalt,
  57. config.Profile,
  58. s,
  59. )
  60. if err != nil {
  61. return nil, err
  62. }
  63. return s, nil
  64. }
  65. // OpenWriteStream returns the global write stream for the Session
  66. func (s *SessionSRTP) OpenWriteStream() (*WriteStreamSRTP, error) {
  67. return s.writeStream, nil
  68. }
  69. // OpenReadStream opens a read stream for the given SSRC, it can be used
  70. // if you want a certain SSRC, but don't want to wait for AcceptStream
  71. func (s *SessionSRTP) OpenReadStream(ssrc uint32) (*ReadStreamSRTP, error) {
  72. r, _ := s.session.getOrCreateReadStream(ssrc, s, newReadStreamSRTP)
  73. if readStream, ok := r.(*ReadStreamSRTP); ok {
  74. return readStream, nil
  75. }
  76. return nil, errFailedTypeAssertion
  77. }
  78. // AcceptStream returns a stream to handle RTCP for a single SSRC
  79. func (s *SessionSRTP) AcceptStream() (*ReadStreamSRTP, uint32, error) {
  80. stream, ok := <-s.newStream
  81. if !ok {
  82. return nil, 0, errStreamAlreadyClosed
  83. }
  84. readStream, ok := stream.(*ReadStreamSRTP)
  85. if !ok {
  86. return nil, 0, errFailedTypeAssertion
  87. }
  88. return readStream, stream.GetSSRC(), nil
  89. }
  90. // Close ends the session
  91. func (s *SessionSRTP) Close() error {
  92. return s.session.close()
  93. }
  94. func (s *SessionSRTP) write(b []byte) (int, error) {
  95. packet := &rtp.Packet{}
  96. if err := packet.Unmarshal(b); err != nil {
  97. return 0, err
  98. }
  99. return s.writeRTP(&packet.Header, packet.Payload)
  100. }
  101. // bufferpool is a global pool of buffers used for encrypted packets in
  102. // writeRTP below. Since it's global, buffers can be shared between
  103. // different sessions, which amortizes the cost of allocating the pool.
  104. //
  105. // 1472 is the maximum Ethernet UDP payload. We give ourselves 20 bytes
  106. // of slack for any authentication tags, which is more than enough for
  107. // either CTR or GCM. If the buffer is too small, no harm, it will just
  108. // get expanded by growBuffer.
  109. var bufferpool = sync.Pool{ // nolint:gochecknoglobals
  110. New: func() interface{} {
  111. return make([]byte, 1492)
  112. },
  113. }
  114. func (s *SessionSRTP) writeRTP(header *rtp.Header, payload []byte) (int, error) {
  115. if _, ok := <-s.session.started; ok {
  116. return 0, errStartedChannelUsedIncorrectly
  117. }
  118. // encryptRTP will either return our buffer, or, if it is too
  119. // small, allocate a new buffer itself. In either case, it is
  120. // safe to put the buffer back into the pool, but only after
  121. // nextConn.Write has returned.
  122. ibuf := bufferpool.Get()
  123. defer bufferpool.Put(ibuf)
  124. s.session.localContextMutex.Lock()
  125. encrypted, err := s.localContext.encryptRTP(ibuf.([]byte), header, payload)
  126. s.session.localContextMutex.Unlock()
  127. if err != nil {
  128. return 0, err
  129. }
  130. return s.session.nextConn.Write(encrypted)
  131. }
  132. func (s *SessionSRTP) setWriteDeadline(t time.Time) error {
  133. return s.session.nextConn.SetWriteDeadline(t)
  134. }
  135. func (s *SessionSRTP) decrypt(buf []byte) error {
  136. h := &rtp.Header{}
  137. headerLen, err := h.Unmarshal(buf)
  138. if err != nil {
  139. return err
  140. }
  141. r, isNew := s.session.getOrCreateReadStream(h.SSRC, s, newReadStreamSRTP)
  142. if r == nil {
  143. return nil // Session has been closed
  144. } else if isNew {
  145. s.session.newStream <- r // Notify AcceptStream
  146. }
  147. readStream, ok := r.(*ReadStreamSRTP)
  148. if !ok {
  149. return errFailedTypeAssertion
  150. }
  151. decrypted, err := s.remoteContext.decryptRTP(buf, buf, h, headerLen)
  152. if err != nil {
  153. return err
  154. }
  155. _, err = readStream.write(decrypted)
  156. if err != nil {
  157. return err
  158. }
  159. return nil
  160. }