s2a.go 14 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427
  1. /*
  2. *
  3. * Copyright 2021 Google LLC
  4. *
  5. * Licensed under the Apache License, Version 2.0 (the "License");
  6. * you may not use this file except in compliance with the License.
  7. * You may obtain a copy of the License at
  8. *
  9. * https://www.apache.org/licenses/LICENSE-2.0
  10. *
  11. * Unless required by applicable law or agreed to in writing, software
  12. * distributed under the License is distributed on an "AS IS" BASIS,
  13. * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  14. * See the License for the specific language governing permissions and
  15. * limitations under the License.
  16. *
  17. */
  18. // Package s2a provides the S2A transport credentials used by a gRPC
  19. // application.
  20. package s2a
  21. import (
  22. "context"
  23. "crypto/tls"
  24. "errors"
  25. "fmt"
  26. "net"
  27. "sync"
  28. "time"
  29. "github.com/golang/protobuf/proto"
  30. "github.com/google/s2a-go/fallback"
  31. "github.com/google/s2a-go/internal/handshaker"
  32. "github.com/google/s2a-go/internal/handshaker/service"
  33. "github.com/google/s2a-go/internal/tokenmanager"
  34. "github.com/google/s2a-go/internal/v2"
  35. "github.com/google/s2a-go/retry"
  36. "google.golang.org/grpc/credentials"
  37. "google.golang.org/grpc/grpclog"
  38. commonpb "github.com/google/s2a-go/internal/proto/common_go_proto"
  39. s2av2pb "github.com/google/s2a-go/internal/proto/v2/s2a_go_proto"
  40. )
  41. const (
  42. s2aSecurityProtocol = "tls"
  43. // defaultTimeout specifies the default server handshake timeout.
  44. defaultTimeout = 30.0 * time.Second
  45. )
  46. // s2aTransportCreds are the transport credentials required for establishing
  47. // a secure connection using the S2A. They implement the
  48. // credentials.TransportCredentials interface.
  49. type s2aTransportCreds struct {
  50. info *credentials.ProtocolInfo
  51. minTLSVersion commonpb.TLSVersion
  52. maxTLSVersion commonpb.TLSVersion
  53. // tlsCiphersuites contains the ciphersuites used in the S2A connection.
  54. // Note that these are currently unconfigurable.
  55. tlsCiphersuites []commonpb.Ciphersuite
  56. // localIdentity should only be used by the client.
  57. localIdentity *commonpb.Identity
  58. // localIdentities should only be used by the server.
  59. localIdentities []*commonpb.Identity
  60. // targetIdentities should only be used by the client.
  61. targetIdentities []*commonpb.Identity
  62. isClient bool
  63. s2aAddr string
  64. ensureProcessSessionTickets *sync.WaitGroup
  65. }
  66. // NewClientCreds returns a client-side transport credentials object that uses
  67. // the S2A to establish a secure connection with a server.
  68. func NewClientCreds(opts *ClientOptions) (credentials.TransportCredentials, error) {
  69. if opts == nil {
  70. return nil, errors.New("nil client options")
  71. }
  72. var targetIdentities []*commonpb.Identity
  73. for _, targetIdentity := range opts.TargetIdentities {
  74. protoTargetIdentity, err := toProtoIdentity(targetIdentity)
  75. if err != nil {
  76. return nil, err
  77. }
  78. targetIdentities = append(targetIdentities, protoTargetIdentity)
  79. }
  80. localIdentity, err := toProtoIdentity(opts.LocalIdentity)
  81. if err != nil {
  82. return nil, err
  83. }
  84. if opts.EnableLegacyMode {
  85. return &s2aTransportCreds{
  86. info: &credentials.ProtocolInfo{
  87. SecurityProtocol: s2aSecurityProtocol,
  88. },
  89. minTLSVersion: commonpb.TLSVersion_TLS1_3,
  90. maxTLSVersion: commonpb.TLSVersion_TLS1_3,
  91. tlsCiphersuites: []commonpb.Ciphersuite{
  92. commonpb.Ciphersuite_AES_128_GCM_SHA256,
  93. commonpb.Ciphersuite_AES_256_GCM_SHA384,
  94. commonpb.Ciphersuite_CHACHA20_POLY1305_SHA256,
  95. },
  96. localIdentity: localIdentity,
  97. targetIdentities: targetIdentities,
  98. isClient: true,
  99. s2aAddr: opts.S2AAddress,
  100. ensureProcessSessionTickets: opts.EnsureProcessSessionTickets,
  101. }, nil
  102. }
  103. verificationMode := getVerificationMode(opts.VerificationMode)
  104. var fallbackFunc fallback.ClientHandshake
  105. if opts.FallbackOpts != nil && opts.FallbackOpts.FallbackClientHandshakeFunc != nil {
  106. fallbackFunc = opts.FallbackOpts.FallbackClientHandshakeFunc
  107. }
  108. return v2.NewClientCreds(opts.S2AAddress, opts.TransportCreds, localIdentity, verificationMode, fallbackFunc, opts.getS2AStream, opts.serverAuthorizationPolicy)
  109. }
  110. // NewServerCreds returns a server-side transport credentials object that uses
  111. // the S2A to establish a secure connection with a client.
  112. func NewServerCreds(opts *ServerOptions) (credentials.TransportCredentials, error) {
  113. if opts == nil {
  114. return nil, errors.New("nil server options")
  115. }
  116. var localIdentities []*commonpb.Identity
  117. for _, localIdentity := range opts.LocalIdentities {
  118. protoLocalIdentity, err := toProtoIdentity(localIdentity)
  119. if err != nil {
  120. return nil, err
  121. }
  122. localIdentities = append(localIdentities, protoLocalIdentity)
  123. }
  124. if opts.EnableLegacyMode {
  125. return &s2aTransportCreds{
  126. info: &credentials.ProtocolInfo{
  127. SecurityProtocol: s2aSecurityProtocol,
  128. },
  129. minTLSVersion: commonpb.TLSVersion_TLS1_3,
  130. maxTLSVersion: commonpb.TLSVersion_TLS1_3,
  131. tlsCiphersuites: []commonpb.Ciphersuite{
  132. commonpb.Ciphersuite_AES_128_GCM_SHA256,
  133. commonpb.Ciphersuite_AES_256_GCM_SHA384,
  134. commonpb.Ciphersuite_CHACHA20_POLY1305_SHA256,
  135. },
  136. localIdentities: localIdentities,
  137. isClient: false,
  138. s2aAddr: opts.S2AAddress,
  139. }, nil
  140. }
  141. verificationMode := getVerificationMode(opts.VerificationMode)
  142. return v2.NewServerCreds(opts.S2AAddress, opts.TransportCreds, localIdentities, verificationMode, opts.getS2AStream)
  143. }
  144. // ClientHandshake initiates a client-side TLS handshake using the S2A.
  145. func (c *s2aTransportCreds) ClientHandshake(ctx context.Context, serverAuthority string, rawConn net.Conn) (net.Conn, credentials.AuthInfo, error) {
  146. if !c.isClient {
  147. return nil, nil, errors.New("client handshake called using server transport credentials")
  148. }
  149. var cancel context.CancelFunc
  150. ctx, cancel = context.WithCancel(ctx)
  151. defer cancel()
  152. // Connect to the S2A.
  153. hsConn, err := service.Dial(ctx, c.s2aAddr, nil)
  154. if err != nil {
  155. grpclog.Infof("Failed to connect to S2A: %v", err)
  156. return nil, nil, err
  157. }
  158. opts := &handshaker.ClientHandshakerOptions{
  159. MinTLSVersion: c.minTLSVersion,
  160. MaxTLSVersion: c.maxTLSVersion,
  161. TLSCiphersuites: c.tlsCiphersuites,
  162. TargetIdentities: c.targetIdentities,
  163. LocalIdentity: c.localIdentity,
  164. TargetName: serverAuthority,
  165. EnsureProcessSessionTickets: c.ensureProcessSessionTickets,
  166. }
  167. chs, err := handshaker.NewClientHandshaker(ctx, hsConn, rawConn, c.s2aAddr, opts)
  168. if err != nil {
  169. grpclog.Infof("Call to handshaker.NewClientHandshaker failed: %v", err)
  170. return nil, nil, err
  171. }
  172. defer func() {
  173. if err != nil {
  174. if closeErr := chs.Close(); closeErr != nil {
  175. grpclog.Infof("Close failed unexpectedly: %v", err)
  176. err = fmt.Errorf("%v: close unexpectedly failed: %v", err, closeErr)
  177. }
  178. }
  179. }()
  180. secConn, authInfo, err := chs.ClientHandshake(context.Background())
  181. if err != nil {
  182. grpclog.Infof("Handshake failed: %v", err)
  183. return nil, nil, err
  184. }
  185. return secConn, authInfo, nil
  186. }
  187. // ServerHandshake initiates a server-side TLS handshake using the S2A.
  188. func (c *s2aTransportCreds) ServerHandshake(rawConn net.Conn) (net.Conn, credentials.AuthInfo, error) {
  189. if c.isClient {
  190. return nil, nil, errors.New("server handshake called using client transport credentials")
  191. }
  192. ctx, cancel := context.WithTimeout(context.Background(), defaultTimeout)
  193. defer cancel()
  194. // Connect to the S2A.
  195. hsConn, err := service.Dial(ctx, c.s2aAddr, nil)
  196. if err != nil {
  197. grpclog.Infof("Failed to connect to S2A: %v", err)
  198. return nil, nil, err
  199. }
  200. opts := &handshaker.ServerHandshakerOptions{
  201. MinTLSVersion: c.minTLSVersion,
  202. MaxTLSVersion: c.maxTLSVersion,
  203. TLSCiphersuites: c.tlsCiphersuites,
  204. LocalIdentities: c.localIdentities,
  205. }
  206. shs, err := handshaker.NewServerHandshaker(ctx, hsConn, rawConn, c.s2aAddr, opts)
  207. if err != nil {
  208. grpclog.Infof("Call to handshaker.NewServerHandshaker failed: %v", err)
  209. return nil, nil, err
  210. }
  211. defer func() {
  212. if err != nil {
  213. if closeErr := shs.Close(); closeErr != nil {
  214. grpclog.Infof("Close failed unexpectedly: %v", err)
  215. err = fmt.Errorf("%v: close unexpectedly failed: %v", err, closeErr)
  216. }
  217. }
  218. }()
  219. secConn, authInfo, err := shs.ServerHandshake(context.Background())
  220. if err != nil {
  221. grpclog.Infof("Handshake failed: %v", err)
  222. return nil, nil, err
  223. }
  224. return secConn, authInfo, nil
  225. }
  226. func (c *s2aTransportCreds) Info() credentials.ProtocolInfo {
  227. return *c.info
  228. }
  229. func (c *s2aTransportCreds) Clone() credentials.TransportCredentials {
  230. info := *c.info
  231. var localIdentity *commonpb.Identity
  232. if c.localIdentity != nil {
  233. localIdentity = proto.Clone(c.localIdentity).(*commonpb.Identity)
  234. }
  235. var localIdentities []*commonpb.Identity
  236. if c.localIdentities != nil {
  237. localIdentities = make([]*commonpb.Identity, len(c.localIdentities))
  238. for i, localIdentity := range c.localIdentities {
  239. localIdentities[i] = proto.Clone(localIdentity).(*commonpb.Identity)
  240. }
  241. }
  242. var targetIdentities []*commonpb.Identity
  243. if c.targetIdentities != nil {
  244. targetIdentities = make([]*commonpb.Identity, len(c.targetIdentities))
  245. for i, targetIdentity := range c.targetIdentities {
  246. targetIdentities[i] = proto.Clone(targetIdentity).(*commonpb.Identity)
  247. }
  248. }
  249. return &s2aTransportCreds{
  250. info: &info,
  251. minTLSVersion: c.minTLSVersion,
  252. maxTLSVersion: c.maxTLSVersion,
  253. tlsCiphersuites: c.tlsCiphersuites,
  254. localIdentity: localIdentity,
  255. localIdentities: localIdentities,
  256. targetIdentities: targetIdentities,
  257. isClient: c.isClient,
  258. s2aAddr: c.s2aAddr,
  259. }
  260. }
  261. func (c *s2aTransportCreds) OverrideServerName(serverNameOverride string) error {
  262. c.info.ServerName = serverNameOverride
  263. return nil
  264. }
  265. // TLSClientConfigOptions specifies parameters for creating client TLS config.
  266. type TLSClientConfigOptions struct {
  267. // ServerName is required by s2a as the expected name when verifying the hostname found in server's certificate.
  268. // tlsConfig, _ := factory.Build(ctx, &s2a.TLSClientConfigOptions{
  269. // ServerName: "example.com",
  270. // })
  271. ServerName string
  272. }
  273. // TLSClientConfigFactory defines the interface for a client TLS config factory.
  274. type TLSClientConfigFactory interface {
  275. Build(ctx context.Context, opts *TLSClientConfigOptions) (*tls.Config, error)
  276. }
  277. // NewTLSClientConfigFactory returns an instance of s2aTLSClientConfigFactory.
  278. func NewTLSClientConfigFactory(opts *ClientOptions) (TLSClientConfigFactory, error) {
  279. if opts == nil {
  280. return nil, fmt.Errorf("opts must be non-nil")
  281. }
  282. if opts.EnableLegacyMode {
  283. return nil, fmt.Errorf("NewTLSClientConfigFactory only supports S2Av2")
  284. }
  285. tokenManager, err := tokenmanager.NewSingleTokenAccessTokenManager()
  286. if err != nil {
  287. // The only possible error is: access token not set in the environment,
  288. // which is okay in environments other than serverless.
  289. grpclog.Infof("Access token manager not initialized: %v", err)
  290. return &s2aTLSClientConfigFactory{
  291. s2av2Address: opts.S2AAddress,
  292. transportCreds: opts.TransportCreds,
  293. tokenManager: nil,
  294. verificationMode: getVerificationMode(opts.VerificationMode),
  295. serverAuthorizationPolicy: opts.serverAuthorizationPolicy,
  296. }, nil
  297. }
  298. return &s2aTLSClientConfigFactory{
  299. s2av2Address: opts.S2AAddress,
  300. transportCreds: opts.TransportCreds,
  301. tokenManager: tokenManager,
  302. verificationMode: getVerificationMode(opts.VerificationMode),
  303. serverAuthorizationPolicy: opts.serverAuthorizationPolicy,
  304. }, nil
  305. }
  306. type s2aTLSClientConfigFactory struct {
  307. s2av2Address string
  308. transportCreds credentials.TransportCredentials
  309. tokenManager tokenmanager.AccessTokenManager
  310. verificationMode s2av2pb.ValidatePeerCertificateChainReq_VerificationMode
  311. serverAuthorizationPolicy []byte
  312. }
  313. func (f *s2aTLSClientConfigFactory) Build(
  314. ctx context.Context, opts *TLSClientConfigOptions) (*tls.Config, error) {
  315. serverName := ""
  316. if opts != nil && opts.ServerName != "" {
  317. serverName = opts.ServerName
  318. }
  319. return v2.NewClientTLSConfig(ctx, f.s2av2Address, f.transportCreds, f.tokenManager, f.verificationMode, serverName, f.serverAuthorizationPolicy)
  320. }
  321. func getVerificationMode(verificationMode VerificationModeType) s2av2pb.ValidatePeerCertificateChainReq_VerificationMode {
  322. switch verificationMode {
  323. case ConnectToGoogle:
  324. return s2av2pb.ValidatePeerCertificateChainReq_CONNECT_TO_GOOGLE
  325. case Spiffe:
  326. return s2av2pb.ValidatePeerCertificateChainReq_SPIFFE
  327. default:
  328. return s2av2pb.ValidatePeerCertificateChainReq_UNSPECIFIED
  329. }
  330. }
  331. // NewS2ADialTLSContextFunc returns a dialer which establishes an MTLS connection using S2A.
  332. // Example use with http.RoundTripper:
  333. //
  334. // dialTLSContext := s2a.NewS2aDialTLSContextFunc(&s2a.ClientOptions{
  335. // S2AAddress: s2aAddress, // required
  336. // })
  337. // transport := http.DefaultTransport
  338. // transport.DialTLSContext = dialTLSContext
  339. func NewS2ADialTLSContextFunc(opts *ClientOptions) func(ctx context.Context, network, addr string) (net.Conn, error) {
  340. return func(ctx context.Context, network, addr string) (net.Conn, error) {
  341. fallback := func(err error) (net.Conn, error) {
  342. if opts.FallbackOpts != nil && opts.FallbackOpts.FallbackDialer != nil &&
  343. opts.FallbackOpts.FallbackDialer.Dialer != nil && opts.FallbackOpts.FallbackDialer.ServerAddr != "" {
  344. fbDialer := opts.FallbackOpts.FallbackDialer
  345. grpclog.Infof("fall back to dial: %s", fbDialer.ServerAddr)
  346. fbConn, fbErr := fbDialer.Dialer.DialContext(ctx, network, fbDialer.ServerAddr)
  347. if fbErr != nil {
  348. return nil, fmt.Errorf("error fallback to %s: %v; S2A error: %w", fbDialer.ServerAddr, fbErr, err)
  349. }
  350. return fbConn, nil
  351. }
  352. return nil, err
  353. }
  354. factory, err := NewTLSClientConfigFactory(opts)
  355. if err != nil {
  356. grpclog.Infof("error creating S2A client config factory: %v", err)
  357. return fallback(err)
  358. }
  359. serverName, _, err := net.SplitHostPort(addr)
  360. if err != nil {
  361. serverName = addr
  362. }
  363. timeoutCtx, cancel := context.WithTimeout(ctx, v2.GetS2ATimeout())
  364. defer cancel()
  365. var s2aTLSConfig *tls.Config
  366. retry.Run(timeoutCtx,
  367. func() error {
  368. s2aTLSConfig, err = factory.Build(timeoutCtx, &TLSClientConfigOptions{
  369. ServerName: serverName,
  370. })
  371. return err
  372. })
  373. if err != nil {
  374. grpclog.Infof("error building S2A TLS config: %v", err)
  375. return fallback(err)
  376. }
  377. s2aDialer := &tls.Dialer{
  378. Config: s2aTLSConfig,
  379. }
  380. var c net.Conn
  381. retry.Run(timeoutCtx,
  382. func() error {
  383. c, err = s2aDialer.DialContext(timeoutCtx, network, addr)
  384. return err
  385. })
  386. if err != nil {
  387. grpclog.Infof("error dialing with S2A to %s: %v", addr, err)
  388. return fallback(err)
  389. }
  390. grpclog.Infof("success dialing MTLS to %s with S2A", addr)
  391. return c, nil
  392. }
  393. }