handshaker.go 10.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347
  1. package dtls
  2. import (
  3. "context"
  4. "crypto/tls"
  5. "crypto/x509"
  6. "fmt"
  7. "io"
  8. "sync"
  9. "time"
  10. "github.com/pion/dtls/v2/pkg/crypto/elliptic"
  11. "github.com/pion/dtls/v2/pkg/crypto/signaturehash"
  12. "github.com/pion/dtls/v2/pkg/protocol/alert"
  13. "github.com/pion/dtls/v2/pkg/protocol/handshake"
  14. "github.com/pion/logging"
  15. )
  16. // [RFC6347 Section-4.2.4]
  17. // +-----------+
  18. // +---> | PREPARING | <--------------------+
  19. // | +-----------+ |
  20. // | | |
  21. // | | Buffer next flight |
  22. // | | |
  23. // | \|/ |
  24. // | +-----------+ |
  25. // | | SENDING |<------------------+ | Send
  26. // | +-----------+ | | HelloRequest
  27. // Receive | | | |
  28. // next | | Send flight | | or
  29. // flight | +--------+ | |
  30. // | | | Set retransmit timer | | Receive
  31. // | | \|/ | | HelloRequest
  32. // | | +-----------+ | | Send
  33. // +--)--| WAITING |-------------------+ | ClientHello
  34. // | | +-----------+ Timer expires | |
  35. // | | | | |
  36. // | | +------------------------+ |
  37. // Receive | | Send Read retransmit |
  38. // last | | last |
  39. // flight | | flight |
  40. // | | |
  41. // \|/\|/ |
  42. // +-----------+ |
  43. // | FINISHED | -------------------------------+
  44. // +-----------+
  45. // | /|\
  46. // | |
  47. // +---+
  48. // Read retransmit
  49. // Retransmit last flight
  50. type handshakeState uint8
  51. const (
  52. handshakeErrored handshakeState = iota
  53. handshakePreparing
  54. handshakeSending
  55. handshakeWaiting
  56. handshakeFinished
  57. )
  58. func (s handshakeState) String() string {
  59. switch s {
  60. case handshakeErrored:
  61. return "Errored"
  62. case handshakePreparing:
  63. return "Preparing"
  64. case handshakeSending:
  65. return "Sending"
  66. case handshakeWaiting:
  67. return "Waiting"
  68. case handshakeFinished:
  69. return "Finished"
  70. default:
  71. return "Unknown"
  72. }
  73. }
  74. type handshakeFSM struct {
  75. currentFlight flightVal
  76. flights []*packet
  77. retransmit bool
  78. state *State
  79. cache *handshakeCache
  80. cfg *handshakeConfig
  81. closed chan struct{}
  82. }
  83. type handshakeConfig struct {
  84. localPSKCallback PSKCallback
  85. localPSKIdentityHint []byte
  86. localCipherSuites []CipherSuite // Available CipherSuites
  87. localSignatureSchemes []signaturehash.Algorithm // Available signature schemes
  88. extendedMasterSecret ExtendedMasterSecretType // Policy for the Extended Master Support extension
  89. localSRTPProtectionProfiles []SRTPProtectionProfile // Available SRTPProtectionProfiles, if empty no SRTP support
  90. serverName string
  91. supportedProtocols []string
  92. clientAuth ClientAuthType // If we are a client should we request a client certificate
  93. localCertificates []tls.Certificate
  94. nameToCertificate map[string]*tls.Certificate
  95. insecureSkipVerify bool
  96. verifyPeerCertificate func(rawCerts [][]byte, verifiedChains [][]*x509.Certificate) error
  97. verifyConnection func(*State) error
  98. sessionStore SessionStore
  99. rootCAs *x509.CertPool
  100. clientCAs *x509.CertPool
  101. retransmitInterval time.Duration
  102. customCipherSuites func() []CipherSuite
  103. ellipticCurves []elliptic.Curve
  104. insecureSkipHelloVerify bool
  105. onFlightState func(flightVal, handshakeState)
  106. log logging.LeveledLogger
  107. keyLogWriter io.Writer
  108. localGetCertificate func(*ClientHelloInfo) (*tls.Certificate, error)
  109. localGetClientCertificate func(*CertificateRequestInfo) (*tls.Certificate, error)
  110. initialEpoch uint16
  111. mu sync.Mutex
  112. }
  113. type flightConn interface {
  114. notify(ctx context.Context, level alert.Level, desc alert.Description) error
  115. writePackets(context.Context, []*packet) error
  116. recvHandshake() <-chan chan struct{}
  117. setLocalEpoch(epoch uint16)
  118. handleQueuedPackets(context.Context) error
  119. sessionKey() []byte
  120. }
  121. func (c *handshakeConfig) writeKeyLog(label string, clientRandom, secret []byte) {
  122. if c.keyLogWriter == nil {
  123. return
  124. }
  125. c.mu.Lock()
  126. defer c.mu.Unlock()
  127. _, err := c.keyLogWriter.Write([]byte(fmt.Sprintf("%s %x %x\n", label, clientRandom, secret)))
  128. if err != nil {
  129. c.log.Debugf("failed to write key log file: %s", err)
  130. }
  131. }
  132. func srvCliStr(isClient bool) string {
  133. if isClient {
  134. return "client"
  135. }
  136. return "server"
  137. }
  138. func newHandshakeFSM(
  139. s *State, cache *handshakeCache, cfg *handshakeConfig,
  140. initialFlight flightVal,
  141. ) *handshakeFSM {
  142. return &handshakeFSM{
  143. currentFlight: initialFlight,
  144. state: s,
  145. cache: cache,
  146. cfg: cfg,
  147. closed: make(chan struct{}),
  148. }
  149. }
  150. func (s *handshakeFSM) Run(ctx context.Context, c flightConn, initialState handshakeState) error {
  151. state := initialState
  152. defer func() {
  153. close(s.closed)
  154. }()
  155. for {
  156. s.cfg.log.Tracef("[handshake:%s] %s: %s", srvCliStr(s.state.isClient), s.currentFlight.String(), state.String())
  157. if s.cfg.onFlightState != nil {
  158. s.cfg.onFlightState(s.currentFlight, state)
  159. }
  160. var err error
  161. switch state {
  162. case handshakePreparing:
  163. state, err = s.prepare(ctx, c)
  164. case handshakeSending:
  165. state, err = s.send(ctx, c)
  166. case handshakeWaiting:
  167. state, err = s.wait(ctx, c)
  168. case handshakeFinished:
  169. state, err = s.finish(ctx, c)
  170. default:
  171. return errInvalidFSMTransition
  172. }
  173. if err != nil {
  174. return err
  175. }
  176. }
  177. }
  178. func (s *handshakeFSM) Done() <-chan struct{} {
  179. return s.closed
  180. }
  181. func (s *handshakeFSM) prepare(ctx context.Context, c flightConn) (handshakeState, error) {
  182. s.flights = nil
  183. // Prepare flights
  184. var (
  185. a *alert.Alert
  186. err error
  187. pkts []*packet
  188. )
  189. gen, retransmit, errFlight := s.currentFlight.getFlightGenerator()
  190. if errFlight != nil {
  191. err = errFlight
  192. a = &alert.Alert{Level: alert.Fatal, Description: alert.InternalError}
  193. } else {
  194. pkts, a, err = gen(c, s.state, s.cache, s.cfg)
  195. s.retransmit = retransmit
  196. }
  197. if a != nil {
  198. if alertErr := c.notify(ctx, a.Level, a.Description); alertErr != nil {
  199. if err != nil {
  200. err = alertErr
  201. }
  202. }
  203. }
  204. if err != nil {
  205. return handshakeErrored, err
  206. }
  207. s.flights = pkts
  208. epoch := s.cfg.initialEpoch
  209. nextEpoch := epoch
  210. for _, p := range s.flights {
  211. p.record.Header.Epoch += epoch
  212. if p.record.Header.Epoch > nextEpoch {
  213. nextEpoch = p.record.Header.Epoch
  214. }
  215. if h, ok := p.record.Content.(*handshake.Handshake); ok {
  216. h.Header.MessageSequence = uint16(s.state.handshakeSendSequence)
  217. s.state.handshakeSendSequence++
  218. }
  219. }
  220. if epoch != nextEpoch {
  221. s.cfg.log.Tracef("[handshake:%s] -> changeCipherSpec (epoch: %d)", srvCliStr(s.state.isClient), nextEpoch)
  222. c.setLocalEpoch(nextEpoch)
  223. }
  224. return handshakeSending, nil
  225. }
  226. func (s *handshakeFSM) send(ctx context.Context, c flightConn) (handshakeState, error) {
  227. // Send flights
  228. if err := c.writePackets(ctx, s.flights); err != nil {
  229. return handshakeErrored, err
  230. }
  231. if s.currentFlight.isLastSendFlight() {
  232. return handshakeFinished, nil
  233. }
  234. return handshakeWaiting, nil
  235. }
  236. func (s *handshakeFSM) wait(ctx context.Context, c flightConn) (handshakeState, error) { //nolint:gocognit
  237. parse, errFlight := s.currentFlight.getFlightParser()
  238. if errFlight != nil {
  239. if alertErr := c.notify(ctx, alert.Fatal, alert.InternalError); alertErr != nil {
  240. if errFlight != nil {
  241. return handshakeErrored, alertErr
  242. }
  243. }
  244. return handshakeErrored, errFlight
  245. }
  246. retransmitTimer := time.NewTimer(s.cfg.retransmitInterval)
  247. for {
  248. select {
  249. case done := <-c.recvHandshake():
  250. nextFlight, alert, err := parse(ctx, c, s.state, s.cache, s.cfg)
  251. close(done)
  252. if alert != nil {
  253. if alertErr := c.notify(ctx, alert.Level, alert.Description); alertErr != nil {
  254. if err != nil {
  255. err = alertErr
  256. }
  257. }
  258. }
  259. if err != nil {
  260. return handshakeErrored, err
  261. }
  262. if nextFlight == 0 {
  263. break
  264. }
  265. s.cfg.log.Tracef("[handshake:%s] %s -> %s", srvCliStr(s.state.isClient), s.currentFlight.String(), nextFlight.String())
  266. if nextFlight.isLastRecvFlight() && s.currentFlight == nextFlight {
  267. return handshakeFinished, nil
  268. }
  269. s.currentFlight = nextFlight
  270. return handshakePreparing, nil
  271. case <-retransmitTimer.C:
  272. if !s.retransmit {
  273. return handshakeWaiting, nil
  274. }
  275. return handshakeSending, nil
  276. case <-ctx.Done():
  277. return handshakeErrored, ctx.Err()
  278. }
  279. }
  280. }
  281. func (s *handshakeFSM) finish(ctx context.Context, c flightConn) (handshakeState, error) {
  282. parse, errFlight := s.currentFlight.getFlightParser()
  283. if errFlight != nil {
  284. if alertErr := c.notify(ctx, alert.Fatal, alert.InternalError); alertErr != nil {
  285. if errFlight != nil {
  286. return handshakeErrored, alertErr
  287. }
  288. }
  289. return handshakeErrored, errFlight
  290. }
  291. retransmitTimer := time.NewTimer(s.cfg.retransmitInterval)
  292. select {
  293. case done := <-c.recvHandshake():
  294. nextFlight, alert, err := parse(ctx, c, s.state, s.cache, s.cfg)
  295. close(done)
  296. if alert != nil {
  297. if alertErr := c.notify(ctx, alert.Level, alert.Description); alertErr != nil {
  298. if err != nil {
  299. err = alertErr
  300. }
  301. }
  302. }
  303. if err != nil {
  304. return handshakeErrored, err
  305. }
  306. if nextFlight == 0 {
  307. break
  308. }
  309. if nextFlight.isLastRecvFlight() && s.currentFlight == nextFlight {
  310. return handshakeFinished, nil
  311. }
  312. <-retransmitTimer.C
  313. // Retransmit last flight
  314. return handshakeSending, nil
  315. case <-ctx.Done():
  316. return handshakeErrored, ctx.Err()
  317. }
  318. return handshakeFinished, nil
  319. }