mlkem.go 4.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183
  1. // Copyright 2024 The Go Authors. All rights reserved.
  2. // Use of this source code is governed by a BSD-style
  3. // license that can be found in the LICENSE file.
  4. //go:build go1.24
  5. package ssh
  6. import (
  7. "crypto"
  8. "crypto/mlkem"
  9. "crypto/sha256"
  10. "errors"
  11. "fmt"
  12. "io"
  13. "runtime"
  14. "slices"
  15. "golang.org/x/crypto/curve25519"
  16. )
  17. func init() {
  18. // After Go 1.24rc1 mlkem swapped the order of return values of Encapsulate.
  19. // See #70950.
  20. if runtime.Version() == "go1.24rc1" {
  21. return
  22. }
  23. supportedKexAlgos = slices.Insert(supportedKexAlgos, 0, KeyExchangeMLKEM768X25519)
  24. defaultKexAlgos = slices.Insert(defaultKexAlgos, 0, KeyExchangeMLKEM768X25519)
  25. kexAlgoMap[KeyExchangeMLKEM768X25519] = &mlkem768WithCurve25519sha256{}
  26. }
  27. // mlkem768WithCurve25519sha256 implements the hybrid ML-KEM768 with
  28. // curve25519-sha256 key exchange method, as described by
  29. // draft-kampanakis-curdle-ssh-pq-ke-05 section 2.3.3.
  30. type mlkem768WithCurve25519sha256 struct{}
  31. func (kex *mlkem768WithCurve25519sha256) Client(c packetConn, rand io.Reader, magics *handshakeMagics) (*kexResult, error) {
  32. var c25519kp curve25519KeyPair
  33. if err := c25519kp.generate(rand); err != nil {
  34. return nil, err
  35. }
  36. seed := make([]byte, mlkem.SeedSize)
  37. if _, err := io.ReadFull(rand, seed); err != nil {
  38. return nil, err
  39. }
  40. mlkemDk, err := mlkem.NewDecapsulationKey768(seed)
  41. if err != nil {
  42. return nil, err
  43. }
  44. hybridKey := append(mlkemDk.EncapsulationKey().Bytes(), c25519kp.pub[:]...)
  45. if err := c.writePacket(Marshal(&kexECDHInitMsg{hybridKey})); err != nil {
  46. return nil, err
  47. }
  48. packet, err := c.readPacket()
  49. if err != nil {
  50. return nil, err
  51. }
  52. var reply kexECDHReplyMsg
  53. if err = Unmarshal(packet, &reply); err != nil {
  54. return nil, err
  55. }
  56. if len(reply.EphemeralPubKey) != mlkem.CiphertextSize768+32 {
  57. return nil, errors.New("ssh: peer's mlkem768x25519 public value has wrong length")
  58. }
  59. // Perform KEM decapsulate operation to obtain shared key from ML-KEM.
  60. mlkem768Secret, err := mlkemDk.Decapsulate(reply.EphemeralPubKey[:mlkem.CiphertextSize768])
  61. if err != nil {
  62. return nil, err
  63. }
  64. // Complete Curve25519 ECDH to obtain its shared key.
  65. c25519Secret, err := curve25519.X25519(c25519kp.priv[:], reply.EphemeralPubKey[mlkem.CiphertextSize768:])
  66. if err != nil {
  67. return nil, fmt.Errorf("ssh: peer's mlkem768x25519 public value is not valid: %w", err)
  68. }
  69. // Compute actual shared key.
  70. h := sha256.New()
  71. h.Write(mlkem768Secret)
  72. h.Write(c25519Secret)
  73. secret := h.Sum(nil)
  74. h.Reset()
  75. magics.write(h)
  76. writeString(h, reply.HostKey)
  77. writeString(h, hybridKey)
  78. writeString(h, reply.EphemeralPubKey)
  79. K := make([]byte, stringLength(len(secret)))
  80. marshalString(K, secret)
  81. h.Write(K)
  82. return &kexResult{
  83. H: h.Sum(nil),
  84. K: K,
  85. HostKey: reply.HostKey,
  86. Signature: reply.Signature,
  87. Hash: crypto.SHA256,
  88. }, nil
  89. }
  90. func (kex *mlkem768WithCurve25519sha256) Server(c packetConn, rand io.Reader, magics *handshakeMagics, priv AlgorithmSigner, algo string) (*kexResult, error) {
  91. packet, err := c.readPacket()
  92. if err != nil {
  93. return nil, err
  94. }
  95. var kexInit kexECDHInitMsg
  96. if err = Unmarshal(packet, &kexInit); err != nil {
  97. return nil, err
  98. }
  99. if len(kexInit.ClientPubKey) != mlkem.EncapsulationKeySize768+32 {
  100. return nil, errors.New("ssh: peer's ML-KEM768/curve25519 public value has wrong length")
  101. }
  102. encapsulationKey, err := mlkem.NewEncapsulationKey768(kexInit.ClientPubKey[:mlkem.EncapsulationKeySize768])
  103. if err != nil {
  104. return nil, fmt.Errorf("ssh: peer's ML-KEM768 encapsulation key is not valid: %w", err)
  105. }
  106. // Perform KEM encapsulate operation to obtain ciphertext and shared key.
  107. mlkem768Secret, mlkem768Ciphertext := encapsulationKey.Encapsulate()
  108. // Perform server side of Curve25519 ECDH to obtain server public value and
  109. // shared key.
  110. var c25519kp curve25519KeyPair
  111. if err := c25519kp.generate(rand); err != nil {
  112. return nil, err
  113. }
  114. c25519Secret, err := curve25519.X25519(c25519kp.priv[:], kexInit.ClientPubKey[mlkem.EncapsulationKeySize768:])
  115. if err != nil {
  116. return nil, fmt.Errorf("ssh: peer's ML-KEM768/curve25519 public value is not valid: %w", err)
  117. }
  118. hybridKey := append(mlkem768Ciphertext, c25519kp.pub[:]...)
  119. // Compute actual shared key.
  120. h := sha256.New()
  121. h.Write(mlkem768Secret)
  122. h.Write(c25519Secret)
  123. secret := h.Sum(nil)
  124. hostKeyBytes := priv.PublicKey().Marshal()
  125. h.Reset()
  126. magics.write(h)
  127. writeString(h, hostKeyBytes)
  128. writeString(h, kexInit.ClientPubKey)
  129. writeString(h, hybridKey)
  130. K := make([]byte, stringLength(len(secret)))
  131. marshalString(K, secret)
  132. h.Write(K)
  133. H := h.Sum(nil)
  134. sig, err := signAndMarshal(priv, rand, H, algo)
  135. if err != nil {
  136. return nil, err
  137. }
  138. reply := kexECDHReplyMsg{
  139. EphemeralPubKey: hybridKey,
  140. HostKey: hostKeyBytes,
  141. Signature: sig,
  142. }
  143. if err := c.writePacket(Marshal(&reply)); err != nil {
  144. return nil, err
  145. }
  146. return &kexResult{
  147. H: H,
  148. K: K,
  149. HostKey: hostKeyBytes,
  150. Signature: sig,
  151. Hash: crypto.SHA256,
  152. }, nil
  153. }