transport.go 9.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373
  1. // Copyright 2011 The Go Authors. All rights reserved.
  2. // Use of this source code is governed by a BSD-style
  3. // license that can be found in the LICENSE file.
  4. package ssh
  5. import (
  6. "bufio"
  7. "bytes"
  8. "errors"
  9. "io"
  10. "log"
  11. )
  12. // debugTransport if set, will print packet types as they go over the
  13. // wire. No message decoding is done, to minimize the impact on timing.
  14. const debugTransport = false
  15. // packetConn represents a transport that implements packet based
  16. // operations.
  17. type packetConn interface {
  18. // Encrypt and send a packet of data to the remote peer.
  19. writePacket(packet []byte) error
  20. // Read a packet from the connection. The read is blocking,
  21. // i.e. if error is nil, then the returned byte slice is
  22. // always non-empty.
  23. readPacket() ([]byte, error)
  24. // Close closes the write-side of the connection.
  25. Close() error
  26. }
  27. // transport is the keyingTransport that implements the SSH packet
  28. // protocol.
  29. type transport struct {
  30. reader connectionState
  31. writer connectionState
  32. bufReader *bufio.Reader
  33. bufWriter *bufio.Writer
  34. rand io.Reader
  35. isClient bool
  36. io.Closer
  37. strictMode bool
  38. initialKEXDone bool
  39. }
  40. // packetCipher represents a combination of SSH encryption/MAC
  41. // protocol. A single instance should be used for one direction only.
  42. type packetCipher interface {
  43. // writeCipherPacket encrypts the packet and writes it to w. The
  44. // contents of the packet are generally scrambled.
  45. writeCipherPacket(seqnum uint32, w io.Writer, rand io.Reader, packet []byte) error
  46. // readCipherPacket reads and decrypts a packet of data. The
  47. // returned packet may be overwritten by future calls of
  48. // readPacket.
  49. readCipherPacket(seqnum uint32, r io.Reader) ([]byte, error)
  50. }
  51. // connectionState represents one side (read or write) of the
  52. // connection. This is necessary because each direction has its own
  53. // keys, and can even have its own algorithms
  54. type connectionState struct {
  55. packetCipher
  56. seqNum uint32
  57. dir direction
  58. pendingKeyChange chan packetCipher
  59. }
  60. func (t *transport) setStrictMode() error {
  61. if t.reader.seqNum != 1 {
  62. return errors.New("ssh: sequence number != 1 when strict KEX mode requested")
  63. }
  64. t.strictMode = true
  65. return nil
  66. }
  67. func (t *transport) setInitialKEXDone() {
  68. t.initialKEXDone = true
  69. }
  70. // prepareKeyChange sets up key material for a keychange. The key changes in
  71. // both directions are triggered by reading and writing a msgNewKey packet
  72. // respectively.
  73. func (t *transport) prepareKeyChange(algs *NegotiatedAlgorithms, kexResult *kexResult) error {
  74. ciph, err := newPacketCipher(t.reader.dir, algs.Read, kexResult)
  75. if err != nil {
  76. return err
  77. }
  78. t.reader.pendingKeyChange <- ciph
  79. ciph, err = newPacketCipher(t.writer.dir, algs.Write, kexResult)
  80. if err != nil {
  81. return err
  82. }
  83. t.writer.pendingKeyChange <- ciph
  84. return nil
  85. }
  86. func (t *transport) printPacket(p []byte, write bool) {
  87. if len(p) == 0 {
  88. return
  89. }
  90. who := "server"
  91. if t.isClient {
  92. who = "client"
  93. }
  94. what := "read"
  95. if write {
  96. what = "write"
  97. }
  98. log.Println(what, who, p[0])
  99. }
  100. // Read and decrypt next packet.
  101. func (t *transport) readPacket() (p []byte, err error) {
  102. for {
  103. p, err = t.reader.readPacket(t.bufReader, t.strictMode)
  104. if err != nil {
  105. break
  106. }
  107. // in strict mode we pass through DEBUG and IGNORE packets only during the initial KEX
  108. if len(p) == 0 || (t.strictMode && !t.initialKEXDone) || (p[0] != msgIgnore && p[0] != msgDebug) {
  109. break
  110. }
  111. }
  112. if debugTransport {
  113. t.printPacket(p, false)
  114. }
  115. return p, err
  116. }
  117. func (s *connectionState) readPacket(r *bufio.Reader, strictMode bool) ([]byte, error) {
  118. packet, err := s.packetCipher.readCipherPacket(s.seqNum, r)
  119. s.seqNum++
  120. if err == nil && len(packet) == 0 {
  121. err = errors.New("ssh: zero length packet")
  122. }
  123. if len(packet) > 0 {
  124. switch packet[0] {
  125. case msgNewKeys:
  126. select {
  127. case cipher := <-s.pendingKeyChange:
  128. s.packetCipher = cipher
  129. if strictMode {
  130. s.seqNum = 0
  131. }
  132. default:
  133. return nil, errors.New("ssh: got bogus newkeys message")
  134. }
  135. case msgDisconnect:
  136. // Transform a disconnect message into an
  137. // error. Since this is lowest level at which
  138. // we interpret message types, doing it here
  139. // ensures that we don't have to handle it
  140. // elsewhere.
  141. var msg disconnectMsg
  142. if err := Unmarshal(packet, &msg); err != nil {
  143. return nil, err
  144. }
  145. return nil, &msg
  146. }
  147. }
  148. // The packet may point to an internal buffer, so copy the
  149. // packet out here.
  150. fresh := make([]byte, len(packet))
  151. copy(fresh, packet)
  152. return fresh, err
  153. }
  154. func (t *transport) writePacket(packet []byte) error {
  155. if debugTransport {
  156. t.printPacket(packet, true)
  157. }
  158. return t.writer.writePacket(t.bufWriter, t.rand, packet, t.strictMode)
  159. }
  160. func (s *connectionState) writePacket(w *bufio.Writer, rand io.Reader, packet []byte, strictMode bool) error {
  161. changeKeys := len(packet) > 0 && packet[0] == msgNewKeys
  162. err := s.packetCipher.writeCipherPacket(s.seqNum, w, rand, packet)
  163. if err != nil {
  164. return err
  165. }
  166. if err = w.Flush(); err != nil {
  167. return err
  168. }
  169. s.seqNum++
  170. if changeKeys {
  171. select {
  172. case cipher := <-s.pendingKeyChange:
  173. s.packetCipher = cipher
  174. if strictMode {
  175. s.seqNum = 0
  176. }
  177. default:
  178. panic("ssh: no key material for msgNewKeys")
  179. }
  180. }
  181. return err
  182. }
  183. func newTransport(rwc io.ReadWriteCloser, rand io.Reader, isClient bool) *transport {
  184. t := &transport{
  185. bufReader: bufio.NewReader(rwc),
  186. bufWriter: bufio.NewWriter(rwc),
  187. rand: rand,
  188. reader: connectionState{
  189. packetCipher: &streamPacketCipher{cipher: noneCipher{}},
  190. pendingKeyChange: make(chan packetCipher, 1),
  191. },
  192. writer: connectionState{
  193. packetCipher: &streamPacketCipher{cipher: noneCipher{}},
  194. pendingKeyChange: make(chan packetCipher, 1),
  195. },
  196. Closer: rwc,
  197. }
  198. t.isClient = isClient
  199. if isClient {
  200. t.reader.dir = serverKeys
  201. t.writer.dir = clientKeys
  202. } else {
  203. t.reader.dir = clientKeys
  204. t.writer.dir = serverKeys
  205. }
  206. return t
  207. }
  208. type direction struct {
  209. ivTag []byte
  210. keyTag []byte
  211. macKeyTag []byte
  212. }
  213. var (
  214. serverKeys = direction{[]byte{'B'}, []byte{'D'}, []byte{'F'}}
  215. clientKeys = direction{[]byte{'A'}, []byte{'C'}, []byte{'E'}}
  216. )
  217. // setupKeys sets the cipher and MAC keys from kex.K, kex.H and sessionId, as
  218. // described in RFC 4253, section 6.4. direction should either be serverKeys
  219. // (to setup server->client keys) or clientKeys (for client->server keys).
  220. func newPacketCipher(d direction, algs DirectionAlgorithms, kex *kexResult) (packetCipher, error) {
  221. cipherMode := cipherModes[algs.Cipher]
  222. iv := make([]byte, cipherMode.ivSize)
  223. key := make([]byte, cipherMode.keySize)
  224. generateKeyMaterial(iv, d.ivTag, kex)
  225. generateKeyMaterial(key, d.keyTag, kex)
  226. var macKey []byte
  227. if !aeadCiphers[algs.Cipher] {
  228. macMode := macModes[algs.MAC]
  229. macKey = make([]byte, macMode.keySize)
  230. generateKeyMaterial(macKey, d.macKeyTag, kex)
  231. }
  232. return cipherModes[algs.Cipher].create(key, iv, macKey, algs)
  233. }
  234. // generateKeyMaterial fills out with key material generated from tag, K, H
  235. // and sessionId, as specified in RFC 4253, section 7.2.
  236. func generateKeyMaterial(out, tag []byte, r *kexResult) {
  237. var digestsSoFar []byte
  238. h := r.Hash.New()
  239. for len(out) > 0 {
  240. h.Reset()
  241. h.Write(r.K)
  242. h.Write(r.H)
  243. if len(digestsSoFar) == 0 {
  244. h.Write(tag)
  245. h.Write(r.SessionID)
  246. } else {
  247. h.Write(digestsSoFar)
  248. }
  249. digest := h.Sum(nil)
  250. n := copy(out, digest)
  251. out = out[n:]
  252. if len(out) > 0 {
  253. digestsSoFar = append(digestsSoFar, digest...)
  254. }
  255. }
  256. }
  257. const packageVersion = "SSH-2.0-Go"
  258. // Sends and receives a version line. The versionLine string should
  259. // be US ASCII, start with "SSH-2.0-", and should not include a
  260. // newline. exchangeVersions returns the other side's version line.
  261. func exchangeVersions(rw io.ReadWriter, versionLine []byte) (them []byte, err error) {
  262. // Contrary to the RFC, we do not ignore lines that don't
  263. // start with "SSH-2.0-" to make the library usable with
  264. // nonconforming servers.
  265. for _, c := range versionLine {
  266. // The spec disallows non US-ASCII chars, and
  267. // specifically forbids null chars.
  268. if c < 32 {
  269. return nil, errors.New("ssh: junk character in version line")
  270. }
  271. }
  272. if _, err = rw.Write(append(versionLine, '\r', '\n')); err != nil {
  273. return
  274. }
  275. them, err = readVersion(rw)
  276. return them, err
  277. }
  278. // maxVersionStringBytes is the maximum number of bytes that we'll
  279. // accept as a version string. RFC 4253 section 4.2 limits this at 255
  280. // chars
  281. const maxVersionStringBytes = 255
  282. // Read version string as specified by RFC 4253, section 4.2.
  283. func readVersion(r io.Reader) ([]byte, error) {
  284. versionString := make([]byte, 0, 64)
  285. var ok bool
  286. var buf [1]byte
  287. for length := 0; length < maxVersionStringBytes; length++ {
  288. _, err := io.ReadFull(r, buf[:])
  289. if err != nil {
  290. return nil, err
  291. }
  292. // The RFC says that the version should be terminated with \r\n
  293. // but several SSH servers actually only send a \n.
  294. if buf[0] == '\n' {
  295. if !bytes.HasPrefix(versionString, []byte("SSH-")) {
  296. // RFC 4253 says we need to ignore all version string lines
  297. // except the one containing the SSH version (provided that
  298. // all the lines do not exceed 255 bytes in total).
  299. versionString = versionString[:0]
  300. continue
  301. }
  302. ok = true
  303. break
  304. }
  305. // non ASCII chars are disallowed, but we are lenient,
  306. // since Go doesn't use null-terminated strings.
  307. // The RFC allows a comment after a space, however,
  308. // all of it (version and comments) goes into the
  309. // session hash.
  310. versionString = append(versionString, buf[0])
  311. }
  312. if !ok {
  313. return nil, errors.New("ssh: overflow reading version string")
  314. }
  315. // There might be a '\r' on the end which we should remove.
  316. if len(versionString) > 0 && versionString[len(versionString)-1] == '\r' {
  317. versionString = versionString[:len(versionString)-1]
  318. }
  319. return versionString, nil
  320. }