mirror of
https://github.com/rocky-linux/peridot.git
synced 2024-10-19 15:55:08 +00:00
405 lines
15 KiB
Go
405 lines
15 KiB
Go
|
/*
|
||
|
*
|
||
|
* Copyright 2022 Google LLC
|
||
|
*
|
||
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
||
|
* you may not use this file except in compliance with the License.
|
||
|
* You may obtain a copy of the License at
|
||
|
*
|
||
|
* https://www.apache.org/licenses/LICENSE-2.0
|
||
|
*
|
||
|
* Unless required by applicable law or agreed to in writing, software
|
||
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
||
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||
|
* See the License for the specific language governing permissions and
|
||
|
* limitations under the License.
|
||
|
*
|
||
|
*/
|
||
|
|
||
|
// Package tlsconfigstore offloads operations to S2Av2.
|
||
|
package tlsconfigstore
|
||
|
|
||
|
import (
|
||
|
"crypto/tls"
|
||
|
"crypto/x509"
|
||
|
"encoding/pem"
|
||
|
"errors"
|
||
|
"fmt"
|
||
|
|
||
|
"github.com/google/s2a-go/internal/tokenmanager"
|
||
|
"github.com/google/s2a-go/internal/v2/certverifier"
|
||
|
"github.com/google/s2a-go/internal/v2/remotesigner"
|
||
|
"github.com/google/s2a-go/stream"
|
||
|
"google.golang.org/grpc/codes"
|
||
|
"google.golang.org/grpc/grpclog"
|
||
|
|
||
|
commonpbv1 "github.com/google/s2a-go/internal/proto/common_go_proto"
|
||
|
commonpb "github.com/google/s2a-go/internal/proto/v2/common_go_proto"
|
||
|
s2av2pb "github.com/google/s2a-go/internal/proto/v2/s2a_go_proto"
|
||
|
)
|
||
|
|
||
|
const (
|
||
|
// HTTP/2
|
||
|
h2 = "h2"
|
||
|
)
|
||
|
|
||
|
// GetTLSConfigurationForClient returns a tls.Config instance for use by a client application.
|
||
|
func GetTLSConfigurationForClient(serverHostname string, s2AStream stream.S2AStream, tokenManager tokenmanager.AccessTokenManager, localIdentity *commonpbv1.Identity, verificationMode s2av2pb.ValidatePeerCertificateChainReq_VerificationMode, serverAuthorizationPolicy []byte) (*tls.Config, error) {
|
||
|
authMechanisms := getAuthMechanisms(tokenManager, []*commonpbv1.Identity{localIdentity})
|
||
|
|
||
|
if grpclog.V(1) {
|
||
|
grpclog.Infof("Sending request to S2Av2 for client TLS config.")
|
||
|
}
|
||
|
// Send request to S2Av2 for config.
|
||
|
if err := s2AStream.Send(&s2av2pb.SessionReq{
|
||
|
LocalIdentity: localIdentity,
|
||
|
AuthenticationMechanisms: authMechanisms,
|
||
|
ReqOneof: &s2av2pb.SessionReq_GetTlsConfigurationReq{
|
||
|
GetTlsConfigurationReq: &s2av2pb.GetTlsConfigurationReq{
|
||
|
ConnectionSide: commonpb.ConnectionSide_CONNECTION_SIDE_CLIENT,
|
||
|
},
|
||
|
},
|
||
|
}); err != nil {
|
||
|
grpclog.Infof("Failed to send request to S2Av2 for client TLS config")
|
||
|
return nil, err
|
||
|
}
|
||
|
|
||
|
// Get the response containing config from S2Av2.
|
||
|
resp, err := s2AStream.Recv()
|
||
|
if err != nil {
|
||
|
grpclog.Infof("Failed to receive client TLS config response from S2Av2.")
|
||
|
return nil, err
|
||
|
}
|
||
|
|
||
|
// TODO(rmehta19): Add unit test for this if statement.
|
||
|
if (resp.GetStatus() != nil) && (resp.GetStatus().Code != uint32(codes.OK)) {
|
||
|
return nil, fmt.Errorf("failed to get TLS configuration from S2A: %d, %v", resp.GetStatus().Code, resp.GetStatus().Details)
|
||
|
}
|
||
|
|
||
|
// Extract TLS configiguration from SessionResp.
|
||
|
tlsConfig := resp.GetGetTlsConfigurationResp().GetClientTlsConfiguration()
|
||
|
|
||
|
var cert tls.Certificate
|
||
|
for i, v := range tlsConfig.CertificateChain {
|
||
|
// Populate Certificates field.
|
||
|
block, _ := pem.Decode([]byte(v))
|
||
|
if block == nil {
|
||
|
return nil, errors.New("certificate in CertificateChain obtained from S2Av2 is empty")
|
||
|
}
|
||
|
x509Cert, err := x509.ParseCertificate(block.Bytes)
|
||
|
if err != nil {
|
||
|
return nil, err
|
||
|
}
|
||
|
cert.Certificate = append(cert.Certificate, x509Cert.Raw)
|
||
|
if i == 0 {
|
||
|
cert.Leaf = x509Cert
|
||
|
}
|
||
|
}
|
||
|
|
||
|
if len(tlsConfig.CertificateChain) > 0 {
|
||
|
cert.PrivateKey = remotesigner.New(cert.Leaf, s2AStream)
|
||
|
if cert.PrivateKey == nil {
|
||
|
return nil, errors.New("failed to retrieve Private Key from Remote Signer Library")
|
||
|
}
|
||
|
}
|
||
|
|
||
|
minVersion, maxVersion, err := getTLSMinMaxVersionsClient(tlsConfig)
|
||
|
if err != nil {
|
||
|
return nil, err
|
||
|
}
|
||
|
|
||
|
// Create mTLS credentials for client.
|
||
|
config := &tls.Config{
|
||
|
VerifyPeerCertificate: certverifier.VerifyServerCertificateChain(serverHostname, verificationMode, s2AStream, serverAuthorizationPolicy),
|
||
|
ServerName: serverHostname,
|
||
|
InsecureSkipVerify: true, // NOLINT
|
||
|
ClientSessionCache: nil,
|
||
|
SessionTicketsDisabled: true,
|
||
|
MinVersion: minVersion,
|
||
|
MaxVersion: maxVersion,
|
||
|
NextProtos: []string{h2},
|
||
|
}
|
||
|
if len(tlsConfig.CertificateChain) > 0 {
|
||
|
config.Certificates = []tls.Certificate{cert}
|
||
|
}
|
||
|
return config, nil
|
||
|
}
|
||
|
|
||
|
// GetTLSConfigurationForServer returns a tls.Config instance for use by a server application.
|
||
|
func GetTLSConfigurationForServer(s2AStream stream.S2AStream, tokenManager tokenmanager.AccessTokenManager, localIdentities []*commonpbv1.Identity, verificationMode s2av2pb.ValidatePeerCertificateChainReq_VerificationMode) (*tls.Config, error) {
|
||
|
return &tls.Config{
|
||
|
GetConfigForClient: ClientConfig(tokenManager, localIdentities, verificationMode, s2AStream),
|
||
|
}, nil
|
||
|
}
|
||
|
|
||
|
// ClientConfig builds a TLS config for a server to establish a secure
|
||
|
// connection with a client, based on SNI communicated during ClientHello.
|
||
|
// Ensures that server presents the correct certificate to establish a TLS
|
||
|
// connection.
|
||
|
func ClientConfig(tokenManager tokenmanager.AccessTokenManager, localIdentities []*commonpbv1.Identity, verificationMode s2av2pb.ValidatePeerCertificateChainReq_VerificationMode, s2AStream stream.S2AStream) func(chi *tls.ClientHelloInfo) (*tls.Config, error) {
|
||
|
return func(chi *tls.ClientHelloInfo) (*tls.Config, error) {
|
||
|
tlsConfig, err := getServerConfigFromS2Av2(tokenManager, localIdentities, chi.ServerName, s2AStream)
|
||
|
if err != nil {
|
||
|
return nil, err
|
||
|
}
|
||
|
|
||
|
var cert tls.Certificate
|
||
|
for i, v := range tlsConfig.CertificateChain {
|
||
|
// Populate Certificates field.
|
||
|
block, _ := pem.Decode([]byte(v))
|
||
|
if block == nil {
|
||
|
return nil, errors.New("certificate in CertificateChain obtained from S2Av2 is empty")
|
||
|
}
|
||
|
x509Cert, err := x509.ParseCertificate(block.Bytes)
|
||
|
if err != nil {
|
||
|
return nil, err
|
||
|
}
|
||
|
cert.Certificate = append(cert.Certificate, x509Cert.Raw)
|
||
|
if i == 0 {
|
||
|
cert.Leaf = x509Cert
|
||
|
}
|
||
|
}
|
||
|
|
||
|
cert.PrivateKey = remotesigner.New(cert.Leaf, s2AStream)
|
||
|
if cert.PrivateKey == nil {
|
||
|
return nil, errors.New("failed to retrieve Private Key from Remote Signer Library")
|
||
|
}
|
||
|
|
||
|
minVersion, maxVersion, err := getTLSMinMaxVersionsServer(tlsConfig)
|
||
|
if err != nil {
|
||
|
return nil, err
|
||
|
}
|
||
|
|
||
|
clientAuth := getTLSClientAuthType(tlsConfig)
|
||
|
|
||
|
var cipherSuites []uint16
|
||
|
cipherSuites = getCipherSuites(tlsConfig.Ciphersuites)
|
||
|
|
||
|
// Create mTLS credentials for server.
|
||
|
return &tls.Config{
|
||
|
Certificates: []tls.Certificate{cert},
|
||
|
VerifyPeerCertificate: certverifier.VerifyClientCertificateChain(verificationMode, s2AStream),
|
||
|
ClientAuth: clientAuth,
|
||
|
CipherSuites: cipherSuites,
|
||
|
SessionTicketsDisabled: true,
|
||
|
MinVersion: minVersion,
|
||
|
MaxVersion: maxVersion,
|
||
|
NextProtos: []string{h2},
|
||
|
}, nil
|
||
|
}
|
||
|
}
|
||
|
|
||
|
func getCipherSuites(tlsConfigCipherSuites []commonpb.Ciphersuite) []uint16 {
|
||
|
var tlsGoCipherSuites []uint16
|
||
|
for _, v := range tlsConfigCipherSuites {
|
||
|
s := getTLSCipherSuite(v)
|
||
|
if s != 0xffff {
|
||
|
tlsGoCipherSuites = append(tlsGoCipherSuites, s)
|
||
|
}
|
||
|
}
|
||
|
return tlsGoCipherSuites
|
||
|
}
|
||
|
|
||
|
func getTLSCipherSuite(tlsCipherSuite commonpb.Ciphersuite) uint16 {
|
||
|
switch tlsCipherSuite {
|
||
|
case commonpb.Ciphersuite_CIPHERSUITE_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256:
|
||
|
return tls.TLS_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256
|
||
|
case commonpb.Ciphersuite_CIPHERSUITE_ECDHE_ECDSA_WITH_AES_256_GCM_SHA384:
|
||
|
return tls.TLS_ECDHE_ECDSA_WITH_AES_256_GCM_SHA384
|
||
|
case commonpb.Ciphersuite_CIPHERSUITE_ECDHE_ECDSA_WITH_CHACHA20_POLY1305_SHA256:
|
||
|
return tls.TLS_ECDHE_ECDSA_WITH_CHACHA20_POLY1305_SHA256
|
||
|
case commonpb.Ciphersuite_CIPHERSUITE_ECDHE_RSA_WITH_AES_128_GCM_SHA256:
|
||
|
return tls.TLS_ECDHE_RSA_WITH_AES_128_GCM_SHA256
|
||
|
case commonpb.Ciphersuite_CIPHERSUITE_ECDHE_RSA_WITH_AES_256_GCM_SHA384:
|
||
|
return tls.TLS_ECDHE_RSA_WITH_AES_256_GCM_SHA384
|
||
|
case commonpb.Ciphersuite_CIPHERSUITE_ECDHE_RSA_WITH_CHACHA20_POLY1305_SHA256:
|
||
|
return tls.TLS_ECDHE_RSA_WITH_CHACHA20_POLY1305_SHA256
|
||
|
default:
|
||
|
return 0xffff
|
||
|
}
|
||
|
}
|
||
|
|
||
|
func getServerConfigFromS2Av2(tokenManager tokenmanager.AccessTokenManager, localIdentities []*commonpbv1.Identity, sni string, s2AStream stream.S2AStream) (*s2av2pb.GetTlsConfigurationResp_ServerTlsConfiguration, error) {
|
||
|
authMechanisms := getAuthMechanisms(tokenManager, localIdentities)
|
||
|
var locID *commonpbv1.Identity
|
||
|
if localIdentities != nil {
|
||
|
locID = localIdentities[0]
|
||
|
}
|
||
|
|
||
|
if err := s2AStream.Send(&s2av2pb.SessionReq{
|
||
|
LocalIdentity: locID,
|
||
|
AuthenticationMechanisms: authMechanisms,
|
||
|
ReqOneof: &s2av2pb.SessionReq_GetTlsConfigurationReq{
|
||
|
GetTlsConfigurationReq: &s2av2pb.GetTlsConfigurationReq{
|
||
|
ConnectionSide: commonpb.ConnectionSide_CONNECTION_SIDE_SERVER,
|
||
|
Sni: sni,
|
||
|
},
|
||
|
},
|
||
|
}); err != nil {
|
||
|
return nil, err
|
||
|
}
|
||
|
|
||
|
resp, err := s2AStream.Recv()
|
||
|
if err != nil {
|
||
|
return nil, err
|
||
|
}
|
||
|
|
||
|
// TODO(rmehta19): Add unit test for this if statement.
|
||
|
if (resp.GetStatus() != nil) && (resp.GetStatus().Code != uint32(codes.OK)) {
|
||
|
return nil, fmt.Errorf("failed to get TLS configuration from S2A: %d, %v", resp.GetStatus().Code, resp.GetStatus().Details)
|
||
|
}
|
||
|
|
||
|
return resp.GetGetTlsConfigurationResp().GetServerTlsConfiguration(), nil
|
||
|
}
|
||
|
|
||
|
func getTLSClientAuthType(tlsConfig *s2av2pb.GetTlsConfigurationResp_ServerTlsConfiguration) tls.ClientAuthType {
|
||
|
var clientAuth tls.ClientAuthType
|
||
|
switch x := tlsConfig.RequestClientCertificate; x {
|
||
|
case s2av2pb.GetTlsConfigurationResp_ServerTlsConfiguration_DONT_REQUEST_CLIENT_CERTIFICATE:
|
||
|
clientAuth = tls.NoClientCert
|
||
|
case s2av2pb.GetTlsConfigurationResp_ServerTlsConfiguration_REQUEST_CLIENT_CERTIFICATE_BUT_DONT_VERIFY:
|
||
|
clientAuth = tls.RequestClientCert
|
||
|
case s2av2pb.GetTlsConfigurationResp_ServerTlsConfiguration_REQUEST_CLIENT_CERTIFICATE_AND_VERIFY:
|
||
|
// This case actually maps to tls.VerifyClientCertIfGiven. However this
|
||
|
// mapping triggers normal verification, followed by custom verification,
|
||
|
// specified in VerifyPeerCertificate. To bypass normal verification, and
|
||
|
// only do custom verification we set clientAuth to RequireAnyClientCert or
|
||
|
// RequestClientCert. See https://github.com/google/s2a-go/pull/43 for full
|
||
|
// discussion.
|
||
|
clientAuth = tls.RequireAnyClientCert
|
||
|
case s2av2pb.GetTlsConfigurationResp_ServerTlsConfiguration_REQUEST_AND_REQUIRE_CLIENT_CERTIFICATE_BUT_DONT_VERIFY:
|
||
|
clientAuth = tls.RequireAnyClientCert
|
||
|
case s2av2pb.GetTlsConfigurationResp_ServerTlsConfiguration_REQUEST_AND_REQUIRE_CLIENT_CERTIFICATE_AND_VERIFY:
|
||
|
// This case actually maps to tls.RequireAndVerifyClientCert. However this
|
||
|
// mapping triggers normal verification, followed by custom verification,
|
||
|
// specified in VerifyPeerCertificate. To bypass normal verification, and
|
||
|
// only do custom verification we set clientAuth to RequireAnyClientCert or
|
||
|
// RequestClientCert. See https://github.com/google/s2a-go/pull/43 for full
|
||
|
// discussion.
|
||
|
clientAuth = tls.RequireAnyClientCert
|
||
|
default:
|
||
|
clientAuth = tls.RequireAnyClientCert
|
||
|
}
|
||
|
return clientAuth
|
||
|
}
|
||
|
|
||
|
func getAuthMechanisms(tokenManager tokenmanager.AccessTokenManager, localIdentities []*commonpbv1.Identity) []*s2av2pb.AuthenticationMechanism {
|
||
|
if tokenManager == nil {
|
||
|
return nil
|
||
|
}
|
||
|
if len(localIdentities) == 0 {
|
||
|
token, err := tokenManager.DefaultToken()
|
||
|
if err != nil {
|
||
|
grpclog.Infof("Unable to get token for empty local identity: %v", err)
|
||
|
return nil
|
||
|
}
|
||
|
return []*s2av2pb.AuthenticationMechanism{
|
||
|
{
|
||
|
MechanismOneof: &s2av2pb.AuthenticationMechanism_Token{
|
||
|
Token: token,
|
||
|
},
|
||
|
},
|
||
|
}
|
||
|
}
|
||
|
var authMechanisms []*s2av2pb.AuthenticationMechanism
|
||
|
for _, localIdentity := range localIdentities {
|
||
|
if localIdentity == nil {
|
||
|
token, err := tokenManager.DefaultToken()
|
||
|
if err != nil {
|
||
|
grpclog.Infof("Unable to get default token for local identity %v: %v", localIdentity, err)
|
||
|
continue
|
||
|
}
|
||
|
authMechanisms = append(authMechanisms, &s2av2pb.AuthenticationMechanism{
|
||
|
Identity: localIdentity,
|
||
|
MechanismOneof: &s2av2pb.AuthenticationMechanism_Token{
|
||
|
Token: token,
|
||
|
},
|
||
|
})
|
||
|
} else {
|
||
|
token, err := tokenManager.Token(localIdentity)
|
||
|
if err != nil {
|
||
|
grpclog.Infof("Unable to get token for local identity %v: %v", localIdentity, err)
|
||
|
continue
|
||
|
}
|
||
|
authMechanisms = append(authMechanisms, &s2av2pb.AuthenticationMechanism{
|
||
|
Identity: localIdentity,
|
||
|
MechanismOneof: &s2av2pb.AuthenticationMechanism_Token{
|
||
|
Token: token,
|
||
|
},
|
||
|
})
|
||
|
}
|
||
|
}
|
||
|
return authMechanisms
|
||
|
}
|
||
|
|
||
|
// TODO(rmehta19): refactor switch statements into a helper function.
|
||
|
func getTLSMinMaxVersionsClient(tlsConfig *s2av2pb.GetTlsConfigurationResp_ClientTlsConfiguration) (uint16, uint16, error) {
|
||
|
// Map S2Av2 TLSVersion to consts defined in tls package.
|
||
|
var minVersion uint16
|
||
|
var maxVersion uint16
|
||
|
switch x := tlsConfig.MinTlsVersion; x {
|
||
|
case commonpb.TLSVersion_TLS_VERSION_1_0:
|
||
|
minVersion = tls.VersionTLS10
|
||
|
case commonpb.TLSVersion_TLS_VERSION_1_1:
|
||
|
minVersion = tls.VersionTLS11
|
||
|
case commonpb.TLSVersion_TLS_VERSION_1_2:
|
||
|
minVersion = tls.VersionTLS12
|
||
|
case commonpb.TLSVersion_TLS_VERSION_1_3:
|
||
|
minVersion = tls.VersionTLS13
|
||
|
default:
|
||
|
return minVersion, maxVersion, fmt.Errorf("S2Av2 provided invalid MinTlsVersion: %v", x)
|
||
|
}
|
||
|
|
||
|
switch x := tlsConfig.MaxTlsVersion; x {
|
||
|
case commonpb.TLSVersion_TLS_VERSION_1_0:
|
||
|
maxVersion = tls.VersionTLS10
|
||
|
case commonpb.TLSVersion_TLS_VERSION_1_1:
|
||
|
maxVersion = tls.VersionTLS11
|
||
|
case commonpb.TLSVersion_TLS_VERSION_1_2:
|
||
|
maxVersion = tls.VersionTLS12
|
||
|
case commonpb.TLSVersion_TLS_VERSION_1_3:
|
||
|
maxVersion = tls.VersionTLS13
|
||
|
default:
|
||
|
return minVersion, maxVersion, fmt.Errorf("S2Av2 provided invalid MaxTlsVersion: %v", x)
|
||
|
}
|
||
|
if minVersion > maxVersion {
|
||
|
return minVersion, maxVersion, errors.New("S2Av2 provided minVersion > maxVersion")
|
||
|
}
|
||
|
return minVersion, maxVersion, nil
|
||
|
}
|
||
|
|
||
|
func getTLSMinMaxVersionsServer(tlsConfig *s2av2pb.GetTlsConfigurationResp_ServerTlsConfiguration) (uint16, uint16, error) {
|
||
|
// Map S2Av2 TLSVersion to consts defined in tls package.
|
||
|
var minVersion uint16
|
||
|
var maxVersion uint16
|
||
|
switch x := tlsConfig.MinTlsVersion; x {
|
||
|
case commonpb.TLSVersion_TLS_VERSION_1_0:
|
||
|
minVersion = tls.VersionTLS10
|
||
|
case commonpb.TLSVersion_TLS_VERSION_1_1:
|
||
|
minVersion = tls.VersionTLS11
|
||
|
case commonpb.TLSVersion_TLS_VERSION_1_2:
|
||
|
minVersion = tls.VersionTLS12
|
||
|
case commonpb.TLSVersion_TLS_VERSION_1_3:
|
||
|
minVersion = tls.VersionTLS13
|
||
|
default:
|
||
|
return minVersion, maxVersion, fmt.Errorf("S2Av2 provided invalid MinTlsVersion: %v", x)
|
||
|
}
|
||
|
|
||
|
switch x := tlsConfig.MaxTlsVersion; x {
|
||
|
case commonpb.TLSVersion_TLS_VERSION_1_0:
|
||
|
maxVersion = tls.VersionTLS10
|
||
|
case commonpb.TLSVersion_TLS_VERSION_1_1:
|
||
|
maxVersion = tls.VersionTLS11
|
||
|
case commonpb.TLSVersion_TLS_VERSION_1_2:
|
||
|
maxVersion = tls.VersionTLS12
|
||
|
case commonpb.TLSVersion_TLS_VERSION_1_3:
|
||
|
maxVersion = tls.VersionTLS13
|
||
|
default:
|
||
|
return minVersion, maxVersion, fmt.Errorf("S2Av2 provided invalid MaxTlsVersion: %v", x)
|
||
|
}
|
||
|
if minVersion > maxVersion {
|
||
|
return minVersion, maxVersion, errors.New("S2Av2 provided minVersion > maxVersion")
|
||
|
}
|
||
|
return minVersion, maxVersion, nil
|
||
|
}
|