socket.go 14 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597
  1. package utp
  2. import (
  3. "context"
  4. "errors"
  5. "io"
  6. "log"
  7. "math/rand"
  8. "net"
  9. "sync"
  10. "time"
  11. "github.com/anacrolix/missinggo"
  12. "github.com/anacrolix/missinggo/inproc"
  13. "github.com/anacrolix/missinggo/pproffd"
  14. )
  15. var (
  16. _ net.Listener = &Socket{}
  17. _ net.PacketConn = &Socket{}
  18. )
  19. // Uniquely identifies any uTP connection on top of the underlying packet
  20. // stream.
  21. type connKey struct {
  22. remoteAddr resolvedAddrStr
  23. connID uint16
  24. }
  25. // A Socket wraps a net.PacketConn, diverting uTP packets to its child uTP
  26. // Conns.
  27. type Socket struct {
  28. pc net.PacketConn
  29. conns map[connKey]*Conn
  30. backlogNotEmpty missinggo.Event
  31. backlog map[syn]struct{}
  32. closed missinggo.Event
  33. destroyed missinggo.Event
  34. wgReadWrite sync.WaitGroup
  35. unusedReads chan read
  36. connDeadlines
  37. // If a read error occurs on the underlying net.PacketConn, it is put
  38. // here. This is because reading is done in its own goroutine to dispatch
  39. // to uTP Conns.
  40. ReadErr error
  41. }
  42. func listenPacket(network, addr string) (pc net.PacketConn, err error) {
  43. if network == "inproc" {
  44. return inproc.ListenPacket(network, addr)
  45. }
  46. return net.ListenPacket(network, addr)
  47. }
  48. // NewSocket creates a net.PacketConn with the given network and address, and
  49. // returns a Socket dispatching on it.
  50. func NewSocket(network, addr string) (s *Socket, err error) {
  51. if network == "" {
  52. network = "udp"
  53. }
  54. pc, err := listenPacket(network, addr)
  55. if err != nil {
  56. return
  57. }
  58. return NewSocketFromPacketConn(pc)
  59. }
  60. // Create a Socket, using the provided net.PacketConn. If you want to retain
  61. // use of the net.PacketConn after the Socket closes it, override the
  62. // net.PacketConn's Close method, or use NetSocketFromPacketConnNoClose.
  63. func NewSocketFromPacketConn(pc net.PacketConn) (s *Socket, err error) {
  64. s = &Socket{
  65. backlog: make(map[syn]struct{}, backlog),
  66. pc: pc,
  67. unusedReads: make(chan read, 100),
  68. wgReadWrite: sync.WaitGroup{},
  69. }
  70. mu.Lock()
  71. sockets[s] = struct{}{}
  72. mu.Unlock()
  73. go s.reader()
  74. return
  75. }
  76. // Create a Socket using the provided PacketConn, that doesn't close the
  77. // PacketConn when the Socket is closed.
  78. func NewSocketFromPacketConnNoClose(pc net.PacketConn) (s *Socket, err error) {
  79. return NewSocketFromPacketConn(packetConnNopCloser{pc})
  80. }
  81. func (s *Socket) unusedRead(read read) {
  82. unusedReads.Add(1)
  83. select {
  84. case s.unusedReads <- read:
  85. default:
  86. // Drop the packet.
  87. unusedReadsDropped.Add(1)
  88. }
  89. }
  90. func (s *Socket) strNetAddr(str string) (a net.Addr) {
  91. var err error
  92. switch n := s.network(); n {
  93. case "udp":
  94. a, err = net.ResolveUDPAddr(n, str)
  95. case "inproc":
  96. a, err = inproc.ResolveAddr(n, str)
  97. default:
  98. panic(n)
  99. }
  100. if err != nil {
  101. panic(err)
  102. }
  103. return
  104. }
  105. func (s *Socket) pushBacklog(syn syn) {
  106. if _, ok := s.backlog[syn]; ok {
  107. return
  108. }
  109. // Pop a pseudo-random syn to make room. TODO: Use missinggo/orderedmap,
  110. // coz that's what is wanted here.
  111. for k := range s.backlog {
  112. if len(s.backlog) < backlog {
  113. break
  114. }
  115. delete(s.backlog, k)
  116. // A syn is sent on the remote's recv_id, so this is where we can send
  117. // the reset.
  118. s.reset(s.strNetAddr(k.addr), k.seq_nr, k.conn_id)
  119. }
  120. s.backlog[syn] = struct{}{}
  121. s.backlogChanged()
  122. }
  123. func (s *Socket) reader() {
  124. mu.Lock()
  125. defer mu.Unlock()
  126. defer s.destroy()
  127. var b [maxRecvSize]byte
  128. for {
  129. s.wgReadWrite.Add(1)
  130. mu.Unlock()
  131. n, addr, err := s.pc.ReadFrom(b[:])
  132. s.wgReadWrite.Done()
  133. mu.Lock()
  134. if s.destroyed.IsSet() {
  135. return
  136. }
  137. if err != nil {
  138. log.Printf("error reading Socket PacketConn: %s", err)
  139. s.ReadErr = err
  140. return
  141. }
  142. s.handleReceivedPacket(read{
  143. append([]byte(nil), b[:n]...),
  144. addr,
  145. })
  146. }
  147. }
  148. func receivedUTPPacketSize(n int) {
  149. if n > largestReceivedUTPPacket {
  150. largestReceivedUTPPacket = n
  151. largestReceivedUTPPacketExpvar.Set(int64(n))
  152. }
  153. }
  154. func (s *Socket) connForRead(h header, from net.Addr) (c *Conn, ok bool) {
  155. c, ok = s.conns[connKey{
  156. resolvedAddrStr(from.String()),
  157. func() uint16 {
  158. if h.Type == stSyn {
  159. // SYNs have a ConnID one lower than the eventual recvID, and we index
  160. // the connections with that, so use it for the lookup.
  161. return h.ConnID + 1
  162. } else {
  163. return h.ConnID
  164. }
  165. }(),
  166. }]
  167. return
  168. }
  169. func (s *Socket) handlePacketReceivedForEstablishedConn(h header, from net.Addr, data []byte, c *Conn) {
  170. if h.Type == stSyn {
  171. if h.ConnID == c.send_id-2 {
  172. // This is a SYN for connection that cannot exist locally. The
  173. // connection the remote wants to establish here with the proposed
  174. // recv_id, already has an existing connection that was dialled
  175. // *out* from this socket, which is why the send_id is 1 higher,
  176. // rather than 1 lower than the recv_id.
  177. log.Print("resetting conflicting syn")
  178. s.reset(from, h.SeqNr, h.ConnID)
  179. return
  180. } else if h.ConnID != c.send_id {
  181. panic("bad assumption")
  182. }
  183. }
  184. c.receivePacket(h, data)
  185. }
  186. func (s *Socket) handleReceivedPacket(p read) {
  187. if len(p.data) < 20 {
  188. s.unusedRead(p)
  189. return
  190. }
  191. var h header
  192. hEnd, err := h.Unmarshal(p.data)
  193. if err != nil || h.Type > stMax || h.Version != 1 {
  194. s.unusedRead(p)
  195. return
  196. }
  197. if c, ok := s.connForRead(h, p.from); ok {
  198. receivedUTPPacketSize(len(p.data))
  199. s.handlePacketReceivedForEstablishedConn(h, p.from, p.data[hEnd:], c)
  200. return
  201. }
  202. // Packet doesn't belong to an existing connection.
  203. switch h.Type {
  204. case stSyn:
  205. s.pushBacklog(syn{
  206. seq_nr: h.SeqNr,
  207. conn_id: h.ConnID,
  208. addr: p.from.String(),
  209. })
  210. return
  211. case stReset:
  212. // Could be a late arriving packet for a Conn we're already done with.
  213. // If it was for an existing connection, we would have handled it
  214. // earlier.
  215. default:
  216. unexpectedPacketsRead.Add(1)
  217. // This is an unexpected packet. We'll send a reset, but also pass it
  218. // on. I don't think you can reset on the received packets ConnID if
  219. // it isn't a SYN, as the send_id will differ in this case.
  220. s.reset(p.from, h.SeqNr, h.ConnID)
  221. // Connection initiated by remote.
  222. s.reset(p.from, h.SeqNr, h.ConnID-1)
  223. // Connection initiated locally.
  224. s.reset(p.from, h.SeqNr, h.ConnID+1)
  225. }
  226. s.unusedRead(p)
  227. }
  228. // Send a reset in response to a packet with the given header.
  229. func (s *Socket) reset(addr net.Addr, ackNr, connId uint16) {
  230. b := make([]byte, 0, maxHeaderSize)
  231. h := header{
  232. Type: stReset,
  233. Version: 1,
  234. ConnID: connId,
  235. AckNr: ackNr,
  236. }
  237. b = b[:h.Marshal(b)]
  238. go s.writeTo(b, addr)
  239. }
  240. // Return a recv_id that should be free. Handling the case where it isn't is
  241. // deferred to a more appropriate function.
  242. func (s *Socket) newConnID(remoteAddr resolvedAddrStr) (id uint16) {
  243. // Rather than use math.Rand, which requires generating all the IDs up
  244. // front and allocating a slice, we do it on the stack, generating the IDs
  245. // only as required. To do this, we use the fact that the array is
  246. // default-initialized. IDs that are 0, are actually their index in the
  247. // array. IDs that are non-zero, are +1 from their intended ID.
  248. var idsBack [0x10000]int
  249. ids := idsBack[:]
  250. for len(ids) != 0 {
  251. // Pick the next ID from the untried ids.
  252. i := rand.Intn(len(ids))
  253. id = uint16(ids[i])
  254. // If it's zero, then treat it as though the index i was the ID.
  255. // Otherwise the value we get is the ID+1.
  256. if id == 0 {
  257. id = uint16(i)
  258. } else {
  259. id--
  260. }
  261. // Check there's no connection using this ID for its recv_id...
  262. _, ok1 := s.conns[connKey{remoteAddr, id}]
  263. // and if we're connecting to our own Socket, that there isn't a Conn
  264. // already receiving on what will correspond to our send_id. Note that
  265. // we just assume that we could be connecting to our own Socket. This
  266. // will halve the available connection IDs to each distinct remote
  267. // address. Presumably that's ~0x8000, down from ~0x10000.
  268. _, ok2 := s.conns[connKey{remoteAddr, id + 1}]
  269. _, ok4 := s.conns[connKey{remoteAddr, id - 1}]
  270. if !ok1 && !ok2 && !ok4 {
  271. return
  272. }
  273. // The set of possible IDs is shrinking. The highest one will be lost, so
  274. // it's moved to the location of the one we just tried.
  275. ids[i] = len(ids) // Conveniently already +1.
  276. // And shrink.
  277. ids = ids[:len(ids)-1]
  278. }
  279. return
  280. }
  281. var (
  282. zeroipv4 = net.ParseIP("0.0.0.0")
  283. zeroipv6 = net.ParseIP("::")
  284. ipv4lo = mustResolveUDP("127.0.0.1")
  285. ipv6lo = mustResolveUDP("::1")
  286. )
  287. func mustResolveUDP(addr string) net.IP {
  288. u, err := net.ResolveIPAddr("ip", addr)
  289. if err != nil {
  290. panic(err)
  291. }
  292. return u.IP
  293. }
  294. func realRemoteAddr(addr net.Addr) net.Addr {
  295. udpAddr, ok := addr.(*net.UDPAddr)
  296. if ok {
  297. if udpAddr.IP.Equal(zeroipv4) {
  298. udpAddr.IP = ipv4lo
  299. }
  300. if udpAddr.IP.Equal(zeroipv6) {
  301. udpAddr.IP = ipv6lo
  302. }
  303. }
  304. return addr
  305. }
  306. func (s *Socket) newConn(addr net.Addr) (c *Conn) {
  307. addr = realRemoteAddr(addr)
  308. c = &Conn{
  309. socket: s,
  310. remoteSocketAddr: addr,
  311. created: time.Now(),
  312. }
  313. c.sendPendingSendSendStateTimer = missinggo.StoppedFuncTimer(c.sendPendingSendStateTimerCallback)
  314. c.packetReadTimeoutTimer = time.AfterFunc(packetReadTimeout, c.receivePacketTimeoutCallback)
  315. return
  316. }
  317. func (s *Socket) Dial(addr string) (net.Conn, error) {
  318. return s.DialContext(context.Background(), "", addr)
  319. }
  320. func (s *Socket) resolveAddr(network, addr string) (net.Addr, error) {
  321. n := s.network()
  322. if network != "" {
  323. n = network
  324. }
  325. if n == "inproc" {
  326. return inproc.ResolveAddr(n, addr)
  327. }
  328. return net.ResolveUDPAddr(n, addr)
  329. }
  330. func (s *Socket) network() string {
  331. return s.pc.LocalAddr().Network()
  332. }
  333. func (s *Socket) startOutboundConn(addr net.Addr) (c *Conn, err error) {
  334. mu.Lock()
  335. defer mu.Unlock()
  336. c = s.newConn(addr)
  337. c.recv_id = s.newConnID(resolvedAddrStr(c.RemoteAddr().String()))
  338. c.send_id = c.recv_id + 1
  339. if logLevel >= 1 {
  340. log.Printf("dial registering addr: %s", c.RemoteAddr().String())
  341. }
  342. if !s.registerConn(c.recv_id, resolvedAddrStr(c.RemoteAddr().String()), c) {
  343. err = errors.New("couldn't register new connection")
  344. log.Println(c.recv_id, c.RemoteAddr().String())
  345. for k, c := range s.conns {
  346. log.Println(k, c, c.age())
  347. }
  348. log.Printf("that's %d connections", len(s.conns))
  349. }
  350. if err != nil {
  351. return
  352. }
  353. c.seq_nr = 1
  354. c.writeSyn()
  355. return
  356. }
  357. func (s *Socket) DialContext(ctx context.Context, network, addr string) (nc net.Conn, err error) {
  358. netAddr, err := s.resolveAddr(network, addr)
  359. if err != nil {
  360. return
  361. }
  362. c, err := s.startOutboundConn(netAddr)
  363. if err != nil {
  364. return
  365. }
  366. connErr := make(chan error, 1)
  367. go func() {
  368. connErr <- c.recvSynAck()
  369. }()
  370. select {
  371. case err = <-connErr:
  372. case <-ctx.Done():
  373. err = ctx.Err()
  374. }
  375. if err != nil {
  376. mu.Lock()
  377. c.destroy(errors.New("dial timeout"))
  378. mu.Unlock()
  379. return
  380. }
  381. mu.Lock()
  382. c.updateCanWrite()
  383. mu.Unlock()
  384. nc = pproffd.WrapNetConn(c)
  385. return
  386. }
  387. func (me *Socket) writeTo(b []byte, addr net.Addr) (n int, err error) {
  388. apdc := artificialPacketDropChance
  389. if apdc != 0 {
  390. if rand.Float64() < apdc {
  391. n = len(b)
  392. return
  393. }
  394. }
  395. n, err = me.pc.WriteTo(b, addr)
  396. return
  397. }
  398. // Returns true if the connection was newly registered, false otherwise.
  399. func (s *Socket) registerConn(recvID uint16, remoteAddr resolvedAddrStr, c *Conn) bool {
  400. if s.conns == nil {
  401. s.conns = make(map[connKey]*Conn)
  402. }
  403. key := connKey{remoteAddr, recvID}
  404. if _, ok := s.conns[key]; ok {
  405. return false
  406. }
  407. c.connKey = key
  408. s.conns[key] = c
  409. return true
  410. }
  411. func (s *Socket) backlogChanged() {
  412. if len(s.backlog) != 0 {
  413. s.backlogNotEmpty.Set()
  414. } else {
  415. s.backlogNotEmpty.Clear()
  416. }
  417. }
  418. func (s *Socket) nextSyn() (syn syn, err error) {
  419. for {
  420. missinggo.WaitEvents(&mu, &s.closed, &s.backlogNotEmpty, &s.destroyed)
  421. if s.closed.IsSet() {
  422. err = errClosed
  423. return
  424. }
  425. if s.destroyed.IsSet() {
  426. err = s.ReadErr
  427. return
  428. }
  429. for k := range s.backlog {
  430. syn = k
  431. delete(s.backlog, k)
  432. s.backlogChanged()
  433. return
  434. }
  435. }
  436. }
  437. // ACK a SYN, and return a new Conn for it. ok is false if the SYN is bad, and
  438. // the Conn invalid.
  439. func (s *Socket) ackSyn(syn syn) (c *Conn, ok bool) {
  440. c = s.newConn(s.strNetAddr(syn.addr))
  441. c.send_id = syn.conn_id
  442. c.recv_id = c.send_id + 1
  443. c.seq_nr = uint16(rand.Int())
  444. c.lastAck = c.seq_nr - 1
  445. c.ack_nr = syn.seq_nr
  446. c.synAcked = true
  447. c.updateCanWrite()
  448. if !s.registerConn(c.recv_id, resolvedAddrStr(syn.addr), c) {
  449. // SYN that triggered this accept duplicates existing connection.
  450. // Ack again in case the SYN was a resend.
  451. c = s.conns[connKey{resolvedAddrStr(syn.addr), c.recv_id}]
  452. if c.send_id != syn.conn_id {
  453. panic(":|")
  454. }
  455. c.sendState()
  456. return
  457. }
  458. c.sendState()
  459. ok = true
  460. return
  461. }
  462. // Accept and return a new uTP connection.
  463. func (s *Socket) Accept() (net.Conn, error) {
  464. mu.Lock()
  465. defer mu.Unlock()
  466. for {
  467. syn, err := s.nextSyn()
  468. if err != nil {
  469. return nil, err
  470. }
  471. c, ok := s.ackSyn(syn)
  472. if ok {
  473. c.updateCanWrite()
  474. return c, nil
  475. }
  476. }
  477. }
  478. // The address we're listening on for new uTP connections.
  479. func (s *Socket) Addr() net.Addr {
  480. return s.pc.LocalAddr()
  481. }
  482. func (s *Socket) CloseNow() error {
  483. mu.Lock()
  484. defer mu.Unlock()
  485. s.closed.Set()
  486. for _, c := range s.conns {
  487. c.closeNow()
  488. }
  489. s.destroy()
  490. s.wgReadWrite.Wait()
  491. return nil
  492. }
  493. func (s *Socket) Close() error {
  494. mu.Lock()
  495. defer mu.Unlock()
  496. s.closed.Set()
  497. s.lazyDestroy()
  498. return nil
  499. }
  500. func (s *Socket) lazyDestroy() {
  501. if len(s.conns) != 0 {
  502. return
  503. }
  504. if !s.closed.IsSet() {
  505. return
  506. }
  507. s.destroy()
  508. }
  509. func (s *Socket) destroy() {
  510. delete(sockets, s)
  511. s.destroyed.Set()
  512. s.pc.Close()
  513. for _, c := range s.conns {
  514. c.destroy(errors.New("Socket destroyed"))
  515. }
  516. }
  517. func (s *Socket) LocalAddr() net.Addr {
  518. return s.pc.LocalAddr()
  519. }
  520. func (s *Socket) ReadFrom(p []byte) (n int, addr net.Addr, err error) {
  521. select {
  522. case read, ok := <-s.unusedReads:
  523. if !ok {
  524. err = io.EOF
  525. return
  526. }
  527. n = copy(p, read.data)
  528. addr = read.from
  529. return
  530. case <-s.connDeadlines.read.passed.LockedChan(&mu):
  531. err = errTimeout
  532. return
  533. }
  534. }
  535. func (s *Socket) WriteTo(b []byte, addr net.Addr) (n int, err error) {
  536. mu.Lock()
  537. if s.connDeadlines.write.passed.IsSet() {
  538. err = errTimeout
  539. }
  540. s.wgReadWrite.Add(1)
  541. defer s.wgReadWrite.Done()
  542. mu.Unlock()
  543. if err != nil {
  544. return
  545. }
  546. return s.pc.WriteTo(b, addr)
  547. }