| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216 |
- package sftp
- import (
- "encoding"
- "sort"
- "sync"
- )
- // The goal of the packetManager is to keep the outgoing packets in the same
- // order as the incoming as is requires by section 7 of the RFC.
- type packetManager struct {
- requests chan orderedPacket
- responses chan orderedPacket
- fini chan struct{}
- incoming orderedPackets
- outgoing orderedPackets
- sender packetSender // connection object
- working *sync.WaitGroup
- packetCount uint32
- // it is not nil if the allocator is enabled
- alloc *allocator
- }
- type packetSender interface {
- sendPacket(encoding.BinaryMarshaler) error
- }
- func newPktMgr(sender packetSender) *packetManager {
- s := &packetManager{
- requests: make(chan orderedPacket, SftpServerWorkerCount),
- responses: make(chan orderedPacket, SftpServerWorkerCount),
- fini: make(chan struct{}),
- incoming: make([]orderedPacket, 0, SftpServerWorkerCount),
- outgoing: make([]orderedPacket, 0, SftpServerWorkerCount),
- sender: sender,
- working: &sync.WaitGroup{},
- }
- go s.controller()
- return s
- }
- // // packet ordering
- func (s *packetManager) newOrderID() uint32 {
- s.packetCount++
- return s.packetCount
- }
- // returns the next orderID without incrementing it.
- // This is used before receiving a new packet, with the allocator enabled, to associate
- // the slice allocated for the received packet with the orderID that will be used to mark
- // the allocated slices for reuse once the request is served
- func (s *packetManager) getNextOrderID() uint32 {
- return s.packetCount + 1
- }
- type orderedRequest struct {
- requestPacket
- orderid uint32
- }
- func (s *packetManager) newOrderedRequest(p requestPacket) orderedRequest {
- return orderedRequest{requestPacket: p, orderid: s.newOrderID()}
- }
- func (p orderedRequest) orderID() uint32 { return p.orderid }
- func (p orderedRequest) setOrderID(oid uint32) { p.orderid = oid }
- type orderedResponse struct {
- responsePacket
- orderid uint32
- }
- func (s *packetManager) newOrderedResponse(p responsePacket, id uint32,
- ) orderedResponse {
- return orderedResponse{responsePacket: p, orderid: id}
- }
- func (p orderedResponse) orderID() uint32 { return p.orderid }
- func (p orderedResponse) setOrderID(oid uint32) { p.orderid = oid }
- type orderedPacket interface {
- id() uint32
- orderID() uint32
- }
- type orderedPackets []orderedPacket
- func (o orderedPackets) Sort() {
- sort.Slice(o, func(i, j int) bool {
- return o[i].orderID() < o[j].orderID()
- })
- }
- // // packet registry
- // register incoming packets to be handled
- func (s *packetManager) incomingPacket(pkt orderedRequest) {
- s.working.Add(1)
- s.requests <- pkt
- }
- // register outgoing packets as being ready
- func (s *packetManager) readyPacket(pkt orderedResponse) {
- s.responses <- pkt
- s.working.Done()
- }
- // shut down packetManager controller
- func (s *packetManager) close() {
- // pause until current packets are processed
- s.working.Wait()
- close(s.fini)
- }
- // Passed a worker function, returns a channel for incoming packets.
- // Keep process packet responses in the order they are received while
- // maximizing throughput of file transfers.
- func (s *packetManager) workerChan(runWorker func(chan orderedRequest),
- ) chan orderedRequest {
- // multiple workers for faster read/writes
- rwChan := make(chan orderedRequest, SftpServerWorkerCount)
- for i := 0; i < SftpServerWorkerCount; i++ {
- runWorker(rwChan)
- }
- // single worker to enforce sequential processing of everything else
- cmdChan := make(chan orderedRequest)
- runWorker(cmdChan)
- pktChan := make(chan orderedRequest, SftpServerWorkerCount)
- go func() {
- for pkt := range pktChan {
- switch pkt.requestPacket.(type) {
- case *sshFxpReadPacket, *sshFxpWritePacket:
- s.incomingPacket(pkt)
- rwChan <- pkt
- continue
- case *sshFxpClosePacket:
- // wait for reads/writes to finish when file is closed
- // incomingPacket() call must occur after this
- s.working.Wait()
- }
- s.incomingPacket(pkt)
- // all non-RW use sequential cmdChan
- cmdChan <- pkt
- }
- close(rwChan)
- close(cmdChan)
- s.close()
- }()
- return pktChan
- }
- // process packets
- func (s *packetManager) controller() {
- for {
- select {
- case pkt := <-s.requests:
- debug("incoming id (oid): %v (%v)", pkt.id(), pkt.orderID())
- s.incoming = append(s.incoming, pkt)
- s.incoming.Sort()
- case pkt := <-s.responses:
- debug("outgoing id (oid): %v (%v)", pkt.id(), pkt.orderID())
- s.outgoing = append(s.outgoing, pkt)
- s.outgoing.Sort()
- case <-s.fini:
- return
- }
- s.maybeSendPackets()
- }
- }
- // send as many packets as are ready
- func (s *packetManager) maybeSendPackets() {
- for {
- if len(s.outgoing) == 0 || len(s.incoming) == 0 {
- debug("break! -- outgoing: %v; incoming: %v",
- len(s.outgoing), len(s.incoming))
- break
- }
- out := s.outgoing[0]
- in := s.incoming[0]
- // debug("incoming: %v", ids(s.incoming))
- // debug("outgoing: %v", ids(s.outgoing))
- if in.orderID() == out.orderID() {
- debug("Sending packet: %v", out.id())
- s.sender.sendPacket(out.(encoding.BinaryMarshaler))
- if s.alloc != nil {
- // mark for reuse the slices allocated for this request
- s.alloc.ReleasePages(in.orderID())
- }
- // pop off heads
- copy(s.incoming, s.incoming[1:]) // shift left
- s.incoming[len(s.incoming)-1] = nil // clear last
- s.incoming = s.incoming[:len(s.incoming)-1] // remove last
- copy(s.outgoing, s.outgoing[1:]) // shift left
- s.outgoing[len(s.outgoing)-1] = nil // clear last
- s.outgoing = s.outgoing[:len(s.outgoing)-1] // remove last
- } else {
- break
- }
- }
- }
- // func oids(o []orderedPacket) []uint32 {
- // res := make([]uint32, 0, len(o))
- // for _, v := range o {
- // res = append(res, v.orderId())
- // }
- // return res
- // }
- // func ids(o []orderedPacket) []uint32 {
- // res := make([]uint32, 0, len(o))
- // for _, v := range o {
- // res = append(res, v.id())
- // }
- // return res
- // }
|