mse.go 13 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615
  1. // https://wiki.vuze.com/w/Message_Stream_Encryption
  2. package mse
  3. import (
  4. "bytes"
  5. "context"
  6. "crypto/rand"
  7. "crypto/rc4"
  8. "crypto/sha1"
  9. "encoding/binary"
  10. "errors"
  11. "expvar"
  12. "fmt"
  13. "github.com/anacrolix/torrent/internal/ctxrw"
  14. "io"
  15. "math"
  16. "math/big"
  17. "strconv"
  18. "sync"
  19. "github.com/anacrolix/missinggo/perf"
  20. )
  21. const (
  22. maxPadLen = 512
  23. CryptoMethodPlaintext CryptoMethod = 1 // After header obfuscation, drop into plaintext
  24. CryptoMethodRC4 CryptoMethod = 2 // After header obfuscation, use RC4 for the rest of the stream
  25. AllSupportedCrypto = CryptoMethodPlaintext | CryptoMethodRC4
  26. )
  27. type CryptoMethod uint32
  28. var (
  29. // Prime P according to the spec, and G, the generator.
  30. p, specG big.Int
  31. // The rand.Int max arg for use in newPadLen()
  32. newPadLenMax big.Int
  33. // For use in initer's hashes
  34. req1 = []byte("req1")
  35. req2 = []byte("req2")
  36. req3 = []byte("req3")
  37. // Verification constant "VC" which is all zeroes in the bittorrent
  38. // implementation.
  39. vc [8]byte
  40. // Zero padding
  41. zeroPad [512]byte
  42. // Tracks counts of received crypto_provides
  43. cryptoProvidesCount = expvar.NewMap("mseCryptoProvides")
  44. )
  45. func init() {
  46. p.SetString("0xFFFFFFFFFFFFFFFFC90FDAA22168C234C4C6628B80DC1CD129024E088A67CC74020BBEA63B139B22514A08798E3404DDEF9519B3CD3A431B302B0A6DF25F14374FE1356D6D51C245E485B576625E7EC6F44C42E9A63A36210000000000090563", 0)
  47. specG.SetInt64(2)
  48. newPadLenMax.SetInt64(maxPadLen + 1)
  49. }
  50. func hash(parts ...[]byte) []byte {
  51. h := sha1.New()
  52. for _, p := range parts {
  53. n, err := h.Write(p)
  54. if err != nil {
  55. panic(err)
  56. }
  57. if n != len(p) {
  58. panic(n)
  59. }
  60. }
  61. return h.Sum(nil)
  62. }
  63. func newEncrypt(initer bool, s, skey []byte) (c *rc4.Cipher) {
  64. c, err := rc4.NewCipher(hash([]byte(func() string {
  65. if initer {
  66. return "keyA"
  67. } else {
  68. return "keyB"
  69. }
  70. }()), s, skey))
  71. if err != nil {
  72. panic(err)
  73. }
  74. var burnSrc, burnDst [1024]byte
  75. c.XORKeyStream(burnDst[:], burnSrc[:])
  76. return
  77. }
  78. type cipherReader struct {
  79. c *rc4.Cipher
  80. r io.Reader
  81. be []byte
  82. }
  83. func (cr *cipherReader) Read(b []byte) (n int, err error) {
  84. if cap(cr.be) < len(b) {
  85. cr.be = make([]byte, len(b))
  86. }
  87. n, err = cr.r.Read(cr.be[:len(b)])
  88. cr.c.XORKeyStream(b[:n], cr.be[:n])
  89. return
  90. }
  91. func newCipherReader(c *rc4.Cipher, r io.Reader) io.Reader {
  92. return &cipherReader{c: c, r: r}
  93. }
  94. type cipherWriter struct {
  95. c *rc4.Cipher
  96. w io.Writer
  97. b []byte
  98. }
  99. func (cr *cipherWriter) Write(b []byte) (n int, err error) {
  100. be := func() []byte {
  101. if len(cr.b) < len(b) {
  102. return make([]byte, len(b))
  103. } else {
  104. ret := cr.b
  105. cr.b = nil
  106. return ret
  107. }
  108. }()
  109. cr.c.XORKeyStream(be, b)
  110. n, err = cr.w.Write(be[:len(b)])
  111. if n != len(b) {
  112. // The cipher will have advanced beyond the callers stream position.
  113. // We can't use the cipher anymore.
  114. cr.c = nil
  115. }
  116. if len(be) > len(cr.b) {
  117. cr.b = be
  118. }
  119. return
  120. }
  121. func newX() big.Int {
  122. var X big.Int
  123. X.SetBytes(func() []byte {
  124. var b [20]byte
  125. _, err := rand.Read(b[:])
  126. if err != nil {
  127. panic(err)
  128. }
  129. return b[:]
  130. }())
  131. return X
  132. }
  133. func paddedLeft(b []byte, _len int) []byte {
  134. if len(b) == _len {
  135. return b
  136. }
  137. ret := make([]byte, _len)
  138. if n := copy(ret[_len-len(b):], b); n != len(b) {
  139. panic(n)
  140. }
  141. return ret
  142. }
  143. // Calculate, and send Y, our public key.
  144. func (h *handshake) postY(x *big.Int) error {
  145. var y big.Int
  146. y.Exp(&specG, x, &p)
  147. return h.postWrite(paddedLeft(y.Bytes(), 96))
  148. }
  149. func (h *handshake) establishS() error {
  150. x := newX()
  151. h.postY(&x)
  152. var b [96]byte
  153. _, err := io.ReadFull(h.ctxConn, b[:])
  154. if err != nil {
  155. return fmt.Errorf("error reading Y: %w", err)
  156. }
  157. var Y, S big.Int
  158. Y.SetBytes(b[:])
  159. S.Exp(&Y, &x, &p)
  160. sBytes := S.Bytes()
  161. copy(h.s[96-len(sBytes):96], sBytes)
  162. return nil
  163. }
  164. func newPadLen() int64 {
  165. i, err := rand.Int(rand.Reader, &newPadLenMax)
  166. if err != nil {
  167. panic(err)
  168. }
  169. ret := i.Int64()
  170. if ret < 0 || ret > maxPadLen {
  171. panic(ret)
  172. }
  173. return ret
  174. }
  175. // Manages state for both initiating and receiving handshakes.
  176. type handshake struct {
  177. conn io.ReadWriter
  178. // The conn with Reads and Writes wrapped to the context given in handshake.Do.
  179. ctxConn io.ReadWriter
  180. s [96]byte
  181. initer bool // Whether we're initiating or receiving.
  182. skeys SecretKeyIter // Skeys we'll accept if receiving.
  183. skey []byte // Skey we're initiating with.
  184. ia []byte // Initial payload. Only used by the initiator.
  185. // Return the bit for the crypto method the receiver wants to use.
  186. chooseMethod CryptoSelector
  187. // Sent to the receiver.
  188. cryptoProvides CryptoMethod
  189. writeMu sync.Mutex
  190. writes [][]byte
  191. writeErr error
  192. writeCond sync.Cond
  193. writeClose bool
  194. writerMu sync.Mutex
  195. writerCond sync.Cond
  196. writerDone bool
  197. }
  198. func (h *handshake) finishWriting() {
  199. h.writeMu.Lock()
  200. h.writeClose = true
  201. h.writeCond.Broadcast()
  202. h.writeMu.Unlock()
  203. h.writerMu.Lock()
  204. for !h.writerDone {
  205. h.writerCond.Wait()
  206. }
  207. h.writerMu.Unlock()
  208. }
  209. func (h *handshake) writer() {
  210. defer func() {
  211. h.writerMu.Lock()
  212. h.writerDone = true
  213. h.writerCond.Broadcast()
  214. h.writerMu.Unlock()
  215. }()
  216. for {
  217. h.writeMu.Lock()
  218. for {
  219. if len(h.writes) != 0 {
  220. break
  221. }
  222. if h.writeClose {
  223. h.writeMu.Unlock()
  224. return
  225. }
  226. h.writeCond.Wait()
  227. }
  228. b := h.writes[0]
  229. h.writes = h.writes[1:]
  230. h.writeMu.Unlock()
  231. _, err := h.ctxConn.Write(b)
  232. if err != nil {
  233. h.writeMu.Lock()
  234. h.writeErr = err
  235. h.writeMu.Unlock()
  236. return
  237. }
  238. }
  239. }
  240. func (h *handshake) postWrite(b []byte) error {
  241. h.writeMu.Lock()
  242. defer h.writeMu.Unlock()
  243. if h.writeErr != nil {
  244. return h.writeErr
  245. }
  246. h.writes = append(h.writes, b)
  247. h.writeCond.Signal()
  248. return nil
  249. }
  250. func xor(a, b []byte) (ret []byte) {
  251. max := len(a)
  252. if max > len(b) {
  253. max = len(b)
  254. }
  255. ret = make([]byte, max)
  256. xorInPlace(ret, a, b)
  257. return
  258. }
  259. func xorInPlace(dst, a, b []byte) {
  260. for i := range dst {
  261. dst[i] = a[i] ^ b[i]
  262. }
  263. }
  264. func marshal(w io.Writer, data ...interface{}) (err error) {
  265. for _, data := range data {
  266. err = binary.Write(w, binary.BigEndian, data)
  267. if err != nil {
  268. break
  269. }
  270. }
  271. return
  272. }
  273. func unmarshal(r io.Reader, data ...interface{}) (err error) {
  274. for _, data := range data {
  275. err = binary.Read(r, binary.BigEndian, data)
  276. if err != nil {
  277. break
  278. }
  279. }
  280. return
  281. }
  282. // Looking for b at the end of a.
  283. func suffixMatchLen(a, b []byte) int {
  284. if len(b) > len(a) {
  285. b = b[:len(a)]
  286. }
  287. // i is how much of b to try to match
  288. for i := len(b); i > 0; i-- {
  289. // j is how many chars we've compared
  290. j := 0
  291. for ; j < i; j++ {
  292. if b[i-1-j] != a[len(a)-1-j] {
  293. goto shorter
  294. }
  295. }
  296. return j
  297. shorter:
  298. }
  299. return 0
  300. }
  301. // Reads from r until b has been seen. Keeps the minimum amount of data in
  302. // memory.
  303. func readUntil(r io.Reader, b []byte) error {
  304. b1 := make([]byte, len(b))
  305. i := 0
  306. for {
  307. _, err := io.ReadFull(r, b1[i:])
  308. if err != nil {
  309. return err
  310. }
  311. i = suffixMatchLen(b1, b)
  312. if i == len(b) {
  313. break
  314. }
  315. if copy(b1, b1[len(b1)-i:]) != i {
  316. panic("wat")
  317. }
  318. }
  319. return nil
  320. }
  321. type readWriter struct {
  322. io.Reader
  323. io.Writer
  324. }
  325. func (h *handshake) newEncrypt(initer bool) *rc4.Cipher {
  326. return newEncrypt(initer, h.s[:], h.skey)
  327. }
  328. func (h *handshake) initerSteps(ctx context.Context) (ret io.ReadWriter, selected CryptoMethod, err error) {
  329. h.postWrite(hash(req1, h.s[:]))
  330. h.postWrite(xor(hash(req2, h.skey), hash(req3, h.s[:])))
  331. buf := &bytes.Buffer{}
  332. padLen := uint16(newPadLen())
  333. if len(h.ia) > math.MaxUint16 {
  334. err = errors.New("initial payload too large")
  335. return
  336. }
  337. err = marshal(buf, vc[:], h.cryptoProvides, padLen, zeroPad[:padLen], uint16(len(h.ia)), h.ia)
  338. if err != nil {
  339. return
  340. }
  341. e := h.newEncrypt(true)
  342. be := make([]byte, buf.Len())
  343. e.XORKeyStream(be, buf.Bytes())
  344. h.postWrite(be)
  345. bC := h.newEncrypt(false)
  346. var eVC [8]byte
  347. bC.XORKeyStream(eVC[:], vc[:])
  348. // Read until the all zero VC. At this point we've only read the 96 byte
  349. // public key, Y. There is potentially 512 byte padding, between us and
  350. // the 8 byte verification constant.
  351. err = readUntil(io.LimitReader(h.ctxConn, 520), eVC[:])
  352. if err != nil {
  353. if err == io.EOF {
  354. err = errors.New("failed to synchronize on VC")
  355. } else {
  356. err = fmt.Errorf("error reading until VC: %w", err)
  357. }
  358. return
  359. }
  360. ctxReader := newCipherReader(bC, h.ctxConn)
  361. var method CryptoMethod
  362. err = unmarshal(ctxReader, &method, &padLen)
  363. if err != nil {
  364. return
  365. }
  366. _, err = io.CopyN(io.Discard, ctxReader, int64(padLen))
  367. if err != nil {
  368. return
  369. }
  370. selected = method & h.cryptoProvides
  371. switch selected {
  372. case CryptoMethodRC4:
  373. ret = readWriter{
  374. newCipherReader(bC, h.conn),
  375. &cipherWriter{e, h.conn, nil},
  376. }
  377. case CryptoMethodPlaintext:
  378. ret = h.conn
  379. default:
  380. err = fmt.Errorf("receiver chose unsupported method: %x", method)
  381. }
  382. return
  383. }
  384. var ErrNoSecretKeyMatch = errors.New("no skey matched")
  385. func (h *handshake) receiverSteps(ctx context.Context) (ret io.ReadWriter, chosen CryptoMethod, err error) {
  386. // There is up to 512 bytes of padding, then the 20 byte hash.
  387. err = readUntil(io.LimitReader(h.ctxConn, 532), hash(req1, h.s[:]))
  388. if err != nil {
  389. if err == io.EOF {
  390. err = errors.New("failed to synchronize on S hash")
  391. }
  392. return
  393. }
  394. var b [20]byte
  395. _, err = io.ReadFull(h.ctxConn, b[:])
  396. if err != nil {
  397. return
  398. }
  399. expectedHash := hash(req3, h.s[:])
  400. eachHash := sha1.New()
  401. var sum, xored [sha1.Size]byte
  402. err = ErrNoSecretKeyMatch
  403. h.skeys(func(skey []byte) bool {
  404. eachHash.Reset()
  405. eachHash.Write(req2)
  406. eachHash.Write(skey)
  407. eachHash.Sum(sum[:0])
  408. xorInPlace(xored[:], sum[:], expectedHash)
  409. if bytes.Equal(xored[:], b[:]) {
  410. h.skey = skey
  411. err = nil
  412. return false
  413. }
  414. return true
  415. })
  416. if err != nil {
  417. return
  418. }
  419. cipher := newEncrypt(true, h.s[:], h.skey)
  420. ctxReader := newCipherReader(cipher, h.ctxConn)
  421. var (
  422. vc [8]byte
  423. provides CryptoMethod
  424. padLen uint16
  425. )
  426. err = unmarshal(ctxReader, vc[:], &provides, &padLen)
  427. if err != nil {
  428. return
  429. }
  430. cryptoProvidesCount.Add(strconv.FormatUint(uint64(provides), 16), 1)
  431. chosen = h.chooseMethod(provides)
  432. _, err = io.CopyN(io.Discard, ctxReader, int64(padLen))
  433. if err != nil {
  434. return
  435. }
  436. var lenIA uint16
  437. unmarshal(ctxReader, &lenIA)
  438. if lenIA != 0 {
  439. h.ia = make([]byte, lenIA)
  440. unmarshal(ctxReader, h.ia)
  441. }
  442. buf := &bytes.Buffer{}
  443. w := cipherWriter{h.newEncrypt(false), buf, nil}
  444. padLen = uint16(newPadLen())
  445. err = marshal(&w, &vc, uint32(chosen), padLen, zeroPad[:padLen])
  446. if err != nil {
  447. return
  448. }
  449. err = h.postWrite(buf.Bytes())
  450. if err != nil {
  451. return
  452. }
  453. switch chosen {
  454. case CryptoMethodRC4:
  455. ret = readWriter{
  456. io.MultiReader(bytes.NewReader(h.ia), newCipherReader(cipher, h.conn)),
  457. &cipherWriter{w.c, h.conn, nil},
  458. }
  459. case CryptoMethodPlaintext:
  460. ret = readWriter{
  461. io.MultiReader(bytes.NewReader(h.ia), h.conn),
  462. h.conn,
  463. }
  464. default:
  465. err = errors.New("chosen crypto method is not supported")
  466. }
  467. return
  468. }
  469. func (h *handshake) Do(ctx context.Context) (ret io.ReadWriter, method CryptoMethod, err error) {
  470. h.writeCond.L = &h.writeMu
  471. h.writerCond.L = &h.writerMu
  472. go h.writer()
  473. defer func() {
  474. h.finishWriting()
  475. if err == nil {
  476. err = h.writeErr
  477. }
  478. }()
  479. err = h.establishS()
  480. if err != nil {
  481. err = fmt.Errorf("error while establishing secret: %w", err)
  482. return
  483. }
  484. pad := make([]byte, newPadLen())
  485. io.ReadFull(rand.Reader, pad)
  486. err = h.postWrite(pad)
  487. if err != nil {
  488. return
  489. }
  490. if h.initer {
  491. ret, method, err = h.initerSteps(ctx)
  492. } else {
  493. ret, method, err = h.receiverSteps(ctx)
  494. }
  495. return
  496. }
  497. func InitiateHandshake(
  498. rw io.ReadWriter,
  499. skey, initialPayload []byte,
  500. cryptoProvides CryptoMethod,
  501. ) (
  502. ret io.ReadWriter, method CryptoMethod, err error,
  503. ) {
  504. return InitiateHandshakeContext(context.TODO(), rw, skey, initialPayload, cryptoProvides)
  505. }
  506. func InitiateHandshakeContext(
  507. ctx context.Context,
  508. rw io.ReadWriter,
  509. skey, initialPayload []byte,
  510. cryptoProvides CryptoMethod,
  511. ) (
  512. ret io.ReadWriter, method CryptoMethod, err error,
  513. ) {
  514. h := handshake{
  515. conn: rw,
  516. ctxConn: ctxrw.WrapReadWriter(ctx, rw),
  517. initer: true,
  518. skey: skey,
  519. ia: initialPayload,
  520. cryptoProvides: cryptoProvides,
  521. }
  522. defer perf.ScopeTimerErr(&err)()
  523. return h.Do(ctx)
  524. }
  525. type HandshakeResult struct {
  526. io.ReadWriter
  527. CryptoMethod
  528. error
  529. SecretKey []byte
  530. }
  531. func ReceiveHandshake(
  532. ctx context.Context,
  533. rw io.ReadWriter,
  534. skeys SecretKeyIter,
  535. selectCrypto CryptoSelector,
  536. ) (io.ReadWriter, CryptoMethod, error) {
  537. res := ReceiveHandshakeEx(ctx, rw, skeys, selectCrypto)
  538. return res.ReadWriter, res.CryptoMethod, res.error
  539. }
  540. func ReceiveHandshakeEx(
  541. ctx context.Context,
  542. rw io.ReadWriter,
  543. skeys SecretKeyIter,
  544. selectCrypto CryptoSelector,
  545. ) (ret HandshakeResult) {
  546. h := handshake{
  547. conn: rw,
  548. ctxConn: ctxrw.WrapReadWriter(ctx, rw),
  549. initer: false,
  550. skeys: skeys,
  551. chooseMethod: selectCrypto,
  552. }
  553. ret.ReadWriter, ret.CryptoMethod, ret.error = h.Do(ctx)
  554. ret.SecretKey = h.skey
  555. return
  556. }
  557. // A function that given a function, calls it with secret keys until it
  558. // returns false or exhausted.
  559. type SecretKeyIter func(callback func(skey []byte) (more bool))
  560. func DefaultCryptoSelector(provided CryptoMethod) CryptoMethod {
  561. // We prefer plaintext for performance reasons.
  562. if provided&CryptoMethodPlaintext != 0 {
  563. return CryptoMethodPlaintext
  564. }
  565. return CryptoMethodRC4
  566. }
  567. type CryptoSelector func(CryptoMethod) CryptoMethod