s2av2.go 13 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391
  1. /*
  2. *
  3. * Copyright 2022 Google LLC
  4. *
  5. * Licensed under the Apache License, Version 2.0 (the "License");
  6. * you may not use this file except in compliance with the License.
  7. * You may obtain a copy of the License at
  8. *
  9. * https://www.apache.org/licenses/LICENSE-2.0
  10. *
  11. * Unless required by applicable law or agreed to in writing, software
  12. * distributed under the License is distributed on an "AS IS" BASIS,
  13. * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  14. * See the License for the specific language governing permissions and
  15. * limitations under the License.
  16. *
  17. */
  18. // Package v2 provides the S2Av2 transport credentials used by a gRPC
  19. // application.
  20. package v2
  21. import (
  22. "context"
  23. "crypto/tls"
  24. "errors"
  25. "net"
  26. "os"
  27. "time"
  28. "github.com/golang/protobuf/proto"
  29. "github.com/google/s2a-go/fallback"
  30. "github.com/google/s2a-go/internal/handshaker/service"
  31. "github.com/google/s2a-go/internal/tokenmanager"
  32. "github.com/google/s2a-go/internal/v2/tlsconfigstore"
  33. "github.com/google/s2a-go/retry"
  34. "github.com/google/s2a-go/stream"
  35. "google.golang.org/grpc"
  36. "google.golang.org/grpc/credentials"
  37. "google.golang.org/grpc/grpclog"
  38. commonpbv1 "github.com/google/s2a-go/internal/proto/common_go_proto"
  39. s2av2pb "github.com/google/s2a-go/internal/proto/v2/s2a_go_proto"
  40. )
  41. const (
  42. s2aSecurityProtocol = "tls"
  43. defaultS2ATimeout = 6 * time.Second
  44. )
  45. // An environment variable, which sets the timeout enforced on the connection to the S2A service for handshake.
  46. const s2aTimeoutEnv = "S2A_TIMEOUT"
  47. type s2av2TransportCreds struct {
  48. info *credentials.ProtocolInfo
  49. isClient bool
  50. serverName string
  51. s2av2Address string
  52. transportCreds credentials.TransportCredentials
  53. tokenManager *tokenmanager.AccessTokenManager
  54. // localIdentity should only be used by the client.
  55. localIdentity *commonpbv1.Identity
  56. // localIdentities should only be used by the server.
  57. localIdentities []*commonpbv1.Identity
  58. verificationMode s2av2pb.ValidatePeerCertificateChainReq_VerificationMode
  59. fallbackClientHandshake fallback.ClientHandshake
  60. getS2AStream func(ctx context.Context, s2av2Address string) (stream.S2AStream, error)
  61. serverAuthorizationPolicy []byte
  62. }
  63. // NewClientCreds returns a client-side transport credentials object that uses
  64. // the S2Av2 to establish a secure connection with a server.
  65. func NewClientCreds(s2av2Address string, transportCreds credentials.TransportCredentials, localIdentity *commonpbv1.Identity, verificationMode s2av2pb.ValidatePeerCertificateChainReq_VerificationMode, fallbackClientHandshakeFunc fallback.ClientHandshake, getS2AStream func(ctx context.Context, s2av2Address string) (stream.S2AStream, error), serverAuthorizationPolicy []byte) (credentials.TransportCredentials, error) {
  66. // Create an AccessTokenManager instance to use to authenticate to S2Av2.
  67. accessTokenManager, err := tokenmanager.NewSingleTokenAccessTokenManager()
  68. creds := &s2av2TransportCreds{
  69. info: &credentials.ProtocolInfo{
  70. SecurityProtocol: s2aSecurityProtocol,
  71. },
  72. isClient: true,
  73. serverName: "",
  74. s2av2Address: s2av2Address,
  75. transportCreds: transportCreds,
  76. localIdentity: localIdentity,
  77. verificationMode: verificationMode,
  78. fallbackClientHandshake: fallbackClientHandshakeFunc,
  79. getS2AStream: getS2AStream,
  80. serverAuthorizationPolicy: serverAuthorizationPolicy,
  81. }
  82. if err != nil {
  83. creds.tokenManager = nil
  84. } else {
  85. creds.tokenManager = &accessTokenManager
  86. }
  87. if grpclog.V(1) {
  88. grpclog.Info("Created client S2Av2 transport credentials.")
  89. }
  90. return creds, nil
  91. }
  92. // NewServerCreds returns a server-side transport credentials object that uses
  93. // the S2Av2 to establish a secure connection with a client.
  94. func NewServerCreds(s2av2Address string, transportCreds credentials.TransportCredentials, localIdentities []*commonpbv1.Identity, verificationMode s2av2pb.ValidatePeerCertificateChainReq_VerificationMode, getS2AStream func(ctx context.Context, s2av2Address string) (stream.S2AStream, error)) (credentials.TransportCredentials, error) {
  95. // Create an AccessTokenManager instance to use to authenticate to S2Av2.
  96. accessTokenManager, err := tokenmanager.NewSingleTokenAccessTokenManager()
  97. creds := &s2av2TransportCreds{
  98. info: &credentials.ProtocolInfo{
  99. SecurityProtocol: s2aSecurityProtocol,
  100. },
  101. isClient: false,
  102. s2av2Address: s2av2Address,
  103. transportCreds: transportCreds,
  104. localIdentities: localIdentities,
  105. verificationMode: verificationMode,
  106. getS2AStream: getS2AStream,
  107. }
  108. if err != nil {
  109. creds.tokenManager = nil
  110. } else {
  111. creds.tokenManager = &accessTokenManager
  112. }
  113. if grpclog.V(1) {
  114. grpclog.Info("Created server S2Av2 transport credentials.")
  115. }
  116. return creds, nil
  117. }
  118. // ClientHandshake performs a client-side mTLS handshake using the S2Av2.
  119. func (c *s2av2TransportCreds) ClientHandshake(ctx context.Context, serverAuthority string, rawConn net.Conn) (net.Conn, credentials.AuthInfo, error) {
  120. if !c.isClient {
  121. return nil, nil, errors.New("client handshake called using server transport credentials")
  122. }
  123. // Remove the port from serverAuthority.
  124. serverName := removeServerNamePort(serverAuthority)
  125. timeoutCtx, cancel := context.WithTimeout(ctx, GetS2ATimeout())
  126. defer cancel()
  127. var s2AStream stream.S2AStream
  128. var err error
  129. retry.Run(timeoutCtx,
  130. func() error {
  131. s2AStream, err = createStream(timeoutCtx, c.s2av2Address, c.transportCreds, c.getS2AStream)
  132. return err
  133. })
  134. if err != nil {
  135. grpclog.Infof("Failed to connect to S2Av2: %v", err)
  136. if c.fallbackClientHandshake != nil {
  137. return c.fallbackClientHandshake(ctx, serverAuthority, rawConn, err)
  138. }
  139. return nil, nil, err
  140. }
  141. defer s2AStream.CloseSend()
  142. if grpclog.V(1) {
  143. grpclog.Infof("Connected to S2Av2.")
  144. }
  145. var config *tls.Config
  146. var tokenManager tokenmanager.AccessTokenManager
  147. if c.tokenManager == nil {
  148. tokenManager = nil
  149. } else {
  150. tokenManager = *c.tokenManager
  151. }
  152. sn := serverName
  153. if c.serverName != "" {
  154. sn = c.serverName
  155. }
  156. retry.Run(timeoutCtx,
  157. func() error {
  158. config, err = tlsconfigstore.GetTLSConfigurationForClient(sn, s2AStream, tokenManager, c.localIdentity, c.verificationMode, c.serverAuthorizationPolicy)
  159. return err
  160. })
  161. if err != nil {
  162. grpclog.Info("Failed to get client TLS config from S2Av2: %v", err)
  163. if c.fallbackClientHandshake != nil {
  164. return c.fallbackClientHandshake(ctx, serverAuthority, rawConn, err)
  165. }
  166. return nil, nil, err
  167. }
  168. if grpclog.V(1) {
  169. grpclog.Infof("Got client TLS config from S2Av2.")
  170. }
  171. creds := credentials.NewTLS(config)
  172. var conn net.Conn
  173. var authInfo credentials.AuthInfo
  174. retry.Run(timeoutCtx,
  175. func() error {
  176. conn, authInfo, err = creds.ClientHandshake(timeoutCtx, serverName, rawConn)
  177. return err
  178. })
  179. if err != nil {
  180. grpclog.Infof("Failed to do client handshake using S2Av2: %v", err)
  181. if c.fallbackClientHandshake != nil {
  182. return c.fallbackClientHandshake(ctx, serverAuthority, rawConn, err)
  183. }
  184. return nil, nil, err
  185. }
  186. grpclog.Infof("Successfully done client handshake using S2Av2 to: %s", serverName)
  187. return conn, authInfo, err
  188. }
  189. // ServerHandshake performs a server-side mTLS handshake using the S2Av2.
  190. func (c *s2av2TransportCreds) ServerHandshake(rawConn net.Conn) (net.Conn, credentials.AuthInfo, error) {
  191. if c.isClient {
  192. return nil, nil, errors.New("server handshake called using client transport credentials")
  193. }
  194. ctx, cancel := context.WithTimeout(context.Background(), GetS2ATimeout())
  195. defer cancel()
  196. var s2AStream stream.S2AStream
  197. var err error
  198. retry.Run(ctx,
  199. func() error {
  200. s2AStream, err = createStream(ctx, c.s2av2Address, c.transportCreds, c.getS2AStream)
  201. return err
  202. })
  203. if err != nil {
  204. grpclog.Infof("Failed to connect to S2Av2: %v", err)
  205. return nil, nil, err
  206. }
  207. defer s2AStream.CloseSend()
  208. if grpclog.V(1) {
  209. grpclog.Infof("Connected to S2Av2.")
  210. }
  211. var tokenManager tokenmanager.AccessTokenManager
  212. if c.tokenManager == nil {
  213. tokenManager = nil
  214. } else {
  215. tokenManager = *c.tokenManager
  216. }
  217. var config *tls.Config
  218. retry.Run(ctx,
  219. func() error {
  220. config, err = tlsconfigstore.GetTLSConfigurationForServer(s2AStream, tokenManager, c.localIdentities, c.verificationMode)
  221. return err
  222. })
  223. if err != nil {
  224. grpclog.Infof("Failed to get server TLS config from S2Av2: %v", err)
  225. return nil, nil, err
  226. }
  227. if grpclog.V(1) {
  228. grpclog.Infof("Got server TLS config from S2Av2.")
  229. }
  230. creds := credentials.NewTLS(config)
  231. var conn net.Conn
  232. var authInfo credentials.AuthInfo
  233. retry.Run(ctx,
  234. func() error {
  235. conn, authInfo, err = creds.ServerHandshake(rawConn)
  236. return err
  237. })
  238. if err != nil {
  239. grpclog.Infof("Failed to do server handshake using S2Av2: %v", err)
  240. return nil, nil, err
  241. }
  242. return conn, authInfo, err
  243. }
  244. // Info returns protocol info of s2av2TransportCreds.
  245. func (c *s2av2TransportCreds) Info() credentials.ProtocolInfo {
  246. return *c.info
  247. }
  248. // Clone makes a deep copy of s2av2TransportCreds.
  249. func (c *s2av2TransportCreds) Clone() credentials.TransportCredentials {
  250. info := *c.info
  251. serverName := c.serverName
  252. fallbackClientHandshake := c.fallbackClientHandshake
  253. s2av2Address := c.s2av2Address
  254. var tokenManager tokenmanager.AccessTokenManager
  255. if c.tokenManager == nil {
  256. tokenManager = nil
  257. } else {
  258. tokenManager = *c.tokenManager
  259. }
  260. verificationMode := c.verificationMode
  261. var localIdentity *commonpbv1.Identity
  262. if c.localIdentity != nil {
  263. localIdentity = proto.Clone(c.localIdentity).(*commonpbv1.Identity)
  264. }
  265. var localIdentities []*commonpbv1.Identity
  266. if c.localIdentities != nil {
  267. localIdentities = make([]*commonpbv1.Identity, len(c.localIdentities))
  268. for i, localIdentity := range c.localIdentities {
  269. localIdentities[i] = proto.Clone(localIdentity).(*commonpbv1.Identity)
  270. }
  271. }
  272. creds := &s2av2TransportCreds{
  273. info: &info,
  274. isClient: c.isClient,
  275. serverName: serverName,
  276. fallbackClientHandshake: fallbackClientHandshake,
  277. s2av2Address: s2av2Address,
  278. localIdentity: localIdentity,
  279. localIdentities: localIdentities,
  280. verificationMode: verificationMode,
  281. }
  282. if c.tokenManager == nil {
  283. creds.tokenManager = nil
  284. } else {
  285. creds.tokenManager = &tokenManager
  286. }
  287. return creds
  288. }
  289. // NewClientTLSConfig returns a tls.Config instance that uses S2Av2 to establish a TLS connection as
  290. // a client. The tls.Config MUST only be used to establish a single TLS connection.
  291. func NewClientTLSConfig(
  292. ctx context.Context,
  293. s2av2Address string,
  294. transportCreds credentials.TransportCredentials,
  295. tokenManager tokenmanager.AccessTokenManager,
  296. verificationMode s2av2pb.ValidatePeerCertificateChainReq_VerificationMode,
  297. serverName string,
  298. serverAuthorizationPolicy []byte) (*tls.Config, error) {
  299. s2AStream, err := createStream(ctx, s2av2Address, transportCreds, nil)
  300. if err != nil {
  301. grpclog.Infof("Failed to connect to S2Av2: %v", err)
  302. return nil, err
  303. }
  304. return tlsconfigstore.GetTLSConfigurationForClient(removeServerNamePort(serverName), s2AStream, tokenManager, nil, verificationMode, serverAuthorizationPolicy)
  305. }
  306. // OverrideServerName sets the ServerName in the s2av2TransportCreds protocol
  307. // info. The ServerName MUST be a hostname.
  308. func (c *s2av2TransportCreds) OverrideServerName(serverNameOverride string) error {
  309. serverName := removeServerNamePort(serverNameOverride)
  310. c.info.ServerName = serverName
  311. c.serverName = serverName
  312. return nil
  313. }
  314. // Remove the trailing port from server name.
  315. func removeServerNamePort(serverName string) string {
  316. name, _, err := net.SplitHostPort(serverName)
  317. if err != nil {
  318. name = serverName
  319. }
  320. return name
  321. }
  322. type s2AGrpcStream struct {
  323. stream s2av2pb.S2AService_SetUpSessionClient
  324. }
  325. func (x s2AGrpcStream) Send(m *s2av2pb.SessionReq) error {
  326. return x.stream.Send(m)
  327. }
  328. func (x s2AGrpcStream) Recv() (*s2av2pb.SessionResp, error) {
  329. return x.stream.Recv()
  330. }
  331. func (x s2AGrpcStream) CloseSend() error {
  332. return x.stream.CloseSend()
  333. }
  334. func createStream(ctx context.Context, s2av2Address string, transportCreds credentials.TransportCredentials, getS2AStream func(ctx context.Context, s2av2Address string) (stream.S2AStream, error)) (stream.S2AStream, error) {
  335. if getS2AStream != nil {
  336. return getS2AStream(ctx, s2av2Address)
  337. }
  338. // TODO(rmehta19): Consider whether to close the connection to S2Av2.
  339. conn, err := service.Dial(ctx, s2av2Address, transportCreds)
  340. if err != nil {
  341. return nil, err
  342. }
  343. client := s2av2pb.NewS2AServiceClient(conn)
  344. gRPCStream, err := client.SetUpSession(ctx, []grpc.CallOption{}...)
  345. if err != nil {
  346. return nil, err
  347. }
  348. return &s2AGrpcStream{
  349. stream: gRPCStream,
  350. }, nil
  351. }
  352. // GetS2ATimeout returns the timeout enforced on the connection to the S2A service for handshake.
  353. func GetS2ATimeout() time.Duration {
  354. timeout, err := time.ParseDuration(os.Getenv(s2aTimeoutEnv))
  355. if err != nil {
  356. return defaultS2ATimeout
  357. }
  358. return timeout
  359. }