connctx.go 3.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172
  1. // Package connctx wraps net.Conn using context.Context.
  2. package connctx
  3. import (
  4. "context"
  5. "errors"
  6. "io"
  7. "net"
  8. "sync"
  9. "sync/atomic"
  10. "time"
  11. )
  12. // ErrClosing is returned on Write to closed connection.
  13. var ErrClosing = errors.New("use of closed network connection")
  14. // Reader is an interface for context controlled reader.
  15. type Reader interface {
  16. ReadContext(context.Context, []byte) (int, error)
  17. }
  18. // Writer is an interface for context controlled writer.
  19. type Writer interface {
  20. WriteContext(context.Context, []byte) (int, error)
  21. }
  22. // ReadWriter is a composite of ReadWriter.
  23. type ReadWriter interface {
  24. Reader
  25. Writer
  26. }
  27. // ConnCtx is a wrapper of net.Conn using context.Context.
  28. type ConnCtx interface {
  29. Reader
  30. Writer
  31. io.Closer
  32. LocalAddr() net.Addr
  33. RemoteAddr() net.Addr
  34. Conn() net.Conn
  35. }
  36. type connCtx struct {
  37. nextConn net.Conn
  38. closed chan struct{}
  39. closeOnce sync.Once
  40. readMu sync.Mutex
  41. writeMu sync.Mutex
  42. }
  43. var veryOld = time.Unix(0, 1) //nolint:gochecknoglobals
  44. // New creates a new ConnCtx wrapping given net.Conn.
  45. func New(conn net.Conn) ConnCtx {
  46. c := &connCtx{
  47. nextConn: conn,
  48. closed: make(chan struct{}),
  49. }
  50. return c
  51. }
  52. func (c *connCtx) ReadContext(ctx context.Context, b []byte) (int, error) {
  53. c.readMu.Lock()
  54. defer c.readMu.Unlock()
  55. select {
  56. case <-c.closed:
  57. return 0, io.EOF
  58. default:
  59. }
  60. done := make(chan struct{})
  61. var wg sync.WaitGroup
  62. var errSetDeadline atomic.Value
  63. wg.Add(1)
  64. go func() {
  65. defer wg.Done()
  66. select {
  67. case <-ctx.Done():
  68. // context canceled
  69. if err := c.nextConn.SetReadDeadline(veryOld); err != nil {
  70. errSetDeadline.Store(err)
  71. return
  72. }
  73. <-done
  74. if err := c.nextConn.SetReadDeadline(time.Time{}); err != nil {
  75. errSetDeadline.Store(err)
  76. }
  77. case <-done:
  78. }
  79. }()
  80. n, err := c.nextConn.Read(b)
  81. close(done)
  82. wg.Wait()
  83. if e := ctx.Err(); e != nil && n == 0 {
  84. err = e
  85. }
  86. if err2, ok := errSetDeadline.Load().(error); ok && err == nil && err2 != nil {
  87. err = err2
  88. }
  89. return n, err
  90. }
  91. func (c *connCtx) WriteContext(ctx context.Context, b []byte) (int, error) {
  92. c.writeMu.Lock()
  93. defer c.writeMu.Unlock()
  94. select {
  95. case <-c.closed:
  96. return 0, ErrClosing
  97. default:
  98. }
  99. done := make(chan struct{})
  100. var wg sync.WaitGroup
  101. var errSetDeadline atomic.Value
  102. wg.Add(1)
  103. go func() {
  104. defer wg.Done()
  105. select {
  106. case <-ctx.Done():
  107. // context canceled
  108. if err := c.nextConn.SetWriteDeadline(veryOld); err != nil {
  109. errSetDeadline.Store(err)
  110. return
  111. }
  112. <-done
  113. if err := c.nextConn.SetWriteDeadline(time.Time{}); err != nil {
  114. errSetDeadline.Store(err)
  115. }
  116. case <-done:
  117. }
  118. }()
  119. n, err := c.nextConn.Write(b)
  120. close(done)
  121. wg.Wait()
  122. if e := ctx.Err(); e != nil && n == 0 {
  123. err = e
  124. }
  125. if err2, ok := errSetDeadline.Load().(error); ok && err == nil && err2 != nil {
  126. err = err2
  127. }
  128. return n, err
  129. }
  130. func (c *connCtx) Close() error {
  131. err := c.nextConn.Close()
  132. c.closeOnce.Do(func() {
  133. c.writeMu.Lock()
  134. c.readMu.Lock()
  135. close(c.closed)
  136. c.readMu.Unlock()
  137. c.writeMu.Unlock()
  138. })
  139. return err
  140. }
  141. func (c *connCtx) LocalAddr() net.Addr {
  142. return c.nextConn.LocalAddr()
  143. }
  144. func (c *connCtx) RemoteAddr() net.Addr {
  145. return c.nextConn.RemoteAddr()
  146. }
  147. func (c *connCtx) Conn() net.Conn {
  148. return c.nextConn
  149. }