srtcp.go 2.0 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586
  1. package srtp
  2. import (
  3. "encoding/binary"
  4. "fmt"
  5. "github.com/pion/rtcp"
  6. )
  7. const maxSRTCPIndex = 0x7FFFFFFF
  8. func (c *Context) decryptRTCP(dst, encrypted []byte) ([]byte, error) {
  9. out := allocateIfMismatch(dst, encrypted)
  10. authTagLen, err := c.cipher.rtcpAuthTagLen()
  11. if err != nil {
  12. return nil, err
  13. }
  14. aeadAuthTagLen, err := c.cipher.aeadAuthTagLen()
  15. if err != nil {
  16. return nil, err
  17. }
  18. tailOffset := len(encrypted) - (authTagLen + srtcpIndexSize)
  19. if tailOffset < aeadAuthTagLen {
  20. return nil, fmt.Errorf("%w: %d", errTooShortRTCP, len(encrypted))
  21. } else if isEncrypted := encrypted[tailOffset] >> 7; isEncrypted == 0 {
  22. return out, nil
  23. }
  24. index := c.cipher.getRTCPIndex(encrypted)
  25. ssrc := binary.BigEndian.Uint32(encrypted[4:])
  26. s := c.getSRTCPSSRCState(ssrc)
  27. markAsValid, ok := s.replayDetector.Check(uint64(index))
  28. if !ok {
  29. return nil, &duplicatedError{Proto: "srtcp", SSRC: ssrc, Index: index}
  30. }
  31. out, err = c.cipher.decryptRTCP(out, encrypted, index, ssrc)
  32. if err != nil {
  33. return nil, err
  34. }
  35. markAsValid()
  36. return out, nil
  37. }
  38. // DecryptRTCP decrypts a buffer that contains a RTCP packet
  39. func (c *Context) DecryptRTCP(dst, encrypted []byte, header *rtcp.Header) ([]byte, error) {
  40. if header == nil {
  41. header = &rtcp.Header{}
  42. }
  43. if err := header.Unmarshal(encrypted); err != nil {
  44. return nil, err
  45. }
  46. return c.decryptRTCP(dst, encrypted)
  47. }
  48. func (c *Context) encryptRTCP(dst, decrypted []byte) ([]byte, error) {
  49. ssrc := binary.BigEndian.Uint32(decrypted[4:])
  50. s := c.getSRTCPSSRCState(ssrc)
  51. // We roll over early because MSB is used for marking as encrypted
  52. s.srtcpIndex++
  53. if s.srtcpIndex > maxSRTCPIndex {
  54. s.srtcpIndex = 0
  55. }
  56. return c.cipher.encryptRTCP(dst, decrypted, s.srtcpIndex, ssrc)
  57. }
  58. // EncryptRTCP Encrypts a RTCP packet
  59. func (c *Context) EncryptRTCP(dst, decrypted []byte, header *rtcp.Header) ([]byte, error) {
  60. if header == nil {
  61. header = &rtcp.Header{}
  62. }
  63. if err := header.Unmarshal(decrypted); err != nil {
  64. return nil, err
  65. }
  66. return c.encryptRTCP(dst, decrypted)
  67. }