| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597 |
- package utp
- import (
- "context"
- "errors"
- "io"
- "log"
- "math/rand"
- "net"
- "sync"
- "time"
- "github.com/anacrolix/missinggo"
- "github.com/anacrolix/missinggo/inproc"
- "github.com/anacrolix/missinggo/pproffd"
- )
- var (
- _ net.Listener = &Socket{}
- _ net.PacketConn = &Socket{}
- )
- // Uniquely identifies any uTP connection on top of the underlying packet
- // stream.
- type connKey struct {
- remoteAddr resolvedAddrStr
- connID uint16
- }
- // A Socket wraps a net.PacketConn, diverting uTP packets to its child uTP
- // Conns.
- type Socket struct {
- pc net.PacketConn
- conns map[connKey]*Conn
- backlogNotEmpty missinggo.Event
- backlog map[syn]struct{}
- closed missinggo.Event
- destroyed missinggo.Event
- wgReadWrite sync.WaitGroup
- unusedReads chan read
- connDeadlines
- // If a read error occurs on the underlying net.PacketConn, it is put
- // here. This is because reading is done in its own goroutine to dispatch
- // to uTP Conns.
- ReadErr error
- }
- func listenPacket(network, addr string) (pc net.PacketConn, err error) {
- if network == "inproc" {
- return inproc.ListenPacket(network, addr)
- }
- return net.ListenPacket(network, addr)
- }
- // NewSocket creates a net.PacketConn with the given network and address, and
- // returns a Socket dispatching on it.
- func NewSocket(network, addr string) (s *Socket, err error) {
- if network == "" {
- network = "udp"
- }
- pc, err := listenPacket(network, addr)
- if err != nil {
- return
- }
- return NewSocketFromPacketConn(pc)
- }
- // Create a Socket, using the provided net.PacketConn. If you want to retain
- // use of the net.PacketConn after the Socket closes it, override the
- // net.PacketConn's Close method, or use NetSocketFromPacketConnNoClose.
- func NewSocketFromPacketConn(pc net.PacketConn) (s *Socket, err error) {
- s = &Socket{
- backlog: make(map[syn]struct{}, backlog),
- pc: pc,
- unusedReads: make(chan read, 100),
- wgReadWrite: sync.WaitGroup{},
- }
- mu.Lock()
- sockets[s] = struct{}{}
- mu.Unlock()
- go s.reader()
- return
- }
- // Create a Socket using the provided PacketConn, that doesn't close the
- // PacketConn when the Socket is closed.
- func NewSocketFromPacketConnNoClose(pc net.PacketConn) (s *Socket, err error) {
- return NewSocketFromPacketConn(packetConnNopCloser{pc})
- }
- func (s *Socket) unusedRead(read read) {
- unusedReads.Add(1)
- select {
- case s.unusedReads <- read:
- default:
- // Drop the packet.
- unusedReadsDropped.Add(1)
- }
- }
- func (s *Socket) strNetAddr(str string) (a net.Addr) {
- var err error
- switch n := s.network(); n {
- case "udp":
- a, err = net.ResolveUDPAddr(n, str)
- case "inproc":
- a, err = inproc.ResolveAddr(n, str)
- default:
- panic(n)
- }
- if err != nil {
- panic(err)
- }
- return
- }
- func (s *Socket) pushBacklog(syn syn) {
- if _, ok := s.backlog[syn]; ok {
- return
- }
- // Pop a pseudo-random syn to make room. TODO: Use missinggo/orderedmap,
- // coz that's what is wanted here.
- for k := range s.backlog {
- if len(s.backlog) < backlog {
- break
- }
- delete(s.backlog, k)
- // A syn is sent on the remote's recv_id, so this is where we can send
- // the reset.
- s.reset(s.strNetAddr(k.addr), k.seq_nr, k.conn_id)
- }
- s.backlog[syn] = struct{}{}
- s.backlogChanged()
- }
- func (s *Socket) reader() {
- mu.Lock()
- defer mu.Unlock()
- defer s.destroy()
- var b [maxRecvSize]byte
- for {
- s.wgReadWrite.Add(1)
- mu.Unlock()
- n, addr, err := s.pc.ReadFrom(b[:])
- s.wgReadWrite.Done()
- mu.Lock()
- if s.destroyed.IsSet() {
- return
- }
- if err != nil {
- log.Printf("error reading Socket PacketConn: %s", err)
- s.ReadErr = err
- return
- }
- s.handleReceivedPacket(read{
- append([]byte(nil), b[:n]...),
- addr,
- })
- }
- }
- func receivedUTPPacketSize(n int) {
- if n > largestReceivedUTPPacket {
- largestReceivedUTPPacket = n
- largestReceivedUTPPacketExpvar.Set(int64(n))
- }
- }
- func (s *Socket) connForRead(h header, from net.Addr) (c *Conn, ok bool) {
- c, ok = s.conns[connKey{
- resolvedAddrStr(from.String()),
- func() uint16 {
- if h.Type == stSyn {
- // SYNs have a ConnID one lower than the eventual recvID, and we index
- // the connections with that, so use it for the lookup.
- return h.ConnID + 1
- } else {
- return h.ConnID
- }
- }(),
- }]
- return
- }
- func (s *Socket) handlePacketReceivedForEstablishedConn(h header, from net.Addr, data []byte, c *Conn) {
- if h.Type == stSyn {
- if h.ConnID == c.send_id-2 {
- // This is a SYN for connection that cannot exist locally. The
- // connection the remote wants to establish here with the proposed
- // recv_id, already has an existing connection that was dialled
- // *out* from this socket, which is why the send_id is 1 higher,
- // rather than 1 lower than the recv_id.
- log.Print("resetting conflicting syn")
- s.reset(from, h.SeqNr, h.ConnID)
- return
- } else if h.ConnID != c.send_id {
- panic("bad assumption")
- }
- }
- c.receivePacket(h, data)
- }
- func (s *Socket) handleReceivedPacket(p read) {
- if len(p.data) < 20 {
- s.unusedRead(p)
- return
- }
- var h header
- hEnd, err := h.Unmarshal(p.data)
- if err != nil || h.Type > stMax || h.Version != 1 {
- s.unusedRead(p)
- return
- }
- if c, ok := s.connForRead(h, p.from); ok {
- receivedUTPPacketSize(len(p.data))
- s.handlePacketReceivedForEstablishedConn(h, p.from, p.data[hEnd:], c)
- return
- }
- // Packet doesn't belong to an existing connection.
- switch h.Type {
- case stSyn:
- s.pushBacklog(syn{
- seq_nr: h.SeqNr,
- conn_id: h.ConnID,
- addr: p.from.String(),
- })
- return
- case stReset:
- // Could be a late arriving packet for a Conn we're already done with.
- // If it was for an existing connection, we would have handled it
- // earlier.
- default:
- unexpectedPacketsRead.Add(1)
- // This is an unexpected packet. We'll send a reset, but also pass it
- // on. I don't think you can reset on the received packets ConnID if
- // it isn't a SYN, as the send_id will differ in this case.
- s.reset(p.from, h.SeqNr, h.ConnID)
- // Connection initiated by remote.
- s.reset(p.from, h.SeqNr, h.ConnID-1)
- // Connection initiated locally.
- s.reset(p.from, h.SeqNr, h.ConnID+1)
- }
- s.unusedRead(p)
- }
- // Send a reset in response to a packet with the given header.
- func (s *Socket) reset(addr net.Addr, ackNr, connId uint16) {
- b := make([]byte, 0, maxHeaderSize)
- h := header{
- Type: stReset,
- Version: 1,
- ConnID: connId,
- AckNr: ackNr,
- }
- b = b[:h.Marshal(b)]
- go s.writeTo(b, addr)
- }
- // Return a recv_id that should be free. Handling the case where it isn't is
- // deferred to a more appropriate function.
- func (s *Socket) newConnID(remoteAddr resolvedAddrStr) (id uint16) {
- // Rather than use math.Rand, which requires generating all the IDs up
- // front and allocating a slice, we do it on the stack, generating the IDs
- // only as required. To do this, we use the fact that the array is
- // default-initialized. IDs that are 0, are actually their index in the
- // array. IDs that are non-zero, are +1 from their intended ID.
- var idsBack [0x10000]int
- ids := idsBack[:]
- for len(ids) != 0 {
- // Pick the next ID from the untried ids.
- i := rand.Intn(len(ids))
- id = uint16(ids[i])
- // If it's zero, then treat it as though the index i was the ID.
- // Otherwise the value we get is the ID+1.
- if id == 0 {
- id = uint16(i)
- } else {
- id--
- }
- // Check there's no connection using this ID for its recv_id...
- _, ok1 := s.conns[connKey{remoteAddr, id}]
- // and if we're connecting to our own Socket, that there isn't a Conn
- // already receiving on what will correspond to our send_id. Note that
- // we just assume that we could be connecting to our own Socket. This
- // will halve the available connection IDs to each distinct remote
- // address. Presumably that's ~0x8000, down from ~0x10000.
- _, ok2 := s.conns[connKey{remoteAddr, id + 1}]
- _, ok4 := s.conns[connKey{remoteAddr, id - 1}]
- if !ok1 && !ok2 && !ok4 {
- return
- }
- // The set of possible IDs is shrinking. The highest one will be lost, so
- // it's moved to the location of the one we just tried.
- ids[i] = len(ids) // Conveniently already +1.
- // And shrink.
- ids = ids[:len(ids)-1]
- }
- return
- }
- var (
- zeroipv4 = net.ParseIP("0.0.0.0")
- zeroipv6 = net.ParseIP("::")
- ipv4lo = mustResolveUDP("127.0.0.1")
- ipv6lo = mustResolveUDP("::1")
- )
- func mustResolveUDP(addr string) net.IP {
- u, err := net.ResolveIPAddr("ip", addr)
- if err != nil {
- panic(err)
- }
- return u.IP
- }
- func realRemoteAddr(addr net.Addr) net.Addr {
- udpAddr, ok := addr.(*net.UDPAddr)
- if ok {
- if udpAddr.IP.Equal(zeroipv4) {
- udpAddr.IP = ipv4lo
- }
- if udpAddr.IP.Equal(zeroipv6) {
- udpAddr.IP = ipv6lo
- }
- }
- return addr
- }
- func (s *Socket) newConn(addr net.Addr) (c *Conn) {
- addr = realRemoteAddr(addr)
- c = &Conn{
- socket: s,
- remoteSocketAddr: addr,
- created: time.Now(),
- }
- c.sendPendingSendSendStateTimer = missinggo.StoppedFuncTimer(c.sendPendingSendStateTimerCallback)
- c.packetReadTimeoutTimer = time.AfterFunc(packetReadTimeout, c.receivePacketTimeoutCallback)
- return
- }
- func (s *Socket) Dial(addr string) (net.Conn, error) {
- return s.DialContext(context.Background(), "", addr)
- }
- func (s *Socket) resolveAddr(network, addr string) (net.Addr, error) {
- n := s.network()
- if network != "" {
- n = network
- }
- if n == "inproc" {
- return inproc.ResolveAddr(n, addr)
- }
- return net.ResolveUDPAddr(n, addr)
- }
- func (s *Socket) network() string {
- return s.pc.LocalAddr().Network()
- }
- func (s *Socket) startOutboundConn(addr net.Addr) (c *Conn, err error) {
- mu.Lock()
- defer mu.Unlock()
- c = s.newConn(addr)
- c.recv_id = s.newConnID(resolvedAddrStr(c.RemoteAddr().String()))
- c.send_id = c.recv_id + 1
- if logLevel >= 1 {
- log.Printf("dial registering addr: %s", c.RemoteAddr().String())
- }
- if !s.registerConn(c.recv_id, resolvedAddrStr(c.RemoteAddr().String()), c) {
- err = errors.New("couldn't register new connection")
- log.Println(c.recv_id, c.RemoteAddr().String())
- for k, c := range s.conns {
- log.Println(k, c, c.age())
- }
- log.Printf("that's %d connections", len(s.conns))
- }
- if err != nil {
- return
- }
- c.seq_nr = 1
- c.writeSyn()
- return
- }
- func (s *Socket) DialContext(ctx context.Context, network, addr string) (nc net.Conn, err error) {
- netAddr, err := s.resolveAddr(network, addr)
- if err != nil {
- return
- }
- c, err := s.startOutboundConn(netAddr)
- if err != nil {
- return
- }
- connErr := make(chan error, 1)
- go func() {
- connErr <- c.recvSynAck()
- }()
- select {
- case err = <-connErr:
- case <-ctx.Done():
- err = ctx.Err()
- }
- if err != nil {
- mu.Lock()
- c.destroy(errors.New("dial timeout"))
- mu.Unlock()
- return
- }
- mu.Lock()
- c.updateCanWrite()
- mu.Unlock()
- nc = pproffd.WrapNetConn(c)
- return
- }
- func (me *Socket) writeTo(b []byte, addr net.Addr) (n int, err error) {
- apdc := artificialPacketDropChance
- if apdc != 0 {
- if rand.Float64() < apdc {
- n = len(b)
- return
- }
- }
- n, err = me.pc.WriteTo(b, addr)
- return
- }
- // Returns true if the connection was newly registered, false otherwise.
- func (s *Socket) registerConn(recvID uint16, remoteAddr resolvedAddrStr, c *Conn) bool {
- if s.conns == nil {
- s.conns = make(map[connKey]*Conn)
- }
- key := connKey{remoteAddr, recvID}
- if _, ok := s.conns[key]; ok {
- return false
- }
- c.connKey = key
- s.conns[key] = c
- return true
- }
- func (s *Socket) backlogChanged() {
- if len(s.backlog) != 0 {
- s.backlogNotEmpty.Set()
- } else {
- s.backlogNotEmpty.Clear()
- }
- }
- func (s *Socket) nextSyn() (syn syn, err error) {
- for {
- missinggo.WaitEvents(&mu, &s.closed, &s.backlogNotEmpty, &s.destroyed)
- if s.closed.IsSet() {
- err = errClosed
- return
- }
- if s.destroyed.IsSet() {
- err = s.ReadErr
- return
- }
- for k := range s.backlog {
- syn = k
- delete(s.backlog, k)
- s.backlogChanged()
- return
- }
- }
- }
- // ACK a SYN, and return a new Conn for it. ok is false if the SYN is bad, and
- // the Conn invalid.
- func (s *Socket) ackSyn(syn syn) (c *Conn, ok bool) {
- c = s.newConn(s.strNetAddr(syn.addr))
- c.send_id = syn.conn_id
- c.recv_id = c.send_id + 1
- c.seq_nr = uint16(rand.Int())
- c.lastAck = c.seq_nr - 1
- c.ack_nr = syn.seq_nr
- c.synAcked = true
- c.updateCanWrite()
- if !s.registerConn(c.recv_id, resolvedAddrStr(syn.addr), c) {
- // SYN that triggered this accept duplicates existing connection.
- // Ack again in case the SYN was a resend.
- c = s.conns[connKey{resolvedAddrStr(syn.addr), c.recv_id}]
- if c.send_id != syn.conn_id {
- panic(":|")
- }
- c.sendState()
- return
- }
- c.sendState()
- ok = true
- return
- }
- // Accept and return a new uTP connection.
- func (s *Socket) Accept() (net.Conn, error) {
- mu.Lock()
- defer mu.Unlock()
- for {
- syn, err := s.nextSyn()
- if err != nil {
- return nil, err
- }
- c, ok := s.ackSyn(syn)
- if ok {
- c.updateCanWrite()
- return c, nil
- }
- }
- }
- // The address we're listening on for new uTP connections.
- func (s *Socket) Addr() net.Addr {
- return s.pc.LocalAddr()
- }
- func (s *Socket) CloseNow() error {
- mu.Lock()
- defer mu.Unlock()
- s.closed.Set()
- for _, c := range s.conns {
- c.closeNow()
- }
- s.destroy()
- s.wgReadWrite.Wait()
- return nil
- }
- func (s *Socket) Close() error {
- mu.Lock()
- defer mu.Unlock()
- s.closed.Set()
- s.lazyDestroy()
- return nil
- }
- func (s *Socket) lazyDestroy() {
- if len(s.conns) != 0 {
- return
- }
- if !s.closed.IsSet() {
- return
- }
- s.destroy()
- }
- func (s *Socket) destroy() {
- delete(sockets, s)
- s.destroyed.Set()
- s.pc.Close()
- for _, c := range s.conns {
- c.destroy(errors.New("Socket destroyed"))
- }
- }
- func (s *Socket) LocalAddr() net.Addr {
- return s.pc.LocalAddr()
- }
- func (s *Socket) ReadFrom(p []byte) (n int, addr net.Addr, err error) {
- select {
- case read, ok := <-s.unusedReads:
- if !ok {
- err = io.EOF
- return
- }
- n = copy(p, read.data)
- addr = read.from
- return
- case <-s.connDeadlines.read.passed.LockedChan(&mu):
- err = errTimeout
- return
- }
- }
- func (s *Socket) WriteTo(b []byte, addr net.Addr) (n int, err error) {
- mu.Lock()
- if s.connDeadlines.write.passed.IsSet() {
- err = errTimeout
- }
- s.wgReadWrite.Add(1)
- defer s.wgReadWrite.Done()
- mu.Unlock()
- if err != nil {
- return
- }
- return s.pc.WriteTo(b, addr)
- }
|