client.go 8.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384
  1. // Copyright 2019 Yunion
  2. //
  3. // Licensed under the Apache License, Version 2.0 (the "License");
  4. // you may not use this file except in compliance with the License.
  5. // You may obtain a copy of the License at
  6. //
  7. // http://www.apache.org/licenses/LICENSE-2.0
  8. //
  9. // Unless required by applicable law or agreed to in writing, software
  10. // distributed under the License is distributed on an "AS IS" BASIS,
  11. // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. // See the License for the specific language governing permissions and
  13. // limitations under the License.
  14. package ssh
  15. import (
  16. "context"
  17. "fmt"
  18. "net"
  19. "strconv"
  20. "sync"
  21. "time"
  22. "golang.org/x/crypto/ssh"
  23. "yunion.io/x/log"
  24. "yunion.io/x/pkg/errors"
  25. "yunion.io/x/pkg/util/sets"
  26. ssh_util "yunion.io/x/onecloud/pkg/util/ssh"
  27. )
  28. type addrMap map[string]interface{}
  29. type portMap map[int]addrMap
  30. func (pm portMap) contains(port int, addr string) bool {
  31. am, ok := pm[port]
  32. if !ok {
  33. return false
  34. }
  35. return am.contains(addr)
  36. }
  37. func (pm portMap) get(port int, addr string) interface{} {
  38. am, ok := pm[port]
  39. if ok {
  40. return am.get(addr)
  41. }
  42. return nil
  43. }
  44. func (pm portMap) set(port int, addr string, v interface{}) {
  45. am, ok := pm[port]
  46. if !ok {
  47. am = addrMap{}
  48. pm[port] = am
  49. }
  50. am.set(addr, v)
  51. }
  52. func (pm portMap) delete(port int, addr string) {
  53. if am, ok := pm[port]; ok {
  54. am.delete(addr)
  55. }
  56. }
  57. func (am addrMap) contains(addr string) bool {
  58. const (
  59. ip4wild = "0.0.0.0"
  60. ip6wild = "::"
  61. )
  62. _, ok := am[addr]
  63. if ok {
  64. return true
  65. }
  66. if _, ok := am[ip4wild]; ok {
  67. return true
  68. }
  69. if _, ok := am[ip6wild]; ok {
  70. return true
  71. }
  72. return false
  73. }
  74. func (am addrMap) get(addr string) interface{} {
  75. return am[addr]
  76. }
  77. func (am addrMap) set(addr string, v interface{}) {
  78. am[addr] = v
  79. }
  80. func (am addrMap) delete(addr string) {
  81. delete(am, addr)
  82. }
  83. type Client struct {
  84. cc *ssh_util.ClientConfig
  85. stopc chan sets.Empty
  86. stopcEx *sync.Mutex
  87. stopcc bool
  88. lfc chan LocalForwardReq
  89. rfc chan RemoteForwardReq
  90. lfclosec chan LocalForwardReq
  91. rfclosec chan RemoteForwardReq
  92. localForwards portMap
  93. remoteForwards portMap
  94. }
  95. func NewClient(cc *ssh_util.ClientConfig) *Client {
  96. c := &Client{
  97. cc: cc,
  98. stopc: make(chan sets.Empty),
  99. stopcEx: &sync.Mutex{},
  100. lfc: make(chan LocalForwardReq),
  101. rfc: make(chan RemoteForwardReq),
  102. lfclosec: make(chan LocalForwardReq),
  103. rfclosec: make(chan RemoteForwardReq),
  104. localForwards: portMap{},
  105. remoteForwards: portMap{},
  106. }
  107. return c
  108. }
  109. func (c *Client) Stop(ctx context.Context) {
  110. c.stopcEx.Lock()
  111. defer c.stopcEx.Unlock()
  112. if !c.stopcc {
  113. close(c.stopc)
  114. c.stopcc = true
  115. }
  116. }
  117. func (c *Client) Start(ctx context.Context) {
  118. ctx, cancelFunc := context.WithCancel(ctx)
  119. defer cancelFunc()
  120. pingT := time.NewTimer(17 * time.Second)
  121. pingFailCount := 0
  122. const pingMaxFail = 3
  123. sshClientC := make(chan *ssh.Client)
  124. var sshClient *ssh.Client
  125. go c.runClientState(ctx, sshClientC)
  126. for {
  127. select {
  128. case sshc := <-sshClientC:
  129. conn := sshc.Conn
  130. localAddr := conn.LocalAddr()
  131. localAddrStr := localAddr.String()
  132. addr, portStr, err := net.SplitHostPort(localAddrStr)
  133. if err != nil {
  134. log.Errorf("split host port of ssh client local addr: %v", err)
  135. sshc.Close()
  136. break
  137. }
  138. port, err := strconv.ParseUint(portStr, 10, 16)
  139. if err != nil {
  140. log.Errorf("parse ssh client local port: %v", err)
  141. sshc.Close()
  142. break
  143. }
  144. if v := c.localForwards.get(int(port), addr); v != nil {
  145. log.Errorf("ssh client local port %d collides with local forward: %#v", port, v)
  146. sshc.Close()
  147. break
  148. }
  149. sshClient = sshc
  150. case req := <-c.lfc:
  151. if sshClient != nil {
  152. c.localForward(ctx, sshClient, req)
  153. }
  154. case req := <-c.rfc:
  155. if sshClient != nil {
  156. c.remoteForward(ctx, sshClient, req)
  157. }
  158. case req := <-c.lfclosec:
  159. c.localForwardClose(ctx, req)
  160. case req := <-c.rfclosec:
  161. c.remoteForwardClose(ctx, req)
  162. case <-pingT.C:
  163. //TODO ping check
  164. //ping fail
  165. if pingFailCount > pingMaxFail {
  166. }
  167. case <-c.stopc:
  168. return
  169. case <-ctx.Done():
  170. return
  171. }
  172. }
  173. }
  174. func (c *Client) runClientState(ctx context.Context, sshClientC chan<- *ssh.Client) {
  175. for {
  176. select {
  177. case <-ctx.Done():
  178. return
  179. default:
  180. }
  181. cc := c.cc
  182. tmoCtx, _ := context.WithTimeout(ctx, 31*time.Second)
  183. sshc, err := cc.ConnectContext(tmoCtx)
  184. if err != nil {
  185. log.Errorf("ssh connect: %s@%s, port %d: %v", cc.Username, cc.Host, cc.Port, err)
  186. waitTmo := time.NewTimer(13 * time.Second)
  187. select {
  188. case <-ctx.Done():
  189. return
  190. case <-waitTmo.C:
  191. }
  192. continue
  193. }
  194. func() {
  195. defer sshc.Conn.Close()
  196. closeC := make(chan struct{})
  197. go func() {
  198. defer close(closeC)
  199. err := sshc.Conn.Wait()
  200. if err != nil {
  201. log.Infof("ssh client conn: %v", err)
  202. }
  203. }()
  204. select {
  205. case sshClientC <- sshc:
  206. case <-closeC:
  207. case <-ctx.Done():
  208. return
  209. }
  210. select {
  211. case <-closeC:
  212. case <-ctx.Done():
  213. return
  214. }
  215. }()
  216. }
  217. }
  218. func (c *Client) connect(ctx context.Context) (*ssh.Client, error) {
  219. sshc, err := c.cc.ConnectContext(ctx)
  220. return sshc, err
  221. }
  222. func (c *Client) LocalForward(ctx context.Context, req LocalForwardReq) {
  223. select {
  224. case c.lfc <- req:
  225. case <-ctx.Done():
  226. }
  227. }
  228. func (c *Client) localForward(ctx context.Context, sshc *ssh.Client, req LocalForwardReq) {
  229. if err := c.localForward_(ctx, sshc, req); err != nil {
  230. log.Errorf("local forward: %v", err)
  231. }
  232. }
  233. func (c *Client) localForward_(ctx context.Context, sshc *ssh.Client, req LocalForwardReq) error {
  234. // check LocalAddr/LocalPort existence
  235. if c.localForwards.contains(req.LocalPort, req.LocalAddr) {
  236. return errors.Errorf("local addr occupied: %s:%d", req.LocalAddr, req.LocalPort)
  237. }
  238. addr := net.JoinHostPort(req.LocalAddr, fmt.Sprintf("%d", req.LocalPort))
  239. listener, err := net.Listen("tcp", addr)
  240. if err != nil {
  241. return errors.Wrapf(err, "tcp listen %s", addr)
  242. }
  243. fwd := &forwarder{
  244. listener: listener,
  245. dial: sshc.Dial,
  246. dialAddr: req.RemoteAddr,
  247. dialPort: req.RemotePort,
  248. done: c.localForwardDone,
  249. doneAddr: req.LocalAddr,
  250. donePort: req.LocalPort,
  251. tick: req.Tick,
  252. tickCb: req.TickCb,
  253. }
  254. c.localForwards.set(req.LocalPort, req.LocalAddr, fwd)
  255. go fwd.Start(ctx)
  256. return nil
  257. }
  258. func (c *Client) localForwardDone(laddr string, lport int) {
  259. c.localForwards.delete(lport, laddr)
  260. }
  261. func (c *Client) RemoteForward(ctx context.Context, req RemoteForwardReq) {
  262. select {
  263. case c.rfc <- req:
  264. case <-ctx.Done():
  265. }
  266. }
  267. func (c *Client) remoteForward(ctx context.Context, sshc *ssh.Client, req RemoteForwardReq) {
  268. if err := c.remoteForward_(ctx, sshc, req); err != nil {
  269. log.Errorf("remote forward: %v", err)
  270. }
  271. }
  272. func (c *Client) remoteForward_(ctx context.Context, sshc *ssh.Client, req RemoteForwardReq) error {
  273. // check RemoteAddr/RemotePort existence
  274. if c.remoteForwards.contains(req.RemotePort, req.RemoteAddr) {
  275. return errors.Errorf("remote addr occupied: %s:%d", req.RemoteAddr, req.RemotePort)
  276. }
  277. addr := net.JoinHostPort(req.RemoteAddr, fmt.Sprintf("%d", req.RemotePort))
  278. listener, err := sshc.Listen("tcp", addr)
  279. if err != nil {
  280. return errors.Wrapf(err, "ssh listen %s", addr)
  281. }
  282. fwd := &forwarder{
  283. listener: listener,
  284. dial: net.Dial,
  285. dialAddr: req.LocalAddr,
  286. dialPort: req.LocalPort,
  287. done: c.remoteForwardDone,
  288. doneAddr: req.RemoteAddr,
  289. donePort: req.RemotePort,
  290. tick: req.Tick,
  291. tickCb: req.TickCb,
  292. }
  293. c.remoteForwards.set(req.RemotePort, req.RemoteAddr, fwd)
  294. go fwd.Start(ctx)
  295. return nil
  296. }
  297. func (c *Client) remoteForwardDone(raddr string, rport int) {
  298. c.remoteForwards.delete(rport, raddr)
  299. }
  300. func (c *Client) LocalForwardClose(ctx context.Context, req LocalForwardReq) {
  301. select {
  302. case c.lfclosec <- req:
  303. case <-ctx.Done():
  304. }
  305. }
  306. func (c *Client) localForwardClose(ctx context.Context, req LocalForwardReq) {
  307. v := c.localForwards.get(req.LocalPort, req.LocalAddr)
  308. if v != nil {
  309. fwd := v.(*forwarder)
  310. fwd.Stop(ctx)
  311. }
  312. }
  313. func (c *Client) RemoteForwardClose(ctx context.Context, req RemoteForwardReq) {
  314. select {
  315. case c.rfclosec <- req:
  316. case <-ctx.Done():
  317. }
  318. }
  319. func (c *Client) remoteForwardClose(ctx context.Context, req RemoteForwardReq) {
  320. v := c.remoteForwards.get(req.RemotePort, req.RemoteAddr)
  321. if v != nil {
  322. fwd := v.(*forwarder)
  323. fwd.Stop(ctx)
  324. }
  325. }