inproc.go 3.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226
  1. package inproc
  2. import (
  3. "errors"
  4. "io"
  5. "math"
  6. "net"
  7. "strconv"
  8. "sync"
  9. "time"
  10. "github.com/anacrolix/missinggo"
  11. )
  12. var (
  13. mu sync.Mutex
  14. cond = sync.Cond{L: &mu}
  15. nextPort int = 1
  16. conns = map[int]*packetConn{}
  17. )
  18. type Addr struct {
  19. Port int
  20. }
  21. func (Addr) Network() string {
  22. return "inproc"
  23. }
  24. func (me Addr) String() string {
  25. return ":" + strconv.FormatInt(int64(me.Port), 10)
  26. }
  27. func getPort() (port int) {
  28. mu.Lock()
  29. defer mu.Unlock()
  30. port = nextPort
  31. nextPort++
  32. return
  33. }
  34. func ResolveAddr(network, str string) (net.Addr, error) {
  35. return ResolveInprocAddr(network, str)
  36. }
  37. func ResolveInprocAddr(network, str string) (addr Addr, err error) {
  38. if str == "" {
  39. addr.Port = getPort()
  40. return
  41. }
  42. _, p, err := net.SplitHostPort(str)
  43. if err != nil {
  44. return
  45. }
  46. i64, err := strconv.ParseInt(p, 10, 0)
  47. if err != nil {
  48. return
  49. }
  50. addr.Port = int(i64)
  51. if addr.Port == 0 {
  52. addr.Port = getPort()
  53. }
  54. return
  55. }
  56. func ListenPacket(network, addrStr string) (nc net.PacketConn, err error) {
  57. addr, err := ResolveInprocAddr(network, addrStr)
  58. if err != nil {
  59. return
  60. }
  61. mu.Lock()
  62. defer mu.Unlock()
  63. if _, ok := conns[addr.Port]; ok {
  64. err = errors.New("address in use")
  65. return
  66. }
  67. pc := &packetConn{
  68. addr: addr,
  69. readDeadline: newCondDeadline(&cond),
  70. writeDeadline: newCondDeadline(&cond),
  71. }
  72. conns[addr.Port] = pc
  73. nc = pc
  74. return
  75. }
  76. type packet struct {
  77. data []byte
  78. addr Addr
  79. }
  80. type packetConn struct {
  81. closed bool
  82. addr Addr
  83. reads []packet
  84. readDeadline *condDeadline
  85. writeDeadline *condDeadline
  86. }
  87. func (me *packetConn) Close() error {
  88. mu.Lock()
  89. defer mu.Unlock()
  90. me.closed = true
  91. delete(conns, me.addr.Port)
  92. cond.Broadcast()
  93. return nil
  94. }
  95. func (me *packetConn) LocalAddr() net.Addr {
  96. return me.addr
  97. }
  98. type errTimeout struct{}
  99. func (errTimeout) Error() string {
  100. return "i/o timeout"
  101. }
  102. func (errTimeout) Temporary() bool {
  103. return false
  104. }
  105. func (errTimeout) Timeout() bool {
  106. return true
  107. }
  108. var _ net.Error = errTimeout{}
  109. func (me *packetConn) WriteTo(b []byte, na net.Addr) (n int, err error) {
  110. mu.Lock()
  111. defer mu.Unlock()
  112. if me.closed {
  113. err = errors.New("closed")
  114. return
  115. }
  116. if me.writeDeadline.exceeded() {
  117. err = errTimeout{}
  118. return
  119. }
  120. n = len(b)
  121. port := missinggo.AddrPort(na)
  122. c, ok := conns[port]
  123. if !ok {
  124. // log.Printf("no conn for port %d", port)
  125. return
  126. }
  127. c.reads = append(c.reads, packet{append([]byte(nil), b...), me.addr})
  128. cond.Broadcast()
  129. return
  130. }
  131. func (me *packetConn) ReadFrom(b []byte) (n int, addr net.Addr, err error) {
  132. mu.Lock()
  133. defer mu.Unlock()
  134. for {
  135. if len(me.reads) != 0 {
  136. r := me.reads[0]
  137. me.reads = me.reads[1:]
  138. n = copy(b, r.data)
  139. addr = r.addr
  140. // log.Println(addr)
  141. return
  142. }
  143. if me.closed {
  144. err = io.EOF
  145. return
  146. }
  147. if me.readDeadline.exceeded() {
  148. err = errTimeout{}
  149. return
  150. }
  151. cond.Wait()
  152. }
  153. }
  154. func (me *packetConn) SetDeadline(t time.Time) error {
  155. me.writeDeadline.setDeadline(t)
  156. me.readDeadline.setDeadline(t)
  157. return nil
  158. }
  159. func (me *packetConn) SetReadDeadline(t time.Time) error {
  160. me.readDeadline.setDeadline(t)
  161. return nil
  162. }
  163. func (me *packetConn) SetWriteDeadline(t time.Time) error {
  164. me.writeDeadline.setDeadline(t)
  165. return nil
  166. }
  167. func newCondDeadline(cond *sync.Cond) (ret *condDeadline) {
  168. ret = &condDeadline{
  169. timer: time.AfterFunc(math.MaxInt64, func() {
  170. mu.Lock()
  171. ret._exceeded = true
  172. mu.Unlock()
  173. cond.Broadcast()
  174. }),
  175. }
  176. ret.setDeadline(time.Time{})
  177. return
  178. }
  179. type condDeadline struct {
  180. mu sync.Mutex
  181. _exceeded bool
  182. timer *time.Timer
  183. }
  184. func (me *condDeadline) setDeadline(t time.Time) {
  185. me.mu.Lock()
  186. defer me.mu.Unlock()
  187. me._exceeded = false
  188. if t.IsZero() {
  189. me.timer.Stop()
  190. return
  191. }
  192. me.timer.Reset(t.Sub(time.Now()))
  193. }
  194. func (me *condDeadline) exceeded() bool {
  195. me.mu.Lock()
  196. defer me.mu.Unlock()
  197. return me._exceeded
  198. }