dtlstransport.go 13 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499
  1. //go:build !js
  2. // +build !js
  3. package webrtc
  4. import (
  5. "crypto/ecdsa"
  6. "crypto/elliptic"
  7. "crypto/rand"
  8. "crypto/tls"
  9. "crypto/x509"
  10. "errors"
  11. "fmt"
  12. "strings"
  13. "sync"
  14. "sync/atomic"
  15. "time"
  16. "github.com/pion/dtls/v2"
  17. "github.com/pion/dtls/v2/pkg/crypto/fingerprint"
  18. "github.com/pion/interceptor"
  19. "github.com/pion/logging"
  20. "github.com/pion/rtcp"
  21. "github.com/pion/srtp/v2"
  22. "github.com/pion/webrtc/v3/internal/mux"
  23. "github.com/pion/webrtc/v3/internal/util"
  24. "github.com/pion/webrtc/v3/pkg/rtcerr"
  25. )
  26. // DTLSTransport allows an application access to information about the DTLS
  27. // transport over which RTP and RTCP packets are sent and received by
  28. // RTPSender and RTPReceiver, as well other data such as SCTP packets sent
  29. // and received by data channels.
  30. type DTLSTransport struct {
  31. lock sync.RWMutex
  32. iceTransport *ICETransport
  33. certificates []Certificate
  34. remoteParameters DTLSParameters
  35. remoteCertificate []byte
  36. state DTLSTransportState
  37. srtpProtectionProfile srtp.ProtectionProfile
  38. onStateChangeHandler func(DTLSTransportState)
  39. conn *dtls.Conn
  40. srtpSession, srtcpSession atomic.Value
  41. srtpEndpoint, srtcpEndpoint *mux.Endpoint
  42. simulcastStreams []*srtp.ReadStreamSRTP
  43. srtpReady chan struct{}
  44. dtlsMatcher mux.MatchFunc
  45. api *API
  46. log logging.LeveledLogger
  47. }
  48. // NewDTLSTransport creates a new DTLSTransport.
  49. // This constructor is part of the ORTC API. It is not
  50. // meant to be used together with the basic WebRTC API.
  51. func (api *API) NewDTLSTransport(transport *ICETransport, certificates []Certificate) (*DTLSTransport, error) {
  52. t := &DTLSTransport{
  53. iceTransport: transport,
  54. api: api,
  55. state: DTLSTransportStateNew,
  56. dtlsMatcher: mux.MatchDTLS,
  57. srtpReady: make(chan struct{}),
  58. log: api.settingEngine.LoggerFactory.NewLogger("DTLSTransport"),
  59. }
  60. if len(certificates) > 0 {
  61. now := time.Now()
  62. for _, x509Cert := range certificates {
  63. if !x509Cert.Expires().IsZero() && now.After(x509Cert.Expires()) {
  64. return nil, &rtcerr.InvalidAccessError{Err: ErrCertificateExpired}
  65. }
  66. t.certificates = append(t.certificates, x509Cert)
  67. }
  68. } else {
  69. sk, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader)
  70. if err != nil {
  71. return nil, &rtcerr.UnknownError{Err: err}
  72. }
  73. certificate, err := GenerateCertificate(sk)
  74. if err != nil {
  75. return nil, err
  76. }
  77. t.certificates = []Certificate{*certificate}
  78. }
  79. return t, nil
  80. }
  81. // ICETransport returns the currently-configured *ICETransport or nil
  82. // if one has not been configured
  83. func (t *DTLSTransport) ICETransport() *ICETransport {
  84. t.lock.RLock()
  85. defer t.lock.RUnlock()
  86. return t.iceTransport
  87. }
  88. // onStateChange requires the caller holds the lock
  89. func (t *DTLSTransport) onStateChange(state DTLSTransportState) {
  90. t.state = state
  91. handler := t.onStateChangeHandler
  92. if handler != nil {
  93. handler(state)
  94. }
  95. }
  96. // OnStateChange sets a handler that is fired when the DTLS
  97. // connection state changes.
  98. func (t *DTLSTransport) OnStateChange(f func(DTLSTransportState)) {
  99. t.lock.Lock()
  100. defer t.lock.Unlock()
  101. t.onStateChangeHandler = f
  102. }
  103. // State returns the current dtls transport state.
  104. func (t *DTLSTransport) State() DTLSTransportState {
  105. t.lock.RLock()
  106. defer t.lock.RUnlock()
  107. return t.state
  108. }
  109. // WriteRTCP sends a user provided RTCP packet to the connected peer. If no peer is connected the
  110. // packet is discarded.
  111. func (t *DTLSTransport) WriteRTCP(pkts []rtcp.Packet) (int, error) {
  112. raw, err := rtcp.Marshal(pkts)
  113. if err != nil {
  114. return 0, err
  115. }
  116. srtcpSession, err := t.getSRTCPSession()
  117. if err != nil {
  118. return 0, err
  119. }
  120. writeStream, err := srtcpSession.OpenWriteStream()
  121. if err != nil {
  122. return 0, fmt.Errorf("%w: %v", errPeerConnWriteRTCPOpenWriteStream, err)
  123. }
  124. if n, err := writeStream.Write(raw); err != nil {
  125. return n, err
  126. }
  127. return 0, nil
  128. }
  129. // GetLocalParameters returns the DTLS parameters of the local DTLSTransport upon construction.
  130. func (t *DTLSTransport) GetLocalParameters() (DTLSParameters, error) {
  131. fingerprints := []DTLSFingerprint{}
  132. for _, c := range t.certificates {
  133. prints, err := c.GetFingerprints()
  134. if err != nil {
  135. return DTLSParameters{}, err
  136. }
  137. fingerprints = append(fingerprints, prints...)
  138. }
  139. return DTLSParameters{
  140. Role: DTLSRoleAuto, // always returns the default role
  141. Fingerprints: fingerprints,
  142. }, nil
  143. }
  144. // GetRemoteCertificate returns the certificate chain in use by the remote side
  145. // returns an empty list prior to selection of the remote certificate
  146. func (t *DTLSTransport) GetRemoteCertificate() []byte {
  147. t.lock.RLock()
  148. defer t.lock.RUnlock()
  149. return t.remoteCertificate
  150. }
  151. func (t *DTLSTransport) startSRTP() error {
  152. srtpConfig := &srtp.Config{
  153. Profile: t.srtpProtectionProfile,
  154. BufferFactory: t.api.settingEngine.BufferFactory,
  155. LoggerFactory: t.api.settingEngine.LoggerFactory,
  156. }
  157. if t.api.settingEngine.replayProtection.SRTP != nil {
  158. srtpConfig.RemoteOptions = append(
  159. srtpConfig.RemoteOptions,
  160. srtp.SRTPReplayProtection(*t.api.settingEngine.replayProtection.SRTP),
  161. )
  162. }
  163. if t.api.settingEngine.disableSRTPReplayProtection {
  164. srtpConfig.RemoteOptions = append(
  165. srtpConfig.RemoteOptions,
  166. srtp.SRTPNoReplayProtection(),
  167. )
  168. }
  169. if t.api.settingEngine.replayProtection.SRTCP != nil {
  170. srtpConfig.RemoteOptions = append(
  171. srtpConfig.RemoteOptions,
  172. srtp.SRTCPReplayProtection(*t.api.settingEngine.replayProtection.SRTCP),
  173. )
  174. }
  175. if t.api.settingEngine.disableSRTCPReplayProtection {
  176. srtpConfig.RemoteOptions = append(
  177. srtpConfig.RemoteOptions,
  178. srtp.SRTCPNoReplayProtection(),
  179. )
  180. }
  181. connState := t.conn.ConnectionState()
  182. err := srtpConfig.ExtractSessionKeysFromDTLS(&connState, t.role() == DTLSRoleClient)
  183. if err != nil {
  184. return fmt.Errorf("%w: %v", errDtlsKeyExtractionFailed, err)
  185. }
  186. srtpSession, err := srtp.NewSessionSRTP(t.srtpEndpoint, srtpConfig)
  187. if err != nil {
  188. return fmt.Errorf("%w: %v", errFailedToStartSRTP, err)
  189. }
  190. srtcpSession, err := srtp.NewSessionSRTCP(t.srtcpEndpoint, srtpConfig)
  191. if err != nil {
  192. return fmt.Errorf("%w: %v", errFailedToStartSRTCP, err)
  193. }
  194. t.srtpSession.Store(srtpSession)
  195. t.srtcpSession.Store(srtcpSession)
  196. close(t.srtpReady)
  197. return nil
  198. }
  199. func (t *DTLSTransport) getSRTPSession() (*srtp.SessionSRTP, error) {
  200. if value, ok := t.srtpSession.Load().(*srtp.SessionSRTP); ok {
  201. return value, nil
  202. }
  203. return nil, errDtlsTransportNotStarted
  204. }
  205. func (t *DTLSTransport) getSRTCPSession() (*srtp.SessionSRTCP, error) {
  206. if value, ok := t.srtcpSession.Load().(*srtp.SessionSRTCP); ok {
  207. return value, nil
  208. }
  209. return nil, errDtlsTransportNotStarted
  210. }
  211. func (t *DTLSTransport) role() DTLSRole {
  212. // If remote has an explicit role use the inverse
  213. switch t.remoteParameters.Role {
  214. case DTLSRoleClient:
  215. return DTLSRoleServer
  216. case DTLSRoleServer:
  217. return DTLSRoleClient
  218. default:
  219. }
  220. // If SettingEngine has an explicit role
  221. switch t.api.settingEngine.answeringDTLSRole {
  222. case DTLSRoleServer:
  223. return DTLSRoleServer
  224. case DTLSRoleClient:
  225. return DTLSRoleClient
  226. default:
  227. }
  228. // Remote was auto and no explicit role was configured via SettingEngine
  229. if t.iceTransport.Role() == ICERoleControlling {
  230. return DTLSRoleServer
  231. }
  232. return defaultDtlsRoleAnswer
  233. }
  234. // Start DTLS transport negotiation with the parameters of the remote DTLS transport
  235. func (t *DTLSTransport) Start(remoteParameters DTLSParameters) error {
  236. // Take lock and prepare connection, we must not hold the lock
  237. // when connecting
  238. prepareTransport := func() (DTLSRole, *dtls.Config, error) {
  239. t.lock.Lock()
  240. defer t.lock.Unlock()
  241. if err := t.ensureICEConn(); err != nil {
  242. return DTLSRole(0), nil, err
  243. }
  244. if t.state != DTLSTransportStateNew {
  245. return DTLSRole(0), nil, &rtcerr.InvalidStateError{Err: fmt.Errorf("%w: %s", errInvalidDTLSStart, t.state)}
  246. }
  247. t.srtpEndpoint = t.iceTransport.newEndpoint(mux.MatchSRTP)
  248. t.srtcpEndpoint = t.iceTransport.newEndpoint(mux.MatchSRTCP)
  249. t.remoteParameters = remoteParameters
  250. cert := t.certificates[0]
  251. t.onStateChange(DTLSTransportStateConnecting)
  252. return t.role(), &dtls.Config{
  253. Certificates: []tls.Certificate{
  254. {
  255. Certificate: [][]byte{cert.x509Cert.Raw},
  256. PrivateKey: cert.privateKey,
  257. },
  258. },
  259. SRTPProtectionProfiles: func() []dtls.SRTPProtectionProfile {
  260. if len(t.api.settingEngine.srtpProtectionProfiles) > 0 {
  261. return t.api.settingEngine.srtpProtectionProfiles
  262. }
  263. return defaultSrtpProtectionProfiles()
  264. }(),
  265. ClientAuth: dtls.RequireAnyClientCert,
  266. LoggerFactory: t.api.settingEngine.LoggerFactory,
  267. InsecureSkipVerify: true,
  268. }, nil
  269. }
  270. var dtlsConn *dtls.Conn
  271. dtlsEndpoint := t.iceTransport.newEndpoint(mux.MatchDTLS)
  272. role, dtlsConfig, err := prepareTransport()
  273. if err != nil {
  274. return err
  275. }
  276. if t.api.settingEngine.replayProtection.DTLS != nil {
  277. dtlsConfig.ReplayProtectionWindow = int(*t.api.settingEngine.replayProtection.DTLS)
  278. }
  279. if t.api.settingEngine.dtls.retransmissionInterval != 0 {
  280. dtlsConfig.FlightInterval = t.api.settingEngine.dtls.retransmissionInterval
  281. }
  282. // Connect as DTLS Client/Server, function is blocking and we
  283. // must not hold the DTLSTransport lock
  284. if role == DTLSRoleClient {
  285. dtlsConn, err = dtls.Client(dtlsEndpoint, dtlsConfig)
  286. } else {
  287. dtlsConn, err = dtls.Server(dtlsEndpoint, dtlsConfig)
  288. }
  289. // Re-take the lock, nothing beyond here is blocking
  290. t.lock.Lock()
  291. defer t.lock.Unlock()
  292. if err != nil {
  293. t.onStateChange(DTLSTransportStateFailed)
  294. return err
  295. }
  296. srtpProfile, ok := dtlsConn.SelectedSRTPProtectionProfile()
  297. if !ok {
  298. t.onStateChange(DTLSTransportStateFailed)
  299. return ErrNoSRTPProtectionProfile
  300. }
  301. switch srtpProfile {
  302. case dtls.SRTP_AEAD_AES_128_GCM:
  303. t.srtpProtectionProfile = srtp.ProtectionProfileAeadAes128Gcm
  304. case dtls.SRTP_AES128_CM_HMAC_SHA1_80:
  305. t.srtpProtectionProfile = srtp.ProtectionProfileAes128CmHmacSha1_80
  306. default:
  307. t.onStateChange(DTLSTransportStateFailed)
  308. return ErrNoSRTPProtectionProfile
  309. }
  310. // Check the fingerprint if a certificate was exchanged
  311. remoteCerts := dtlsConn.ConnectionState().PeerCertificates
  312. if len(remoteCerts) == 0 {
  313. t.onStateChange(DTLSTransportStateFailed)
  314. return errNoRemoteCertificate
  315. }
  316. t.remoteCertificate = remoteCerts[0]
  317. if !t.api.settingEngine.disableCertificateFingerprintVerification {
  318. parsedRemoteCert, err := x509.ParseCertificate(t.remoteCertificate)
  319. if err != nil {
  320. if closeErr := dtlsConn.Close(); closeErr != nil {
  321. t.log.Error(err.Error())
  322. }
  323. t.onStateChange(DTLSTransportStateFailed)
  324. return err
  325. }
  326. if err = t.validateFingerPrint(parsedRemoteCert); err != nil {
  327. if closeErr := dtlsConn.Close(); closeErr != nil {
  328. t.log.Error(err.Error())
  329. }
  330. t.onStateChange(DTLSTransportStateFailed)
  331. return err
  332. }
  333. }
  334. t.conn = dtlsConn
  335. t.onStateChange(DTLSTransportStateConnected)
  336. return t.startSRTP()
  337. }
  338. // Stop stops and closes the DTLSTransport object.
  339. func (t *DTLSTransport) Stop() error {
  340. t.lock.Lock()
  341. defer t.lock.Unlock()
  342. // Try closing everything and collect the errors
  343. var closeErrs []error
  344. if srtpSession, err := t.getSRTPSession(); err == nil && srtpSession != nil {
  345. closeErrs = append(closeErrs, srtpSession.Close())
  346. }
  347. if srtcpSession, err := t.getSRTCPSession(); err == nil && srtcpSession != nil {
  348. closeErrs = append(closeErrs, srtcpSession.Close())
  349. }
  350. for i := range t.simulcastStreams {
  351. closeErrs = append(closeErrs, t.simulcastStreams[i].Close())
  352. }
  353. if t.conn != nil {
  354. // dtls connection may be closed on sctp close.
  355. if err := t.conn.Close(); err != nil && !errors.Is(err, dtls.ErrConnClosed) {
  356. closeErrs = append(closeErrs, err)
  357. }
  358. }
  359. t.onStateChange(DTLSTransportStateClosed)
  360. return util.FlattenErrs(closeErrs)
  361. }
  362. func (t *DTLSTransport) validateFingerPrint(remoteCert *x509.Certificate) error {
  363. for _, fp := range t.remoteParameters.Fingerprints {
  364. hashAlgo, err := fingerprint.HashFromString(fp.Algorithm)
  365. if err != nil {
  366. return err
  367. }
  368. remoteValue, err := fingerprint.Fingerprint(remoteCert, hashAlgo)
  369. if err != nil {
  370. return err
  371. }
  372. if strings.EqualFold(remoteValue, fp.Value) {
  373. return nil
  374. }
  375. }
  376. return errNoMatchingCertificateFingerprint
  377. }
  378. func (t *DTLSTransport) ensureICEConn() error {
  379. if t.iceTransport == nil {
  380. return errICEConnectionNotStarted
  381. }
  382. return nil
  383. }
  384. func (t *DTLSTransport) storeSimulcastStream(s *srtp.ReadStreamSRTP) {
  385. t.lock.Lock()
  386. defer t.lock.Unlock()
  387. t.simulcastStreams = append(t.simulcastStreams, s)
  388. }
  389. func (t *DTLSTransport) streamsForSSRC(ssrc SSRC, streamInfo interceptor.StreamInfo) (*srtp.ReadStreamSRTP, interceptor.RTPReader, *srtp.ReadStreamSRTCP, interceptor.RTCPReader, error) {
  390. srtpSession, err := t.getSRTPSession()
  391. if err != nil {
  392. return nil, nil, nil, nil, err
  393. }
  394. rtpReadStream, err := srtpSession.OpenReadStream(uint32(ssrc))
  395. if err != nil {
  396. return nil, nil, nil, nil, err
  397. }
  398. rtpInterceptor := t.api.interceptor.BindRemoteStream(&streamInfo, interceptor.RTPReaderFunc(func(in []byte, a interceptor.Attributes) (n int, attributes interceptor.Attributes, err error) {
  399. n, err = rtpReadStream.Read(in)
  400. return n, a, err
  401. }))
  402. srtcpSession, err := t.getSRTCPSession()
  403. if err != nil {
  404. return nil, nil, nil, nil, err
  405. }
  406. rtcpReadStream, err := srtcpSession.OpenReadStream(uint32(ssrc))
  407. if err != nil {
  408. return nil, nil, nil, nil, err
  409. }
  410. rtcpInterceptor := t.api.interceptor.BindRTCPReader(interceptor.RTPReaderFunc(func(in []byte, a interceptor.Attributes) (n int, attributes interceptor.Attributes, err error) {
  411. n, err = rtcpReadStream.Read(in)
  412. return n, a, err
  413. }))
  414. return rtpReadStream, rtpInterceptor, rtcpReadStream, rtcpInterceptor, nil
  415. }