fix application bug

This commit is contained in:
Jeff
2019-05-13 11:19:18 +08:00
committed by zryfish
parent 996d6fe4c5
commit 5462f51e65
717 changed files with 87703 additions and 53426 deletions

View File

@@ -0,0 +1,104 @@
package handshake
import (
"crypto/cipher"
"encoding/binary"
"github.com/lucas-clemente/quic-go/internal/protocol"
)
type sealer struct {
aead cipher.AEAD
hpEncrypter cipher.Block
// use a single slice to avoid allocations
nonceBuf []byte
hpMask []byte
// short headers protect 5 bits in the first byte, long headers only 4
is1RTT bool
}
var _ Sealer = &sealer{}
func newSealer(aead cipher.AEAD, hpEncrypter cipher.Block, is1RTT bool) Sealer {
return &sealer{
aead: aead,
nonceBuf: make([]byte, aead.NonceSize()),
is1RTT: is1RTT,
hpEncrypter: hpEncrypter,
hpMask: make([]byte, hpEncrypter.BlockSize()),
}
}
func (s *sealer) Seal(dst, src []byte, pn protocol.PacketNumber, ad []byte) []byte {
binary.BigEndian.PutUint64(s.nonceBuf[len(s.nonceBuf)-8:], uint64(pn))
// The AEAD we're using here will be the qtls.aeadAESGCM13.
// It uses the nonce provided here and XOR it with the IV.
return s.aead.Seal(dst, s.nonceBuf, src, ad)
}
func (s *sealer) EncryptHeader(sample []byte, firstByte *byte, pnBytes []byte) {
if len(sample) != s.hpEncrypter.BlockSize() {
panic("invalid sample size")
}
s.hpEncrypter.Encrypt(s.hpMask, sample)
if s.is1RTT {
*firstByte ^= s.hpMask[0] & 0x1f
} else {
*firstByte ^= s.hpMask[0] & 0xf
}
for i := range pnBytes {
pnBytes[i] ^= s.hpMask[i+1]
}
}
func (s *sealer) Overhead() int {
return s.aead.Overhead()
}
type opener struct {
aead cipher.AEAD
pnDecrypter cipher.Block
// use a single slice to avoid allocations
nonceBuf []byte
hpMask []byte
// short headers protect 5 bits in the first byte, long headers only 4
is1RTT bool
}
var _ Opener = &opener{}
func newOpener(aead cipher.AEAD, pnDecrypter cipher.Block, is1RTT bool) Opener {
return &opener{
aead: aead,
nonceBuf: make([]byte, aead.NonceSize()),
is1RTT: is1RTT,
pnDecrypter: pnDecrypter,
hpMask: make([]byte, pnDecrypter.BlockSize()),
}
}
func (o *opener) Open(dst, src []byte, pn protocol.PacketNumber, ad []byte) ([]byte, error) {
binary.BigEndian.PutUint64(o.nonceBuf[len(o.nonceBuf)-8:], uint64(pn))
// The AEAD we're using here will be the qtls.aeadAESGCM13.
// It uses the nonce provided here and XOR it with the IV.
return o.aead.Open(dst, o.nonceBuf, src, ad)
}
func (o *opener) DecryptHeader(sample []byte, firstByte *byte, pnBytes []byte) {
if len(sample) != o.pnDecrypter.BlockSize() {
panic("invalid sample size")
}
o.pnDecrypter.Encrypt(o.hpMask, sample)
if o.is1RTT {
*firstByte ^= o.hpMask[0] & 0x1f
} else {
*firstByte ^= o.hpMask[0] & 0xf
}
for i := range pnBytes {
pnBytes[i] ^= o.hpMask[i+1]
}
}

View File

@@ -5,6 +5,8 @@ import (
"fmt"
"net"
"time"
"github.com/lucas-clemente/quic-go/internal/protocol"
)
const (
@@ -14,14 +16,17 @@ const (
// A Cookie is derived from the client address and can be used to verify the ownership of this address.
type Cookie struct {
RemoteAddr string
// The time that the STK was issued (resolution 1 second)
RemoteAddr string
OriginalDestConnectionID protocol.ConnectionID
// The time that the Cookie was issued (resolution 1 second)
SentTime time.Time
}
// token is the struct that is used for ASN1 serialization and deserialization
type token struct {
Data []byte
RemoteAddr []byte
OriginalDestConnectionID []byte
Timestamp int64
}
@@ -42,10 +47,11 @@ func NewCookieGenerator() (*CookieGenerator, error) {
}
// NewToken generates a new Cookie for a given source address
func (g *CookieGenerator) NewToken(raddr net.Addr) ([]byte, error) {
func (g *CookieGenerator) NewToken(raddr net.Addr, origConnID protocol.ConnectionID) ([]byte, error) {
data, err := asn1.Marshal(token{
Data: encodeRemoteAddr(raddr),
Timestamp: time.Now().Unix(),
RemoteAddr: encodeRemoteAddr(raddr),
OriginalDestConnectionID: origConnID,
Timestamp: time.Now().Unix(),
})
if err != nil {
return nil, err
@@ -72,10 +78,14 @@ func (g *CookieGenerator) DecodeToken(encrypted []byte) (*Cookie, error) {
if len(rest) != 0 {
return nil, fmt.Errorf("rest when unpacking token: %d", len(rest))
}
return &Cookie{
RemoteAddr: decodeRemoteAddr(t.Data),
cookie := &Cookie{
RemoteAddr: decodeRemoteAddr(t.RemoteAddr),
SentTime: time.Unix(t.Timestamp, 0),
}, nil
}
if len(t.OriginalDestConnectionID) > 0 {
cookie.OriginalDestConnectionID = protocol.ConnectionID(t.OriginalDestConnectionID)
}
return cookie, nil
}
// encodeRemoteAddr encodes a remote address such that it can be saved in the Cookie

View File

@@ -0,0 +1,581 @@
package handshake
import (
"crypto/aes"
"crypto/tls"
"errors"
"fmt"
"io"
"net"
"sync"
"unsafe"
"github.com/lucas-clemente/quic-go/internal/protocol"
"github.com/lucas-clemente/quic-go/internal/qerr"
"github.com/lucas-clemente/quic-go/internal/utils"
"github.com/marten-seemann/qtls"
)
type messageType uint8
// TLS handshake message types.
const (
typeClientHello messageType = 1
typeServerHello messageType = 2
typeNewSessionTicket messageType = 4
typeEncryptedExtensions messageType = 8
typeCertificate messageType = 11
typeCertificateRequest messageType = 13
typeCertificateVerify messageType = 15
typeFinished messageType = 20
)
func (m messageType) String() string {
switch m {
case typeClientHello:
return "ClientHello"
case typeServerHello:
return "ServerHello"
case typeNewSessionTicket:
return "NewSessionTicket"
case typeEncryptedExtensions:
return "EncryptedExtensions"
case typeCertificate:
return "Certificate"
case typeCertificateRequest:
return "CertificateRequest"
case typeCertificateVerify:
return "CertificateVerify"
case typeFinished:
return "Finished"
default:
return fmt.Sprintf("unknown message type: %d", m)
}
}
// ErrOpenerNotYetAvailable is returned when an opener is requested for an encryption level,
// but the corresponding opener has not yet been initialized
// This can happen when packets arrive out of order.
var ErrOpenerNotYetAvailable = errors.New("CryptoSetup: opener at this encryption level not yet available")
type cryptoSetup struct {
tlsConf *qtls.Config
conn *qtls.Conn
messageChan chan []byte
paramsChan <-chan []byte
handleParamsCallback func([]byte)
alertChan chan uint8
// HandleData() sends errors on the messageErrChan
messageErrChan chan error
// handshakeDone is closed as soon as the go routine running qtls.Handshake() returns
handshakeDone chan struct{}
// is closed when Close() is called
closeChan chan struct{}
clientHelloWritten bool
clientHelloWrittenChan chan struct{}
receivedWriteKey chan struct{}
receivedReadKey chan struct{}
// WriteRecord does a non-blocking send on this channel.
// This way, handleMessage can see if qtls tries to write a message.
// This is necessary:
// for servers: to see if a HelloRetryRequest should be sent in response to a ClientHello
// for clients: to see if a ServerHello is a HelloRetryRequest
writeRecord chan struct{}
logger utils.Logger
perspective protocol.Perspective
mutex sync.Mutex // protects all members below
readEncLevel protocol.EncryptionLevel
writeEncLevel protocol.EncryptionLevel
initialStream io.Writer
initialOpener Opener
initialSealer Sealer
handshakeStream io.Writer
handshakeOpener Opener
handshakeSealer Sealer
oneRTTStream io.Writer
opener Opener
sealer Sealer
}
var _ qtls.RecordLayer = &cryptoSetup{}
var _ CryptoSetup = &cryptoSetup{}
// NewCryptoSetupClient creates a new crypto setup for the client
func NewCryptoSetupClient(
initialStream io.Writer,
handshakeStream io.Writer,
oneRTTStream io.Writer,
connID protocol.ConnectionID,
remoteAddr net.Addr,
tp *TransportParameters,
handleParams func([]byte),
tlsConf *tls.Config,
logger utils.Logger,
) (CryptoSetup, <-chan struct{} /* ClientHello written */, error) {
cs, clientHelloWritten, err := newCryptoSetup(
initialStream,
handshakeStream,
oneRTTStream,
connID,
tp,
handleParams,
tlsConf,
logger,
protocol.PerspectiveClient,
)
if err != nil {
return nil, nil, err
}
cs.conn = qtls.Client(newConn(remoteAddr), cs.tlsConf)
return cs, clientHelloWritten, nil
}
// NewCryptoSetupServer creates a new crypto setup for the server
func NewCryptoSetupServer(
initialStream io.Writer,
handshakeStream io.Writer,
oneRTTStream io.Writer,
connID protocol.ConnectionID,
remoteAddr net.Addr,
tp *TransportParameters,
handleParams func([]byte),
tlsConf *tls.Config,
logger utils.Logger,
) (CryptoSetup, error) {
cs, _, err := newCryptoSetup(
initialStream,
handshakeStream,
oneRTTStream,
connID,
tp,
handleParams,
tlsConf,
logger,
protocol.PerspectiveServer,
)
if err != nil {
return nil, err
}
cs.conn = qtls.Server(newConn(remoteAddr), cs.tlsConf)
return cs, nil
}
func newCryptoSetup(
initialStream io.Writer,
handshakeStream io.Writer,
oneRTTStream io.Writer,
connID protocol.ConnectionID,
tp *TransportParameters,
handleParams func([]byte),
tlsConf *tls.Config,
logger utils.Logger,
perspective protocol.Perspective,
) (*cryptoSetup, <-chan struct{} /* ClientHello written */, error) {
initialSealer, initialOpener, err := NewInitialAEAD(connID, perspective)
if err != nil {
return nil, nil, err
}
extHandler := newExtensionHandler(tp.Marshal(), perspective)
cs := &cryptoSetup{
initialStream: initialStream,
initialSealer: initialSealer,
initialOpener: initialOpener,
handshakeStream: handshakeStream,
oneRTTStream: oneRTTStream,
readEncLevel: protocol.EncryptionInitial,
writeEncLevel: protocol.EncryptionInitial,
handleParamsCallback: handleParams,
paramsChan: extHandler.TransportParameters(),
logger: logger,
perspective: perspective,
handshakeDone: make(chan struct{}),
alertChan: make(chan uint8),
messageErrChan: make(chan error, 1),
clientHelloWrittenChan: make(chan struct{}),
messageChan: make(chan []byte, 100),
receivedReadKey: make(chan struct{}),
receivedWriteKey: make(chan struct{}),
writeRecord: make(chan struct{}),
closeChan: make(chan struct{}),
}
qtlsConf := tlsConfigToQtlsConfig(tlsConf, cs, extHandler)
cs.tlsConf = qtlsConf
return cs, cs.clientHelloWrittenChan, nil
}
func (h *cryptoSetup) ChangeConnectionID(id protocol.ConnectionID) error {
initialSealer, initialOpener, err := NewInitialAEAD(id, h.perspective)
if err != nil {
return err
}
h.initialSealer = initialSealer
h.initialOpener = initialOpener
return nil
}
func (h *cryptoSetup) RunHandshake() error {
// Handle errors that might occur when HandleData() is called.
handshakeComplete := make(chan struct{})
handshakeErrChan := make(chan error, 1)
go func() {
defer close(h.handshakeDone)
if err := h.conn.Handshake(); err != nil {
handshakeErrChan <- err
return
}
close(handshakeComplete)
}()
select {
case <-h.closeChan:
close(h.messageChan)
// wait until the Handshake() go routine has returned
return errors.New("Handshake aborted")
case <-handshakeComplete: // return when the handshake is done
return nil
case alert := <-h.alertChan:
err := <-handshakeErrChan
return qerr.CryptoError(alert, err.Error())
case err := <-h.messageErrChan:
// If the handshake errored because of an error that occurred during HandleData(),
// that error message will be more useful than the error message generated by Handshake().
// Close the message chan that qtls is receiving messages from.
// This will make qtls.Handshake() return.
// Thereby the go routine running qtls.Handshake() will return.
close(h.messageChan)
return err
}
}
func (h *cryptoSetup) Close() error {
close(h.closeChan)
// wait until qtls.Handshake() actually returned
<-h.handshakeDone
return nil
}
// handleMessage handles a TLS handshake message.
// It is called by the crypto streams when a new message is available.
// It returns if it is done with messages on the same encryption level.
func (h *cryptoSetup) HandleMessage(data []byte, encLevel protocol.EncryptionLevel) bool /* stream finished */ {
msgType := messageType(data[0])
h.logger.Debugf("Received %s message (%d bytes, encryption level: %s)", msgType, len(data), encLevel)
if err := h.checkEncryptionLevel(msgType, encLevel); err != nil {
h.messageErrChan <- err
return false
}
h.messageChan <- data
switch h.perspective {
case protocol.PerspectiveClient:
return h.handleMessageForClient(msgType)
case protocol.PerspectiveServer:
return h.handleMessageForServer(msgType)
default:
panic("")
}
}
func (h *cryptoSetup) checkEncryptionLevel(msgType messageType, encLevel protocol.EncryptionLevel) error {
var expected protocol.EncryptionLevel
switch msgType {
case typeClientHello,
typeServerHello:
expected = protocol.EncryptionInitial
case typeEncryptedExtensions,
typeCertificate,
typeCertificateRequest,
typeCertificateVerify,
typeFinished:
expected = protocol.EncryptionHandshake
case typeNewSessionTicket:
expected = protocol.Encryption1RTT
default:
return fmt.Errorf("unexpected handshake message: %d", msgType)
}
if encLevel != expected {
return fmt.Errorf("expected handshake message %s to have encryption level %s, has %s", msgType, expected, encLevel)
}
return nil
}
func (h *cryptoSetup) handleMessageForServer(msgType messageType) bool {
switch msgType {
case typeClientHello:
select {
case <-h.writeRecord:
// If qtls sends a HelloRetryRequest, it will only write the record.
// If it accepts the ClientHello, it will first read the transport parameters.
h.logger.Debugf("Sending HelloRetryRequest")
return false
case data := <-h.paramsChan:
h.handleParamsCallback(data)
case <-h.handshakeDone:
return false
}
// get the handshake read key
select {
case <-h.receivedReadKey:
case <-h.handshakeDone:
return false
}
// get the handshake write key
select {
case <-h.receivedWriteKey:
case <-h.handshakeDone:
return false
}
// get the 1-RTT write key
select {
case <-h.receivedWriteKey:
case <-h.handshakeDone:
return false
}
return true
case typeCertificate, typeCertificateVerify:
// nothing to do
return false
case typeFinished:
// get the 1-RTT read key
select {
case <-h.receivedReadKey:
case <-h.handshakeDone:
return false
}
return true
default:
panic("unexpected handshake message")
}
}
func (h *cryptoSetup) handleMessageForClient(msgType messageType) bool {
switch msgType {
case typeServerHello:
// get the handshake write key
select {
case <-h.writeRecord:
// If qtls writes in response to a ServerHello, this means that this ServerHello
// is a HelloRetryRequest.
// Otherwise, we'd just wait for the Certificate message.
h.logger.Debugf("ServerHello is a HelloRetryRequest")
return false
case <-h.receivedWriteKey:
case <-h.handshakeDone:
return false
}
// get the handshake read key
select {
case <-h.receivedReadKey:
case <-h.handshakeDone:
return false
}
return true
case typeEncryptedExtensions:
select {
case data := <-h.paramsChan:
h.handleParamsCallback(data)
case <-h.handshakeDone:
return false
}
return false
case typeCertificateRequest, typeCertificate, typeCertificateVerify:
// nothing to do
return false
case typeFinished:
// get the 1-RTT read key
select {
case <-h.receivedReadKey:
case <-h.handshakeDone:
return false
}
// get the handshake write key
select {
case <-h.receivedWriteKey:
case <-h.handshakeDone:
return false
}
return true
case typeNewSessionTicket:
<-h.handshakeDone // don't process session tickets before the handshake has completed
h.conn.HandlePostHandshakeMessage()
return false
default:
panic("unexpected handshake message: ")
}
}
// ReadHandshakeMessage is called by TLS.
// It blocks until a new handshake message is available.
func (h *cryptoSetup) ReadHandshakeMessage() ([]byte, error) {
msg, ok := <-h.messageChan
if !ok {
return nil, errors.New("error while handling the handshake message")
}
return msg, nil
}
func (h *cryptoSetup) SetReadKey(suite *qtls.CipherSuite, trafficSecret []byte) {
key := qtls.HkdfExpandLabel(suite.Hash(), trafficSecret, []byte{}, "quic key", suite.KeyLen())
iv := qtls.HkdfExpandLabel(suite.Hash(), trafficSecret, []byte{}, "quic iv", suite.IVLen())
hpKey := qtls.HkdfExpandLabel(suite.Hash(), trafficSecret, []byte{}, "quic hp", suite.KeyLen())
hpDecrypter, err := aes.NewCipher(hpKey)
if err != nil {
panic(fmt.Sprintf("error creating new AES cipher: %s", err))
}
h.mutex.Lock()
switch h.readEncLevel {
case protocol.EncryptionInitial:
h.readEncLevel = protocol.EncryptionHandshake
h.handshakeOpener = newOpener(suite.AEAD(key, iv), hpDecrypter, false)
h.logger.Debugf("Installed Handshake Read keys")
case protocol.EncryptionHandshake:
h.readEncLevel = protocol.Encryption1RTT
h.opener = newOpener(suite.AEAD(key, iv), hpDecrypter, true)
h.logger.Debugf("Installed 1-RTT Read keys")
default:
panic("unexpected read encryption level")
}
h.mutex.Unlock()
h.receivedReadKey <- struct{}{}
}
func (h *cryptoSetup) SetWriteKey(suite *qtls.CipherSuite, trafficSecret []byte) {
key := qtls.HkdfExpandLabel(suite.Hash(), trafficSecret, []byte{}, "quic key", suite.KeyLen())
iv := qtls.HkdfExpandLabel(suite.Hash(), trafficSecret, []byte{}, "quic iv", suite.IVLen())
hpKey := qtls.HkdfExpandLabel(suite.Hash(), trafficSecret, []byte{}, "quic hp", suite.KeyLen())
hpEncrypter, err := aes.NewCipher(hpKey)
if err != nil {
panic(fmt.Sprintf("error creating new AES cipher: %s", err))
}
h.mutex.Lock()
switch h.writeEncLevel {
case protocol.EncryptionInitial:
h.writeEncLevel = protocol.EncryptionHandshake
h.handshakeSealer = newSealer(suite.AEAD(key, iv), hpEncrypter, false)
h.logger.Debugf("Installed Handshake Write keys")
case protocol.EncryptionHandshake:
h.writeEncLevel = protocol.Encryption1RTT
h.sealer = newSealer(suite.AEAD(key, iv), hpEncrypter, true)
h.logger.Debugf("Installed 1-RTT Write keys")
default:
panic("unexpected write encryption level")
}
h.mutex.Unlock()
h.receivedWriteKey <- struct{}{}
}
// WriteRecord is called when TLS writes data
func (h *cryptoSetup) WriteRecord(p []byte) (int, error) {
defer func() {
select {
case h.writeRecord <- struct{}{}:
default:
}
}()
h.mutex.Lock()
defer h.mutex.Unlock()
switch h.writeEncLevel {
case protocol.EncryptionInitial:
// assume that the first WriteRecord call contains the ClientHello
n, err := h.initialStream.Write(p)
if !h.clientHelloWritten && h.perspective == protocol.PerspectiveClient {
h.clientHelloWritten = true
close(h.clientHelloWrittenChan)
}
return n, err
case protocol.EncryptionHandshake:
return h.handshakeStream.Write(p)
case protocol.Encryption1RTT:
return h.oneRTTStream.Write(p)
default:
panic(fmt.Sprintf("unexpected write encryption level: %s", h.writeEncLevel))
}
}
func (h *cryptoSetup) SendAlert(alert uint8) {
h.alertChan <- alert
}
func (h *cryptoSetup) GetSealer() (protocol.EncryptionLevel, Sealer) {
h.mutex.Lock()
defer h.mutex.Unlock()
if h.sealer != nil {
return protocol.Encryption1RTT, h.sealer
}
if h.handshakeSealer != nil {
return protocol.EncryptionHandshake, h.handshakeSealer
}
return protocol.EncryptionInitial, h.initialSealer
}
func (h *cryptoSetup) GetSealerWithEncryptionLevel(level protocol.EncryptionLevel) (Sealer, error) {
errNoSealer := fmt.Errorf("CryptoSetup: no sealer with encryption level %s", level.String())
h.mutex.Lock()
defer h.mutex.Unlock()
switch level {
case protocol.EncryptionInitial:
return h.initialSealer, nil
case protocol.EncryptionHandshake:
if h.handshakeSealer == nil {
return nil, errNoSealer
}
return h.handshakeSealer, nil
case protocol.Encryption1RTT:
if h.sealer == nil {
return nil, errNoSealer
}
return h.sealer, nil
default:
return nil, errNoSealer
}
}
func (h *cryptoSetup) GetOpener(level protocol.EncryptionLevel) (Opener, error) {
h.mutex.Lock()
defer h.mutex.Unlock()
switch level {
case protocol.EncryptionInitial:
return h.initialOpener, nil
case protocol.EncryptionHandshake:
if h.handshakeOpener == nil {
return nil, ErrOpenerNotYetAvailable
}
return h.handshakeOpener, nil
case protocol.Encryption1RTT:
if h.opener == nil {
return nil, ErrOpenerNotYetAvailable
}
return h.opener, nil
default:
return nil, fmt.Errorf("CryptoSetup: no opener with encryption level %s", level)
}
}
func (h *cryptoSetup) ConnectionState() tls.ConnectionState {
cs := h.conn.ConnectionState()
// h.conn is a qtls.Conn, which returns a qtls.ConnectionState.
// qtls.ConnectionState is identical to the tls.ConnectionState.
// It contains an unexported field which is used ExportKeyingMaterial().
// The only way to return a tls.ConnectionState is to use unsafe.
// In unsafe.go we check that the two objects are actually identical.
return *(*tls.ConnectionState)(unsafe.Pointer(&cs))
}

View File

@@ -1,542 +0,0 @@
package handshake
import (
"bytes"
"crypto/rand"
"crypto/tls"
"encoding/binary"
"errors"
"fmt"
"io"
"sync"
"time"
"github.com/lucas-clemente/quic-go/internal/crypto"
"github.com/lucas-clemente/quic-go/internal/protocol"
"github.com/lucas-clemente/quic-go/internal/utils"
"github.com/lucas-clemente/quic-go/qerr"
)
type cryptoSetupClient struct {
mutex sync.RWMutex
hostname string
connID protocol.ConnectionID
version protocol.VersionNumber
initialVersion protocol.VersionNumber
negotiatedVersions []protocol.VersionNumber
cryptoStream io.ReadWriter
serverConfig *serverConfigClient
stk []byte
sno []byte
nonc []byte
proof []byte
chloForSignature []byte
lastSentCHLO []byte
certManager crypto.CertManager
divNonceChan chan struct{}
diversificationNonce []byte
clientHelloCounter int
serverVerified bool // has the certificate chain and the proof already been verified
keyDerivation QuicCryptoKeyDerivationFunction
receivedSecurePacket bool
nullAEAD crypto.AEAD
secureAEAD crypto.AEAD
forwardSecureAEAD crypto.AEAD
paramsChan chan<- TransportParameters
handshakeEvent chan<- struct{}
params *TransportParameters
logger utils.Logger
}
var _ CryptoSetup = &cryptoSetupClient{}
var (
errNoObitForClientNonce = errors.New("CryptoSetup BUG: No OBIT for client nonce available")
errClientNonceAlreadyExists = errors.New("CryptoSetup BUG: A client nonce was already generated")
errConflictingDiversificationNonces = errors.New("Received two different diversification nonces")
)
// NewCryptoSetupClient creates a new CryptoSetup instance for a client
func NewCryptoSetupClient(
cryptoStream io.ReadWriter,
connID protocol.ConnectionID,
version protocol.VersionNumber,
tlsConf *tls.Config,
params *TransportParameters,
paramsChan chan<- TransportParameters,
handshakeEvent chan<- struct{},
initialVersion protocol.VersionNumber,
negotiatedVersions []protocol.VersionNumber,
logger utils.Logger,
) (CryptoSetup, error) {
nullAEAD, err := crypto.NewNullAEAD(protocol.PerspectiveClient, connID, version)
if err != nil {
return nil, err
}
divNonceChan := make(chan struct{})
cs := &cryptoSetupClient{
cryptoStream: cryptoStream,
hostname: tlsConf.ServerName,
connID: connID,
version: version,
certManager: crypto.NewCertManager(tlsConf),
params: params,
keyDerivation: crypto.DeriveQuicCryptoAESKeys,
nullAEAD: nullAEAD,
paramsChan: paramsChan,
handshakeEvent: handshakeEvent,
initialVersion: initialVersion,
// The server might have sent greased versions in the Version Negotiation packet.
// We need strip those from the list, since they won't be included in the handshake tag.
negotiatedVersions: protocol.StripGreasedVersions(negotiatedVersions),
divNonceChan: divNonceChan,
logger: logger,
}
return cs, nil
}
func (h *cryptoSetupClient) HandleCryptoStream() error {
messageChan := make(chan HandshakeMessage)
errorChan := make(chan error, 1)
go func() {
for {
message, err := ParseHandshakeMessage(h.cryptoStream)
if err != nil {
errorChan <- qerr.Error(qerr.HandshakeFailed, err.Error())
return
}
messageChan <- message
}
}()
for {
if err := h.maybeUpgradeCrypto(); err != nil {
return err
}
h.mutex.RLock()
sendCHLO := h.secureAEAD == nil
h.mutex.RUnlock()
if sendCHLO {
if err := h.sendCHLO(); err != nil {
return err
}
}
var message HandshakeMessage
select {
case <-h.divNonceChan:
// there's no message to process, but we should try upgrading the crypto again
continue
case message = <-messageChan:
case err := <-errorChan:
return err
}
h.logger.Debugf("Got %s", message)
switch message.Tag {
case TagREJ:
if err := h.handleREJMessage(message.Data); err != nil {
return err
}
case TagSHLO:
params, err := h.handleSHLOMessage(message.Data)
if err != nil {
return err
}
// blocks until the session has received the parameters
h.paramsChan <- *params
h.handshakeEvent <- struct{}{}
close(h.handshakeEvent)
default:
return qerr.InvalidCryptoMessageType
}
}
}
func (h *cryptoSetupClient) handleREJMessage(cryptoData map[Tag][]byte) error {
var err error
if stk, ok := cryptoData[TagSTK]; ok {
h.stk = stk
}
if sno, ok := cryptoData[TagSNO]; ok {
h.sno = sno
}
// TODO: what happens if the server sends a different server config in two packets?
if scfg, ok := cryptoData[TagSCFG]; ok {
h.serverConfig, err = parseServerConfig(scfg)
if err != nil {
return err
}
if h.serverConfig.IsExpired() {
return qerr.CryptoServerConfigExpired
}
// now that we have a server config, we can use its OBIT value to generate a client nonce
if len(h.nonc) == 0 {
err = h.generateClientNonce()
if err != nil {
return err
}
}
}
if proof, ok := cryptoData[TagPROF]; ok {
h.proof = proof
h.chloForSignature = h.lastSentCHLO
}
if crt, ok := cryptoData[TagCERT]; ok {
err := h.certManager.SetData(crt)
if err != nil {
return qerr.Error(qerr.InvalidCryptoMessageParameter, "Certificate data invalid")
}
err = h.certManager.Verify(h.hostname)
if err != nil {
h.logger.Infof("Certificate validation failed: %s", err.Error())
return qerr.ProofInvalid
}
}
if h.serverConfig != nil && len(h.proof) != 0 && h.certManager.GetLeafCert() != nil {
validProof := h.certManager.VerifyServerProof(h.proof, h.chloForSignature, h.serverConfig.Get())
if !validProof {
h.logger.Infof("Server proof verification failed")
return qerr.ProofInvalid
}
h.serverVerified = true
}
return nil
}
func (h *cryptoSetupClient) handleSHLOMessage(cryptoData map[Tag][]byte) (*TransportParameters, error) {
h.mutex.Lock()
defer h.mutex.Unlock()
if !h.receivedSecurePacket {
return nil, qerr.Error(qerr.CryptoEncryptionLevelIncorrect, "unencrypted SHLO message")
}
if sno, ok := cryptoData[TagSNO]; ok {
h.sno = sno
}
serverPubs, ok := cryptoData[TagPUBS]
if !ok {
return nil, qerr.Error(qerr.CryptoMessageParameterNotFound, "PUBS")
}
verTag, ok := cryptoData[TagVER]
if !ok {
return nil, qerr.Error(qerr.InvalidCryptoMessageParameter, "server hello missing version list")
}
if !h.validateVersionList(verTag) {
return nil, qerr.Error(qerr.VersionNegotiationMismatch, "Downgrade attack detected")
}
nonce := append(h.nonc, h.sno...)
ephermalSharedSecret, err := h.serverConfig.kex.CalculateSharedKey(serverPubs)
if err != nil {
return nil, err
}
leafCert := h.certManager.GetLeafCert()
h.forwardSecureAEAD, err = h.keyDerivation(
true,
ephermalSharedSecret,
nonce,
h.connID,
h.lastSentCHLO,
h.serverConfig.Get(),
leafCert,
nil,
protocol.PerspectiveClient,
)
if err != nil {
return nil, err
}
h.logger.Debugf("Creating AEAD for forward-secure encryption. Stopping to accept all lower encryption levels.")
params, err := readHelloMap(cryptoData)
if err != nil {
return nil, qerr.InvalidCryptoMessageParameter
}
return params, nil
}
func (h *cryptoSetupClient) validateVersionList(verTags []byte) bool {
numNegotiatedVersions := len(h.negotiatedVersions)
if numNegotiatedVersions == 0 {
return true
}
if len(verTags)%4 != 0 || len(verTags)/4 != numNegotiatedVersions {
return false
}
b := bytes.NewReader(verTags)
for i := 0; i < numNegotiatedVersions; i++ {
v, err := utils.BigEndian.ReadUint32(b)
if err != nil { // should never occur, since the length was already checked
return false
}
if protocol.VersionNumber(v) != h.negotiatedVersions[i] {
return false
}
}
return true
}
func (h *cryptoSetupClient) Open(dst, src []byte, packetNumber protocol.PacketNumber, associatedData []byte) ([]byte, protocol.EncryptionLevel, error) {
h.mutex.RLock()
defer h.mutex.RUnlock()
if h.forwardSecureAEAD != nil {
data, err := h.forwardSecureAEAD.Open(dst, src, packetNumber, associatedData)
if err == nil {
return data, protocol.EncryptionForwardSecure, nil
}
return nil, protocol.EncryptionUnspecified, err
}
if h.secureAEAD != nil {
data, err := h.secureAEAD.Open(dst, src, packetNumber, associatedData)
if err == nil {
h.logger.Debugf("Received first secure packet. Stopping to accept unencrypted packets.")
h.receivedSecurePacket = true
return data, protocol.EncryptionSecure, nil
}
if h.receivedSecurePacket {
return nil, protocol.EncryptionUnspecified, err
}
}
res, err := h.nullAEAD.Open(dst, src, packetNumber, associatedData)
if err != nil {
return nil, protocol.EncryptionUnspecified, err
}
return res, protocol.EncryptionUnencrypted, nil
}
func (h *cryptoSetupClient) GetSealer() (protocol.EncryptionLevel, Sealer) {
h.mutex.RLock()
defer h.mutex.RUnlock()
if h.forwardSecureAEAD != nil {
return protocol.EncryptionForwardSecure, h.forwardSecureAEAD
} else if h.secureAEAD != nil {
return protocol.EncryptionSecure, h.secureAEAD
} else {
return protocol.EncryptionUnencrypted, h.nullAEAD
}
}
func (h *cryptoSetupClient) GetSealerForCryptoStream() (protocol.EncryptionLevel, Sealer) {
return protocol.EncryptionUnencrypted, h.nullAEAD
}
func (h *cryptoSetupClient) GetSealerWithEncryptionLevel(encLevel protocol.EncryptionLevel) (Sealer, error) {
h.mutex.RLock()
defer h.mutex.RUnlock()
switch encLevel {
case protocol.EncryptionUnencrypted:
return h.nullAEAD, nil
case protocol.EncryptionSecure:
if h.secureAEAD == nil {
return nil, errors.New("CryptoSetupClient: no secureAEAD")
}
return h.secureAEAD, nil
case protocol.EncryptionForwardSecure:
if h.forwardSecureAEAD == nil {
return nil, errors.New("CryptoSetupClient: no forwardSecureAEAD")
}
return h.forwardSecureAEAD, nil
}
return nil, errors.New("CryptoSetupClient: no encryption level specified")
}
func (h *cryptoSetupClient) ConnectionState() ConnectionState {
h.mutex.Lock()
defer h.mutex.Unlock()
return ConnectionState{
HandshakeComplete: h.forwardSecureAEAD != nil,
PeerCertificates: h.certManager.GetChain(),
}
}
func (h *cryptoSetupClient) SetDiversificationNonce(divNonce []byte) error {
h.mutex.Lock()
if len(h.diversificationNonce) > 0 {
defer h.mutex.Unlock()
if !bytes.Equal(h.diversificationNonce, divNonce) {
return errConflictingDiversificationNonces
}
return nil
}
h.diversificationNonce = divNonce
h.mutex.Unlock()
h.divNonceChan <- struct{}{}
return nil
}
func (h *cryptoSetupClient) sendCHLO() error {
h.clientHelloCounter++
if h.clientHelloCounter > protocol.MaxClientHellos {
return qerr.Error(qerr.CryptoTooManyRejects, fmt.Sprintf("More than %d rejects", protocol.MaxClientHellos))
}
b := &bytes.Buffer{}
tags, err := h.getTags()
if err != nil {
return err
}
h.addPadding(tags)
message := HandshakeMessage{
Tag: TagCHLO,
Data: tags,
}
h.logger.Debugf("Sending %s", message)
message.Write(b)
_, err = h.cryptoStream.Write(b.Bytes())
if err != nil {
return err
}
h.lastSentCHLO = b.Bytes()
return nil
}
func (h *cryptoSetupClient) getTags() (map[Tag][]byte, error) {
tags := h.params.getHelloMap()
tags[TagSNI] = []byte(h.hostname)
tags[TagPDMD] = []byte("X509")
ccs := h.certManager.GetCommonCertificateHashes()
if len(ccs) > 0 {
tags[TagCCS] = ccs
}
versionTag := make([]byte, 4)
binary.BigEndian.PutUint32(versionTag, uint32(h.initialVersion))
tags[TagVER] = versionTag
if len(h.stk) > 0 {
tags[TagSTK] = h.stk
}
if len(h.sno) > 0 {
tags[TagSNO] = h.sno
}
if h.serverConfig != nil {
tags[TagSCID] = h.serverConfig.ID
leafCert := h.certManager.GetLeafCert()
if leafCert != nil {
certHash, _ := h.certManager.GetLeafCertHash()
xlct := make([]byte, 8)
binary.LittleEndian.PutUint64(xlct, certHash)
tags[TagNONC] = h.nonc
tags[TagXLCT] = xlct
tags[TagKEXS] = []byte("C255")
tags[TagAEAD] = []byte("AESG")
tags[TagPUBS] = h.serverConfig.kex.PublicKey() // TODO: check if 3 bytes need to be prepended
}
}
return tags, nil
}
// add a TagPAD to a tagMap, such that the total size will be bigger than the ClientHelloMinimumSize
func (h *cryptoSetupClient) addPadding(tags map[Tag][]byte) {
var size int
for _, tag := range tags {
size += 8 + len(tag) // 4 bytes for the tag + 4 bytes for the offset + the length of the data
}
paddingSize := protocol.MinClientHelloSize - size
if paddingSize > 0 {
tags[TagPAD] = bytes.Repeat([]byte{0}, paddingSize)
}
}
func (h *cryptoSetupClient) maybeUpgradeCrypto() error {
if !h.serverVerified {
return nil
}
h.mutex.Lock()
defer h.mutex.Unlock()
leafCert := h.certManager.GetLeafCert()
if h.secureAEAD == nil && (h.serverConfig != nil && len(h.serverConfig.sharedSecret) > 0 && len(h.nonc) > 0 && len(leafCert) > 0 && len(h.diversificationNonce) > 0 && len(h.lastSentCHLO) > 0) {
var err error
var nonce []byte
if h.sno == nil {
nonce = h.nonc
} else {
nonce = append(h.nonc, h.sno...)
}
h.secureAEAD, err = h.keyDerivation(
false,
h.serverConfig.sharedSecret,
nonce,
h.connID,
h.lastSentCHLO,
h.serverConfig.Get(),
leafCert,
h.diversificationNonce,
protocol.PerspectiveClient,
)
if err != nil {
return err
}
h.logger.Debugf("Creating AEAD for secure encryption.")
h.handshakeEvent <- struct{}{}
}
return nil
}
func (h *cryptoSetupClient) generateClientNonce() error {
if len(h.nonc) > 0 {
return errClientNonceAlreadyExists
}
nonc := make([]byte, 32)
binary.BigEndian.PutUint32(nonc, uint32(time.Now().Unix()))
if len(h.serverConfig.obit) != 8 {
return errNoObitForClientNonce
}
copy(nonc[4:12], h.serverConfig.obit)
_, err := rand.Read(nonc[12:])
if err != nil {
return err
}
h.nonc = nonc
return nil
}

View File

@@ -1,467 +0,0 @@
package handshake
import (
"bytes"
"crypto/rand"
"encoding/binary"
"errors"
"io"
"net"
"sync"
"github.com/lucas-clemente/quic-go/internal/crypto"
"github.com/lucas-clemente/quic-go/internal/protocol"
"github.com/lucas-clemente/quic-go/internal/utils"
"github.com/lucas-clemente/quic-go/qerr"
)
// QuicCryptoKeyDerivationFunction is used for key derivation
type QuicCryptoKeyDerivationFunction func(forwardSecure bool, sharedSecret, nonces []byte, connID protocol.ConnectionID, chlo []byte, scfg []byte, cert []byte, divNonce []byte, pers protocol.Perspective) (crypto.AEAD, error)
// KeyExchangeFunction is used to make a new KEX
type KeyExchangeFunction func() (crypto.KeyExchange, error)
// The CryptoSetupServer handles all things crypto for the Session
type cryptoSetupServer struct {
mutex sync.RWMutex
connID protocol.ConnectionID
remoteAddr net.Addr
scfg *ServerConfig
diversificationNonce []byte
version protocol.VersionNumber
supportedVersions []protocol.VersionNumber
acceptSTKCallback func(net.Addr, *Cookie) bool
nullAEAD crypto.AEAD
secureAEAD crypto.AEAD
forwardSecureAEAD crypto.AEAD
receivedForwardSecurePacket bool
receivedSecurePacket bool
sentSHLO chan struct{} // this channel is closed as soon as the SHLO has been written
receivedParams bool
paramsChan chan<- TransportParameters
handshakeEvent chan<- struct{}
keyDerivation QuicCryptoKeyDerivationFunction
keyExchange KeyExchangeFunction
cryptoStream io.ReadWriter
params *TransportParameters
sni string // need to fill out the ConnectionState
logger utils.Logger
}
var _ CryptoSetup = &cryptoSetupServer{}
// ErrNSTPExperiment is returned when the client sends the NSTP tag in the CHLO.
// This is an experiment implemented by Chrome in QUIC 38, which we don't support at this point.
var ErrNSTPExperiment = qerr.Error(qerr.InvalidCryptoMessageParameter, "NSTP experiment. Unsupported")
// NewCryptoSetup creates a new CryptoSetup instance for a server
func NewCryptoSetup(
cryptoStream io.ReadWriter,
connID protocol.ConnectionID,
remoteAddr net.Addr,
version protocol.VersionNumber,
divNonce []byte,
scfg *ServerConfig,
params *TransportParameters,
supportedVersions []protocol.VersionNumber,
acceptSTK func(net.Addr, *Cookie) bool,
paramsChan chan<- TransportParameters,
handshakeEvent chan<- struct{},
logger utils.Logger,
) (CryptoSetup, error) {
nullAEAD, err := crypto.NewNullAEAD(protocol.PerspectiveServer, connID, version)
if err != nil {
return nil, err
}
return &cryptoSetupServer{
cryptoStream: cryptoStream,
connID: connID,
remoteAddr: remoteAddr,
version: version,
supportedVersions: supportedVersions,
diversificationNonce: divNonce,
scfg: scfg,
keyDerivation: crypto.DeriveQuicCryptoAESKeys,
keyExchange: getEphermalKEX,
nullAEAD: nullAEAD,
params: params,
acceptSTKCallback: acceptSTK,
sentSHLO: make(chan struct{}),
paramsChan: paramsChan,
handshakeEvent: handshakeEvent,
logger: logger,
}, nil
}
// HandleCryptoStream reads and writes messages on the crypto stream
func (h *cryptoSetupServer) HandleCryptoStream() error {
for {
var chloData bytes.Buffer
message, err := ParseHandshakeMessage(io.TeeReader(h.cryptoStream, &chloData))
if err != nil {
return qerr.HandshakeFailed
}
if message.Tag != TagCHLO {
return qerr.InvalidCryptoMessageType
}
h.logger.Debugf("Got %s", message)
done, err := h.handleMessage(chloData.Bytes(), message.Data)
if err != nil {
return err
}
if done {
return nil
}
}
}
func (h *cryptoSetupServer) handleMessage(chloData []byte, cryptoData map[Tag][]byte) (bool, error) {
if _, isNSTPExperiment := cryptoData[TagNSTP]; isNSTPExperiment {
return false, ErrNSTPExperiment
}
sniSlice, ok := cryptoData[TagSNI]
if !ok {
return false, qerr.Error(qerr.CryptoMessageParameterNotFound, "SNI required")
}
sni := string(sniSlice)
if sni == "" {
return false, qerr.Error(qerr.CryptoMessageParameterNotFound, "SNI required")
}
h.sni = sni
// prevent version downgrade attacks
// see https://groups.google.com/a/chromium.org/forum/#!topic/proto-quic/N-de9j63tCk for a discussion and examples
verSlice, ok := cryptoData[TagVER]
if !ok {
return false, qerr.Error(qerr.InvalidCryptoMessageParameter, "client hello missing version tag")
}
if len(verSlice) != 4 {
return false, qerr.Error(qerr.InvalidCryptoMessageParameter, "incorrect version tag")
}
ver := protocol.VersionNumber(binary.BigEndian.Uint32(verSlice))
// If the client's preferred version is not the version we are currently speaking, then the client went through a version negotiation. In this case, we need to make sure that we actually do not support this version and that it wasn't a downgrade attack.
if ver != h.version && protocol.IsSupportedVersion(h.supportedVersions, ver) {
return false, qerr.Error(qerr.VersionNegotiationMismatch, "Downgrade attack detected")
}
var reply []byte
var err error
certUncompressed, err := h.scfg.certChain.GetLeafCert(sni)
if err != nil {
return false, err
}
params, err := readHelloMap(cryptoData)
if err != nil {
return false, err
}
// blocks until the session has received the parameters
if !h.receivedParams {
h.receivedParams = true
h.paramsChan <- *params
}
if !h.isInchoateCHLO(cryptoData, certUncompressed) {
// We have a CHLO with a proper server config ID, do a 0-RTT handshake
reply, err = h.handleCHLO(sni, chloData, cryptoData)
if err != nil {
return false, err
}
if _, err := h.cryptoStream.Write(reply); err != nil {
return false, err
}
h.handshakeEvent <- struct{}{}
close(h.sentSHLO)
return true, nil
}
// We have an inchoate or non-matching CHLO, we now send a rejection
reply, err = h.handleInchoateCHLO(sni, chloData, cryptoData)
if err != nil {
return false, err
}
_, err = h.cryptoStream.Write(reply)
return false, err
}
// Open a message
func (h *cryptoSetupServer) Open(dst, src []byte, packetNumber protocol.PacketNumber, associatedData []byte) ([]byte, protocol.EncryptionLevel, error) {
h.mutex.RLock()
defer h.mutex.RUnlock()
if h.forwardSecureAEAD != nil {
res, err := h.forwardSecureAEAD.Open(dst, src, packetNumber, associatedData)
if err == nil {
if !h.receivedForwardSecurePacket { // this is the first forward secure packet we receive from the client
h.logger.Debugf("Received first forward-secure packet. Stopping to accept all lower encryption levels.")
h.receivedForwardSecurePacket = true
// wait for the send on the handshakeEvent chan
<-h.sentSHLO
close(h.handshakeEvent)
}
return res, protocol.EncryptionForwardSecure, nil
}
if h.receivedForwardSecurePacket {
return nil, protocol.EncryptionUnspecified, err
}
}
if h.secureAEAD != nil {
res, err := h.secureAEAD.Open(dst, src, packetNumber, associatedData)
if err == nil {
h.logger.Debugf("Received first secure packet. Stopping to accept unencrypted packets.")
h.receivedSecurePacket = true
return res, protocol.EncryptionSecure, nil
}
if h.receivedSecurePacket {
return nil, protocol.EncryptionUnspecified, err
}
}
res, err := h.nullAEAD.Open(dst, src, packetNumber, associatedData)
if err != nil {
return res, protocol.EncryptionUnspecified, err
}
return res, protocol.EncryptionUnencrypted, err
}
func (h *cryptoSetupServer) GetSealer() (protocol.EncryptionLevel, Sealer) {
h.mutex.RLock()
defer h.mutex.RUnlock()
if h.forwardSecureAEAD != nil {
return protocol.EncryptionForwardSecure, h.forwardSecureAEAD
}
return protocol.EncryptionUnencrypted, h.nullAEAD
}
func (h *cryptoSetupServer) GetSealerForCryptoStream() (protocol.EncryptionLevel, Sealer) {
h.mutex.RLock()
defer h.mutex.RUnlock()
if h.secureAEAD != nil {
return protocol.EncryptionSecure, h.secureAEAD
}
return protocol.EncryptionUnencrypted, h.nullAEAD
}
func (h *cryptoSetupServer) GetSealerWithEncryptionLevel(encLevel protocol.EncryptionLevel) (Sealer, error) {
h.mutex.RLock()
defer h.mutex.RUnlock()
switch encLevel {
case protocol.EncryptionUnencrypted:
return h.nullAEAD, nil
case protocol.EncryptionSecure:
if h.secureAEAD == nil {
return nil, errors.New("CryptoSetupServer: no secureAEAD")
}
return h.secureAEAD, nil
case protocol.EncryptionForwardSecure:
if h.forwardSecureAEAD == nil {
return nil, errors.New("CryptoSetupServer: no forwardSecureAEAD")
}
return h.forwardSecureAEAD, nil
}
return nil, errors.New("CryptoSetupServer: no encryption level specified")
}
func (h *cryptoSetupServer) isInchoateCHLO(cryptoData map[Tag][]byte, cert []byte) bool {
if _, ok := cryptoData[TagPUBS]; !ok {
return true
}
scid, ok := cryptoData[TagSCID]
if !ok || !bytes.Equal(h.scfg.ID, scid) {
return true
}
xlctTag, ok := cryptoData[TagXLCT]
if !ok || len(xlctTag) != 8 {
return true
}
xlct := binary.LittleEndian.Uint64(xlctTag)
if crypto.HashCert(cert) != xlct {
return true
}
return !h.acceptSTK(cryptoData[TagSTK])
}
func (h *cryptoSetupServer) acceptSTK(token []byte) bool {
stk, err := h.scfg.cookieGenerator.DecodeToken(token)
if err != nil {
h.logger.Debugf("STK invalid: %s", err.Error())
return false
}
return h.acceptSTKCallback(h.remoteAddr, stk)
}
func (h *cryptoSetupServer) handleInchoateCHLO(sni string, chlo []byte, cryptoData map[Tag][]byte) ([]byte, error) {
token, err := h.scfg.cookieGenerator.NewToken(h.remoteAddr)
if err != nil {
return nil, err
}
replyMap := map[Tag][]byte{
TagSCFG: h.scfg.Get(),
TagSTK: token,
TagSVID: []byte("quic-go"),
}
if h.acceptSTK(cryptoData[TagSTK]) {
proof, err := h.scfg.Sign(sni, chlo)
if err != nil {
return nil, err
}
commonSetHashes := cryptoData[TagCCS]
cachedCertsHashes := cryptoData[TagCCRT]
certCompressed, err := h.scfg.GetCertsCompressed(sni, commonSetHashes, cachedCertsHashes)
if err != nil {
return nil, err
}
// Token was valid, send more details
replyMap[TagPROF] = proof
replyMap[TagCERT] = certCompressed
}
message := HandshakeMessage{
Tag: TagREJ,
Data: replyMap,
}
var serverReply bytes.Buffer
message.Write(&serverReply)
h.logger.Debugf("Sending %s", message)
return serverReply.Bytes(), nil
}
func (h *cryptoSetupServer) handleCHLO(sni string, data []byte, cryptoData map[Tag][]byte) ([]byte, error) {
// We have a CHLO matching our server config, we can continue with the 0-RTT handshake
sharedSecret, err := h.scfg.kex.CalculateSharedKey(cryptoData[TagPUBS])
if err != nil {
return nil, err
}
h.mutex.Lock()
defer h.mutex.Unlock()
certUncompressed, err := h.scfg.certChain.GetLeafCert(sni)
if err != nil {
return nil, err
}
serverNonce := make([]byte, 32)
if _, err = rand.Read(serverNonce); err != nil {
return nil, err
}
clientNonce := cryptoData[TagNONC]
err = h.validateClientNonce(clientNonce)
if err != nil {
return nil, err
}
aead := cryptoData[TagAEAD]
if !bytes.Equal(aead, []byte("AESG")) {
return nil, qerr.Error(qerr.CryptoNoSupport, "Unsupported AEAD or KEXS")
}
kexs := cryptoData[TagKEXS]
if !bytes.Equal(kexs, []byte("C255")) {
return nil, qerr.Error(qerr.CryptoNoSupport, "Unsupported AEAD or KEXS")
}
h.secureAEAD, err = h.keyDerivation(
false,
sharedSecret,
clientNonce,
h.connID,
data,
h.scfg.Get(),
certUncompressed,
h.diversificationNonce,
protocol.PerspectiveServer,
)
if err != nil {
return nil, err
}
h.logger.Debugf("Creating AEAD for secure encryption.")
h.handshakeEvent <- struct{}{}
// Generate a new curve instance to derive the forward secure key
var fsNonce bytes.Buffer
fsNonce.Write(clientNonce)
fsNonce.Write(serverNonce)
ephermalKex, err := h.keyExchange()
if err != nil {
return nil, err
}
ephermalSharedSecret, err := ephermalKex.CalculateSharedKey(cryptoData[TagPUBS])
if err != nil {
return nil, err
}
h.forwardSecureAEAD, err = h.keyDerivation(
true,
ephermalSharedSecret,
fsNonce.Bytes(),
h.connID,
data,
h.scfg.Get(),
certUncompressed,
nil,
protocol.PerspectiveServer,
)
if err != nil {
return nil, err
}
h.logger.Debugf("Creating AEAD for forward-secure encryption.")
replyMap := h.params.getHelloMap()
// add crypto parameters
verTag := &bytes.Buffer{}
for _, v := range h.supportedVersions {
utils.BigEndian.WriteUint32(verTag, uint32(v))
}
replyMap[TagPUBS] = ephermalKex.PublicKey()
replyMap[TagSNO] = serverNonce
replyMap[TagVER] = verTag.Bytes()
// note that the SHLO *has* to fit into one packet
message := HandshakeMessage{
Tag: TagSHLO,
Data: replyMap,
}
var reply bytes.Buffer
message.Write(&reply)
h.logger.Debugf("Sending %s", message)
return reply.Bytes(), nil
}
func (h *cryptoSetupServer) ConnectionState() ConnectionState {
h.mutex.Lock()
defer h.mutex.Unlock()
return ConnectionState{
ServerName: h.sni,
HandshakeComplete: h.receivedForwardSecurePacket,
}
}
func (h *cryptoSetupServer) validateClientNonce(nonce []byte) error {
if len(nonce) != 32 {
return qerr.Error(qerr.InvalidCryptoMessageParameter, "invalid client nonce length")
}
if !bytes.Equal(nonce[4:12], h.scfg.obit) {
return qerr.Error(qerr.InvalidCryptoMessageParameter, "OBIT not matching")
}
return nil
}

View File

@@ -1,163 +0,0 @@
package handshake
import (
"errors"
"fmt"
"io"
"sync"
"github.com/bifurcation/mint"
"github.com/lucas-clemente/quic-go/internal/crypto"
"github.com/lucas-clemente/quic-go/internal/protocol"
)
// KeyDerivationFunction is used for key derivation
type KeyDerivationFunction func(crypto.TLSExporter, protocol.Perspective) (crypto.AEAD, error)
type cryptoSetupTLS struct {
mutex sync.RWMutex
perspective protocol.Perspective
keyDerivation KeyDerivationFunction
nullAEAD crypto.AEAD
aead crypto.AEAD
tls mintTLS
conn *cryptoStreamConn
handshakeEvent chan<- struct{}
}
var _ CryptoSetupTLS = &cryptoSetupTLS{}
// NewCryptoSetupTLSServer creates a new TLS CryptoSetup instance for a server
func NewCryptoSetupTLSServer(
cryptoStream io.ReadWriter,
connID protocol.ConnectionID,
config *mint.Config,
handshakeEvent chan<- struct{},
version protocol.VersionNumber,
) (CryptoSetupTLS, error) {
nullAEAD, err := crypto.NewNullAEAD(protocol.PerspectiveServer, connID, version)
if err != nil {
return nil, err
}
conn := newCryptoStreamConn(cryptoStream)
tls := mint.Server(conn, config)
return &cryptoSetupTLS{
tls: tls,
conn: conn,
nullAEAD: nullAEAD,
perspective: protocol.PerspectiveServer,
keyDerivation: crypto.DeriveAESKeys,
handshakeEvent: handshakeEvent,
}, nil
}
// NewCryptoSetupTLSClient creates a new TLS CryptoSetup instance for a client
func NewCryptoSetupTLSClient(
cryptoStream io.ReadWriter,
connID protocol.ConnectionID,
config *mint.Config,
handshakeEvent chan<- struct{},
version protocol.VersionNumber,
) (CryptoSetupTLS, error) {
nullAEAD, err := crypto.NewNullAEAD(protocol.PerspectiveClient, connID, version)
if err != nil {
return nil, err
}
conn := newCryptoStreamConn(cryptoStream)
tls := mint.Client(conn, config)
return &cryptoSetupTLS{
tls: tls,
conn: conn,
perspective: protocol.PerspectiveClient,
nullAEAD: nullAEAD,
keyDerivation: crypto.DeriveAESKeys,
handshakeEvent: handshakeEvent,
}, nil
}
func (h *cryptoSetupTLS) HandleCryptoStream() error {
for {
if alert := h.tls.Handshake(); alert != mint.AlertNoAlert {
return fmt.Errorf("TLS handshake error: %s (Alert %d)", alert.String(), alert)
}
state := h.tls.ConnectionState().HandshakeState
if err := h.conn.Flush(); err != nil {
return err
}
if state == mint.StateClientConnected || state == mint.StateServerConnected {
break
}
}
aead, err := h.keyDerivation(h.tls, h.perspective)
if err != nil {
return err
}
h.mutex.Lock()
h.aead = aead
h.mutex.Unlock()
h.handshakeEvent <- struct{}{}
close(h.handshakeEvent)
return nil
}
func (h *cryptoSetupTLS) OpenHandshake(dst, src []byte, packetNumber protocol.PacketNumber, associatedData []byte) ([]byte, error) {
return h.nullAEAD.Open(dst, src, packetNumber, associatedData)
}
func (h *cryptoSetupTLS) Open1RTT(dst, src []byte, packetNumber protocol.PacketNumber, associatedData []byte) ([]byte, error) {
h.mutex.RLock()
defer h.mutex.RUnlock()
if h.aead == nil {
return nil, errors.New("no 1-RTT sealer")
}
return h.aead.Open(dst, src, packetNumber, associatedData)
}
func (h *cryptoSetupTLS) GetSealer() (protocol.EncryptionLevel, Sealer) {
h.mutex.RLock()
defer h.mutex.RUnlock()
if h.aead != nil {
return protocol.EncryptionForwardSecure, h.aead
}
return protocol.EncryptionUnencrypted, h.nullAEAD
}
func (h *cryptoSetupTLS) GetSealerWithEncryptionLevel(encLevel protocol.EncryptionLevel) (Sealer, error) {
errNoSealer := fmt.Errorf("CryptoSetup: no sealer with encryption level %s", encLevel.String())
h.mutex.RLock()
defer h.mutex.RUnlock()
switch encLevel {
case protocol.EncryptionUnencrypted:
return h.nullAEAD, nil
case protocol.EncryptionForwardSecure:
if h.aead == nil {
return nil, errNoSealer
}
return h.aead, nil
default:
return nil, errNoSealer
}
}
func (h *cryptoSetupTLS) GetSealerForCryptoStream() (protocol.EncryptionLevel, Sealer) {
return protocol.EncryptionUnencrypted, h.nullAEAD
}
func (h *cryptoSetupTLS) ConnectionState() ConnectionState {
h.mutex.Lock()
defer h.mutex.Unlock()
mintConnState := h.tls.ConnectionState()
return ConnectionState{
// TODO: set the ServerName, once mint exports it
HandshakeComplete: h.aead != nil,
PeerCertificates: mintConnState.PeerCertificates,
}
}

View File

@@ -1,69 +0,0 @@
package handshake
import (
"bytes"
"io"
"net"
"time"
)
type cryptoStreamConn struct {
buffer *bytes.Buffer
stream io.ReadWriter
}
var _ net.Conn = &cryptoStreamConn{}
func newCryptoStreamConn(stream io.ReadWriter) *cryptoStreamConn {
return &cryptoStreamConn{
stream: stream,
buffer: &bytes.Buffer{},
}
}
func (c *cryptoStreamConn) Read(b []byte) (int, error) {
return c.stream.Read(b)
}
func (c *cryptoStreamConn) Write(p []byte) (int, error) {
return c.buffer.Write(p)
}
func (c *cryptoStreamConn) Flush() error {
if c.buffer.Len() == 0 {
return nil
}
_, err := c.stream.Write(c.buffer.Bytes())
c.buffer.Reset()
return err
}
// Close is not implemented
func (c *cryptoStreamConn) Close() error {
return nil
}
// LocalAddr is not implemented
func (c *cryptoStreamConn) LocalAddr() net.Addr {
return nil
}
// RemoteAddr is not implemented
func (c *cryptoStreamConn) RemoteAddr() net.Addr {
return nil
}
// SetReadDeadline is not implemented
func (c *cryptoStreamConn) SetReadDeadline(time.Time) error {
return nil
}
// SetWriteDeadline is not implemented
func (c *cryptoStreamConn) SetWriteDeadline(time.Time) error {
return nil
}
// SetDeadline is not implemented
func (c *cryptoStreamConn) SetDeadline(time.Time) error {
return nil
}

View File

@@ -1,48 +0,0 @@
package handshake
import (
"sync"
"time"
"github.com/lucas-clemente/quic-go/internal/crypto"
"github.com/lucas-clemente/quic-go/internal/protocol"
)
var (
kexLifetime = protocol.EphermalKeyLifetime
kexCurrent crypto.KeyExchange
kexCurrentTime time.Time
kexMutex sync.RWMutex
)
// getEphermalKEX returns the currently active KEX, which changes every protocol.EphermalKeyLifetime
// See the explanation from the QUIC crypto doc:
//
// A single connection is the usual scope for forward security, but the security
// difference between an ephemeral key used for a single connection, and one
// used for all connections for 60 seconds is negligible. Thus we can amortise
// the Diffie-Hellman key generation at the server over all the connections in a
// small time span.
func getEphermalKEX() (crypto.KeyExchange, error) {
kexMutex.RLock()
res := kexCurrent
t := kexCurrentTime
kexMutex.RUnlock()
if res != nil && time.Since(t) < kexLifetime {
return res, nil
}
kexMutex.Lock()
defer kexMutex.Unlock()
// Check if still unfulfilled
if kexCurrent == nil || time.Since(kexCurrentTime) >= kexLifetime {
kex, err := crypto.NewCurve25519KEX()
if err != nil {
return nil, err
}
kexCurrent = kex
kexCurrentTime = time.Now()
return kexCurrent, nil
}
return kexCurrent, nil
}

View File

@@ -1,137 +0,0 @@
package handshake
import (
"bytes"
"encoding/binary"
"fmt"
"io"
"sort"
"github.com/lucas-clemente/quic-go/internal/protocol"
"github.com/lucas-clemente/quic-go/internal/utils"
"github.com/lucas-clemente/quic-go/qerr"
)
// A HandshakeMessage is a handshake message
type HandshakeMessage struct {
Tag Tag
Data map[Tag][]byte
}
var _ fmt.Stringer = &HandshakeMessage{}
// ParseHandshakeMessage reads a crypto message
func ParseHandshakeMessage(r io.Reader) (HandshakeMessage, error) {
slice4 := make([]byte, 4)
if _, err := io.ReadFull(r, slice4); err != nil {
return HandshakeMessage{}, err
}
messageTag := Tag(binary.LittleEndian.Uint32(slice4))
if _, err := io.ReadFull(r, slice4); err != nil {
return HandshakeMessage{}, err
}
nPairs := binary.LittleEndian.Uint32(slice4)
if nPairs > protocol.CryptoMaxParams {
return HandshakeMessage{}, qerr.CryptoTooManyEntries
}
index := make([]byte, nPairs*8)
if _, err := io.ReadFull(r, index); err != nil {
return HandshakeMessage{}, err
}
resultMap := map[Tag][]byte{}
var dataStart uint32
for indexPos := 0; indexPos < int(nPairs)*8; indexPos += 8 {
tag := Tag(binary.LittleEndian.Uint32(index[indexPos : indexPos+4]))
dataEnd := binary.LittleEndian.Uint32(index[indexPos+4 : indexPos+8])
dataLen := dataEnd - dataStart
if dataLen > protocol.CryptoParameterMaxLength {
return HandshakeMessage{}, qerr.Error(qerr.CryptoInvalidValueLength, "value too long")
}
data := make([]byte, dataLen)
if _, err := io.ReadFull(r, data); err != nil {
return HandshakeMessage{}, err
}
resultMap[tag] = data
dataStart = dataEnd
}
return HandshakeMessage{
Tag: messageTag,
Data: resultMap}, nil
}
// Write writes a crypto message
func (h HandshakeMessage) Write(b *bytes.Buffer) {
data := h.Data
utils.LittleEndian.WriteUint32(b, uint32(h.Tag))
utils.LittleEndian.WriteUint16(b, uint16(len(data)))
utils.LittleEndian.WriteUint16(b, 0)
// Save current position in the buffer, so that we can update the index in-place later
indexStart := b.Len()
indexData := make([]byte, 8*len(data))
b.Write(indexData) // Will be updated later
offset := uint32(0)
for i, t := range h.getTagsSorted() {
v := data[t]
b.Write(v)
offset += uint32(len(v))
binary.LittleEndian.PutUint32(indexData[i*8:], uint32(t))
binary.LittleEndian.PutUint32(indexData[i*8+4:], offset)
}
// Now we write the index data for real
copy(b.Bytes()[indexStart:], indexData)
}
func (h *HandshakeMessage) getTagsSorted() []Tag {
tags := make([]Tag, len(h.Data))
i := 0
for t := range h.Data {
tags[i] = t
i++
}
sort.Slice(tags, func(i, j int) bool {
return tags[i] < tags[j]
})
return tags
}
func (h HandshakeMessage) String() string {
var pad string
res := tagToString(h.Tag) + ":\n"
for _, tag := range h.getTagsSorted() {
if tag == TagPAD {
pad = fmt.Sprintf("\t%s: (%d bytes)\n", tagToString(tag), len(h.Data[tag]))
} else {
res += fmt.Sprintf("\t%s: %#v\n", tagToString(tag), string(h.Data[tag]))
}
}
if len(pad) > 0 {
res += pad
}
return res
}
func tagToString(tag Tag) string {
b := make([]byte, 4)
binary.LittleEndian.PutUint32(b, uint32(tag))
for i := range b {
if b[i] == 0 {
b[i] = ' '
}
}
return string(b)
}

View File

@@ -0,0 +1,52 @@
package handshake
import (
"crypto"
"crypto/aes"
"github.com/lucas-clemente/quic-go/internal/protocol"
"github.com/marten-seemann/qtls"
)
var quicVersion1Salt = []byte{0xef, 0x4f, 0xb0, 0xab, 0xb4, 0x74, 0x70, 0xc4, 0x1b, 0xef, 0xcf, 0x80, 0x31, 0x33, 0x4f, 0xae, 0x48, 0x5e, 0x09, 0xa0}
// NewInitialAEAD creates a new AEAD for Initial encryption / decryption.
func NewInitialAEAD(connID protocol.ConnectionID, pers protocol.Perspective) (Sealer, Opener, error) {
clientSecret, serverSecret := computeSecrets(connID)
var mySecret, otherSecret []byte
if pers == protocol.PerspectiveClient {
mySecret = clientSecret
otherSecret = serverSecret
} else {
mySecret = serverSecret
otherSecret = clientSecret
}
myKey, myHPKey, myIV := computeInitialKeyAndIV(mySecret)
otherKey, otherHPKey, otherIV := computeInitialKeyAndIV(otherSecret)
encrypter := qtls.AEADAESGCMTLS13(myKey, myIV)
hpEncrypter, err := aes.NewCipher(myHPKey)
if err != nil {
return nil, nil, err
}
decrypter := qtls.AEADAESGCMTLS13(otherKey, otherIV)
hpDecrypter, err := aes.NewCipher(otherHPKey)
if err != nil {
return nil, nil, err
}
return newSealer(encrypter, hpEncrypter, false), newOpener(decrypter, hpDecrypter, false), nil
}
func computeSecrets(connID protocol.ConnectionID) (clientSecret, serverSecret []byte) {
initialSecret := qtls.HkdfExtract(crypto.SHA256, connID, quicVersion1Salt)
clientSecret = qtls.HkdfExpandLabel(crypto.SHA256, initialSecret, []byte{}, "client in", crypto.SHA256.Size())
serverSecret = qtls.HkdfExpandLabel(crypto.SHA256, initialSecret, []byte{}, "server in", crypto.SHA256.Size())
return
}
func computeInitialKeyAndIV(secret []byte) (key, hpKey, iv []byte) {
key = qtls.HkdfExpandLabel(crypto.SHA256, secret, []byte{}, "quic key", 16)
hpKey = qtls.HkdfExpandLabel(crypto.SHA256, secret, []byte{}, "quic hp", 16)
iv = qtls.HkdfExpandLabel(crypto.SHA256, secret, []byte{}, "quic iv", 12)
return
}

View File

@@ -1,55 +1,46 @@
package handshake
import (
"crypto/tls"
"crypto/x509"
"io"
"github.com/bifurcation/mint"
"github.com/lucas-clemente/quic-go/internal/crypto"
"github.com/lucas-clemente/quic-go/internal/protocol"
"github.com/marten-seemann/qtls"
)
// Opener opens a packet
type Opener interface {
Open(dst, src []byte, packetNumber protocol.PacketNumber, associatedData []byte) ([]byte, error)
DecryptHeader(sample []byte, firstByte *byte, pnBytes []byte)
}
// Sealer seals a packet
type Sealer interface {
Seal(dst, src []byte, packetNumber protocol.PacketNumber, associatedData []byte) []byte
EncryptHeader(sample []byte, firstByte *byte, pnBytes []byte)
Overhead() int
}
// mintTLS combines some methods needed to interact with mint.
type mintTLS interface {
crypto.TLSExporter
Handshake() mint.Alert
// A tlsExtensionHandler sends and received the QUIC TLS extension.
type tlsExtensionHandler interface {
GetExtensions(msgType uint8) []qtls.Extension
ReceivedExtensions(msgType uint8, exts []qtls.Extension)
TransportParameters() <-chan []byte
}
// A TLSExtensionHandler sends and received the QUIC TLS extension.
// It provides the parameters sent by the peer on a channel.
type TLSExtensionHandler interface {
Send(mint.HandshakeType, *mint.ExtensionList) error
Receive(mint.HandshakeType, *mint.ExtensionList) error
GetPeerParams() <-chan TransportParameters
}
// CryptoSetup handles the handshake and protecting / unprotecting packets
type CryptoSetup interface {
RunHandshake() error
io.Closer
ChangeConnectionID(protocol.ConnectionID) error
type baseCryptoSetup interface {
HandleCryptoStream() error
ConnectionState() ConnectionState
HandleMessage([]byte, protocol.EncryptionLevel) bool
ConnectionState() tls.ConnectionState
GetSealer() (protocol.EncryptionLevel, Sealer)
GetSealerWithEncryptionLevel(protocol.EncryptionLevel) (Sealer, error)
GetSealerForCryptoStream() (protocol.EncryptionLevel, Sealer)
}
// CryptoSetup is the crypto setup used by gQUIC
type CryptoSetup interface {
baseCryptoSetup
Open(dst, src []byte, packetNumber protocol.PacketNumber, associatedData []byte) ([]byte, protocol.EncryptionLevel, error)
}
// CryptoSetupTLS is the crypto setup used by IETF QUIC
type CryptoSetupTLS interface {
baseCryptoSetup
OpenHandshake(dst, src []byte, packetNumber protocol.PacketNumber, associatedData []byte) ([]byte, error)
Open1RTT(dst, src []byte, packetNumber protocol.PacketNumber, associatedData []byte) ([]byte, error)
GetOpener(protocol.EncryptionLevel) (Opener, error)
}
// ConnectionState records basic details about the QUIC connection.

View File

@@ -1,3 +0,0 @@
package handshake
//go:generate sh -c "../mockgen_internal.sh handshake mock_mint_tls_test.go github.com/lucas-clemente/quic-go/internal/handshake mintTLS"

View File

@@ -0,0 +1,132 @@
package handshake
import (
"crypto/tls"
"net"
"time"
"unsafe"
"github.com/marten-seemann/qtls"
)
type conn struct {
remoteAddr net.Addr
}
func newConn(remote net.Addr) net.Conn {
return &conn{remoteAddr: remote}
}
var _ net.Conn = &conn{}
func (c *conn) Read([]byte) (int, error) { return 0, nil }
func (c *conn) Write([]byte) (int, error) { return 0, nil }
func (c *conn) Close() error { return nil }
func (c *conn) RemoteAddr() net.Addr { return c.remoteAddr }
func (c *conn) LocalAddr() net.Addr { return nil }
func (c *conn) SetReadDeadline(time.Time) error { return nil }
func (c *conn) SetWriteDeadline(time.Time) error { return nil }
func (c *conn) SetDeadline(time.Time) error { return nil }
type clientSessionCache struct {
tls.ClientSessionCache
}
var _ qtls.ClientSessionCache = &clientSessionCache{}
func (c *clientSessionCache) Get(sessionKey string) (*qtls.ClientSessionState, bool) {
sess, ok := c.ClientSessionCache.Get(sessionKey)
if sess == nil {
return nil, ok
}
// qtls.ClientSessionState is identical to the tls.ClientSessionState.
// In order to allow users of quic-go to use a tls.Config,
// we need this workaround to use the ClientSessionCache.
// In unsafe.go we check that the two structs are actually identical.
usess := (*[unsafe.Sizeof(*sess)]byte)(unsafe.Pointer(sess))[:]
var session qtls.ClientSessionState
usession := (*[unsafe.Sizeof(session)]byte)(unsafe.Pointer(&session))[:]
copy(usession, usess)
return &session, ok
}
func (c *clientSessionCache) Put(sessionKey string, cs *qtls.ClientSessionState) {
// qtls.ClientSessionState is identical to the tls.ClientSessionState.
// In order to allow users of quic-go to use a tls.Config,
// we need this workaround to use the ClientSessionCache.
// In unsafe.go we check that the two structs are actually identical.
usess := (*[unsafe.Sizeof(*cs)]byte)(unsafe.Pointer(cs))[:]
var session tls.ClientSessionState
usession := (*[unsafe.Sizeof(session)]byte)(unsafe.Pointer(&session))[:]
copy(usession, usess)
c.ClientSessionCache.Put(sessionKey, &session)
}
func tlsConfigToQtlsConfig(
c *tls.Config,
recordLayer qtls.RecordLayer,
extHandler tlsExtensionHandler,
) *qtls.Config {
if c == nil {
c = &tls.Config{}
}
// Clone the config first. This executes the tls.Config.serverInit().
// This sets the SessionTicketKey, if the user didn't supply one.
c = c.Clone()
// QUIC requires TLS 1.3 or newer
minVersion := c.MinVersion
if minVersion < qtls.VersionTLS13 {
minVersion = qtls.VersionTLS13
}
maxVersion := c.MaxVersion
if maxVersion < qtls.VersionTLS13 {
maxVersion = qtls.VersionTLS13
}
var getConfigForClient func(ch *tls.ClientHelloInfo) (*qtls.Config, error)
if c.GetConfigForClient != nil {
getConfigForClient = func(ch *tls.ClientHelloInfo) (*qtls.Config, error) {
tlsConf, err := c.GetConfigForClient(ch)
if err != nil {
return nil, err
}
if tlsConf == nil {
return nil, nil
}
return tlsConfigToQtlsConfig(tlsConf, recordLayer, extHandler), nil
}
}
var csc qtls.ClientSessionCache
if c.ClientSessionCache != nil {
csc = &clientSessionCache{c.ClientSessionCache}
}
return &qtls.Config{
Rand: c.Rand,
Time: c.Time,
Certificates: c.Certificates,
NameToCertificate: c.NameToCertificate,
GetCertificate: c.GetCertificate,
GetClientCertificate: c.GetClientCertificate,
GetConfigForClient: getConfigForClient,
VerifyPeerCertificate: c.VerifyPeerCertificate,
RootCAs: c.RootCAs,
NextProtos: c.NextProtos,
ServerName: c.ServerName,
ClientAuth: c.ClientAuth,
ClientCAs: c.ClientCAs,
InsecureSkipVerify: c.InsecureSkipVerify,
CipherSuites: c.CipherSuites,
PreferServerCipherSuites: c.PreferServerCipherSuites,
SessionTicketsDisabled: c.SessionTicketsDisabled,
SessionTicketKey: c.SessionTicketKey,
ClientSessionCache: csc,
MinVersion: minVersion,
MaxVersion: maxVersion,
CurvePreferences: c.CurvePreferences,
DynamicRecordSizingDisabled: c.DynamicRecordSizingDisabled,
// no need to copy Renegotiation, it's not supported by TLS 1.3
KeyLogWriter: c.KeyLogWriter,
AlternativeRecordLayer: recordLayer,
GetExtensions: extHandler.GetExtensions,
ReceivedExtensions: extHandler.ReceivedExtensions,
}
}

View File

@@ -1,73 +0,0 @@
package handshake
import (
"bytes"
"crypto/rand"
"github.com/lucas-clemente/quic-go/internal/crypto"
)
// ServerConfig is a server config
type ServerConfig struct {
kex crypto.KeyExchange
certChain crypto.CertChain
ID []byte
obit []byte
cookieGenerator *CookieGenerator
}
// NewServerConfig creates a new server config
func NewServerConfig(kex crypto.KeyExchange, certChain crypto.CertChain) (*ServerConfig, error) {
id := make([]byte, 16)
_, err := rand.Read(id)
if err != nil {
return nil, err
}
obit := make([]byte, 8)
if _, err = rand.Read(obit); err != nil {
return nil, err
}
cookieGenerator, err := NewCookieGenerator()
if err != nil {
return nil, err
}
return &ServerConfig{
kex: kex,
certChain: certChain,
ID: id,
obit: obit,
cookieGenerator: cookieGenerator,
}, nil
}
// Get the server config binary representation
func (s *ServerConfig) Get() []byte {
var serverConfig bytes.Buffer
msg := HandshakeMessage{
Tag: TagSCFG,
Data: map[Tag][]byte{
TagSCID: s.ID,
TagKEXS: []byte("C255"),
TagAEAD: []byte("AESG"),
TagPUBS: append([]byte{0x20, 0x00, 0x00}, s.kex.PublicKey()...),
TagOBIT: s.obit,
TagEXPY: {0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff},
},
}
msg.Write(&serverConfig)
return serverConfig.Bytes()
}
// Sign the server config and CHLO with the server's keyData
func (s *ServerConfig) Sign(sni string, chlo []byte) ([]byte, error) {
return s.certChain.SignServerProof(sni, chlo, s.Get())
}
// GetCertsCompressed returns the certificate data
func (s *ServerConfig) GetCertsCompressed(sni string, commonSetHashes, compressedHashes []byte) ([]byte, error) {
return s.certChain.GetCertsCompressed(sni, commonSetHashes, compressedHashes)
}

View File

@@ -1,184 +0,0 @@
package handshake
import (
"bytes"
"encoding/binary"
"errors"
"math"
"time"
"github.com/lucas-clemente/quic-go/internal/crypto"
"github.com/lucas-clemente/quic-go/internal/utils"
"github.com/lucas-clemente/quic-go/qerr"
)
type serverConfigClient struct {
raw []byte
ID []byte
obit []byte
expiry time.Time
kex crypto.KeyExchange
sharedSecret []byte
}
var (
errMessageNotServerConfig = errors.New("ServerConfig must have TagSCFG")
)
// parseServerConfig parses a server config
func parseServerConfig(data []byte) (*serverConfigClient, error) {
message, err := ParseHandshakeMessage(bytes.NewReader(data))
if err != nil {
return nil, err
}
if message.Tag != TagSCFG {
return nil, errMessageNotServerConfig
}
scfg := &serverConfigClient{raw: data}
err = scfg.parseValues(message.Data)
if err != nil {
return nil, err
}
return scfg, nil
}
func (s *serverConfigClient) parseValues(tagMap map[Tag][]byte) error {
// SCID
scfgID, ok := tagMap[TagSCID]
if !ok {
return qerr.Error(qerr.CryptoMessageParameterNotFound, "SCID")
}
if len(scfgID) != 16 {
return qerr.Error(qerr.CryptoInvalidValueLength, "SCID")
}
s.ID = scfgID
// KEXS
// TODO: setup Key Exchange
kexs, ok := tagMap[TagKEXS]
if !ok {
return qerr.Error(qerr.CryptoMessageParameterNotFound, "KEXS")
}
if len(kexs)%4 != 0 {
return qerr.Error(qerr.CryptoInvalidValueLength, "KEXS")
}
c255Foundat := -1
for i := 0; i < len(kexs)/4; i++ {
if bytes.Equal(kexs[4*i:4*i+4], []byte("C255")) {
c255Foundat = i
break
}
}
if c255Foundat < 0 {
return qerr.Error(qerr.CryptoNoSupport, "KEXS: Could not find C255, other key exchanges are not supported")
}
// AEAD
aead, ok := tagMap[TagAEAD]
if !ok {
return qerr.Error(qerr.CryptoMessageParameterNotFound, "AEAD")
}
if len(aead)%4 != 0 {
return qerr.Error(qerr.CryptoInvalidValueLength, "AEAD")
}
var aesgFound bool
for i := 0; i < len(aead)/4; i++ {
if bytes.Equal(aead[4*i:4*i+4], []byte("AESG")) {
aesgFound = true
break
}
}
if !aesgFound {
return qerr.Error(qerr.CryptoNoSupport, "AEAD")
}
// PUBS
pubs, ok := tagMap[TagPUBS]
if !ok {
return qerr.Error(qerr.CryptoMessageParameterNotFound, "PUBS")
}
var pubsKexs []struct {
Length uint32
Value []byte
}
var lastLen uint32
for i := 0; i < len(pubs)-3; i += int(lastLen) + 3 {
// the PUBS value is always prepended by 3 byte little endian length field
err := binary.Read(bytes.NewReader([]byte{pubs[i], pubs[i+1], pubs[i+2], 0x00}), binary.LittleEndian, &lastLen)
if err != nil {
return qerr.Error(qerr.CryptoInvalidValueLength, "PUBS not decodable")
}
if lastLen == 0 {
return qerr.Error(qerr.CryptoInvalidValueLength, "PUBS")
}
if i+3+int(lastLen) > len(pubs) {
return qerr.Error(qerr.CryptoInvalidValueLength, "PUBS")
}
pubsKexs = append(pubsKexs, struct {
Length uint32
Value []byte
}{lastLen, pubs[i+3 : i+3+int(lastLen)]})
}
if c255Foundat >= len(pubsKexs) {
return qerr.Error(qerr.CryptoMessageParameterNotFound, "KEXS not in PUBS")
}
if pubsKexs[c255Foundat].Length != 32 {
return qerr.Error(qerr.CryptoInvalidValueLength, "PUBS")
}
var err error
s.kex, err = crypto.NewCurve25519KEX()
if err != nil {
return err
}
s.sharedSecret, err = s.kex.CalculateSharedKey(pubsKexs[c255Foundat].Value)
if err != nil {
return err
}
// OBIT
obit, ok := tagMap[TagOBIT]
if !ok {
return qerr.Error(qerr.CryptoMessageParameterNotFound, "OBIT")
}
if len(obit) != 8 {
return qerr.Error(qerr.CryptoInvalidValueLength, "OBIT")
}
s.obit = obit
// EXPY
expy, ok := tagMap[TagEXPY]
if !ok {
return qerr.Error(qerr.CryptoMessageParameterNotFound, "EXPY")
}
if len(expy) != 8 {
return qerr.Error(qerr.CryptoInvalidValueLength, "EXPY")
}
// make sure that the value doesn't overflow an int64
// furthermore, values close to MaxInt64 are not a valid input to time.Unix, thus set MaxInt64/2 as the maximum value here
expyTimestamp := utils.MinUint64(binary.LittleEndian.Uint64(expy), math.MaxInt64/2)
s.expiry = time.Unix(int64(expyTimestamp), 0)
// TODO: implement VER
return nil
}
func (s *serverConfigClient) IsExpired() bool {
return s.expiry.Before(time.Now())
}
func (s *serverConfigClient) Get() []byte {
return s.raw
}

View File

@@ -1,93 +0,0 @@
package handshake
// A Tag in the QUIC crypto
type Tag uint32
const (
// TagCHLO is a client hello
TagCHLO Tag = 'C' + 'H'<<8 + 'L'<<16 + 'O'<<24
// TagREJ is a server hello rejection
TagREJ Tag = 'R' + 'E'<<8 + 'J'<<16
// TagSCFG is a server config
TagSCFG Tag = 'S' + 'C'<<8 + 'F'<<16 + 'G'<<24
// TagPAD is padding
TagPAD Tag = 'P' + 'A'<<8 + 'D'<<16
// TagSNI is the server name indication
TagSNI Tag = 'S' + 'N'<<8 + 'I'<<16
// TagVER is the QUIC version
TagVER Tag = 'V' + 'E'<<8 + 'R'<<16
// TagCCS are the hashes of the common certificate sets
TagCCS Tag = 'C' + 'C'<<8 + 'S'<<16
// TagCCRT are the hashes of the cached certificates
TagCCRT Tag = 'C' + 'C'<<8 + 'R'<<16 + 'T'<<24
// TagMSPC is max streams per connection
TagMSPC Tag = 'M' + 'S'<<8 + 'P'<<16 + 'C'<<24
// TagMIDS is max incoming dyanamic streams
TagMIDS Tag = 'M' + 'I'<<8 + 'D'<<16 + 'S'<<24
// TagUAID is the user agent ID
TagUAID Tag = 'U' + 'A'<<8 + 'I'<<16 + 'D'<<24
// TagSVID is the server ID (unofficial tag by us :)
TagSVID Tag = 'S' + 'V'<<8 + 'I'<<16 + 'D'<<24
// TagTCID is truncation of the connection ID
TagTCID Tag = 'T' + 'C'<<8 + 'I'<<16 + 'D'<<24
// TagPDMD is the proof demand
TagPDMD Tag = 'P' + 'D'<<8 + 'M'<<16 + 'D'<<24
// TagSRBF is the socket receive buffer
TagSRBF Tag = 'S' + 'R'<<8 + 'B'<<16 + 'F'<<24
// TagICSL is the idle connection state lifetime
TagICSL Tag = 'I' + 'C'<<8 + 'S'<<16 + 'L'<<24
// TagNONP is the client proof nonce
TagNONP Tag = 'N' + 'O'<<8 + 'N'<<16 + 'P'<<24
// TagSCLS is the silently close timeout
TagSCLS Tag = 'S' + 'C'<<8 + 'L'<<16 + 'S'<<24
// TagCSCT is the signed cert timestamp (RFC6962) of leaf cert
TagCSCT Tag = 'C' + 'S'<<8 + 'C'<<16 + 'T'<<24
// TagCOPT are the connection options
TagCOPT Tag = 'C' + 'O'<<8 + 'P'<<16 + 'T'<<24
// TagCFCW is the initial session/connection flow control receive window
TagCFCW Tag = 'C' + 'F'<<8 + 'C'<<16 + 'W'<<24
// TagSFCW is the initial stream flow control receive window.
TagSFCW Tag = 'S' + 'F'<<8 + 'C'<<16 + 'W'<<24
// TagNSTP is the no STOP_WAITING experiment
// currently unsupported by quic-go
TagNSTP Tag = 'N' + 'S'<<8 + 'T'<<16 + 'P'<<24
// TagSTK is the source-address token
TagSTK Tag = 'S' + 'T'<<8 + 'K'<<16
// TagSNO is the server nonce
TagSNO Tag = 'S' + 'N'<<8 + 'O'<<16
// TagPROF is the server proof
TagPROF Tag = 'P' + 'R'<<8 + 'O'<<16 + 'F'<<24
// TagNONC is the client nonce
TagNONC Tag = 'N' + 'O'<<8 + 'N'<<16 + 'C'<<24
// TagXLCT is the expected leaf certificate
TagXLCT Tag = 'X' + 'L'<<8 + 'C'<<16 + 'T'<<24
// TagSCID is the server config ID
TagSCID Tag = 'S' + 'C'<<8 + 'I'<<16 + 'D'<<24
// TagKEXS is the list of key exchange algos
TagKEXS Tag = 'K' + 'E'<<8 + 'X'<<16 + 'S'<<24
// TagAEAD is the list of AEAD algos
TagAEAD Tag = 'A' + 'E'<<8 + 'A'<<16 + 'D'<<24
// TagPUBS is the public value for the KEX
TagPUBS Tag = 'P' + 'U'<<8 + 'B'<<16 + 'S'<<24
// TagOBIT is the client orbit
TagOBIT Tag = 'O' + 'B'<<8 + 'I'<<16 + 'T'<<24
// TagEXPY is the server config expiry
TagEXPY Tag = 'E' + 'X'<<8 + 'P'<<16 + 'Y'<<24
// TagCERT is the CERT data
TagCERT Tag = 0xff545243
// TagSHLO is the server hello
TagSHLO Tag = 'S' + 'H'<<8 + 'L'<<16 + 'O'<<24
// TagPRST is the public reset tag
TagPRST Tag = 'P' + 'R'<<8 + 'S'<<16 + 'T'<<24
// TagRSEQ is the public reset rejected packet number
TagRSEQ Tag = 'R' + 'S'<<8 + 'E'<<16 + 'Q'<<24
// TagRNON is the public reset nonce
TagRNON Tag = 'R' + 'N'<<8 + 'O'<<16 + 'N'<<24
)

View File

@@ -1,123 +0,0 @@
package handshake
import (
"bytes"
"encoding/binary"
"errors"
"fmt"
"github.com/bifurcation/mint"
"github.com/lucas-clemente/quic-go/internal/protocol"
"github.com/lucas-clemente/quic-go/internal/utils"
)
type transportParameterID uint16
const quicTLSExtensionType = 0xff5
const (
initialMaxStreamDataParameterID transportParameterID = 0x0
initialMaxDataParameterID transportParameterID = 0x1
initialMaxBidiStreamsParameterID transportParameterID = 0x2
idleTimeoutParameterID transportParameterID = 0x3
maxPacketSizeParameterID transportParameterID = 0x5
statelessResetTokenParameterID transportParameterID = 0x6
initialMaxUniStreamsParameterID transportParameterID = 0x8
disableMigrationParameterID transportParameterID = 0x9
)
type clientHelloTransportParameters struct {
InitialVersion protocol.VersionNumber
Parameters TransportParameters
}
func (p *clientHelloTransportParameters) Marshal() []byte {
const lenOffset = 4
b := &bytes.Buffer{}
utils.BigEndian.WriteUint32(b, uint32(p.InitialVersion))
b.Write([]byte{0, 0}) // length. Will be replaced later
p.Parameters.marshal(b)
data := b.Bytes()
binary.BigEndian.PutUint16(data[lenOffset:lenOffset+2], uint16(len(data)-lenOffset-2))
return data
}
func (p *clientHelloTransportParameters) Unmarshal(data []byte) error {
if len(data) < 6 {
return errors.New("transport parameter data too short")
}
p.InitialVersion = protocol.VersionNumber(binary.BigEndian.Uint32(data[:4]))
paramsLen := int(binary.BigEndian.Uint16(data[4:6]))
data = data[6:]
if len(data) != paramsLen {
return fmt.Errorf("expected transport parameters to be %d bytes long, have %d", paramsLen, len(data))
}
return p.Parameters.unmarshal(data)
}
type encryptedExtensionsTransportParameters struct {
NegotiatedVersion protocol.VersionNumber
SupportedVersions []protocol.VersionNumber
Parameters TransportParameters
}
func (p *encryptedExtensionsTransportParameters) Marshal() []byte {
b := &bytes.Buffer{}
utils.BigEndian.WriteUint32(b, uint32(p.NegotiatedVersion))
b.WriteByte(uint8(4 * len(p.SupportedVersions)))
for _, v := range p.SupportedVersions {
utils.BigEndian.WriteUint32(b, uint32(v))
}
lenOffset := b.Len()
b.Write([]byte{0, 0}) // length. Will be replaced later
p.Parameters.marshal(b)
data := b.Bytes()
binary.BigEndian.PutUint16(data[lenOffset:lenOffset+2], uint16(len(data)-lenOffset-2))
return data
}
func (p *encryptedExtensionsTransportParameters) Unmarshal(data []byte) error {
if len(data) < 5 {
return errors.New("transport parameter data too short")
}
p.NegotiatedVersion = protocol.VersionNumber(binary.BigEndian.Uint32(data[:4]))
numVersions := int(data[4])
if numVersions%4 != 0 {
return fmt.Errorf("invalid length for version list: %d", numVersions)
}
numVersions /= 4
data = data[5:]
if len(data) < 4*numVersions+2 /*length field for the parameter list */ {
return errors.New("transport parameter data too short")
}
p.SupportedVersions = make([]protocol.VersionNumber, numVersions)
for i := 0; i < numVersions; i++ {
p.SupportedVersions[i] = protocol.VersionNumber(binary.BigEndian.Uint32(data[:4]))
data = data[4:]
}
paramsLen := int(binary.BigEndian.Uint16(data[:2]))
data = data[2:]
if len(data) != paramsLen {
return fmt.Errorf("expected transport parameters to be %d bytes long, have %d", paramsLen, len(data))
}
return p.Parameters.unmarshal(data)
}
type tlsExtensionBody struct {
data []byte
}
var _ mint.ExtensionBody = &tlsExtensionBody{}
func (e *tlsExtensionBody) Type() mint.ExtensionType {
return quicTLSExtensionType
}
func (e *tlsExtensionBody) Marshal() ([]byte, error) {
return e.data, nil
}
func (e *tlsExtensionBody) Unmarshal(data []byte) (int, error) {
e.data = data
return len(data), nil
}

View File

@@ -0,0 +1,58 @@
package handshake
import (
"github.com/lucas-clemente/quic-go/internal/protocol"
"github.com/marten-seemann/qtls"
)
const quicTLSExtensionType = 0xffa5
type extensionHandler struct {
ourParams []byte
paramsChan chan []byte
perspective protocol.Perspective
}
var _ tlsExtensionHandler = &extensionHandler{}
// newExtensionHandler creates a new extension handler
func newExtensionHandler(params []byte, pers protocol.Perspective) tlsExtensionHandler {
return &extensionHandler{
ourParams: params,
paramsChan: make(chan []byte),
perspective: pers,
}
}
func (h *extensionHandler) GetExtensions(msgType uint8) []qtls.Extension {
if (h.perspective == protocol.PerspectiveClient && messageType(msgType) != typeClientHello) ||
(h.perspective == protocol.PerspectiveServer && messageType(msgType) != typeEncryptedExtensions) {
return nil
}
return []qtls.Extension{{
Type: quicTLSExtensionType,
Data: h.ourParams,
}}
}
func (h *extensionHandler) ReceivedExtensions(msgType uint8, exts []qtls.Extension) {
if (h.perspective == protocol.PerspectiveClient && messageType(msgType) != typeEncryptedExtensions) ||
(h.perspective == protocol.PerspectiveServer && messageType(msgType) != typeClientHello) {
return
}
var data []byte
for _, ext := range exts {
if ext.Type == quicTLSExtensionType {
data = ext.Data
break
}
}
h.paramsChan <- data
}
func (h *extensionHandler) TransportParameters() <-chan []byte {
return h.paramsChan
}

View File

@@ -1,112 +0,0 @@
package handshake
import (
"errors"
"fmt"
"github.com/lucas-clemente/quic-go/qerr"
"github.com/bifurcation/mint"
"github.com/lucas-clemente/quic-go/internal/protocol"
"github.com/lucas-clemente/quic-go/internal/utils"
)
type extensionHandlerClient struct {
ourParams *TransportParameters
paramsChan chan TransportParameters
initialVersion protocol.VersionNumber
supportedVersions []protocol.VersionNumber
version protocol.VersionNumber
logger utils.Logger
}
var _ mint.AppExtensionHandler = &extensionHandlerClient{}
var _ TLSExtensionHandler = &extensionHandlerClient{}
// NewExtensionHandlerClient creates a new extension handler for the client.
func NewExtensionHandlerClient(
params *TransportParameters,
initialVersion protocol.VersionNumber,
supportedVersions []protocol.VersionNumber,
version protocol.VersionNumber,
logger utils.Logger,
) TLSExtensionHandler {
// The client reads the transport parameters from the Encrypted Extensions message.
// The paramsChan is used in the session's run loop's select statement.
// We have to use an unbuffered channel here to make sure that the session actually processes the transport parameters immediately.
paramsChan := make(chan TransportParameters)
return &extensionHandlerClient{
ourParams: params,
paramsChan: paramsChan,
initialVersion: initialVersion,
supportedVersions: supportedVersions,
version: version,
logger: logger,
}
}
func (h *extensionHandlerClient) Send(hType mint.HandshakeType, el *mint.ExtensionList) error {
if hType != mint.HandshakeTypeClientHello {
return nil
}
h.logger.Debugf("Sending Transport Parameters: %s", h.ourParams)
chtp := &clientHelloTransportParameters{
InitialVersion: h.initialVersion,
Parameters: *h.ourParams,
}
return el.Add(&tlsExtensionBody{data: chtp.Marshal()})
}
func (h *extensionHandlerClient) Receive(hType mint.HandshakeType, el *mint.ExtensionList) error {
ext := &tlsExtensionBody{}
found, err := el.Find(ext)
if err != nil {
return err
}
if hType != mint.HandshakeTypeEncryptedExtensions {
if found {
return fmt.Errorf("Unexpected QUIC extension in handshake message %d", hType)
}
return nil
}
// hType == mint.HandshakeTypeEncryptedExtensions
if !found {
return errors.New("EncryptedExtensions message didn't contain a QUIC extension")
}
eetp := &encryptedExtensionsTransportParameters{}
if err := eetp.Unmarshal(ext.data); err != nil {
return err
}
// check that the negotiated_version is the current version
if eetp.NegotiatedVersion != h.version {
return qerr.Error(qerr.VersionNegotiationMismatch, "current version doesn't match negotiated_version")
}
// check that the current version is included in the supported versions
if !protocol.IsSupportedVersion(eetp.SupportedVersions, h.version) {
return qerr.Error(qerr.VersionNegotiationMismatch, "current version not included in the supported versions")
}
// if version negotiation was performed, check that we would have selected the current version based on the supported versions sent by the server
if h.version != h.initialVersion {
negotiatedVersion, ok := protocol.ChooseSupportedVersion(h.supportedVersions, eetp.SupportedVersions)
if !ok || h.version != negotiatedVersion {
return qerr.Error(qerr.VersionNegotiationMismatch, "would have picked a different version")
}
}
// check that the server sent a stateless reset token
if len(eetp.Parameters.StatelessResetToken) == 0 {
return errors.New("server didn't sent stateless_reset_token")
}
h.logger.Debugf("Received Transport Parameters: %s", &eetp.Parameters)
h.paramsChan <- eetp.Parameters
return nil
}
func (h *extensionHandlerClient) GetPeerParams() <-chan TransportParameters {
return h.paramsChan
}

View File

@@ -1,100 +0,0 @@
package handshake
import (
"errors"
"fmt"
"github.com/lucas-clemente/quic-go/qerr"
"github.com/bifurcation/mint"
"github.com/lucas-clemente/quic-go/internal/protocol"
"github.com/lucas-clemente/quic-go/internal/utils"
)
type extensionHandlerServer struct {
ourParams *TransportParameters
paramsChan chan TransportParameters
version protocol.VersionNumber
supportedVersions []protocol.VersionNumber
logger utils.Logger
}
var _ mint.AppExtensionHandler = &extensionHandlerServer{}
var _ TLSExtensionHandler = &extensionHandlerServer{}
// NewExtensionHandlerServer creates a new extension handler for the server
func NewExtensionHandlerServer(
params *TransportParameters,
supportedVersions []protocol.VersionNumber,
version protocol.VersionNumber,
logger utils.Logger,
) TLSExtensionHandler {
// Processing the ClientHello is performed statelessly (and from a single go-routine).
// Therefore, we have to use a buffered chan to pass the transport parameters to that go routine.
paramsChan := make(chan TransportParameters, 1)
return &extensionHandlerServer{
ourParams: params,
paramsChan: paramsChan,
supportedVersions: supportedVersions,
version: version,
logger: logger,
}
}
func (h *extensionHandlerServer) Send(hType mint.HandshakeType, el *mint.ExtensionList) error {
if hType != mint.HandshakeTypeEncryptedExtensions {
return nil
}
h.logger.Debugf("Sending Transport Parameters: %s", h.ourParams)
eetp := &encryptedExtensionsTransportParameters{
NegotiatedVersion: h.version,
SupportedVersions: protocol.GetGreasedVersions(h.supportedVersions),
Parameters: *h.ourParams,
}
return el.Add(&tlsExtensionBody{data: eetp.Marshal()})
}
func (h *extensionHandlerServer) Receive(hType mint.HandshakeType, el *mint.ExtensionList) error {
ext := &tlsExtensionBody{}
found, err := el.Find(ext)
if err != nil {
return err
}
if hType != mint.HandshakeTypeClientHello {
if found {
return fmt.Errorf("Unexpected QUIC extension in handshake message %d", hType)
}
return nil
}
if !found {
return errors.New("ClientHello didn't contain a QUIC extension")
}
chtp := &clientHelloTransportParameters{}
if err := chtp.Unmarshal(ext.data); err != nil {
return err
}
// perform the stateless version negotiation validation:
// make sure that we would have sent a Version Negotiation Packet if the client offered the initial version
// this is the case if and only if the initial version is not contained in the supported versions
if chtp.InitialVersion != h.version && protocol.IsSupportedVersion(h.supportedVersions, chtp.InitialVersion) {
return qerr.Error(qerr.VersionNegotiationMismatch, "Client should have used the initial version")
}
// check that the client didn't send a stateless reset token
if len(chtp.Parameters.StatelessResetToken) != 0 {
// TODO: return the correct error type
return errors.New("client sent a stateless reset token")
}
h.logger.Debugf("Received Transport Parameters: %s", &chtp.Parameters)
h.paramsChan <- chtp.Parameters
return nil
}
func (h *extensionHandlerServer) GetPeerParams() <-chan TransportParameters {
return h.paramsChan
}

View File

@@ -3,158 +3,123 @@ package handshake
import (
"bytes"
"encoding/binary"
"errors"
"fmt"
"io"
"sort"
"time"
"github.com/lucas-clemente/quic-go/internal/protocol"
"github.com/lucas-clemente/quic-go/internal/utils"
"github.com/lucas-clemente/quic-go/qerr"
)
// errMalformedTag is returned when the tag value cannot be read
var errMalformedTag = qerr.Error(qerr.InvalidCryptoMessageParameter, "malformed Tag value")
type transportParameterID uint16
const (
originalConnectionIDParameterID transportParameterID = 0x0
idleTimeoutParameterID transportParameterID = 0x1
statelessResetTokenParameterID transportParameterID = 0x2
maxPacketSizeParameterID transportParameterID = 0x3
initialMaxDataParameterID transportParameterID = 0x4
initialMaxStreamDataBidiLocalParameterID transportParameterID = 0x5
initialMaxStreamDataBidiRemoteParameterID transportParameterID = 0x6
initialMaxStreamDataUniParameterID transportParameterID = 0x7
initialMaxStreamsBidiParameterID transportParameterID = 0x8
initialMaxStreamsUniParameterID transportParameterID = 0x9
ackDelayExponentParameterID transportParameterID = 0xa
disableMigrationParameterID transportParameterID = 0xc
)
// TransportParameters are parameters sent to the peer during the handshake
type TransportParameters struct {
StreamFlowControlWindow protocol.ByteCount
ConnectionFlowControlWindow protocol.ByteCount
InitialMaxStreamDataBidiLocal protocol.ByteCount
InitialMaxStreamDataBidiRemote protocol.ByteCount
InitialMaxStreamDataUni protocol.ByteCount
InitialMaxData protocol.ByteCount
AckDelayExponent uint8
MaxPacketSize protocol.ByteCount
MaxUniStreams uint16 // only used for IETF QUIC
MaxBidiStreams uint16 // only used for IETF QUIC
MaxStreams uint32 // only used for gQUIC
MaxUniStreams uint64
MaxBidiStreams uint64
OmitConnectionID bool // only used for gQUIC
IdleTimeout time.Duration
DisableMigration bool // only used for IETF QUIC
StatelessResetToken []byte // only used for IETF QUIC
IdleTimeout time.Duration
DisableMigration bool
StatelessResetToken *[16]byte
OriginalConnectionID protocol.ConnectionID
}
// readHelloMap reads the transport parameters from the tags sent in a gQUIC handshake message
func readHelloMap(tags map[Tag][]byte) (*TransportParameters, error) {
params := &TransportParameters{}
if value, ok := tags[TagTCID]; ok {
v, err := utils.LittleEndian.ReadUint32(bytes.NewBuffer(value))
if err != nil {
return nil, errMalformedTag
}
params.OmitConnectionID = (v == 0)
// Unmarshal the transport parameters
func (p *TransportParameters) Unmarshal(data []byte, sentBy protocol.Perspective) error {
if len(data) < 2 {
return errors.New("transport parameter data too short")
}
if value, ok := tags[TagMIDS]; ok {
v, err := utils.LittleEndian.ReadUint32(bytes.NewBuffer(value))
if err != nil {
return nil, errMalformedTag
}
params.MaxStreams = v
length := binary.BigEndian.Uint16(data[:2])
if len(data)-2 < int(length) {
return fmt.Errorf("expected transport parameters to be %d bytes long, have %d", length, len(data)-2)
}
if value, ok := tags[TagICSL]; ok {
v, err := utils.LittleEndian.ReadUint32(bytes.NewBuffer(value))
if err != nil {
return nil, errMalformedTag
}
params.IdleTimeout = utils.MaxDuration(protocol.MinRemoteIdleTimeout, time.Duration(v)*time.Second)
}
if value, ok := tags[TagSFCW]; ok {
v, err := utils.LittleEndian.ReadUint32(bytes.NewBuffer(value))
if err != nil {
return nil, errMalformedTag
}
params.StreamFlowControlWindow = protocol.ByteCount(v)
}
if value, ok := tags[TagCFCW]; ok {
v, err := utils.LittleEndian.ReadUint32(bytes.NewBuffer(value))
if err != nil {
return nil, errMalformedTag
}
params.ConnectionFlowControlWindow = protocol.ByteCount(v)
}
return params, nil
}
// GetHelloMap gets all parameters needed for the Hello message in the gQUIC handshake.
func (p *TransportParameters) getHelloMap() map[Tag][]byte {
sfcw := bytes.NewBuffer([]byte{})
utils.LittleEndian.WriteUint32(sfcw, uint32(p.StreamFlowControlWindow))
cfcw := bytes.NewBuffer([]byte{})
utils.LittleEndian.WriteUint32(cfcw, uint32(p.ConnectionFlowControlWindow))
mids := bytes.NewBuffer([]byte{})
utils.LittleEndian.WriteUint32(mids, p.MaxStreams)
icsl := bytes.NewBuffer([]byte{})
utils.LittleEndian.WriteUint32(icsl, uint32(p.IdleTimeout/time.Second))
tags := map[Tag][]byte{
TagICSL: icsl.Bytes(),
TagMIDS: mids.Bytes(),
TagCFCW: cfcw.Bytes(),
TagSFCW: sfcw.Bytes(),
}
if p.OmitConnectionID {
tags[TagTCID] = []byte{0, 0, 0, 0}
}
return tags
}
func (p *TransportParameters) unmarshal(data []byte) error {
// needed to check that every parameter is only sent at most once
var parameterIDs []transportParameterID
for len(data) >= 4 {
paramID := transportParameterID(binary.BigEndian.Uint16(data[:2]))
paramLen := int(binary.BigEndian.Uint16(data[2:4]))
data = data[4:]
if len(data) < paramLen {
return fmt.Errorf("remaining length (%d) smaller than parameter length (%d)", len(data), paramLen)
}
var readAckDelayExponent bool
r := bytes.NewReader(data[2:])
for r.Len() >= 4 {
paramIDInt, _ := utils.BigEndian.ReadUint16(r)
paramID := transportParameterID(paramIDInt)
paramLen, _ := utils.BigEndian.ReadUint16(r)
parameterIDs = append(parameterIDs, paramID)
switch paramID {
case initialMaxStreamDataParameterID:
if paramLen != 4 {
return fmt.Errorf("wrong length for initial_max_stream_data: %d (expected 4)", paramLen)
case ackDelayExponentParameterID:
readAckDelayExponent = true
fallthrough
case initialMaxStreamDataBidiLocalParameterID,
initialMaxStreamDataBidiRemoteParameterID,
initialMaxStreamDataUniParameterID,
initialMaxDataParameterID,
initialMaxStreamsBidiParameterID,
initialMaxStreamsUniParameterID,
idleTimeoutParameterID,
maxPacketSizeParameterID:
if err := p.readNumericTransportParameter(r, paramID, int(paramLen)); err != nil {
return err
}
p.StreamFlowControlWindow = protocol.ByteCount(binary.BigEndian.Uint32(data[:4]))
case initialMaxDataParameterID:
if paramLen != 4 {
return fmt.Errorf("wrong length for initial_max_data: %d (expected 4)", paramLen)
default:
if r.Len() < int(paramLen) {
return fmt.Errorf("remaining length (%d) smaller than parameter length (%d)", r.Len(), paramLen)
}
p.ConnectionFlowControlWindow = protocol.ByteCount(binary.BigEndian.Uint32(data[:4]))
case initialMaxBidiStreamsParameterID:
if paramLen != 2 {
return fmt.Errorf("wrong length for initial_max_stream_id_bidi: %d (expected 2)", paramLen)
switch paramID {
case disableMigrationParameterID:
if paramLen != 0 {
return fmt.Errorf("wrong length for disable_migration: %d (expected empty)", paramLen)
}
p.DisableMigration = true
case statelessResetTokenParameterID:
if sentBy == protocol.PerspectiveClient {
return errors.New("client sent a stateless_reset_token")
}
if paramLen != 16 {
return fmt.Errorf("wrong length for stateless_reset_token: %d (expected 16)", paramLen)
}
var token [16]byte
r.Read(token[:])
p.StatelessResetToken = &token
case originalConnectionIDParameterID:
if sentBy == protocol.PerspectiveClient {
return errors.New("client sent an original_connection_id")
}
p.OriginalConnectionID, _ = protocol.ReadConnectionID(r, int(paramLen))
default:
r.Seek(int64(paramLen), io.SeekCurrent)
}
p.MaxBidiStreams = binary.BigEndian.Uint16(data[:2])
case initialMaxUniStreamsParameterID:
if paramLen != 2 {
return fmt.Errorf("wrong length for initial_max_stream_id_uni: %d (expected 2)", paramLen)
}
p.MaxUniStreams = binary.BigEndian.Uint16(data[:2])
case idleTimeoutParameterID:
if paramLen != 2 {
return fmt.Errorf("wrong length for idle_timeout: %d (expected 2)", paramLen)
}
p.IdleTimeout = utils.MaxDuration(protocol.MinRemoteIdleTimeout, time.Duration(binary.BigEndian.Uint16(data[:2]))*time.Second)
case maxPacketSizeParameterID:
if paramLen != 2 {
return fmt.Errorf("wrong length for max_packet_size: %d (expected 2)", paramLen)
}
maxPacketSize := protocol.ByteCount(binary.BigEndian.Uint16(data[:2]))
if maxPacketSize < 1200 {
return fmt.Errorf("invalid value for max_packet_size: %d (minimum 1200)", maxPacketSize)
}
p.MaxPacketSize = maxPacketSize
case disableMigrationParameterID:
if paramLen != 0 {
return fmt.Errorf("wrong length for disable_migration: %d (expected empty)", paramLen)
}
p.DisableMigration = true
case statelessResetTokenParameterID:
if paramLen != 16 {
return fmt.Errorf("wrong length for stateless_reset_token: %d (expected 16)", paramLen)
}
p.StatelessResetToken = data[:16]
}
data = data[paramLen:]
}
if !readAckDelayExponent {
p.AckDelayExponent = protocol.DefaultAckDelayExponent
}
// check that every transport parameter was sent at most once
@@ -165,51 +130,131 @@ func (p *TransportParameters) unmarshal(data []byte) error {
}
}
if len(data) != 0 {
return fmt.Errorf("should have read all data. Still have %d bytes", len(data))
if r.Len() != 0 {
return fmt.Errorf("should have read all data. Still have %d bytes", r.Len())
}
return nil
}
func (p *TransportParameters) marshal(b *bytes.Buffer) {
// initial_max_stream_data
utils.BigEndian.WriteUint16(b, uint16(initialMaxStreamDataParameterID))
utils.BigEndian.WriteUint16(b, 4)
utils.BigEndian.WriteUint32(b, uint32(p.StreamFlowControlWindow))
func (p *TransportParameters) readNumericTransportParameter(
r *bytes.Reader,
paramID transportParameterID,
expectedLen int,
) error {
remainingLen := r.Len()
val, err := utils.ReadVarInt(r)
if err != nil {
return fmt.Errorf("error while reading transport parameter %d: %s", paramID, err)
}
if remainingLen-r.Len() != expectedLen {
return fmt.Errorf("inconsistent transport parameter length for %d", paramID)
}
switch paramID {
case initialMaxStreamDataBidiLocalParameterID:
p.InitialMaxStreamDataBidiLocal = protocol.ByteCount(val)
case initialMaxStreamDataBidiRemoteParameterID:
p.InitialMaxStreamDataBidiRemote = protocol.ByteCount(val)
case initialMaxStreamDataUniParameterID:
p.InitialMaxStreamDataUni = protocol.ByteCount(val)
case initialMaxDataParameterID:
p.InitialMaxData = protocol.ByteCount(val)
case initialMaxStreamsBidiParameterID:
p.MaxBidiStreams = val
case initialMaxStreamsUniParameterID:
p.MaxUniStreams = val
case idleTimeoutParameterID:
p.IdleTimeout = utils.MaxDuration(protocol.MinRemoteIdleTimeout, time.Duration(val)*time.Millisecond)
case maxPacketSizeParameterID:
if val < 1200 {
return fmt.Errorf("invalid value for max_packet_size: %d (minimum 1200)", val)
}
p.MaxPacketSize = protocol.ByteCount(val)
case ackDelayExponentParameterID:
if val > protocol.MaxAckDelayExponent {
return fmt.Errorf("invalid value for ack_delay_exponent: %d (maximum %d)", val, protocol.MaxAckDelayExponent)
}
p.AckDelayExponent = uint8(val)
default:
return fmt.Errorf("TransportParameter BUG: transport parameter %d not found", paramID)
}
return nil
}
// Marshal the transport parameters
func (p *TransportParameters) Marshal() []byte {
b := &bytes.Buffer{}
b.Write([]byte{0, 0}) // length. Will be replaced later
// initial_max_stream_data_bidi_local
utils.BigEndian.WriteUint16(b, uint16(initialMaxStreamDataBidiLocalParameterID))
utils.BigEndian.WriteUint16(b, uint16(utils.VarIntLen(uint64(p.InitialMaxStreamDataBidiLocal))))
utils.WriteVarInt(b, uint64(p.InitialMaxStreamDataBidiLocal))
// initial_max_stream_data_bidi_remote
utils.BigEndian.WriteUint16(b, uint16(initialMaxStreamDataBidiRemoteParameterID))
utils.BigEndian.WriteUint16(b, uint16(utils.VarIntLen(uint64(p.InitialMaxStreamDataBidiRemote))))
utils.WriteVarInt(b, uint64(p.InitialMaxStreamDataBidiRemote))
// initial_max_stream_data_uni
utils.BigEndian.WriteUint16(b, uint16(initialMaxStreamDataUniParameterID))
utils.BigEndian.WriteUint16(b, uint16(utils.VarIntLen(uint64(p.InitialMaxStreamDataUni))))
utils.WriteVarInt(b, uint64(p.InitialMaxStreamDataUni))
// initial_max_data
utils.BigEndian.WriteUint16(b, uint16(initialMaxDataParameterID))
utils.BigEndian.WriteUint16(b, 4)
utils.BigEndian.WriteUint32(b, uint32(p.ConnectionFlowControlWindow))
utils.BigEndian.WriteUint16(b, uint16(utils.VarIntLen(uint64(p.InitialMaxData))))
utils.WriteVarInt(b, uint64(p.InitialMaxData))
// initial_max_bidi_streams
utils.BigEndian.WriteUint16(b, uint16(initialMaxBidiStreamsParameterID))
utils.BigEndian.WriteUint16(b, 2)
utils.BigEndian.WriteUint16(b, p.MaxBidiStreams)
utils.BigEndian.WriteUint16(b, uint16(initialMaxStreamsBidiParameterID))
utils.BigEndian.WriteUint16(b, uint16(utils.VarIntLen(p.MaxBidiStreams)))
utils.WriteVarInt(b, p.MaxBidiStreams)
// initial_max_uni_streams
utils.BigEndian.WriteUint16(b, uint16(initialMaxUniStreamsParameterID))
utils.BigEndian.WriteUint16(b, 2)
utils.BigEndian.WriteUint16(b, p.MaxUniStreams)
utils.BigEndian.WriteUint16(b, uint16(initialMaxStreamsUniParameterID))
utils.BigEndian.WriteUint16(b, uint16(utils.VarIntLen(p.MaxUniStreams)))
utils.WriteVarInt(b, p.MaxUniStreams)
// idle_timeout
idleTimeout := uint64(p.IdleTimeout / time.Millisecond)
utils.BigEndian.WriteUint16(b, uint16(idleTimeoutParameterID))
utils.BigEndian.WriteUint16(b, 2)
utils.BigEndian.WriteUint16(b, uint16(p.IdleTimeout/time.Second))
utils.BigEndian.WriteUint16(b, uint16(utils.VarIntLen(idleTimeout)))
utils.WriteVarInt(b, idleTimeout)
// max_packet_size
utils.BigEndian.WriteUint16(b, uint16(maxPacketSizeParameterID))
utils.BigEndian.WriteUint16(b, 2)
utils.BigEndian.WriteUint16(b, uint16(protocol.MaxReceivePacketSize))
utils.BigEndian.WriteUint16(b, uint16(utils.VarIntLen(uint64(protocol.MaxReceivePacketSize))))
utils.WriteVarInt(b, uint64(protocol.MaxReceivePacketSize))
// ack_delay_exponent
// Only send it if is different from the default value.
if p.AckDelayExponent != protocol.DefaultAckDelayExponent {
utils.BigEndian.WriteUint16(b, uint16(ackDelayExponentParameterID))
utils.BigEndian.WriteUint16(b, uint16(utils.VarIntLen(uint64(p.AckDelayExponent))))
utils.WriteVarInt(b, uint64(p.AckDelayExponent))
}
// disable_migration
if p.DisableMigration {
utils.BigEndian.WriteUint16(b, uint16(disableMigrationParameterID))
utils.BigEndian.WriteUint16(b, 0)
}
if len(p.StatelessResetToken) > 0 {
if p.StatelessResetToken != nil {
utils.BigEndian.WriteUint16(b, uint16(statelessResetTokenParameterID))
utils.BigEndian.WriteUint16(b, uint16(len(p.StatelessResetToken))) // should always be 16 bytes
b.Write(p.StatelessResetToken)
utils.BigEndian.WriteUint16(b, 16)
b.Write(p.StatelessResetToken[:])
}
// original_connection_id
if p.OriginalConnectionID.Len() > 0 {
utils.BigEndian.WriteUint16(b, uint16(originalConnectionIDParameterID))
utils.BigEndian.WriteUint16(b, uint16(p.OriginalConnectionID.Len()))
b.Write(p.OriginalConnectionID.Bytes())
}
data := b.Bytes()
binary.BigEndian.PutUint16(data[:2], uint16(b.Len()-2))
return data
}
// String returns a string representation, intended for logging.
// It should only used for IETF QUIC.
func (p *TransportParameters) String() string {
return fmt.Sprintf("&handshake.TransportParameters{StreamFlowControlWindow: %#x, ConnectionFlowControlWindow: %#x, MaxBidiStreams: %d, MaxUniStreams: %d, IdleTimeout: %s}", p.StreamFlowControlWindow, p.ConnectionFlowControlWindow, p.MaxBidiStreams, p.MaxUniStreams, p.IdleTimeout)
logString := "&handshake.TransportParameters{OriginalConnectionID: %s, InitialMaxStreamDataBidiLocal: %#x, InitialMaxStreamDataBidiRemote: %#x, InitialMaxStreamDataUni: %#x, InitialMaxData: %#x, MaxBidiStreams: %d, MaxUniStreams: %d, IdleTimeout: %s, AckDelayExponent: %d"
logParams := []interface{}{p.OriginalConnectionID, p.InitialMaxStreamDataBidiLocal, p.InitialMaxStreamDataBidiRemote, p.InitialMaxStreamDataUni, p.InitialMaxData, p.MaxBidiStreams, p.MaxUniStreams, p.IdleTimeout, p.AckDelayExponent}
if p.StatelessResetToken != nil { // the client never sends a stateless reset token
logString += ", StatelessResetToken: %#x"
logParams = append(logParams, *p.StatelessResetToken)
}
logString += "}"
return fmt.Sprintf(logString, logParams...)
}

View File

@@ -0,0 +1,38 @@
package handshake
// This package uses unsafe to convert between:
// * qtls.ConnectionState and tls.ConnectionState
// * qtls.ClientSessionState and tls.ClientSessionState
// We check in init() that this conversion actually is safe.
import (
"crypto/tls"
"reflect"
"github.com/marten-seemann/qtls"
)
func init() {
if !structsEqual(&tls.ConnectionState{}, &qtls.ConnectionState{}) {
panic("qtls.ConnectionState not compatible with tls.ConnectionState")
}
if !structsEqual(&tls.ClientSessionState{}, &qtls.ClientSessionState{}) {
panic("qtls.ClientSessionState not compatible with tls.ClientSessionState")
}
}
func structsEqual(a, b interface{}) bool {
sa := reflect.ValueOf(a).Elem()
sb := reflect.ValueOf(b).Elem()
if sa.NumField() != sb.NumField() {
return false
}
for i := 0; i < sa.NumField(); i++ {
fa := sa.Type().Field(i)
fb := sb.Type().Field(i)
if !reflect.DeepEqual(fa.Index, fb.Index) || fa.Name != fb.Name || fa.Anonymous != fb.Anonymous || fa.Offset != fb.Offset || !reflect.DeepEqual(fa.Type, fb.Type) {
return false
}
}
return true
}