| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183 |
- package srtp
- import (
- "net"
- "time"
- "github.com/pion/logging"
- "github.com/pion/rtcp"
- )
- const defaultSessionSRTCPReplayProtectionWindow = 64
- // SessionSRTCP implements io.ReadWriteCloser and provides a bi-directional SRTCP session
- // SRTCP itself does not have a design like this, but it is common in most applications
- // for local/remote to each have their own keying material. This provides those patterns
- // instead of making everyone re-implement
- type SessionSRTCP struct {
- session
- writeStream *WriteStreamSRTCP
- }
- // NewSessionSRTCP creates a SRTCP session using conn as the underlying transport.
- func NewSessionSRTCP(conn net.Conn, config *Config) (*SessionSRTCP, error) { //nolint:dupl
- if config == nil {
- return nil, errNoConfig
- } else if conn == nil {
- return nil, errNoConn
- }
- loggerFactory := config.LoggerFactory
- if loggerFactory == nil {
- loggerFactory = logging.NewDefaultLoggerFactory()
- }
- localOpts := append(
- []ContextOption{},
- config.LocalOptions...,
- )
- remoteOpts := append(
- []ContextOption{
- // Default options
- SRTCPReplayProtection(defaultSessionSRTCPReplayProtectionWindow),
- },
- config.RemoteOptions...,
- )
- s := &SessionSRTCP{
- session: session{
- nextConn: conn,
- localOptions: localOpts,
- remoteOptions: remoteOpts,
- readStreams: map[uint32]readStream{},
- newStream: make(chan readStream),
- started: make(chan interface{}),
- closed: make(chan interface{}),
- bufferFactory: config.BufferFactory,
- log: loggerFactory.NewLogger("srtp"),
- },
- }
- s.writeStream = &WriteStreamSRTCP{s}
- err := s.session.start(
- config.Keys.LocalMasterKey, config.Keys.LocalMasterSalt,
- config.Keys.RemoteMasterKey, config.Keys.RemoteMasterSalt,
- config.Profile,
- s,
- )
- if err != nil {
- return nil, err
- }
- return s, nil
- }
- // OpenWriteStream returns the global write stream for the Session
- func (s *SessionSRTCP) OpenWriteStream() (*WriteStreamSRTCP, error) {
- return s.writeStream, nil
- }
- // OpenReadStream opens a read stream for the given SSRC, it can be used
- // if you want a certain SSRC, but don't want to wait for AcceptStream
- func (s *SessionSRTCP) OpenReadStream(ssrc uint32) (*ReadStreamSRTCP, error) {
- r, _ := s.session.getOrCreateReadStream(ssrc, s, newReadStreamSRTCP)
- if readStream, ok := r.(*ReadStreamSRTCP); ok {
- return readStream, nil
- }
- return nil, errFailedTypeAssertion
- }
- // AcceptStream returns a stream to handle RTCP for a single SSRC
- func (s *SessionSRTCP) AcceptStream() (*ReadStreamSRTCP, uint32, error) {
- stream, ok := <-s.newStream
- if !ok {
- return nil, 0, errStreamAlreadyClosed
- }
- readStream, ok := stream.(*ReadStreamSRTCP)
- if !ok {
- return nil, 0, errFailedTypeAssertion
- }
- return readStream, stream.GetSSRC(), nil
- }
- // Close ends the session
- func (s *SessionSRTCP) Close() error {
- return s.session.close()
- }
- // Private
- func (s *SessionSRTCP) write(buf []byte) (int, error) {
- if _, ok := <-s.session.started; ok {
- return 0, errStartedChannelUsedIncorrectly
- }
- ibuf := bufferpool.Get()
- defer bufferpool.Put(ibuf)
- s.session.localContextMutex.Lock()
- encrypted, err := s.localContext.EncryptRTCP(ibuf.([]byte), buf, nil)
- s.session.localContextMutex.Unlock()
- if err != nil {
- return 0, err
- }
- return s.session.nextConn.Write(encrypted)
- }
- func (s *SessionSRTCP) setWriteDeadline(t time.Time) error {
- return s.session.nextConn.SetWriteDeadline(t)
- }
- // create a list of Destination SSRCs
- // that's a superset of all Destinations in the slice.
- func destinationSSRC(pkts []rtcp.Packet) []uint32 {
- ssrcSet := make(map[uint32]struct{})
- for _, p := range pkts {
- for _, ssrc := range p.DestinationSSRC() {
- ssrcSet[ssrc] = struct{}{}
- }
- }
- out := make([]uint32, 0, len(ssrcSet))
- for ssrc := range ssrcSet {
- out = append(out, ssrc)
- }
- return out
- }
- func (s *SessionSRTCP) decrypt(buf []byte) error {
- decrypted, err := s.remoteContext.DecryptRTCP(buf, buf, nil)
- if err != nil {
- return err
- }
- pkt, err := rtcp.Unmarshal(decrypted)
- if err != nil {
- return err
- }
- for _, ssrc := range destinationSSRC(pkt) {
- r, isNew := s.session.getOrCreateReadStream(ssrc, s, newReadStreamSRTCP)
- if r == nil {
- return nil // Session has been closed
- } else if isNew {
- s.session.newStream <- r // Notify AcceptStream
- }
- readStream, ok := r.(*ReadStreamSRTCP)
- if !ok {
- return errFailedTypeAssertion
- }
- _, err = readStream.write(decrypted)
- if err != nil {
- return err
- }
- }
- return nil
- }
|