| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333 |
- // Copyright 2019 Yunion
- //
- // Licensed under the Apache License, Version 2.0 (the "License");
- // you may not use this file except in compliance with the License.
- // You may obtain a copy of the License at
- //
- // http://www.apache.org/licenses/LICENSE-2.0
- //
- // Unless required by applicable law or agreed to in writing, software
- // distributed under the License is distributed on an "AS IS" BASIS,
- // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- // See the License for the specific language governing permissions and
- // limitations under the License.
- package ssh
- import (
- "bytes"
- "context"
- "fmt"
- "io"
- "net"
- "os"
- "strconv"
- "strings"
- "time"
- "golang.org/x/crypto/ssh"
- "golang.org/x/crypto/ssh/terminal"
- "yunion.io/x/log"
- "yunion.io/x/pkg/errors"
- )
- const (
- ErrBadConfig = errors.Error("bad config")
- ErrNetwork = errors.Error("network error")
- ErrProtocol = errors.Error("ssh protocol error")
- )
- type ClientConfig struct {
- Username string
- Password string
- Host string
- Port int
- PrivateKey string
- }
- func parsePrivateKey(keyBuff string) (ssh.Signer, error) {
- return ssh.ParsePrivateKey([]byte(keyBuff))
- }
- func (conf ClientConfig) ToSshConfig() (*ssh.ClientConfig, error) {
- cliConfig := &ssh.ClientConfig{
- User: conf.Username,
- HostKeyCallback: ssh.InsecureIgnoreHostKey(),
- Timeout: 15 * time.Second,
- }
- auths := make([]ssh.AuthMethod, 0)
- if conf.Password != "" {
- auths = append(auths, ssh.Password(conf.Password))
- }
- if conf.PrivateKey != "" {
- signer, err := parsePrivateKey(conf.PrivateKey)
- if err != nil {
- return nil, errors.Wrapf(ErrBadConfig, "parse private key: %v", err)
- }
- auths = append(auths, ssh.PublicKeys(signer))
- }
- cliConfig.Auth = auths
- return cliConfig, nil
- }
- func (conf ClientConfig) Connect() (*ssh.Client, error) {
- cliConfig, err := conf.ToSshConfig()
- if err != nil {
- return nil, err
- }
- addr := net.JoinHostPort(conf.Host, strconv.Itoa(conf.Port))
- client, err := ssh.Dial("tcp", addr, cliConfig)
- if err != nil {
- return nil, err
- }
- return client, nil
- }
- func (conf ClientConfig) ConnectContext(ctx context.Context) (*ssh.Client, error) {
- cliConfig, err := conf.ToSshConfig()
- if err != nil {
- return nil, err
- }
- addr := net.JoinHostPort(conf.Host, strconv.Itoa(conf.Port))
- d := &net.Dialer{}
- netconn, err := d.DialContext(ctx, "tcp", addr)
- if err != nil {
- return nil, errors.Wrapf(ErrNetwork, "tcp dial: %v", err)
- }
- sshconn, chans, reqs, err := ssh.NewClientConn(netconn, addr, cliConfig)
- if err != nil {
- netconn.Close()
- return nil, errors.Wrap(ErrProtocol, err.Error())
- }
- sshc := ssh.NewClient(sshconn, chans, reqs)
- return sshc, nil
- }
- type Client struct {
- config ClientConfig
- client *ssh.Client
- }
- func (conf ClientConfig) NewClient() (*Client, error) {
- cli, err := conf.Connect()
- if err != nil {
- return nil, err
- }
- return &Client{
- config: conf,
- client: cli,
- }, nil
- }
- func NewClient(
- host string,
- port int,
- username string,
- password string,
- privateKey string,
- ) (*Client, error) {
- config := &ClientConfig{
- Host: host,
- Port: port,
- Username: username,
- Password: password,
- PrivateKey: privateKey,
- }
- return config.NewClient()
- }
- func (s *Client) GetConfig() ClientConfig {
- return s.config
- }
- func (s *Client) RawRun(cmds ...string) ([]string, error) {
- return s.run(false, cmds, nil, false)
- }
- func (s *Client) RunCmd(cmd string) ([]string, error) {
- return s.Run(cmd)
- }
- func (s *Client) Run(cmds ...string) ([]string, error) {
- return s.run(true, cmds, nil, false)
- }
- func (s *Client) RunWithInput(input io.Reader, cmds ...string) ([]string, error) {
- return s.run(true, cmds, input, false)
- }
- // RunWithTTY request Pty before run command.
- func (s *Client) RunWithTTY(cmds ...string) ([]string, error) {
- return s.run(false, cmds, nil, true)
- }
- func (s *Client) run(parseOutput bool, cmds []string, input io.Reader, withPty bool) ([]string, error) {
- ret := []string{}
- for _, cmd := range cmds {
- session, err := s.client.NewSession()
- if err != nil {
- return nil, err
- }
- defer session.Close()
- if withPty {
- modes := ssh.TerminalModes{
- ssh.ECHO: 1, // enable echoing
- ssh.TTY_OP_ISPEED: 14400, // input speed = 14.4kbaud
- ssh.TTY_OP_OSPEED: 14400, // output speed = 14.4kbaud
- }
- if err := session.RequestPty("xterm", 24, 80, modes); err != nil {
- return nil, errors.Wrap(err, "Setup TTY")
- }
- }
- log.Debugf("Run command(%s@%s): %s", s.config.Username, s.config.Host, cmd)
- var stdOut bytes.Buffer
- var stdErr bytes.Buffer
- session.Stdout = &stdOut
- session.Stderr = &stdErr
- session.Stdin = input
- err = session.Run(cmd)
- if err != nil {
- var outputErr error
- errMsg := stdErr.String()
- if len(stdOut.String()) != 0 {
- errMsg = fmt.Sprintf("%s %s", errMsg, stdOut.String())
- }
- outputErr = errors.Error(errMsg)
- err = errors.Wrapf(outputErr, "%q error: %v, cmd error", cmd, err)
- return nil, err
- }
- if parseOutput {
- ret = append(ret, ParseOutput(stdOut.Bytes())...)
- } else {
- ret = append(ret, stdOut.String())
- }
- }
- return ret, nil
- }
- func ParseOutput(output []byte) []string {
- lines := make([]string, 0)
- for _, line := range strings.Split(string(output), "\n") {
- lines = append(lines, strings.TrimSpace(line))
- }
- return lines
- }
- func (s *Client) Close() {
- s.client.Close()
- }
- func updateTermSize(session *ssh.Session, quit <-chan int) {
- sigwinchCh := make(chan os.Signal, 1)
- setsignal(sigwinchCh)
- fd := int(os.Stdin.Fd())
- width, height, err := terminal.GetSize(fd)
- if err != nil {
- log.Errorf("get terminal size: %v", err)
- }
- for {
- select {
- case <-quit:
- return
- case sigwinCh := <-sigwinchCh:
- if sigwinCh == nil {
- <-quit
- return
- }
- termWidth, termHeight, err := terminal.GetSize(fd)
- if err != nil {
- log.Errorf("get terminal size: %v", err)
- }
- if termHeight == height && termWidth == width {
- continue
- }
- err = session.WindowChange(termHeight, termWidth)
- if err != nil {
- log.Errorf("send window-change request: %v", err)
- continue
- }
- width = termWidth
- height = termHeight
- }
- }
- }
- func (s *Client) RunTerminal() error {
- defer s.Close()
- session, err := s.client.NewSession()
- if err != nil {
- return errors.Wrap(err, "open new session")
- }
- defer session.Close()
- fd := int(os.Stdin.Fd())
- state, err := terminal.MakeRaw(fd)
- if err != nil {
- return errors.Wrap(err, "make raw terminal")
- }
- defer terminal.Restore(fd, state)
- w, h, err := terminal.GetSize(fd)
- if err != nil {
- return errors.Wrap(err, "get terminal size")
- }
- modes := ssh.TerminalModes{
- ssh.ECHO: 1,
- ssh.TTY_OP_ISPEED: 14400,
- ssh.TTY_OP_OSPEED: 14400,
- }
- term := os.Getenv("TERM")
- if term == "" {
- term = "xterm-256color"
- }
- if err := session.RequestPty(term, h, w, modes); err != nil {
- return errors.Wrap(err, "session xterm")
- }
- session.Stdout = os.Stdout
- session.Stderr = os.Stderr
- session.Stdin = os.Stdin
- if err := session.Shell(); err != nil {
- return errors.Wrap(err, "session shell")
- }
- quit := make(chan int)
- go updateTermSize(session, quit)
- if err := session.Wait(); err != nil {
- if e, ok := err.(*ssh.ExitError); ok {
- switch e.ExitStatus() {
- case 130:
- quit <- 1
- return nil
- }
- }
- quit <- 1
- return errors.Wrap(err, "ssh wait")
- }
- quit <- 1
- return nil
- }
- func IsExitMissingError(err error) bool {
- errStr := new(ssh.ExitMissingError).Error()
- if strings.Contains(err.Error(), errStr) {
- return true
- }
- return false
- }
|