mirror of
https://github.com/rocky-linux/peridot.git
synced 2024-11-24 06:01:25 +00:00
392 lines
13 KiB
Go
392 lines
13 KiB
Go
|
/*
|
||
|
*
|
||
|
* Copyright 2022 Google LLC
|
||
|
*
|
||
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
||
|
* you may not use this file except in compliance with the License.
|
||
|
* You may obtain a copy of the License at
|
||
|
*
|
||
|
* https://www.apache.org/licenses/LICENSE-2.0
|
||
|
*
|
||
|
* Unless required by applicable law or agreed to in writing, software
|
||
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
||
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||
|
* See the License for the specific language governing permissions and
|
||
|
* limitations under the License.
|
||
|
*
|
||
|
*/
|
||
|
|
||
|
// Package v2 provides the S2Av2 transport credentials used by a gRPC
|
||
|
// application.
|
||
|
package v2
|
||
|
|
||
|
import (
|
||
|
"context"
|
||
|
"crypto/tls"
|
||
|
"errors"
|
||
|
"net"
|
||
|
"os"
|
||
|
"time"
|
||
|
|
||
|
"github.com/golang/protobuf/proto"
|
||
|
"github.com/google/s2a-go/fallback"
|
||
|
"github.com/google/s2a-go/internal/handshaker/service"
|
||
|
"github.com/google/s2a-go/internal/tokenmanager"
|
||
|
"github.com/google/s2a-go/internal/v2/tlsconfigstore"
|
||
|
"github.com/google/s2a-go/retry"
|
||
|
"github.com/google/s2a-go/stream"
|
||
|
"google.golang.org/grpc"
|
||
|
"google.golang.org/grpc/credentials"
|
||
|
"google.golang.org/grpc/grpclog"
|
||
|
|
||
|
commonpbv1 "github.com/google/s2a-go/internal/proto/common_go_proto"
|
||
|
s2av2pb "github.com/google/s2a-go/internal/proto/v2/s2a_go_proto"
|
||
|
)
|
||
|
|
||
|
const (
|
||
|
s2aSecurityProtocol = "tls"
|
||
|
defaultS2ATimeout = 6 * time.Second
|
||
|
)
|
||
|
|
||
|
// An environment variable, which sets the timeout enforced on the connection to the S2A service for handshake.
|
||
|
const s2aTimeoutEnv = "S2A_TIMEOUT"
|
||
|
|
||
|
type s2av2TransportCreds struct {
|
||
|
info *credentials.ProtocolInfo
|
||
|
isClient bool
|
||
|
serverName string
|
||
|
s2av2Address string
|
||
|
transportCreds credentials.TransportCredentials
|
||
|
tokenManager *tokenmanager.AccessTokenManager
|
||
|
// localIdentity should only be used by the client.
|
||
|
localIdentity *commonpbv1.Identity
|
||
|
// localIdentities should only be used by the server.
|
||
|
localIdentities []*commonpbv1.Identity
|
||
|
verificationMode s2av2pb.ValidatePeerCertificateChainReq_VerificationMode
|
||
|
fallbackClientHandshake fallback.ClientHandshake
|
||
|
getS2AStream func(ctx context.Context, s2av2Address string) (stream.S2AStream, error)
|
||
|
serverAuthorizationPolicy []byte
|
||
|
}
|
||
|
|
||
|
// NewClientCreds returns a client-side transport credentials object that uses
|
||
|
// the S2Av2 to establish a secure connection with a server.
|
||
|
func NewClientCreds(s2av2Address string, transportCreds credentials.TransportCredentials, localIdentity *commonpbv1.Identity, verificationMode s2av2pb.ValidatePeerCertificateChainReq_VerificationMode, fallbackClientHandshakeFunc fallback.ClientHandshake, getS2AStream func(ctx context.Context, s2av2Address string) (stream.S2AStream, error), serverAuthorizationPolicy []byte) (credentials.TransportCredentials, error) {
|
||
|
// Create an AccessTokenManager instance to use to authenticate to S2Av2.
|
||
|
accessTokenManager, err := tokenmanager.NewSingleTokenAccessTokenManager()
|
||
|
|
||
|
creds := &s2av2TransportCreds{
|
||
|
info: &credentials.ProtocolInfo{
|
||
|
SecurityProtocol: s2aSecurityProtocol,
|
||
|
},
|
||
|
isClient: true,
|
||
|
serverName: "",
|
||
|
s2av2Address: s2av2Address,
|
||
|
transportCreds: transportCreds,
|
||
|
localIdentity: localIdentity,
|
||
|
verificationMode: verificationMode,
|
||
|
fallbackClientHandshake: fallbackClientHandshakeFunc,
|
||
|
getS2AStream: getS2AStream,
|
||
|
serverAuthorizationPolicy: serverAuthorizationPolicy,
|
||
|
}
|
||
|
if err != nil {
|
||
|
creds.tokenManager = nil
|
||
|
} else {
|
||
|
creds.tokenManager = &accessTokenManager
|
||
|
}
|
||
|
if grpclog.V(1) {
|
||
|
grpclog.Info("Created client S2Av2 transport credentials.")
|
||
|
}
|
||
|
return creds, nil
|
||
|
}
|
||
|
|
||
|
// NewServerCreds returns a server-side transport credentials object that uses
|
||
|
// the S2Av2 to establish a secure connection with a client.
|
||
|
func NewServerCreds(s2av2Address string, transportCreds credentials.TransportCredentials, localIdentities []*commonpbv1.Identity, verificationMode s2av2pb.ValidatePeerCertificateChainReq_VerificationMode, getS2AStream func(ctx context.Context, s2av2Address string) (stream.S2AStream, error)) (credentials.TransportCredentials, error) {
|
||
|
// Create an AccessTokenManager instance to use to authenticate to S2Av2.
|
||
|
accessTokenManager, err := tokenmanager.NewSingleTokenAccessTokenManager()
|
||
|
creds := &s2av2TransportCreds{
|
||
|
info: &credentials.ProtocolInfo{
|
||
|
SecurityProtocol: s2aSecurityProtocol,
|
||
|
},
|
||
|
isClient: false,
|
||
|
s2av2Address: s2av2Address,
|
||
|
transportCreds: transportCreds,
|
||
|
localIdentities: localIdentities,
|
||
|
verificationMode: verificationMode,
|
||
|
getS2AStream: getS2AStream,
|
||
|
}
|
||
|
if err != nil {
|
||
|
creds.tokenManager = nil
|
||
|
} else {
|
||
|
creds.tokenManager = &accessTokenManager
|
||
|
}
|
||
|
if grpclog.V(1) {
|
||
|
grpclog.Info("Created server S2Av2 transport credentials.")
|
||
|
}
|
||
|
return creds, nil
|
||
|
}
|
||
|
|
||
|
// ClientHandshake performs a client-side mTLS handshake using the S2Av2.
|
||
|
func (c *s2av2TransportCreds) ClientHandshake(ctx context.Context, serverAuthority string, rawConn net.Conn) (net.Conn, credentials.AuthInfo, error) {
|
||
|
if !c.isClient {
|
||
|
return nil, nil, errors.New("client handshake called using server transport credentials")
|
||
|
}
|
||
|
// Remove the port from serverAuthority.
|
||
|
serverName := removeServerNamePort(serverAuthority)
|
||
|
timeoutCtx, cancel := context.WithTimeout(ctx, GetS2ATimeout())
|
||
|
defer cancel()
|
||
|
var s2AStream stream.S2AStream
|
||
|
var err error
|
||
|
retry.Run(timeoutCtx,
|
||
|
func() error {
|
||
|
s2AStream, err = createStream(timeoutCtx, c.s2av2Address, c.transportCreds, c.getS2AStream)
|
||
|
return err
|
||
|
})
|
||
|
if err != nil {
|
||
|
grpclog.Infof("Failed to connect to S2Av2: %v", err)
|
||
|
if c.fallbackClientHandshake != nil {
|
||
|
return c.fallbackClientHandshake(ctx, serverAuthority, rawConn, err)
|
||
|
}
|
||
|
return nil, nil, err
|
||
|
}
|
||
|
defer s2AStream.CloseSend()
|
||
|
if grpclog.V(1) {
|
||
|
grpclog.Infof("Connected to S2Av2.")
|
||
|
}
|
||
|
var config *tls.Config
|
||
|
|
||
|
var tokenManager tokenmanager.AccessTokenManager
|
||
|
if c.tokenManager == nil {
|
||
|
tokenManager = nil
|
||
|
} else {
|
||
|
tokenManager = *c.tokenManager
|
||
|
}
|
||
|
|
||
|
sn := serverName
|
||
|
if c.serverName != "" {
|
||
|
sn = c.serverName
|
||
|
}
|
||
|
retry.Run(timeoutCtx,
|
||
|
func() error {
|
||
|
config, err = tlsconfigstore.GetTLSConfigurationForClient(sn, s2AStream, tokenManager, c.localIdentity, c.verificationMode, c.serverAuthorizationPolicy)
|
||
|
return err
|
||
|
})
|
||
|
if err != nil {
|
||
|
grpclog.Info("Failed to get client TLS config from S2Av2: %v", err)
|
||
|
if c.fallbackClientHandshake != nil {
|
||
|
return c.fallbackClientHandshake(ctx, serverAuthority, rawConn, err)
|
||
|
}
|
||
|
return nil, nil, err
|
||
|
}
|
||
|
if grpclog.V(1) {
|
||
|
grpclog.Infof("Got client TLS config from S2Av2.")
|
||
|
}
|
||
|
|
||
|
creds := credentials.NewTLS(config)
|
||
|
var conn net.Conn
|
||
|
var authInfo credentials.AuthInfo
|
||
|
retry.Run(timeoutCtx,
|
||
|
func() error {
|
||
|
conn, authInfo, err = creds.ClientHandshake(timeoutCtx, serverName, rawConn)
|
||
|
return err
|
||
|
})
|
||
|
if err != nil {
|
||
|
grpclog.Infof("Failed to do client handshake using S2Av2: %v", err)
|
||
|
if c.fallbackClientHandshake != nil {
|
||
|
return c.fallbackClientHandshake(ctx, serverAuthority, rawConn, err)
|
||
|
}
|
||
|
return nil, nil, err
|
||
|
}
|
||
|
grpclog.Infof("Successfully done client handshake using S2Av2 to: %s", serverName)
|
||
|
|
||
|
return conn, authInfo, err
|
||
|
}
|
||
|
|
||
|
// ServerHandshake performs a server-side mTLS handshake using the S2Av2.
|
||
|
func (c *s2av2TransportCreds) ServerHandshake(rawConn net.Conn) (net.Conn, credentials.AuthInfo, error) {
|
||
|
if c.isClient {
|
||
|
return nil, nil, errors.New("server handshake called using client transport credentials")
|
||
|
}
|
||
|
ctx, cancel := context.WithTimeout(context.Background(), GetS2ATimeout())
|
||
|
defer cancel()
|
||
|
var s2AStream stream.S2AStream
|
||
|
var err error
|
||
|
retry.Run(ctx,
|
||
|
func() error {
|
||
|
s2AStream, err = createStream(ctx, c.s2av2Address, c.transportCreds, c.getS2AStream)
|
||
|
return err
|
||
|
})
|
||
|
if err != nil {
|
||
|
grpclog.Infof("Failed to connect to S2Av2: %v", err)
|
||
|
return nil, nil, err
|
||
|
}
|
||
|
defer s2AStream.CloseSend()
|
||
|
if grpclog.V(1) {
|
||
|
grpclog.Infof("Connected to S2Av2.")
|
||
|
}
|
||
|
|
||
|
var tokenManager tokenmanager.AccessTokenManager
|
||
|
if c.tokenManager == nil {
|
||
|
tokenManager = nil
|
||
|
} else {
|
||
|
tokenManager = *c.tokenManager
|
||
|
}
|
||
|
|
||
|
var config *tls.Config
|
||
|
retry.Run(ctx,
|
||
|
func() error {
|
||
|
config, err = tlsconfigstore.GetTLSConfigurationForServer(s2AStream, tokenManager, c.localIdentities, c.verificationMode)
|
||
|
return err
|
||
|
})
|
||
|
if err != nil {
|
||
|
grpclog.Infof("Failed to get server TLS config from S2Av2: %v", err)
|
||
|
return nil, nil, err
|
||
|
}
|
||
|
if grpclog.V(1) {
|
||
|
grpclog.Infof("Got server TLS config from S2Av2.")
|
||
|
}
|
||
|
|
||
|
creds := credentials.NewTLS(config)
|
||
|
var conn net.Conn
|
||
|
var authInfo credentials.AuthInfo
|
||
|
retry.Run(ctx,
|
||
|
func() error {
|
||
|
conn, authInfo, err = creds.ServerHandshake(rawConn)
|
||
|
return err
|
||
|
})
|
||
|
if err != nil {
|
||
|
grpclog.Infof("Failed to do server handshake using S2Av2: %v", err)
|
||
|
return nil, nil, err
|
||
|
}
|
||
|
return conn, authInfo, err
|
||
|
}
|
||
|
|
||
|
// Info returns protocol info of s2av2TransportCreds.
|
||
|
func (c *s2av2TransportCreds) Info() credentials.ProtocolInfo {
|
||
|
return *c.info
|
||
|
}
|
||
|
|
||
|
// Clone makes a deep copy of s2av2TransportCreds.
|
||
|
func (c *s2av2TransportCreds) Clone() credentials.TransportCredentials {
|
||
|
info := *c.info
|
||
|
serverName := c.serverName
|
||
|
fallbackClientHandshake := c.fallbackClientHandshake
|
||
|
|
||
|
s2av2Address := c.s2av2Address
|
||
|
var tokenManager tokenmanager.AccessTokenManager
|
||
|
if c.tokenManager == nil {
|
||
|
tokenManager = nil
|
||
|
} else {
|
||
|
tokenManager = *c.tokenManager
|
||
|
}
|
||
|
verificationMode := c.verificationMode
|
||
|
var localIdentity *commonpbv1.Identity
|
||
|
if c.localIdentity != nil {
|
||
|
localIdentity = proto.Clone(c.localIdentity).(*commonpbv1.Identity)
|
||
|
}
|
||
|
var localIdentities []*commonpbv1.Identity
|
||
|
if c.localIdentities != nil {
|
||
|
localIdentities = make([]*commonpbv1.Identity, len(c.localIdentities))
|
||
|
for i, localIdentity := range c.localIdentities {
|
||
|
localIdentities[i] = proto.Clone(localIdentity).(*commonpbv1.Identity)
|
||
|
}
|
||
|
}
|
||
|
creds := &s2av2TransportCreds{
|
||
|
info: &info,
|
||
|
isClient: c.isClient,
|
||
|
serverName: serverName,
|
||
|
fallbackClientHandshake: fallbackClientHandshake,
|
||
|
s2av2Address: s2av2Address,
|
||
|
localIdentity: localIdentity,
|
||
|
localIdentities: localIdentities,
|
||
|
verificationMode: verificationMode,
|
||
|
}
|
||
|
if c.tokenManager == nil {
|
||
|
creds.tokenManager = nil
|
||
|
} else {
|
||
|
creds.tokenManager = &tokenManager
|
||
|
}
|
||
|
return creds
|
||
|
}
|
||
|
|
||
|
// NewClientTLSConfig returns a tls.Config instance that uses S2Av2 to establish a TLS connection as
|
||
|
// a client. The tls.Config MUST only be used to establish a single TLS connection.
|
||
|
func NewClientTLSConfig(
|
||
|
ctx context.Context,
|
||
|
s2av2Address string,
|
||
|
transportCreds credentials.TransportCredentials,
|
||
|
tokenManager tokenmanager.AccessTokenManager,
|
||
|
verificationMode s2av2pb.ValidatePeerCertificateChainReq_VerificationMode,
|
||
|
serverName string,
|
||
|
serverAuthorizationPolicy []byte) (*tls.Config, error) {
|
||
|
s2AStream, err := createStream(ctx, s2av2Address, transportCreds, nil)
|
||
|
if err != nil {
|
||
|
grpclog.Infof("Failed to connect to S2Av2: %v", err)
|
||
|
return nil, err
|
||
|
}
|
||
|
|
||
|
return tlsconfigstore.GetTLSConfigurationForClient(removeServerNamePort(serverName), s2AStream, tokenManager, nil, verificationMode, serverAuthorizationPolicy)
|
||
|
}
|
||
|
|
||
|
// OverrideServerName sets the ServerName in the s2av2TransportCreds protocol
|
||
|
// info. The ServerName MUST be a hostname.
|
||
|
func (c *s2av2TransportCreds) OverrideServerName(serverNameOverride string) error {
|
||
|
serverName := removeServerNamePort(serverNameOverride)
|
||
|
c.info.ServerName = serverName
|
||
|
c.serverName = serverName
|
||
|
return nil
|
||
|
}
|
||
|
|
||
|
// Remove the trailing port from server name.
|
||
|
func removeServerNamePort(serverName string) string {
|
||
|
name, _, err := net.SplitHostPort(serverName)
|
||
|
if err != nil {
|
||
|
name = serverName
|
||
|
}
|
||
|
return name
|
||
|
}
|
||
|
|
||
|
type s2AGrpcStream struct {
|
||
|
stream s2av2pb.S2AService_SetUpSessionClient
|
||
|
}
|
||
|
|
||
|
func (x s2AGrpcStream) Send(m *s2av2pb.SessionReq) error {
|
||
|
return x.stream.Send(m)
|
||
|
}
|
||
|
|
||
|
func (x s2AGrpcStream) Recv() (*s2av2pb.SessionResp, error) {
|
||
|
return x.stream.Recv()
|
||
|
}
|
||
|
|
||
|
func (x s2AGrpcStream) CloseSend() error {
|
||
|
return x.stream.CloseSend()
|
||
|
}
|
||
|
|
||
|
func createStream(ctx context.Context, s2av2Address string, transportCreds credentials.TransportCredentials, getS2AStream func(ctx context.Context, s2av2Address string) (stream.S2AStream, error)) (stream.S2AStream, error) {
|
||
|
if getS2AStream != nil {
|
||
|
return getS2AStream(ctx, s2av2Address)
|
||
|
}
|
||
|
// TODO(rmehta19): Consider whether to close the connection to S2Av2.
|
||
|
conn, err := service.Dial(ctx, s2av2Address, transportCreds)
|
||
|
if err != nil {
|
||
|
return nil, err
|
||
|
}
|
||
|
client := s2av2pb.NewS2AServiceClient(conn)
|
||
|
gRPCStream, err := client.SetUpSession(ctx, []grpc.CallOption{}...)
|
||
|
if err != nil {
|
||
|
return nil, err
|
||
|
}
|
||
|
return &s2AGrpcStream{
|
||
|
stream: gRPCStream,
|
||
|
}, nil
|
||
|
}
|
||
|
|
||
|
// GetS2ATimeout returns the timeout enforced on the connection to the S2A service for handshake.
|
||
|
func GetS2ATimeout() time.Duration {
|
||
|
timeout, err := time.ParseDuration(os.Getenv(s2aTimeoutEnv))
|
||
|
if err != nil {
|
||
|
return defaultS2ATimeout
|
||
|
}
|
||
|
return timeout
|
||
|
}
|