peridot/vendor/github.com/google/s2a-go/internal/record/record.go

758 lines
30 KiB
Go
Raw Normal View History

/*
*
* Copyright 2021 Google LLC
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* https://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*
*/
// Package record implements the TLS 1.3 record protocol used by the S2A
// transport credentials.
package record
import (
"encoding/binary"
"errors"
"fmt"
"math"
"net"
"sync"
commonpb "github.com/google/s2a-go/internal/proto/common_go_proto"
"github.com/google/s2a-go/internal/record/internal/halfconn"
"github.com/google/s2a-go/internal/tokenmanager"
"google.golang.org/grpc/grpclog"
)
// recordType is the `ContentType` as described in
// https://tools.ietf.org/html/rfc8446#section-5.1.
type recordType byte
const (
alert recordType = 21
handshake recordType = 22
applicationData recordType = 23
)
// keyUpdateRequest is the `KeyUpdateRequest` as described in
// https://tools.ietf.org/html/rfc8446#section-4.6.3.
type keyUpdateRequest byte
const (
updateNotRequested keyUpdateRequest = 0
updateRequested keyUpdateRequest = 1
)
// alertDescription is the `AlertDescription` as described in
// https://tools.ietf.org/html/rfc8446#section-6.
type alertDescription byte
const (
closeNotify alertDescription = 0
)
// sessionTicketState is used to determine whether session tickets have not yet
// been received, are in the process of being received, or have finished
// receiving.
type sessionTicketState byte
const (
ticketsNotYetReceived sessionTicketState = 0
receivingTickets sessionTicketState = 1
notReceivingTickets sessionTicketState = 2
)
const (
// The TLS 1.3-specific constants below (tlsRecordMaxPlaintextSize,
// tlsRecordHeaderSize, tlsRecordTypeSize) were taken from
// https://tools.ietf.org/html/rfc8446#section-5.1.
// tlsRecordMaxPlaintextSize is the maximum size in bytes of the plaintext
// in a single TLS 1.3 record.
tlsRecordMaxPlaintextSize = 16384 // 2^14
// tlsRecordTypeSize is the size in bytes of the TLS 1.3 record type.
tlsRecordTypeSize = 1
// tlsTagSize is the size in bytes of the tag of the following three
// ciphersuites: AES-128-GCM-SHA256, AES-256-GCM-SHA384,
// CHACHA20-POLY1305-SHA256.
tlsTagSize = 16
// tlsRecordMaxPayloadSize is the maximum size in bytes of the payload in a
// single TLS 1.3 record. This is the maximum size of the plaintext plus the
// record type byte and 16 bytes of the tag.
tlsRecordMaxPayloadSize = tlsRecordMaxPlaintextSize + tlsRecordTypeSize + tlsTagSize
// tlsRecordHeaderTypeSize is the size in bytes of the TLS 1.3 record
// header type.
tlsRecordHeaderTypeSize = 1
// tlsRecordHeaderLegacyRecordVersionSize is the size in bytes of the TLS
// 1.3 record header legacy record version.
tlsRecordHeaderLegacyRecordVersionSize = 2
// tlsRecordHeaderPayloadLengthSize is the size in bytes of the TLS 1.3
// record header payload length.
tlsRecordHeaderPayloadLengthSize = 2
// tlsRecordHeaderSize is the size in bytes of the TLS 1.3 record header.
tlsRecordHeaderSize = tlsRecordHeaderTypeSize + tlsRecordHeaderLegacyRecordVersionSize + tlsRecordHeaderPayloadLengthSize
// tlsRecordMaxSize
tlsRecordMaxSize = tlsRecordMaxPayloadSize + tlsRecordHeaderSize
// tlsApplicationData is the application data type of the TLS 1.3 record
// header.
tlsApplicationData = 23
// tlsLegacyRecordVersion is the legacy record version of the TLS record.
tlsLegacyRecordVersion = 3
// tlsAlertSize is the size in bytes of an alert of TLS 1.3.
tlsAlertSize = 2
)
const (
// These are TLS 1.3 handshake-specific constants.
// tlsHandshakeNewSessionTicketType is the prefix of a handshake new session
// ticket message of TLS 1.3.
tlsHandshakeNewSessionTicketType = 4
// tlsHandshakeKeyUpdateType is the prefix of a handshake key update message
// of TLS 1.3.
tlsHandshakeKeyUpdateType = 24
// tlsHandshakeMsgTypeSize is the size in bytes of the TLS 1.3 handshake
// message type field.
tlsHandshakeMsgTypeSize = 1
// tlsHandshakeLengthSize is the size in bytes of the TLS 1.3 handshake
// message length field.
tlsHandshakeLengthSize = 3
// tlsHandshakeKeyUpdateMsgSize is the size in bytes of the TLS 1.3
// handshake key update message.
tlsHandshakeKeyUpdateMsgSize = 1
// tlsHandshakePrefixSize is the size in bytes of the prefix of the TLS 1.3
// handshake message.
tlsHandshakePrefixSize = 4
// tlsMaxSessionTicketSize is the maximum size of a NewSessionTicket message
// in TLS 1.3. This is the sum of the max sizes of all the fields in the
// NewSessionTicket struct specified in
// https://tools.ietf.org/html/rfc8446#section-4.6.1.
tlsMaxSessionTicketSize = 131338
)
const (
// outBufMaxRecords is the maximum number of records that can fit in the
// ourRecordsBuf buffer.
outBufMaxRecords = 16
// outBufMaxSize is the maximum size (in bytes) of the outRecordsBuf buffer.
outBufMaxSize = outBufMaxRecords * tlsRecordMaxSize
// maxAllowedTickets is the maximum number of session tickets that are
// allowed. The number of tickets are limited to ensure that the size of the
// ticket queue does not grow indefinitely. S2A also keeps a limit on the
// number of tickets that it caches.
maxAllowedTickets = 5
)
// preConstructedKeyUpdateMsg holds the key update message. This is needed as an
// optimization so that the same message does not need to be constructed every
// time a key update message is sent.
var preConstructedKeyUpdateMsg = buildKeyUpdateRequest()
// conn represents a secured TLS connection. It implements the net.Conn
// interface.
type conn struct {
net.Conn
// inConn is the half connection responsible for decrypting incoming bytes.
inConn *halfconn.S2AHalfConnection
// outConn is the half connection responsible for encrypting outgoing bytes.
outConn *halfconn.S2AHalfConnection
// pendingApplicationData holds data that has been read from the connection
// and decrypted, but has not yet been returned by Read.
pendingApplicationData []byte
// unusedBuf holds data read from the network that has not yet been
// decrypted. This data might not consist of a complete record. It may
// consist of several records, the last of which could be incomplete.
unusedBuf []byte
// outRecordsBuf is a buffer used to store outgoing TLS records before
// they are written to the network.
outRecordsBuf []byte
// nextRecord stores the next record info in the unusedBuf buffer.
nextRecord []byte
// overheadSize is the overhead size in bytes of each TLS 1.3 record, which
// is computed as overheadSize = header size + record type byte + tag size.
// Note that there is no padding by zeros in the overhead calculation.
overheadSize int
// readMutex guards against concurrent calls to Read. This is required since
// Close may be called during a Read.
readMutex sync.Mutex
// writeMutex guards against concurrent calls to Write. This is required
// since Close may be called during a Write, and also because a key update
// message may be written during a Read.
writeMutex sync.Mutex
// handshakeBuf holds handshake messages while they are being processed.
handshakeBuf []byte
// ticketState is the current processing state of the session tickets.
ticketState sessionTicketState
// sessionTickets holds the completed session tickets until they are sent to
// the handshaker service for processing.
sessionTickets [][]byte
// ticketSender sends session tickets to the S2A handshaker service.
ticketSender s2aTicketSender
// callComplete is a channel that blocks closing the record protocol until a
// pending call to the S2A completes.
callComplete chan bool
}
// ConnParameters holds the parameters used for creating a new conn object.
type ConnParameters struct {
// NetConn is the TCP connection to the peer. This parameter is required.
NetConn net.Conn
// Ciphersuite is the TLS ciphersuite negotiated by the S2A handshaker
// service. This parameter is required.
Ciphersuite commonpb.Ciphersuite
// TLSVersion is the TLS version number negotiated by the S2A handshaker
// service. This parameter is required.
TLSVersion commonpb.TLSVersion
// InTrafficSecret is the traffic secret used to derive the session key for
// the inbound direction. This parameter is required.
InTrafficSecret []byte
// OutTrafficSecret is the traffic secret used to derive the session key
// for the outbound direction. This parameter is required.
OutTrafficSecret []byte
// UnusedBuf is the data read from the network that has not yet been
// decrypted. This parameter is optional. If not provided, then no
// application data was sent in the same flight of messages as the final
// handshake message.
UnusedBuf []byte
// InSequence is the sequence number of the next, incoming, TLS record.
// This parameter is required.
InSequence uint64
// OutSequence is the sequence number of the next, outgoing, TLS record.
// This parameter is required.
OutSequence uint64
// HSAddr stores the address of the S2A handshaker service. This parameter
// is optional. If not provided, then TLS resumption is disabled.
HSAddr string
// ConnectionId is the connection identifier that was created and sent by
// S2A at the end of a handshake.
ConnectionID uint64
// LocalIdentity is the local identity that was used by S2A during session
// setup and included in the session result.
LocalIdentity *commonpb.Identity
// EnsureProcessSessionTickets allows users to wait and ensure that all
// available session tickets are sent to S2A before a process completes.
EnsureProcessSessionTickets *sync.WaitGroup
}
// NewConn creates a TLS record protocol that wraps the TCP connection.
func NewConn(o *ConnParameters) (net.Conn, error) {
if o == nil {
return nil, errors.New("conn options must not be nil")
}
if o.TLSVersion != commonpb.TLSVersion_TLS1_3 {
return nil, errors.New("TLS version must be TLS 1.3")
}
inConn, err := halfconn.New(o.Ciphersuite, o.InTrafficSecret, o.InSequence)
if err != nil {
return nil, fmt.Errorf("failed to create inbound half connection: %v", err)
}
outConn, err := halfconn.New(o.Ciphersuite, o.OutTrafficSecret, o.OutSequence)
if err != nil {
return nil, fmt.Errorf("failed to create outbound half connection: %v", err)
}
// The tag size for the in/out connections should be the same.
overheadSize := tlsRecordHeaderSize + tlsRecordTypeSize + inConn.TagSize()
var unusedBuf []byte
if o.UnusedBuf == nil {
// We pre-allocate unusedBuf to be of size
// 2*tlsRecordMaxSize-1 during initialization. We only read from the
// network into unusedBuf when unusedBuf does not contain a complete
// record and the incomplete record is at most tlsRecordMaxSize-1
// (bytes). And we read at most tlsRecordMaxSize bytes of data from the
// network into unusedBuf at one time. Therefore, 2*tlsRecordMaxSize-1
// is large enough to buffer data read from the network.
unusedBuf = make([]byte, 0, 2*tlsRecordMaxSize-1)
} else {
unusedBuf = make([]byte, len(o.UnusedBuf))
copy(unusedBuf, o.UnusedBuf)
}
tokenManager, err := tokenmanager.NewSingleTokenAccessTokenManager()
if err != nil {
grpclog.Infof("failed to create single token access token manager: %v", err)
}
s2aConn := &conn{
Conn: o.NetConn,
inConn: inConn,
outConn: outConn,
unusedBuf: unusedBuf,
outRecordsBuf: make([]byte, tlsRecordMaxSize),
nextRecord: unusedBuf,
overheadSize: overheadSize,
ticketState: ticketsNotYetReceived,
// Pre-allocate the buffer for one session ticket message and the max
// plaintext size. This is the largest size that handshakeBuf will need
// to hold. The largest incomplete handshake message is the
// [handshake header size] + [max session ticket size] - 1.
// Then, tlsRecordMaxPlaintextSize is the maximum size that will be
// appended to the handshakeBuf before the handshake message is
// completed. Therefore, the buffer size below should be large enough to
// buffer any handshake messages.
handshakeBuf: make([]byte, 0, tlsHandshakePrefixSize+tlsMaxSessionTicketSize+tlsRecordMaxPlaintextSize-1),
ticketSender: &ticketSender{
hsAddr: o.HSAddr,
connectionID: o.ConnectionID,
localIdentity: o.LocalIdentity,
tokenManager: tokenManager,
ensureProcessSessionTickets: o.EnsureProcessSessionTickets,
},
callComplete: make(chan bool),
}
return s2aConn, nil
}
// Read reads and decrypts a TLS 1.3 record from the underlying connection, and
// copies any application data received from the peer into b. If the size of the
// payload is greater than len(b), Read retains the remaining bytes in an
// internal buffer, and subsequent calls to Read will read from this buffer
// until it is exhausted. At most 1 TLS record worth of application data is
// written to b for each call to Read.
//
// Note that for the user to efficiently call this method, the user should
// ensure that the buffer b is allocated such that the buffer does not have any
// unused segments. This can be done by calling Read via io.ReadFull, which
// continually calls Read until the specified buffer has been filled. Also note
// that the user should close the connection via Close() if an error is thrown
// by a call to Read.
func (p *conn) Read(b []byte) (n int, err error) {
p.readMutex.Lock()
defer p.readMutex.Unlock()
// Check if p.pendingApplication data has leftover application data from
// the previous call to Read.
if len(p.pendingApplicationData) == 0 {
// Read a full record from the wire.
record, err := p.readFullRecord()
if err != nil {
return 0, err
}
// Now we have a complete record, so split the header and validate it
// The TLS record is split into 2 pieces: the record header and the
// payload. The payload has the following form:
// [payload] = [ciphertext of application data]
// + [ciphertext of record type byte]
// + [(optionally) ciphertext of padding by zeros]
// + [tag]
header, payload, err := splitAndValidateHeader(record)
if err != nil {
return 0, err
}
// Decrypt the ciphertext.
p.pendingApplicationData, err = p.inConn.Decrypt(payload[:0], payload, header)
if err != nil {
return 0, err
}
// Remove the padding by zeros and the record type byte from the
// p.pendingApplicationData buffer.
msgType, err := p.stripPaddingAndType()
if err != nil {
return 0, err
}
// Check that the length of the plaintext after stripping the padding
// and record type byte is under the maximum plaintext size.
if len(p.pendingApplicationData) > tlsRecordMaxPlaintextSize {
return 0, errors.New("plaintext size larger than maximum")
}
// The expected message types are application data, alert, and
// handshake. For application data, the bytes are directly copied into
// b. For an alert, the type of the alert is checked and the connection
// is closed on a close notify alert. For a handshake message, the
// handshake message type is checked. The handshake message type can be
// a key update type, for which we advance the traffic secret, and a
// new session ticket type, for which we send the received ticket to S2A
// for processing.
switch msgType {
case applicationData:
if len(p.handshakeBuf) > 0 {
return 0, errors.New("application data received while processing fragmented handshake messages")
}
if p.ticketState == receivingTickets {
p.ticketState = notReceivingTickets
grpclog.Infof("Sending session tickets to S2A.")
p.ticketSender.sendTicketsToS2A(p.sessionTickets, p.callComplete)
}
case alert:
return 0, p.handleAlertMessage()
case handshake:
if err = p.handleHandshakeMessage(); err != nil {
return 0, err
}
return 0, nil
default:
return 0, errors.New("unknown record type")
}
}
// Write as much application data as possible to b, the output buffer.
n = copy(b, p.pendingApplicationData)
p.pendingApplicationData = p.pendingApplicationData[n:]
return n, nil
}
// Write divides b into segments of size tlsRecordMaxPlaintextSize, builds a
// TLS 1.3 record (of type "application data") from each segment, and sends
// the record to the peer. It returns the number of plaintext bytes that were
// successfully sent to the peer.
func (p *conn) Write(b []byte) (n int, err error) {
p.writeMutex.Lock()
defer p.writeMutex.Unlock()
return p.writeTLSRecord(b, tlsApplicationData)
}
// writeTLSRecord divides b into segments of size maxPlaintextBytesPerRecord,
// builds a TLS 1.3 record (of type recordType) from each segment, and sends
// the record to the peer. It returns the number of plaintext bytes that were
// successfully sent to the peer.
func (p *conn) writeTLSRecord(b []byte, recordType byte) (n int, err error) {
// Create a record of only header, record type, and tag if given empty
// byte array.
if len(b) == 0 {
recordEndIndex, _, err := p.buildRecord(b, recordType, 0)
if err != nil {
return 0, err
}
// Write the bytes stored in outRecordsBuf to p.Conn. Since we return
// the number of plaintext bytes written without overhead, we will
// always return 0 while p.Conn.Write returns the entire record length.
_, err = p.Conn.Write(p.outRecordsBuf[:recordEndIndex])
return 0, err
}
numRecords := int(math.Ceil(float64(len(b)) / float64(tlsRecordMaxPlaintextSize)))
totalRecordsSize := len(b) + numRecords*p.overheadSize
partialBSize := len(b)
if totalRecordsSize > outBufMaxSize {
totalRecordsSize = outBufMaxSize
partialBSize = outBufMaxRecords * tlsRecordMaxPlaintextSize
}
if len(p.outRecordsBuf) < totalRecordsSize {
p.outRecordsBuf = make([]byte, totalRecordsSize)
}
for bStart := 0; bStart < len(b); bStart += partialBSize {
bEnd := bStart + partialBSize
if bEnd > len(b) {
bEnd = len(b)
}
partialB := b[bStart:bEnd]
recordEndIndex := 0
for len(partialB) > 0 {
recordEndIndex, partialB, err = p.buildRecord(partialB, recordType, recordEndIndex)
if err != nil {
// Return the amount of bytes written prior to the error.
return bStart, err
}
}
// Write the bytes stored in outRecordsBuf to p.Conn. If there is an
// error, calculate the total number of plaintext bytes of complete
// records successfully written to the peer and return it.
nn, err := p.Conn.Write(p.outRecordsBuf[:recordEndIndex])
if err != nil {
numberOfCompletedRecords := int(math.Floor(float64(nn) / float64(tlsRecordMaxSize)))
return bStart + numberOfCompletedRecords*tlsRecordMaxPlaintextSize, err
}
}
return len(b), nil
}
// buildRecord builds a TLS 1.3 record of type recordType from plaintext,
// and writes the record to outRecordsBuf at recordStartIndex. The record will
// have at most tlsRecordMaxPlaintextSize bytes of payload. It returns the
// index of outRecordsBuf where the current record ends, as well as any
// remaining plaintext bytes.
func (p *conn) buildRecord(plaintext []byte, recordType byte, recordStartIndex int) (n int, remainingPlaintext []byte, err error) {
// Construct the payload, which consists of application data and record type.
dataLen := len(plaintext)
if dataLen > tlsRecordMaxPlaintextSize {
dataLen = tlsRecordMaxPlaintextSize
}
remainingPlaintext = plaintext[dataLen:]
newRecordBuf := p.outRecordsBuf[recordStartIndex:]
copy(newRecordBuf[tlsRecordHeaderSize:], plaintext[:dataLen])
newRecordBuf[tlsRecordHeaderSize+dataLen] = recordType
payload := newRecordBuf[tlsRecordHeaderSize : tlsRecordHeaderSize+dataLen+1] // 1 is for the recordType.
// Construct the header.
newRecordBuf[0] = tlsApplicationData
newRecordBuf[1] = tlsLegacyRecordVersion
newRecordBuf[2] = tlsLegacyRecordVersion
binary.BigEndian.PutUint16(newRecordBuf[3:], uint16(len(payload)+tlsTagSize))
header := newRecordBuf[:tlsRecordHeaderSize]
// Encrypt the payload using header as aad.
encryptedPayload, err := p.outConn.Encrypt(newRecordBuf[tlsRecordHeaderSize:][:0], payload, header)
if err != nil {
return 0, plaintext, err
}
recordStartIndex += len(header) + len(encryptedPayload)
return recordStartIndex, remainingPlaintext, nil
}
func (p *conn) Close() error {
p.readMutex.Lock()
defer p.readMutex.Unlock()
p.writeMutex.Lock()
defer p.writeMutex.Unlock()
// If p.ticketState is equal to notReceivingTickets, then S2A has
// been sent a flight of session tickets, and we must wait for the
// call to S2A to complete before closing the record protocol.
if p.ticketState == notReceivingTickets {
<-p.callComplete
grpclog.Infof("Safe to close the connection because sending tickets to S2A is (already) complete.")
}
return p.Conn.Close()
}
// stripPaddingAndType strips the padding by zeros and record type from
// p.pendingApplicationData and returns the record type. Note that
// p.pendingApplicationData should be of the form:
// [application data] + [record type byte] + [trailing zeros]
func (p *conn) stripPaddingAndType() (recordType, error) {
if len(p.pendingApplicationData) == 0 {
return 0, errors.New("application data had length 0")
}
i := len(p.pendingApplicationData) - 1
// Search for the index of the record type byte.
for i > 0 {
if p.pendingApplicationData[i] != 0 {
break
}
i--
}
rt := recordType(p.pendingApplicationData[i])
p.pendingApplicationData = p.pendingApplicationData[:i]
return rt, nil
}
// readFullRecord reads from the wire until a record is completed and returns
// the full record.
func (p *conn) readFullRecord() (fullRecord []byte, err error) {
fullRecord, p.nextRecord, err = parseReadBuffer(p.nextRecord, tlsRecordMaxPayloadSize)
if err != nil {
return nil, err
}
// Check whether the next record to be decrypted has been completely
// received.
if len(fullRecord) == 0 {
copy(p.unusedBuf, p.nextRecord)
p.unusedBuf = p.unusedBuf[:len(p.nextRecord)]
// Always copy next incomplete record to the beginning of the
// unusedBuf buffer and reset nextRecord to it.
p.nextRecord = p.unusedBuf
}
// Keep reading from the wire until we have a complete record.
for len(fullRecord) == 0 {
if len(p.unusedBuf) == cap(p.unusedBuf) {
tmp := make([]byte, len(p.unusedBuf), cap(p.unusedBuf)+tlsRecordMaxSize)
copy(tmp, p.unusedBuf)
p.unusedBuf = tmp
}
n, err := p.Conn.Read(p.unusedBuf[len(p.unusedBuf):min(cap(p.unusedBuf), len(p.unusedBuf)+tlsRecordMaxSize)])
if err != nil {
return nil, err
}
p.unusedBuf = p.unusedBuf[:len(p.unusedBuf)+n]
fullRecord, p.nextRecord, err = parseReadBuffer(p.unusedBuf, tlsRecordMaxPayloadSize)
if err != nil {
return nil, err
}
}
return fullRecord, nil
}
// parseReadBuffer parses the provided buffer and returns a full record and any
// remaining bytes in that buffer. If the record is incomplete, nil is returned
// for the first return value and the given byte buffer is returned for the
// second return value. The length of the payload specified by the header should
// not be greater than maxLen, otherwise an error is returned. Note that this
// function does not allocate or copy any buffers.
func parseReadBuffer(b []byte, maxLen uint16) (fullRecord, remaining []byte, err error) {
// If the header is not complete, return the provided buffer as remaining
// buffer.
if len(b) < tlsRecordHeaderSize {
return nil, b, nil
}
msgLenField := b[tlsRecordHeaderTypeSize+tlsRecordHeaderLegacyRecordVersionSize : tlsRecordHeaderSize]
length := binary.BigEndian.Uint16(msgLenField)
if length > maxLen {
return nil, nil, fmt.Errorf("record length larger than the limit %d", maxLen)
}
if len(b) < int(length)+tlsRecordHeaderSize {
// Record is not complete yet.
return nil, b, nil
}
return b[:tlsRecordHeaderSize+length], b[tlsRecordHeaderSize+length:], nil
}
// splitAndValidateHeader splits the header from the payload in the TLS 1.3
// record and returns them. Note that the header is checked for validity, and an
// error is returned when an invalid header is parsed. Also note that this
// function does not allocate or copy any buffers.
func splitAndValidateHeader(record []byte) (header, payload []byte, err error) {
if len(record) < tlsRecordHeaderSize {
return nil, nil, fmt.Errorf("record was smaller than the header size")
}
header = record[:tlsRecordHeaderSize]
payload = record[tlsRecordHeaderSize:]
if header[0] != tlsApplicationData {
return nil, nil, fmt.Errorf("incorrect type in the header")
}
// Check the legacy record version, which should be 0x03, 0x03.
if header[1] != 0x03 || header[2] != 0x03 {
return nil, nil, fmt.Errorf("incorrect legacy record version in the header")
}
return header, payload, nil
}
// handleAlertMessage handles an alert message.
func (p *conn) handleAlertMessage() error {
if len(p.pendingApplicationData) != tlsAlertSize {
return errors.New("invalid alert message size")
}
alertType := p.pendingApplicationData[1]
// Clear the body of the alert message.
p.pendingApplicationData = p.pendingApplicationData[:0]
if alertType == byte(closeNotify) {
return errors.New("received a close notify alert")
}
// TODO(matthewstevenson88): Add support for more alert types.
return fmt.Errorf("received an unrecognized alert type: %v", alertType)
}
// parseHandshakeHeader parses a handshake message from the handshake buffer.
// It returns the message type, the message length, the message, the raw message
// that includes the type and length bytes and a flag indicating whether the
// handshake message has been fully parsed. i.e. whether the entire handshake
// message was in the handshake buffer.
func (p *conn) parseHandshakeMsg() (msgType byte, msgLen uint32, msg []byte, rawMsg []byte, ok bool) {
// Handle the case where the 4 byte handshake header is fragmented.
if len(p.handshakeBuf) < tlsHandshakePrefixSize {
return 0, 0, nil, nil, false
}
msgType = p.handshakeBuf[0]
msgLen = bigEndianInt24(p.handshakeBuf[tlsHandshakeMsgTypeSize : tlsHandshakeMsgTypeSize+tlsHandshakeLengthSize])
if msgLen > uint32(len(p.handshakeBuf)-tlsHandshakePrefixSize) {
return 0, 0, nil, nil, false
}
msg = p.handshakeBuf[tlsHandshakePrefixSize : tlsHandshakePrefixSize+msgLen]
rawMsg = p.handshakeBuf[:tlsHandshakeMsgTypeSize+tlsHandshakeLengthSize+msgLen]
p.handshakeBuf = p.handshakeBuf[tlsHandshakePrefixSize+msgLen:]
return msgType, msgLen, msg, rawMsg, true
}
// handleHandshakeMessage handles a handshake message. Note that the first
// complete handshake message from the handshake buffer is removed, if it
// exists.
func (p *conn) handleHandshakeMessage() error {
// Copy the pending application data to the handshake buffer. At this point,
// we are guaranteed that the pending application data contains only parts
// of a handshake message.
p.handshakeBuf = append(p.handshakeBuf, p.pendingApplicationData...)
p.pendingApplicationData = p.pendingApplicationData[:0]
// Several handshake messages may be coalesced into a single record.
// Continue reading them until the handshake buffer is empty.
for len(p.handshakeBuf) > 0 {
handshakeMsgType, msgLen, msg, rawMsg, ok := p.parseHandshakeMsg()
if !ok {
// The handshake could not be fully parsed, so read in another
// record and try again later.
break
}
switch handshakeMsgType {
case tlsHandshakeKeyUpdateType:
if msgLen != tlsHandshakeKeyUpdateMsgSize {
return errors.New("invalid handshake key update message length")
}
if len(p.handshakeBuf) != 0 {
return errors.New("key update message must be the last message of a handshake record")
}
if err := p.handleKeyUpdateMsg(msg); err != nil {
return err
}
case tlsHandshakeNewSessionTicketType:
// Ignore tickets that are received after a batch of tickets has
// been sent to S2A.
if p.ticketState == notReceivingTickets {
continue
}
if p.ticketState == ticketsNotYetReceived {
p.ticketState = receivingTickets
}
p.sessionTickets = append(p.sessionTickets, rawMsg)
if len(p.sessionTickets) == maxAllowedTickets {
p.ticketState = notReceivingTickets
grpclog.Infof("Sending session tickets to S2A.")
p.ticketSender.sendTicketsToS2A(p.sessionTickets, p.callComplete)
}
default:
return errors.New("unknown handshake message type")
}
}
return nil
}
func buildKeyUpdateRequest() []byte {
b := make([]byte, tlsHandshakePrefixSize+tlsHandshakeKeyUpdateMsgSize)
b[0] = tlsHandshakeKeyUpdateType
b[1] = 0
b[2] = 0
b[3] = tlsHandshakeKeyUpdateMsgSize
b[4] = byte(updateNotRequested)
return b
}
// handleKeyUpdateMsg handles a key update message.
func (p *conn) handleKeyUpdateMsg(msg []byte) error {
keyUpdateRequest := msg[0]
if keyUpdateRequest != byte(updateNotRequested) &&
keyUpdateRequest != byte(updateRequested) {
return errors.New("invalid handshake key update message")
}
if err := p.inConn.UpdateKey(); err != nil {
return err
}
// Send a key update message back to the peer if requested.
if keyUpdateRequest == byte(updateRequested) {
p.writeMutex.Lock()
defer p.writeMutex.Unlock()
n, err := p.writeTLSRecord(preConstructedKeyUpdateMsg, byte(handshake))
if err != nil {
return err
}
if n != tlsHandshakePrefixSize+tlsHandshakeKeyUpdateMsgSize {
return errors.New("key update request message wrote less bytes than expected")
}
if err = p.outConn.UpdateKey(); err != nil {
return err
}
}
return nil
}
// bidEndianInt24 converts the given byte buffer of at least size 3 and
// outputs the resulting 24 bit integer as a uint32. This is needed because
// TLS 1.3 requires 3 byte integers, and the binary.BigEndian package does
// not provide a way to transform a byte buffer into a 3 byte integer.
func bigEndianInt24(b []byte) uint32 {
_ = b[2] // bounds check hint to compiler; see golang.org/issue/14808
return uint32(b[2]) | uint32(b[1])<<8 | uint32(b[0])<<16
}
func min(a, b int) int {
if a < b {
return a
}
return b
}