state.go 6.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215
  1. package dtls
  2. import (
  3. "bytes"
  4. "encoding/gob"
  5. "sync/atomic"
  6. "github.com/pion/dtls/v2/pkg/crypto/elliptic"
  7. "github.com/pion/dtls/v2/pkg/crypto/prf"
  8. "github.com/pion/dtls/v2/pkg/protocol/handshake"
  9. "github.com/pion/transport/v2/replaydetector"
  10. )
  11. // State holds the dtls connection state and implements both encoding.BinaryMarshaler and encoding.BinaryUnmarshaler
  12. type State struct {
  13. localEpoch, remoteEpoch atomic.Value
  14. localSequenceNumber []uint64 // uint48
  15. localRandom, remoteRandom handshake.Random
  16. masterSecret []byte
  17. cipherSuite CipherSuite // nil if a cipherSuite hasn't been chosen
  18. srtpProtectionProfile SRTPProtectionProfile // Negotiated SRTPProtectionProfile
  19. PeerCertificates [][]byte
  20. IdentityHint []byte
  21. SessionID []byte
  22. isClient bool
  23. preMasterSecret []byte
  24. extendedMasterSecret bool
  25. namedCurve elliptic.Curve
  26. localKeypair *elliptic.Keypair
  27. cookie []byte
  28. handshakeSendSequence int
  29. handshakeRecvSequence int
  30. serverName string
  31. remoteRequestedCertificate bool // Did we get a CertificateRequest
  32. localCertificatesVerify []byte // cache CertificateVerify
  33. localVerifyData []byte // cached VerifyData
  34. localKeySignature []byte // cached keySignature
  35. peerCertificatesVerified bool
  36. replayDetector []replaydetector.ReplayDetector
  37. peerSupportedProtocols []string
  38. NegotiatedProtocol string
  39. }
  40. type serializedState struct {
  41. LocalEpoch uint16
  42. RemoteEpoch uint16
  43. LocalRandom [handshake.RandomLength]byte
  44. RemoteRandom [handshake.RandomLength]byte
  45. CipherSuiteID uint16
  46. MasterSecret []byte
  47. SequenceNumber uint64
  48. SRTPProtectionProfile uint16
  49. PeerCertificates [][]byte
  50. IdentityHint []byte
  51. SessionID []byte
  52. IsClient bool
  53. }
  54. func (s *State) clone() *State {
  55. serialized := s.serialize()
  56. state := &State{}
  57. state.deserialize(*serialized)
  58. return state
  59. }
  60. func (s *State) serialize() *serializedState {
  61. // Marshal random values
  62. localRnd := s.localRandom.MarshalFixed()
  63. remoteRnd := s.remoteRandom.MarshalFixed()
  64. epoch := s.getLocalEpoch()
  65. return &serializedState{
  66. LocalEpoch: s.getLocalEpoch(),
  67. RemoteEpoch: s.getRemoteEpoch(),
  68. CipherSuiteID: uint16(s.cipherSuite.ID()),
  69. MasterSecret: s.masterSecret,
  70. SequenceNumber: atomic.LoadUint64(&s.localSequenceNumber[epoch]),
  71. LocalRandom: localRnd,
  72. RemoteRandom: remoteRnd,
  73. SRTPProtectionProfile: uint16(s.srtpProtectionProfile),
  74. PeerCertificates: s.PeerCertificates,
  75. IdentityHint: s.IdentityHint,
  76. SessionID: s.SessionID,
  77. IsClient: s.isClient,
  78. }
  79. }
  80. func (s *State) deserialize(serialized serializedState) {
  81. // Set epoch values
  82. epoch := serialized.LocalEpoch
  83. s.localEpoch.Store(serialized.LocalEpoch)
  84. s.remoteEpoch.Store(serialized.RemoteEpoch)
  85. for len(s.localSequenceNumber) <= int(epoch) {
  86. s.localSequenceNumber = append(s.localSequenceNumber, uint64(0))
  87. }
  88. // Set random values
  89. localRandom := &handshake.Random{}
  90. localRandom.UnmarshalFixed(serialized.LocalRandom)
  91. s.localRandom = *localRandom
  92. remoteRandom := &handshake.Random{}
  93. remoteRandom.UnmarshalFixed(serialized.RemoteRandom)
  94. s.remoteRandom = *remoteRandom
  95. s.isClient = serialized.IsClient
  96. // Set master secret
  97. s.masterSecret = serialized.MasterSecret
  98. // Set cipher suite
  99. s.cipherSuite = cipherSuiteForID(CipherSuiteID(serialized.CipherSuiteID), nil)
  100. atomic.StoreUint64(&s.localSequenceNumber[epoch], serialized.SequenceNumber)
  101. s.srtpProtectionProfile = SRTPProtectionProfile(serialized.SRTPProtectionProfile)
  102. // Set remote certificate
  103. s.PeerCertificates = serialized.PeerCertificates
  104. s.IdentityHint = serialized.IdentityHint
  105. s.SessionID = serialized.SessionID
  106. }
  107. func (s *State) initCipherSuite() error {
  108. if s.cipherSuite.IsInitialized() {
  109. return nil
  110. }
  111. localRandom := s.localRandom.MarshalFixed()
  112. remoteRandom := s.remoteRandom.MarshalFixed()
  113. var err error
  114. if s.isClient {
  115. err = s.cipherSuite.Init(s.masterSecret, localRandom[:], remoteRandom[:], true)
  116. } else {
  117. err = s.cipherSuite.Init(s.masterSecret, remoteRandom[:], localRandom[:], false)
  118. }
  119. if err != nil {
  120. return err
  121. }
  122. return nil
  123. }
  124. // MarshalBinary is a binary.BinaryMarshaler.MarshalBinary implementation
  125. func (s *State) MarshalBinary() ([]byte, error) {
  126. serialized := s.serialize()
  127. var buf bytes.Buffer
  128. enc := gob.NewEncoder(&buf)
  129. if err := enc.Encode(*serialized); err != nil {
  130. return nil, err
  131. }
  132. return buf.Bytes(), nil
  133. }
  134. // UnmarshalBinary is a binary.BinaryUnmarshaler.UnmarshalBinary implementation
  135. func (s *State) UnmarshalBinary(data []byte) error {
  136. enc := gob.NewDecoder(bytes.NewBuffer(data))
  137. var serialized serializedState
  138. if err := enc.Decode(&serialized); err != nil {
  139. return err
  140. }
  141. s.deserialize(serialized)
  142. if err := s.initCipherSuite(); err != nil {
  143. return err
  144. }
  145. return nil
  146. }
  147. // ExportKeyingMaterial returns length bytes of exported key material in a new
  148. // slice as defined in RFC 5705.
  149. // This allows protocols to use DTLS for key establishment, but
  150. // then use some of the keying material for their own purposes
  151. func (s *State) ExportKeyingMaterial(label string, context []byte, length int) ([]byte, error) {
  152. if s.getLocalEpoch() == 0 {
  153. return nil, errHandshakeInProgress
  154. } else if len(context) != 0 {
  155. return nil, errContextUnsupported
  156. } else if _, ok := invalidKeyingLabels()[label]; ok {
  157. return nil, errReservedExportKeyingMaterial
  158. }
  159. localRandom := s.localRandom.MarshalFixed()
  160. remoteRandom := s.remoteRandom.MarshalFixed()
  161. seed := []byte(label)
  162. if s.isClient {
  163. seed = append(append(seed, localRandom[:]...), remoteRandom[:]...)
  164. } else {
  165. seed = append(append(seed, remoteRandom[:]...), localRandom[:]...)
  166. }
  167. return prf.PHash(s.masterSecret, seed, length, s.cipherSuite.HashFunc())
  168. }
  169. func (s *State) getRemoteEpoch() uint16 {
  170. if remoteEpoch, ok := s.remoteEpoch.Load().(uint16); ok {
  171. return remoteEpoch
  172. }
  173. return 0
  174. }
  175. func (s *State) getLocalEpoch() uint16 {
  176. if localEpoch, ok := s.localEpoch.Load().(uint16); ok {
  177. return localEpoch
  178. }
  179. return 0
  180. }