client.go 5.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240
  1. package udp
  2. import (
  3. "bytes"
  4. "context"
  5. "encoding/binary"
  6. "fmt"
  7. "github.com/protolambda/ctxlock"
  8. "io"
  9. "net"
  10. "time"
  11. "github.com/anacrolix/dht/v2/krpc"
  12. )
  13. // Client interacts with UDP trackers via its Writer and Dispatcher. It has no knowledge of
  14. // connection specifics.
  15. type Client struct {
  16. mu ctxlock.Lock
  17. connId ConnectionId
  18. connIdIssued time.Time
  19. shouldReconnectOverride func() bool
  20. Dispatcher *Dispatcher
  21. Writer io.Writer
  22. }
  23. func (cl *Client) Announce(
  24. ctx context.Context, req AnnounceRequest, opts Options,
  25. // Decides whether the response body is IPv6 or IPv4, see BEP 15.
  26. ipv6 func(net.Addr) bool,
  27. ) (
  28. respHdr AnnounceResponseHeader,
  29. // A slice of krpc.NodeAddr, likely wrapped in an appropriate unmarshalling wrapper.
  30. peers AnnounceResponsePeers,
  31. err error,
  32. ) {
  33. respBody, addr, err := cl.request(ctx, ActionAnnounce, append(mustMarshal(req), opts.Encode()...))
  34. if err != nil {
  35. return
  36. }
  37. r := bytes.NewBuffer(respBody)
  38. err = Read(r, &respHdr)
  39. if err != nil {
  40. err = fmt.Errorf("reading response header: %w", err)
  41. return
  42. }
  43. if ipv6(addr) {
  44. peers = &krpc.CompactIPv6NodeAddrs{}
  45. } else {
  46. peers = &krpc.CompactIPv4NodeAddrs{}
  47. }
  48. err = peers.UnmarshalBinary(r.Bytes())
  49. if err != nil {
  50. err = fmt.Errorf("reading response peers: %w", err)
  51. }
  52. return
  53. }
  54. // There's no way to pass options in a scrape, since we don't when the request body ends.
  55. func (cl *Client) Scrape(
  56. ctx context.Context, ihs []InfoHash,
  57. ) (
  58. out ScrapeResponse, err error,
  59. ) {
  60. respBody, _, err := cl.request(ctx, ActionScrape, mustMarshal(ScrapeRequest(ihs)))
  61. if err != nil {
  62. return
  63. }
  64. r := bytes.NewBuffer(respBody)
  65. for r.Len() != 0 {
  66. var item ScrapeInfohashResult
  67. err = Read(r, &item)
  68. if err != nil {
  69. return
  70. }
  71. out = append(out, item)
  72. }
  73. if len(out) > len(ihs) {
  74. err = fmt.Errorf("got %v results but expected %v", len(out), len(ihs))
  75. return
  76. }
  77. return
  78. }
  79. func (cl *Client) shouldReconnectDefault() bool {
  80. return cl.connIdIssued.IsZero() || time.Since(cl.connIdIssued) >= time.Minute
  81. }
  82. func (cl *Client) shouldReconnect() bool {
  83. if cl.shouldReconnectOverride != nil {
  84. return cl.shouldReconnectOverride()
  85. }
  86. return cl.shouldReconnectDefault()
  87. }
  88. func (cl *Client) connect(ctx context.Context) (err error) {
  89. if !cl.shouldReconnect() {
  90. return nil
  91. }
  92. return cl.doConnectRoundTrip(ctx)
  93. }
  94. // This just does the connect request and updates local state if it succeeds.
  95. func (cl *Client) doConnectRoundTrip(ctx context.Context) (err error) {
  96. respBody, _, err := cl.request(ctx, ActionConnect, nil)
  97. if err != nil {
  98. return err
  99. }
  100. var connResp ConnectionResponse
  101. err = binary.Read(bytes.NewReader(respBody), binary.BigEndian, &connResp)
  102. if err != nil {
  103. return
  104. }
  105. cl.connId = connResp.ConnectionId
  106. cl.connIdIssued = time.Now()
  107. //log.Printf("conn id set to %x", cl.connId)
  108. return
  109. }
  110. func (cl *Client) connIdForRequest(ctx context.Context, action Action) (id ConnectionId, err error) {
  111. if action == ActionConnect {
  112. id = ConnectRequestConnectionId
  113. return
  114. }
  115. err = cl.connect(ctx)
  116. if err != nil {
  117. return
  118. }
  119. id = cl.connId
  120. return
  121. }
  122. func (cl *Client) writeRequest(
  123. ctx context.Context, action Action, body []byte, tId TransactionId, buf *bytes.Buffer,
  124. ) (
  125. err error,
  126. ) {
  127. var connId ConnectionId
  128. if action == ActionConnect {
  129. connId = ConnectRequestConnectionId
  130. } else {
  131. // We lock here while establishing a connection ID, and then ensuring that the request is
  132. // written before allowing the connection ID to change again. This is to ensure the server
  133. // doesn't assign us another ID before we've sent this request. Note that this doesn't allow
  134. // for us to return if the context is cancelled while we wait to obtain a new ID.
  135. err = cl.mu.LockCtx(ctx)
  136. if err != nil {
  137. return fmt.Errorf("locking connection id: %w", err)
  138. }
  139. defer cl.mu.Unlock()
  140. connId, err = cl.connIdForRequest(ctx, action)
  141. if err != nil {
  142. return
  143. }
  144. }
  145. buf.Reset()
  146. err = Write(buf, RequestHeader{
  147. ConnectionId: connId,
  148. Action: action,
  149. TransactionId: tId,
  150. })
  151. if err != nil {
  152. panic(err)
  153. }
  154. buf.Write(body)
  155. _, err = cl.Writer.Write(buf.Bytes())
  156. //log.Printf("sent request with conn id %x", connId)
  157. return
  158. }
  159. func (cl *Client) requestWriter(
  160. ctx context.Context,
  161. action Action,
  162. body []byte,
  163. tId TransactionId,
  164. ) (err error) {
  165. var buf bytes.Buffer
  166. for n := 0; ; n++ {
  167. err = cl.writeRequest(ctx, action, body, tId, &buf)
  168. if err != nil {
  169. return
  170. }
  171. select {
  172. case <-ctx.Done():
  173. return ctx.Err()
  174. case <-time.After(timeout(n)):
  175. }
  176. }
  177. }
  178. const ConnectionIdMissmatchNul = "Connection ID missmatch.\x00"
  179. type ErrorResponse struct {
  180. Message string
  181. }
  182. func (me ErrorResponse) Error() string {
  183. return fmt.Sprintf("error response: %#q", me.Message)
  184. }
  185. func (cl *Client) request(
  186. ctx context.Context,
  187. action Action,
  188. body []byte,
  189. ) (respBody []byte, addr net.Addr, err error) {
  190. respChan := make(chan DispatchedResponse, 1)
  191. t := cl.Dispatcher.NewTransaction(func(dr DispatchedResponse) {
  192. respChan <- dr
  193. })
  194. defer t.End()
  195. ctx, cancel := context.WithCancel(ctx)
  196. defer cancel()
  197. writeErr := make(chan error, 1)
  198. go func() {
  199. writeErr <- cl.requestWriter(ctx, action, body, t.Id())
  200. }()
  201. select {
  202. case dr := <-respChan:
  203. if dr.Header.Action == action {
  204. respBody = dr.Body
  205. addr = dr.Addr
  206. } else if dr.Header.Action == ActionError {
  207. // udp://tracker.torrent.eu.org:451/announce frequently returns "Connection ID
  208. // missmatch.\x00"
  209. err = ErrorResponse{Message: string(dr.Body)}
  210. // Force a reconnection. Probably any error is worth doing this for, but the one we're
  211. // specifically interested in is ConnectionIdMissmatchNul.
  212. cl.connIdIssued = time.Time{}
  213. } else {
  214. err = fmt.Errorf("unexpected response action %v", dr.Header.Action)
  215. }
  216. case err = <-writeErr:
  217. err = fmt.Errorf("write error: %w", err)
  218. case <-ctx.Done():
  219. err = context.Cause(ctx)
  220. }
  221. return
  222. }