client.go 14 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631
  1. package stun
  2. import (
  3. "errors"
  4. "fmt"
  5. "io"
  6. "log"
  7. "net"
  8. "runtime"
  9. "sync"
  10. "sync/atomic"
  11. "time"
  12. )
  13. // Dial connects to the address on the named network and then
  14. // initializes Client on that connection, returning error if any.
  15. func Dial(network, address string) (*Client, error) {
  16. conn, err := net.Dial(network, address)
  17. if err != nil {
  18. return nil, err
  19. }
  20. return NewClient(conn)
  21. }
  22. // ErrNoConnection means that ClientOptions.Connection is nil.
  23. var ErrNoConnection = errors.New("no connection provided")
  24. // ClientOption sets some client option.
  25. type ClientOption func(c *Client)
  26. // WithHandler sets client handler which is called if Agent emits the Event
  27. // with TransactionID that is not currently registered by Client.
  28. // Useful for handling Data indications from TURN server.
  29. func WithHandler(h Handler) ClientOption {
  30. return func(c *Client) {
  31. c.handler = h
  32. }
  33. }
  34. // WithRTO sets client RTO as defined in STUN RFC.
  35. func WithRTO(rto time.Duration) ClientOption {
  36. return func(c *Client) {
  37. c.rto = int64(rto)
  38. }
  39. }
  40. // WithClock sets Clock of client, the source of current time.
  41. // Also clock is passed to default collector if set.
  42. func WithClock(clock Clock) ClientOption {
  43. return func(c *Client) {
  44. c.clock = clock
  45. }
  46. }
  47. // WithTimeoutRate sets RTO timer minimum resolution.
  48. func WithTimeoutRate(d time.Duration) ClientOption {
  49. return func(c *Client) {
  50. c.rtoRate = d
  51. }
  52. }
  53. // WithAgent sets client STUN agent.
  54. //
  55. // Defaults to agent implementation in current package,
  56. // see agent.go.
  57. func WithAgent(a ClientAgent) ClientOption {
  58. return func(c *Client) {
  59. c.a = a
  60. }
  61. }
  62. // WithCollector rests client timeout collector, the implementation
  63. // of ticker which calls function on each tick.
  64. func WithCollector(coll Collector) ClientOption {
  65. return func(c *Client) {
  66. c.collector = coll
  67. }
  68. }
  69. // WithNoConnClose prevents client from closing underlying connection when
  70. // the Close() method is called.
  71. var WithNoConnClose ClientOption = func(c *Client) {
  72. c.closeConn = false
  73. }
  74. // WithNoRetransmit disables retransmissions and sets RTO to
  75. // defaultMaxAttempts * defaultRTO which will be effectively time out
  76. // if not set.
  77. //
  78. // Useful for TCP connections where transport handles RTO.
  79. func WithNoRetransmit(c *Client) {
  80. c.maxAttempts = 0
  81. if c.rto == 0 {
  82. c.rto = defaultMaxAttempts * int64(defaultRTO)
  83. }
  84. }
  85. const (
  86. defaultTimeoutRate = time.Millisecond * 5
  87. defaultRTO = time.Millisecond * 300
  88. defaultMaxAttempts = 7
  89. )
  90. // NewClient initializes new Client from provided options,
  91. // starting internal goroutines and using default options fields
  92. // if necessary. Call Close method after using Client to close conn and
  93. // release resources.
  94. //
  95. // The conn will be closed on Close call. Use WithNoConnClose option to
  96. // prevent that.
  97. //
  98. // Note that user should handle the protocol multiplexing, client does not
  99. // provide any API for it, so if you need to read application data, wrap the
  100. // connection with your (de-)multiplexer and pass the wrapper as conn.
  101. func NewClient(conn Connection, options ...ClientOption) (*Client, error) {
  102. c := &Client{
  103. close: make(chan struct{}),
  104. c: conn,
  105. clock: systemClock,
  106. rto: int64(defaultRTO),
  107. rtoRate: defaultTimeoutRate,
  108. t: make(map[transactionID]*clientTransaction, 100),
  109. maxAttempts: defaultMaxAttempts,
  110. closeConn: true,
  111. }
  112. for _, o := range options {
  113. o(c)
  114. }
  115. if c.c == nil {
  116. return nil, ErrNoConnection
  117. }
  118. if c.a == nil {
  119. c.a = NewAgent(nil)
  120. }
  121. if err := c.a.SetHandler(c.handleAgentCallback); err != nil {
  122. return nil, err
  123. }
  124. if c.collector == nil {
  125. c.collector = &tickerCollector{
  126. close: make(chan struct{}),
  127. clock: c.clock,
  128. }
  129. }
  130. if err := c.collector.Start(c.rtoRate, func(t time.Time) {
  131. closedOrPanic(c.a.Collect(t))
  132. }); err != nil {
  133. return nil, err
  134. }
  135. c.wg.Add(1)
  136. go c.readUntilClosed()
  137. runtime.SetFinalizer(c, clientFinalizer)
  138. return c, nil
  139. }
  140. func clientFinalizer(c *Client) {
  141. if c == nil {
  142. return
  143. }
  144. err := c.Close()
  145. if err == ErrClientClosed {
  146. return
  147. }
  148. if err == nil {
  149. log.Println("client: called finalizer on non-closed client") // nolint
  150. return
  151. }
  152. log.Println("client: called finalizer on non-closed client:", err) // nolint
  153. }
  154. // Connection wraps Reader, Writer and Closer interfaces.
  155. type Connection interface {
  156. io.Reader
  157. io.Writer
  158. io.Closer
  159. }
  160. // ClientAgent is Agent implementation that is used by Client to
  161. // process transactions.
  162. type ClientAgent interface {
  163. Process(*Message) error
  164. Close() error
  165. Start(id [TransactionIDSize]byte, deadline time.Time) error
  166. Stop(id [TransactionIDSize]byte) error
  167. Collect(time.Time) error
  168. SetHandler(h Handler) error
  169. }
  170. // Client simulates "connection" to STUN server.
  171. type Client struct {
  172. rto int64 // time.Duration
  173. a ClientAgent
  174. c Connection
  175. close chan struct{}
  176. rtoRate time.Duration
  177. maxAttempts int32
  178. closed bool
  179. closeConn bool // should call c.Close() while closing
  180. wg sync.WaitGroup
  181. clock Clock
  182. handler Handler
  183. collector Collector
  184. t map[transactionID]*clientTransaction
  185. // mux guards closed and t
  186. mux sync.RWMutex
  187. }
  188. // clientTransaction represents transaction in progress.
  189. // If transaction is succeed or failed, f will be called
  190. // provided by event.
  191. // Concurrent access is invalid.
  192. type clientTransaction struct {
  193. id transactionID
  194. attempt int32
  195. calls int32
  196. h Handler
  197. start time.Time
  198. rto time.Duration
  199. raw []byte
  200. }
  201. func (t *clientTransaction) handle(e Event) {
  202. if atomic.AddInt32(&t.calls, 1) == 1 {
  203. t.h(e)
  204. }
  205. }
  206. var clientTransactionPool = &sync.Pool{
  207. New: func() interface{} {
  208. return &clientTransaction{
  209. raw: make([]byte, 1500),
  210. }
  211. },
  212. }
  213. func acquireClientTransaction() *clientTransaction {
  214. return clientTransactionPool.Get().(*clientTransaction)
  215. }
  216. func putClientTransaction(t *clientTransaction) {
  217. t.raw = t.raw[:0]
  218. t.start = time.Time{}
  219. t.attempt = 0
  220. t.id = transactionID{}
  221. clientTransactionPool.Put(t)
  222. }
  223. func (t *clientTransaction) nextTimeout(now time.Time) time.Time {
  224. return now.Add(time.Duration(t.attempt+1) * t.rto)
  225. }
  226. // start registers transaction.
  227. //
  228. // Could return ErrClientClosed, ErrTransactionExists.
  229. func (c *Client) start(t *clientTransaction) error {
  230. c.mux.Lock()
  231. defer c.mux.Unlock()
  232. if c.closed {
  233. return ErrClientClosed
  234. }
  235. _, exists := c.t[t.id]
  236. if exists {
  237. return ErrTransactionExists
  238. }
  239. c.t[t.id] = t
  240. return nil
  241. }
  242. // Clock abstracts the source of current time.
  243. type Clock interface {
  244. Now() time.Time
  245. }
  246. type systemClockService struct{}
  247. func (systemClockService) Now() time.Time { return time.Now() }
  248. var systemClock = systemClockService{}
  249. // SetRTO sets current RTO value.
  250. func (c *Client) SetRTO(rto time.Duration) {
  251. atomic.StoreInt64(&c.rto, int64(rto))
  252. }
  253. // StopErr occurs when Client fails to stop transaction while
  254. // processing error.
  255. type StopErr struct {
  256. Err error // value returned by Stop()
  257. Cause error // error that caused Stop() call
  258. }
  259. func (e StopErr) Error() string {
  260. return fmt.Sprintf("error while stopping due to %s: %s", sprintErr(e.Cause), sprintErr(e.Err))
  261. }
  262. // CloseErr indicates client close failure.
  263. type CloseErr struct {
  264. AgentErr error
  265. ConnectionErr error
  266. }
  267. func sprintErr(err error) string {
  268. if err == nil {
  269. return "<nil>"
  270. }
  271. return err.Error()
  272. }
  273. func (c CloseErr) Error() string {
  274. return fmt.Sprintf("failed to close: %s (connection), %s (agent)", sprintErr(c.ConnectionErr), sprintErr(c.AgentErr))
  275. }
  276. func (c *Client) readUntilClosed() {
  277. defer c.wg.Done()
  278. m := new(Message)
  279. m.Raw = make([]byte, 1024)
  280. for {
  281. select {
  282. case <-c.close:
  283. return
  284. default:
  285. }
  286. _, err := m.ReadFrom(c.c)
  287. if err == nil {
  288. if pErr := c.a.Process(m); pErr == ErrAgentClosed {
  289. return
  290. }
  291. }
  292. }
  293. }
  294. func closedOrPanic(err error) {
  295. if err == nil || err == ErrAgentClosed {
  296. return
  297. }
  298. panic(err) // nolint
  299. }
  300. type tickerCollector struct {
  301. close chan struct{}
  302. wg sync.WaitGroup
  303. clock Clock
  304. }
  305. // Collector calls function f with constant rate.
  306. //
  307. // The simple Collector is ticker which calls function on each tick.
  308. type Collector interface {
  309. Start(rate time.Duration, f func(now time.Time)) error
  310. Close() error
  311. }
  312. func (a *tickerCollector) Start(rate time.Duration, f func(now time.Time)) error {
  313. t := time.NewTicker(rate)
  314. a.wg.Add(1)
  315. go func() {
  316. defer a.wg.Done()
  317. for {
  318. select {
  319. case <-a.close:
  320. t.Stop()
  321. return
  322. case <-t.C:
  323. f(a.clock.Now())
  324. }
  325. }
  326. }()
  327. return nil
  328. }
  329. func (a *tickerCollector) Close() error {
  330. close(a.close)
  331. a.wg.Wait()
  332. return nil
  333. }
  334. // ErrClientClosed indicates that client is closed.
  335. var ErrClientClosed = errors.New("client is closed")
  336. // Close stops internal connection and agent, returning CloseErr on error.
  337. func (c *Client) Close() error {
  338. if err := c.checkInit(); err != nil {
  339. return err
  340. }
  341. c.mux.Lock()
  342. if c.closed {
  343. c.mux.Unlock()
  344. return ErrClientClosed
  345. }
  346. c.closed = true
  347. c.mux.Unlock()
  348. if closeErr := c.collector.Close(); closeErr != nil {
  349. return closeErr
  350. }
  351. var connErr error
  352. agentErr := c.a.Close()
  353. if c.closeConn {
  354. connErr = c.c.Close()
  355. }
  356. close(c.close)
  357. c.wg.Wait()
  358. if agentErr == nil && connErr == nil {
  359. return nil
  360. }
  361. return CloseErr{
  362. AgentErr: agentErr,
  363. ConnectionErr: connErr,
  364. }
  365. }
  366. // Indicate sends indication m to server. Shorthand to Start call
  367. // with zero deadline and callback.
  368. func (c *Client) Indicate(m *Message) error {
  369. return c.Start(m, nil)
  370. }
  371. // callbackWaitHandler blocks on wait() call until callback is called.
  372. type callbackWaitHandler struct {
  373. handler Handler
  374. callback func(event Event)
  375. cond *sync.Cond
  376. processed bool
  377. }
  378. func (s *callbackWaitHandler) HandleEvent(e Event) {
  379. s.cond.L.Lock()
  380. if s.callback == nil {
  381. panic("s.callback is nil") // nolint
  382. }
  383. s.callback(e)
  384. s.processed = true
  385. s.cond.Broadcast()
  386. s.cond.L.Unlock()
  387. }
  388. func (s *callbackWaitHandler) wait() {
  389. s.cond.L.Lock()
  390. for !s.processed {
  391. s.cond.Wait()
  392. }
  393. s.processed = false
  394. s.callback = nil
  395. s.cond.L.Unlock()
  396. }
  397. func (s *callbackWaitHandler) setCallback(f func(event Event)) {
  398. if f == nil {
  399. panic("f is nil") // nolint
  400. }
  401. s.cond.L.Lock()
  402. s.callback = f
  403. if s.handler == nil {
  404. s.handler = s.HandleEvent
  405. }
  406. s.cond.L.Unlock()
  407. }
  408. var callbackWaitHandlerPool = sync.Pool{
  409. New: func() interface{} {
  410. return &callbackWaitHandler{
  411. cond: sync.NewCond(new(sync.Mutex)),
  412. }
  413. },
  414. }
  415. // ErrClientNotInitialized means that client connection or agent is nil.
  416. var ErrClientNotInitialized = errors.New("client not initialized")
  417. func (c *Client) checkInit() error {
  418. if c == nil || c.c == nil || c.a == nil || c.close == nil {
  419. return ErrClientNotInitialized
  420. }
  421. return nil
  422. }
  423. // Do is Start wrapper that waits until callback is called. If no callback
  424. // provided, Indicate is called instead.
  425. //
  426. // Do has cpu overhead due to blocking, see BenchmarkClient_Do.
  427. // Use Start method for less overhead.
  428. func (c *Client) Do(m *Message, f func(Event)) error {
  429. if err := c.checkInit(); err != nil {
  430. return err
  431. }
  432. if f == nil {
  433. return c.Indicate(m)
  434. }
  435. h := callbackWaitHandlerPool.Get().(*callbackWaitHandler)
  436. h.setCallback(f)
  437. defer func() {
  438. callbackWaitHandlerPool.Put(h)
  439. }()
  440. if err := c.Start(m, h.handler); err != nil {
  441. return err
  442. }
  443. h.wait()
  444. return nil
  445. }
  446. func (c *Client) delete(id transactionID) {
  447. c.mux.Lock()
  448. if c.t != nil {
  449. delete(c.t, id)
  450. }
  451. c.mux.Unlock()
  452. }
  453. type buffer struct {
  454. buf []byte
  455. }
  456. var bufferPool = &sync.Pool{
  457. New: func() interface{} {
  458. return &buffer{buf: make([]byte, 2048)}
  459. },
  460. }
  461. func (c *Client) handleAgentCallback(e Event) {
  462. c.mux.Lock()
  463. if c.closed {
  464. c.mux.Unlock()
  465. return
  466. }
  467. t, found := c.t[e.TransactionID]
  468. if found {
  469. delete(c.t, t.id)
  470. }
  471. c.mux.Unlock()
  472. if !found {
  473. if c.handler != nil && e.Error != ErrTransactionStopped {
  474. c.handler(e)
  475. }
  476. // Ignoring.
  477. return
  478. }
  479. if atomic.LoadInt32(&c.maxAttempts) <= t.attempt || e.Error == nil {
  480. // Transaction completed.
  481. t.handle(e)
  482. putClientTransaction(t)
  483. return
  484. }
  485. // Doing re-transmission.
  486. t.attempt++
  487. b := bufferPool.Get().(*buffer)
  488. b.buf = b.buf[:copy(b.buf[:cap(b.buf)], t.raw)]
  489. defer bufferPool.Put(b)
  490. var (
  491. now = c.clock.Now()
  492. timeOut = t.nextTimeout(now)
  493. id = t.id
  494. )
  495. // Starting client transaction.
  496. if startErr := c.start(t); startErr != nil {
  497. c.delete(id)
  498. e.Error = startErr
  499. t.handle(e)
  500. putClientTransaction(t)
  501. return
  502. }
  503. // Starting agent transaction.
  504. if startErr := c.a.Start(id, timeOut); startErr != nil {
  505. c.delete(id)
  506. e.Error = startErr
  507. t.handle(e)
  508. putClientTransaction(t)
  509. return
  510. }
  511. // Writing message to connection again.
  512. _, writeErr := c.c.Write(b.buf)
  513. if writeErr != nil {
  514. c.delete(id)
  515. e.Error = writeErr
  516. // Stopping agent transaction instead of waiting until it's deadline.
  517. // This will call handleAgentCallback with "ErrTransactionStopped" error
  518. // which will be ignored.
  519. if stopErr := c.a.Stop(id); stopErr != nil {
  520. // Failed to stop agent transaction. Wrapping the error in StopError.
  521. e.Error = StopErr{
  522. Err: stopErr,
  523. Cause: writeErr,
  524. }
  525. }
  526. t.handle(e)
  527. putClientTransaction(t)
  528. return
  529. }
  530. }
  531. // Start starts transaction (if h set) and writes message to server, handler
  532. // is called asynchronously.
  533. func (c *Client) Start(m *Message, h Handler) error {
  534. if err := c.checkInit(); err != nil {
  535. return err
  536. }
  537. c.mux.RLock()
  538. closed := c.closed
  539. c.mux.RUnlock()
  540. if closed {
  541. return ErrClientClosed
  542. }
  543. if h != nil {
  544. // Starting transaction only if h is set. Useful for indications.
  545. t := acquireClientTransaction()
  546. t.id = m.TransactionID
  547. t.start = c.clock.Now()
  548. t.h = h
  549. t.rto = time.Duration(atomic.LoadInt64(&c.rto))
  550. t.attempt = 0
  551. t.raw = append(t.raw[:0], m.Raw...)
  552. t.calls = 0
  553. d := t.nextTimeout(t.start)
  554. if err := c.start(t); err != nil {
  555. return err
  556. }
  557. if err := c.a.Start(m.TransactionID, d); err != nil {
  558. return err
  559. }
  560. }
  561. _, err := m.WriteTo(c.c)
  562. if err != nil && h != nil {
  563. c.delete(m.TransactionID)
  564. // Stopping transaction instead of waiting until deadline.
  565. if stopErr := c.a.Stop(m.TransactionID); stopErr != nil {
  566. return StopErr{
  567. Err: stopErr,
  568. Cause: err,
  569. }
  570. }
  571. }
  572. return err
  573. }