forwarder.go 2.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147
  1. // Copyright 2019 Yunion
  2. //
  3. // Licensed under the Apache License, Version 2.0 (the "License");
  4. // you may not use this file except in compliance with the License.
  5. // You may obtain a copy of the License at
  6. //
  7. // http://www.apache.org/licenses/LICENSE-2.0
  8. //
  9. // Unless required by applicable law or agreed to in writing, software
  10. // distributed under the License is distributed on an "AS IS" BASIS,
  11. // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. // See the License for the specific language governing permissions and
  13. // limitations under the License.
  14. package ssh
  15. import (
  16. "context"
  17. "fmt"
  18. "io"
  19. "net"
  20. "time"
  21. "yunion.io/x/log"
  22. )
  23. type TickFunc func(context.Context)
  24. type LocalForwardReq struct {
  25. LocalAddr string
  26. LocalPort int
  27. RemoteAddr string
  28. RemotePort int
  29. Tick time.Duration
  30. TickCb TickFunc
  31. }
  32. type RemoteForwardReq struct {
  33. // LocalAddr is the address the forward will forward to
  34. LocalAddr string
  35. // LocalPort is the port the forward will forward to
  36. LocalPort int
  37. // RemoteAddr is the address on the remote to listen on
  38. RemoteAddr string
  39. // RemotePort is the address on the remote to listen on
  40. RemotePort int
  41. Tick time.Duration
  42. TickCb TickFunc
  43. }
  44. type dialFunc func(n, addr string) (net.Conn, error)
  45. type doneFunc func(laddr string, lport int)
  46. type forwarder struct {
  47. listener net.Listener
  48. dial dialFunc
  49. dialAddr string
  50. dialPort int
  51. done doneFunc
  52. doneAddr string
  53. donePort int
  54. tick time.Duration
  55. tickCb TickFunc
  56. }
  57. func (fwd *forwarder) Stop(ctx context.Context) {
  58. fwd.listener.Close()
  59. }
  60. func (fwd *forwarder) Start(
  61. ctx context.Context,
  62. ) {
  63. var (
  64. listener = fwd.listener
  65. dial = fwd.dial
  66. dialAddr = fwd.dialAddr
  67. dialPort = fwd.dialPort
  68. done = fwd.done
  69. doneAddr = fwd.doneAddr
  70. donePort = fwd.donePort
  71. tick = fwd.tick
  72. tickCb = fwd.tickCb
  73. )
  74. ctx, cancelFunc := context.WithCancel(ctx)
  75. if done != nil {
  76. defer done(doneAddr, donePort)
  77. }
  78. defer listener.Close()
  79. go func() { // accept local/remote connection
  80. for {
  81. conn, err := listener.Accept()
  82. if err != nil {
  83. log.Warningf("local forward: accept: %v", err)
  84. cancelFunc()
  85. break
  86. }
  87. go func(local net.Conn) {
  88. defer local.Close()
  89. // dial remote/local
  90. addr := net.JoinHostPort(dialAddr, fmt.Sprintf("%d", dialPort))
  91. remote, err := dial("tcp", addr)
  92. if err != nil {
  93. log.Warningf("local forward: dial remote: %v", err)
  94. return
  95. }
  96. defer remote.Close()
  97. // forward
  98. go io.Copy(local, remote)
  99. go io.Copy(remote, local)
  100. <-ctx.Done()
  101. }(conn)
  102. }
  103. }()
  104. if tick > 0 && tickCb != nil {
  105. go func() {
  106. ticker := time.NewTicker(tick)
  107. defer ticker.Stop()
  108. for {
  109. select {
  110. case <-ticker.C:
  111. tickCb(ctx)
  112. case <-ctx.Done():
  113. return
  114. }
  115. }
  116. }()
  117. }
  118. for {
  119. select {
  120. case <-ctx.Done():
  121. return
  122. }
  123. }
  124. }