conn.go 27 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889909192939495969798991001011021031041051061071081091101111121131141151161171181191201211221231241251261271281291301311321331341351361371381391401411421431441451461471481491501511521531541551561571581591601611621631641651661671681691701711721731741751761771781791801811821831841851861871881891901911921931941951961971981992002012022032042052062072082092102112122132142152162172182192202212222232242252262272282292302312322332342352362372382392402412422432442452462472482492502512522532542552562572582592602612622632642652662672682692702712722732742752762772782792802812822832842852862872882892902912922932942952962972982993003013023033043053063073083093103113123133143153163173183193203213223233243253263273283293303313323333343353363373383393403413423433443453463473483493503513523533543553563573583593603613623633643653663673683693703713723733743753763773783793803813823833843853863873883893903913923933943953963973983994004014024034044054064074084094104114124134144154164174184194204214224234244254264274284294304314324334344354364374384394404414424434444454464474484494504514524534544554564574584594604614624634644654664674684694704714724734744754764774784794804814824834844854864874884894904914924934944954964974984995005015025035045055065075085095105115125135145155165175185195205215225235245255265275285295305315325335345355365375385395405415425435445455465475485495505515525535545555565575585595605615625635645655665675685695705715725735745755765775785795805815825835845855865875885895905915925935945955965975985996006016026036046056066076086096106116126136146156166176186196206216226236246256266276286296306316326336346356366376386396406416426436446456466476486496506516526536546556566576586596606616626636646656666676686696706716726736746756766776786796806816826836846856866876886896906916926936946956966976986997007017027037047057067077087097107117127137147157167177187197207217227237247257267277287297307317327337347357367377387397407417427437447457467477487497507517527537547557567577587597607617627637647657667677687697707717727737747757767777787797807817827837847857867877887897907917927937947957967977987998008018028038048058068078088098108118128138148158168178188198208218228238248258268278288298308318328338348358368378388398408418428438448458468478488498508518528538548558568578588598608618628638648658668678688698708718728738748758768778788798808818828838848858868878888898908918928938948958968978988999009019029039049059069079089099109119129139149159169179189199209219229239249259269279289299309319329339349359369379389399409419429439449459469479489499509519529539549559569579589599609619629639649659669679689699709719729739749759769779789799809819829839849859869879889899909919929939949959969979989991000100110021003100410051006100710081009101010111012101310141015101610171018101910201021102210231024
  1. package dtls
  2. import (
  3. "context"
  4. "errors"
  5. "fmt"
  6. "io"
  7. "net"
  8. "sync"
  9. "sync/atomic"
  10. "time"
  11. "github.com/pion/dtls/v2/internal/closer"
  12. "github.com/pion/dtls/v2/pkg/crypto/elliptic"
  13. "github.com/pion/dtls/v2/pkg/crypto/signaturehash"
  14. "github.com/pion/dtls/v2/pkg/protocol"
  15. "github.com/pion/dtls/v2/pkg/protocol/alert"
  16. "github.com/pion/dtls/v2/pkg/protocol/handshake"
  17. "github.com/pion/dtls/v2/pkg/protocol/recordlayer"
  18. "github.com/pion/logging"
  19. "github.com/pion/transport/v2/connctx"
  20. "github.com/pion/transport/v2/deadline"
  21. "github.com/pion/transport/v2/replaydetector"
  22. )
  23. const (
  24. initialTickerInterval = time.Second
  25. cookieLength = 20
  26. sessionLength = 32
  27. defaultNamedCurve = elliptic.X25519
  28. inboundBufferSize = 8192
  29. // Default replay protection window is specified by RFC 6347 Section 4.1.2.6
  30. defaultReplayProtectionWindow = 64
  31. )
  32. func invalidKeyingLabels() map[string]bool {
  33. return map[string]bool{
  34. "client finished": true,
  35. "server finished": true,
  36. "master secret": true,
  37. "key expansion": true,
  38. }
  39. }
  40. // Conn represents a DTLS connection
  41. type Conn struct {
  42. lock sync.RWMutex // Internal lock (must not be public)
  43. nextConn connctx.ConnCtx // Embedded Conn, typically a udpconn we read/write from
  44. fragmentBuffer *fragmentBuffer // out-of-order and missing fragment handling
  45. handshakeCache *handshakeCache // caching of handshake messages for verifyData generation
  46. decrypted chan interface{} // Decrypted Application Data or error, pull by calling `Read`
  47. state State // Internal state
  48. maximumTransmissionUnit int
  49. handshakeCompletedSuccessfully atomic.Value
  50. encryptedPackets [][]byte
  51. connectionClosedByUser bool
  52. closeLock sync.Mutex
  53. closed *closer.Closer
  54. handshakeLoopsFinished sync.WaitGroup
  55. readDeadline *deadline.Deadline
  56. writeDeadline *deadline.Deadline
  57. log logging.LeveledLogger
  58. reading chan struct{}
  59. handshakeRecv chan chan struct{}
  60. cancelHandshaker func()
  61. cancelHandshakeReader func()
  62. fsm *handshakeFSM
  63. replayProtectionWindow uint
  64. }
  65. func createConn(ctx context.Context, nextConn net.Conn, config *Config, isClient bool, initialState *State) (*Conn, error) {
  66. err := validateConfig(config)
  67. if err != nil {
  68. return nil, err
  69. }
  70. if nextConn == nil {
  71. return nil, errNilNextConn
  72. }
  73. cipherSuites, err := parseCipherSuites(config.CipherSuites, config.CustomCipherSuites, config.includeCertificateSuites(), config.PSK != nil)
  74. if err != nil {
  75. return nil, err
  76. }
  77. signatureSchemes, err := signaturehash.ParseSignatureSchemes(config.SignatureSchemes, config.InsecureHashes)
  78. if err != nil {
  79. return nil, err
  80. }
  81. workerInterval := initialTickerInterval
  82. if config.FlightInterval != 0 {
  83. workerInterval = config.FlightInterval
  84. }
  85. loggerFactory := config.LoggerFactory
  86. if loggerFactory == nil {
  87. loggerFactory = logging.NewDefaultLoggerFactory()
  88. }
  89. logger := loggerFactory.NewLogger("dtls")
  90. mtu := config.MTU
  91. if mtu <= 0 {
  92. mtu = defaultMTU
  93. }
  94. replayProtectionWindow := config.ReplayProtectionWindow
  95. if replayProtectionWindow <= 0 {
  96. replayProtectionWindow = defaultReplayProtectionWindow
  97. }
  98. c := &Conn{
  99. nextConn: connctx.New(nextConn),
  100. fragmentBuffer: newFragmentBuffer(),
  101. handshakeCache: newHandshakeCache(),
  102. maximumTransmissionUnit: mtu,
  103. decrypted: make(chan interface{}, 1),
  104. log: logger,
  105. readDeadline: deadline.New(),
  106. writeDeadline: deadline.New(),
  107. reading: make(chan struct{}, 1),
  108. handshakeRecv: make(chan chan struct{}),
  109. closed: closer.NewCloser(),
  110. cancelHandshaker: func() {},
  111. replayProtectionWindow: uint(replayProtectionWindow),
  112. state: State{
  113. isClient: isClient,
  114. },
  115. }
  116. c.setRemoteEpoch(0)
  117. c.setLocalEpoch(0)
  118. serverName := config.ServerName
  119. // Do not allow the use of an IP address literal as an SNI value.
  120. // See RFC 6066, Section 3.
  121. if net.ParseIP(serverName) != nil {
  122. serverName = ""
  123. }
  124. curves := config.EllipticCurves
  125. if len(curves) == 0 {
  126. curves = defaultCurves
  127. }
  128. hsCfg := &handshakeConfig{
  129. localPSKCallback: config.PSK,
  130. localPSKIdentityHint: config.PSKIdentityHint,
  131. localCipherSuites: cipherSuites,
  132. localSignatureSchemes: signatureSchemes,
  133. extendedMasterSecret: config.ExtendedMasterSecret,
  134. localSRTPProtectionProfiles: config.SRTPProtectionProfiles,
  135. serverName: serverName,
  136. supportedProtocols: config.SupportedProtocols,
  137. clientAuth: config.ClientAuth,
  138. localCertificates: config.Certificates,
  139. insecureSkipVerify: config.InsecureSkipVerify,
  140. verifyPeerCertificate: config.VerifyPeerCertificate,
  141. verifyConnection: config.VerifyConnection,
  142. rootCAs: config.RootCAs,
  143. clientCAs: config.ClientCAs,
  144. customCipherSuites: config.CustomCipherSuites,
  145. retransmitInterval: workerInterval,
  146. log: logger,
  147. initialEpoch: 0,
  148. keyLogWriter: config.KeyLogWriter,
  149. sessionStore: config.SessionStore,
  150. ellipticCurves: curves,
  151. localGetCertificate: config.GetCertificate,
  152. localGetClientCertificate: config.GetClientCertificate,
  153. insecureSkipHelloVerify: config.InsecureSkipVerifyHello,
  154. }
  155. // rfc5246#section-7.4.3
  156. // In addition, the hash and signature algorithms MUST be compatible
  157. // with the key in the server's end-entity certificate.
  158. if !isClient {
  159. cert, err := hsCfg.getCertificate(&ClientHelloInfo{})
  160. if err != nil && !errors.Is(err, errNoCertificates) {
  161. return nil, err
  162. }
  163. hsCfg.localCipherSuites = filterCipherSuitesForCertificate(cert, cipherSuites)
  164. }
  165. var initialFlight flightVal
  166. var initialFSMState handshakeState
  167. if initialState != nil {
  168. if c.state.isClient {
  169. initialFlight = flight5
  170. } else {
  171. initialFlight = flight6
  172. }
  173. initialFSMState = handshakeFinished
  174. c.state = *initialState
  175. } else {
  176. if c.state.isClient {
  177. initialFlight = flight1
  178. } else {
  179. initialFlight = flight0
  180. }
  181. initialFSMState = handshakePreparing
  182. }
  183. // Do handshake
  184. if err := c.handshake(ctx, hsCfg, initialFlight, initialFSMState); err != nil {
  185. return nil, err
  186. }
  187. c.log.Trace("Handshake Completed")
  188. return c, nil
  189. }
  190. // Dial connects to the given network address and establishes a DTLS connection on top.
  191. // Connection handshake will timeout using ConnectContextMaker in the Config.
  192. // If you want to specify the timeout duration, use DialWithContext() instead.
  193. func Dial(network string, raddr *net.UDPAddr, config *Config) (*Conn, error) {
  194. ctx, cancel := config.connectContextMaker()
  195. defer cancel()
  196. return DialWithContext(ctx, network, raddr, config)
  197. }
  198. // Client establishes a DTLS connection over an existing connection.
  199. // Connection handshake will timeout using ConnectContextMaker in the Config.
  200. // If you want to specify the timeout duration, use ClientWithContext() instead.
  201. func Client(conn net.Conn, config *Config) (*Conn, error) {
  202. ctx, cancel := config.connectContextMaker()
  203. defer cancel()
  204. return ClientWithContext(ctx, conn, config)
  205. }
  206. // Server listens for incoming DTLS connections.
  207. // Connection handshake will timeout using ConnectContextMaker in the Config.
  208. // If you want to specify the timeout duration, use ServerWithContext() instead.
  209. func Server(conn net.Conn, config *Config) (*Conn, error) {
  210. ctx, cancel := config.connectContextMaker()
  211. defer cancel()
  212. return ServerWithContext(ctx, conn, config)
  213. }
  214. // DialWithContext connects to the given network address and establishes a DTLS connection on top.
  215. func DialWithContext(ctx context.Context, network string, raddr *net.UDPAddr, config *Config) (*Conn, error) {
  216. pConn, err := net.DialUDP(network, nil, raddr)
  217. if err != nil {
  218. return nil, err
  219. }
  220. return ClientWithContext(ctx, pConn, config)
  221. }
  222. // ClientWithContext establishes a DTLS connection over an existing connection.
  223. func ClientWithContext(ctx context.Context, conn net.Conn, config *Config) (*Conn, error) {
  224. switch {
  225. case config == nil:
  226. return nil, errNoConfigProvided
  227. case config.PSK != nil && config.PSKIdentityHint == nil:
  228. return nil, errPSKAndIdentityMustBeSetForClient
  229. }
  230. return createConn(ctx, conn, config, true, nil)
  231. }
  232. // ServerWithContext listens for incoming DTLS connections.
  233. func ServerWithContext(ctx context.Context, conn net.Conn, config *Config) (*Conn, error) {
  234. if config == nil {
  235. return nil, errNoConfigProvided
  236. }
  237. return createConn(ctx, conn, config, false, nil)
  238. }
  239. // Read reads data from the connection.
  240. func (c *Conn) Read(p []byte) (n int, err error) {
  241. if !c.isHandshakeCompletedSuccessfully() {
  242. return 0, errHandshakeInProgress
  243. }
  244. select {
  245. case <-c.readDeadline.Done():
  246. return 0, errDeadlineExceeded
  247. default:
  248. }
  249. for {
  250. select {
  251. case <-c.readDeadline.Done():
  252. return 0, errDeadlineExceeded
  253. case out, ok := <-c.decrypted:
  254. if !ok {
  255. return 0, io.EOF
  256. }
  257. switch val := out.(type) {
  258. case ([]byte):
  259. if len(p) < len(val) {
  260. return 0, errBufferTooSmall
  261. }
  262. copy(p, val)
  263. return len(val), nil
  264. case (error):
  265. return 0, val
  266. }
  267. }
  268. }
  269. }
  270. // Write writes len(p) bytes from p to the DTLS connection
  271. func (c *Conn) Write(p []byte) (int, error) {
  272. if c.isConnectionClosed() {
  273. return 0, ErrConnClosed
  274. }
  275. select {
  276. case <-c.writeDeadline.Done():
  277. return 0, errDeadlineExceeded
  278. default:
  279. }
  280. if !c.isHandshakeCompletedSuccessfully() {
  281. return 0, errHandshakeInProgress
  282. }
  283. return len(p), c.writePackets(c.writeDeadline, []*packet{
  284. {
  285. record: &recordlayer.RecordLayer{
  286. Header: recordlayer.Header{
  287. Epoch: c.state.getLocalEpoch(),
  288. Version: protocol.Version1_2,
  289. },
  290. Content: &protocol.ApplicationData{
  291. Data: p,
  292. },
  293. },
  294. shouldEncrypt: true,
  295. },
  296. })
  297. }
  298. // Close closes the connection.
  299. func (c *Conn) Close() error {
  300. err := c.close(true) //nolint:contextcheck
  301. c.handshakeLoopsFinished.Wait()
  302. return err
  303. }
  304. // ConnectionState returns basic DTLS details about the connection.
  305. // Note that this replaced the `Export` function of v1.
  306. func (c *Conn) ConnectionState() State {
  307. c.lock.RLock()
  308. defer c.lock.RUnlock()
  309. return *c.state.clone()
  310. }
  311. // SelectedSRTPProtectionProfile returns the selected SRTPProtectionProfile
  312. func (c *Conn) SelectedSRTPProtectionProfile() (SRTPProtectionProfile, bool) {
  313. c.lock.RLock()
  314. defer c.lock.RUnlock()
  315. if c.state.srtpProtectionProfile == 0 {
  316. return 0, false
  317. }
  318. return c.state.srtpProtectionProfile, true
  319. }
  320. func (c *Conn) writePackets(ctx context.Context, pkts []*packet) error {
  321. c.lock.Lock()
  322. defer c.lock.Unlock()
  323. var rawPackets [][]byte
  324. for _, p := range pkts {
  325. if h, ok := p.record.Content.(*handshake.Handshake); ok {
  326. handshakeRaw, err := p.record.Marshal()
  327. if err != nil {
  328. return err
  329. }
  330. c.log.Tracef("[handshake:%v] -> %s (epoch: %d, seq: %d)",
  331. srvCliStr(c.state.isClient), h.Header.Type.String(),
  332. p.record.Header.Epoch, h.Header.MessageSequence)
  333. c.handshakeCache.push(handshakeRaw[recordlayer.HeaderSize:], p.record.Header.Epoch, h.Header.MessageSequence, h.Header.Type, c.state.isClient)
  334. rawHandshakePackets, err := c.processHandshakePacket(p, h)
  335. if err != nil {
  336. return err
  337. }
  338. rawPackets = append(rawPackets, rawHandshakePackets...)
  339. } else {
  340. rawPacket, err := c.processPacket(p)
  341. if err != nil {
  342. return err
  343. }
  344. rawPackets = append(rawPackets, rawPacket)
  345. }
  346. }
  347. if len(rawPackets) == 0 {
  348. return nil
  349. }
  350. compactedRawPackets := c.compactRawPackets(rawPackets)
  351. for _, compactedRawPackets := range compactedRawPackets {
  352. if _, err := c.nextConn.WriteContext(ctx, compactedRawPackets); err != nil {
  353. return netError(err)
  354. }
  355. }
  356. return nil
  357. }
  358. func (c *Conn) compactRawPackets(rawPackets [][]byte) [][]byte {
  359. combinedRawPackets := make([][]byte, 0)
  360. currentCombinedRawPacket := make([]byte, 0)
  361. for _, rawPacket := range rawPackets {
  362. if len(currentCombinedRawPacket) > 0 && len(currentCombinedRawPacket)+len(rawPacket) >= c.maximumTransmissionUnit {
  363. combinedRawPackets = append(combinedRawPackets, currentCombinedRawPacket)
  364. currentCombinedRawPacket = []byte{}
  365. }
  366. currentCombinedRawPacket = append(currentCombinedRawPacket, rawPacket...)
  367. }
  368. combinedRawPackets = append(combinedRawPackets, currentCombinedRawPacket)
  369. return combinedRawPackets
  370. }
  371. func (c *Conn) processPacket(p *packet) ([]byte, error) {
  372. epoch := p.record.Header.Epoch
  373. for len(c.state.localSequenceNumber) <= int(epoch) {
  374. c.state.localSequenceNumber = append(c.state.localSequenceNumber, uint64(0))
  375. }
  376. seq := atomic.AddUint64(&c.state.localSequenceNumber[epoch], 1) - 1
  377. if seq > recordlayer.MaxSequenceNumber {
  378. // RFC 6347 Section 4.1.0
  379. // The implementation must either abandon an association or rehandshake
  380. // prior to allowing the sequence number to wrap.
  381. return nil, errSequenceNumberOverflow
  382. }
  383. p.record.Header.SequenceNumber = seq
  384. rawPacket, err := p.record.Marshal()
  385. if err != nil {
  386. return nil, err
  387. }
  388. if p.shouldEncrypt {
  389. var err error
  390. rawPacket, err = c.state.cipherSuite.Encrypt(p.record, rawPacket)
  391. if err != nil {
  392. return nil, err
  393. }
  394. }
  395. return rawPacket, nil
  396. }
  397. func (c *Conn) processHandshakePacket(p *packet, h *handshake.Handshake) ([][]byte, error) {
  398. rawPackets := make([][]byte, 0)
  399. handshakeFragments, err := c.fragmentHandshake(h)
  400. if err != nil {
  401. return nil, err
  402. }
  403. epoch := p.record.Header.Epoch
  404. for len(c.state.localSequenceNumber) <= int(epoch) {
  405. c.state.localSequenceNumber = append(c.state.localSequenceNumber, uint64(0))
  406. }
  407. for _, handshakeFragment := range handshakeFragments {
  408. seq := atomic.AddUint64(&c.state.localSequenceNumber[epoch], 1) - 1
  409. if seq > recordlayer.MaxSequenceNumber {
  410. return nil, errSequenceNumberOverflow
  411. }
  412. recordlayerHeader := &recordlayer.Header{
  413. Version: p.record.Header.Version,
  414. ContentType: p.record.Header.ContentType,
  415. ContentLen: uint16(len(handshakeFragment)),
  416. Epoch: p.record.Header.Epoch,
  417. SequenceNumber: seq,
  418. }
  419. rawPacket, err := recordlayerHeader.Marshal()
  420. if err != nil {
  421. return nil, err
  422. }
  423. p.record.Header = *recordlayerHeader
  424. rawPacket = append(rawPacket, handshakeFragment...)
  425. if p.shouldEncrypt {
  426. var err error
  427. rawPacket, err = c.state.cipherSuite.Encrypt(p.record, rawPacket)
  428. if err != nil {
  429. return nil, err
  430. }
  431. }
  432. rawPackets = append(rawPackets, rawPacket)
  433. }
  434. return rawPackets, nil
  435. }
  436. func (c *Conn) fragmentHandshake(h *handshake.Handshake) ([][]byte, error) {
  437. content, err := h.Message.Marshal()
  438. if err != nil {
  439. return nil, err
  440. }
  441. fragmentedHandshakes := make([][]byte, 0)
  442. contentFragments := splitBytes(content, c.maximumTransmissionUnit)
  443. if len(contentFragments) == 0 {
  444. contentFragments = [][]byte{
  445. {},
  446. }
  447. }
  448. offset := 0
  449. for _, contentFragment := range contentFragments {
  450. contentFragmentLen := len(contentFragment)
  451. headerFragment := &handshake.Header{
  452. Type: h.Header.Type,
  453. Length: h.Header.Length,
  454. MessageSequence: h.Header.MessageSequence,
  455. FragmentOffset: uint32(offset),
  456. FragmentLength: uint32(contentFragmentLen),
  457. }
  458. offset += contentFragmentLen
  459. fragmentedHandshake, err := headerFragment.Marshal()
  460. if err != nil {
  461. return nil, err
  462. }
  463. fragmentedHandshake = append(fragmentedHandshake, contentFragment...)
  464. fragmentedHandshakes = append(fragmentedHandshakes, fragmentedHandshake)
  465. }
  466. return fragmentedHandshakes, nil
  467. }
  468. var poolReadBuffer = sync.Pool{ //nolint:gochecknoglobals
  469. New: func() interface{} {
  470. b := make([]byte, inboundBufferSize)
  471. return &b
  472. },
  473. }
  474. func (c *Conn) readAndBuffer(ctx context.Context) error {
  475. bufptr, ok := poolReadBuffer.Get().(*[]byte)
  476. if !ok {
  477. return errFailedToAccessPoolReadBuffer
  478. }
  479. defer poolReadBuffer.Put(bufptr)
  480. b := *bufptr
  481. i, err := c.nextConn.ReadContext(ctx, b)
  482. if err != nil {
  483. return netError(err)
  484. }
  485. pkts, err := recordlayer.UnpackDatagram(b[:i])
  486. if err != nil {
  487. return err
  488. }
  489. var hasHandshake bool
  490. for _, p := range pkts {
  491. hs, alert, err := c.handleIncomingPacket(ctx, p, true)
  492. if alert != nil {
  493. if alertErr := c.notify(ctx, alert.Level, alert.Description); alertErr != nil {
  494. if err == nil {
  495. err = alertErr
  496. }
  497. }
  498. }
  499. if hs {
  500. hasHandshake = true
  501. }
  502. var e *alertError
  503. if errors.As(err, &e) {
  504. if e.IsFatalOrCloseNotify() {
  505. return e
  506. }
  507. } else if err != nil {
  508. return e
  509. }
  510. }
  511. if hasHandshake {
  512. done := make(chan struct{})
  513. select {
  514. case c.handshakeRecv <- done:
  515. // If the other party may retransmit the flight,
  516. // we should respond even if it not a new message.
  517. <-done
  518. case <-c.fsm.Done():
  519. }
  520. }
  521. return nil
  522. }
  523. func (c *Conn) handleQueuedPackets(ctx context.Context) error {
  524. pkts := c.encryptedPackets
  525. c.encryptedPackets = nil
  526. for _, p := range pkts {
  527. _, alert, err := c.handleIncomingPacket(ctx, p, false) // don't re-enqueue
  528. if alert != nil {
  529. if alertErr := c.notify(ctx, alert.Level, alert.Description); alertErr != nil {
  530. if err == nil {
  531. err = alertErr
  532. }
  533. }
  534. }
  535. var e *alertError
  536. if errors.As(err, &e) {
  537. if e.IsFatalOrCloseNotify() {
  538. return e
  539. }
  540. } else if err != nil {
  541. return e
  542. }
  543. }
  544. return nil
  545. }
  546. func (c *Conn) handleIncomingPacket(ctx context.Context, buf []byte, enqueue bool) (bool, *alert.Alert, error) { //nolint:gocognit
  547. h := &recordlayer.Header{}
  548. if err := h.Unmarshal(buf); err != nil {
  549. // Decode error must be silently discarded
  550. // [RFC6347 Section-4.1.2.7]
  551. c.log.Debugf("discarded broken packet: %v", err)
  552. return false, nil, nil
  553. }
  554. // Validate epoch
  555. remoteEpoch := c.state.getRemoteEpoch()
  556. if h.Epoch > remoteEpoch {
  557. if h.Epoch > remoteEpoch+1 {
  558. c.log.Debugf("discarded future packet (epoch: %d, seq: %d)",
  559. h.Epoch, h.SequenceNumber,
  560. )
  561. return false, nil, nil
  562. }
  563. if enqueue {
  564. c.log.Debug("received packet of next epoch, queuing packet")
  565. c.encryptedPackets = append(c.encryptedPackets, buf)
  566. }
  567. return false, nil, nil
  568. }
  569. // Anti-replay protection
  570. for len(c.state.replayDetector) <= int(h.Epoch) {
  571. c.state.replayDetector = append(c.state.replayDetector,
  572. replaydetector.New(c.replayProtectionWindow, recordlayer.MaxSequenceNumber),
  573. )
  574. }
  575. markPacketAsValid, ok := c.state.replayDetector[int(h.Epoch)].Check(h.SequenceNumber)
  576. if !ok {
  577. c.log.Debugf("discarded duplicated packet (epoch: %d, seq: %d)",
  578. h.Epoch, h.SequenceNumber,
  579. )
  580. return false, nil, nil
  581. }
  582. // Decrypt
  583. if h.Epoch != 0 {
  584. if c.state.cipherSuite == nil || !c.state.cipherSuite.IsInitialized() {
  585. if enqueue {
  586. c.encryptedPackets = append(c.encryptedPackets, buf)
  587. c.log.Debug("handshake not finished, queuing packet")
  588. }
  589. return false, nil, nil
  590. }
  591. var err error
  592. buf, err = c.state.cipherSuite.Decrypt(buf)
  593. if err != nil {
  594. c.log.Debugf("%s: decrypt failed: %s", srvCliStr(c.state.isClient), err)
  595. return false, nil, nil
  596. }
  597. }
  598. isHandshake, err := c.fragmentBuffer.push(append([]byte{}, buf...))
  599. if err != nil {
  600. // Decode error must be silently discarded
  601. // [RFC6347 Section-4.1.2.7]
  602. c.log.Debugf("defragment failed: %s", err)
  603. return false, nil, nil
  604. } else if isHandshake {
  605. markPacketAsValid()
  606. for out, epoch := c.fragmentBuffer.pop(); out != nil; out, epoch = c.fragmentBuffer.pop() {
  607. header := &handshake.Header{}
  608. if err := header.Unmarshal(out); err != nil {
  609. c.log.Debugf("%s: handshake parse failed: %s", srvCliStr(c.state.isClient), err)
  610. continue
  611. }
  612. c.handshakeCache.push(out, epoch, header.MessageSequence, header.Type, !c.state.isClient)
  613. }
  614. return true, nil, nil
  615. }
  616. r := &recordlayer.RecordLayer{}
  617. if err := r.Unmarshal(buf); err != nil {
  618. return false, &alert.Alert{Level: alert.Fatal, Description: alert.DecodeError}, err
  619. }
  620. switch content := r.Content.(type) {
  621. case *alert.Alert:
  622. c.log.Tracef("%s: <- %s", srvCliStr(c.state.isClient), content.String())
  623. var a *alert.Alert
  624. if content.Description == alert.CloseNotify {
  625. // Respond with a close_notify [RFC5246 Section 7.2.1]
  626. a = &alert.Alert{Level: alert.Warning, Description: alert.CloseNotify}
  627. }
  628. markPacketAsValid()
  629. return false, a, &alertError{content}
  630. case *protocol.ChangeCipherSpec:
  631. if c.state.cipherSuite == nil || !c.state.cipherSuite.IsInitialized() {
  632. if enqueue {
  633. c.encryptedPackets = append(c.encryptedPackets, buf)
  634. c.log.Debugf("CipherSuite not initialized, queuing packet")
  635. }
  636. return false, nil, nil
  637. }
  638. newRemoteEpoch := h.Epoch + 1
  639. c.log.Tracef("%s: <- ChangeCipherSpec (epoch: %d)", srvCliStr(c.state.isClient), newRemoteEpoch)
  640. if c.state.getRemoteEpoch()+1 == newRemoteEpoch {
  641. c.setRemoteEpoch(newRemoteEpoch)
  642. markPacketAsValid()
  643. }
  644. case *protocol.ApplicationData:
  645. if h.Epoch == 0 {
  646. return false, &alert.Alert{Level: alert.Fatal, Description: alert.UnexpectedMessage}, errApplicationDataEpochZero
  647. }
  648. markPacketAsValid()
  649. select {
  650. case c.decrypted <- content.Data:
  651. case <-c.closed.Done():
  652. case <-ctx.Done():
  653. }
  654. default:
  655. return false, &alert.Alert{Level: alert.Fatal, Description: alert.UnexpectedMessage}, fmt.Errorf("%w: %d", errUnhandledContextType, content.ContentType())
  656. }
  657. return false, nil, nil
  658. }
  659. func (c *Conn) recvHandshake() <-chan chan struct{} {
  660. return c.handshakeRecv
  661. }
  662. func (c *Conn) notify(ctx context.Context, level alert.Level, desc alert.Description) error {
  663. if level == alert.Fatal && len(c.state.SessionID) > 0 {
  664. // According to the RFC, we need to delete the stored session.
  665. // https://datatracker.ietf.org/doc/html/rfc5246#section-7.2
  666. if ss := c.fsm.cfg.sessionStore; ss != nil {
  667. c.log.Tracef("clean invalid session: %s", c.state.SessionID)
  668. if err := ss.Del(c.sessionKey()); err != nil {
  669. return err
  670. }
  671. }
  672. }
  673. return c.writePackets(ctx, []*packet{
  674. {
  675. record: &recordlayer.RecordLayer{
  676. Header: recordlayer.Header{
  677. Epoch: c.state.getLocalEpoch(),
  678. Version: protocol.Version1_2,
  679. },
  680. Content: &alert.Alert{
  681. Level: level,
  682. Description: desc,
  683. },
  684. },
  685. shouldEncrypt: c.isHandshakeCompletedSuccessfully(),
  686. },
  687. })
  688. }
  689. func (c *Conn) setHandshakeCompletedSuccessfully() {
  690. c.handshakeCompletedSuccessfully.Store(struct{ bool }{true})
  691. }
  692. func (c *Conn) isHandshakeCompletedSuccessfully() bool {
  693. boolean, _ := c.handshakeCompletedSuccessfully.Load().(struct{ bool })
  694. return boolean.bool
  695. }
  696. func (c *Conn) handshake(ctx context.Context, cfg *handshakeConfig, initialFlight flightVal, initialState handshakeState) error { //nolint:gocognit
  697. c.fsm = newHandshakeFSM(&c.state, c.handshakeCache, cfg, initialFlight)
  698. done := make(chan struct{})
  699. ctxRead, cancelRead := context.WithCancel(context.Background())
  700. c.cancelHandshakeReader = cancelRead
  701. cfg.onFlightState = func(f flightVal, s handshakeState) {
  702. if s == handshakeFinished && !c.isHandshakeCompletedSuccessfully() {
  703. c.setHandshakeCompletedSuccessfully()
  704. close(done)
  705. }
  706. }
  707. ctxHs, cancel := context.WithCancel(context.Background())
  708. c.cancelHandshaker = cancel
  709. firstErr := make(chan error, 1)
  710. c.handshakeLoopsFinished.Add(2)
  711. // Handshake routine should be live until close.
  712. // The other party may request retransmission of the last flight to cope with packet drop.
  713. go func() {
  714. defer c.handshakeLoopsFinished.Done()
  715. err := c.fsm.Run(ctxHs, c, initialState)
  716. if !errors.Is(err, context.Canceled) {
  717. select {
  718. case firstErr <- err:
  719. default:
  720. }
  721. }
  722. }()
  723. go func() {
  724. defer func() {
  725. // Escaping read loop.
  726. // It's safe to close decrypted channnel now.
  727. close(c.decrypted)
  728. // Force stop handshaker when the underlying connection is closed.
  729. cancel()
  730. }()
  731. defer c.handshakeLoopsFinished.Done()
  732. for {
  733. if err := c.readAndBuffer(ctxRead); err != nil {
  734. var e *alertError
  735. if errors.As(err, &e) {
  736. if !e.IsFatalOrCloseNotify() {
  737. if c.isHandshakeCompletedSuccessfully() {
  738. // Pass the error to Read()
  739. select {
  740. case c.decrypted <- err:
  741. case <-c.closed.Done():
  742. case <-ctxRead.Done():
  743. }
  744. }
  745. continue // non-fatal alert must not stop read loop
  746. }
  747. } else {
  748. switch {
  749. case errors.Is(err, context.DeadlineExceeded), errors.Is(err, context.Canceled), errors.Is(err, io.EOF):
  750. default:
  751. if c.isHandshakeCompletedSuccessfully() {
  752. // Keep read loop and pass the read error to Read()
  753. select {
  754. case c.decrypted <- err:
  755. case <-c.closed.Done():
  756. case <-ctxRead.Done():
  757. }
  758. continue // non-fatal alert must not stop read loop
  759. }
  760. }
  761. }
  762. select {
  763. case firstErr <- err:
  764. default:
  765. }
  766. if e != nil {
  767. if e.IsFatalOrCloseNotify() {
  768. _ = c.close(false) //nolint:contextcheck
  769. }
  770. }
  771. if !c.isConnectionClosed() && errors.Is(err, context.Canceled) {
  772. c.log.Trace("handshake timeouts - closing underline connection")
  773. _ = c.close(false) //nolint:contextcheck
  774. }
  775. return
  776. }
  777. }
  778. }()
  779. select {
  780. case err := <-firstErr:
  781. cancelRead()
  782. cancel()
  783. c.handshakeLoopsFinished.Wait()
  784. return c.translateHandshakeCtxError(err)
  785. case <-ctx.Done():
  786. cancelRead()
  787. cancel()
  788. c.handshakeLoopsFinished.Wait()
  789. return c.translateHandshakeCtxError(ctx.Err())
  790. case <-done:
  791. return nil
  792. }
  793. }
  794. func (c *Conn) translateHandshakeCtxError(err error) error {
  795. if err == nil {
  796. return nil
  797. }
  798. if errors.Is(err, context.Canceled) && c.isHandshakeCompletedSuccessfully() {
  799. return nil
  800. }
  801. return &HandshakeError{Err: err}
  802. }
  803. func (c *Conn) close(byUser bool) error {
  804. c.cancelHandshaker()
  805. c.cancelHandshakeReader()
  806. if c.isHandshakeCompletedSuccessfully() && byUser {
  807. // Discard error from notify() to return non-error on the first user call of Close()
  808. // even if the underlying connection is already closed.
  809. _ = c.notify(context.Background(), alert.Warning, alert.CloseNotify)
  810. }
  811. c.closeLock.Lock()
  812. // Don't return ErrConnClosed at the first time of the call from user.
  813. closedByUser := c.connectionClosedByUser
  814. if byUser {
  815. c.connectionClosedByUser = true
  816. }
  817. isClosed := c.isConnectionClosed()
  818. c.closed.Close()
  819. c.closeLock.Unlock()
  820. if closedByUser {
  821. return ErrConnClosed
  822. }
  823. if isClosed {
  824. return nil
  825. }
  826. return c.nextConn.Close()
  827. }
  828. func (c *Conn) isConnectionClosed() bool {
  829. select {
  830. case <-c.closed.Done():
  831. return true
  832. default:
  833. return false
  834. }
  835. }
  836. func (c *Conn) setLocalEpoch(epoch uint16) {
  837. c.state.localEpoch.Store(epoch)
  838. }
  839. func (c *Conn) setRemoteEpoch(epoch uint16) {
  840. c.state.remoteEpoch.Store(epoch)
  841. }
  842. // LocalAddr implements net.Conn.LocalAddr
  843. func (c *Conn) LocalAddr() net.Addr {
  844. return c.nextConn.LocalAddr()
  845. }
  846. // RemoteAddr implements net.Conn.RemoteAddr
  847. func (c *Conn) RemoteAddr() net.Addr {
  848. return c.nextConn.RemoteAddr()
  849. }
  850. func (c *Conn) sessionKey() []byte {
  851. if c.state.isClient {
  852. // As ServerName can be like 0.example.com, it's better to add
  853. // delimiter character which is not allowed to be in
  854. // neither address or domain name.
  855. return []byte(c.nextConn.RemoteAddr().String() + "_" + c.fsm.cfg.serverName)
  856. }
  857. return c.state.SessionID
  858. }
  859. // SetDeadline implements net.Conn.SetDeadline
  860. func (c *Conn) SetDeadline(t time.Time) error {
  861. c.readDeadline.Set(t)
  862. return c.SetWriteDeadline(t)
  863. }
  864. // SetReadDeadline implements net.Conn.SetReadDeadline
  865. func (c *Conn) SetReadDeadline(t time.Time) error {
  866. c.readDeadline.Set(t)
  867. // Read deadline is fully managed by this layer.
  868. // Don't set read deadline to underlying connection.
  869. return nil
  870. }
  871. // SetWriteDeadline implements net.Conn.SetWriteDeadline
  872. func (c *Conn) SetWriteDeadline(t time.Time) error {
  873. c.writeDeadline.Set(t)
  874. // Write deadline is also fully managed by this layer.
  875. return nil
  876. }