TUN-4597: Add a QUIC server skeleton

- Added a QUIC server to accept streams
- Unit test for this server also tests ALPN
- Temporary echo capability for HTTP ConnectionType
This commit is contained in:
Sudarsan Reddy
2021-08-03 10:04:02 +01:00
parent fd4000184c
commit ed024d0741
768 changed files with 84848 additions and 15639 deletions

View File

@@ -0,0 +1,155 @@
package handshake
import (
"crypto/cipher"
"encoding/binary"
"github.com/lucas-clemente/quic-go/internal/protocol"
"github.com/lucas-clemente/quic-go/internal/qtls"
"github.com/lucas-clemente/quic-go/internal/utils"
)
func createAEAD(suite *qtls.CipherSuiteTLS13, trafficSecret []byte) cipher.AEAD {
key := hkdfExpandLabel(suite.Hash, trafficSecret, []byte{}, "quic key", suite.KeyLen)
iv := hkdfExpandLabel(suite.Hash, trafficSecret, []byte{}, "quic iv", suite.IVLen())
return suite.AEAD(key, iv)
}
type longHeaderSealer struct {
aead cipher.AEAD
headerProtector headerProtector
// use a single slice to avoid allocations
nonceBuf []byte
}
var _ LongHeaderSealer = &longHeaderSealer{}
func newLongHeaderSealer(aead cipher.AEAD, headerProtector headerProtector) LongHeaderSealer {
return &longHeaderSealer{
aead: aead,
headerProtector: headerProtector,
nonceBuf: make([]byte, aead.NonceSize()),
}
}
func (s *longHeaderSealer) Seal(dst, src []byte, pn protocol.PacketNumber, ad []byte) []byte {
binary.BigEndian.PutUint64(s.nonceBuf[len(s.nonceBuf)-8:], uint64(pn))
// The AEAD we're using here will be the qtls.aeadAESGCM13.
// It uses the nonce provided here and XOR it with the IV.
return s.aead.Seal(dst, s.nonceBuf, src, ad)
}
func (s *longHeaderSealer) EncryptHeader(sample []byte, firstByte *byte, pnBytes []byte) {
s.headerProtector.EncryptHeader(sample, firstByte, pnBytes)
}
func (s *longHeaderSealer) Overhead() int {
return s.aead.Overhead()
}
type longHeaderOpener struct {
aead cipher.AEAD
headerProtector headerProtector
highestRcvdPN protocol.PacketNumber // highest packet number received (which could be successfully unprotected)
// use a single slice to avoid allocations
nonceBuf []byte
}
var _ LongHeaderOpener = &longHeaderOpener{}
func newLongHeaderOpener(aead cipher.AEAD, headerProtector headerProtector) LongHeaderOpener {
return &longHeaderOpener{
aead: aead,
headerProtector: headerProtector,
nonceBuf: make([]byte, aead.NonceSize()),
}
}
func (o *longHeaderOpener) DecodePacketNumber(wirePN protocol.PacketNumber, wirePNLen protocol.PacketNumberLen) protocol.PacketNumber {
return protocol.DecodePacketNumber(wirePNLen, o.highestRcvdPN, wirePN)
}
func (o *longHeaderOpener) Open(dst, src []byte, pn protocol.PacketNumber, ad []byte) ([]byte, error) {
binary.BigEndian.PutUint64(o.nonceBuf[len(o.nonceBuf)-8:], uint64(pn))
// The AEAD we're using here will be the qtls.aeadAESGCM13.
// It uses the nonce provided here and XOR it with the IV.
dec, err := o.aead.Open(dst, o.nonceBuf, src, ad)
if err == nil {
o.highestRcvdPN = utils.MaxPacketNumber(o.highestRcvdPN, pn)
} else {
err = ErrDecryptionFailed
}
return dec, err
}
func (o *longHeaderOpener) DecryptHeader(sample []byte, firstByte *byte, pnBytes []byte) {
o.headerProtector.DecryptHeader(sample, firstByte, pnBytes)
}
type handshakeSealer struct {
LongHeaderSealer
dropInitialKeys func()
dropped bool
}
func newHandshakeSealer(
aead cipher.AEAD,
headerProtector headerProtector,
dropInitialKeys func(),
perspective protocol.Perspective,
) LongHeaderSealer {
sealer := newLongHeaderSealer(aead, headerProtector)
// The client drops Initial keys when sending the first Handshake packet.
if perspective == protocol.PerspectiveServer {
return sealer
}
return &handshakeSealer{
LongHeaderSealer: sealer,
dropInitialKeys: dropInitialKeys,
}
}
func (s *handshakeSealer) Seal(dst, src []byte, pn protocol.PacketNumber, ad []byte) []byte {
data := s.LongHeaderSealer.Seal(dst, src, pn, ad)
if !s.dropped {
s.dropInitialKeys()
s.dropped = true
}
return data
}
type handshakeOpener struct {
LongHeaderOpener
dropInitialKeys func()
dropped bool
}
func newHandshakeOpener(
aead cipher.AEAD,
headerProtector headerProtector,
dropInitialKeys func(),
perspective protocol.Perspective,
) LongHeaderOpener {
opener := newLongHeaderOpener(aead, headerProtector)
// The server drops Initial keys when first successfully processing a Handshake packet.
if perspective == protocol.PerspectiveClient {
return opener
}
return &handshakeOpener{
LongHeaderOpener: opener,
dropInitialKeys: dropInitialKeys,
}
}
func (o *handshakeOpener) Open(dst, src []byte, pn protocol.PacketNumber, ad []byte) ([]byte, error) {
dec, err := o.LongHeaderOpener.Open(dst, src, pn, ad)
if err == nil && !o.dropped {
o.dropInitialKeys()
o.dropped = true
}
return dec, err
}

View File

@@ -0,0 +1,800 @@
package handshake
import (
"bytes"
"crypto/tls"
"errors"
"fmt"
"io"
"net"
"sync"
"time"
"github.com/lucas-clemente/quic-go/internal/protocol"
"github.com/lucas-clemente/quic-go/internal/qerr"
"github.com/lucas-clemente/quic-go/internal/qtls"
"github.com/lucas-clemente/quic-go/internal/utils"
"github.com/lucas-clemente/quic-go/internal/wire"
"github.com/lucas-clemente/quic-go/logging"
"github.com/lucas-clemente/quic-go/quicvarint"
)
// TLS unexpected_message alert
const alertUnexpectedMessage uint8 = 10
type messageType uint8
// TLS handshake message types.
const (
typeClientHello messageType = 1
typeServerHello messageType = 2
typeNewSessionTicket messageType = 4
typeEncryptedExtensions messageType = 8
typeCertificate messageType = 11
typeCertificateRequest messageType = 13
typeCertificateVerify messageType = 15
typeFinished messageType = 20
)
func (m messageType) String() string {
switch m {
case typeClientHello:
return "ClientHello"
case typeServerHello:
return "ServerHello"
case typeNewSessionTicket:
return "NewSessionTicket"
case typeEncryptedExtensions:
return "EncryptedExtensions"
case typeCertificate:
return "Certificate"
case typeCertificateRequest:
return "CertificateRequest"
case typeCertificateVerify:
return "CertificateVerify"
case typeFinished:
return "Finished"
default:
return fmt.Sprintf("unknown message type: %d", m)
}
}
const clientSessionStateRevision = 3
type conn struct {
localAddr, remoteAddr net.Addr
version protocol.VersionNumber
}
var _ ConnWithVersion = &conn{}
func newConn(local, remote net.Addr, version protocol.VersionNumber) ConnWithVersion {
return &conn{
localAddr: local,
remoteAddr: remote,
version: version,
}
}
var _ net.Conn = &conn{}
func (c *conn) Read([]byte) (int, error) { return 0, nil }
func (c *conn) Write([]byte) (int, error) { return 0, nil }
func (c *conn) Close() error { return nil }
func (c *conn) RemoteAddr() net.Addr { return c.remoteAddr }
func (c *conn) LocalAddr() net.Addr { return c.localAddr }
func (c *conn) SetReadDeadline(time.Time) error { return nil }
func (c *conn) SetWriteDeadline(time.Time) error { return nil }
func (c *conn) SetDeadline(time.Time) error { return nil }
func (c *conn) GetQUICVersion() protocol.VersionNumber { return c.version }
type cryptoSetup struct {
tlsConf *tls.Config
extraConf *qtls.ExtraConfig
conn *qtls.Conn
version protocol.VersionNumber
messageChan chan []byte
isReadingHandshakeMessage chan struct{}
readFirstHandshakeMessage bool
ourParams *wire.TransportParameters
peerParams *wire.TransportParameters
paramsChan <-chan []byte
runner handshakeRunner
alertChan chan uint8
// handshakeDone is closed as soon as the go routine running qtls.Handshake() returns
handshakeDone chan struct{}
// is closed when Close() is called
closeChan chan struct{}
zeroRTTParameters *wire.TransportParameters
clientHelloWritten bool
clientHelloWrittenChan chan *wire.TransportParameters
rttStats *utils.RTTStats
tracer logging.ConnectionTracer
logger utils.Logger
perspective protocol.Perspective
mutex sync.Mutex // protects all members below
handshakeCompleteTime time.Time
readEncLevel protocol.EncryptionLevel
writeEncLevel protocol.EncryptionLevel
zeroRTTOpener LongHeaderOpener // only set for the server
zeroRTTSealer LongHeaderSealer // only set for the client
initialStream io.Writer
initialOpener LongHeaderOpener
initialSealer LongHeaderSealer
handshakeStream io.Writer
handshakeOpener LongHeaderOpener
handshakeSealer LongHeaderSealer
aead *updatableAEAD
has1RTTSealer bool
has1RTTOpener bool
}
var (
_ qtls.RecordLayer = &cryptoSetup{}
_ CryptoSetup = &cryptoSetup{}
)
// NewCryptoSetupClient creates a new crypto setup for the client
func NewCryptoSetupClient(
initialStream io.Writer,
handshakeStream io.Writer,
connID protocol.ConnectionID,
localAddr net.Addr,
remoteAddr net.Addr,
tp *wire.TransportParameters,
runner handshakeRunner,
tlsConf *tls.Config,
enable0RTT bool,
rttStats *utils.RTTStats,
tracer logging.ConnectionTracer,
logger utils.Logger,
version protocol.VersionNumber,
) (CryptoSetup, <-chan *wire.TransportParameters /* ClientHello written. Receive nil for non-0-RTT */) {
cs, clientHelloWritten := newCryptoSetup(
initialStream,
handshakeStream,
connID,
tp,
runner,
tlsConf,
enable0RTT,
rttStats,
tracer,
logger,
protocol.PerspectiveClient,
version,
)
cs.conn = qtls.Client(newConn(localAddr, remoteAddr, version), cs.tlsConf, cs.extraConf)
return cs, clientHelloWritten
}
// NewCryptoSetupServer creates a new crypto setup for the server
func NewCryptoSetupServer(
initialStream io.Writer,
handshakeStream io.Writer,
connID protocol.ConnectionID,
localAddr net.Addr,
remoteAddr net.Addr,
tp *wire.TransportParameters,
runner handshakeRunner,
tlsConf *tls.Config,
enable0RTT bool,
rttStats *utils.RTTStats,
tracer logging.ConnectionTracer,
logger utils.Logger,
version protocol.VersionNumber,
) CryptoSetup {
cs, _ := newCryptoSetup(
initialStream,
handshakeStream,
connID,
tp,
runner,
tlsConf,
enable0RTT,
rttStats,
tracer,
logger,
protocol.PerspectiveServer,
version,
)
cs.conn = qtls.Server(newConn(localAddr, remoteAddr, version), cs.tlsConf, cs.extraConf)
return cs
}
func newCryptoSetup(
initialStream io.Writer,
handshakeStream io.Writer,
connID protocol.ConnectionID,
tp *wire.TransportParameters,
runner handshakeRunner,
tlsConf *tls.Config,
enable0RTT bool,
rttStats *utils.RTTStats,
tracer logging.ConnectionTracer,
logger utils.Logger,
perspective protocol.Perspective,
version protocol.VersionNumber,
) (*cryptoSetup, <-chan *wire.TransportParameters /* ClientHello written. Receive nil for non-0-RTT */) {
initialSealer, initialOpener := NewInitialAEAD(connID, perspective, version)
if tracer != nil {
tracer.UpdatedKeyFromTLS(protocol.EncryptionInitial, protocol.PerspectiveClient)
tracer.UpdatedKeyFromTLS(protocol.EncryptionInitial, protocol.PerspectiveServer)
}
extHandler := newExtensionHandler(tp.Marshal(perspective), perspective, version)
cs := &cryptoSetup{
tlsConf: tlsConf,
initialStream: initialStream,
initialSealer: initialSealer,
initialOpener: initialOpener,
handshakeStream: handshakeStream,
aead: newUpdatableAEAD(rttStats, tracer, logger),
readEncLevel: protocol.EncryptionInitial,
writeEncLevel: protocol.EncryptionInitial,
runner: runner,
ourParams: tp,
paramsChan: extHandler.TransportParameters(),
rttStats: rttStats,
tracer: tracer,
logger: logger,
perspective: perspective,
handshakeDone: make(chan struct{}),
alertChan: make(chan uint8),
clientHelloWrittenChan: make(chan *wire.TransportParameters, 1),
messageChan: make(chan []byte, 100),
isReadingHandshakeMessage: make(chan struct{}),
closeChan: make(chan struct{}),
version: version,
}
var maxEarlyData uint32
if enable0RTT {
maxEarlyData = 0xffffffff
}
cs.extraConf = &qtls.ExtraConfig{
GetExtensions: extHandler.GetExtensions,
ReceivedExtensions: extHandler.ReceivedExtensions,
AlternativeRecordLayer: cs,
EnforceNextProtoSelection: true,
MaxEarlyData: maxEarlyData,
Accept0RTT: cs.accept0RTT,
Rejected0RTT: cs.rejected0RTT,
Enable0RTT: enable0RTT,
GetAppDataForSessionState: cs.marshalDataForSessionState,
SetAppDataFromSessionState: cs.handleDataFromSessionState,
}
return cs, cs.clientHelloWrittenChan
}
func (h *cryptoSetup) ChangeConnectionID(id protocol.ConnectionID) {
initialSealer, initialOpener := NewInitialAEAD(id, h.perspective, h.version)
h.initialSealer = initialSealer
h.initialOpener = initialOpener
if h.tracer != nil {
h.tracer.UpdatedKeyFromTLS(protocol.EncryptionInitial, protocol.PerspectiveClient)
h.tracer.UpdatedKeyFromTLS(protocol.EncryptionInitial, protocol.PerspectiveServer)
}
}
func (h *cryptoSetup) SetLargest1RTTAcked(pn protocol.PacketNumber) error {
return h.aead.SetLargestAcked(pn)
}
func (h *cryptoSetup) RunHandshake() {
// Handle errors that might occur when HandleData() is called.
handshakeComplete := make(chan struct{})
handshakeErrChan := make(chan error, 1)
go func() {
defer close(h.handshakeDone)
if err := h.conn.Handshake(); err != nil {
handshakeErrChan <- err
return
}
close(handshakeComplete)
}()
select {
case <-handshakeComplete: // return when the handshake is done
h.mutex.Lock()
h.handshakeCompleteTime = time.Now()
h.mutex.Unlock()
h.runner.OnHandshakeComplete()
case <-h.closeChan:
// wait until the Handshake() go routine has returned
<-h.handshakeDone
case alert := <-h.alertChan:
handshakeErr := <-handshakeErrChan
h.onError(alert, handshakeErr.Error())
}
}
func (h *cryptoSetup) onError(alert uint8, message string) {
h.runner.OnError(qerr.NewCryptoError(alert, message))
}
// Close closes the crypto setup.
// It aborts the handshake, if it is still running.
// It must only be called once.
func (h *cryptoSetup) Close() error {
close(h.closeChan)
// wait until qtls.Handshake() actually returned
<-h.handshakeDone
return nil
}
// handleMessage handles a TLS handshake message.
// It is called by the crypto streams when a new message is available.
// It returns if it is done with messages on the same encryption level.
func (h *cryptoSetup) HandleMessage(data []byte, encLevel protocol.EncryptionLevel) bool /* stream finished */ {
msgType := messageType(data[0])
h.logger.Debugf("Received %s message (%d bytes, encryption level: %s)", msgType, len(data), encLevel)
if err := h.checkEncryptionLevel(msgType, encLevel); err != nil {
h.onError(alertUnexpectedMessage, err.Error())
return false
}
h.messageChan <- data
if encLevel == protocol.Encryption1RTT {
h.handlePostHandshakeMessage()
return false
}
readLoop:
for {
select {
case data := <-h.paramsChan:
if data == nil {
h.onError(0x6d, "missing quic_transport_parameters extension")
} else {
h.handleTransportParameters(data)
}
case <-h.isReadingHandshakeMessage:
break readLoop
case <-h.handshakeDone:
break readLoop
case <-h.closeChan:
break readLoop
}
}
// We're done with the Initial encryption level after processing a ClientHello / ServerHello,
// but only if a handshake opener and sealer was created.
// Otherwise, a HelloRetryRequest was performed.
// We're done with the Handshake encryption level after processing the Finished message.
return ((msgType == typeClientHello || msgType == typeServerHello) && h.handshakeOpener != nil && h.handshakeSealer != nil) ||
msgType == typeFinished
}
func (h *cryptoSetup) checkEncryptionLevel(msgType messageType, encLevel protocol.EncryptionLevel) error {
var expected protocol.EncryptionLevel
switch msgType {
case typeClientHello,
typeServerHello:
expected = protocol.EncryptionInitial
case typeEncryptedExtensions,
typeCertificate,
typeCertificateRequest,
typeCertificateVerify,
typeFinished:
expected = protocol.EncryptionHandshake
case typeNewSessionTicket:
expected = protocol.Encryption1RTT
default:
return fmt.Errorf("unexpected handshake message: %d", msgType)
}
if encLevel != expected {
return fmt.Errorf("expected handshake message %s to have encryption level %s, has %s", msgType, expected, encLevel)
}
return nil
}
func (h *cryptoSetup) handleTransportParameters(data []byte) {
var tp wire.TransportParameters
if err := tp.Unmarshal(data, h.perspective.Opposite()); err != nil {
h.runner.OnError(&qerr.TransportError{
ErrorCode: qerr.TransportParameterError,
ErrorMessage: err.Error(),
})
}
h.peerParams = &tp
h.runner.OnReceivedParams(h.peerParams)
}
// must be called after receiving the transport parameters
func (h *cryptoSetup) marshalDataForSessionState() []byte {
buf := &bytes.Buffer{}
quicvarint.Write(buf, clientSessionStateRevision)
quicvarint.Write(buf, uint64(h.rttStats.SmoothedRTT().Microseconds()))
h.peerParams.MarshalForSessionTicket(buf)
return buf.Bytes()
}
func (h *cryptoSetup) handleDataFromSessionState(data []byte) {
tp, err := h.handleDataFromSessionStateImpl(data)
if err != nil {
h.logger.Debugf("Restoring of transport parameters from session ticket failed: %s", err.Error())
return
}
h.zeroRTTParameters = tp
}
func (h *cryptoSetup) handleDataFromSessionStateImpl(data []byte) (*wire.TransportParameters, error) {
r := bytes.NewReader(data)
ver, err := quicvarint.Read(r)
if err != nil {
return nil, err
}
if ver != clientSessionStateRevision {
return nil, fmt.Errorf("mismatching version. Got %d, expected %d", ver, clientSessionStateRevision)
}
rtt, err := quicvarint.Read(r)
if err != nil {
return nil, err
}
h.rttStats.SetInitialRTT(time.Duration(rtt) * time.Microsecond)
var tp wire.TransportParameters
if err := tp.UnmarshalFromSessionTicket(r); err != nil {
return nil, err
}
return &tp, nil
}
// only valid for the server
func (h *cryptoSetup) GetSessionTicket() ([]byte, error) {
var appData []byte
// Save transport parameters to the session ticket if we're allowing 0-RTT.
if h.extraConf.MaxEarlyData > 0 {
appData = (&sessionTicket{
Parameters: h.ourParams,
RTT: h.rttStats.SmoothedRTT(),
}).Marshal()
}
return h.conn.GetSessionTicket(appData)
}
// accept0RTT is called for the server when receiving the client's session ticket.
// It decides whether to accept 0-RTT.
func (h *cryptoSetup) accept0RTT(sessionTicketData []byte) bool {
var t sessionTicket
if err := t.Unmarshal(sessionTicketData); err != nil {
h.logger.Debugf("Unmarshalling transport parameters from session ticket failed: %s", err.Error())
return false
}
valid := h.ourParams.ValidFor0RTT(t.Parameters)
if valid {
h.logger.Debugf("Accepting 0-RTT. Restoring RTT from session ticket: %s", t.RTT)
h.rttStats.SetInitialRTT(t.RTT)
} else {
h.logger.Debugf("Transport parameters changed. Rejecting 0-RTT.")
}
return valid
}
// rejected0RTT is called for the client when the server rejects 0-RTT.
func (h *cryptoSetup) rejected0RTT() {
h.logger.Debugf("0-RTT was rejected. Dropping 0-RTT keys.")
h.mutex.Lock()
had0RTTKeys := h.zeroRTTSealer != nil
h.zeroRTTSealer = nil
h.mutex.Unlock()
if had0RTTKeys {
h.runner.DropKeys(protocol.Encryption0RTT)
}
}
func (h *cryptoSetup) handlePostHandshakeMessage() {
// make sure the handshake has already completed
<-h.handshakeDone
done := make(chan struct{})
defer close(done)
// h.alertChan is an unbuffered channel.
// If an error occurs during conn.HandlePostHandshakeMessage,
// it will be sent on this channel.
// Read it from a go-routine so that HandlePostHandshakeMessage doesn't deadlock.
alertChan := make(chan uint8, 1)
go func() {
<-h.isReadingHandshakeMessage
select {
case alert := <-h.alertChan:
alertChan <- alert
case <-done:
}
}()
if err := h.conn.HandlePostHandshakeMessage(); err != nil {
select {
case <-h.closeChan:
case alert := <-alertChan:
h.onError(alert, err.Error())
}
}
}
// ReadHandshakeMessage is called by TLS.
// It blocks until a new handshake message is available.
func (h *cryptoSetup) ReadHandshakeMessage() ([]byte, error) {
if !h.readFirstHandshakeMessage {
h.readFirstHandshakeMessage = true
} else {
select {
case h.isReadingHandshakeMessage <- struct{}{}:
case <-h.closeChan:
return nil, errors.New("error while handling the handshake message")
}
}
select {
case msg := <-h.messageChan:
return msg, nil
case <-h.closeChan:
return nil, errors.New("error while handling the handshake message")
}
}
func (h *cryptoSetup) SetReadKey(encLevel qtls.EncryptionLevel, suite *qtls.CipherSuiteTLS13, trafficSecret []byte) {
h.mutex.Lock()
switch encLevel {
case qtls.Encryption0RTT:
if h.perspective == protocol.PerspectiveClient {
panic("Received 0-RTT read key for the client")
}
h.zeroRTTOpener = newLongHeaderOpener(
createAEAD(suite, trafficSecret),
newHeaderProtector(suite, trafficSecret, true),
)
h.mutex.Unlock()
h.logger.Debugf("Installed 0-RTT Read keys (using %s)", tls.CipherSuiteName(suite.ID))
if h.tracer != nil {
h.tracer.UpdatedKeyFromTLS(protocol.Encryption0RTT, h.perspective.Opposite())
}
return
case qtls.EncryptionHandshake:
h.readEncLevel = protocol.EncryptionHandshake
h.handshakeOpener = newHandshakeOpener(
createAEAD(suite, trafficSecret),
newHeaderProtector(suite, trafficSecret, true),
h.dropInitialKeys,
h.perspective,
)
h.logger.Debugf("Installed Handshake Read keys (using %s)", tls.CipherSuiteName(suite.ID))
case qtls.EncryptionApplication:
h.readEncLevel = protocol.Encryption1RTT
h.aead.SetReadKey(suite, trafficSecret)
h.has1RTTOpener = true
h.logger.Debugf("Installed 1-RTT Read keys (using %s)", tls.CipherSuiteName(suite.ID))
default:
panic("unexpected read encryption level")
}
h.mutex.Unlock()
if h.tracer != nil {
h.tracer.UpdatedKeyFromTLS(h.readEncLevel, h.perspective.Opposite())
}
}
func (h *cryptoSetup) SetWriteKey(encLevel qtls.EncryptionLevel, suite *qtls.CipherSuiteTLS13, trafficSecret []byte) {
h.mutex.Lock()
switch encLevel {
case qtls.Encryption0RTT:
if h.perspective == protocol.PerspectiveServer {
panic("Received 0-RTT write key for the server")
}
h.zeroRTTSealer = newLongHeaderSealer(
createAEAD(suite, trafficSecret),
newHeaderProtector(suite, trafficSecret, true),
)
h.mutex.Unlock()
h.logger.Debugf("Installed 0-RTT Write keys (using %s)", tls.CipherSuiteName(suite.ID))
if h.tracer != nil {
h.tracer.UpdatedKeyFromTLS(protocol.Encryption0RTT, h.perspective)
}
return
case qtls.EncryptionHandshake:
h.writeEncLevel = protocol.EncryptionHandshake
h.handshakeSealer = newHandshakeSealer(
createAEAD(suite, trafficSecret),
newHeaderProtector(suite, trafficSecret, true),
h.dropInitialKeys,
h.perspective,
)
h.logger.Debugf("Installed Handshake Write keys (using %s)", tls.CipherSuiteName(suite.ID))
case qtls.EncryptionApplication:
h.writeEncLevel = protocol.Encryption1RTT
h.aead.SetWriteKey(suite, trafficSecret)
h.has1RTTSealer = true
h.logger.Debugf("Installed 1-RTT Write keys (using %s)", tls.CipherSuiteName(suite.ID))
if h.zeroRTTSealer != nil {
h.zeroRTTSealer = nil
h.logger.Debugf("Dropping 0-RTT keys.")
if h.tracer != nil {
h.tracer.DroppedEncryptionLevel(protocol.Encryption0RTT)
}
}
default:
panic("unexpected write encryption level")
}
h.mutex.Unlock()
if h.tracer != nil {
h.tracer.UpdatedKeyFromTLS(h.writeEncLevel, h.perspective)
}
}
// WriteRecord is called when TLS writes data
func (h *cryptoSetup) WriteRecord(p []byte) (int, error) {
h.mutex.Lock()
defer h.mutex.Unlock()
//nolint:exhaustive // LS records can only be written for Initial and Handshake.
switch h.writeEncLevel {
case protocol.EncryptionInitial:
// assume that the first WriteRecord call contains the ClientHello
n, err := h.initialStream.Write(p)
if !h.clientHelloWritten && h.perspective == protocol.PerspectiveClient {
h.clientHelloWritten = true
if h.zeroRTTSealer != nil && h.zeroRTTParameters != nil {
h.logger.Debugf("Doing 0-RTT.")
h.clientHelloWrittenChan <- h.zeroRTTParameters
} else {
h.logger.Debugf("Not doing 0-RTT.")
h.clientHelloWrittenChan <- nil
}
}
return n, err
case protocol.EncryptionHandshake:
return h.handshakeStream.Write(p)
default:
panic(fmt.Sprintf("unexpected write encryption level: %s", h.writeEncLevel))
}
}
func (h *cryptoSetup) SendAlert(alert uint8) {
select {
case h.alertChan <- alert:
case <-h.closeChan:
// no need to send an alert when we've already closed
}
}
// used a callback in the handshakeSealer and handshakeOpener
func (h *cryptoSetup) dropInitialKeys() {
h.mutex.Lock()
h.initialOpener = nil
h.initialSealer = nil
h.mutex.Unlock()
h.runner.DropKeys(protocol.EncryptionInitial)
h.logger.Debugf("Dropping Initial keys.")
}
func (h *cryptoSetup) SetHandshakeConfirmed() {
h.aead.SetHandshakeConfirmed()
// drop Handshake keys
var dropped bool
h.mutex.Lock()
if h.handshakeOpener != nil {
h.handshakeOpener = nil
h.handshakeSealer = nil
dropped = true
}
h.mutex.Unlock()
if dropped {
h.runner.DropKeys(protocol.EncryptionHandshake)
h.logger.Debugf("Dropping Handshake keys.")
}
}
func (h *cryptoSetup) GetInitialSealer() (LongHeaderSealer, error) {
h.mutex.Lock()
defer h.mutex.Unlock()
if h.initialSealer == nil {
return nil, ErrKeysDropped
}
return h.initialSealer, nil
}
func (h *cryptoSetup) Get0RTTSealer() (LongHeaderSealer, error) {
h.mutex.Lock()
defer h.mutex.Unlock()
if h.zeroRTTSealer == nil {
return nil, ErrKeysDropped
}
return h.zeroRTTSealer, nil
}
func (h *cryptoSetup) GetHandshakeSealer() (LongHeaderSealer, error) {
h.mutex.Lock()
defer h.mutex.Unlock()
if h.handshakeSealer == nil {
if h.initialSealer == nil {
return nil, ErrKeysDropped
}
return nil, ErrKeysNotYetAvailable
}
return h.handshakeSealer, nil
}
func (h *cryptoSetup) Get1RTTSealer() (ShortHeaderSealer, error) {
h.mutex.Lock()
defer h.mutex.Unlock()
if !h.has1RTTSealer {
return nil, ErrKeysNotYetAvailable
}
return h.aead, nil
}
func (h *cryptoSetup) GetInitialOpener() (LongHeaderOpener, error) {
h.mutex.Lock()
defer h.mutex.Unlock()
if h.initialOpener == nil {
return nil, ErrKeysDropped
}
return h.initialOpener, nil
}
func (h *cryptoSetup) Get0RTTOpener() (LongHeaderOpener, error) {
h.mutex.Lock()
defer h.mutex.Unlock()
if h.zeroRTTOpener == nil {
if h.initialOpener != nil {
return nil, ErrKeysNotYetAvailable
}
// if the initial opener is also not available, the keys were already dropped
return nil, ErrKeysDropped
}
return h.zeroRTTOpener, nil
}
func (h *cryptoSetup) GetHandshakeOpener() (LongHeaderOpener, error) {
h.mutex.Lock()
defer h.mutex.Unlock()
if h.handshakeOpener == nil {
if h.initialOpener != nil {
return nil, ErrKeysNotYetAvailable
}
// if the initial opener is also not available, the keys were already dropped
return nil, ErrKeysDropped
}
return h.handshakeOpener, nil
}
func (h *cryptoSetup) Get1RTTOpener() (ShortHeaderOpener, error) {
h.mutex.Lock()
defer h.mutex.Unlock()
if h.zeroRTTOpener != nil && time.Since(h.handshakeCompleteTime) > 3*h.rttStats.PTO(true) {
h.zeroRTTOpener = nil
h.logger.Debugf("Dropping 0-RTT keys.")
if h.tracer != nil {
h.tracer.DroppedEncryptionLevel(protocol.Encryption0RTT)
}
}
if !h.has1RTTOpener {
return nil, ErrKeysNotYetAvailable
}
return h.aead, nil
}
func (h *cryptoSetup) ConnectionState() ConnectionState {
return qtls.GetConnectionState(h.conn)
}

View File

@@ -0,0 +1,127 @@
package handshake
import (
"crypto/aes"
"crypto/cipher"
"crypto/tls"
"encoding/binary"
"fmt"
"golang.org/x/crypto/chacha20"
"github.com/lucas-clemente/quic-go/internal/qtls"
)
type headerProtector interface {
EncryptHeader(sample []byte, firstByte *byte, hdrBytes []byte)
DecryptHeader(sample []byte, firstByte *byte, hdrBytes []byte)
}
func newHeaderProtector(suite *qtls.CipherSuiteTLS13, trafficSecret []byte, isLongHeader bool) headerProtector {
switch suite.ID {
case tls.TLS_AES_128_GCM_SHA256, tls.TLS_AES_256_GCM_SHA384:
return newAESHeaderProtector(suite, trafficSecret, isLongHeader)
case tls.TLS_CHACHA20_POLY1305_SHA256:
return newChaChaHeaderProtector(suite, trafficSecret, isLongHeader)
default:
panic(fmt.Sprintf("Invalid cipher suite id: %d", suite.ID))
}
}
type aesHeaderProtector struct {
mask []byte
block cipher.Block
isLongHeader bool
}
var _ headerProtector = &aesHeaderProtector{}
func newAESHeaderProtector(suite *qtls.CipherSuiteTLS13, trafficSecret []byte, isLongHeader bool) headerProtector {
hpKey := hkdfExpandLabel(suite.Hash, trafficSecret, []byte{}, "quic hp", suite.KeyLen)
block, err := aes.NewCipher(hpKey)
if err != nil {
panic(fmt.Sprintf("error creating new AES cipher: %s", err))
}
return &aesHeaderProtector{
block: block,
mask: make([]byte, block.BlockSize()),
isLongHeader: isLongHeader,
}
}
func (p *aesHeaderProtector) DecryptHeader(sample []byte, firstByte *byte, hdrBytes []byte) {
p.apply(sample, firstByte, hdrBytes)
}
func (p *aesHeaderProtector) EncryptHeader(sample []byte, firstByte *byte, hdrBytes []byte) {
p.apply(sample, firstByte, hdrBytes)
}
func (p *aesHeaderProtector) apply(sample []byte, firstByte *byte, hdrBytes []byte) {
if len(sample) != len(p.mask) {
panic("invalid sample size")
}
p.block.Encrypt(p.mask, sample)
if p.isLongHeader {
*firstByte ^= p.mask[0] & 0xf
} else {
*firstByte ^= p.mask[0] & 0x1f
}
for i := range hdrBytes {
hdrBytes[i] ^= p.mask[i+1]
}
}
type chachaHeaderProtector struct {
mask [5]byte
key [32]byte
isLongHeader bool
}
var _ headerProtector = &chachaHeaderProtector{}
func newChaChaHeaderProtector(suite *qtls.CipherSuiteTLS13, trafficSecret []byte, isLongHeader bool) headerProtector {
hpKey := hkdfExpandLabel(suite.Hash, trafficSecret, []byte{}, "quic hp", suite.KeyLen)
p := &chachaHeaderProtector{
isLongHeader: isLongHeader,
}
copy(p.key[:], hpKey)
return p
}
func (p *chachaHeaderProtector) DecryptHeader(sample []byte, firstByte *byte, hdrBytes []byte) {
p.apply(sample, firstByte, hdrBytes)
}
func (p *chachaHeaderProtector) EncryptHeader(sample []byte, firstByte *byte, hdrBytes []byte) {
p.apply(sample, firstByte, hdrBytes)
}
func (p *chachaHeaderProtector) apply(sample []byte, firstByte *byte, hdrBytes []byte) {
if len(sample) != 16 {
panic("invalid sample size")
}
for i := 0; i < 5; i++ {
p.mask[i] = 0
}
cipher, err := chacha20.NewUnauthenticatedCipher(p.key[:], sample[4:])
if err != nil {
panic(err)
}
cipher.SetCounter(binary.LittleEndian.Uint32(sample[:4]))
cipher.XORKeyStream(p.mask[:], p.mask[:])
p.applyMask(firstByte, hdrBytes)
}
func (p *chachaHeaderProtector) applyMask(firstByte *byte, hdrBytes []byte) {
if p.isLongHeader {
*firstByte ^= p.mask[0] & 0xf
} else {
*firstByte ^= p.mask[0] & 0x1f
}
for i := range hdrBytes {
hdrBytes[i] ^= p.mask[i+1]
}
}

View File

@@ -0,0 +1,29 @@
package handshake
import (
"crypto"
"encoding/binary"
"golang.org/x/crypto/hkdf"
)
// hkdfExpandLabel HKDF expands a label.
// Since this implementation avoids using a cryptobyte.Builder, it is about 15% faster than the
// hkdfExpandLabel in the standard library.
func hkdfExpandLabel(hash crypto.Hash, secret, context []byte, label string, length int) []byte {
b := make([]byte, 3, 3+6+len(label)+1+len(context))
binary.BigEndian.PutUint16(b, uint16(length))
b[2] = uint8(6 + len(label))
b = append(b, []byte("tls13 ")...)
b = append(b, []byte(label)...)
b = b[:3+6+len(label)+1]
b[3+6+len(label)] = uint8(len(context))
b = append(b, context...)
out := make([]byte, length)
n, err := hkdf.Expand(hash.New, secret, b).Read(out)
if err != nil || n != length {
panic("quic: HKDF-Expand-Label invocation failed unexpectedly")
}
return out
}

View File

@@ -0,0 +1,64 @@
package handshake
import (
"crypto"
"crypto/tls"
"golang.org/x/crypto/hkdf"
"github.com/lucas-clemente/quic-go/internal/protocol"
"github.com/lucas-clemente/quic-go/internal/qtls"
)
var (
quicSaltOld = []byte{0xaf, 0xbf, 0xec, 0x28, 0x99, 0x93, 0xd2, 0x4c, 0x9e, 0x97, 0x86, 0xf1, 0x9c, 0x61, 0x11, 0xe0, 0x43, 0x90, 0xa8, 0x99}
quicSaltDraft34 = []byte{0x38, 0x76, 0x2c, 0xf7, 0xf5, 0x59, 0x34, 0xb3, 0x4d, 0x17, 0x9a, 0xe6, 0xa4, 0xc8, 0x0c, 0xad, 0xcc, 0xbb, 0x7f, 0x0a}
)
func getSalt(v protocol.VersionNumber) []byte {
if v == protocol.VersionDraft34 || v == protocol.Version1 {
return quicSaltDraft34
}
return quicSaltOld
}
var initialSuite = &qtls.CipherSuiteTLS13{
ID: tls.TLS_AES_128_GCM_SHA256,
KeyLen: 16,
AEAD: qtls.AEADAESGCMTLS13,
Hash: crypto.SHA256,
}
// NewInitialAEAD creates a new AEAD for Initial encryption / decryption.
func NewInitialAEAD(connID protocol.ConnectionID, pers protocol.Perspective, v protocol.VersionNumber) (LongHeaderSealer, LongHeaderOpener) {
clientSecret, serverSecret := computeSecrets(connID, v)
var mySecret, otherSecret []byte
if pers == protocol.PerspectiveClient {
mySecret = clientSecret
otherSecret = serverSecret
} else {
mySecret = serverSecret
otherSecret = clientSecret
}
myKey, myIV := computeInitialKeyAndIV(mySecret)
otherKey, otherIV := computeInitialKeyAndIV(otherSecret)
encrypter := qtls.AEADAESGCMTLS13(myKey, myIV)
decrypter := qtls.AEADAESGCMTLS13(otherKey, otherIV)
return newLongHeaderSealer(encrypter, newHeaderProtector(initialSuite, mySecret, true)),
newLongHeaderOpener(decrypter, newAESHeaderProtector(initialSuite, otherSecret, true))
}
func computeSecrets(connID protocol.ConnectionID, v protocol.VersionNumber) (clientSecret, serverSecret []byte) {
initialSecret := hkdf.Extract(crypto.SHA256.New, connID, getSalt(v))
clientSecret = hkdfExpandLabel(crypto.SHA256, initialSecret, []byte{}, "client in", crypto.SHA256.Size())
serverSecret = hkdfExpandLabel(crypto.SHA256, initialSecret, []byte{}, "server in", crypto.SHA256.Size())
return
}
func computeInitialKeyAndIV(secret []byte) (key, iv []byte) {
key = hkdfExpandLabel(crypto.SHA256, secret, []byte{}, "quic key", 16)
iv = hkdfExpandLabel(crypto.SHA256, secret, []byte{}, "quic iv", 12)
return
}

View File

@@ -0,0 +1,102 @@
package handshake
import (
"errors"
"io"
"net"
"time"
"github.com/lucas-clemente/quic-go/internal/protocol"
"github.com/lucas-clemente/quic-go/internal/qtls"
"github.com/lucas-clemente/quic-go/internal/wire"
)
var (
// ErrKeysNotYetAvailable is returned when an opener or a sealer is requested for an encryption level,
// but the corresponding opener has not yet been initialized
// This can happen when packets arrive out of order.
ErrKeysNotYetAvailable = errors.New("CryptoSetup: keys at this encryption level not yet available")
// ErrKeysDropped is returned when an opener or a sealer is requested for an encryption level,
// but the corresponding keys have already been dropped.
ErrKeysDropped = errors.New("CryptoSetup: keys were already dropped")
// ErrDecryptionFailed is returned when the AEAD fails to open the packet.
ErrDecryptionFailed = errors.New("decryption failed")
)
// ConnectionState contains information about the state of the connection.
type ConnectionState = qtls.ConnectionState
type headerDecryptor interface {
DecryptHeader(sample []byte, firstByte *byte, pnBytes []byte)
}
// LongHeaderOpener opens a long header packet
type LongHeaderOpener interface {
headerDecryptor
DecodePacketNumber(wirePN protocol.PacketNumber, wirePNLen protocol.PacketNumberLen) protocol.PacketNumber
Open(dst, src []byte, pn protocol.PacketNumber, associatedData []byte) ([]byte, error)
}
// ShortHeaderOpener opens a short header packet
type ShortHeaderOpener interface {
headerDecryptor
DecodePacketNumber(wirePN protocol.PacketNumber, wirePNLen protocol.PacketNumberLen) protocol.PacketNumber
Open(dst, src []byte, rcvTime time.Time, pn protocol.PacketNumber, kp protocol.KeyPhaseBit, associatedData []byte) ([]byte, error)
}
// LongHeaderSealer seals a long header packet
type LongHeaderSealer interface {
Seal(dst, src []byte, packetNumber protocol.PacketNumber, associatedData []byte) []byte
EncryptHeader(sample []byte, firstByte *byte, pnBytes []byte)
Overhead() int
}
// ShortHeaderSealer seals a short header packet
type ShortHeaderSealer interface {
LongHeaderSealer
KeyPhase() protocol.KeyPhaseBit
}
// A tlsExtensionHandler sends and received the QUIC TLS extension.
type tlsExtensionHandler interface {
GetExtensions(msgType uint8) []qtls.Extension
ReceivedExtensions(msgType uint8, exts []qtls.Extension)
TransportParameters() <-chan []byte
}
type handshakeRunner interface {
OnReceivedParams(*wire.TransportParameters)
OnHandshakeComplete()
OnError(error)
DropKeys(protocol.EncryptionLevel)
}
// CryptoSetup handles the handshake and protecting / unprotecting packets
type CryptoSetup interface {
RunHandshake()
io.Closer
ChangeConnectionID(protocol.ConnectionID)
GetSessionTicket() ([]byte, error)
HandleMessage([]byte, protocol.EncryptionLevel) bool
SetLargest1RTTAcked(protocol.PacketNumber) error
SetHandshakeConfirmed()
ConnectionState() ConnectionState
GetInitialOpener() (LongHeaderOpener, error)
GetHandshakeOpener() (LongHeaderOpener, error)
Get0RTTOpener() (LongHeaderOpener, error)
Get1RTTOpener() (ShortHeaderOpener, error)
GetInitialSealer() (LongHeaderSealer, error)
GetHandshakeSealer() (LongHeaderSealer, error)
Get0RTTSealer() (LongHeaderSealer, error)
Get1RTTSealer() (ShortHeaderSealer, error)
}
// ConnWithVersion is the connection used in the ClientHelloInfo.
// It can be used to determine the QUIC version in use.
type ConnWithVersion interface {
net.Conn
GetQUICVersion() protocol.VersionNumber
}

View File

@@ -0,0 +1,3 @@
package handshake
//go:generate sh -c "../../mockgen_private.sh handshake mock_handshake_runner_test.go github.com/lucas-clemente/quic-go/internal/handshake handshakeRunner"

View File

@@ -0,0 +1,62 @@
package handshake
import (
"bytes"
"crypto/aes"
"crypto/cipher"
"fmt"
"sync"
"github.com/lucas-clemente/quic-go/internal/protocol"
)
var (
oldRetryAEAD cipher.AEAD // used for QUIC draft versions up to 34
retryAEAD cipher.AEAD // used for QUIC draft-34
)
func init() {
oldRetryAEAD = initAEAD([16]byte{0xcc, 0xce, 0x18, 0x7e, 0xd0, 0x9a, 0x09, 0xd0, 0x57, 0x28, 0x15, 0x5a, 0x6c, 0xb9, 0x6b, 0xe1})
retryAEAD = initAEAD([16]byte{0xbe, 0x0c, 0x69, 0x0b, 0x9f, 0x66, 0x57, 0x5a, 0x1d, 0x76, 0x6b, 0x54, 0xe3, 0x68, 0xc8, 0x4e})
}
func initAEAD(key [16]byte) cipher.AEAD {
aes, err := aes.NewCipher(key[:])
if err != nil {
panic(err)
}
aead, err := cipher.NewGCM(aes)
if err != nil {
panic(err)
}
return aead
}
var (
retryBuf bytes.Buffer
retryMutex sync.Mutex
oldRetryNonce = [12]byte{0xe5, 0x49, 0x30, 0xf9, 0x7f, 0x21, 0x36, 0xf0, 0x53, 0x0a, 0x8c, 0x1c}
retryNonce = [12]byte{0x46, 0x15, 0x99, 0xd3, 0x5d, 0x63, 0x2b, 0xf2, 0x23, 0x98, 0x25, 0xbb}
)
// GetRetryIntegrityTag calculates the integrity tag on a Retry packet
func GetRetryIntegrityTag(retry []byte, origDestConnID protocol.ConnectionID, version protocol.VersionNumber) *[16]byte {
retryMutex.Lock()
retryBuf.WriteByte(uint8(origDestConnID.Len()))
retryBuf.Write(origDestConnID.Bytes())
retryBuf.Write(retry)
var tag [16]byte
var sealed []byte
if version != protocol.VersionDraft34 && version != protocol.Version1 {
sealed = oldRetryAEAD.Seal(tag[:0], oldRetryNonce[:], nil, retryBuf.Bytes())
} else {
sealed = retryAEAD.Seal(tag[:0], retryNonce[:], nil, retryBuf.Bytes())
}
if len(sealed) != 16 {
panic(fmt.Sprintf("unexpected Retry integrity tag length: %d", len(sealed)))
}
retryBuf.Reset()
retryMutex.Unlock()
return &tag
}

View File

@@ -0,0 +1,48 @@
package handshake
import (
"bytes"
"errors"
"fmt"
"time"
"github.com/lucas-clemente/quic-go/internal/wire"
"github.com/lucas-clemente/quic-go/quicvarint"
)
const sessionTicketRevision = 2
type sessionTicket struct {
Parameters *wire.TransportParameters
RTT time.Duration // to be encoded in mus
}
func (t *sessionTicket) Marshal() []byte {
b := &bytes.Buffer{}
quicvarint.Write(b, sessionTicketRevision)
quicvarint.Write(b, uint64(t.RTT.Microseconds()))
t.Parameters.MarshalForSessionTicket(b)
return b.Bytes()
}
func (t *sessionTicket) Unmarshal(b []byte) error {
r := bytes.NewReader(b)
rev, err := quicvarint.Read(r)
if err != nil {
return errors.New("failed to read session ticket revision")
}
if rev != sessionTicketRevision {
return fmt.Errorf("unknown session ticket revision: %d", rev)
}
rtt, err := quicvarint.Read(r)
if err != nil {
return errors.New("failed to read RTT")
}
var tp wire.TransportParameters
if err := tp.UnmarshalFromSessionTicket(r); err != nil {
return fmt.Errorf("unmarshaling transport parameters from session ticket failed: %s", err.Error())
}
t.Parameters = &tp
t.RTT = time.Duration(rtt) * time.Microsecond
return nil
}

View File

@@ -0,0 +1,68 @@
package handshake
import (
"github.com/lucas-clemente/quic-go/internal/protocol"
"github.com/lucas-clemente/quic-go/internal/qtls"
)
const (
quicTLSExtensionTypeOldDrafts = 0xffa5
quicTLSExtensionType = 0x39
)
type extensionHandler struct {
ourParams []byte
paramsChan chan []byte
extensionType uint16
perspective protocol.Perspective
}
var _ tlsExtensionHandler = &extensionHandler{}
// newExtensionHandler creates a new extension handler
func newExtensionHandler(params []byte, pers protocol.Perspective, v protocol.VersionNumber) tlsExtensionHandler {
et := uint16(quicTLSExtensionType)
if v != protocol.VersionDraft34 && v != protocol.Version1 {
et = quicTLSExtensionTypeOldDrafts
}
return &extensionHandler{
ourParams: params,
paramsChan: make(chan []byte),
perspective: pers,
extensionType: et,
}
}
func (h *extensionHandler) GetExtensions(msgType uint8) []qtls.Extension {
if (h.perspective == protocol.PerspectiveClient && messageType(msgType) != typeClientHello) ||
(h.perspective == protocol.PerspectiveServer && messageType(msgType) != typeEncryptedExtensions) {
return nil
}
return []qtls.Extension{{
Type: h.extensionType,
Data: h.ourParams,
}}
}
func (h *extensionHandler) ReceivedExtensions(msgType uint8, exts []qtls.Extension) {
if (h.perspective == protocol.PerspectiveClient && messageType(msgType) != typeEncryptedExtensions) ||
(h.perspective == protocol.PerspectiveServer && messageType(msgType) != typeClientHello) {
return
}
var data []byte
for _, ext := range exts {
if ext.Type == h.extensionType {
data = ext.Data
break
}
}
h.paramsChan <- data
}
func (h *extensionHandler) TransportParameters() <-chan []byte {
return h.paramsChan
}

View File

@@ -0,0 +1,134 @@
package handshake
import (
"encoding/asn1"
"fmt"
"io"
"net"
"time"
"github.com/lucas-clemente/quic-go/internal/protocol"
)
const (
tokenPrefixIP byte = iota
tokenPrefixString
)
// A Token is derived from the client address and can be used to verify the ownership of this address.
type Token struct {
IsRetryToken bool
RemoteAddr string
SentTime time.Time
// only set for retry tokens
OriginalDestConnectionID protocol.ConnectionID
RetrySrcConnectionID protocol.ConnectionID
}
// token is the struct that is used for ASN1 serialization and deserialization
type token struct {
IsRetryToken bool
RemoteAddr []byte
Timestamp int64
OriginalDestConnectionID []byte
RetrySrcConnectionID []byte
}
// A TokenGenerator generates tokens
type TokenGenerator struct {
tokenProtector tokenProtector
}
// NewTokenGenerator initializes a new TookenGenerator
func NewTokenGenerator(rand io.Reader) (*TokenGenerator, error) {
tokenProtector, err := newTokenProtector(rand)
if err != nil {
return nil, err
}
return &TokenGenerator{
tokenProtector: tokenProtector,
}, nil
}
// NewRetryToken generates a new token for a Retry for a given source address
func (g *TokenGenerator) NewRetryToken(
raddr net.Addr,
origDestConnID protocol.ConnectionID,
retrySrcConnID protocol.ConnectionID,
) ([]byte, error) {
data, err := asn1.Marshal(token{
IsRetryToken: true,
RemoteAddr: encodeRemoteAddr(raddr),
OriginalDestConnectionID: origDestConnID,
RetrySrcConnectionID: retrySrcConnID,
Timestamp: time.Now().UnixNano(),
})
if err != nil {
return nil, err
}
return g.tokenProtector.NewToken(data)
}
// NewToken generates a new token to be sent in a NEW_TOKEN frame
func (g *TokenGenerator) NewToken(raddr net.Addr) ([]byte, error) {
data, err := asn1.Marshal(token{
RemoteAddr: encodeRemoteAddr(raddr),
Timestamp: time.Now().UnixNano(),
})
if err != nil {
return nil, err
}
return g.tokenProtector.NewToken(data)
}
// DecodeToken decodes a token
func (g *TokenGenerator) DecodeToken(encrypted []byte) (*Token, error) {
// if the client didn't send any token, DecodeToken will be called with a nil-slice
if len(encrypted) == 0 {
return nil, nil
}
data, err := g.tokenProtector.DecodeToken(encrypted)
if err != nil {
return nil, err
}
t := &token{}
rest, err := asn1.Unmarshal(data, t)
if err != nil {
return nil, err
}
if len(rest) != 0 {
return nil, fmt.Errorf("rest when unpacking token: %d", len(rest))
}
token := &Token{
IsRetryToken: t.IsRetryToken,
RemoteAddr: decodeRemoteAddr(t.RemoteAddr),
SentTime: time.Unix(0, t.Timestamp),
}
if t.IsRetryToken {
token.OriginalDestConnectionID = protocol.ConnectionID(t.OriginalDestConnectionID)
token.RetrySrcConnectionID = protocol.ConnectionID(t.RetrySrcConnectionID)
}
return token, nil
}
// encodeRemoteAddr encodes a remote address such that it can be saved in the token
func encodeRemoteAddr(remoteAddr net.Addr) []byte {
if udpAddr, ok := remoteAddr.(*net.UDPAddr); ok {
return append([]byte{tokenPrefixIP}, udpAddr.IP...)
}
return append([]byte{tokenPrefixString}, []byte(remoteAddr.String())...)
}
// decodeRemoteAddr decodes the remote address saved in the token
func decodeRemoteAddr(data []byte) string {
// data will never be empty for a token that we generated.
// Check it to be on the safe side
if len(data) == 0 {
return ""
}
if data[0] == tokenPrefixIP {
return net.IP(data[1:]).String()
}
return string(data[1:])
}

View File

@@ -0,0 +1,89 @@
package handshake
import (
"crypto/aes"
"crypto/cipher"
"crypto/sha256"
"fmt"
"io"
"golang.org/x/crypto/hkdf"
)
// TokenProtector is used to create and verify a token
type tokenProtector interface {
// NewToken creates a new token
NewToken([]byte) ([]byte, error)
// DecodeToken decodes a token
DecodeToken([]byte) ([]byte, error)
}
const (
tokenSecretSize = 32
tokenNonceSize = 32
)
// tokenProtector is used to create and verify a token
type tokenProtectorImpl struct {
rand io.Reader
secret []byte
}
// newTokenProtector creates a source for source address tokens
func newTokenProtector(rand io.Reader) (tokenProtector, error) {
secret := make([]byte, tokenSecretSize)
if _, err := rand.Read(secret); err != nil {
return nil, err
}
return &tokenProtectorImpl{
rand: rand,
secret: secret,
}, nil
}
// NewToken encodes data into a new token.
func (s *tokenProtectorImpl) NewToken(data []byte) ([]byte, error) {
nonce := make([]byte, tokenNonceSize)
if _, err := s.rand.Read(nonce); err != nil {
return nil, err
}
aead, aeadNonce, err := s.createAEAD(nonce)
if err != nil {
return nil, err
}
return append(nonce, aead.Seal(nil, aeadNonce, data, nil)...), nil
}
// DecodeToken decodes a token.
func (s *tokenProtectorImpl) DecodeToken(p []byte) ([]byte, error) {
if len(p) < tokenNonceSize {
return nil, fmt.Errorf("token too short: %d", len(p))
}
nonce := p[:tokenNonceSize]
aead, aeadNonce, err := s.createAEAD(nonce)
if err != nil {
return nil, err
}
return aead.Open(nil, aeadNonce, p[tokenNonceSize:], nil)
}
func (s *tokenProtectorImpl) createAEAD(nonce []byte) (cipher.AEAD, []byte, error) {
h := hkdf.New(sha256.New, s.secret, nonce, []byte("quic-go token source"))
key := make([]byte, 32) // use a 32 byte key, in order to select AES-256
if _, err := io.ReadFull(h, key); err != nil {
return nil, nil, err
}
aeadNonce := make([]byte, 12)
if _, err := io.ReadFull(h, aeadNonce); err != nil {
return nil, nil, err
}
c, err := aes.NewCipher(key)
if err != nil {
return nil, nil, err
}
aead, err := cipher.NewGCM(c)
if err != nil {
return nil, nil, err
}
return aead, aeadNonce, nil
}

View File

@@ -0,0 +1,321 @@
package handshake
import (
"crypto"
"crypto/cipher"
"crypto/tls"
"encoding/binary"
"fmt"
"time"
"github.com/lucas-clemente/quic-go/internal/protocol"
"github.com/lucas-clemente/quic-go/internal/qerr"
"github.com/lucas-clemente/quic-go/internal/qtls"
"github.com/lucas-clemente/quic-go/internal/utils"
"github.com/lucas-clemente/quic-go/logging"
)
// KeyUpdateInterval is the maximum number of packets we send or receive before initiating a key update.
// It's a package-level variable to allow modifying it for testing purposes.
var KeyUpdateInterval uint64 = protocol.KeyUpdateInterval
type updatableAEAD struct {
suite *qtls.CipherSuiteTLS13
keyPhase protocol.KeyPhase
largestAcked protocol.PacketNumber
firstPacketNumber protocol.PacketNumber
handshakeConfirmed bool
keyUpdateInterval uint64
invalidPacketLimit uint64
invalidPacketCount uint64
// Time when the keys should be dropped. Keys are dropped on the next call to Open().
prevRcvAEADExpiry time.Time
prevRcvAEAD cipher.AEAD
firstRcvdWithCurrentKey protocol.PacketNumber
firstSentWithCurrentKey protocol.PacketNumber
highestRcvdPN protocol.PacketNumber // highest packet number received (which could be successfully unprotected)
numRcvdWithCurrentKey uint64
numSentWithCurrentKey uint64
rcvAEAD cipher.AEAD
sendAEAD cipher.AEAD
// caches cipher.AEAD.Overhead(). This speeds up calls to Overhead().
aeadOverhead int
nextRcvAEAD cipher.AEAD
nextSendAEAD cipher.AEAD
nextRcvTrafficSecret []byte
nextSendTrafficSecret []byte
headerDecrypter headerProtector
headerEncrypter headerProtector
rttStats *utils.RTTStats
tracer logging.ConnectionTracer
logger utils.Logger
// use a single slice to avoid allocations
nonceBuf []byte
}
var (
_ ShortHeaderOpener = &updatableAEAD{}
_ ShortHeaderSealer = &updatableAEAD{}
)
func newUpdatableAEAD(rttStats *utils.RTTStats, tracer logging.ConnectionTracer, logger utils.Logger) *updatableAEAD {
return &updatableAEAD{
firstPacketNumber: protocol.InvalidPacketNumber,
largestAcked: protocol.InvalidPacketNumber,
firstRcvdWithCurrentKey: protocol.InvalidPacketNumber,
firstSentWithCurrentKey: protocol.InvalidPacketNumber,
keyUpdateInterval: KeyUpdateInterval,
rttStats: rttStats,
tracer: tracer,
logger: logger,
}
}
func (a *updatableAEAD) rollKeys() {
if a.prevRcvAEAD != nil {
a.logger.Debugf("Dropping key phase %d ahead of scheduled time. Drop time was: %s", a.keyPhase-1, a.prevRcvAEADExpiry)
if a.tracer != nil {
a.tracer.DroppedKey(a.keyPhase - 1)
}
a.prevRcvAEADExpiry = time.Time{}
}
a.keyPhase++
a.firstRcvdWithCurrentKey = protocol.InvalidPacketNumber
a.firstSentWithCurrentKey = protocol.InvalidPacketNumber
a.numRcvdWithCurrentKey = 0
a.numSentWithCurrentKey = 0
a.prevRcvAEAD = a.rcvAEAD
a.rcvAEAD = a.nextRcvAEAD
a.sendAEAD = a.nextSendAEAD
a.nextRcvTrafficSecret = a.getNextTrafficSecret(a.suite.Hash, a.nextRcvTrafficSecret)
a.nextSendTrafficSecret = a.getNextTrafficSecret(a.suite.Hash, a.nextSendTrafficSecret)
a.nextRcvAEAD = createAEAD(a.suite, a.nextRcvTrafficSecret)
a.nextSendAEAD = createAEAD(a.suite, a.nextSendTrafficSecret)
}
func (a *updatableAEAD) startKeyDropTimer(now time.Time) {
d := 3 * a.rttStats.PTO(true)
a.logger.Debugf("Starting key drop timer to drop key phase %d (in %s)", a.keyPhase-1, d)
a.prevRcvAEADExpiry = now.Add(d)
}
func (a *updatableAEAD) getNextTrafficSecret(hash crypto.Hash, ts []byte) []byte {
return hkdfExpandLabel(hash, ts, []byte{}, "quic ku", hash.Size())
}
// For the client, this function is called before SetWriteKey.
// For the server, this function is called after SetWriteKey.
func (a *updatableAEAD) SetReadKey(suite *qtls.CipherSuiteTLS13, trafficSecret []byte) {
a.rcvAEAD = createAEAD(suite, trafficSecret)
a.headerDecrypter = newHeaderProtector(suite, trafficSecret, false)
if a.suite == nil {
a.setAEADParameters(a.rcvAEAD, suite)
}
a.nextRcvTrafficSecret = a.getNextTrafficSecret(suite.Hash, trafficSecret)
a.nextRcvAEAD = createAEAD(suite, a.nextRcvTrafficSecret)
}
// For the client, this function is called after SetReadKey.
// For the server, this function is called before SetWriteKey.
func (a *updatableAEAD) SetWriteKey(suite *qtls.CipherSuiteTLS13, trafficSecret []byte) {
a.sendAEAD = createAEAD(suite, trafficSecret)
a.headerEncrypter = newHeaderProtector(suite, trafficSecret, false)
if a.suite == nil {
a.setAEADParameters(a.sendAEAD, suite)
}
a.nextSendTrafficSecret = a.getNextTrafficSecret(suite.Hash, trafficSecret)
a.nextSendAEAD = createAEAD(suite, a.nextSendTrafficSecret)
}
func (a *updatableAEAD) setAEADParameters(aead cipher.AEAD, suite *qtls.CipherSuiteTLS13) {
a.nonceBuf = make([]byte, aead.NonceSize())
a.aeadOverhead = aead.Overhead()
a.suite = suite
switch suite.ID {
case tls.TLS_AES_128_GCM_SHA256, tls.TLS_AES_256_GCM_SHA384:
a.invalidPacketLimit = protocol.InvalidPacketLimitAES
case tls.TLS_CHACHA20_POLY1305_SHA256:
a.invalidPacketLimit = protocol.InvalidPacketLimitChaCha
default:
panic(fmt.Sprintf("unknown cipher suite %d", suite.ID))
}
}
func (a *updatableAEAD) DecodePacketNumber(wirePN protocol.PacketNumber, wirePNLen protocol.PacketNumberLen) protocol.PacketNumber {
return protocol.DecodePacketNumber(wirePNLen, a.highestRcvdPN, wirePN)
}
func (a *updatableAEAD) Open(dst, src []byte, rcvTime time.Time, pn protocol.PacketNumber, kp protocol.KeyPhaseBit, ad []byte) ([]byte, error) {
dec, err := a.open(dst, src, rcvTime, pn, kp, ad)
if err == ErrDecryptionFailed {
a.invalidPacketCount++
if a.invalidPacketCount >= a.invalidPacketLimit {
return nil, &qerr.TransportError{ErrorCode: qerr.AEADLimitReached}
}
}
if err == nil {
a.highestRcvdPN = utils.MaxPacketNumber(a.highestRcvdPN, pn)
}
return dec, err
}
func (a *updatableAEAD) open(dst, src []byte, rcvTime time.Time, pn protocol.PacketNumber, kp protocol.KeyPhaseBit, ad []byte) ([]byte, error) {
if a.prevRcvAEAD != nil && !a.prevRcvAEADExpiry.IsZero() && rcvTime.After(a.prevRcvAEADExpiry) {
a.prevRcvAEAD = nil
a.logger.Debugf("Dropping key phase %d", a.keyPhase-1)
a.prevRcvAEADExpiry = time.Time{}
if a.tracer != nil {
a.tracer.DroppedKey(a.keyPhase - 1)
}
}
binary.BigEndian.PutUint64(a.nonceBuf[len(a.nonceBuf)-8:], uint64(pn))
if kp != a.keyPhase.Bit() {
if a.keyPhase > 0 && a.firstRcvdWithCurrentKey == protocol.InvalidPacketNumber || pn < a.firstRcvdWithCurrentKey {
if a.prevRcvAEAD == nil {
return nil, ErrKeysDropped
}
// we updated the key, but the peer hasn't updated yet
dec, err := a.prevRcvAEAD.Open(dst, a.nonceBuf, src, ad)
if err != nil {
err = ErrDecryptionFailed
}
return dec, err
}
// try opening the packet with the next key phase
dec, err := a.nextRcvAEAD.Open(dst, a.nonceBuf, src, ad)
if err != nil {
return nil, ErrDecryptionFailed
}
// Opening succeeded. Check if the peer was allowed to update.
if a.keyPhase > 0 && a.firstSentWithCurrentKey == protocol.InvalidPacketNumber {
return nil, &qerr.TransportError{
ErrorCode: qerr.KeyUpdateError,
ErrorMessage: "keys updated too quickly",
}
}
a.rollKeys()
a.logger.Debugf("Peer updated keys to %d", a.keyPhase)
// The peer initiated this key update. It's safe to drop the keys for the previous generation now.
// Start a timer to drop the previous key generation.
a.startKeyDropTimer(rcvTime)
if a.tracer != nil {
a.tracer.UpdatedKey(a.keyPhase, true)
}
a.firstRcvdWithCurrentKey = pn
return dec, err
}
// The AEAD we're using here will be the qtls.aeadAESGCM13.
// It uses the nonce provided here and XOR it with the IV.
dec, err := a.rcvAEAD.Open(dst, a.nonceBuf, src, ad)
if err != nil {
return dec, ErrDecryptionFailed
}
a.numRcvdWithCurrentKey++
if a.firstRcvdWithCurrentKey == protocol.InvalidPacketNumber {
// We initiated the key updated, and now we received the first packet protected with the new key phase.
// Therefore, we are certain that the peer rolled its keys as well. Start a timer to drop the old keys.
if a.keyPhase > 0 {
a.logger.Debugf("Peer confirmed key update to phase %d", a.keyPhase)
a.startKeyDropTimer(rcvTime)
}
a.firstRcvdWithCurrentKey = pn
}
return dec, err
}
func (a *updatableAEAD) Seal(dst, src []byte, pn protocol.PacketNumber, ad []byte) []byte {
if a.firstSentWithCurrentKey == protocol.InvalidPacketNumber {
a.firstSentWithCurrentKey = pn
}
if a.firstPacketNumber == protocol.InvalidPacketNumber {
a.firstPacketNumber = pn
}
a.numSentWithCurrentKey++
binary.BigEndian.PutUint64(a.nonceBuf[len(a.nonceBuf)-8:], uint64(pn))
// The AEAD we're using here will be the qtls.aeadAESGCM13.
// It uses the nonce provided here and XOR it with the IV.
return a.sendAEAD.Seal(dst, a.nonceBuf, src, ad)
}
func (a *updatableAEAD) SetLargestAcked(pn protocol.PacketNumber) error {
if a.firstSentWithCurrentKey != protocol.InvalidPacketNumber &&
pn >= a.firstSentWithCurrentKey && a.numRcvdWithCurrentKey == 0 {
return &qerr.TransportError{
ErrorCode: qerr.KeyUpdateError,
ErrorMessage: fmt.Sprintf("received ACK for key phase %d, but peer didn't update keys", a.keyPhase),
}
}
a.largestAcked = pn
return nil
}
func (a *updatableAEAD) SetHandshakeConfirmed() {
a.handshakeConfirmed = true
}
func (a *updatableAEAD) updateAllowed() bool {
if !a.handshakeConfirmed {
return false
}
// the first key update is allowed as soon as the handshake is confirmed
return a.keyPhase == 0 ||
// subsequent key updates as soon as a packet sent with that key phase has been acknowledged
(a.firstSentWithCurrentKey != protocol.InvalidPacketNumber &&
a.largestAcked != protocol.InvalidPacketNumber &&
a.largestAcked >= a.firstSentWithCurrentKey)
}
func (a *updatableAEAD) shouldInitiateKeyUpdate() bool {
if !a.updateAllowed() {
return false
}
if a.numRcvdWithCurrentKey >= a.keyUpdateInterval {
a.logger.Debugf("Received %d packets with current key phase. Initiating key update to the next key phase: %d", a.numRcvdWithCurrentKey, a.keyPhase+1)
return true
}
if a.numSentWithCurrentKey >= a.keyUpdateInterval {
a.logger.Debugf("Sent %d packets with current key phase. Initiating key update to the next key phase: %d", a.numSentWithCurrentKey, a.keyPhase+1)
return true
}
return false
}
func (a *updatableAEAD) KeyPhase() protocol.KeyPhaseBit {
if a.shouldInitiateKeyUpdate() {
a.rollKeys()
a.logger.Debugf("Initiating key update to key phase %d", a.keyPhase)
if a.tracer != nil {
a.tracer.UpdatedKey(a.keyPhase, false)
}
}
return a.keyPhase.Bit()
}
func (a *updatableAEAD) Overhead() int {
return a.aeadOverhead
}
func (a *updatableAEAD) EncryptHeader(sample []byte, firstByte *byte, hdrBytes []byte) {
a.headerEncrypter.EncryptHeader(sample, firstByte, hdrBytes)
}
func (a *updatableAEAD) DecryptHeader(sample []byte, firstByte *byte, hdrBytes []byte) {
a.headerDecrypter.DecryptHeader(sample, firstByte, hdrBytes)
}
func (a *updatableAEAD) FirstPacketNumber() protocol.PacketNumber {
return a.firstPacketNumber
}