handshake_cache.go 4.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169
  1. package dtls
  2. import (
  3. "sync"
  4. "github.com/pion/dtls/v2/pkg/crypto/prf"
  5. "github.com/pion/dtls/v2/pkg/protocol/handshake"
  6. )
  7. type handshakeCacheItem struct {
  8. typ handshake.Type
  9. isClient bool
  10. epoch uint16
  11. messageSequence uint16
  12. data []byte
  13. }
  14. type handshakeCachePullRule struct {
  15. typ handshake.Type
  16. epoch uint16
  17. isClient bool
  18. optional bool
  19. }
  20. type handshakeCache struct {
  21. cache []*handshakeCacheItem
  22. mu sync.Mutex
  23. }
  24. func newHandshakeCache() *handshakeCache {
  25. return &handshakeCache{}
  26. }
  27. func (h *handshakeCache) push(data []byte, epoch, messageSequence uint16, typ handshake.Type, isClient bool) {
  28. h.mu.Lock()
  29. defer h.mu.Unlock()
  30. h.cache = append(h.cache, &handshakeCacheItem{
  31. data: append([]byte{}, data...),
  32. epoch: epoch,
  33. messageSequence: messageSequence,
  34. typ: typ,
  35. isClient: isClient,
  36. })
  37. }
  38. // returns a list handshakes that match the requested rules
  39. // the list will contain null entries for rules that can't be satisfied
  40. // multiple entries may match a rule, but only the last match is returned (ie ClientHello with cookies)
  41. func (h *handshakeCache) pull(rules ...handshakeCachePullRule) []*handshakeCacheItem {
  42. h.mu.Lock()
  43. defer h.mu.Unlock()
  44. out := make([]*handshakeCacheItem, len(rules))
  45. for i, r := range rules {
  46. for _, c := range h.cache {
  47. if c.typ == r.typ && c.isClient == r.isClient && c.epoch == r.epoch {
  48. switch {
  49. case out[i] == nil:
  50. out[i] = c
  51. case out[i].messageSequence < c.messageSequence:
  52. out[i] = c
  53. }
  54. }
  55. }
  56. }
  57. return out
  58. }
  59. // fullPullMap pulls all handshakes between rules[0] to rules[len(rules)-1] as map.
  60. func (h *handshakeCache) fullPullMap(startSeq int, cipherSuite CipherSuite, rules ...handshakeCachePullRule) (int, map[handshake.Type]handshake.Message, bool) {
  61. h.mu.Lock()
  62. defer h.mu.Unlock()
  63. ci := make(map[handshake.Type]*handshakeCacheItem)
  64. for _, r := range rules {
  65. var item *handshakeCacheItem
  66. for _, c := range h.cache {
  67. if c.typ == r.typ && c.isClient == r.isClient && c.epoch == r.epoch {
  68. switch {
  69. case item == nil:
  70. item = c
  71. case item.messageSequence < c.messageSequence:
  72. item = c
  73. }
  74. }
  75. }
  76. if !r.optional && item == nil {
  77. // Missing mandatory message.
  78. return startSeq, nil, false
  79. }
  80. ci[r.typ] = item
  81. }
  82. out := make(map[handshake.Type]handshake.Message)
  83. seq := startSeq
  84. for _, r := range rules {
  85. t := r.typ
  86. i := ci[t]
  87. if i == nil {
  88. continue
  89. }
  90. var keyExchangeAlgorithm CipherSuiteKeyExchangeAlgorithm
  91. if cipherSuite != nil {
  92. keyExchangeAlgorithm = cipherSuite.KeyExchangeAlgorithm()
  93. }
  94. rawHandshake := &handshake.Handshake{
  95. KeyExchangeAlgorithm: keyExchangeAlgorithm,
  96. }
  97. if err := rawHandshake.Unmarshal(i.data); err != nil {
  98. return startSeq, nil, false
  99. }
  100. if uint16(seq) != rawHandshake.Header.MessageSequence {
  101. // There is a gap. Some messages are not arrived.
  102. return startSeq, nil, false
  103. }
  104. seq++
  105. out[t] = rawHandshake.Message
  106. }
  107. return seq, out, true
  108. }
  109. // pullAndMerge calls pull and then merges the results, ignoring any null entries
  110. func (h *handshakeCache) pullAndMerge(rules ...handshakeCachePullRule) []byte {
  111. merged := []byte{}
  112. for _, p := range h.pull(rules...) {
  113. if p != nil {
  114. merged = append(merged, p.data...)
  115. }
  116. }
  117. return merged
  118. }
  119. // sessionHash returns the session hash for Extended Master Secret support
  120. // https://tools.ietf.org/html/draft-ietf-tls-session-hash-06#section-4
  121. func (h *handshakeCache) sessionHash(hf prf.HashFunc, epoch uint16, additional ...[]byte) ([]byte, error) {
  122. merged := []byte{}
  123. // Order defined by https://tools.ietf.org/html/rfc5246#section-7.3
  124. handshakeBuffer := h.pull(
  125. handshakeCachePullRule{handshake.TypeClientHello, epoch, true, false},
  126. handshakeCachePullRule{handshake.TypeServerHello, epoch, false, false},
  127. handshakeCachePullRule{handshake.TypeCertificate, epoch, false, false},
  128. handshakeCachePullRule{handshake.TypeServerKeyExchange, epoch, false, false},
  129. handshakeCachePullRule{handshake.TypeCertificateRequest, epoch, false, false},
  130. handshakeCachePullRule{handshake.TypeServerHelloDone, epoch, false, false},
  131. handshakeCachePullRule{handshake.TypeCertificate, epoch, true, false},
  132. handshakeCachePullRule{handshake.TypeClientKeyExchange, epoch, true, false},
  133. )
  134. for _, p := range handshakeBuffer {
  135. if p == nil {
  136. continue
  137. }
  138. merged = append(merged, p.data...)
  139. }
  140. for _, a := range additional {
  141. merged = append(merged, a...)
  142. }
  143. hash := hf()
  144. if _, err := hash.Write(merged); err != nil {
  145. return []byte{}, err
  146. }
  147. return hash.Sum(nil), nil
  148. }