packet-manager.go 5.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216
  1. package sftp
  2. import (
  3. "encoding"
  4. "sort"
  5. "sync"
  6. )
  7. // The goal of the packetManager is to keep the outgoing packets in the same
  8. // order as the incoming as is requires by section 7 of the RFC.
  9. type packetManager struct {
  10. requests chan orderedPacket
  11. responses chan orderedPacket
  12. fini chan struct{}
  13. incoming orderedPackets
  14. outgoing orderedPackets
  15. sender packetSender // connection object
  16. working *sync.WaitGroup
  17. packetCount uint32
  18. // it is not nil if the allocator is enabled
  19. alloc *allocator
  20. }
  21. type packetSender interface {
  22. sendPacket(encoding.BinaryMarshaler) error
  23. }
  24. func newPktMgr(sender packetSender) *packetManager {
  25. s := &packetManager{
  26. requests: make(chan orderedPacket, SftpServerWorkerCount),
  27. responses: make(chan orderedPacket, SftpServerWorkerCount),
  28. fini: make(chan struct{}),
  29. incoming: make([]orderedPacket, 0, SftpServerWorkerCount),
  30. outgoing: make([]orderedPacket, 0, SftpServerWorkerCount),
  31. sender: sender,
  32. working: &sync.WaitGroup{},
  33. }
  34. go s.controller()
  35. return s
  36. }
  37. // // packet ordering
  38. func (s *packetManager) newOrderID() uint32 {
  39. s.packetCount++
  40. return s.packetCount
  41. }
  42. // returns the next orderID without incrementing it.
  43. // This is used before receiving a new packet, with the allocator enabled, to associate
  44. // the slice allocated for the received packet with the orderID that will be used to mark
  45. // the allocated slices for reuse once the request is served
  46. func (s *packetManager) getNextOrderID() uint32 {
  47. return s.packetCount + 1
  48. }
  49. type orderedRequest struct {
  50. requestPacket
  51. orderid uint32
  52. }
  53. func (s *packetManager) newOrderedRequest(p requestPacket) orderedRequest {
  54. return orderedRequest{requestPacket: p, orderid: s.newOrderID()}
  55. }
  56. func (p orderedRequest) orderID() uint32 { return p.orderid }
  57. func (p orderedRequest) setOrderID(oid uint32) { p.orderid = oid }
  58. type orderedResponse struct {
  59. responsePacket
  60. orderid uint32
  61. }
  62. func (s *packetManager) newOrderedResponse(p responsePacket, id uint32,
  63. ) orderedResponse {
  64. return orderedResponse{responsePacket: p, orderid: id}
  65. }
  66. func (p orderedResponse) orderID() uint32 { return p.orderid }
  67. func (p orderedResponse) setOrderID(oid uint32) { p.orderid = oid }
  68. type orderedPacket interface {
  69. id() uint32
  70. orderID() uint32
  71. }
  72. type orderedPackets []orderedPacket
  73. func (o orderedPackets) Sort() {
  74. sort.Slice(o, func(i, j int) bool {
  75. return o[i].orderID() < o[j].orderID()
  76. })
  77. }
  78. // // packet registry
  79. // register incoming packets to be handled
  80. func (s *packetManager) incomingPacket(pkt orderedRequest) {
  81. s.working.Add(1)
  82. s.requests <- pkt
  83. }
  84. // register outgoing packets as being ready
  85. func (s *packetManager) readyPacket(pkt orderedResponse) {
  86. s.responses <- pkt
  87. s.working.Done()
  88. }
  89. // shut down packetManager controller
  90. func (s *packetManager) close() {
  91. // pause until current packets are processed
  92. s.working.Wait()
  93. close(s.fini)
  94. }
  95. // Passed a worker function, returns a channel for incoming packets.
  96. // Keep process packet responses in the order they are received while
  97. // maximizing throughput of file transfers.
  98. func (s *packetManager) workerChan(runWorker func(chan orderedRequest),
  99. ) chan orderedRequest {
  100. // multiple workers for faster read/writes
  101. rwChan := make(chan orderedRequest, SftpServerWorkerCount)
  102. for i := 0; i < SftpServerWorkerCount; i++ {
  103. runWorker(rwChan)
  104. }
  105. // single worker to enforce sequential processing of everything else
  106. cmdChan := make(chan orderedRequest)
  107. runWorker(cmdChan)
  108. pktChan := make(chan orderedRequest, SftpServerWorkerCount)
  109. go func() {
  110. for pkt := range pktChan {
  111. switch pkt.requestPacket.(type) {
  112. case *sshFxpReadPacket, *sshFxpWritePacket:
  113. s.incomingPacket(pkt)
  114. rwChan <- pkt
  115. continue
  116. case *sshFxpClosePacket:
  117. // wait for reads/writes to finish when file is closed
  118. // incomingPacket() call must occur after this
  119. s.working.Wait()
  120. }
  121. s.incomingPacket(pkt)
  122. // all non-RW use sequential cmdChan
  123. cmdChan <- pkt
  124. }
  125. close(rwChan)
  126. close(cmdChan)
  127. s.close()
  128. }()
  129. return pktChan
  130. }
  131. // process packets
  132. func (s *packetManager) controller() {
  133. for {
  134. select {
  135. case pkt := <-s.requests:
  136. debug("incoming id (oid): %v (%v)", pkt.id(), pkt.orderID())
  137. s.incoming = append(s.incoming, pkt)
  138. s.incoming.Sort()
  139. case pkt := <-s.responses:
  140. debug("outgoing id (oid): %v (%v)", pkt.id(), pkt.orderID())
  141. s.outgoing = append(s.outgoing, pkt)
  142. s.outgoing.Sort()
  143. case <-s.fini:
  144. return
  145. }
  146. s.maybeSendPackets()
  147. }
  148. }
  149. // send as many packets as are ready
  150. func (s *packetManager) maybeSendPackets() {
  151. for {
  152. if len(s.outgoing) == 0 || len(s.incoming) == 0 {
  153. debug("break! -- outgoing: %v; incoming: %v",
  154. len(s.outgoing), len(s.incoming))
  155. break
  156. }
  157. out := s.outgoing[0]
  158. in := s.incoming[0]
  159. // debug("incoming: %v", ids(s.incoming))
  160. // debug("outgoing: %v", ids(s.outgoing))
  161. if in.orderID() == out.orderID() {
  162. debug("Sending packet: %v", out.id())
  163. s.sender.sendPacket(out.(encoding.BinaryMarshaler))
  164. if s.alloc != nil {
  165. // mark for reuse the slices allocated for this request
  166. s.alloc.ReleasePages(in.orderID())
  167. }
  168. // pop off heads
  169. copy(s.incoming, s.incoming[1:]) // shift left
  170. s.incoming[len(s.incoming)-1] = nil // clear last
  171. s.incoming = s.incoming[:len(s.incoming)-1] // remove last
  172. copy(s.outgoing, s.outgoing[1:]) // shift left
  173. s.outgoing[len(s.outgoing)-1] = nil // clear last
  174. s.outgoing = s.outgoing[:len(s.outgoing)-1] // remove last
  175. } else {
  176. break
  177. }
  178. }
  179. }
  180. // func oids(o []orderedPacket) []uint32 {
  181. // res := make([]uint32, 0, len(o))
  182. // for _, v := range o {
  183. // res = append(res, v.orderId())
  184. // }
  185. // return res
  186. // }
  187. // func ids(o []orderedPacket) []uint32 {
  188. // res := make([]uint32, 0, len(o))
  189. // for _, v := range o {
  190. // res = append(res, v.id())
  191. // }
  192. // return res
  193. // }