conn.go 3.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183
  1. package sftp
  2. import (
  3. "encoding"
  4. "fmt"
  5. "io"
  6. "sync"
  7. )
  8. // conn implements a bidirectional channel on which client and server
  9. // connections are multiplexed.
  10. type conn struct {
  11. io.Reader
  12. io.WriteCloser
  13. // this is the same allocator used in packet manager
  14. alloc *allocator
  15. sync.Mutex // used to serialise writes to sendPacket
  16. }
  17. // the orderID is used in server mode if the allocator is enabled.
  18. // For the client mode just pass 0.
  19. // It returns io.EOF if the connection is closed and
  20. // there are no more packets to read.
  21. func (c *conn) recvPacket(orderID uint32) (uint8, []byte, error) {
  22. return recvPacket(c, c.alloc, orderID)
  23. }
  24. func (c *conn) sendPacket(m encoding.BinaryMarshaler) error {
  25. c.Lock()
  26. defer c.Unlock()
  27. return sendPacket(c, m)
  28. }
  29. func (c *conn) Close() error {
  30. c.Lock()
  31. defer c.Unlock()
  32. return c.WriteCloser.Close()
  33. }
  34. type clientConn struct {
  35. conn
  36. wg sync.WaitGroup
  37. sync.Mutex // protects inflight
  38. inflight map[uint32]chan<- result // outstanding requests
  39. closed chan struct{}
  40. err error
  41. }
  42. // Wait blocks until the conn has shut down, and return the error
  43. // causing the shutdown. It can be called concurrently from multiple
  44. // goroutines.
  45. func (c *clientConn) Wait() error {
  46. <-c.closed
  47. return c.err
  48. }
  49. // Close closes the SFTP session.
  50. func (c *clientConn) Close() error {
  51. defer c.wg.Wait()
  52. return c.conn.Close()
  53. }
  54. // recv continuously reads from the server and forwards responses to the
  55. // appropriate channel.
  56. func (c *clientConn) recv() error {
  57. defer c.conn.Close()
  58. for {
  59. typ, data, err := c.recvPacket(0)
  60. if err != nil {
  61. return err
  62. }
  63. sid, _, err := unmarshalUint32Safe(data)
  64. if err != nil {
  65. return err
  66. }
  67. ch, ok := c.getChannel(sid)
  68. if !ok {
  69. // This is an unexpected occurrence. Send the error
  70. // back to all listeners so that they terminate
  71. // gracefully.
  72. return fmt.Errorf("sid not found: %d", sid)
  73. }
  74. ch <- result{typ: typ, data: data}
  75. }
  76. }
  77. func (c *clientConn) putChannel(ch chan<- result, sid uint32) bool {
  78. c.Lock()
  79. defer c.Unlock()
  80. select {
  81. case <-c.closed:
  82. // already closed with broadcastErr, return error on chan.
  83. ch <- result{err: ErrSSHFxConnectionLost}
  84. return false
  85. default:
  86. }
  87. c.inflight[sid] = ch
  88. return true
  89. }
  90. func (c *clientConn) getChannel(sid uint32) (chan<- result, bool) {
  91. c.Lock()
  92. defer c.Unlock()
  93. ch, ok := c.inflight[sid]
  94. delete(c.inflight, sid)
  95. return ch, ok
  96. }
  97. // result captures the result of receiving the a packet from the server
  98. type result struct {
  99. typ byte
  100. data []byte
  101. err error
  102. }
  103. type idmarshaler interface {
  104. id() uint32
  105. encoding.BinaryMarshaler
  106. }
  107. func (c *clientConn) sendPacket(ch chan result, p idmarshaler) (byte, []byte, error) {
  108. if cap(ch) < 1 {
  109. ch = make(chan result, 1)
  110. }
  111. c.dispatchRequest(ch, p)
  112. s := <-ch
  113. return s.typ, s.data, s.err
  114. }
  115. // dispatchRequest should ideally only be called by race-detection tests outside of this file,
  116. // where you have to ensure two packets are in flight sequentially after each other.
  117. func (c *clientConn) dispatchRequest(ch chan<- result, p idmarshaler) {
  118. sid := p.id()
  119. if !c.putChannel(ch, sid) {
  120. // already closed.
  121. return
  122. }
  123. if err := c.conn.sendPacket(p); err != nil {
  124. if ch, ok := c.getChannel(sid); ok {
  125. ch <- result{err: err}
  126. }
  127. }
  128. }
  129. // broadcastErr sends an error to all goroutines waiting for a response.
  130. func (c *clientConn) broadcastErr(err error) {
  131. c.Lock()
  132. defer c.Unlock()
  133. bcastRes := result{err: ErrSSHFxConnectionLost}
  134. for sid, ch := range c.inflight {
  135. ch <- bcastRes
  136. // Replace the chan in inflight,
  137. // we have hijacked this chan,
  138. // and this guarantees always-only-once sending.
  139. c.inflight[sid] = make(chan<- result, 1)
  140. }
  141. c.err = err
  142. close(c.closed)
  143. }
  144. type serverConn struct {
  145. conn
  146. }
  147. func (s *serverConn) sendError(id uint32, err error) error {
  148. return s.sendPacket(statusFromError(id, err))
  149. }