use go 1.12

Signed-off-by: hongming <talonwan@yunify.com>
This commit is contained in:
hongming
2019-03-12 15:47:56 +08:00
parent b59c244ca2
commit 4144404b0b
1110 changed files with 161100 additions and 14519 deletions

View File

@@ -0,0 +1,5 @@
root = true
[*]
indent_style = tab
indent_size = 2

3
vendor/github.com/lucas-clemente/quic-go/.gitignore generated vendored Normal file
View File

@@ -0,0 +1,3 @@
debug
debug.test
main

View File

@@ -0,0 +1,24 @@
{
"DisableAll": true,
"Exclude": [
"vendor",
"streams_map_incoming_generic.go",
"streams_map_outgoing_generic.go"
],
"Enable": [
"deadcode",
"goimports",
"ineffassign",
"megacheck",
"misspell",
"structcheck",
"unconvert",
"varcheck",
"vet"
],
"Linters": {
"vet": "go tool vet -printfuncs=Infof,Debugf,Warningf,Errorf:PATH:LINE:MESSAGE",
"misspell": "misspell -i ect:PATH:LINE:COL:MESSAGE",
"megacheck": "megacheck -ignore github.com/lucas-clemente/quic-go/h2quic/response_writer_closenotifier.go:SA1019:PATH:LINE:COL:MESSAGE"
}
}

52
vendor/github.com/lucas-clemente/quic-go/.travis.yml generated vendored Normal file
View File

@@ -0,0 +1,52 @@
dist: trusty
group: travis_latest
addons:
hosts:
- quic.clemente.io
language: go
go:
- "1.10.4"
- "1.11"
# first part of the GOARCH workaround
# setting the GOARCH directly doesn't work, since the value will be overwritten later
# so set it to a temporary environment variable first
env:
global:
- TIMESCALE_FACTOR=20
matrix:
- TRAVIS_GOARCH=amd64 TESTMODE=lint
- TRAVIS_GOARCH=amd64 TESTMODE=unit
- TRAVIS_GOARCH=amd64 TESTMODE=integration
- TRAVIS_GOARCH=386 TESTMODE=unit
- TRAVIS_GOARCH=386 TESTMODE=integration
# Linters might work differently in different Go versions.
# Only run them in the most recent one.
matrix:
exclude:
- go: "1.10.4"
env: TRAVIS_GOARCH=amd64 TESTMODE=lint
- go: "1.10.4"
env: TRAVIS_GOARCH=386 TESTMODE=lint
# second part of the GOARCH workaround
# now actually set the GOARCH env variable to the value of the temporary variable set earlier
before_install:
- go get golang.org/x/tools/cmd/cover
- go get github.com/onsi/ginkgo/ginkgo
- go get github.com/onsi/gomega
- export GOARCH=$TRAVIS_GOARCH
- go env # for debugging
- "printf \"quic.clemente.io certificate valid until: \" && openssl x509 -in example/fullchain.pem -enddate -noout | cut -d = -f 2"
- "export DISPLAY=:99.0"
- "Xvfb $DISPLAY &> /dev/null &"
script:
- .travis/script.sh
after_success:
- .travis/after_success.sh

44
vendor/github.com/lucas-clemente/quic-go/Changelog.md generated vendored Normal file
View File

@@ -0,0 +1,44 @@
# Changelog
## v0.10.0 (2018-08-28)
- Add support for QUIC 44, drop support for QUIC 42.
## v0.9.0 (2018-08-15)
- Add a `quic.Config` option for the length of the connection ID (for IETF QUIC).
- Split Session.Close into one method for regular closing and one for closing with an error.
## v0.8.0 (2018-06-26)
- Add support for unidirectional streams (for IETF QUIC).
- Add a `quic.Config` option for the maximum number of incoming streams.
- Add support for QUIC 42 and 43.
- Add dial functions that use a context.
- Multiplex clients on a net.PacketConn, when using Dial(conn).
## v0.7.0 (2018-02-03)
- The lower boundary for packets included in ACKs is now derived, and the value sent in STOP_WAITING frames is ignored.
- Remove `DialNonFWSecure` and `DialAddrNonFWSecure`.
- Expose the `ConnectionState` in the `Session` (experimental API).
- Implement packet pacing.
## v0.6.0 (2017-12-12)
- Add support for QUIC 39, drop support for QUIC 35 - 37
- Added `quic.Config` options for maximal flow control windows
- Add a `quic.Config` option for QUIC versions
- Add a `quic.Config` option to request omission of the connection ID from a server
- Add a `quic.Config` option to configure the source address validation
- Add a `quic.Config` option to configure the handshake timeout
- Add a `quic.Config` option to configure the idle timeout
- Add a `quic.Config` option to configure keep-alive
- Rename the STK to Cookie
- Implement `net.Conn`-style deadlines for streams
- Remove the `tls.Config` from the `quic.Config`. The `tls.Config` must now be passed to the `Dial` and `Listen` functions as a separate parameter. See the [Godoc](https://godoc.org/github.com/lucas-clemente/quic-go) for details.
- Changed the log level environment variable to only accept strings ("DEBUG", "INFO", "ERROR"), see [the wiki](https://github.com/lucas-clemente/quic-go/wiki/Logging) for more details.
- Rename the `h2quic.QuicRoundTripper` to `h2quic.RoundTripper`
- Changed `h2quic.Server.Serve()` to accept a `net.PacketConn`
- Drop support for Go 1.7 and 1.8.
- Various bugfixes

21
vendor/github.com/lucas-clemente/quic-go/LICENSE generated vendored Normal file
View File

@@ -0,0 +1,21 @@
MIT License
Copyright (c) 2016 the quic-go authors & Google, Inc.
Permission is hereby granted, free of charge, to any person obtaining a copy
of this software and associated documentation files (the "Software"), to deal
in the Software without restriction, including without limitation the rights
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
copies of the Software, and to permit persons to whom the Software is
furnished to do so, subject to the following conditions:
The above copyright notice and this permission notice shall be included in all
copies or substantial portions of the Software.
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
SOFTWARE.

73
vendor/github.com/lucas-clemente/quic-go/README.md generated vendored Normal file
View File

@@ -0,0 +1,73 @@
# A QUIC implementation in pure Go
<img src="docs/quic.png" width=303 height=124>
[![Godoc Reference](https://img.shields.io/badge/godoc-reference-blue.svg?style=flat-square)](https://godoc.org/github.com/lucas-clemente/quic-go)
[![Travis Build Status](https://img.shields.io/travis/lucas-clemente/quic-go/master.svg?style=flat-square&label=Travis+build)](https://travis-ci.org/lucas-clemente/quic-go)
[![CircleCI Build Status](https://img.shields.io/circleci/project/github/lucas-clemente/quic-go.svg?style=flat-square&label=CircleCI+build)](https://circleci.com/gh/lucas-clemente/quic-go)
[![Windows Build Status](https://img.shields.io/appveyor/ci/lucas-clemente/quic-go/master.svg?style=flat-square&label=windows+build)](https://ci.appveyor.com/project/lucas-clemente/quic-go/branch/master)
[![Code Coverage](https://img.shields.io/codecov/c/github/lucas-clemente/quic-go/master.svg?style=flat-square)](https://codecov.io/gh/lucas-clemente/quic-go/)
quic-go is an implementation of the [QUIC](https://en.wikipedia.org/wiki/QUIC) protocol in Go.
## Roadmap
quic-go is compatible with the current version(s) of Google Chrome and QUIC as deployed on Google's servers. We're actively tracking the development of the Chrome code to ensure compatibility as the protocol evolves. In that process, we're dropping support for old QUIC versions.
As Google's QUIC versions are expected to converge towards the [IETF QUIC draft](https://github.com/quicwg/base-drafts), quic-go will eventually implement that draft.
## Guides
We currently support Go 1.9+.
Installing and updating dependencies:
go get -t -u ./...
Running tests:
go test ./...
### Running the example server
go run example/main.go -www /var/www/
Using the `quic_client` from chromium:
quic_client --host=127.0.0.1 --port=6121 --v=1 https://quic.clemente.io
Using Chrome:
/Applications/Google\ Chrome.app/Contents/MacOS/Google\ Chrome --user-data-dir=/tmp/chrome --no-proxy-server --enable-quic --origin-to-force-quic-on=quic.clemente.io:443 --host-resolver-rules='MAP quic.clemente.io:443 127.0.0.1:6121' https://quic.clemente.io
### QUIC without HTTP/2
Take a look at [this echo example](example/echo/echo.go).
### Using the example client
go run example/client/main.go https://clemente.io
## Usage
### As a server
See the [example server](example/main.go) or try out [Caddy](https://github.com/mholt/caddy) (from version 0.9, [instructions here](https://github.com/mholt/caddy/wiki/QUIC)). Starting a QUIC server is very similar to the standard lib http in go:
```go
http.Handle("/", http.FileServer(http.Dir(wwwDir)))
h2quic.ListenAndServeQUIC("localhost:4242", "/path/to/cert/chain.pem", "/path/to/privkey.pem", nil)
```
### As a client
See the [example client](example/client/main.go). Use a `h2quic.RoundTripper` as a `Transport` in a `http.Client`.
```go
http.Client{
Transport: &h2quic.RoundTripper{},
}
```
## Contributing
We are always happy to welcome new contributors! We have a number of self-contained issues that are suitable for first-time contributors, they are tagged with [help wanted](https://github.com/lucas-clemente/quic-go/issues?q=is%3Aissue+is%3Aopen+label%3A%22help+wanted%22). If you have any questions, please feel free to reach out by opening an issue or leaving a comment.

37
vendor/github.com/lucas-clemente/quic-go/appveyor.yml generated vendored Normal file
View File

@@ -0,0 +1,37 @@
version: "{build}"
os: Windows Server 2012 R2
environment:
GOPATH: c:\gopath
CGO_ENABLED: 0
TIMESCALE_FACTOR: 20
matrix:
- GOARCH: 386
- GOARCH: amd64
hosts:
quic.clemente.io: 127.0.0.1
clone_folder: c:\gopath\src\github.com\lucas-clemente\quic-go
install:
- rmdir c:\go /s /q
- appveyor DownloadFile https://storage.googleapis.com/golang/go1.11.windows-amd64.zip
- 7z x go1.11.windows-amd64.zip -y -oC:\ > NUL
- set PATH=%PATH%;%GOPATH%\bin\windows_%GOARCH%;%GOPATH%\bin
- echo %PATH%
- echo %GOPATH%
- go get github.com/onsi/ginkgo/ginkgo
- go get github.com/onsi/gomega
- go version
- go env
- go get -v -t ./...
build_script:
- ginkgo -r -v -randomizeAllSpecs -randomizeSuites -trace -skipPackage benchmark,integrationtests
- ginkgo -randomizeAllSpecs -randomizeSuites -trace benchmark -- -samples=1
test: off
deploy: off

View File

@@ -0,0 +1,27 @@
package quic
import (
"sync"
"github.com/lucas-clemente/quic-go/internal/protocol"
)
var bufferPool sync.Pool
func getPacketBuffer() *[]byte {
return bufferPool.Get().(*[]byte)
}
func putPacketBuffer(buf *[]byte) {
if cap(*buf) != int(protocol.MaxReceivePacketSize) {
panic("putPacketBuffer called with packet of wrong size!")
}
bufferPool.Put(buf)
}
func init() {
bufferPool.New = func() interface{} {
b := make([]byte, 0, protocol.MaxReceivePacketSize)
return &b
}
}

587
vendor/github.com/lucas-clemente/quic-go/client.go generated vendored Normal file
View File

@@ -0,0 +1,587 @@
package quic
import (
"bytes"
"context"
"crypto/tls"
"errors"
"fmt"
"net"
"sync"
"github.com/bifurcation/mint"
"github.com/lucas-clemente/quic-go/internal/handshake"
"github.com/lucas-clemente/quic-go/internal/protocol"
"github.com/lucas-clemente/quic-go/internal/utils"
"github.com/lucas-clemente/quic-go/internal/wire"
"github.com/lucas-clemente/quic-go/qerr"
)
type client struct {
mutex sync.Mutex
conn connection
// If the client is created with DialAddr, we create a packet conn.
// If it is started with Dial, we take a packet conn as a parameter.
createdPacketConn bool
packetHandlers packetHandlerManager
token []byte
versionNegotiated bool // has the server accepted our version
receivedVersionNegotiationPacket bool
negotiatedVersions []protocol.VersionNumber // the list of versions from the version negotiation packet
tlsConf *tls.Config
mintConf *mint.Config
config *Config
srcConnID protocol.ConnectionID
destConnID protocol.ConnectionID
initialVersion protocol.VersionNumber
version protocol.VersionNumber
handshakeChan chan struct{}
closeCallback func(protocol.ConnectionID)
session quicSession
logger utils.Logger
}
var _ packetHandler = &client{}
var (
// make it possible to mock connection ID generation in the tests
generateConnectionID = protocol.GenerateConnectionID
generateConnectionIDForInitial = protocol.GenerateConnectionIDForInitial
errCloseSessionForNewVersion = errors.New("closing session in order to recreate it with a new version")
errCloseSessionForRetry = errors.New("closing session in response to a stateless retry")
)
// DialAddr establishes a new QUIC connection to a server.
// The hostname for SNI is taken from the given address.
func DialAddr(
addr string,
tlsConf *tls.Config,
config *Config,
) (Session, error) {
return DialAddrContext(context.Background(), addr, tlsConf, config)
}
// DialAddrContext establishes a new QUIC connection to a server using the provided context.
// The hostname for SNI is taken from the given address.
func DialAddrContext(
ctx context.Context,
addr string,
tlsConf *tls.Config,
config *Config,
) (Session, error) {
udpAddr, err := net.ResolveUDPAddr("udp", addr)
if err != nil {
return nil, err
}
udpConn, err := net.ListenUDP("udp", &net.UDPAddr{IP: net.IPv4zero, Port: 0})
if err != nil {
return nil, err
}
return dialContext(ctx, udpConn, udpAddr, addr, tlsConf, config, true)
}
// Dial establishes a new QUIC connection to a server using a net.PacketConn.
// The host parameter is used for SNI.
func Dial(
pconn net.PacketConn,
remoteAddr net.Addr,
host string,
tlsConf *tls.Config,
config *Config,
) (Session, error) {
return DialContext(context.Background(), pconn, remoteAddr, host, tlsConf, config)
}
// DialContext establishes a new QUIC connection to a server using a net.PacketConn using the provided context.
// The host parameter is used for SNI.
func DialContext(
ctx context.Context,
pconn net.PacketConn,
remoteAddr net.Addr,
host string,
tlsConf *tls.Config,
config *Config,
) (Session, error) {
return dialContext(ctx, pconn, remoteAddr, host, tlsConf, config, false)
}
func dialContext(
ctx context.Context,
pconn net.PacketConn,
remoteAddr net.Addr,
host string,
tlsConf *tls.Config,
config *Config,
createdPacketConn bool,
) (Session, error) {
config = populateClientConfig(config, createdPacketConn)
if !createdPacketConn {
for _, v := range config.Versions {
if v == protocol.Version44 {
return nil, errors.New("Cannot multiplex connections using gQUIC 44, see https://groups.google.com/a/chromium.org/forum/#!topic/proto-quic/pE9NlLLjizE. Please disable gQUIC 44 in the quic.Config, or use DialAddr")
}
}
}
packetHandlers, err := getMultiplexer().AddConn(pconn, config.ConnectionIDLength)
if err != nil {
return nil, err
}
c, err := newClient(pconn, remoteAddr, config, tlsConf, host, packetHandlers.Remove, createdPacketConn)
if err != nil {
return nil, err
}
c.packetHandlers = packetHandlers
if err := c.dial(ctx); err != nil {
return nil, err
}
return c.session, nil
}
func newClient(
pconn net.PacketConn,
remoteAddr net.Addr,
config *Config,
tlsConf *tls.Config,
host string,
closeCallback func(protocol.ConnectionID),
createdPacketConn bool,
) (*client, error) {
if tlsConf == nil {
tlsConf = &tls.Config{}
}
if tlsConf.ServerName == "" {
var err error
tlsConf.ServerName, _, err = net.SplitHostPort(host)
if err != nil {
return nil, err
}
}
// check that all versions are actually supported
if config != nil {
for _, v := range config.Versions {
if !protocol.IsValidVersion(v) {
return nil, fmt.Errorf("%s is not a valid QUIC version", v)
}
}
}
onClose := func(protocol.ConnectionID) {}
if closeCallback != nil {
onClose = closeCallback
}
c := &client{
conn: &conn{pconn: pconn, currentAddr: remoteAddr},
createdPacketConn: createdPacketConn,
tlsConf: tlsConf,
config: config,
version: config.Versions[0],
handshakeChan: make(chan struct{}),
closeCallback: onClose,
logger: utils.DefaultLogger.WithPrefix("client"),
}
return c, c.generateConnectionIDs()
}
// populateClientConfig populates fields in the quic.Config with their default values, if none are set
// it may be called with nil
func populateClientConfig(config *Config, createdPacketConn bool) *Config {
if config == nil {
config = &Config{}
}
versions := config.Versions
if len(versions) == 0 {
versions = protocol.SupportedVersions
}
handshakeTimeout := protocol.DefaultHandshakeTimeout
if config.HandshakeTimeout != 0 {
handshakeTimeout = config.HandshakeTimeout
}
idleTimeout := protocol.DefaultIdleTimeout
if config.IdleTimeout != 0 {
idleTimeout = config.IdleTimeout
}
maxReceiveStreamFlowControlWindow := config.MaxReceiveStreamFlowControlWindow
if maxReceiveStreamFlowControlWindow == 0 {
maxReceiveStreamFlowControlWindow = protocol.DefaultMaxReceiveStreamFlowControlWindowClient
}
maxReceiveConnectionFlowControlWindow := config.MaxReceiveConnectionFlowControlWindow
if maxReceiveConnectionFlowControlWindow == 0 {
maxReceiveConnectionFlowControlWindow = protocol.DefaultMaxReceiveConnectionFlowControlWindowClient
}
maxIncomingStreams := config.MaxIncomingStreams
if maxIncomingStreams == 0 {
maxIncomingStreams = protocol.DefaultMaxIncomingStreams
} else if maxIncomingStreams < 0 {
maxIncomingStreams = 0
}
maxIncomingUniStreams := config.MaxIncomingUniStreams
if maxIncomingUniStreams == 0 {
maxIncomingUniStreams = protocol.DefaultMaxIncomingUniStreams
} else if maxIncomingUniStreams < 0 {
maxIncomingUniStreams = 0
}
connIDLen := config.ConnectionIDLength
if connIDLen == 0 && !createdPacketConn {
connIDLen = protocol.DefaultConnectionIDLength
}
for _, v := range versions {
if v == protocol.Version44 {
connIDLen = 0
}
}
return &Config{
Versions: versions,
HandshakeTimeout: handshakeTimeout,
IdleTimeout: idleTimeout,
RequestConnectionIDOmission: config.RequestConnectionIDOmission,
ConnectionIDLength: connIDLen,
MaxReceiveStreamFlowControlWindow: maxReceiveStreamFlowControlWindow,
MaxReceiveConnectionFlowControlWindow: maxReceiveConnectionFlowControlWindow,
MaxIncomingStreams: maxIncomingStreams,
MaxIncomingUniStreams: maxIncomingUniStreams,
KeepAlive: config.KeepAlive,
}
}
func (c *client) generateConnectionIDs() error {
connIDLen := protocol.ConnectionIDLenGQUIC
if c.version.UsesTLS() {
connIDLen = c.config.ConnectionIDLength
}
srcConnID, err := generateConnectionID(connIDLen)
if err != nil {
return err
}
destConnID := srcConnID
if c.version.UsesTLS() {
destConnID, err = generateConnectionIDForInitial()
if err != nil {
return err
}
}
c.srcConnID = srcConnID
c.destConnID = destConnID
if c.version == protocol.Version44 {
c.srcConnID = nil
}
return nil
}
func (c *client) dial(ctx context.Context) error {
c.logger.Infof("Starting new connection to %s (%s -> %s), source connection ID %s, destination connection ID %s, version %s", c.tlsConf.ServerName, c.conn.LocalAddr(), c.conn.RemoteAddr(), c.srcConnID, c.destConnID, c.version)
var err error
if c.version.UsesTLS() {
err = c.dialTLS(ctx)
} else {
err = c.dialGQUIC(ctx)
}
return err
}
func (c *client) dialGQUIC(ctx context.Context) error {
if err := c.createNewGQUICSession(); err != nil {
return err
}
err := c.establishSecureConnection(ctx)
if err == errCloseSessionForNewVersion {
return c.dial(ctx)
}
return err
}
func (c *client) dialTLS(ctx context.Context) error {
params := &handshake.TransportParameters{
StreamFlowControlWindow: protocol.ReceiveStreamFlowControlWindow,
ConnectionFlowControlWindow: protocol.ReceiveConnectionFlowControlWindow,
IdleTimeout: c.config.IdleTimeout,
OmitConnectionID: c.config.RequestConnectionIDOmission,
MaxBidiStreams: uint16(c.config.MaxIncomingStreams),
MaxUniStreams: uint16(c.config.MaxIncomingUniStreams),
DisableMigration: true,
}
extHandler := handshake.NewExtensionHandlerClient(params, c.initialVersion, c.config.Versions, c.version, c.logger)
mintConf, err := tlsToMintConfig(c.tlsConf, protocol.PerspectiveClient)
if err != nil {
return err
}
mintConf.ExtensionHandler = extHandler
c.mintConf = mintConf
if err := c.createNewTLSSession(extHandler.GetPeerParams(), c.version); err != nil {
return err
}
err = c.establishSecureConnection(ctx)
if err == errCloseSessionForRetry || err == errCloseSessionForNewVersion {
return c.dial(ctx)
}
return err
}
// establishSecureConnection runs the session, and tries to establish a secure connection
// It returns:
// - errCloseSessionForNewVersion when the server sends a version negotiation packet
// - handshake.ErrCloseSessionForRetry when the server performs a stateless retry (for IETF QUIC)
// - any other error that might occur
// - when the connection is secure (for gQUIC), or forward-secure (for IETF QUIC)
func (c *client) establishSecureConnection(ctx context.Context) error {
errorChan := make(chan error, 1)
go func() {
err := c.session.run() // returns as soon as the session is closed
if err != errCloseSessionForRetry && err != errCloseSessionForNewVersion && c.createdPacketConn {
c.conn.Close()
}
errorChan <- err
}()
select {
case <-ctx.Done():
// The session will send a PeerGoingAway error to the server.
c.session.Close()
return ctx.Err()
case err := <-errorChan:
return err
case <-c.handshakeChan:
// handshake successfully completed
return nil
}
}
func (c *client) handlePacket(p *receivedPacket) {
if err := c.handlePacketImpl(p); err != nil {
c.logger.Errorf("error handling packet: %s", err)
}
}
func (c *client) handlePacketImpl(p *receivedPacket) error {
c.mutex.Lock()
defer c.mutex.Unlock()
// handle Version Negotiation Packets
if p.header.IsVersionNegotiation {
err := c.handleVersionNegotiationPacket(p.header)
if err != nil {
c.session.destroy(err)
}
// version negotiation packets have no payload
return err
}
if !c.version.UsesIETFHeaderFormat() {
connID := p.header.DestConnectionID
// reject packets with truncated connection id if we didn't request truncation
if !c.config.RequestConnectionIDOmission && connID.Len() == 0 {
return errors.New("received packet with truncated connection ID, but didn't request truncation")
}
// reject packets with the wrong connection ID
if connID.Len() > 0 && !connID.Equal(c.srcConnID) {
return fmt.Errorf("received a packet with an unexpected connection ID (%s, expected %s)", connID, c.srcConnID)
}
if p.header.ResetFlag {
return c.handlePublicReset(p)
}
} else {
// reject packets with the wrong connection ID
if !p.header.DestConnectionID.Equal(c.srcConnID) {
return fmt.Errorf("received a packet with an unexpected connection ID (%s, expected %s)", p.header.DestConnectionID, c.srcConnID)
}
}
if p.header.IsLongHeader {
switch p.header.Type {
case protocol.PacketTypeRetry:
c.handleRetryPacket(p.header)
return nil
case protocol.PacketTypeHandshake, protocol.PacketType0RTT:
default:
return fmt.Errorf("Received unsupported packet type: %s", p.header.Type)
}
}
// this is the first packet we are receiving
// since it is not a Version Negotiation Packet, this means the server supports the suggested version
if !c.versionNegotiated {
c.versionNegotiated = true
}
c.session.handlePacket(p)
return nil
}
func (c *client) handlePublicReset(p *receivedPacket) error {
cr := c.conn.RemoteAddr()
// check if the remote address and the connection ID match
// otherwise this might be an attacker trying to inject a PUBLIC_RESET to kill the connection
if cr.Network() != p.remoteAddr.Network() || cr.String() != p.remoteAddr.String() || !p.header.DestConnectionID.Equal(c.srcConnID) {
return errors.New("Received a spoofed Public Reset")
}
pr, err := wire.ParsePublicReset(bytes.NewReader(p.data))
if err != nil {
return fmt.Errorf("Received a Public Reset. An error occurred parsing the packet: %s", err)
}
c.session.closeRemote(qerr.Error(qerr.PublicReset, fmt.Sprintf("Received a Public Reset for packet number %#x", pr.RejectedPacketNumber)))
c.logger.Infof("Received Public Reset, rejected packet number: %#x", pr.RejectedPacketNumber)
return nil
}
func (c *client) handleVersionNegotiationPacket(hdr *wire.Header) error {
// ignore delayed / duplicated version negotiation packets
if c.receivedVersionNegotiationPacket || c.versionNegotiated {
c.logger.Debugf("Received a delayed Version Negotiation Packet.")
return nil
}
for _, v := range hdr.SupportedVersions {
if v == c.version {
// the version negotiation packet contains the version that we offered
// this might be a packet sent by an attacker (or by a terribly broken server implementation)
// ignore it
return nil
}
}
c.logger.Infof("Received a Version Negotiation Packet. Supported Versions: %s", hdr.SupportedVersions)
newVersion, ok := protocol.ChooseSupportedVersion(c.config.Versions, hdr.SupportedVersions)
if !ok {
return qerr.InvalidVersion
}
c.receivedVersionNegotiationPacket = true
c.negotiatedVersions = hdr.SupportedVersions
// switch to negotiated version
c.initialVersion = c.version
c.version = newVersion
if err := c.generateConnectionIDs(); err != nil {
return err
}
c.logger.Infof("Switching to QUIC version %s. New connection ID: %s", newVersion, c.destConnID)
c.session.destroy(errCloseSessionForNewVersion)
return nil
}
func (c *client) handleRetryPacket(hdr *wire.Header) {
c.logger.Debugf("<- Received Retry")
hdr.Log(c.logger)
if !hdr.OrigDestConnectionID.Equal(c.destConnID) {
c.logger.Debugf("Ignoring spoofed Retry. Original Destination Connection ID: %s, expected: %s", hdr.OrigDestConnectionID, c.destConnID)
return
}
if hdr.SrcConnectionID.Equal(c.destConnID) {
c.logger.Debugf("Ignoring Retry, since the server didn't change the Source Connection ID.")
return
}
// If a token is already set, this means that we already received a Retry from the server.
// Ignore this Retry packet.
if len(c.token) > 0 {
c.logger.Debugf("Ignoring Retry, since a Retry was already received.")
return
}
c.destConnID = hdr.SrcConnectionID
c.token = hdr.Token
c.session.destroy(errCloseSessionForRetry)
}
func (c *client) createNewGQUICSession() error {
c.mutex.Lock()
defer c.mutex.Unlock()
runner := &runner{
onHandshakeCompleteImpl: func(_ Session) { close(c.handshakeChan) },
removeConnectionIDImpl: c.closeCallback,
}
sess, err := newClientSession(
c.conn,
runner,
c.version,
c.destConnID,
c.srcConnID,
c.tlsConf,
c.config,
c.initialVersion,
c.negotiatedVersions,
c.logger,
)
if err != nil {
return err
}
c.session = sess
c.packetHandlers.Add(c.srcConnID, c)
if c.config.RequestConnectionIDOmission {
c.packetHandlers.Add(protocol.ConnectionID{}, c)
}
return nil
}
func (c *client) createNewTLSSession(
paramsChan <-chan handshake.TransportParameters,
version protocol.VersionNumber,
) error {
c.mutex.Lock()
defer c.mutex.Unlock()
runner := &runner{
onHandshakeCompleteImpl: func(_ Session) { close(c.handshakeChan) },
removeConnectionIDImpl: c.closeCallback,
}
sess, err := newTLSClientSession(
c.conn,
runner,
c.token,
c.destConnID,
c.srcConnID,
c.config,
c.mintConf,
paramsChan,
1,
c.logger,
c.version,
)
if err != nil {
return err
}
c.session = sess
c.packetHandlers.Add(c.srcConnID, c)
return nil
}
func (c *client) Close() error {
c.mutex.Lock()
defer c.mutex.Unlock()
if c.session == nil {
return nil
}
return c.session.Close()
}
func (c *client) destroy(e error) {
c.mutex.Lock()
defer c.mutex.Unlock()
if c.session == nil {
return
}
c.session.destroy(e)
}
func (c *client) GetVersion() protocol.VersionNumber {
c.mutex.Lock()
v := c.version
c.mutex.Unlock()
return v
}
func (c *client) GetPerspective() protocol.Perspective {
return protocol.PerspectiveClient
}

18
vendor/github.com/lucas-clemente/quic-go/codecov.yml generated vendored Normal file
View File

@@ -0,0 +1,18 @@
coverage:
round: nearest
ignore:
- streams_map_incoming_bidi.go
- streams_map_incoming_uni.go
- streams_map_outgoing_bidi.go
- streams_map_outgoing_uni.go
- h2quic/gzipreader.go
- h2quic/response.go
- internal/ackhandler/packet_linkedlist.go
- internal/utils/byteinterval_linkedlist.go
- internal/utils/packetinterval_linkedlist.go
- internal/utils/linkedlist/linkedlist.go
status:
project:
default:
threshold: 0.5
patch: false

54
vendor/github.com/lucas-clemente/quic-go/conn.go generated vendored Normal file
View File

@@ -0,0 +1,54 @@
package quic
import (
"net"
"sync"
)
type connection interface {
Write([]byte) error
Read([]byte) (int, net.Addr, error)
Close() error
LocalAddr() net.Addr
RemoteAddr() net.Addr
SetCurrentRemoteAddr(net.Addr)
}
type conn struct {
mutex sync.RWMutex
pconn net.PacketConn
currentAddr net.Addr
}
var _ connection = &conn{}
func (c *conn) Write(p []byte) error {
_, err := c.pconn.WriteTo(p, c.currentAddr)
return err
}
func (c *conn) Read(p []byte) (int, net.Addr, error) {
return c.pconn.ReadFrom(p)
}
func (c *conn) SetCurrentRemoteAddr(addr net.Addr) {
c.mutex.Lock()
c.currentAddr = addr
c.mutex.Unlock()
}
func (c *conn) LocalAddr() net.Addr {
return c.pconn.LocalAddr()
}
func (c *conn) RemoteAddr() net.Addr {
c.mutex.RLock()
addr := c.currentAddr
c.mutex.RUnlock()
return addr
}
func (c *conn) Close() error {
return c.pconn.Close()
}

View File

@@ -0,0 +1,42 @@
package quic
import (
"io"
"github.com/lucas-clemente/quic-go/internal/flowcontrol"
"github.com/lucas-clemente/quic-go/internal/protocol"
"github.com/lucas-clemente/quic-go/internal/wire"
)
type cryptoStream interface {
StreamID() protocol.StreamID
io.Reader
io.Writer
handleStreamFrame(*wire.StreamFrame) error
hasData() bool
popStreamFrame(protocol.ByteCount) (*wire.StreamFrame, bool)
closeForShutdown(error)
setReadOffset(protocol.ByteCount)
// methods needed for flow control
getWindowUpdate() protocol.ByteCount
handleMaxStreamDataFrame(*wire.MaxStreamDataFrame)
}
type cryptoStreamImpl struct {
*stream
}
var _ cryptoStream = &cryptoStreamImpl{}
func newCryptoStream(sender streamSender, flowController flowcontrol.StreamFlowController, version protocol.VersionNumber) cryptoStream {
str := newStream(version.CryptoStreamID(), sender, flowController, version)
return &cryptoStreamImpl{str}
}
// SetReadOffset sets the read offset.
// It is only needed for the crypto stream.
// It must not be called concurrently with any other stream methods, especially Read and Write.
func (s *cryptoStreamImpl) setReadOffset(offset protocol.ByteCount) {
s.receiveStream.readOffset = offset
s.receiveStream.frameQueue.readPos = offset
}

View File

@@ -0,0 +1,158 @@
package quic
import (
"errors"
"github.com/lucas-clemente/quic-go/internal/protocol"
"github.com/lucas-clemente/quic-go/internal/utils"
)
type frameSorter struct {
queue map[protocol.ByteCount][]byte
readPos protocol.ByteCount
finalOffset protocol.ByteCount
gaps *utils.ByteIntervalList
}
var errDuplicateStreamData = errors.New("Duplicate Stream Data")
func newFrameSorter() *frameSorter {
s := frameSorter{
gaps: utils.NewByteIntervalList(),
queue: make(map[protocol.ByteCount][]byte),
finalOffset: protocol.MaxByteCount,
}
s.gaps.PushFront(utils.ByteInterval{Start: 0, End: protocol.MaxByteCount})
return &s
}
func (s *frameSorter) Push(data []byte, offset protocol.ByteCount, fin bool) error {
err := s.push(data, offset, fin)
if err == errDuplicateStreamData {
return nil
}
return err
}
func (s *frameSorter) push(data []byte, offset protocol.ByteCount, fin bool) error {
if fin {
s.finalOffset = offset + protocol.ByteCount(len(data))
}
if len(data) == 0 {
return nil
}
var wasCut bool
if oldData, ok := s.queue[offset]; ok {
if len(data) <= len(oldData) {
return errDuplicateStreamData
}
data = data[len(oldData):]
offset += protocol.ByteCount(len(oldData))
wasCut = true
}
start := offset
end := offset + protocol.ByteCount(len(data))
// skip all gaps that are before this stream frame
var gap *utils.ByteIntervalElement
for gap = s.gaps.Front(); gap != nil; gap = gap.Next() {
// the frame is a duplicate. Ignore it
if end <= gap.Value.Start {
return errDuplicateStreamData
}
if end > gap.Value.Start && start <= gap.Value.End {
break
}
}
if gap == nil {
return errors.New("StreamFrameSorter BUG: no gap found")
}
if start < gap.Value.Start {
add := gap.Value.Start - start
offset += add
start += add
data = data[add:]
wasCut = true
}
// find the highest gaps whose Start lies before the end of the frame
endGap := gap
for end >= endGap.Value.End {
nextEndGap := endGap.Next()
if nextEndGap == nil {
return errors.New("StreamFrameSorter BUG: no end gap found")
}
if endGap != gap {
s.gaps.Remove(endGap)
}
if end <= nextEndGap.Value.Start {
break
}
// delete queued frames completely covered by the current frame
delete(s.queue, endGap.Value.End)
endGap = nextEndGap
}
if end > endGap.Value.End {
cutLen := end - endGap.Value.End
len := protocol.ByteCount(len(data)) - cutLen
end -= cutLen
data = data[:len]
wasCut = true
}
if start == gap.Value.Start {
if end >= gap.Value.End {
// the frame completely fills this gap
// delete the gap
s.gaps.Remove(gap)
}
if end < endGap.Value.End {
// the frame covers the beginning of the gap
// adjust the Start value to shrink the gap
endGap.Value.Start = end
}
} else if end == endGap.Value.End {
// the frame covers the end of the gap
// adjust the End value to shrink the gap
gap.Value.End = start
} else {
if gap == endGap {
// the frame lies within the current gap, splitting it into two
// insert a new gap and adjust the current one
intv := utils.ByteInterval{Start: end, End: gap.Value.End}
s.gaps.InsertAfter(intv, gap)
gap.Value.End = start
} else {
gap.Value.End = start
endGap.Value.Start = end
}
}
if s.gaps.Len() > protocol.MaxStreamFrameSorterGaps {
return errors.New("Too many gaps in received data")
}
if wasCut {
newData := make([]byte, len(data))
copy(newData, data)
data = newData
}
s.queue[offset] = data
return nil
}
func (s *frameSorter) Pop() ([]byte /* data */, bool /* fin */) {
data, ok := s.queue[s.readPos]
if !ok {
return nil, s.readPos >= s.finalOffset
}
delete(s.queue, s.readPos)
s.readPos += protocol.ByteCount(len(data))
return data, s.readPos >= s.finalOffset
}

103
vendor/github.com/lucas-clemente/quic-go/framer.go generated vendored Normal file
View File

@@ -0,0 +1,103 @@
package quic
import (
"sync"
"github.com/lucas-clemente/quic-go/internal/protocol"
"github.com/lucas-clemente/quic-go/internal/wire"
)
type framer struct {
streamGetter streamGetter
cryptoStream cryptoStream
version protocol.VersionNumber
streamQueueMutex sync.Mutex
activeStreams map[protocol.StreamID]struct{}
streamQueue []protocol.StreamID
controlFrameMutex sync.Mutex
controlFrames []wire.Frame
}
func newFramer(
cryptoStream cryptoStream,
streamGetter streamGetter,
v protocol.VersionNumber,
) *framer {
return &framer{
streamGetter: streamGetter,
cryptoStream: cryptoStream,
activeStreams: make(map[protocol.StreamID]struct{}),
version: v,
}
}
func (f *framer) QueueControlFrame(frame wire.Frame) {
f.controlFrameMutex.Lock()
f.controlFrames = append(f.controlFrames, frame)
f.controlFrameMutex.Unlock()
}
func (f *framer) AppendControlFrames(frames []wire.Frame, maxLen protocol.ByteCount) ([]wire.Frame, protocol.ByteCount) {
var length protocol.ByteCount
f.controlFrameMutex.Lock()
for len(f.controlFrames) > 0 {
frame := f.controlFrames[len(f.controlFrames)-1]
frameLen := frame.Length(f.version)
if length+frameLen > maxLen {
break
}
frames = append(frames, frame)
length += frameLen
f.controlFrames = f.controlFrames[:len(f.controlFrames)-1]
}
f.controlFrameMutex.Unlock()
return frames, length
}
// AddActiveStream adds a stream that has data to write.
// It should not be used for the crypto stream.
func (f *framer) AddActiveStream(id protocol.StreamID) {
f.streamQueueMutex.Lock()
if _, ok := f.activeStreams[id]; !ok {
f.streamQueue = append(f.streamQueue, id)
f.activeStreams[id] = struct{}{}
}
f.streamQueueMutex.Unlock()
}
func (f *framer) AppendStreamFrames(frames []wire.Frame, maxLen protocol.ByteCount) []wire.Frame {
var length protocol.ByteCount
f.streamQueueMutex.Lock()
// pop STREAM frames, until less than MinStreamFrameSize bytes are left in the packet
numActiveStreams := len(f.streamQueue)
for i := 0; i < numActiveStreams; i++ {
if maxLen-length < protocol.MinStreamFrameSize {
break
}
id := f.streamQueue[0]
f.streamQueue = f.streamQueue[1:]
// This should never return an error. Better check it anyway.
// The stream will only be in the streamQueue, if it enqueued itself there.
str, err := f.streamGetter.GetOrOpenSendStream(id)
// The stream can be nil if it completed after it said it had data.
if str == nil || err != nil {
delete(f.activeStreams, id)
continue
}
frame, hasMoreData := str.popStreamFrame(maxLen - length)
if hasMoreData { // put the stream back in the queue (at the end)
f.streamQueue = append(f.streamQueue, id)
} else { // no more data to send. Stream is not active any more
delete(f.activeStreams, id)
}
if frame == nil { // can happen if the receiveStream was canceled after it said it had data
continue
}
frames = append(frames, frame)
length += frame.Length(f.version)
}
f.streamQueueMutex.Unlock()
return frames
}

View File

@@ -0,0 +1,314 @@
package h2quic
import (
"crypto/tls"
"errors"
"fmt"
"io"
"net"
"net/http"
"strings"
"sync"
"golang.org/x/net/http2"
"golang.org/x/net/http2/hpack"
"golang.org/x/net/idna"
quic "github.com/lucas-clemente/quic-go"
"github.com/lucas-clemente/quic-go/internal/protocol"
"github.com/lucas-clemente/quic-go/internal/utils"
"github.com/lucas-clemente/quic-go/qerr"
)
type roundTripperOpts struct {
DisableCompression bool
}
var dialAddr = quic.DialAddr
// client is a HTTP2 client doing QUIC requests
type client struct {
mutex sync.RWMutex
tlsConf *tls.Config
config *quic.Config
opts *roundTripperOpts
hostname string
handshakeErr error
dialOnce sync.Once
dialer func(network, addr string, tlsCfg *tls.Config, cfg *quic.Config) (quic.Session, error)
session quic.Session
headerStream quic.Stream
headerErr *qerr.QuicError
headerErrored chan struct{} // this channel is closed if an error occurs on the header stream
requestWriter *requestWriter
responses map[protocol.StreamID]chan *http.Response
logger utils.Logger
}
var _ http.RoundTripper = &client{}
var defaultQuicConfig = &quic.Config{
RequestConnectionIDOmission: true,
KeepAlive: true,
}
// newClient creates a new client
func newClient(
hostname string,
tlsConfig *tls.Config,
opts *roundTripperOpts,
quicConfig *quic.Config,
dialer func(network, addr string, tlsCfg *tls.Config, cfg *quic.Config) (quic.Session, error),
) *client {
config := defaultQuicConfig
if quicConfig != nil {
config = quicConfig
}
return &client{
hostname: authorityAddr("https", hostname),
responses: make(map[protocol.StreamID]chan *http.Response),
tlsConf: tlsConfig,
config: config,
opts: opts,
headerErrored: make(chan struct{}),
dialer: dialer,
logger: utils.DefaultLogger.WithPrefix("client"),
}
}
// dial dials the connection
func (c *client) dial() error {
var err error
if c.dialer != nil {
c.session, err = c.dialer("udp", c.hostname, c.tlsConf, c.config)
} else {
c.session, err = dialAddr(c.hostname, c.tlsConf, c.config)
}
if err != nil {
return err
}
// once the version has been negotiated, open the header stream
c.headerStream, err = c.session.OpenStream()
if err != nil {
return err
}
c.requestWriter = newRequestWriter(c.headerStream, c.logger)
go c.handleHeaderStream()
return nil
}
func (c *client) handleHeaderStream() {
decoder := hpack.NewDecoder(4096, func(hf hpack.HeaderField) {})
h2framer := http2.NewFramer(nil, c.headerStream)
var err error
for err == nil {
err = c.readResponse(h2framer, decoder)
}
if quicErr, ok := err.(*qerr.QuicError); !ok || quicErr.ErrorCode != qerr.PeerGoingAway {
c.logger.Debugf("Error handling header stream: %s", err)
}
c.headerErr = qerr.Error(qerr.InvalidHeadersStreamData, err.Error())
// stop all running request
close(c.headerErrored)
}
func (c *client) readResponse(h2framer *http2.Framer, decoder *hpack.Decoder) error {
frame, err := h2framer.ReadFrame()
if err != nil {
return err
}
hframe, ok := frame.(*http2.HeadersFrame)
if !ok {
return errors.New("not a headers frame")
}
mhframe := &http2.MetaHeadersFrame{HeadersFrame: hframe}
mhframe.Fields, err = decoder.DecodeFull(hframe.HeaderBlockFragment())
if err != nil {
return fmt.Errorf("cannot read header fields: %s", err.Error())
}
c.mutex.RLock()
responseChan, ok := c.responses[protocol.StreamID(hframe.StreamID)]
c.mutex.RUnlock()
if !ok {
return fmt.Errorf("response channel for stream %d not found", hframe.StreamID)
}
rsp, err := responseFromHeaders(mhframe)
if err != nil {
return err
}
responseChan <- rsp
return nil
}
// Roundtrip executes a request and returns a response
func (c *client) RoundTrip(req *http.Request) (*http.Response, error) {
// TODO: add port to address, if it doesn't have one
if req.URL.Scheme != "https" {
return nil, errors.New("quic http2: unsupported scheme")
}
if authorityAddr("https", hostnameFromRequest(req)) != c.hostname {
return nil, fmt.Errorf("h2quic Client BUG: RoundTrip called for the wrong client (expected %s, got %s)", c.hostname, req.Host)
}
c.dialOnce.Do(func() {
c.handshakeErr = c.dial()
})
if c.handshakeErr != nil {
return nil, c.handshakeErr
}
hasBody := (req.Body != nil)
responseChan := make(chan *http.Response)
dataStream, err := c.session.OpenStreamSync()
if err != nil {
_ = c.closeWithError(err)
return nil, err
}
c.mutex.Lock()
c.responses[dataStream.StreamID()] = responseChan
c.mutex.Unlock()
var requestedGzip bool
if !c.opts.DisableCompression && req.Header.Get("Accept-Encoding") == "" && req.Header.Get("Range") == "" && req.Method != "HEAD" {
requestedGzip = true
}
// TODO: add support for trailers
endStream := !hasBody
err = c.requestWriter.WriteRequest(req, dataStream.StreamID(), endStream, requestedGzip)
if err != nil {
_ = c.closeWithError(err)
return nil, err
}
resc := make(chan error, 1)
if hasBody {
go func() {
resc <- c.writeRequestBody(dataStream, req.Body)
}()
}
var res *http.Response
var receivedResponse bool
var bodySent bool
if !hasBody {
bodySent = true
}
ctx := req.Context()
for !(bodySent && receivedResponse) {
select {
case res = <-responseChan:
receivedResponse = true
c.mutex.Lock()
delete(c.responses, dataStream.StreamID())
c.mutex.Unlock()
case err := <-resc:
bodySent = true
if err != nil {
return nil, err
}
case <-ctx.Done():
// error code 6 signals that stream was canceled
dataStream.CancelRead(6)
dataStream.CancelWrite(6)
c.mutex.Lock()
delete(c.responses, dataStream.StreamID())
c.mutex.Unlock()
return nil, ctx.Err()
case <-c.headerErrored:
// an error occurred on the header stream
_ = c.closeWithError(c.headerErr)
return nil, c.headerErr
}
}
// TODO: correctly set this variable
var streamEnded bool
isHead := (req.Method == "HEAD")
res = setLength(res, isHead, streamEnded)
if streamEnded || isHead {
res.Body = noBody
} else {
res.Body = dataStream
if requestedGzip && res.Header.Get("Content-Encoding") == "gzip" {
res.Header.Del("Content-Encoding")
res.Header.Del("Content-Length")
res.ContentLength = -1
res.Body = &gzipReader{body: res.Body}
res.Uncompressed = true
}
}
res.Request = req
return res, nil
}
func (c *client) writeRequestBody(dataStream quic.Stream, body io.ReadCloser) (err error) {
defer func() {
cerr := body.Close()
if err == nil {
// TODO: what to do with dataStream here? Maybe reset it?
err = cerr
}
}()
_, err = io.Copy(dataStream, body)
if err != nil {
// TODO: what to do with dataStream here? Maybe reset it?
return err
}
return dataStream.Close()
}
func (c *client) closeWithError(e error) error {
if c.session == nil {
return nil
}
return c.session.CloseWithError(quic.ErrorCode(qerr.InternalError), e)
}
// Close closes the client
func (c *client) Close() error {
if c.session == nil {
return nil
}
return c.session.Close()
}
// copied from net/transport.go
// authorityAddr returns a given authority (a host/IP, or host:port / ip:port)
// and returns a host:port. The port 443 is added if needed.
func authorityAddr(scheme string, authority string) (addr string) {
host, port, err := net.SplitHostPort(authority)
if err != nil { // authority didn't have a port
port = "443"
if scheme == "http" {
port = "80"
}
host = authority
}
if a, err := idna.ToASCII(host); err == nil {
host = a
}
// IPv6 address literal, without a port:
if strings.HasPrefix(host, "[") && strings.HasSuffix(host, "]") {
return host + ":" + port
}
return net.JoinHostPort(host, port)
}

View File

@@ -0,0 +1,35 @@
package h2quic
// copied from net/transport.go
// gzipReader wraps a response body so it can lazily
// call gzip.NewReader on the first call to Read
import (
"compress/gzip"
"io"
)
// call gzip.NewReader on the first call to Read
type gzipReader struct {
body io.ReadCloser // underlying Response.Body
zr *gzip.Reader // lazily-initialized gzip reader
zerr error // sticky error
}
func (gz *gzipReader) Read(p []byte) (n int, err error) {
if gz.zerr != nil {
return 0, gz.zerr
}
if gz.zr == nil {
gz.zr, err = gzip.NewReader(gz.body)
if err != nil {
gz.zerr = err
return 0, err
}
}
return gz.zr.Read(p)
}
func (gz *gzipReader) Close() error {
return gz.body.Close()
}

View File

@@ -0,0 +1,77 @@
package h2quic
import (
"crypto/tls"
"errors"
"net/http"
"net/url"
"strconv"
"strings"
"golang.org/x/net/http2/hpack"
)
func requestFromHeaders(headers []hpack.HeaderField) (*http.Request, error) {
var path, authority, method, contentLengthStr string
httpHeaders := http.Header{}
for _, h := range headers {
switch h.Name {
case ":path":
path = h.Value
case ":method":
method = h.Value
case ":authority":
authority = h.Value
case "content-length":
contentLengthStr = h.Value
default:
if !h.IsPseudo() {
httpHeaders.Add(h.Name, h.Value)
}
}
}
// concatenate cookie headers, see https://tools.ietf.org/html/rfc6265#section-5.4
if len(httpHeaders["Cookie"]) > 0 {
httpHeaders.Set("Cookie", strings.Join(httpHeaders["Cookie"], "; "))
}
if len(path) == 0 || len(authority) == 0 || len(method) == 0 {
return nil, errors.New(":path, :authority and :method must not be empty")
}
u, err := url.Parse(path)
if err != nil {
return nil, err
}
var contentLength int64
if len(contentLengthStr) > 0 {
contentLength, err = strconv.ParseInt(contentLengthStr, 10, 64)
if err != nil {
return nil, err
}
}
return &http.Request{
Method: method,
URL: u,
Proto: "HTTP/2.0",
ProtoMajor: 2,
ProtoMinor: 0,
Header: httpHeaders,
Body: nil,
ContentLength: contentLength,
Host: authority,
RequestURI: path,
TLS: &tls.ConnectionState{},
}, nil
}
func hostnameFromRequest(req *http.Request) string {
if req.URL != nil {
return req.URL.Host
}
return ""
}

View File

@@ -0,0 +1,29 @@
package h2quic
import (
"io"
quic "github.com/lucas-clemente/quic-go"
)
type requestBody struct {
requestRead bool
dataStream quic.Stream
}
// make sure the requestBody can be used as a http.Request.Body
var _ io.ReadCloser = &requestBody{}
func newRequestBody(stream quic.Stream) *requestBody {
return &requestBody{dataStream: stream}
}
func (b *requestBody) Read(p []byte) (int, error) {
b.requestRead = true
return b.dataStream.Read(p)
}
func (b *requestBody) Close() error {
// stream's Close() closes the write side, not the read side
return nil
}

View File

@@ -0,0 +1,203 @@
package h2quic
import (
"bytes"
"fmt"
"net/http"
"strconv"
"strings"
"sync"
"golang.org/x/net/http/httpguts"
"golang.org/x/net/http2"
"golang.org/x/net/http2/hpack"
quic "github.com/lucas-clemente/quic-go"
"github.com/lucas-clemente/quic-go/internal/protocol"
"github.com/lucas-clemente/quic-go/internal/utils"
)
type requestWriter struct {
mutex sync.Mutex
headerStream quic.Stream
henc *hpack.Encoder
hbuf bytes.Buffer // HPACK encoder writes into this
logger utils.Logger
}
const defaultUserAgent = "quic-go"
func newRequestWriter(headerStream quic.Stream, logger utils.Logger) *requestWriter {
rw := &requestWriter{
headerStream: headerStream,
logger: logger,
}
rw.henc = hpack.NewEncoder(&rw.hbuf)
return rw
}
func (w *requestWriter) WriteRequest(req *http.Request, dataStreamID protocol.StreamID, endStream, requestGzip bool) error {
// TODO: add support for trailers
// TODO: add support for gzip compression
// TODO: write continuation frames, if the header frame is too long
w.mutex.Lock()
defer w.mutex.Unlock()
w.encodeHeaders(req, requestGzip, "", actualContentLength(req))
h2framer := http2.NewFramer(w.headerStream, nil)
return h2framer.WriteHeaders(http2.HeadersFrameParam{
StreamID: uint32(dataStreamID),
EndHeaders: true,
EndStream: endStream,
BlockFragment: w.hbuf.Bytes(),
Priority: http2.PriorityParam{Weight: 0xff},
})
}
// the rest of this files is copied from http2.Transport
func (w *requestWriter) encodeHeaders(req *http.Request, addGzipHeader bool, trailers string, contentLength int64) ([]byte, error) {
w.hbuf.Reset()
host := req.Host
if host == "" {
host = req.URL.Host
}
host, err := httpguts.PunycodeHostPort(host)
if err != nil {
return nil, err
}
var path string
if req.Method != "CONNECT" {
path = req.URL.RequestURI()
if !validPseudoPath(path) {
orig := path
path = strings.TrimPrefix(path, req.URL.Scheme+"://"+host)
if !validPseudoPath(path) {
if req.URL.Opaque != "" {
return nil, fmt.Errorf("invalid request :path %q from URL.Opaque = %q", orig, req.URL.Opaque)
}
return nil, fmt.Errorf("invalid request :path %q", orig)
}
}
}
// Check for any invalid headers and return an error before we
// potentially pollute our hpack state. (We want to be able to
// continue to reuse the hpack encoder for future requests)
for k, vv := range req.Header {
if !httpguts.ValidHeaderFieldName(k) {
return nil, fmt.Errorf("invalid HTTP header name %q", k)
}
for _, v := range vv {
if !httpguts.ValidHeaderFieldValue(v) {
return nil, fmt.Errorf("invalid HTTP header value %q for header %q", v, k)
}
}
}
// 8.1.2.3 Request Pseudo-Header Fields
// The :path pseudo-header field includes the path and query parts of the
// target URI (the path-absolute production and optionally a '?' character
// followed by the query production (see Sections 3.3 and 3.4 of
// [RFC3986]).
w.writeHeader(":authority", host)
w.writeHeader(":method", req.Method)
if req.Method != "CONNECT" {
w.writeHeader(":path", path)
w.writeHeader(":scheme", req.URL.Scheme)
}
if trailers != "" {
w.writeHeader("trailer", trailers)
}
var didUA bool
for k, vv := range req.Header {
lowKey := strings.ToLower(k)
switch lowKey {
case "host", "content-length":
// Host is :authority, already sent.
// Content-Length is automatic, set below.
continue
case "connection", "proxy-connection", "transfer-encoding", "upgrade", "keep-alive":
// Per 8.1.2.2 Connection-Specific Header
// Fields, don't send connection-specific
// fields. We have already checked if any
// are error-worthy so just ignore the rest.
continue
case "user-agent":
// Match Go's http1 behavior: at most one
// User-Agent. If set to nil or empty string,
// then omit it. Otherwise if not mentioned,
// include the default (below).
didUA = true
if len(vv) < 1 {
continue
}
vv = vv[:1]
if vv[0] == "" {
continue
}
}
for _, v := range vv {
w.writeHeader(lowKey, v)
}
}
if shouldSendReqContentLength(req.Method, contentLength) {
w.writeHeader("content-length", strconv.FormatInt(contentLength, 10))
}
if addGzipHeader {
w.writeHeader("accept-encoding", "gzip")
}
if !didUA {
w.writeHeader("user-agent", defaultUserAgent)
}
return w.hbuf.Bytes(), nil
}
func (w *requestWriter) writeHeader(name, value string) {
w.logger.Debugf("http2: Transport encoding header %q = %q", name, value)
w.henc.WriteField(hpack.HeaderField{Name: name, Value: value})
}
// shouldSendReqContentLength reports whether the http2.Transport should send
// a "content-length" request header. This logic is basically a copy of the net/http
// transferWriter.shouldSendContentLength.
// The contentLength is the corrected contentLength (so 0 means actually 0, not unknown).
// -1 means unknown.
func shouldSendReqContentLength(method string, contentLength int64) bool {
if contentLength > 0 {
return true
}
if contentLength < 0 {
return false
}
// For zero bodies, whether we send a content-length depends on the method.
// It also kinda doesn't matter for http2 either way, with END_STREAM.
switch method {
case "POST", "PUT", "PATCH":
return true
default:
return false
}
}
func validPseudoPath(v string) bool {
return (len(v) > 0 && v[0] == '/' && (len(v) == 1 || v[1] != '/')) || v == "*"
}
// actualContentLength returns a sanitized version of
// req.ContentLength, where 0 actually means zero (not unknown) and -1
// means unknown.
func actualContentLength(req *http.Request) int64 {
if req.Body == nil {
return 0
}
if req.ContentLength != 0 {
return req.ContentLength
}
return -1
}

View File

@@ -0,0 +1,95 @@
package h2quic
import (
"bytes"
"errors"
"io/ioutil"
"net/http"
"net/textproto"
"strconv"
"strings"
"golang.org/x/net/http2"
)
// copied from net/http2/transport.go
var errResponseHeaderListSize = errors.New("http2: response header list larger than advertised limit")
var noBody = ioutil.NopCloser(bytes.NewReader(nil))
// from the handleResponse function
func responseFromHeaders(f *http2.MetaHeadersFrame) (*http.Response, error) {
if f.Truncated {
return nil, errResponseHeaderListSize
}
status := f.PseudoValue("status")
if status == "" {
return nil, errors.New("missing status pseudo header")
}
statusCode, err := strconv.Atoi(status)
if err != nil {
return nil, errors.New("malformed non-numeric status pseudo header")
}
// TODO: handle statusCode == 100
header := make(http.Header)
res := &http.Response{
Proto: "HTTP/2.0",
ProtoMajor: 2,
Header: header,
StatusCode: statusCode,
Status: status + " " + http.StatusText(statusCode),
}
for _, hf := range f.RegularFields() {
key := http.CanonicalHeaderKey(hf.Name)
if key == "Trailer" {
t := res.Trailer
if t == nil {
t = make(http.Header)
res.Trailer = t
}
foreachHeaderElement(hf.Value, func(v string) {
t[http.CanonicalHeaderKey(v)] = nil
})
} else {
header[key] = append(header[key], hf.Value)
}
}
return res, nil
}
// continuation of the handleResponse function
func setLength(res *http.Response, isHead, streamEnded bool) *http.Response {
if !streamEnded || isHead {
res.ContentLength = -1
if clens := res.Header["Content-Length"]; len(clens) == 1 {
if clen64, err := strconv.ParseInt(clens[0], 10, 64); err == nil {
res.ContentLength = clen64
}
}
}
return res
}
// copied from net/http/server.go
// foreachHeaderElement splits v according to the "#rule" construction
// in RFC 2616 section 2.1 and calls fn for each non-empty element.
func foreachHeaderElement(v string, fn func(string)) {
v = textproto.TrimString(v)
if v == "" {
return
}
if !strings.Contains(v, ",") {
fn(v)
return
}
for _, f := range strings.Split(v, ",") {
if f = textproto.TrimString(f); f != "" {
fn(f)
}
}
}

View File

@@ -0,0 +1,114 @@
package h2quic
import (
"bytes"
"net/http"
"strconv"
"strings"
"sync"
quic "github.com/lucas-clemente/quic-go"
"github.com/lucas-clemente/quic-go/internal/protocol"
"github.com/lucas-clemente/quic-go/internal/utils"
"golang.org/x/net/http2"
"golang.org/x/net/http2/hpack"
)
type responseWriter struct {
dataStreamID protocol.StreamID
dataStream quic.Stream
headerStream quic.Stream
headerStreamMutex *sync.Mutex
header http.Header
status int // status code passed to WriteHeader
headerWritten bool
logger utils.Logger
}
func newResponseWriter(
headerStream quic.Stream,
headerStreamMutex *sync.Mutex,
dataStream quic.Stream,
dataStreamID protocol.StreamID,
logger utils.Logger,
) *responseWriter {
return &responseWriter{
header: http.Header{},
headerStream: headerStream,
headerStreamMutex: headerStreamMutex,
dataStream: dataStream,
dataStreamID: dataStreamID,
logger: logger,
}
}
func (w *responseWriter) Header() http.Header {
return w.header
}
func (w *responseWriter) WriteHeader(status int) {
if w.headerWritten {
return
}
w.headerWritten = true
w.status = status
var headers bytes.Buffer
enc := hpack.NewEncoder(&headers)
enc.WriteField(hpack.HeaderField{Name: ":status", Value: strconv.Itoa(status)})
for k, v := range w.header {
for index := range v {
enc.WriteField(hpack.HeaderField{Name: strings.ToLower(k), Value: v[index]})
}
}
w.logger.Infof("Responding with %d", status)
w.headerStreamMutex.Lock()
defer w.headerStreamMutex.Unlock()
h2framer := http2.NewFramer(w.headerStream, nil)
err := h2framer.WriteHeaders(http2.HeadersFrameParam{
StreamID: uint32(w.dataStreamID),
EndHeaders: true,
BlockFragment: headers.Bytes(),
})
if err != nil {
w.logger.Errorf("could not write h2 header: %s", err.Error())
}
}
func (w *responseWriter) Write(p []byte) (int, error) {
if !w.headerWritten {
w.WriteHeader(200)
}
if !bodyAllowedForStatus(w.status) {
return 0, http.ErrBodyNotAllowed
}
return w.dataStream.Write(p)
}
func (w *responseWriter) Flush() {}
// This is a NOP. Use http.Request.Context
func (w *responseWriter) CloseNotify() <-chan bool { return make(<-chan bool) }
// test that we implement http.Flusher
var _ http.Flusher = &responseWriter{}
// copied from http2/http2.go
// bodyAllowedForStatus reports whether a given response status code
// permits a body. See RFC 2616, section 4.4.
func bodyAllowedForStatus(status int) bool {
switch {
case status >= 100 && status <= 199:
return false
case status == 204:
return false
case status == 304:
return false
}
return true
}

View File

@@ -0,0 +1,9 @@
package h2quic
import "net/http"
// The CloseNotifier is a deprecated interface, and staticcheck will report that from Go 1.11.
// By defining it in a separate file, we can exclude this file from staticcheck.
// test that we implement http.CloseNotifier
var _ http.CloseNotifier = &responseWriter{}

View File

@@ -0,0 +1,179 @@
package h2quic
import (
"crypto/tls"
"errors"
"fmt"
"io"
"net/http"
"strings"
"sync"
quic "github.com/lucas-clemente/quic-go"
"golang.org/x/net/http/httpguts"
)
type roundTripCloser interface {
http.RoundTripper
io.Closer
}
// RoundTripper implements the http.RoundTripper interface
type RoundTripper struct {
mutex sync.Mutex
// DisableCompression, if true, prevents the Transport from
// requesting compression with an "Accept-Encoding: gzip"
// request header when the Request contains no existing
// Accept-Encoding value. If the Transport requests gzip on
// its own and gets a gzipped response, it's transparently
// decoded in the Response.Body. However, if the user
// explicitly requested gzip it is not automatically
// uncompressed.
DisableCompression bool
// TLSClientConfig specifies the TLS configuration to use with
// tls.Client. If nil, the default configuration is used.
TLSClientConfig *tls.Config
// QuicConfig is the quic.Config used for dialing new connections.
// If nil, reasonable default values will be used.
QuicConfig *quic.Config
// Dial specifies an optional dial function for creating QUIC
// connections for requests.
// If Dial is nil, quic.DialAddr will be used.
Dial func(network, addr string, tlsCfg *tls.Config, cfg *quic.Config) (quic.Session, error)
clients map[string]roundTripCloser
}
// RoundTripOpt are options for the Transport.RoundTripOpt method.
type RoundTripOpt struct {
// OnlyCachedConn controls whether the RoundTripper may
// create a new QUIC connection. If set true and
// no cached connection is available, RoundTrip
// will return ErrNoCachedConn.
OnlyCachedConn bool
}
var _ roundTripCloser = &RoundTripper{}
// ErrNoCachedConn is returned when RoundTripper.OnlyCachedConn is set
var ErrNoCachedConn = errors.New("h2quic: no cached connection was available")
// RoundTripOpt is like RoundTrip, but takes options.
func (r *RoundTripper) RoundTripOpt(req *http.Request, opt RoundTripOpt) (*http.Response, error) {
if req.URL == nil {
closeRequestBody(req)
return nil, errors.New("quic: nil Request.URL")
}
if req.URL.Host == "" {
closeRequestBody(req)
return nil, errors.New("quic: no Host in request URL")
}
if req.Header == nil {
closeRequestBody(req)
return nil, errors.New("quic: nil Request.Header")
}
if req.URL.Scheme == "https" {
for k, vv := range req.Header {
if !httpguts.ValidHeaderFieldName(k) {
return nil, fmt.Errorf("quic: invalid http header field name %q", k)
}
for _, v := range vv {
if !httpguts.ValidHeaderFieldValue(v) {
return nil, fmt.Errorf("quic: invalid http header field value %q for key %v", v, k)
}
}
}
} else {
closeRequestBody(req)
return nil, fmt.Errorf("quic: unsupported protocol scheme: %s", req.URL.Scheme)
}
if req.Method != "" && !validMethod(req.Method) {
closeRequestBody(req)
return nil, fmt.Errorf("quic: invalid method %q", req.Method)
}
hostname := authorityAddr("https", hostnameFromRequest(req))
cl, err := r.getClient(hostname, opt.OnlyCachedConn)
if err != nil {
return nil, err
}
return cl.RoundTrip(req)
}
// RoundTrip does a round trip.
func (r *RoundTripper) RoundTrip(req *http.Request) (*http.Response, error) {
return r.RoundTripOpt(req, RoundTripOpt{})
}
func (r *RoundTripper) getClient(hostname string, onlyCached bool) (http.RoundTripper, error) {
r.mutex.Lock()
defer r.mutex.Unlock()
if r.clients == nil {
r.clients = make(map[string]roundTripCloser)
}
client, ok := r.clients[hostname]
if !ok {
if onlyCached {
return nil, ErrNoCachedConn
}
client = newClient(
hostname,
r.TLSClientConfig,
&roundTripperOpts{DisableCompression: r.DisableCompression},
r.QuicConfig,
r.Dial,
)
r.clients[hostname] = client
}
return client, nil
}
// Close closes the QUIC connections that this RoundTripper has used
func (r *RoundTripper) Close() error {
r.mutex.Lock()
defer r.mutex.Unlock()
for _, client := range r.clients {
if err := client.Close(); err != nil {
return err
}
}
r.clients = nil
return nil
}
func closeRequestBody(req *http.Request) {
if req.Body != nil {
req.Body.Close()
}
}
func validMethod(method string) bool {
/*
Method = "OPTIONS" ; Section 9.2
| "GET" ; Section 9.3
| "HEAD" ; Section 9.4
| "POST" ; Section 9.5
| "PUT" ; Section 9.6
| "DELETE" ; Section 9.7
| "TRACE" ; Section 9.8
| "CONNECT" ; Section 9.9
| extension-method
extension-method = token
token = 1*<any CHAR except CTLs or separators>
*/
return len(method) > 0 && strings.IndexFunc(method, isNotToken) == -1
}
// copied from net/http/http.go
func isNotToken(r rune) bool {
return !httpguts.IsTokenRune(r)
}

View File

@@ -0,0 +1,402 @@
package h2quic
import (
"crypto/tls"
"errors"
"fmt"
"net"
"net/http"
"runtime"
"strings"
"sync"
"sync/atomic"
"time"
quic "github.com/lucas-clemente/quic-go"
"github.com/lucas-clemente/quic-go/internal/protocol"
"github.com/lucas-clemente/quic-go/internal/utils"
"github.com/lucas-clemente/quic-go/qerr"
"golang.org/x/net/http2"
"golang.org/x/net/http2/hpack"
)
type streamCreator interface {
quic.Session
GetOrOpenStream(protocol.StreamID) (quic.Stream, error)
}
type remoteCloser interface {
CloseRemote(protocol.ByteCount)
}
// allows mocking of quic.Listen and quic.ListenAddr
var (
quicListen = quic.Listen
quicListenAddr = quic.ListenAddr
)
// Server is a HTTP2 server listening for QUIC connections.
type Server struct {
*http.Server
// By providing a quic.Config, it is possible to set parameters of the QUIC connection.
// If nil, it uses reasonable default values.
QuicConfig *quic.Config
// Private flag for demo, do not use
CloseAfterFirstRequest bool
port uint32 // used atomically
listenerMutex sync.Mutex
listener quic.Listener
closed bool
supportedVersionsAsString string
logger utils.Logger // will be set by Server.serveImpl()
}
// ListenAndServe listens on the UDP address s.Addr and calls s.Handler to handle HTTP/2 requests on incoming connections.
func (s *Server) ListenAndServe() error {
if s.Server == nil {
return errors.New("use of h2quic.Server without http.Server")
}
return s.serveImpl(s.TLSConfig, nil)
}
// ListenAndServeTLS listens on the UDP address s.Addr and calls s.Handler to handle HTTP/2 requests on incoming connections.
func (s *Server) ListenAndServeTLS(certFile, keyFile string) error {
var err error
certs := make([]tls.Certificate, 1)
certs[0], err = tls.LoadX509KeyPair(certFile, keyFile)
if err != nil {
return err
}
// We currently only use the cert-related stuff from tls.Config,
// so we don't need to make a full copy.
config := &tls.Config{
Certificates: certs,
}
return s.serveImpl(config, nil)
}
// Serve an existing UDP connection.
func (s *Server) Serve(conn net.PacketConn) error {
return s.serveImpl(s.TLSConfig, conn)
}
func (s *Server) serveImpl(tlsConfig *tls.Config, conn net.PacketConn) error {
if s.Server == nil {
return errors.New("use of h2quic.Server without http.Server")
}
s.logger = utils.DefaultLogger.WithPrefix("server")
s.listenerMutex.Lock()
if s.closed {
s.listenerMutex.Unlock()
return errors.New("Server is already closed")
}
if s.listener != nil {
s.listenerMutex.Unlock()
return errors.New("ListenAndServe may only be called once")
}
var ln quic.Listener
var err error
if conn == nil {
ln, err = quicListenAddr(s.Addr, tlsConfig, s.QuicConfig)
} else {
ln, err = quicListen(conn, tlsConfig, s.QuicConfig)
}
if err != nil {
s.listenerMutex.Unlock()
return err
}
s.listener = ln
s.listenerMutex.Unlock()
for {
sess, err := ln.Accept()
if err != nil {
return err
}
go s.handleHeaderStream(sess.(streamCreator))
}
}
func (s *Server) handleHeaderStream(session streamCreator) {
stream, err := session.AcceptStream()
if err != nil {
session.CloseWithError(quic.ErrorCode(qerr.InvalidHeadersStreamData), err)
return
}
hpackDecoder := hpack.NewDecoder(4096, nil)
h2framer := http2.NewFramer(nil, stream)
var headerStreamMutex sync.Mutex // Protects concurrent calls to Write()
for {
if err := s.handleRequest(session, stream, &headerStreamMutex, hpackDecoder, h2framer); err != nil {
// QuicErrors must originate from stream.Read() returning an error.
// In this case, the session has already logged the error, so we don't
// need to log it again.
errorCode := qerr.InternalError
if qerr, ok := err.(*qerr.QuicError); ok {
errorCode = qerr.ErrorCode
s.logger.Errorf("error handling h2 request: %s", err.Error())
}
session.CloseWithError(quic.ErrorCode(errorCode), err)
return
}
}
}
func (s *Server) handleRequest(session streamCreator, headerStream quic.Stream, headerStreamMutex *sync.Mutex, hpackDecoder *hpack.Decoder, h2framer *http2.Framer) error {
h2frame, err := h2framer.ReadFrame()
if err != nil {
return qerr.Error(qerr.HeadersStreamDataDecompressFailure, "cannot read frame")
}
var h2headersFrame *http2.HeadersFrame
switch f := h2frame.(type) {
case *http2.PriorityFrame:
// ignore PRIORITY frames
s.logger.Debugf("Ignoring H2 PRIORITY frame: %#v", f)
return nil
case *http2.HeadersFrame:
h2headersFrame = f
default:
return qerr.Error(qerr.InvalidHeadersStreamData, "expected a header frame")
}
if !h2headersFrame.HeadersEnded() {
return errors.New("http2 header continuation not implemented")
}
headers, err := hpackDecoder.DecodeFull(h2headersFrame.HeaderBlockFragment())
if err != nil {
s.logger.Errorf("invalid http2 headers encoding: %s", err.Error())
return err
}
req, err := requestFromHeaders(headers)
if err != nil {
return err
}
if s.logger.Debug() {
s.logger.Infof("%s %s%s, on data stream %d", req.Method, req.Host, req.RequestURI, h2headersFrame.StreamID)
} else {
s.logger.Infof("%s %s%s", req.Method, req.Host, req.RequestURI)
}
dataStream, err := session.GetOrOpenStream(protocol.StreamID(h2headersFrame.StreamID))
if err != nil {
return err
}
// this can happen if the client immediately closes the data stream after sending the request and the runtime processes the reset before the request
if dataStream == nil {
return nil
}
// handleRequest should be as non-blocking as possible to minimize
// head-of-line blocking. Potentially blocking code is run in a separate
// goroutine, enabling handleRequest to return before the code is executed.
go func() {
streamEnded := h2headersFrame.StreamEnded()
if streamEnded {
dataStream.(remoteCloser).CloseRemote(0)
streamEnded = true
_, _ = dataStream.Read([]byte{0}) // read the eof
}
req = req.WithContext(dataStream.Context())
reqBody := newRequestBody(dataStream)
req.Body = reqBody
req.RemoteAddr = session.RemoteAddr().String()
responseWriter := newResponseWriter(headerStream, headerStreamMutex, dataStream, protocol.StreamID(h2headersFrame.StreamID), s.logger)
handler := s.Handler
if handler == nil {
handler = http.DefaultServeMux
}
panicked := false
func() {
defer func() {
if p := recover(); p != nil {
// Copied from net/http/server.go
const size = 64 << 10
buf := make([]byte, size)
buf = buf[:runtime.Stack(buf, false)]
s.logger.Errorf("http: panic serving: %v\n%s", p, buf)
panicked = true
}
}()
handler.ServeHTTP(responseWriter, req)
}()
if panicked {
responseWriter.WriteHeader(500)
} else {
responseWriter.WriteHeader(200)
}
if responseWriter.dataStream != nil {
if !streamEnded && !reqBody.requestRead {
// in gQUIC, the error code doesn't matter, so just use 0 here
responseWriter.dataStream.CancelRead(0)
}
responseWriter.dataStream.Close()
}
if s.CloseAfterFirstRequest {
time.Sleep(100 * time.Millisecond)
session.Close()
}
}()
return nil
}
// Close the server immediately, aborting requests and sending CONNECTION_CLOSE frames to connected clients.
// Close in combination with ListenAndServe() (instead of Serve()) may race if it is called before a UDP socket is established.
func (s *Server) Close() error {
s.listenerMutex.Lock()
defer s.listenerMutex.Unlock()
s.closed = true
if s.listener != nil {
err := s.listener.Close()
s.listener = nil
return err
}
return nil
}
// CloseGracefully shuts down the server gracefully. The server sends a GOAWAY frame first, then waits for either timeout to trigger, or for all running requests to complete.
// CloseGracefully in combination with ListenAndServe() (instead of Serve()) may race if it is called before a UDP socket is established.
func (s *Server) CloseGracefully(timeout time.Duration) error {
// TODO: implement
return nil
}
// SetQuicHeaders can be used to set the proper headers that announce that this server supports QUIC.
// The values that are set depend on the port information from s.Server.Addr, and currently look like this (if Addr has port 443):
// Alt-Svc: quic=":443"; ma=2592000; v="33,32,31,30"
func (s *Server) SetQuicHeaders(hdr http.Header) error {
port := atomic.LoadUint32(&s.port)
if port == 0 {
// Extract port from s.Server.Addr
_, portStr, err := net.SplitHostPort(s.Server.Addr)
if err != nil {
return err
}
portInt, err := net.LookupPort("tcp", portStr)
if err != nil {
return err
}
port = uint32(portInt)
atomic.StoreUint32(&s.port, port)
}
if s.supportedVersionsAsString == "" {
var versions []string
for _, v := range protocol.SupportedVersions {
versions = append(versions, v.ToAltSvc())
}
s.supportedVersionsAsString = strings.Join(versions, ",")
}
hdr.Add("Alt-Svc", fmt.Sprintf(`quic=":%d"; ma=2592000; v="%s"`, port, s.supportedVersionsAsString))
return nil
}
// ListenAndServeQUIC listens on the UDP network address addr and calls the
// handler for HTTP/2 requests on incoming connections. http.DefaultServeMux is
// used when handler is nil.
func ListenAndServeQUIC(addr, certFile, keyFile string, handler http.Handler) error {
server := &Server{
Server: &http.Server{
Addr: addr,
Handler: handler,
},
}
return server.ListenAndServeTLS(certFile, keyFile)
}
// ListenAndServe listens on the given network address for both, TLS and QUIC
// connetions in parallel. It returns if one of the two returns an error.
// http.DefaultServeMux is used when handler is nil.
// The correct Alt-Svc headers for QUIC are set.
func ListenAndServe(addr, certFile, keyFile string, handler http.Handler) error {
// Load certs
var err error
certs := make([]tls.Certificate, 1)
certs[0], err = tls.LoadX509KeyPair(certFile, keyFile)
if err != nil {
return err
}
// We currently only use the cert-related stuff from tls.Config,
// so we don't need to make a full copy.
config := &tls.Config{
Certificates: certs,
}
// Open the listeners
udpAddr, err := net.ResolveUDPAddr("udp", addr)
if err != nil {
return err
}
udpConn, err := net.ListenUDP("udp", udpAddr)
if err != nil {
return err
}
defer udpConn.Close()
tcpAddr, err := net.ResolveTCPAddr("tcp", addr)
if err != nil {
return err
}
tcpConn, err := net.ListenTCP("tcp", tcpAddr)
if err != nil {
return err
}
defer tcpConn.Close()
tlsConn := tls.NewListener(tcpConn, config)
defer tlsConn.Close()
// Start the servers
httpServer := &http.Server{
Addr: addr,
TLSConfig: config,
}
quicServer := &Server{
Server: httpServer,
}
if handler == nil {
handler = http.DefaultServeMux
}
httpServer.Handler = http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
quicServer.SetQuicHeaders(w.Header())
handler.ServeHTTP(w, r)
})
hErr := make(chan error)
qErr := make(chan error)
go func() {
hErr <- httpServer.Serve(tlsConn)
}()
go func() {
qErr <- quicServer.Serve(udpConn)
}()
select {
case err := <-hErr:
quicServer.Close()
return err
case err := <-qErr:
// Cannot close the HTTP server or wait for requests to complete properly :/
return err
}
}

221
vendor/github.com/lucas-clemente/quic-go/interface.go generated vendored Normal file
View File

@@ -0,0 +1,221 @@
package quic
import (
"context"
"io"
"net"
"time"
"github.com/lucas-clemente/quic-go/internal/handshake"
"github.com/lucas-clemente/quic-go/internal/protocol"
)
// The StreamID is the ID of a QUIC stream.
type StreamID = protocol.StreamID
// A VersionNumber is a QUIC version number.
type VersionNumber = protocol.VersionNumber
const (
// VersionGQUIC39 is gQUIC version 39.
VersionGQUIC39 = protocol.Version39
// VersionGQUIC43 is gQUIC version 43.
VersionGQUIC43 = protocol.Version43
// VersionGQUIC44 is gQUIC version 44.
VersionGQUIC44 = protocol.Version44
)
// A Cookie can be used to verify the ownership of the client address.
type Cookie = handshake.Cookie
// ConnectionState records basic details about the QUIC connection.
type ConnectionState = handshake.ConnectionState
// An ErrorCode is an application-defined error code.
type ErrorCode = protocol.ApplicationErrorCode
// Stream is the interface implemented by QUIC streams
type Stream interface {
// StreamID returns the stream ID.
StreamID() StreamID
// Read reads data from the stream.
// Read can be made to time out and return a net.Error with Timeout() == true
// after a fixed time limit; see SetDeadline and SetReadDeadline.
// If the stream was canceled by the peer, the error implements the StreamError
// interface, and Canceled() == true.
io.Reader
// Write writes data to the stream.
// Write can be made to time out and return a net.Error with Timeout() == true
// after a fixed time limit; see SetDeadline and SetWriteDeadline.
// If the stream was canceled by the peer, the error implements the StreamError
// interface, and Canceled() == true.
io.Writer
// Close closes the write-direction of the stream.
// Future calls to Write are not permitted after calling Close.
// It must not be called concurrently with Write.
// It must not be called after calling CancelWrite.
io.Closer
// CancelWrite aborts sending on this stream.
// It must not be called after Close.
// Data already written, but not yet delivered to the peer is not guaranteed to be delivered reliably.
// Write will unblock immediately, and future calls to Write will fail.
CancelWrite(ErrorCode) error
// CancelRead aborts receiving on this stream.
// It will ask the peer to stop transmitting stream data.
// Read will unblock immediately, and future Read calls will fail.
CancelRead(ErrorCode) error
// The context is canceled as soon as the write-side of the stream is closed.
// This happens when Close() is called, or when the stream is reset (either locally or remotely).
// Warning: This API should not be considered stable and might change soon.
Context() context.Context
// SetReadDeadline sets the deadline for future Read calls and
// any currently-blocked Read call.
// A zero value for t means Read will not time out.
SetReadDeadline(t time.Time) error
// SetWriteDeadline sets the deadline for future Write calls
// and any currently-blocked Write call.
// Even if write times out, it may return n > 0, indicating that
// some of the data was successfully written.
// A zero value for t means Write will not time out.
SetWriteDeadline(t time.Time) error
// SetDeadline sets the read and write deadlines associated
// with the connection. It is equivalent to calling both
// SetReadDeadline and SetWriteDeadline.
SetDeadline(t time.Time) error
}
// A ReceiveStream is a unidirectional Receive Stream.
type ReceiveStream interface {
// see Stream.StreamID
StreamID() StreamID
// see Stream.Read
io.Reader
// see Stream.CancelRead
CancelRead(ErrorCode) error
// see Stream.SetReadDealine
SetReadDeadline(t time.Time) error
}
// A SendStream is a unidirectional Send Stream.
type SendStream interface {
// see Stream.StreamID
StreamID() StreamID
// see Stream.Write
io.Writer
// see Stream.Close
io.Closer
// see Stream.CancelWrite
CancelWrite(ErrorCode) error
// see Stream.Context
Context() context.Context
// see Stream.SetWriteDeadline
SetWriteDeadline(t time.Time) error
}
// StreamError is returned by Read and Write when the peer cancels the stream.
type StreamError interface {
error
Canceled() bool
ErrorCode() ErrorCode
}
// A Session is a QUIC connection between two peers.
type Session interface {
// AcceptStream returns the next stream opened by the peer, blocking until one is available.
AcceptStream() (Stream, error)
// AcceptUniStream returns the next unidirectional stream opened by the peer, blocking until one is available.
AcceptUniStream() (ReceiveStream, error)
// OpenStream opens a new bidirectional QUIC stream.
// It returns a special error when the peer's concurrent stream limit is reached.
// There is no signaling to the peer about new streams:
// The peer can only accept the stream after data has been sent on the stream.
// TODO(#1152): Enable testing for the special error
OpenStream() (Stream, error)
// OpenStreamSync opens a new bidirectional QUIC stream.
// It blocks until the peer's concurrent stream limit allows a new stream to be opened.
OpenStreamSync() (Stream, error)
// OpenUniStream opens a new outgoing unidirectional QUIC stream.
// It returns a special error when the peer's concurrent stream limit is reached.
// TODO(#1152): Enable testing for the special error
OpenUniStream() (SendStream, error)
// OpenUniStreamSync opens a new outgoing unidirectional QUIC stream.
// It blocks until the peer's concurrent stream limit allows a new stream to be opened.
OpenUniStreamSync() (SendStream, error)
// LocalAddr returns the local address.
LocalAddr() net.Addr
// RemoteAddr returns the address of the peer.
RemoteAddr() net.Addr
// Close the connection.
io.Closer
// Close the connection with an error.
// The error must not be nil.
CloseWithError(ErrorCode, error) error
// The context is cancelled when the session is closed.
// Warning: This API should not be considered stable and might change soon.
Context() context.Context
// ConnectionState returns basic details about the QUIC connection.
// Warning: This API should not be considered stable and might change soon.
ConnectionState() ConnectionState
}
// Config contains all configuration data needed for a QUIC server or client.
type Config struct {
// The QUIC versions that can be negotiated.
// If not set, it uses all versions available.
// Warning: This API should not be considered stable and will change soon.
Versions []VersionNumber
// Ask the server to omit the connection ID sent in the Public Header.
// This saves 8 bytes in the Public Header in every packet. However, if the IP address of the server changes, the connection cannot be migrated.
// Currently only valid for the client.
RequestConnectionIDOmission bool
// The length of the connection ID in bytes. Only valid for IETF QUIC.
// It can be 0, or any value between 4 and 18.
// If not set, the interpretation depends on where the Config is used:
// If used for dialing an address, a 0 byte connection ID will be used.
// If used for a server, or dialing on a packet conn, a 4 byte connection ID will be used.
// When dialing on a packet conn, the ConnectionIDLength value must be the same for every Dial call.
ConnectionIDLength int
// HandshakeTimeout is the maximum duration that the cryptographic handshake may take.
// If the timeout is exceeded, the connection is closed.
// If this value is zero, the timeout is set to 10 seconds.
HandshakeTimeout time.Duration
// IdleTimeout is the maximum duration that may pass without any incoming network activity.
// This value only applies after the handshake has completed.
// If the timeout is exceeded, the connection is closed.
// If this value is zero, the timeout is set to 30 seconds.
IdleTimeout time.Duration
// AcceptCookie determines if a Cookie is accepted.
// It is called with cookie = nil if the client didn't send an Cookie.
// If not set, it verifies that the address matches, and that the Cookie was issued within the last 24 hours.
// This option is only valid for the server.
AcceptCookie func(clientAddr net.Addr, cookie *Cookie) bool
// MaxReceiveStreamFlowControlWindow is the maximum stream-level flow control window for receiving data.
// If this value is zero, it will default to 1 MB for the server and 6 MB for the client.
MaxReceiveStreamFlowControlWindow uint64
// MaxReceiveConnectionFlowControlWindow is the connection-level flow control window for receiving data.
// If this value is zero, it will default to 1.5 MB for the server and 15 MB for the client.
MaxReceiveConnectionFlowControlWindow uint64
// MaxIncomingStreams is the maximum number of concurrent bidirectional streams that a peer is allowed to open.
// If not set, it will default to 100.
// If set to a negative value, it doesn't allow any bidirectional streams.
// Values larger than 65535 (math.MaxUint16) are invalid.
MaxIncomingStreams int
// MaxIncomingUniStreams is the maximum number of concurrent unidirectional streams that a peer is allowed to open.
// This value doesn't have any effect in Google QUIC.
// If not set, it will default to 100.
// If set to a negative value, it doesn't allow any unidirectional streams.
// Values larger than 65535 (math.MaxUint16) are invalid.
MaxIncomingUniStreams int
// KeepAlive defines whether this peer will periodically send PING frames to keep the connection alive.
KeepAlive bool
}
// A Listener for incoming QUIC connections
type Listener interface {
// Close the server, sending CONNECTION_CLOSE frames to each peer.
Close() error
// Addr returns the local network addr that the server is listening on.
Addr() net.Addr
// Accept returns new sessions. It should be called in a loop.
Accept() (Session, error)
}

View File

@@ -0,0 +1,3 @@
package ackhandler
//go:generate genny -pkg ackhandler -in ../utils/linkedlist/linkedlist.go -out packet_linkedlist.go gen Item=Packet

View File

@@ -0,0 +1,47 @@
package ackhandler
import (
"time"
"github.com/lucas-clemente/quic-go/internal/protocol"
"github.com/lucas-clemente/quic-go/internal/wire"
)
// SentPacketHandler handles ACKs received for outgoing packets
type SentPacketHandler interface {
// SentPacket may modify the packet
SentPacket(packet *Packet)
SentPacketsAsRetransmission(packets []*Packet, retransmissionOf protocol.PacketNumber)
ReceivedAck(ackFrame *wire.AckFrame, withPacketNumber protocol.PacketNumber, encLevel protocol.EncryptionLevel, recvTime time.Time) error
SetHandshakeComplete()
// The SendMode determines if and what kind of packets can be sent.
SendMode() SendMode
// TimeUntilSend is the time when the next packet should be sent.
// It is used for pacing packets.
TimeUntilSend() time.Time
// ShouldSendNumPackets returns the number of packets that should be sent immediately.
// It always returns a number greater or equal than 1.
// A number greater than 1 is returned when the pacing delay is smaller than the minimum pacing delay.
// Note that the number of packets is only calculated based on the pacing algorithm.
// Before sending any packet, SendingAllowed() must be called to learn if we can actually send it.
ShouldSendNumPackets() int
GetStopWaitingFrame(force bool) *wire.StopWaitingFrame
GetLowestPacketNotConfirmedAcked() protocol.PacketNumber
DequeuePacketForRetransmission() *Packet
DequeueProbePacket() (*Packet, error)
GetPacketNumberLen(protocol.PacketNumber) protocol.PacketNumberLen
GetAlarmTimeout() time.Time
OnAlarm() error
}
// ReceivedPacketHandler handles ACKs needed to send for incoming packets
type ReceivedPacketHandler interface {
ReceivedPacket(packetNumber protocol.PacketNumber, rcvTime time.Time, shouldInstigateAck bool) error
IgnoreBelow(protocol.PacketNumber)
GetAlarmTimeout() time.Time
GetAckFrame() *wire.AckFrame
}

View File

@@ -0,0 +1,29 @@
package ackhandler
import (
"time"
"github.com/lucas-clemente/quic-go/internal/protocol"
"github.com/lucas-clemente/quic-go/internal/wire"
)
// A Packet is a packet
type Packet struct {
PacketNumber protocol.PacketNumber
PacketType protocol.PacketType
Frames []wire.Frame
Length protocol.ByteCount
EncryptionLevel protocol.EncryptionLevel
SendTime time.Time
largestAcked protocol.PacketNumber // if the packet contains an ACK, the LargestAcked value of that ACK
// There are two reasons why a packet cannot be retransmitted:
// * it was already retransmitted
// * this packet is a retransmission, and we already received an ACK for the original packet
canBeRetransmitted bool
includedInBytesInFlight bool
retransmittedAs []protocol.PacketNumber
isRetransmission bool // we need a separate bool here because 0 is a valid packet number
retransmissionOf protocol.PacketNumber
}

View File

@@ -0,0 +1,217 @@
// This file was automatically generated by genny.
// Any changes will be lost if this file is regenerated.
// see https://github.com/cheekybits/genny
package ackhandler
// Linked list implementation from the Go standard library.
// PacketElement is an element of a linked list.
type PacketElement struct {
// Next and previous pointers in the doubly-linked list of elements.
// To simplify the implementation, internally a list l is implemented
// as a ring, such that &l.root is both the next element of the last
// list element (l.Back()) and the previous element of the first list
// element (l.Front()).
next, prev *PacketElement
// The list to which this element belongs.
list *PacketList
// The value stored with this element.
Value Packet
}
// Next returns the next list element or nil.
func (e *PacketElement) Next() *PacketElement {
if p := e.next; e.list != nil && p != &e.list.root {
return p
}
return nil
}
// Prev returns the previous list element or nil.
func (e *PacketElement) Prev() *PacketElement {
if p := e.prev; e.list != nil && p != &e.list.root {
return p
}
return nil
}
// PacketList is a linked list of Packets.
type PacketList struct {
root PacketElement // sentinel list element, only &root, root.prev, and root.next are used
len int // current list length excluding (this) sentinel element
}
// Init initializes or clears list l.
func (l *PacketList) Init() *PacketList {
l.root.next = &l.root
l.root.prev = &l.root
l.len = 0
return l
}
// NewPacketList returns an initialized list.
func NewPacketList() *PacketList { return new(PacketList).Init() }
// Len returns the number of elements of list l.
// The complexity is O(1).
func (l *PacketList) Len() int { return l.len }
// Front returns the first element of list l or nil if the list is empty.
func (l *PacketList) Front() *PacketElement {
if l.len == 0 {
return nil
}
return l.root.next
}
// Back returns the last element of list l or nil if the list is empty.
func (l *PacketList) Back() *PacketElement {
if l.len == 0 {
return nil
}
return l.root.prev
}
// lazyInit lazily initializes a zero List value.
func (l *PacketList) lazyInit() {
if l.root.next == nil {
l.Init()
}
}
// insert inserts e after at, increments l.len, and returns e.
func (l *PacketList) insert(e, at *PacketElement) *PacketElement {
n := at.next
at.next = e
e.prev = at
e.next = n
n.prev = e
e.list = l
l.len++
return e
}
// insertValue is a convenience wrapper for insert(&Element{Value: v}, at).
func (l *PacketList) insertValue(v Packet, at *PacketElement) *PacketElement {
return l.insert(&PacketElement{Value: v}, at)
}
// remove removes e from its list, decrements l.len, and returns e.
func (l *PacketList) remove(e *PacketElement) *PacketElement {
e.prev.next = e.next
e.next.prev = e.prev
e.next = nil // avoid memory leaks
e.prev = nil // avoid memory leaks
e.list = nil
l.len--
return e
}
// Remove removes e from l if e is an element of list l.
// It returns the element value e.Value.
// The element must not be nil.
func (l *PacketList) Remove(e *PacketElement) Packet {
if e.list == l {
// if e.list == l, l must have been initialized when e was inserted
// in l or l == nil (e is a zero Element) and l.remove will crash
l.remove(e)
}
return e.Value
}
// PushFront inserts a new element e with value v at the front of list l and returns e.
func (l *PacketList) PushFront(v Packet) *PacketElement {
l.lazyInit()
return l.insertValue(v, &l.root)
}
// PushBack inserts a new element e with value v at the back of list l and returns e.
func (l *PacketList) PushBack(v Packet) *PacketElement {
l.lazyInit()
return l.insertValue(v, l.root.prev)
}
// InsertBefore inserts a new element e with value v immediately before mark and returns e.
// If mark is not an element of l, the list is not modified.
// The mark must not be nil.
func (l *PacketList) InsertBefore(v Packet, mark *PacketElement) *PacketElement {
if mark.list != l {
return nil
}
// see comment in List.Remove about initialization of l
return l.insertValue(v, mark.prev)
}
// InsertAfter inserts a new element e with value v immediately after mark and returns e.
// If mark is not an element of l, the list is not modified.
// The mark must not be nil.
func (l *PacketList) InsertAfter(v Packet, mark *PacketElement) *PacketElement {
if mark.list != l {
return nil
}
// see comment in List.Remove about initialization of l
return l.insertValue(v, mark)
}
// MoveToFront moves element e to the front of list l.
// If e is not an element of l, the list is not modified.
// The element must not be nil.
func (l *PacketList) MoveToFront(e *PacketElement) {
if e.list != l || l.root.next == e {
return
}
// see comment in List.Remove about initialization of l
l.insert(l.remove(e), &l.root)
}
// MoveToBack moves element e to the back of list l.
// If e is not an element of l, the list is not modified.
// The element must not be nil.
func (l *PacketList) MoveToBack(e *PacketElement) {
if e.list != l || l.root.prev == e {
return
}
// see comment in List.Remove about initialization of l
l.insert(l.remove(e), l.root.prev)
}
// MoveBefore moves element e to its new position before mark.
// If e or mark is not an element of l, or e == mark, the list is not modified.
// The element and mark must not be nil.
func (l *PacketList) MoveBefore(e, mark *PacketElement) {
if e.list != l || e == mark || mark.list != l {
return
}
l.insert(l.remove(e), mark.prev)
}
// MoveAfter moves element e to its new position after mark.
// If e or mark is not an element of l, or e == mark, the list is not modified.
// The element and mark must not be nil.
func (l *PacketList) MoveAfter(e, mark *PacketElement) {
if e.list != l || e == mark || mark.list != l {
return
}
l.insert(l.remove(e), mark)
}
// PushBackList inserts a copy of an other list at the back of list l.
// The lists l and other may be the same. They must not be nil.
func (l *PacketList) PushBackList(other *PacketList) {
l.lazyInit()
for i, e := other.Len(), other.Front(); i > 0; i, e = i-1, e.Next() {
l.insertValue(e.Value, l.root.prev)
}
}
// PushFrontList inserts a copy of an other list at the front of list l.
// The lists l and other may be the same. They must not be nil.
func (l *PacketList) PushFrontList(other *PacketList) {
l.lazyInit()
for i, e := other.Len(), other.Back(); i > 0; i, e = i-1, e.Prev() {
l.insertValue(e.Value, &l.root)
}
}

View File

@@ -0,0 +1,215 @@
package ackhandler
import (
"time"
"github.com/lucas-clemente/quic-go/internal/congestion"
"github.com/lucas-clemente/quic-go/internal/protocol"
"github.com/lucas-clemente/quic-go/internal/utils"
"github.com/lucas-clemente/quic-go/internal/wire"
)
type receivedPacketHandler struct {
largestObserved protocol.PacketNumber
ignoreBelow protocol.PacketNumber
largestObservedReceivedTime time.Time
packetHistory *receivedPacketHistory
ackSendDelay time.Duration
rttStats *congestion.RTTStats
packetsReceivedSinceLastAck int
retransmittablePacketsReceivedSinceLastAck int
ackQueued bool
ackAlarm time.Time
lastAck *wire.AckFrame
logger utils.Logger
version protocol.VersionNumber
}
const (
// maximum delay that can be applied to an ACK for a retransmittable packet
ackSendDelay = 25 * time.Millisecond
// initial maximum number of retransmittable packets received before sending an ack.
initialRetransmittablePacketsBeforeAck = 2
// number of retransmittable that an ACK is sent for
retransmittablePacketsBeforeAck = 10
// 1/5 RTT delay when doing ack decimation
ackDecimationDelay = 1.0 / 4
// 1/8 RTT delay when doing ack decimation
shortAckDecimationDelay = 1.0 / 8
// Minimum number of packets received before ack decimation is enabled.
// This intends to avoid the beginning of slow start, when CWNDs may be
// rapidly increasing.
minReceivedBeforeAckDecimation = 100
// Maximum number of packets to ack immediately after a missing packet for
// fast retransmission to kick in at the sender. This limit is created to
// reduce the number of acks sent that have no benefit for fast retransmission.
// Set to the number of nacks needed for fast retransmit plus one for protection
// against an ack loss
maxPacketsAfterNewMissing = 4
)
// NewReceivedPacketHandler creates a new receivedPacketHandler
func NewReceivedPacketHandler(
rttStats *congestion.RTTStats,
logger utils.Logger,
version protocol.VersionNumber,
) ReceivedPacketHandler {
return &receivedPacketHandler{
packetHistory: newReceivedPacketHistory(),
ackSendDelay: ackSendDelay,
rttStats: rttStats,
logger: logger,
version: version,
}
}
func (h *receivedPacketHandler) ReceivedPacket(packetNumber protocol.PacketNumber, rcvTime time.Time, shouldInstigateAck bool) error {
if packetNumber < h.ignoreBelow {
return nil
}
isMissing := h.isMissing(packetNumber)
if packetNumber > h.largestObserved {
h.largestObserved = packetNumber
h.largestObservedReceivedTime = rcvTime
}
if err := h.packetHistory.ReceivedPacket(packetNumber); err != nil {
return err
}
h.maybeQueueAck(packetNumber, rcvTime, shouldInstigateAck, isMissing)
return nil
}
// IgnoreBelow sets a lower limit for acking packets.
// Packets with packet numbers smaller than p will not be acked.
func (h *receivedPacketHandler) IgnoreBelow(p protocol.PacketNumber) {
if p <= h.ignoreBelow {
return
}
h.ignoreBelow = p
h.packetHistory.DeleteBelow(p)
if h.logger.Debug() {
h.logger.Debugf("\tIgnoring all packets below %#x.", p)
}
}
// isMissing says if a packet was reported missing in the last ACK.
func (h *receivedPacketHandler) isMissing(p protocol.PacketNumber) bool {
if h.lastAck == nil || p < h.ignoreBelow {
return false
}
return p < h.lastAck.LargestAcked() && !h.lastAck.AcksPacket(p)
}
func (h *receivedPacketHandler) hasNewMissingPackets() bool {
if h.lastAck == nil {
return false
}
highestRange := h.packetHistory.GetHighestAckRange()
return highestRange.Smallest >= h.lastAck.LargestAcked() && highestRange.Len() <= maxPacketsAfterNewMissing
}
// maybeQueueAck queues an ACK, if necessary.
// It is implemented analogously to Chrome's QuicConnection::MaybeQueueAck()
// in ACK_DECIMATION_WITH_REORDERING mode.
func (h *receivedPacketHandler) maybeQueueAck(packetNumber protocol.PacketNumber, rcvTime time.Time, shouldInstigateAck, wasMissing bool) {
h.packetsReceivedSinceLastAck++
// always ack the first packet
if h.lastAck == nil {
h.logger.Debugf("\tQueueing ACK because the first packet should be acknowledged.")
h.ackQueued = true
return
}
// Send an ACK if this packet was reported missing in an ACK sent before.
// Ack decimation with reordering relies on the timer to send an ACK, but if
// missing packets we reported in the previous ack, send an ACK immediately.
if wasMissing {
if h.logger.Debug() {
h.logger.Debugf("\tQueueing ACK because packet %#x was missing before.", packetNumber)
}
h.ackQueued = true
}
if !h.ackQueued && shouldInstigateAck {
h.retransmittablePacketsReceivedSinceLastAck++
if packetNumber > minReceivedBeforeAckDecimation {
// ack up to 10 packets at once
if h.retransmittablePacketsReceivedSinceLastAck >= retransmittablePacketsBeforeAck {
h.ackQueued = true
if h.logger.Debug() {
h.logger.Debugf("\tQueueing ACK because packet %d packets were received after the last ACK (using threshold: %d).", h.retransmittablePacketsReceivedSinceLastAck, retransmittablePacketsBeforeAck)
}
} else if h.ackAlarm.IsZero() {
// wait for the minimum of the ack decimation delay or the delayed ack time before sending an ack
ackDelay := utils.MinDuration(ackSendDelay, time.Duration(float64(h.rttStats.MinRTT())*float64(ackDecimationDelay)))
h.ackAlarm = rcvTime.Add(ackDelay)
if h.logger.Debug() {
h.logger.Debugf("\tSetting ACK timer to min(1/4 min-RTT, max ack delay): %s (%s from now)", ackDelay, time.Until(h.ackAlarm))
}
}
} else {
// send an ACK every 2 retransmittable packets
if h.retransmittablePacketsReceivedSinceLastAck >= initialRetransmittablePacketsBeforeAck {
if h.logger.Debug() {
h.logger.Debugf("\tQueueing ACK because packet %d packets were received after the last ACK (using initial threshold: %d).", h.retransmittablePacketsReceivedSinceLastAck, initialRetransmittablePacketsBeforeAck)
}
h.ackQueued = true
} else if h.ackAlarm.IsZero() {
if h.logger.Debug() {
h.logger.Debugf("\tSetting ACK timer to max ack delay: %s", ackSendDelay)
}
h.ackAlarm = rcvTime.Add(ackSendDelay)
}
}
// If there are new missing packets to report, set a short timer to send an ACK.
if h.hasNewMissingPackets() {
// wait the minimum of 1/8 min RTT and the existing ack time
ackDelay := time.Duration(float64(h.rttStats.MinRTT()) * float64(shortAckDecimationDelay))
ackTime := rcvTime.Add(ackDelay)
if h.ackAlarm.IsZero() || h.ackAlarm.After(ackTime) {
h.ackAlarm = ackTime
if h.logger.Debug() {
h.logger.Debugf("\tSetting ACK timer to 1/8 min-RTT: %s (%s from now)", ackDelay, time.Until(h.ackAlarm))
}
}
}
}
if h.ackQueued {
// cancel the ack alarm
h.ackAlarm = time.Time{}
}
}
func (h *receivedPacketHandler) GetAckFrame() *wire.AckFrame {
now := time.Now()
if !h.ackQueued && (h.ackAlarm.IsZero() || h.ackAlarm.After(now)) {
return nil
}
if h.logger.Debug() && !h.ackQueued && !h.ackAlarm.IsZero() {
h.logger.Debugf("Sending ACK because the ACK timer expired.")
}
ack := &wire.AckFrame{
AckRanges: h.packetHistory.GetAckRanges(),
DelayTime: now.Sub(h.largestObservedReceivedTime),
}
h.lastAck = ack
h.ackAlarm = time.Time{}
h.ackQueued = false
h.packetsReceivedSinceLastAck = 0
h.retransmittablePacketsReceivedSinceLastAck = 0
return ack
}
func (h *receivedPacketHandler) GetAlarmTimeout() time.Time { return h.ackAlarm }

View File

@@ -0,0 +1,121 @@
package ackhandler
import (
"github.com/lucas-clemente/quic-go/internal/protocol"
"github.com/lucas-clemente/quic-go/internal/utils"
"github.com/lucas-clemente/quic-go/internal/wire"
"github.com/lucas-clemente/quic-go/qerr"
)
// The receivedPacketHistory stores if a packet number has already been received.
// It does not store packet contents.
type receivedPacketHistory struct {
ranges *utils.PacketIntervalList
lowestInReceivedPacketNumbers protocol.PacketNumber
}
var errTooManyOutstandingReceivedAckRanges = qerr.Error(qerr.TooManyOutstandingReceivedPackets, "Too many outstanding received ACK ranges")
// newReceivedPacketHistory creates a new received packet history
func newReceivedPacketHistory() *receivedPacketHistory {
return &receivedPacketHistory{
ranges: utils.NewPacketIntervalList(),
}
}
// ReceivedPacket registers a packet with PacketNumber p and updates the ranges
func (h *receivedPacketHistory) ReceivedPacket(p protocol.PacketNumber) error {
if h.ranges.Len() >= protocol.MaxTrackedReceivedAckRanges {
return errTooManyOutstandingReceivedAckRanges
}
if h.ranges.Len() == 0 {
h.ranges.PushBack(utils.PacketInterval{Start: p, End: p})
return nil
}
for el := h.ranges.Back(); el != nil; el = el.Prev() {
// p already included in an existing range. Nothing to do here
if p >= el.Value.Start && p <= el.Value.End {
return nil
}
var rangeExtended bool
if el.Value.End == p-1 { // extend a range at the end
rangeExtended = true
el.Value.End = p
} else if el.Value.Start == p+1 { // extend a range at the beginning
rangeExtended = true
el.Value.Start = p
}
// if a range was extended (either at the beginning or at the end, maybe it is possible to merge two ranges into one)
if rangeExtended {
prev := el.Prev()
if prev != nil && prev.Value.End+1 == el.Value.Start { // merge two ranges
prev.Value.End = el.Value.End
h.ranges.Remove(el)
return nil
}
return nil // if the two ranges were not merge, we're done here
}
// create a new range at the end
if p > el.Value.End {
h.ranges.InsertAfter(utils.PacketInterval{Start: p, End: p}, el)
return nil
}
}
// create a new range at the beginning
h.ranges.InsertBefore(utils.PacketInterval{Start: p, End: p}, h.ranges.Front())
return nil
}
// DeleteBelow deletes all entries below (but not including) p
func (h *receivedPacketHistory) DeleteBelow(p protocol.PacketNumber) {
if p <= h.lowestInReceivedPacketNumbers {
return
}
h.lowestInReceivedPacketNumbers = p
nextEl := h.ranges.Front()
for el := h.ranges.Front(); nextEl != nil; el = nextEl {
nextEl = el.Next()
if p > el.Value.Start && p <= el.Value.End {
el.Value.Start = p
} else if el.Value.End < p { // delete a whole range
h.ranges.Remove(el)
} else { // no ranges affected. Nothing to do
return
}
}
}
// GetAckRanges gets a slice of all AckRanges that can be used in an AckFrame
func (h *receivedPacketHistory) GetAckRanges() []wire.AckRange {
if h.ranges.Len() == 0 {
return nil
}
ackRanges := make([]wire.AckRange, h.ranges.Len())
i := 0
for el := h.ranges.Back(); el != nil; el = el.Prev() {
ackRanges[i] = wire.AckRange{Smallest: el.Value.Start, Largest: el.Value.End}
i++
}
return ackRanges
}
func (h *receivedPacketHistory) GetHighestAckRange() wire.AckRange {
ackRange := wire.AckRange{}
if h.ranges.Len() > 0 {
r := h.ranges.Back().Value
ackRange.Smallest = r.Start
ackRange.Largest = r.End
}
return ackRange
}

View File

@@ -0,0 +1,36 @@
package ackhandler
import "github.com/lucas-clemente/quic-go/internal/wire"
// Returns a new slice with all non-retransmittable frames deleted.
func stripNonRetransmittableFrames(fs []wire.Frame) []wire.Frame {
res := make([]wire.Frame, 0, len(fs))
for _, f := range fs {
if IsFrameRetransmittable(f) {
res = append(res, f)
}
}
return res
}
// IsFrameRetransmittable returns true if the frame should be retransmitted.
func IsFrameRetransmittable(f wire.Frame) bool {
switch f.(type) {
case *wire.StopWaitingFrame:
return false
case *wire.AckFrame:
return false
default:
return true
}
}
// HasRetransmittableFrames returns true if at least one frame is retransmittable.
func HasRetransmittableFrames(fs []wire.Frame) bool {
for _, f := range fs {
if IsFrameRetransmittable(f) {
return true
}
}
return false
}

View File

@@ -0,0 +1,40 @@
package ackhandler
import "fmt"
// The SendMode says what kind of packets can be sent.
type SendMode uint8
const (
// SendNone means that no packets should be sent
SendNone SendMode = iota
// SendAck means an ACK-only packet should be sent
SendAck
// SendRetransmission means that retransmissions should be sent
SendRetransmission
// SendRTO means that an RTO probe packet should be sent
SendRTO
// SendTLP means that a TLP probe packet should be sent
SendTLP
// SendAny means that any packet should be sent
SendAny
)
func (s SendMode) String() string {
switch s {
case SendNone:
return "none"
case SendAck:
return "ack"
case SendRetransmission:
return "retransmission"
case SendRTO:
return "rto"
case SendTLP:
return "tlp"
case SendAny:
return "any"
default:
return fmt.Sprintf("invalid send mode: %d", s)
}
}

View File

@@ -0,0 +1,659 @@
package ackhandler
import (
"errors"
"fmt"
"math"
"time"
"github.com/lucas-clemente/quic-go/internal/congestion"
"github.com/lucas-clemente/quic-go/internal/protocol"
"github.com/lucas-clemente/quic-go/internal/utils"
"github.com/lucas-clemente/quic-go/internal/wire"
"github.com/lucas-clemente/quic-go/qerr"
)
const (
// Maximum reordering in time space before time based loss detection considers a packet lost.
// In fraction of an RTT.
timeReorderingFraction = 1.0 / 8
// defaultRTOTimeout is the RTO time on new connections
defaultRTOTimeout = 500 * time.Millisecond
// Minimum time in the future a tail loss probe alarm may be set for.
minTPLTimeout = 10 * time.Millisecond
// Maximum number of tail loss probes before an RTO fires.
maxTLPs = 2
// Minimum time in the future an RTO alarm may be set for.
minRTOTimeout = 200 * time.Millisecond
// maxRTOTimeout is the maximum RTO time
maxRTOTimeout = 60 * time.Second
)
type sentPacketHandler struct {
lastSentPacketNumber protocol.PacketNumber
lastSentRetransmittablePacketTime time.Time
lastSentHandshakePacketTime time.Time
nextPacketSendTime time.Time
skippedPackets []protocol.PacketNumber
largestAcked protocol.PacketNumber
largestReceivedPacketWithAck protocol.PacketNumber
// lowestPacketNotConfirmedAcked is the lowest packet number that we sent an ACK for, but haven't received confirmation, that this ACK actually arrived
// example: we send an ACK for packets 90-100 with packet number 20
// once we receive an ACK from the peer for packet 20, the lowestPacketNotConfirmedAcked is 101
lowestPacketNotConfirmedAcked protocol.PacketNumber
largestSentBeforeRTO protocol.PacketNumber
packetHistory *sentPacketHistory
stopWaitingManager stopWaitingManager
retransmissionQueue []*Packet
bytesInFlight protocol.ByteCount
congestion congestion.SendAlgorithm
rttStats *congestion.RTTStats
handshakeComplete bool
// The number of times the handshake packets have been retransmitted without receiving an ack.
handshakeCount uint32
// The number of times a TLP has been sent without receiving an ack.
tlpCount uint32
allowTLP bool
// The number of times an RTO has been sent without receiving an ack.
rtoCount uint32
// The number of RTO probe packets that should be sent.
numRTOs int
// The time at which the next packet will be considered lost based on early transmit or exceeding the reordering window in time.
lossTime time.Time
// The alarm timeout
alarm time.Time
logger utils.Logger
version protocol.VersionNumber
}
// NewSentPacketHandler creates a new sentPacketHandler
func NewSentPacketHandler(rttStats *congestion.RTTStats, logger utils.Logger, version protocol.VersionNumber) SentPacketHandler {
congestion := congestion.NewCubicSender(
congestion.DefaultClock{},
rttStats,
false, /* don't use reno since chromium doesn't (why?) */
protocol.InitialCongestionWindow,
protocol.DefaultMaxCongestionWindow,
)
return &sentPacketHandler{
packetHistory: newSentPacketHistory(),
stopWaitingManager: stopWaitingManager{},
rttStats: rttStats,
congestion: congestion,
logger: logger,
version: version,
}
}
func (h *sentPacketHandler) lowestUnacked() protocol.PacketNumber {
if p := h.packetHistory.FirstOutstanding(); p != nil {
return p.PacketNumber
}
return h.largestAcked + 1
}
func (h *sentPacketHandler) SetHandshakeComplete() {
h.logger.Debugf("Handshake complete. Discarding all outstanding handshake packets.")
var queue []*Packet
for _, packet := range h.retransmissionQueue {
if packet.EncryptionLevel == protocol.EncryptionForwardSecure {
queue = append(queue, packet)
}
}
var handshakePackets []*Packet
h.packetHistory.Iterate(func(p *Packet) (bool, error) {
if p.EncryptionLevel != protocol.EncryptionForwardSecure {
handshakePackets = append(handshakePackets, p)
}
return true, nil
})
for _, p := range handshakePackets {
h.packetHistory.Remove(p.PacketNumber)
}
h.retransmissionQueue = queue
h.handshakeComplete = true
}
func (h *sentPacketHandler) SentPacket(packet *Packet) {
if isRetransmittable := h.sentPacketImpl(packet); isRetransmittable {
h.packetHistory.SentPacket(packet)
h.updateLossDetectionAlarm()
}
}
func (h *sentPacketHandler) SentPacketsAsRetransmission(packets []*Packet, retransmissionOf protocol.PacketNumber) {
var p []*Packet
for _, packet := range packets {
if isRetransmittable := h.sentPacketImpl(packet); isRetransmittable {
p = append(p, packet)
}
}
h.packetHistory.SentPacketsAsRetransmission(p, retransmissionOf)
h.updateLossDetectionAlarm()
}
func (h *sentPacketHandler) sentPacketImpl(packet *Packet) bool /* isRetransmittable */ {
for p := h.lastSentPacketNumber + 1; p < packet.PacketNumber; p++ {
h.logger.Debugf("Skipping packet number %#x", p)
h.skippedPackets = append(h.skippedPackets, p)
if len(h.skippedPackets) > protocol.MaxTrackedSkippedPackets {
h.skippedPackets = h.skippedPackets[1:]
}
}
h.lastSentPacketNumber = packet.PacketNumber
if len(packet.Frames) > 0 {
if ackFrame, ok := packet.Frames[0].(*wire.AckFrame); ok {
packet.largestAcked = ackFrame.LargestAcked()
}
}
packet.Frames = stripNonRetransmittableFrames(packet.Frames)
isRetransmittable := len(packet.Frames) != 0
if isRetransmittable {
if packet.EncryptionLevel < protocol.EncryptionForwardSecure {
h.lastSentHandshakePacketTime = packet.SendTime
}
h.lastSentRetransmittablePacketTime = packet.SendTime
packet.includedInBytesInFlight = true
h.bytesInFlight += packet.Length
packet.canBeRetransmitted = true
if h.numRTOs > 0 {
h.numRTOs--
}
h.allowTLP = false
}
h.congestion.OnPacketSent(packet.SendTime, h.bytesInFlight, packet.PacketNumber, packet.Length, isRetransmittable)
h.nextPacketSendTime = utils.MaxTime(h.nextPacketSendTime, packet.SendTime).Add(h.congestion.TimeUntilSend(h.bytesInFlight))
return isRetransmittable
}
func (h *sentPacketHandler) ReceivedAck(ackFrame *wire.AckFrame, withPacketNumber protocol.PacketNumber, encLevel protocol.EncryptionLevel, rcvTime time.Time) error {
largestAcked := ackFrame.LargestAcked()
if largestAcked > h.lastSentPacketNumber {
return qerr.Error(qerr.InvalidAckData, "Received ACK for an unsent package")
}
// duplicate or out of order ACK
if withPacketNumber != 0 && withPacketNumber <= h.largestReceivedPacketWithAck {
h.logger.Debugf("Ignoring ACK frame (duplicate or out of order).")
return nil
}
h.largestReceivedPacketWithAck = withPacketNumber
h.largestAcked = utils.MaxPacketNumber(h.largestAcked, largestAcked)
if h.skippedPacketsAcked(ackFrame) {
return qerr.Error(qerr.InvalidAckData, "Received an ACK for a skipped packet number")
}
if rttUpdated := h.maybeUpdateRTT(largestAcked, ackFrame.DelayTime, rcvTime); rttUpdated {
h.congestion.MaybeExitSlowStart()
}
ackedPackets, err := h.determineNewlyAckedPackets(ackFrame)
if err != nil {
return err
}
priorInFlight := h.bytesInFlight
for _, p := range ackedPackets {
if encLevel < p.EncryptionLevel {
return fmt.Errorf("Received ACK with encryption level %s that acks a packet %d (encryption level %s)", encLevel, p.PacketNumber, p.EncryptionLevel)
}
// largestAcked == 0 either means that the packet didn't contain an ACK, or it just acked packet 0
// It is safe to ignore the corner case of packets that just acked packet 0, because
// the lowestPacketNotConfirmedAcked is only used to limit the number of ACK ranges we will send.
if p.largestAcked != 0 {
h.lowestPacketNotConfirmedAcked = utils.MaxPacketNumber(h.lowestPacketNotConfirmedAcked, p.largestAcked+1)
}
if err := h.onPacketAcked(p, rcvTime); err != nil {
return err
}
if p.includedInBytesInFlight {
h.congestion.OnPacketAcked(p.PacketNumber, p.Length, priorInFlight, rcvTime)
}
}
if err := h.detectLostPackets(rcvTime, priorInFlight); err != nil {
return err
}
h.updateLossDetectionAlarm()
h.garbageCollectSkippedPackets()
h.stopWaitingManager.ReceivedAck(ackFrame)
return nil
}
func (h *sentPacketHandler) GetLowestPacketNotConfirmedAcked() protocol.PacketNumber {
return h.lowestPacketNotConfirmedAcked
}
func (h *sentPacketHandler) determineNewlyAckedPackets(ackFrame *wire.AckFrame) ([]*Packet, error) {
var ackedPackets []*Packet
ackRangeIndex := 0
lowestAcked := ackFrame.LowestAcked()
largestAcked := ackFrame.LargestAcked()
err := h.packetHistory.Iterate(func(p *Packet) (bool, error) {
// Ignore packets below the lowest acked
if p.PacketNumber < lowestAcked {
return true, nil
}
// Break after largest acked is reached
if p.PacketNumber > largestAcked {
return false, nil
}
if ackFrame.HasMissingRanges() {
ackRange := ackFrame.AckRanges[len(ackFrame.AckRanges)-1-ackRangeIndex]
for p.PacketNumber > ackRange.Largest && ackRangeIndex < len(ackFrame.AckRanges)-1 {
ackRangeIndex++
ackRange = ackFrame.AckRanges[len(ackFrame.AckRanges)-1-ackRangeIndex]
}
if p.PacketNumber >= ackRange.Smallest { // packet i contained in ACK range
if p.PacketNumber > ackRange.Largest {
return false, fmt.Errorf("BUG: ackhandler would have acked wrong packet 0x%x, while evaluating range 0x%x -> 0x%x", p.PacketNumber, ackRange.Smallest, ackRange.Largest)
}
ackedPackets = append(ackedPackets, p)
}
} else {
ackedPackets = append(ackedPackets, p)
}
return true, nil
})
if h.logger.Debug() && len(ackedPackets) > 0 {
pns := make([]protocol.PacketNumber, len(ackedPackets))
for i, p := range ackedPackets {
pns[i] = p.PacketNumber
}
h.logger.Debugf("\tnewly acked packets (%d): %#x", len(pns), pns)
}
return ackedPackets, err
}
func (h *sentPacketHandler) maybeUpdateRTT(largestAcked protocol.PacketNumber, ackDelay time.Duration, rcvTime time.Time) bool {
if p := h.packetHistory.GetPacket(largestAcked); p != nil {
h.rttStats.UpdateRTT(rcvTime.Sub(p.SendTime), ackDelay, rcvTime)
if h.logger.Debug() {
h.logger.Debugf("\tupdated RTT: %s (σ: %s)", h.rttStats.SmoothedRTT(), h.rttStats.MeanDeviation())
}
return true
}
return false
}
func (h *sentPacketHandler) updateLossDetectionAlarm() {
// Cancel the alarm if no packets are outstanding
if !h.packetHistory.HasOutstandingPackets() {
h.alarm = time.Time{}
return
}
if h.packetHistory.HasOutstandingHandshakePackets() {
h.alarm = h.lastSentHandshakePacketTime.Add(h.computeHandshakeTimeout())
} else if !h.lossTime.IsZero() {
// Early retransmit timer or time loss detection.
h.alarm = h.lossTime
} else {
// RTO or TLP alarm
alarmDuration := h.computeRTOTimeout()
if h.tlpCount < maxTLPs {
tlpAlarm := h.computeTLPTimeout()
// if the RTO duration is shorter than the TLP duration, use the RTO duration
alarmDuration = utils.MinDuration(alarmDuration, tlpAlarm)
}
h.alarm = h.lastSentRetransmittablePacketTime.Add(alarmDuration)
}
}
func (h *sentPacketHandler) detectLostPackets(now time.Time, priorInFlight protocol.ByteCount) error {
h.lossTime = time.Time{}
maxRTT := float64(utils.MaxDuration(h.rttStats.LatestRTT(), h.rttStats.SmoothedRTT()))
delayUntilLost := time.Duration((1.0 + timeReorderingFraction) * maxRTT)
var lostPackets []*Packet
h.packetHistory.Iterate(func(packet *Packet) (bool, error) {
if packet.PacketNumber > h.largestAcked {
return false, nil
}
timeSinceSent := now.Sub(packet.SendTime)
if timeSinceSent > delayUntilLost {
lostPackets = append(lostPackets, packet)
} else if h.lossTime.IsZero() {
if h.logger.Debug() {
h.logger.Debugf("\tsetting loss timer for packet %#x to %s (in %s)", packet.PacketNumber, delayUntilLost, delayUntilLost-timeSinceSent)
}
// Note: This conditional is only entered once per call
h.lossTime = now.Add(delayUntilLost - timeSinceSent)
}
return true, nil
})
if h.logger.Debug() && len(lostPackets) > 0 {
pns := make([]protocol.PacketNumber, len(lostPackets))
for i, p := range lostPackets {
pns[i] = p.PacketNumber
}
h.logger.Debugf("\tlost packets (%d): %#x", len(pns), pns)
}
for _, p := range lostPackets {
// the bytes in flight need to be reduced no matter if this packet will be retransmitted
if p.includedInBytesInFlight {
h.bytesInFlight -= p.Length
h.congestion.OnPacketLost(p.PacketNumber, p.Length, priorInFlight)
}
if p.canBeRetransmitted {
// queue the packet for retransmission, and report the loss to the congestion controller
if err := h.queuePacketForRetransmission(p); err != nil {
return err
}
}
h.packetHistory.Remove(p.PacketNumber)
}
return nil
}
func (h *sentPacketHandler) OnAlarm() error {
// When all outstanding are acknowledged, the alarm is canceled in
// updateLossDetectionAlarm. This doesn't reset the timer in the session though.
// When OnAlarm is called, we therefore need to make sure that there are
// actually packets outstanding.
if h.packetHistory.HasOutstandingPackets() {
if err := h.onVerifiedAlarm(); err != nil {
return err
}
}
h.updateLossDetectionAlarm()
return nil
}
func (h *sentPacketHandler) onVerifiedAlarm() error {
var err error
if h.packetHistory.HasOutstandingHandshakePackets() {
if h.logger.Debug() {
h.logger.Debugf("Loss detection alarm fired in handshake mode. Handshake count: %d", h.handshakeCount)
}
h.handshakeCount++
err = h.queueHandshakePacketsForRetransmission()
} else if !h.lossTime.IsZero() {
if h.logger.Debug() {
h.logger.Debugf("Loss detection alarm fired in loss timer mode. Loss time: %s", h.lossTime)
}
// Early retransmit or time loss detection
err = h.detectLostPackets(time.Now(), h.bytesInFlight)
} else if h.tlpCount < maxTLPs { // TLP
if h.logger.Debug() {
h.logger.Debugf("Loss detection alarm fired in TLP mode. TLP count: %d", h.tlpCount)
}
h.allowTLP = true
h.tlpCount++
} else { // RTO
if h.logger.Debug() {
h.logger.Debugf("Loss detection alarm fired in RTO mode. RTO count: %d", h.rtoCount)
}
if h.rtoCount == 0 {
h.largestSentBeforeRTO = h.lastSentPacketNumber
}
h.rtoCount++
h.numRTOs += 2
}
return err
}
func (h *sentPacketHandler) GetAlarmTimeout() time.Time {
return h.alarm
}
func (h *sentPacketHandler) onPacketAcked(p *Packet, rcvTime time.Time) error {
// This happens if a packet and its retransmissions is acked in the same ACK.
// As soon as we process the first one, this will remove all the retransmissions,
// so we won't find the retransmitted packet number later.
if packet := h.packetHistory.GetPacket(p.PacketNumber); packet == nil {
return nil
}
// only report the acking of this packet to the congestion controller if:
// * it is a retransmittable packet
// * this packet wasn't retransmitted yet
if p.isRetransmission {
// that the parent doesn't exist is expected to happen every time the original packet was already acked
if parent := h.packetHistory.GetPacket(p.retransmissionOf); parent != nil {
if len(parent.retransmittedAs) == 1 {
parent.retransmittedAs = nil
} else {
// remove this packet from the slice of retransmission
retransmittedAs := make([]protocol.PacketNumber, 0, len(parent.retransmittedAs)-1)
for _, pn := range parent.retransmittedAs {
if pn != p.PacketNumber {
retransmittedAs = append(retransmittedAs, pn)
}
}
parent.retransmittedAs = retransmittedAs
}
}
}
// this also applies to packets that have been retransmitted as probe packets
if p.includedInBytesInFlight {
h.bytesInFlight -= p.Length
}
if h.rtoCount > 0 {
h.verifyRTO(p.PacketNumber)
}
if err := h.stopRetransmissionsFor(p); err != nil {
return err
}
h.rtoCount = 0
h.tlpCount = 0
h.handshakeCount = 0
return h.packetHistory.Remove(p.PacketNumber)
}
func (h *sentPacketHandler) stopRetransmissionsFor(p *Packet) error {
if err := h.packetHistory.MarkCannotBeRetransmitted(p.PacketNumber); err != nil {
return err
}
for _, r := range p.retransmittedAs {
packet := h.packetHistory.GetPacket(r)
if packet == nil {
return fmt.Errorf("sent packet handler BUG: marking packet as not retransmittable %d (retransmission of %d) not found in history", r, p.PacketNumber)
}
h.stopRetransmissionsFor(packet)
}
return nil
}
func (h *sentPacketHandler) verifyRTO(pn protocol.PacketNumber) {
if pn <= h.largestSentBeforeRTO {
h.logger.Debugf("Spurious RTO detected. Received an ACK for %#x (largest sent before RTO: %#x)", pn, h.largestSentBeforeRTO)
// Replace SRTT with latest_rtt and increase the variance to prevent
// a spurious RTO from happening again.
h.rttStats.ExpireSmoothedMetrics()
return
}
h.logger.Debugf("RTO verified. Received an ACK for %#x (largest sent before RTO: %#x", pn, h.largestSentBeforeRTO)
h.congestion.OnRetransmissionTimeout(true)
}
func (h *sentPacketHandler) DequeuePacketForRetransmission() *Packet {
if len(h.retransmissionQueue) == 0 {
return nil
}
packet := h.retransmissionQueue[0]
// Shift the slice and don't retain anything that isn't needed.
copy(h.retransmissionQueue, h.retransmissionQueue[1:])
h.retransmissionQueue[len(h.retransmissionQueue)-1] = nil
h.retransmissionQueue = h.retransmissionQueue[:len(h.retransmissionQueue)-1]
return packet
}
func (h *sentPacketHandler) DequeueProbePacket() (*Packet, error) {
if len(h.retransmissionQueue) == 0 {
p := h.packetHistory.FirstOutstanding()
if p == nil {
return nil, errors.New("cannot dequeue a probe packet. No outstanding packets")
}
if err := h.queuePacketForRetransmission(p); err != nil {
return nil, err
}
}
return h.DequeuePacketForRetransmission(), nil
}
func (h *sentPacketHandler) GetPacketNumberLen(p protocol.PacketNumber) protocol.PacketNumberLen {
return protocol.GetPacketNumberLengthForHeader(p, h.lowestUnacked(), h.version)
}
func (h *sentPacketHandler) GetStopWaitingFrame(force bool) *wire.StopWaitingFrame {
return h.stopWaitingManager.GetStopWaitingFrame(force)
}
func (h *sentPacketHandler) SendMode() SendMode {
numTrackedPackets := len(h.retransmissionQueue) + h.packetHistory.Len()
// Don't send any packets if we're keeping track of the maximum number of packets.
// Note that since MaxOutstandingSentPackets is smaller than MaxTrackedSentPackets,
// we will stop sending out new data when reaching MaxOutstandingSentPackets,
// but still allow sending of retransmissions and ACKs.
if numTrackedPackets >= protocol.MaxTrackedSentPackets {
if h.logger.Debug() {
h.logger.Debugf("Limited by the number of tracked packets: tracking %d packets, maximum %d", numTrackedPackets, protocol.MaxTrackedSentPackets)
}
return SendNone
}
if h.allowTLP {
return SendTLP
}
if h.numRTOs > 0 {
return SendRTO
}
// Only send ACKs if we're congestion limited.
if cwnd := h.congestion.GetCongestionWindow(); h.bytesInFlight > cwnd {
if h.logger.Debug() {
h.logger.Debugf("Congestion limited: bytes in flight %d, window %d", h.bytesInFlight, cwnd)
}
return SendAck
}
// Send retransmissions first, if there are any.
if len(h.retransmissionQueue) > 0 {
return SendRetransmission
}
if numTrackedPackets >= protocol.MaxOutstandingSentPackets {
if h.logger.Debug() {
h.logger.Debugf("Max outstanding limited: tracking %d packets, maximum: %d", numTrackedPackets, protocol.MaxOutstandingSentPackets)
}
return SendAck
}
return SendAny
}
func (h *sentPacketHandler) TimeUntilSend() time.Time {
return h.nextPacketSendTime
}
func (h *sentPacketHandler) ShouldSendNumPackets() int {
if h.numRTOs > 0 {
// RTO probes should not be paced, but must be sent immediately.
return h.numRTOs
}
delay := h.congestion.TimeUntilSend(h.bytesInFlight)
if delay == 0 || delay > protocol.MinPacingDelay {
return 1
}
return int(math.Ceil(float64(protocol.MinPacingDelay) / float64(delay)))
}
func (h *sentPacketHandler) queueHandshakePacketsForRetransmission() error {
var handshakePackets []*Packet
h.packetHistory.Iterate(func(p *Packet) (bool, error) {
if p.canBeRetransmitted && p.EncryptionLevel < protocol.EncryptionForwardSecure {
handshakePackets = append(handshakePackets, p)
}
return true, nil
})
for _, p := range handshakePackets {
h.logger.Debugf("Queueing packet %#x as a handshake retransmission", p.PacketNumber)
if err := h.queuePacketForRetransmission(p); err != nil {
return err
}
}
return nil
}
func (h *sentPacketHandler) queuePacketForRetransmission(p *Packet) error {
if !p.canBeRetransmitted {
return fmt.Errorf("sent packet handler BUG: packet %d already queued for retransmission", p.PacketNumber)
}
if err := h.packetHistory.MarkCannotBeRetransmitted(p.PacketNumber); err != nil {
return err
}
h.retransmissionQueue = append(h.retransmissionQueue, p)
h.stopWaitingManager.QueuedRetransmissionForPacketNumber(p.PacketNumber)
return nil
}
func (h *sentPacketHandler) computeHandshakeTimeout() time.Duration {
duration := utils.MaxDuration(2*h.rttStats.SmoothedOrInitialRTT(), minTPLTimeout)
// exponential backoff
// There's an implicit limit to this set by the handshake timeout.
return duration << h.handshakeCount
}
func (h *sentPacketHandler) computeTLPTimeout() time.Duration {
// TODO(#1236): include the max_ack_delay
return utils.MaxDuration(h.rttStats.SmoothedOrInitialRTT()*3/2, minTPLTimeout)
}
func (h *sentPacketHandler) computeRTOTimeout() time.Duration {
var rto time.Duration
rtt := h.rttStats.SmoothedRTT()
if rtt == 0 {
rto = defaultRTOTimeout
} else {
rto = rtt + 4*h.rttStats.MeanDeviation()
}
rto = utils.MaxDuration(rto, minRTOTimeout)
// Exponential backoff
rto <<= h.rtoCount
return utils.MinDuration(rto, maxRTOTimeout)
}
func (h *sentPacketHandler) skippedPacketsAcked(ackFrame *wire.AckFrame) bool {
for _, p := range h.skippedPackets {
if ackFrame.AcksPacket(p) {
return true
}
}
return false
}
func (h *sentPacketHandler) garbageCollectSkippedPackets() {
lowestUnacked := h.lowestUnacked()
deleteIndex := 0
for i, p := range h.skippedPackets {
if p < lowestUnacked {
deleteIndex = i + 1
}
}
h.skippedPackets = h.skippedPackets[deleteIndex:]
}

View File

@@ -0,0 +1,168 @@
package ackhandler
import (
"fmt"
"github.com/lucas-clemente/quic-go/internal/protocol"
)
type sentPacketHistory struct {
packetList *PacketList
packetMap map[protocol.PacketNumber]*PacketElement
numOutstandingPackets int
numOutstandingHandshakePackets int
firstOutstanding *PacketElement
}
func newSentPacketHistory() *sentPacketHistory {
return &sentPacketHistory{
packetList: NewPacketList(),
packetMap: make(map[protocol.PacketNumber]*PacketElement),
}
}
func (h *sentPacketHistory) SentPacket(p *Packet) {
h.sentPacketImpl(p)
}
func (h *sentPacketHistory) sentPacketImpl(p *Packet) *PacketElement {
el := h.packetList.PushBack(*p)
h.packetMap[p.PacketNumber] = el
if h.firstOutstanding == nil {
h.firstOutstanding = el
}
if p.canBeRetransmitted {
h.numOutstandingPackets++
if p.EncryptionLevel < protocol.EncryptionForwardSecure {
h.numOutstandingHandshakePackets++
}
}
return el
}
func (h *sentPacketHistory) SentPacketsAsRetransmission(packets []*Packet, retransmissionOf protocol.PacketNumber) {
retransmission, ok := h.packetMap[retransmissionOf]
// The retransmitted packet is not present anymore.
// This can happen if it was acked in between dequeueing of the retransmission and sending.
// Just treat the retransmissions as normal packets.
// TODO: This won't happen if we clear packets queued for retransmission on new ACKs.
if !ok {
for _, packet := range packets {
h.sentPacketImpl(packet)
}
return
}
retransmission.Value.retransmittedAs = make([]protocol.PacketNumber, len(packets))
for i, packet := range packets {
retransmission.Value.retransmittedAs[i] = packet.PacketNumber
el := h.sentPacketImpl(packet)
el.Value.isRetransmission = true
el.Value.retransmissionOf = retransmissionOf
}
}
func (h *sentPacketHistory) GetPacket(p protocol.PacketNumber) *Packet {
if el, ok := h.packetMap[p]; ok {
return &el.Value
}
return nil
}
// Iterate iterates through all packets.
// The callback must not modify the history.
func (h *sentPacketHistory) Iterate(cb func(*Packet) (cont bool, err error)) error {
cont := true
for el := h.packetList.Front(); cont && el != nil; el = el.Next() {
var err error
cont, err = cb(&el.Value)
if err != nil {
return err
}
}
return nil
}
// FirstOutStanding returns the first outstanding packet.
// It must not be modified (e.g. retransmitted).
// Use DequeueFirstPacketForRetransmission() to retransmit it.
func (h *sentPacketHistory) FirstOutstanding() *Packet {
if h.firstOutstanding == nil {
return nil
}
return &h.firstOutstanding.Value
}
// QueuePacketForRetransmission marks a packet for retransmission.
// A packet can only be queued once.
func (h *sentPacketHistory) MarkCannotBeRetransmitted(pn protocol.PacketNumber) error {
el, ok := h.packetMap[pn]
if !ok {
return fmt.Errorf("sent packet history: packet %d not found", pn)
}
if el.Value.canBeRetransmitted {
h.numOutstandingPackets--
if h.numOutstandingPackets < 0 {
panic("numOutstandingHandshakePackets negative")
}
if el.Value.EncryptionLevel < protocol.EncryptionForwardSecure {
h.numOutstandingHandshakePackets--
if h.numOutstandingHandshakePackets < 0 {
panic("numOutstandingHandshakePackets negative")
}
}
}
el.Value.canBeRetransmitted = false
if el == h.firstOutstanding {
h.readjustFirstOutstanding()
}
return nil
}
// readjustFirstOutstanding readjusts the pointer to the first outstanding packet.
// This is necessary every time the first outstanding packet is deleted or retransmitted.
func (h *sentPacketHistory) readjustFirstOutstanding() {
el := h.firstOutstanding.Next()
for el != nil && !el.Value.canBeRetransmitted {
el = el.Next()
}
h.firstOutstanding = el
}
func (h *sentPacketHistory) Len() int {
return len(h.packetMap)
}
func (h *sentPacketHistory) Remove(p protocol.PacketNumber) error {
el, ok := h.packetMap[p]
if !ok {
return fmt.Errorf("packet %d not found in sent packet history", p)
}
if el == h.firstOutstanding {
h.readjustFirstOutstanding()
}
if el.Value.canBeRetransmitted {
h.numOutstandingPackets--
if h.numOutstandingPackets < 0 {
panic("numOutstandingHandshakePackets negative")
}
if el.Value.EncryptionLevel < protocol.EncryptionForwardSecure {
h.numOutstandingHandshakePackets--
if h.numOutstandingHandshakePackets < 0 {
panic("numOutstandingHandshakePackets negative")
}
}
}
h.packetList.Remove(el)
delete(h.packetMap, p)
return nil
}
func (h *sentPacketHistory) HasOutstandingPackets() bool {
return h.numOutstandingPackets > 0
}
func (h *sentPacketHistory) HasOutstandingHandshakePackets() bool {
return h.numOutstandingHandshakePackets > 0
}

View File

@@ -0,0 +1,43 @@
package ackhandler
import (
"github.com/lucas-clemente/quic-go/internal/protocol"
"github.com/lucas-clemente/quic-go/internal/wire"
)
// This stopWaitingManager is not supposed to satisfy the StopWaitingManager interface, which is a remnant of the legacy AckHandler, and should be remove once we drop support for QUIC 33
type stopWaitingManager struct {
largestLeastUnackedSent protocol.PacketNumber
nextLeastUnacked protocol.PacketNumber
lastStopWaitingFrame *wire.StopWaitingFrame
}
func (s *stopWaitingManager) GetStopWaitingFrame(force bool) *wire.StopWaitingFrame {
if s.nextLeastUnacked <= s.largestLeastUnackedSent {
if force {
return s.lastStopWaitingFrame
}
return nil
}
s.largestLeastUnackedSent = s.nextLeastUnacked
swf := &wire.StopWaitingFrame{
LeastUnacked: s.nextLeastUnacked,
}
s.lastStopWaitingFrame = swf
return swf
}
func (s *stopWaitingManager) ReceivedAck(ack *wire.AckFrame) {
largestAcked := ack.LargestAcked()
if largestAcked >= s.nextLeastUnacked {
s.nextLeastUnacked = largestAcked + 1
}
}
func (s *stopWaitingManager) QueuedRetransmissionForPacketNumber(p protocol.PacketNumber) {
if p >= s.nextLeastUnacked {
s.nextLeastUnacked = p + 1
}
}

View File

@@ -0,0 +1,22 @@
package congestion
import (
"time"
"github.com/lucas-clemente/quic-go/internal/protocol"
)
// Bandwidth of a connection
type Bandwidth uint64
const (
// BitsPerSecond is 1 bit per second
BitsPerSecond Bandwidth = 1
// BytesPerSecond is 1 byte per second
BytesPerSecond = 8 * BitsPerSecond
)
// BandwidthFromDelta calculates the bandwidth from a number of bytes and a time delta
func BandwidthFromDelta(bytes protocol.ByteCount, delta time.Duration) Bandwidth {
return Bandwidth(bytes) * Bandwidth(time.Second) / Bandwidth(delta) * BytesPerSecond
}

View File

@@ -0,0 +1,18 @@
package congestion
import "time"
// A Clock returns the current time
type Clock interface {
Now() time.Time
}
// DefaultClock implements the Clock interface using the Go stdlib clock.
type DefaultClock struct{}
var _ Clock = DefaultClock{}
// Now gets the current time
func (DefaultClock) Now() time.Time {
return time.Now()
}

View File

@@ -0,0 +1,210 @@
package congestion
import (
"math"
"time"
"github.com/lucas-clemente/quic-go/internal/protocol"
"github.com/lucas-clemente/quic-go/internal/utils"
)
// This cubic implementation is based on the one found in Chromiums's QUIC
// implementation, in the files net/quic/congestion_control/cubic.{hh,cc}.
// Constants based on TCP defaults.
// The following constants are in 2^10 fractions of a second instead of ms to
// allow a 10 shift right to divide.
// 1024*1024^3 (first 1024 is from 0.100^3)
// where 0.100 is 100 ms which is the scaling round trip time.
const cubeScale = 40
const cubeCongestionWindowScale = 410
const cubeFactor protocol.ByteCount = 1 << cubeScale / cubeCongestionWindowScale / protocol.DefaultTCPMSS
const defaultNumConnections = 2
// Default Cubic backoff factor
const beta float32 = 0.7
// Additional backoff factor when loss occurs in the concave part of the Cubic
// curve. This additional backoff factor is expected to give up bandwidth to
// new concurrent flows and speed up convergence.
const betaLastMax float32 = 0.85
// Cubic implements the cubic algorithm from TCP
type Cubic struct {
clock Clock
// Number of connections to simulate.
numConnections int
// Time when this cycle started, after last loss event.
epoch time.Time
// Max congestion window used just before last loss event.
// Note: to improve fairness to other streams an additional back off is
// applied to this value if the new value is below our latest value.
lastMaxCongestionWindow protocol.ByteCount
// Number of acked bytes since the cycle started (epoch).
ackedBytesCount protocol.ByteCount
// TCP Reno equivalent congestion window in packets.
estimatedTCPcongestionWindow protocol.ByteCount
// Origin point of cubic function.
originPointCongestionWindow protocol.ByteCount
// Time to origin point of cubic function in 2^10 fractions of a second.
timeToOriginPoint uint32
// Last congestion window in packets computed by cubic function.
lastTargetCongestionWindow protocol.ByteCount
}
// NewCubic returns a new Cubic instance
func NewCubic(clock Clock) *Cubic {
c := &Cubic{
clock: clock,
numConnections: defaultNumConnections,
}
c.Reset()
return c
}
// Reset is called after a timeout to reset the cubic state
func (c *Cubic) Reset() {
c.epoch = time.Time{}
c.lastMaxCongestionWindow = 0
c.ackedBytesCount = 0
c.estimatedTCPcongestionWindow = 0
c.originPointCongestionWindow = 0
c.timeToOriginPoint = 0
c.lastTargetCongestionWindow = 0
}
func (c *Cubic) alpha() float32 {
// TCPFriendly alpha is described in Section 3.3 of the CUBIC paper. Note that
// beta here is a cwnd multiplier, and is equal to 1-beta from the paper.
// We derive the equivalent alpha for an N-connection emulation as:
b := c.beta()
return 3 * float32(c.numConnections) * float32(c.numConnections) * (1 - b) / (1 + b)
}
func (c *Cubic) beta() float32 {
// kNConnectionBeta is the backoff factor after loss for our N-connection
// emulation, which emulates the effective backoff of an ensemble of N
// TCP-Reno connections on a single loss event. The effective multiplier is
// computed as:
return (float32(c.numConnections) - 1 + beta) / float32(c.numConnections)
}
func (c *Cubic) betaLastMax() float32 {
// betaLastMax is the additional backoff factor after loss for our
// N-connection emulation, which emulates the additional backoff of
// an ensemble of N TCP-Reno connections on a single loss event. The
// effective multiplier is computed as:
return (float32(c.numConnections) - 1 + betaLastMax) / float32(c.numConnections)
}
// OnApplicationLimited is called on ack arrival when sender is unable to use
// the available congestion window. Resets Cubic state during quiescence.
func (c *Cubic) OnApplicationLimited() {
// When sender is not using the available congestion window, the window does
// not grow. But to be RTT-independent, Cubic assumes that the sender has been
// using the entire window during the time since the beginning of the current
// "epoch" (the end of the last loss recovery period). Since
// application-limited periods break this assumption, we reset the epoch when
// in such a period. This reset effectively freezes congestion window growth
// through application-limited periods and allows Cubic growth to continue
// when the entire window is being used.
c.epoch = time.Time{}
}
// CongestionWindowAfterPacketLoss computes a new congestion window to use after
// a loss event. Returns the new congestion window in packets. The new
// congestion window is a multiplicative decrease of our current window.
func (c *Cubic) CongestionWindowAfterPacketLoss(currentCongestionWindow protocol.ByteCount) protocol.ByteCount {
if currentCongestionWindow+protocol.DefaultTCPMSS < c.lastMaxCongestionWindow {
// We never reached the old max, so assume we are competing with another
// flow. Use our extra back off factor to allow the other flow to go up.
c.lastMaxCongestionWindow = protocol.ByteCount(c.betaLastMax() * float32(currentCongestionWindow))
} else {
c.lastMaxCongestionWindow = currentCongestionWindow
}
c.epoch = time.Time{} // Reset time.
return protocol.ByteCount(float32(currentCongestionWindow) * c.beta())
}
// CongestionWindowAfterAck computes a new congestion window to use after a received ACK.
// Returns the new congestion window in packets. The new congestion window
// follows a cubic function that depends on the time passed since last
// packet loss.
func (c *Cubic) CongestionWindowAfterAck(
ackedBytes protocol.ByteCount,
currentCongestionWindow protocol.ByteCount,
delayMin time.Duration,
eventTime time.Time,
) protocol.ByteCount {
c.ackedBytesCount += ackedBytes
if c.epoch.IsZero() {
// First ACK after a loss event.
c.epoch = eventTime // Start of epoch.
c.ackedBytesCount = ackedBytes // Reset count.
// Reset estimated_tcp_congestion_window_ to be in sync with cubic.
c.estimatedTCPcongestionWindow = currentCongestionWindow
if c.lastMaxCongestionWindow <= currentCongestionWindow {
c.timeToOriginPoint = 0
c.originPointCongestionWindow = currentCongestionWindow
} else {
c.timeToOriginPoint = uint32(math.Cbrt(float64(cubeFactor * (c.lastMaxCongestionWindow - currentCongestionWindow))))
c.originPointCongestionWindow = c.lastMaxCongestionWindow
}
}
// Change the time unit from microseconds to 2^10 fractions per second. Take
// the round trip time in account. This is done to allow us to use shift as a
// divide operator.
elapsedTime := int64(eventTime.Add(delayMin).Sub(c.epoch)/time.Microsecond) << 10 / (1000 * 1000)
// Right-shifts of negative, signed numbers have implementation-dependent
// behavior, so force the offset to be positive, as is done in the kernel.
offset := int64(c.timeToOriginPoint) - elapsedTime
if offset < 0 {
offset = -offset
}
deltaCongestionWindow := protocol.ByteCount(cubeCongestionWindowScale*offset*offset*offset) * protocol.DefaultTCPMSS >> cubeScale
var targetCongestionWindow protocol.ByteCount
if elapsedTime > int64(c.timeToOriginPoint) {
targetCongestionWindow = c.originPointCongestionWindow + deltaCongestionWindow
} else {
targetCongestionWindow = c.originPointCongestionWindow - deltaCongestionWindow
}
// Limit the CWND increase to half the acked bytes.
targetCongestionWindow = utils.MinByteCount(targetCongestionWindow, currentCongestionWindow+c.ackedBytesCount/2)
// Increase the window by approximately Alpha * 1 MSS of bytes every
// time we ack an estimated tcp window of bytes. For small
// congestion windows (less than 25), the formula below will
// increase slightly slower than linearly per estimated tcp window
// of bytes.
c.estimatedTCPcongestionWindow += protocol.ByteCount(float32(c.ackedBytesCount) * c.alpha() * float32(protocol.DefaultTCPMSS) / float32(c.estimatedTCPcongestionWindow))
c.ackedBytesCount = 0
// We have a new cubic congestion window.
c.lastTargetCongestionWindow = targetCongestionWindow
// Compute target congestion_window based on cubic target and estimated TCP
// congestion_window, use highest (fastest).
if targetCongestionWindow < c.estimatedTCPcongestionWindow {
targetCongestionWindow = c.estimatedTCPcongestionWindow
}
return targetCongestionWindow
}
// SetNumConnections sets the number of emulated connections
func (c *Cubic) SetNumConnections(n int) {
c.numConnections = n
}

View File

@@ -0,0 +1,318 @@
package congestion
import (
"time"
"github.com/lucas-clemente/quic-go/internal/protocol"
"github.com/lucas-clemente/quic-go/internal/utils"
)
const (
maxBurstBytes = 3 * protocol.DefaultTCPMSS
renoBeta float32 = 0.7 // Reno backoff factor.
defaultMinimumCongestionWindow protocol.ByteCount = 2 * protocol.DefaultTCPMSS
)
type cubicSender struct {
hybridSlowStart HybridSlowStart
prr PrrSender
rttStats *RTTStats
stats connectionStats
cubic *Cubic
reno bool
// Track the largest packet that has been sent.
largestSentPacketNumber protocol.PacketNumber
// Track the largest packet that has been acked.
largestAckedPacketNumber protocol.PacketNumber
// Track the largest packet number outstanding when a CWND cutback occurs.
largestSentAtLastCutback protocol.PacketNumber
// Whether the last loss event caused us to exit slowstart.
// Used for stats collection of slowstartPacketsLost
lastCutbackExitedSlowstart bool
// When true, exit slow start with large cutback of congestion window.
slowStartLargeReduction bool
// Congestion window in packets.
congestionWindow protocol.ByteCount
// Minimum congestion window in packets.
minCongestionWindow protocol.ByteCount
// Maximum congestion window.
maxCongestionWindow protocol.ByteCount
// Slow start congestion window in bytes, aka ssthresh.
slowstartThreshold protocol.ByteCount
// Number of connections to simulate.
numConnections int
// ACK counter for the Reno implementation.
numAckedPackets uint64
initialCongestionWindow protocol.ByteCount
initialMaxCongestionWindow protocol.ByteCount
minSlowStartExitWindow protocol.ByteCount
}
var _ SendAlgorithm = &cubicSender{}
var _ SendAlgorithmWithDebugInfo = &cubicSender{}
// NewCubicSender makes a new cubic sender
func NewCubicSender(clock Clock, rttStats *RTTStats, reno bool, initialCongestionWindow, initialMaxCongestionWindow protocol.ByteCount) SendAlgorithmWithDebugInfo {
return &cubicSender{
rttStats: rttStats,
initialCongestionWindow: initialCongestionWindow,
initialMaxCongestionWindow: initialMaxCongestionWindow,
congestionWindow: initialCongestionWindow,
minCongestionWindow: defaultMinimumCongestionWindow,
slowstartThreshold: initialMaxCongestionWindow,
maxCongestionWindow: initialMaxCongestionWindow,
numConnections: defaultNumConnections,
cubic: NewCubic(clock),
reno: reno,
}
}
// TimeUntilSend returns when the next packet should be sent.
func (c *cubicSender) TimeUntilSend(bytesInFlight protocol.ByteCount) time.Duration {
if c.InRecovery() {
// PRR is used when in recovery.
if c.prr.CanSend(c.GetCongestionWindow(), bytesInFlight, c.GetSlowStartThreshold()) {
return 0
}
}
delay := c.rttStats.SmoothedRTT() / time.Duration(2*c.GetCongestionWindow())
if !c.InSlowStart() { // adjust delay, such that it's 1.25*cwd/rtt
delay = delay * 8 / 5
}
return delay
}
func (c *cubicSender) OnPacketSent(
sentTime time.Time,
bytesInFlight protocol.ByteCount,
packetNumber protocol.PacketNumber,
bytes protocol.ByteCount,
isRetransmittable bool,
) {
if !isRetransmittable {
return
}
if c.InRecovery() {
// PRR is used when in recovery.
c.prr.OnPacketSent(bytes)
}
c.largestSentPacketNumber = packetNumber
c.hybridSlowStart.OnPacketSent(packetNumber)
}
func (c *cubicSender) InRecovery() bool {
return c.largestAckedPacketNumber <= c.largestSentAtLastCutback && c.largestAckedPacketNumber != 0
}
func (c *cubicSender) InSlowStart() bool {
return c.GetCongestionWindow() < c.GetSlowStartThreshold()
}
func (c *cubicSender) GetCongestionWindow() protocol.ByteCount {
return c.congestionWindow
}
func (c *cubicSender) GetSlowStartThreshold() protocol.ByteCount {
return c.slowstartThreshold
}
func (c *cubicSender) ExitSlowstart() {
c.slowstartThreshold = c.congestionWindow
}
func (c *cubicSender) SlowstartThreshold() protocol.ByteCount {
return c.slowstartThreshold
}
func (c *cubicSender) MaybeExitSlowStart() {
if c.InSlowStart() && c.hybridSlowStart.ShouldExitSlowStart(c.rttStats.LatestRTT(), c.rttStats.MinRTT(), c.GetCongestionWindow()/protocol.DefaultTCPMSS) {
c.ExitSlowstart()
}
}
func (c *cubicSender) OnPacketAcked(
ackedPacketNumber protocol.PacketNumber,
ackedBytes protocol.ByteCount,
priorInFlight protocol.ByteCount,
eventTime time.Time,
) {
c.largestAckedPacketNumber = utils.MaxPacketNumber(ackedPacketNumber, c.largestAckedPacketNumber)
if c.InRecovery() {
// PRR is used when in recovery.
c.prr.OnPacketAcked(ackedBytes)
return
}
c.maybeIncreaseCwnd(ackedPacketNumber, ackedBytes, priorInFlight, eventTime)
if c.InSlowStart() {
c.hybridSlowStart.OnPacketAcked(ackedPacketNumber)
}
}
func (c *cubicSender) OnPacketLost(
packetNumber protocol.PacketNumber,
lostBytes protocol.ByteCount,
priorInFlight protocol.ByteCount,
) {
// TCP NewReno (RFC6582) says that once a loss occurs, any losses in packets
// already sent should be treated as a single loss event, since it's expected.
if packetNumber <= c.largestSentAtLastCutback {
if c.lastCutbackExitedSlowstart {
c.stats.slowstartPacketsLost++
c.stats.slowstartBytesLost += lostBytes
if c.slowStartLargeReduction {
// Reduce congestion window by lost_bytes for every loss.
c.congestionWindow = utils.MaxByteCount(c.congestionWindow-lostBytes, c.minSlowStartExitWindow)
c.slowstartThreshold = c.congestionWindow
}
}
return
}
c.lastCutbackExitedSlowstart = c.InSlowStart()
if c.InSlowStart() {
c.stats.slowstartPacketsLost++
}
c.prr.OnPacketLost(priorInFlight)
// TODO(chromium): Separate out all of slow start into a separate class.
if c.slowStartLargeReduction && c.InSlowStart() {
if c.congestionWindow >= 2*c.initialCongestionWindow {
c.minSlowStartExitWindow = c.congestionWindow / 2
}
c.congestionWindow -= protocol.DefaultTCPMSS
} else if c.reno {
c.congestionWindow = protocol.ByteCount(float32(c.congestionWindow) * c.RenoBeta())
} else {
c.congestionWindow = c.cubic.CongestionWindowAfterPacketLoss(c.congestionWindow)
}
if c.congestionWindow < c.minCongestionWindow {
c.congestionWindow = c.minCongestionWindow
}
c.slowstartThreshold = c.congestionWindow
c.largestSentAtLastCutback = c.largestSentPacketNumber
// reset packet count from congestion avoidance mode. We start
// counting again when we're out of recovery.
c.numAckedPackets = 0
}
func (c *cubicSender) RenoBeta() float32 {
// kNConnectionBeta is the backoff factor after loss for our N-connection
// emulation, which emulates the effective backoff of an ensemble of N
// TCP-Reno connections on a single loss event. The effective multiplier is
// computed as:
return (float32(c.numConnections) - 1. + renoBeta) / float32(c.numConnections)
}
// Called when we receive an ack. Normal TCP tracks how many packets one ack
// represents, but quic has a separate ack for each packet.
func (c *cubicSender) maybeIncreaseCwnd(
ackedPacketNumber protocol.PacketNumber,
ackedBytes protocol.ByteCount,
priorInFlight protocol.ByteCount,
eventTime time.Time,
) {
// Do not increase the congestion window unless the sender is close to using
// the current window.
if !c.isCwndLimited(priorInFlight) {
c.cubic.OnApplicationLimited()
return
}
if c.congestionWindow >= c.maxCongestionWindow {
return
}
if c.InSlowStart() {
// TCP slow start, exponential growth, increase by one for each ACK.
c.congestionWindow += protocol.DefaultTCPMSS
return
}
// Congestion avoidance
if c.reno {
// Classic Reno congestion avoidance.
c.numAckedPackets++
// Divide by num_connections to smoothly increase the CWND at a faster
// rate than conventional Reno.
if c.numAckedPackets*uint64(c.numConnections) >= uint64(c.congestionWindow)/uint64(protocol.DefaultTCPMSS) {
c.congestionWindow += protocol.DefaultTCPMSS
c.numAckedPackets = 0
}
} else {
c.congestionWindow = utils.MinByteCount(c.maxCongestionWindow, c.cubic.CongestionWindowAfterAck(ackedBytes, c.congestionWindow, c.rttStats.MinRTT(), eventTime))
}
}
func (c *cubicSender) isCwndLimited(bytesInFlight protocol.ByteCount) bool {
congestionWindow := c.GetCongestionWindow()
if bytesInFlight >= congestionWindow {
return true
}
availableBytes := congestionWindow - bytesInFlight
slowStartLimited := c.InSlowStart() && bytesInFlight > congestionWindow/2
return slowStartLimited || availableBytes <= maxBurstBytes
}
// BandwidthEstimate returns the current bandwidth estimate
func (c *cubicSender) BandwidthEstimate() Bandwidth {
srtt := c.rttStats.SmoothedRTT()
if srtt == 0 {
// If we haven't measured an rtt, the bandwidth estimate is unknown.
return 0
}
return BandwidthFromDelta(c.GetCongestionWindow(), srtt)
}
// HybridSlowStart returns the hybrid slow start instance for testing
func (c *cubicSender) HybridSlowStart() *HybridSlowStart {
return &c.hybridSlowStart
}
// SetNumEmulatedConnections sets the number of emulated connections
func (c *cubicSender) SetNumEmulatedConnections(n int) {
c.numConnections = utils.Max(n, 1)
c.cubic.SetNumConnections(c.numConnections)
}
// OnRetransmissionTimeout is called on an retransmission timeout
func (c *cubicSender) OnRetransmissionTimeout(packetsRetransmitted bool) {
c.largestSentAtLastCutback = 0
if !packetsRetransmitted {
return
}
c.hybridSlowStart.Restart()
c.cubic.Reset()
c.slowstartThreshold = c.congestionWindow / 2
c.congestionWindow = c.minCongestionWindow
}
// OnConnectionMigration is called when the connection is migrated (?)
func (c *cubicSender) OnConnectionMigration() {
c.hybridSlowStart.Restart()
c.prr = PrrSender{}
c.largestSentPacketNumber = 0
c.largestAckedPacketNumber = 0
c.largestSentAtLastCutback = 0
c.lastCutbackExitedSlowstart = false
c.cubic.Reset()
c.numAckedPackets = 0
c.congestionWindow = c.initialCongestionWindow
c.slowstartThreshold = c.initialMaxCongestionWindow
c.maxCongestionWindow = c.initialMaxCongestionWindow
}
// SetSlowStartLargeReduction allows enabling the SSLR experiment
func (c *cubicSender) SetSlowStartLargeReduction(enabled bool) {
c.slowStartLargeReduction = enabled
}

View File

@@ -0,0 +1,111 @@
package congestion
import (
"time"
"github.com/lucas-clemente/quic-go/internal/protocol"
"github.com/lucas-clemente/quic-go/internal/utils"
)
// Note(pwestin): the magic clamping numbers come from the original code in
// tcp_cubic.c.
const hybridStartLowWindow = protocol.ByteCount(16)
// Number of delay samples for detecting the increase of delay.
const hybridStartMinSamples = uint32(8)
// Exit slow start if the min rtt has increased by more than 1/8th.
const hybridStartDelayFactorExp = 3 // 2^3 = 8
// The original paper specifies 2 and 8ms, but those have changed over time.
const hybridStartDelayMinThresholdUs = int64(4000)
const hybridStartDelayMaxThresholdUs = int64(16000)
// HybridSlowStart implements the TCP hybrid slow start algorithm
type HybridSlowStart struct {
endPacketNumber protocol.PacketNumber
lastSentPacketNumber protocol.PacketNumber
started bool
currentMinRTT time.Duration
rttSampleCount uint32
hystartFound bool
}
// StartReceiveRound is called for the start of each receive round (burst) in the slow start phase.
func (s *HybridSlowStart) StartReceiveRound(lastSent protocol.PacketNumber) {
s.endPacketNumber = lastSent
s.currentMinRTT = 0
s.rttSampleCount = 0
s.started = true
}
// IsEndOfRound returns true if this ack is the last packet number of our current slow start round.
func (s *HybridSlowStart) IsEndOfRound(ack protocol.PacketNumber) bool {
return s.endPacketNumber < ack
}
// ShouldExitSlowStart should be called on every new ack frame, since a new
// RTT measurement can be made then.
// rtt: the RTT for this ack packet.
// minRTT: is the lowest delay (RTT) we have seen during the session.
// congestionWindow: the congestion window in packets.
func (s *HybridSlowStart) ShouldExitSlowStart(latestRTT time.Duration, minRTT time.Duration, congestionWindow protocol.ByteCount) bool {
if !s.started {
// Time to start the hybrid slow start.
s.StartReceiveRound(s.lastSentPacketNumber)
}
if s.hystartFound {
return true
}
// Second detection parameter - delay increase detection.
// Compare the minimum delay (s.currentMinRTT) of the current
// burst of packets relative to the minimum delay during the session.
// Note: we only look at the first few(8) packets in each burst, since we
// only want to compare the lowest RTT of the burst relative to previous
// bursts.
s.rttSampleCount++
if s.rttSampleCount <= hybridStartMinSamples {
if s.currentMinRTT == 0 || s.currentMinRTT > latestRTT {
s.currentMinRTT = latestRTT
}
}
// We only need to check this once per round.
if s.rttSampleCount == hybridStartMinSamples {
// Divide minRTT by 8 to get a rtt increase threshold for exiting.
minRTTincreaseThresholdUs := int64(minRTT / time.Microsecond >> hybridStartDelayFactorExp)
// Ensure the rtt threshold is never less than 2ms or more than 16ms.
minRTTincreaseThresholdUs = utils.MinInt64(minRTTincreaseThresholdUs, hybridStartDelayMaxThresholdUs)
minRTTincreaseThreshold := time.Duration(utils.MaxInt64(minRTTincreaseThresholdUs, hybridStartDelayMinThresholdUs)) * time.Microsecond
if s.currentMinRTT > (minRTT + minRTTincreaseThreshold) {
s.hystartFound = true
}
}
// Exit from slow start if the cwnd is greater than 16 and
// increasing delay is found.
return congestionWindow >= hybridStartLowWindow && s.hystartFound
}
// OnPacketSent is called when a packet was sent
func (s *HybridSlowStart) OnPacketSent(packetNumber protocol.PacketNumber) {
s.lastSentPacketNumber = packetNumber
}
// OnPacketAcked gets invoked after ShouldExitSlowStart, so it's best to end
// the round when the final packet of the burst is received and start it on
// the next incoming ack.
func (s *HybridSlowStart) OnPacketAcked(ackedPacketNumber protocol.PacketNumber) {
if s.IsEndOfRound(ackedPacketNumber) {
s.started = false
}
}
// Started returns true if started
func (s *HybridSlowStart) Started() bool {
return s.started
}
// Restart the slow start phase
func (s *HybridSlowStart) Restart() {
s.started = false
s.hystartFound = false
}

View File

@@ -0,0 +1,36 @@
package congestion
import (
"time"
"github.com/lucas-clemente/quic-go/internal/protocol"
)
// A SendAlgorithm performs congestion control and calculates the congestion window
type SendAlgorithm interface {
TimeUntilSend(bytesInFlight protocol.ByteCount) time.Duration
OnPacketSent(sentTime time.Time, bytesInFlight protocol.ByteCount, packetNumber protocol.PacketNumber, bytes protocol.ByteCount, isRetransmittable bool)
GetCongestionWindow() protocol.ByteCount
MaybeExitSlowStart()
OnPacketAcked(number protocol.PacketNumber, ackedBytes protocol.ByteCount, priorInFlight protocol.ByteCount, eventTime time.Time)
OnPacketLost(number protocol.PacketNumber, lostBytes protocol.ByteCount, priorInFlight protocol.ByteCount)
SetNumEmulatedConnections(n int)
OnRetransmissionTimeout(packetsRetransmitted bool)
OnConnectionMigration()
// Experiments
SetSlowStartLargeReduction(enabled bool)
}
// SendAlgorithmWithDebugInfo adds some debug functions to SendAlgorithm
type SendAlgorithmWithDebugInfo interface {
SendAlgorithm
BandwidthEstimate() Bandwidth
// Stuff only used in testing
HybridSlowStart() *HybridSlowStart
SlowstartThreshold() protocol.ByteCount
RenoBeta() float32
InRecovery() bool
}

View File

@@ -0,0 +1,54 @@
package congestion
import (
"github.com/lucas-clemente/quic-go/internal/protocol"
)
// PrrSender implements the Proportional Rate Reduction (PRR) per RFC 6937
type PrrSender struct {
bytesSentSinceLoss protocol.ByteCount
bytesDeliveredSinceLoss protocol.ByteCount
ackCountSinceLoss protocol.ByteCount
bytesInFlightBeforeLoss protocol.ByteCount
}
// OnPacketSent should be called after a packet was sent
func (p *PrrSender) OnPacketSent(sentBytes protocol.ByteCount) {
p.bytesSentSinceLoss += sentBytes
}
// OnPacketLost should be called on the first loss that triggers a recovery
// period and all other methods in this class should only be called when in
// recovery.
func (p *PrrSender) OnPacketLost(priorInFlight protocol.ByteCount) {
p.bytesSentSinceLoss = 0
p.bytesInFlightBeforeLoss = priorInFlight
p.bytesDeliveredSinceLoss = 0
p.ackCountSinceLoss = 0
}
// OnPacketAcked should be called after a packet was acked
func (p *PrrSender) OnPacketAcked(ackedBytes protocol.ByteCount) {
p.bytesDeliveredSinceLoss += ackedBytes
p.ackCountSinceLoss++
}
// CanSend returns if packets can be sent
func (p *PrrSender) CanSend(congestionWindow, bytesInFlight, slowstartThreshold protocol.ByteCount) bool {
// Return QuicTime::Zero In order to ensure limited transmit always works.
if p.bytesSentSinceLoss == 0 || bytesInFlight < protocol.DefaultTCPMSS {
return true
}
if congestionWindow > bytesInFlight {
// During PRR-SSRB, limit outgoing packets to 1 extra MSS per ack, instead
// of sending the entire available window. This prevents burst retransmits
// when more packets are lost than the CWND reduction.
// limit = MAX(prr_delivered - prr_out, DeliveredData) + MSS
return p.bytesDeliveredSinceLoss+p.ackCountSinceLoss*protocol.DefaultTCPMSS > p.bytesSentSinceLoss
}
// Implement Proportional Rate Reduction (RFC6937).
// Checks a simplified version of the PRR formula that doesn't use division:
// AvailableSendWindow =
// CEIL(prr_delivered * ssthresh / BytesInFlightAtLoss) - prr_sent
return p.bytesDeliveredSinceLoss*slowstartThreshold > p.bytesSentSinceLoss*p.bytesInFlightBeforeLoss
}

View File

@@ -0,0 +1,101 @@
package congestion
import (
"time"
"github.com/lucas-clemente/quic-go/internal/utils"
)
const (
rttAlpha float32 = 0.125
oneMinusAlpha float32 = (1 - rttAlpha)
rttBeta float32 = 0.25
oneMinusBeta float32 = (1 - rttBeta)
// The default RTT used before an RTT sample is taken.
defaultInitialRTT = 100 * time.Millisecond
)
// RTTStats provides round-trip statistics
type RTTStats struct {
minRTT time.Duration
latestRTT time.Duration
smoothedRTT time.Duration
meanDeviation time.Duration
}
// NewRTTStats makes a properly initialized RTTStats object
func NewRTTStats() *RTTStats {
return &RTTStats{}
}
// MinRTT Returns the minRTT for the entire connection.
// May return Zero if no valid updates have occurred.
func (r *RTTStats) MinRTT() time.Duration { return r.minRTT }
// LatestRTT returns the most recent rtt measurement.
// May return Zero if no valid updates have occurred.
func (r *RTTStats) LatestRTT() time.Duration { return r.latestRTT }
// SmoothedRTT returns the EWMA smoothed RTT for the connection.
// May return Zero if no valid updates have occurred.
func (r *RTTStats) SmoothedRTT() time.Duration { return r.smoothedRTT }
// SmoothedOrInitialRTT returns the EWMA smoothed RTT for the connection.
// If no valid updates have occurred, it returns the initial RTT.
func (r *RTTStats) SmoothedOrInitialRTT() time.Duration {
if r.smoothedRTT != 0 {
return r.smoothedRTT
}
return defaultInitialRTT
}
// MeanDeviation gets the mean deviation
func (r *RTTStats) MeanDeviation() time.Duration { return r.meanDeviation }
// UpdateRTT updates the RTT based on a new sample.
func (r *RTTStats) UpdateRTT(sendDelta, ackDelay time.Duration, now time.Time) {
if sendDelta == utils.InfDuration || sendDelta <= 0 {
return
}
// Update r.minRTT first. r.minRTT does not use an rttSample corrected for
// ackDelay but the raw observed sendDelta, since poor clock granularity at
// the client may cause a high ackDelay to result in underestimation of the
// r.minRTT.
if r.minRTT == 0 || r.minRTT > sendDelta {
r.minRTT = sendDelta
}
// Correct for ackDelay if information received from the peer results in a
// an RTT sample at least as large as minRTT. Otherwise, only use the
// sendDelta.
sample := sendDelta
if sample-r.minRTT >= ackDelay {
sample -= ackDelay
}
r.latestRTT = sample
// First time call.
if r.smoothedRTT == 0 {
r.smoothedRTT = sample
r.meanDeviation = sample / 2
} else {
r.meanDeviation = time.Duration(oneMinusBeta*float32(r.meanDeviation/time.Microsecond)+rttBeta*float32(utils.AbsDuration(r.smoothedRTT-sample)/time.Microsecond)) * time.Microsecond
r.smoothedRTT = time.Duration((float32(r.smoothedRTT/time.Microsecond)*oneMinusAlpha)+(float32(sample/time.Microsecond)*rttAlpha)) * time.Microsecond
}
}
// OnConnectionMigration is called when connection migrates and rtt measurement needs to be reset.
func (r *RTTStats) OnConnectionMigration() {
r.latestRTT = 0
r.minRTT = 0
r.smoothedRTT = 0
r.meanDeviation = 0
}
// ExpireSmoothedMetrics causes the smoothed_rtt to be increased to the latest_rtt if the latest_rtt
// is larger. The mean deviation is increased to the most recent deviation if
// it's larger.
func (r *RTTStats) ExpireSmoothedMetrics() {
r.meanDeviation = utils.MaxDuration(r.meanDeviation, utils.AbsDuration(r.smoothedRTT-r.latestRTT))
r.smoothedRTT = utils.MaxDuration(r.smoothedRTT, r.latestRTT)
}

View File

@@ -0,0 +1,8 @@
package congestion
import "github.com/lucas-clemente/quic-go/internal/protocol"
type connectionStats struct {
slowstartPacketsLost protocol.PacketNumber
slowstartBytesLost protocol.ByteCount
}

View File

@@ -0,0 +1,10 @@
package crypto
import "github.com/lucas-clemente/quic-go/internal/protocol"
// An AEAD implements QUIC's authenticated encryption and associated data
type AEAD interface {
Open(dst, src []byte, packetNumber protocol.PacketNumber, associatedData []byte) ([]byte, error)
Seal(dst, src []byte, packetNumber protocol.PacketNumber, associatedData []byte) []byte
Overhead() int
}

View File

@@ -0,0 +1,72 @@
package crypto
import (
"crypto/cipher"
"encoding/binary"
"errors"
"github.com/lucas-clemente/aes12"
"github.com/lucas-clemente/quic-go/internal/protocol"
)
type aeadAESGCM12 struct {
otherIV []byte
myIV []byte
encrypter cipher.AEAD
decrypter cipher.AEAD
}
var _ AEAD = &aeadAESGCM12{}
// NewAEADAESGCM12 creates a AEAD using AES-GCM with 12 bytes tag size
//
// AES-GCM support is a bit hacky, since the go stdlib does not support 12 byte
// tag size, and couples the cipher and aes packages closely.
// See https://github.com/lucas-clemente/aes12.
func NewAEADAESGCM12(otherKey []byte, myKey []byte, otherIV []byte, myIV []byte) (AEAD, error) {
if len(myKey) != 16 || len(otherKey) != 16 || len(myIV) != 4 || len(otherIV) != 4 {
return nil, errors.New("AES-GCM: expected 16-byte keys and 4-byte IVs")
}
encrypterCipher, err := aes12.NewCipher(myKey)
if err != nil {
return nil, err
}
encrypter, err := aes12.NewGCM(encrypterCipher)
if err != nil {
return nil, err
}
decrypterCipher, err := aes12.NewCipher(otherKey)
if err != nil {
return nil, err
}
decrypter, err := aes12.NewGCM(decrypterCipher)
if err != nil {
return nil, err
}
return &aeadAESGCM12{
otherIV: otherIV,
myIV: myIV,
encrypter: encrypter,
decrypter: decrypter,
}, nil
}
func (aead *aeadAESGCM12) Open(dst, src []byte, packetNumber protocol.PacketNumber, associatedData []byte) ([]byte, error) {
return aead.decrypter.Open(dst, aead.makeNonce(aead.otherIV, packetNumber), src, associatedData)
}
func (aead *aeadAESGCM12) Seal(dst, src []byte, packetNumber protocol.PacketNumber, associatedData []byte) []byte {
return aead.encrypter.Seal(dst, aead.makeNonce(aead.myIV, packetNumber), src, associatedData)
}
func (aead *aeadAESGCM12) makeNonce(iv []byte, packetNumber protocol.PacketNumber) []byte {
res := make([]byte, 12)
copy(res[0:4], iv)
binary.LittleEndian.PutUint64(res[4:12], uint64(packetNumber))
return res
}
func (aead *aeadAESGCM12) Overhead() int {
return aead.encrypter.Overhead()
}

View File

@@ -0,0 +1,74 @@
package crypto
import (
"crypto/aes"
"crypto/cipher"
"encoding/binary"
"errors"
"github.com/lucas-clemente/quic-go/internal/protocol"
)
type aeadAESGCM struct {
otherIV []byte
myIV []byte
encrypter cipher.AEAD
decrypter cipher.AEAD
}
var _ AEAD = &aeadAESGCM{}
const ivLen = 12
// NewAEADAESGCM creates a AEAD using AES-GCM
func NewAEADAESGCM(otherKey []byte, myKey []byte, otherIV []byte, myIV []byte) (AEAD, error) {
// the IVs need to be at least 8 bytes long, otherwise we can't compute the nonce
if len(otherIV) != ivLen || len(myIV) != ivLen {
return nil, errors.New("AES-GCM: expected 12 byte IVs")
}
encrypterCipher, err := aes.NewCipher(myKey)
if err != nil {
return nil, err
}
encrypter, err := cipher.NewGCM(encrypterCipher)
if err != nil {
return nil, err
}
decrypterCipher, err := aes.NewCipher(otherKey)
if err != nil {
return nil, err
}
decrypter, err := cipher.NewGCM(decrypterCipher)
if err != nil {
return nil, err
}
return &aeadAESGCM{
otherIV: otherIV,
myIV: myIV,
encrypter: encrypter,
decrypter: decrypter,
}, nil
}
func (aead *aeadAESGCM) Open(dst, src []byte, packetNumber protocol.PacketNumber, associatedData []byte) ([]byte, error) {
return aead.decrypter.Open(dst, aead.makeNonce(aead.otherIV, packetNumber), src, associatedData)
}
func (aead *aeadAESGCM) Seal(dst, src []byte, packetNumber protocol.PacketNumber, associatedData []byte) []byte {
return aead.encrypter.Seal(dst, aead.makeNonce(aead.myIV, packetNumber), src, associatedData)
}
func (aead *aeadAESGCM) makeNonce(iv []byte, packetNumber protocol.PacketNumber) []byte {
nonce := make([]byte, ivLen)
binary.BigEndian.PutUint64(nonce[ivLen-8:], uint64(packetNumber))
for i := 0; i < ivLen; i++ {
nonce[i] ^= iv[i]
}
return nonce
}
func (aead *aeadAESGCM) Overhead() int {
return aead.encrypter.Overhead()
}

View File

@@ -0,0 +1,48 @@
package crypto
import (
"fmt"
"hash/fnv"
"github.com/hashicorp/golang-lru"
"github.com/lucas-clemente/quic-go/internal/protocol"
)
var (
compressedCertsCache *lru.Cache
)
func getCompressedCert(chain [][]byte, pCommonSetHashes, pCachedHashes []byte) ([]byte, error) {
// Hash all inputs
hasher := fnv.New64a()
for _, v := range chain {
hasher.Write(v)
}
hasher.Write(pCommonSetHashes)
hasher.Write(pCachedHashes)
hash := hasher.Sum64()
var result []byte
resultI, isCached := compressedCertsCache.Get(hash)
if isCached {
result = resultI.([]byte)
} else {
var err error
result, err = compressChain(chain, pCommonSetHashes, pCachedHashes)
if err != nil {
return nil, err
}
compressedCertsCache.Add(hash, result)
}
return result, nil
}
func init() {
var err error
compressedCertsCache, err = lru.New(protocol.NumCachedCertificates)
if err != nil {
panic(fmt.Sprintf("fatal error in quic-go: could not create lru cache: %s", err.Error()))
}
}

View File

@@ -0,0 +1,118 @@
package crypto
import (
"crypto/tls"
"errors"
"strings"
)
// A CertChain holds a certificate and a private key
type CertChain interface {
SignServerProof(sni string, chlo []byte, serverConfigData []byte) ([]byte, error)
GetCertsCompressed(sni string, commonSetHashes, cachedHashes []byte) ([]byte, error)
GetLeafCert(sni string) ([]byte, error)
}
// proofSource stores a key and a certificate for the server proof
type certChain struct {
config *tls.Config
}
var _ CertChain = &certChain{}
var errNoMatchingCertificate = errors.New("no matching certificate found")
// NewCertChain loads the key and cert from files
func NewCertChain(tlsConfig *tls.Config) CertChain {
return &certChain{config: tlsConfig}
}
// SignServerProof signs CHLO and server config for use in the server proof
func (c *certChain) SignServerProof(sni string, chlo []byte, serverConfigData []byte) ([]byte, error) {
cert, err := c.getCertForSNI(sni)
if err != nil {
return nil, err
}
return signServerProof(cert, chlo, serverConfigData)
}
// GetCertsCompressed gets the certificate in the format described by the QUIC crypto doc
func (c *certChain) GetCertsCompressed(sni string, pCommonSetHashes, pCachedHashes []byte) ([]byte, error) {
cert, err := c.getCertForSNI(sni)
if err != nil {
return nil, err
}
return getCompressedCert(cert.Certificate, pCommonSetHashes, pCachedHashes)
}
// GetLeafCert gets the leaf certificate
func (c *certChain) GetLeafCert(sni string) ([]byte, error) {
cert, err := c.getCertForSNI(sni)
if err != nil {
return nil, err
}
return cert.Certificate[0], nil
}
func (c *certChain) getCertForSNI(sni string) (*tls.Certificate, error) {
conf, err := maybeGetConfigForClient(c.config, sni)
if err != nil {
return nil, err
}
// The rest of this function is mostly copied from crypto/tls.getCertificate
if conf.GetCertificate != nil {
cert, err := conf.GetCertificate(&tls.ClientHelloInfo{ServerName: sni})
if cert != nil || err != nil {
return cert, err
}
}
if len(conf.Certificates) == 0 {
return nil, errNoMatchingCertificate
}
if len(conf.Certificates) == 1 || conf.NameToCertificate == nil {
// There's only one choice, so no point doing any work.
return &conf.Certificates[0], nil
}
name := strings.ToLower(sni)
for len(name) > 0 && name[len(name)-1] == '.' {
name = name[:len(name)-1]
}
if cert, ok := conf.NameToCertificate[name]; ok {
return cert, nil
}
// try replacing labels in the name with wildcards until we get a
// match.
labels := strings.Split(name, ".")
for i := range labels {
labels[i] = "*"
candidate := strings.Join(labels, ".")
if cert, ok := conf.NameToCertificate[candidate]; ok {
return cert, nil
}
}
// If nothing matches, return the first certificate.
return &conf.Certificates[0], nil
}
func maybeGetConfigForClient(c *tls.Config, sni string) (*tls.Config, error) {
if c.GetConfigForClient == nil {
return c, nil
}
confForClient, err := c.GetConfigForClient(&tls.ClientHelloInfo{ServerName: sni})
if err != nil {
return nil, err
}
// if GetConfigForClient returns nil, use the original config
if confForClient == nil {
return c, nil
}
return confForClient, nil
}

View File

@@ -0,0 +1,272 @@
package crypto
import (
"bytes"
"compress/flate"
"compress/zlib"
"encoding/binary"
"errors"
"fmt"
"hash/fnv"
"github.com/lucas-clemente/quic-go/internal/utils"
)
type entryType uint8
const (
entryCompressed entryType = 1
entryCached entryType = 2
entryCommon entryType = 3
)
type entry struct {
t entryType
h uint64 // set hash
i uint32 // index
}
func compressChain(chain [][]byte, pCommonSetHashes, pCachedHashes []byte) ([]byte, error) {
res := &bytes.Buffer{}
cachedHashes, err := splitHashes(pCachedHashes)
if err != nil {
return nil, err
}
setHashes, err := splitHashes(pCommonSetHashes)
if err != nil {
return nil, err
}
chainHashes := make([]uint64, len(chain))
for i := range chain {
chainHashes[i] = HashCert(chain[i])
}
entries := buildEntries(chain, chainHashes, cachedHashes, setHashes)
totalUncompressedLen := 0
for i, e := range entries {
res.WriteByte(uint8(e.t))
switch e.t {
case entryCached:
utils.LittleEndian.WriteUint64(res, e.h)
case entryCommon:
utils.LittleEndian.WriteUint64(res, e.h)
utils.LittleEndian.WriteUint32(res, e.i)
case entryCompressed:
totalUncompressedLen += 4 + len(chain[i])
}
}
res.WriteByte(0) // end of list
if totalUncompressedLen > 0 {
gz, err := zlib.NewWriterLevelDict(res, flate.BestCompression, buildZlibDictForEntries(entries, chain))
if err != nil {
return nil, fmt.Errorf("cert compression failed: %s", err.Error())
}
utils.LittleEndian.WriteUint32(res, uint32(totalUncompressedLen))
for i, e := range entries {
if e.t != entryCompressed {
continue
}
lenCert := len(chain[i])
gz.Write([]byte{
byte(lenCert & 0xff),
byte((lenCert >> 8) & 0xff),
byte((lenCert >> 16) & 0xff),
byte((lenCert >> 24) & 0xff),
})
gz.Write(chain[i])
}
gz.Close()
}
return res.Bytes(), nil
}
func decompressChain(data []byte) ([][]byte, error) {
var chain [][]byte
var entries []entry
r := bytes.NewReader(data)
var numCerts int
var hasCompressedCerts bool
for {
entryTypeByte, err := r.ReadByte()
if entryTypeByte == 0 {
break
}
et := entryType(entryTypeByte)
if err != nil {
return nil, err
}
numCerts++
switch et {
case entryCached:
// we're not sending any certificate hashes in the CHLO, so there shouldn't be any cached certificates in the chain
return nil, errors.New("unexpected cached certificate")
case entryCommon:
e := entry{t: entryCommon}
e.h, err = utils.LittleEndian.ReadUint64(r)
if err != nil {
return nil, err
}
e.i, err = utils.LittleEndian.ReadUint32(r)
if err != nil {
return nil, err
}
certSet, ok := certSets[e.h]
if !ok {
return nil, errors.New("unknown certSet")
}
if e.i >= uint32(len(certSet)) {
return nil, errors.New("certificate not found in certSet")
}
entries = append(entries, e)
chain = append(chain, certSet[e.i])
case entryCompressed:
hasCompressedCerts = true
entries = append(entries, entry{t: entryCompressed})
chain = append(chain, nil)
default:
return nil, errors.New("unknown entryType")
}
}
if numCerts == 0 {
return make([][]byte, 0), nil
}
if hasCompressedCerts {
uncompressedLength, err := utils.LittleEndian.ReadUint32(r)
if err != nil {
fmt.Println(4)
return nil, err
}
zlibDict := buildZlibDictForEntries(entries, chain)
gz, err := zlib.NewReaderDict(r, zlibDict)
if err != nil {
return nil, err
}
defer gz.Close()
var totalLength uint32
var certIndex int
for totalLength < uncompressedLength {
lenBytes := make([]byte, 4)
_, err := gz.Read(lenBytes)
if err != nil {
return nil, err
}
certLen := binary.LittleEndian.Uint32(lenBytes)
cert := make([]byte, certLen)
n, err := gz.Read(cert)
if uint32(n) != certLen && err != nil {
return nil, err
}
for {
if certIndex >= len(entries) {
return nil, errors.New("CertCompression BUG: no element to save uncompressed certificate")
}
if entries[certIndex].t == entryCompressed {
chain[certIndex] = cert
certIndex++
break
}
certIndex++
}
totalLength += 4 + certLen
}
}
return chain, nil
}
func buildEntries(chain [][]byte, chainHashes, cachedHashes, setHashes []uint64) []entry {
res := make([]entry, len(chain))
chainLoop:
for i := range chain {
// Check if hash is in cachedHashes
for j := range cachedHashes {
if chainHashes[i] == cachedHashes[j] {
res[i] = entry{t: entryCached, h: chainHashes[i]}
continue chainLoop
}
}
// Go through common sets and check if it's in there
for _, setHash := range setHashes {
set, ok := certSets[setHash]
if !ok {
// We don't have this set
continue
}
// We have this set, check if chain[i] is in the set
pos := set.findCertInSet(chain[i])
if pos >= 0 {
// Found
res[i] = entry{t: entryCommon, h: setHash, i: uint32(pos)}
continue chainLoop
}
}
res[i] = entry{t: entryCompressed}
}
return res
}
func buildZlibDictForEntries(entries []entry, chain [][]byte) []byte {
var dict bytes.Buffer
// First the cached and common in reverse order
for i := len(entries) - 1; i >= 0; i-- {
if entries[i].t == entryCompressed {
continue
}
dict.Write(chain[i])
}
dict.Write(certDictZlib)
return dict.Bytes()
}
func splitHashes(hashes []byte) ([]uint64, error) {
if len(hashes)%8 != 0 {
return nil, errors.New("expected a multiple of 8 bytes for CCS / CCRT hashes")
}
n := len(hashes) / 8
res := make([]uint64, n)
for i := 0; i < n; i++ {
res[i] = binary.LittleEndian.Uint64(hashes[i*8 : (i+1)*8])
}
return res, nil
}
func getCommonCertificateHashes() []byte {
ccs := make([]byte, 8*len(certSets))
i := 0
for certSetHash := range certSets {
binary.LittleEndian.PutUint64(ccs[i*8:(i+1)*8], certSetHash)
i++
}
return ccs
}
// HashCert calculates the FNV1a hash of a certificate
func HashCert(cert []byte) uint64 {
h := fnv.New64a()
h.Write(cert)
return h.Sum64()
}

View File

@@ -0,0 +1,128 @@
package crypto
var certDictZlib = []byte{
0x04, 0x02, 0x30, 0x00, 0x30, 0x1d, 0x06, 0x03, 0x55, 0x1d, 0x25, 0x04,
0x16, 0x30, 0x14, 0x06, 0x08, 0x2b, 0x06, 0x01, 0x05, 0x05, 0x07, 0x03,
0x01, 0x06, 0x08, 0x2b, 0x06, 0x01, 0x05, 0x05, 0x07, 0x03, 0x02, 0x30,
0x5f, 0x06, 0x09, 0x60, 0x86, 0x48, 0x01, 0x86, 0xf8, 0x42, 0x04, 0x01,
0x06, 0x06, 0x0b, 0x60, 0x86, 0x48, 0x01, 0x86, 0xfd, 0x6d, 0x01, 0x07,
0x17, 0x01, 0x30, 0x33, 0x20, 0x45, 0x78, 0x74, 0x65, 0x6e, 0x64, 0x65,
0x64, 0x20, 0x56, 0x61, 0x6c, 0x69, 0x64, 0x61, 0x74, 0x69, 0x6f, 0x6e,
0x20, 0x53, 0x20, 0x4c, 0x69, 0x6d, 0x69, 0x74, 0x65, 0x64, 0x31, 0x34,
0x20, 0x53, 0x53, 0x4c, 0x20, 0x43, 0x41, 0x30, 0x1e, 0x17, 0x0d, 0x31,
0x32, 0x20, 0x53, 0x65, 0x63, 0x75, 0x72, 0x65, 0x20, 0x53, 0x65, 0x72,
0x76, 0x65, 0x72, 0x20, 0x43, 0x41, 0x30, 0x2d, 0x61, 0x69, 0x61, 0x2e,
0x76, 0x65, 0x72, 0x69, 0x73, 0x69, 0x67, 0x6e, 0x2e, 0x63, 0x6f, 0x6d,
0x2f, 0x45, 0x2d, 0x63, 0x72, 0x6c, 0x2e, 0x76, 0x65, 0x72, 0x69, 0x73,
0x69, 0x67, 0x6e, 0x2e, 0x63, 0x6f, 0x6d, 0x2f, 0x45, 0x2e, 0x63, 0x65,
0x72, 0x30, 0x0d, 0x06, 0x09, 0x2a, 0x86, 0x48, 0x86, 0xf7, 0x0d, 0x01,
0x01, 0x05, 0x05, 0x00, 0x03, 0x82, 0x01, 0x01, 0x00, 0x4a, 0x2e, 0x63,
0x6f, 0x6d, 0x2f, 0x72, 0x65, 0x73, 0x6f, 0x75, 0x72, 0x63, 0x65, 0x73,
0x2f, 0x63, 0x70, 0x73, 0x20, 0x28, 0x63, 0x29, 0x30, 0x30, 0x09, 0x06,
0x03, 0x55, 0x1d, 0x13, 0x04, 0x02, 0x30, 0x00, 0x30, 0x1d, 0x30, 0x0d,
0x06, 0x09, 0x2a, 0x86, 0x48, 0x86, 0xf7, 0x0d, 0x01, 0x01, 0x05, 0x05,
0x00, 0x03, 0x82, 0x01, 0x01, 0x00, 0x7b, 0x30, 0x1d, 0x06, 0x03, 0x55,
0x1d, 0x0e, 0x30, 0x82, 0x01, 0x22, 0x30, 0x0d, 0x06, 0x09, 0x2a, 0x86,
0x48, 0x86, 0xf7, 0x0d, 0x01, 0x01, 0x01, 0x05, 0x00, 0x03, 0x82, 0x01,
0x0f, 0x00, 0x30, 0x82, 0x01, 0x0a, 0x02, 0x82, 0x01, 0x01, 0x00, 0xd2,
0x6f, 0x64, 0x6f, 0x63, 0x61, 0x2e, 0x63, 0x6f, 0x6d, 0x2f, 0x43, 0x2e,
0x63, 0x72, 0x6c, 0x30, 0x1d, 0x06, 0x03, 0x55, 0x1d, 0x0e, 0x04, 0x16,
0x04, 0x14, 0xb4, 0x2e, 0x67, 0x6c, 0x6f, 0x62, 0x61, 0x6c, 0x73, 0x69,
0x67, 0x6e, 0x2e, 0x63, 0x6f, 0x6d, 0x2f, 0x72, 0x30, 0x0b, 0x06, 0x03,
0x55, 0x1d, 0x0f, 0x04, 0x04, 0x03, 0x02, 0x01, 0x30, 0x0d, 0x06, 0x09,
0x2a, 0x86, 0x48, 0x86, 0xf7, 0x0d, 0x01, 0x01, 0x05, 0x05, 0x00, 0x30,
0x81, 0xca, 0x31, 0x0b, 0x30, 0x09, 0x06, 0x03, 0x55, 0x04, 0x06, 0x13,
0x02, 0x55, 0x53, 0x31, 0x10, 0x30, 0x0e, 0x06, 0x03, 0x55, 0x04, 0x08,
0x13, 0x07, 0x41, 0x72, 0x69, 0x7a, 0x6f, 0x6e, 0x61, 0x31, 0x13, 0x30,
0x11, 0x06, 0x03, 0x55, 0x04, 0x07, 0x13, 0x0a, 0x53, 0x63, 0x6f, 0x74,
0x74, 0x73, 0x64, 0x61, 0x6c, 0x65, 0x31, 0x1a, 0x30, 0x18, 0x06, 0x03,
0x55, 0x04, 0x0a, 0x13, 0x11, 0x47, 0x6f, 0x44, 0x61, 0x64, 0x64, 0x79,
0x2e, 0x63, 0x6f, 0x6d, 0x2c, 0x20, 0x49, 0x6e, 0x63, 0x2e, 0x31, 0x33,
0x30, 0x31, 0x06, 0x03, 0x55, 0x04, 0x0b, 0x13, 0x2a, 0x68, 0x74, 0x74,
0x70, 0x3a, 0x2f, 0x2f, 0x63, 0x65, 0x72, 0x74, 0x69, 0x66, 0x69, 0x63,
0x61, 0x74, 0x65, 0x73, 0x2e, 0x67, 0x6f, 0x64, 0x61, 0x64, 0x64, 0x79,
0x2e, 0x63, 0x6f, 0x6d, 0x2f, 0x72, 0x65, 0x70, 0x6f, 0x73, 0x69, 0x74,
0x6f, 0x72, 0x79, 0x31, 0x30, 0x30, 0x2e, 0x06, 0x03, 0x55, 0x04, 0x03,
0x13, 0x27, 0x47, 0x6f, 0x20, 0x44, 0x61, 0x64, 0x64, 0x79, 0x20, 0x53,
0x65, 0x63, 0x75, 0x72, 0x65, 0x20, 0x43, 0x65, 0x72, 0x74, 0x69, 0x66,
0x69, 0x63, 0x61, 0x74, 0x69, 0x6f, 0x6e, 0x20, 0x41, 0x75, 0x74, 0x68,
0x6f, 0x72, 0x69, 0x74, 0x79, 0x31, 0x11, 0x30, 0x0f, 0x06, 0x03, 0x55,
0x04, 0x05, 0x13, 0x08, 0x30, 0x37, 0x39, 0x36, 0x39, 0x32, 0x38, 0x37,
0x30, 0x1e, 0x17, 0x0d, 0x31, 0x31, 0x30, 0x0e, 0x06, 0x03, 0x55, 0x1d,
0x0f, 0x01, 0x01, 0xff, 0x04, 0x04, 0x03, 0x02, 0x05, 0xa0, 0x30, 0x0c,
0x06, 0x03, 0x55, 0x1d, 0x13, 0x01, 0x01, 0xff, 0x04, 0x02, 0x30, 0x00,
0x30, 0x1d, 0x30, 0x0f, 0x06, 0x03, 0x55, 0x1d, 0x13, 0x01, 0x01, 0xff,
0x04, 0x05, 0x30, 0x03, 0x01, 0x01, 0x00, 0x30, 0x1d, 0x06, 0x03, 0x55,
0x1d, 0x25, 0x04, 0x16, 0x30, 0x14, 0x06, 0x08, 0x2b, 0x06, 0x01, 0x05,
0x05, 0x07, 0x03, 0x01, 0x06, 0x08, 0x2b, 0x06, 0x01, 0x05, 0x05, 0x07,
0x03, 0x02, 0x30, 0x0e, 0x06, 0x03, 0x55, 0x1d, 0x0f, 0x01, 0x01, 0xff,
0x04, 0x04, 0x03, 0x02, 0x05, 0xa0, 0x30, 0x33, 0x06, 0x03, 0x55, 0x1d,
0x1f, 0x04, 0x2c, 0x30, 0x2a, 0x30, 0x28, 0xa0, 0x26, 0xa0, 0x24, 0x86,
0x22, 0x68, 0x74, 0x74, 0x70, 0x3a, 0x2f, 0x2f, 0x63, 0x72, 0x6c, 0x2e,
0x67, 0x6f, 0x64, 0x61, 0x64, 0x64, 0x79, 0x2e, 0x63, 0x6f, 0x6d, 0x2f,
0x67, 0x64, 0x73, 0x31, 0x2d, 0x32, 0x30, 0x2a, 0x30, 0x28, 0x06, 0x08,
0x2b, 0x06, 0x01, 0x05, 0x05, 0x07, 0x02, 0x01, 0x16, 0x1c, 0x68, 0x74,
0x74, 0x70, 0x73, 0x3a, 0x2f, 0x2f, 0x77, 0x77, 0x77, 0x2e, 0x76, 0x65,
0x72, 0x69, 0x73, 0x69, 0x67, 0x6e, 0x2e, 0x63, 0x6f, 0x6d, 0x2f, 0x63,
0x70, 0x73, 0x30, 0x34, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x5a, 0x17,
0x0d, 0x31, 0x33, 0x30, 0x35, 0x30, 0x39, 0x06, 0x08, 0x2b, 0x06, 0x01,
0x05, 0x05, 0x07, 0x30, 0x02, 0x86, 0x2d, 0x68, 0x74, 0x74, 0x70, 0x3a,
0x2f, 0x2f, 0x73, 0x30, 0x39, 0x30, 0x37, 0x06, 0x08, 0x2b, 0x06, 0x01,
0x05, 0x05, 0x07, 0x02, 0x30, 0x44, 0x06, 0x03, 0x55, 0x1d, 0x20, 0x04,
0x3d, 0x30, 0x3b, 0x30, 0x39, 0x06, 0x0b, 0x60, 0x86, 0x48, 0x01, 0x86,
0xf8, 0x45, 0x01, 0x07, 0x17, 0x06, 0x31, 0x0b, 0x30, 0x09, 0x06, 0x03,
0x55, 0x04, 0x06, 0x13, 0x02, 0x47, 0x42, 0x31, 0x1b, 0x53, 0x31, 0x17,
0x30, 0x15, 0x06, 0x03, 0x55, 0x04, 0x0a, 0x13, 0x0e, 0x56, 0x65, 0x72,
0x69, 0x53, 0x69, 0x67, 0x6e, 0x2c, 0x20, 0x49, 0x6e, 0x63, 0x2e, 0x31,
0x1f, 0x30, 0x1d, 0x06, 0x03, 0x55, 0x04, 0x0b, 0x13, 0x16, 0x56, 0x65,
0x72, 0x69, 0x53, 0x69, 0x67, 0x6e, 0x20, 0x54, 0x72, 0x75, 0x73, 0x74,
0x20, 0x4e, 0x65, 0x74, 0x77, 0x6f, 0x72, 0x6b, 0x31, 0x3b, 0x30, 0x39,
0x06, 0x03, 0x55, 0x04, 0x0b, 0x13, 0x32, 0x54, 0x65, 0x72, 0x6d, 0x73,
0x20, 0x6f, 0x66, 0x20, 0x75, 0x73, 0x65, 0x20, 0x61, 0x74, 0x20, 0x68,
0x74, 0x74, 0x70, 0x73, 0x3a, 0x2f, 0x2f, 0x77, 0x77, 0x77, 0x2e, 0x76,
0x65, 0x72, 0x69, 0x73, 0x69, 0x67, 0x6e, 0x2e, 0x63, 0x6f, 0x6d, 0x2f,
0x72, 0x70, 0x61, 0x20, 0x28, 0x63, 0x29, 0x30, 0x31, 0x10, 0x30, 0x0e,
0x06, 0x03, 0x55, 0x04, 0x07, 0x13, 0x07, 0x53, 0x31, 0x13, 0x30, 0x11,
0x06, 0x03, 0x55, 0x04, 0x0b, 0x13, 0x0a, 0x47, 0x31, 0x13, 0x30, 0x11,
0x06, 0x0b, 0x2b, 0x06, 0x01, 0x04, 0x01, 0x82, 0x37, 0x3c, 0x02, 0x01,
0x03, 0x13, 0x02, 0x55, 0x31, 0x16, 0x30, 0x14, 0x06, 0x03, 0x55, 0x04,
0x03, 0x14, 0x31, 0x19, 0x30, 0x17, 0x06, 0x03, 0x55, 0x04, 0x03, 0x13,
0x31, 0x1d, 0x30, 0x1b, 0x06, 0x03, 0x55, 0x04, 0x0f, 0x13, 0x14, 0x50,
0x72, 0x69, 0x76, 0x61, 0x74, 0x65, 0x20, 0x4f, 0x72, 0x67, 0x61, 0x6e,
0x69, 0x7a, 0x61, 0x74, 0x69, 0x6f, 0x6e, 0x31, 0x12, 0x31, 0x21, 0x30,
0x1f, 0x06, 0x03, 0x55, 0x04, 0x0b, 0x13, 0x18, 0x44, 0x6f, 0x6d, 0x61,
0x69, 0x6e, 0x20, 0x43, 0x6f, 0x6e, 0x74, 0x72, 0x6f, 0x6c, 0x20, 0x56,
0x61, 0x6c, 0x69, 0x64, 0x61, 0x74, 0x65, 0x64, 0x31, 0x14, 0x31, 0x31,
0x30, 0x2f, 0x06, 0x03, 0x55, 0x04, 0x0b, 0x13, 0x28, 0x53, 0x65, 0x65,
0x20, 0x77, 0x77, 0x77, 0x2e, 0x72, 0x3a, 0x2f, 0x2f, 0x73, 0x65, 0x63,
0x75, 0x72, 0x65, 0x2e, 0x67, 0x47, 0x6c, 0x6f, 0x62, 0x61, 0x6c, 0x53,
0x69, 0x67, 0x6e, 0x31, 0x53, 0x65, 0x72, 0x76, 0x65, 0x72, 0x43, 0x41,
0x2e, 0x63, 0x72, 0x6c, 0x56, 0x65, 0x72, 0x69, 0x53, 0x69, 0x67, 0x6e,
0x20, 0x43, 0x6c, 0x61, 0x73, 0x73, 0x20, 0x33, 0x20, 0x45, 0x63, 0x72,
0x6c, 0x2e, 0x67, 0x65, 0x6f, 0x74, 0x72, 0x75, 0x73, 0x74, 0x2e, 0x63,
0x6f, 0x6d, 0x2f, 0x63, 0x72, 0x6c, 0x73, 0x2f, 0x73, 0x64, 0x31, 0x1a,
0x30, 0x18, 0x06, 0x03, 0x55, 0x04, 0x0a, 0x68, 0x74, 0x74, 0x70, 0x3a,
0x2f, 0x2f, 0x45, 0x56, 0x49, 0x6e, 0x74, 0x6c, 0x2d, 0x63, 0x63, 0x72,
0x74, 0x2e, 0x67, 0x77, 0x77, 0x77, 0x2e, 0x67, 0x69, 0x63, 0x65, 0x72,
0x74, 0x2e, 0x63, 0x6f, 0x6d, 0x2f, 0x31, 0x6f, 0x63, 0x73, 0x70, 0x2e,
0x76, 0x65, 0x72, 0x69, 0x73, 0x69, 0x67, 0x6e, 0x2e, 0x63, 0x6f, 0x6d,
0x30, 0x39, 0x72, 0x61, 0x70, 0x69, 0x64, 0x73, 0x73, 0x6c, 0x2e, 0x63,
0x6f, 0x73, 0x2e, 0x67, 0x6f, 0x64, 0x61, 0x64, 0x64, 0x79, 0x2e, 0x63,
0x6f, 0x6d, 0x2f, 0x72, 0x65, 0x70, 0x6f, 0x73, 0x69, 0x74, 0x6f, 0x72,
0x79, 0x2f, 0x30, 0x81, 0x80, 0x06, 0x08, 0x2b, 0x06, 0x01, 0x05, 0x05,
0x07, 0x01, 0x01, 0x04, 0x74, 0x30, 0x72, 0x30, 0x24, 0x06, 0x08, 0x2b,
0x06, 0x01, 0x05, 0x05, 0x07, 0x30, 0x01, 0x86, 0x18, 0x68, 0x74, 0x74,
0x70, 0x3a, 0x2f, 0x2f, 0x6f, 0x63, 0x73, 0x70, 0x2e, 0x67, 0x6f, 0x64,
0x61, 0x64, 0x64, 0x79, 0x2e, 0x63, 0x6f, 0x6d, 0x2f, 0x30, 0x4a, 0x06,
0x08, 0x2b, 0x06, 0x01, 0x05, 0x05, 0x07, 0x30, 0x02, 0x86, 0x3e, 0x68,
0x74, 0x74, 0x70, 0x3a, 0x2f, 0x2f, 0x63, 0x65, 0x72, 0x74, 0x69, 0x66,
0x69, 0x63, 0x61, 0x74, 0x65, 0x73, 0x2e, 0x67, 0x6f, 0x64, 0x61, 0x64,
0x64, 0x79, 0x2e, 0x63, 0x6f, 0x6d, 0x2f, 0x72, 0x65, 0x70, 0x6f, 0x73,
0x69, 0x74, 0x6f, 0x72, 0x79, 0x2f, 0x67, 0x64, 0x5f, 0x69, 0x6e, 0x74,
0x65, 0x72, 0x6d, 0x65, 0x64, 0x69, 0x61, 0x74, 0x65, 0x2e, 0x63, 0x72,
0x74, 0x30, 0x1f, 0x06, 0x03, 0x55, 0x1d, 0x23, 0x04, 0x18, 0x30, 0x16,
0x80, 0x14, 0xfd, 0xac, 0x61, 0x32, 0x93, 0x6c, 0x45, 0xd6, 0xe2, 0xee,
0x85, 0x5f, 0x9a, 0xba, 0xe7, 0x76, 0x99, 0x68, 0xcc, 0xe7, 0x30, 0x27,
0x86, 0x29, 0x68, 0x74, 0x74, 0x70, 0x3a, 0x2f, 0x2f, 0x63, 0x86, 0x30,
0x68, 0x74, 0x74, 0x70, 0x3a, 0x2f, 0x2f, 0x73,
}

View File

@@ -0,0 +1,135 @@
package crypto
import (
"crypto/tls"
"crypto/x509"
"errors"
"hash/fnv"
"time"
"github.com/lucas-clemente/quic-go/qerr"
)
// CertManager manages the certificates sent by the server
type CertManager interface {
SetData([]byte) error
GetCommonCertificateHashes() []byte
GetLeafCert() []byte
GetLeafCertHash() (uint64, error)
VerifyServerProof(proof, chlo, serverConfigData []byte) bool
Verify(hostname string) error
GetChain() []*x509.Certificate
}
type certManager struct {
chain []*x509.Certificate
config *tls.Config
}
var _ CertManager = &certManager{}
var errNoCertificateChain = errors.New("CertManager BUG: No certicifate chain loaded")
// NewCertManager creates a new CertManager
func NewCertManager(tlsConfig *tls.Config) CertManager {
return &certManager{config: tlsConfig}
}
// SetData takes the byte-slice sent in the SHLO and decompresses it into the certificate chain
func (c *certManager) SetData(data []byte) error {
byteChain, err := decompressChain(data)
if err != nil {
return qerr.Error(qerr.InvalidCryptoMessageParameter, "Certificate data invalid")
}
chain := make([]*x509.Certificate, len(byteChain))
for i, data := range byteChain {
cert, err := x509.ParseCertificate(data)
if err != nil {
return err
}
chain[i] = cert
}
c.chain = chain
return nil
}
func (c *certManager) GetChain() []*x509.Certificate {
return c.chain
}
func (c *certManager) GetCommonCertificateHashes() []byte {
return getCommonCertificateHashes()
}
// GetLeafCert returns the leaf certificate of the certificate chain
// it returns nil if the certificate chain has not yet been set
func (c *certManager) GetLeafCert() []byte {
if len(c.chain) == 0 {
return nil
}
return c.chain[0].Raw
}
// GetLeafCertHash calculates the FNV1a_64 hash of the leaf certificate
func (c *certManager) GetLeafCertHash() (uint64, error) {
leafCert := c.GetLeafCert()
if leafCert == nil {
return 0, errNoCertificateChain
}
h := fnv.New64a()
_, err := h.Write(leafCert)
if err != nil {
return 0, err
}
return h.Sum64(), nil
}
// VerifyServerProof verifies the signature of the server config
// it should only be called after the certificate chain has been set, otherwise it returns false
func (c *certManager) VerifyServerProof(proof, chlo, serverConfigData []byte) bool {
if len(c.chain) == 0 {
return false
}
return verifyServerProof(proof, c.chain[0], chlo, serverConfigData)
}
// Verify verifies the certificate chain
func (c *certManager) Verify(hostname string) error {
if len(c.chain) == 0 {
return errNoCertificateChain
}
if c.config != nil && c.config.InsecureSkipVerify {
return nil
}
leafCert := c.chain[0]
var opts x509.VerifyOptions
if c.config != nil {
opts.Roots = c.config.RootCAs
if c.config.Time == nil {
opts.CurrentTime = time.Now()
} else {
opts.CurrentTime = c.config.Time()
}
}
// we don't need to care about the tls.Config.ServerName here, since hostname has already been set to that value in the session setup
opts.DNSName = hostname
// the first certificate is the leaf certificate, all others are intermediates
if len(c.chain) > 1 {
intermediates := x509.NewCertPool()
for i := 1; i < len(c.chain); i++ {
intermediates.AddCert(c.chain[i])
}
opts.Intermediates = intermediates
}
_, err := leafCert.Verify(opts)
return err
}

View File

@@ -0,0 +1,24 @@
package crypto
import (
"bytes"
"github.com/lucas-clemente/quic-go-certificates"
)
type certSet [][]byte
var certSets = map[uint64]certSet{
certsets.CertSet2Hash: certsets.CertSet2,
certsets.CertSet3Hash: certsets.CertSet3,
}
// findCertInSet searches for the cert in the set. Negative return value means not found.
func (s *certSet) findCertInSet(cert []byte) int {
for i, c := range *s {
if bytes.Equal(c, cert) {
return i
}
}
return -1
}

View File

@@ -0,0 +1,61 @@
// +build ignore
package crypto
import (
"crypto/cipher"
"encoding/binary"
"errors"
"github.com/aead/chacha20"
"github.com/lucas-clemente/quic-go/internal/protocol"
)
type aeadChacha20Poly1305 struct {
otherIV []byte
myIV []byte
encrypter cipher.AEAD
decrypter cipher.AEAD
}
// NewAEADChacha20Poly1305 creates a AEAD using chacha20poly1305
func NewAEADChacha20Poly1305(otherKey []byte, myKey []byte, otherIV []byte, myIV []byte) (AEAD, error) {
if len(myKey) != 32 || len(otherKey) != 32 || len(myIV) != 4 || len(otherIV) != 4 {
return nil, errors.New("chacha20poly1305: expected 32-byte keys and 4-byte IVs")
}
// copy because ChaCha20Poly1305 expects array pointers
var MyKey, OtherKey [32]byte
copy(MyKey[:], myKey)
copy(OtherKey[:], otherKey)
encrypter, err := chacha20.NewChaCha20Poly1305WithTagSize(&MyKey, 12)
if err != nil {
return nil, err
}
decrypter, err := chacha20.NewChaCha20Poly1305WithTagSize(&OtherKey, 12)
if err != nil {
return nil, err
}
return &aeadChacha20Poly1305{
otherIV: otherIV,
myIV: myIV,
encrypter: encrypter,
decrypter: decrypter,
}, nil
}
func (aead *aeadChacha20Poly1305) Open(dst, src []byte, packetNumber protocol.PacketNumber, associatedData []byte) ([]byte, error) {
return aead.decrypter.Open(dst, aead.makeNonce(aead.otherIV, packetNumber), src, associatedData)
}
func (aead *aeadChacha20Poly1305) Seal(dst, src []byte, packetNumber protocol.PacketNumber, associatedData []byte) []byte {
return aead.encrypter.Seal(dst, aead.makeNonce(aead.myIV, packetNumber), src, associatedData)
}
func (aead *aeadChacha20Poly1305) makeNonce(iv []byte, packetNumber protocol.PacketNumber) []byte {
res := make([]byte, 12)
copy(res[0:4], iv)
binary.LittleEndian.PutUint64(res[4:12], uint64(packetNumber))
return res
}

View File

@@ -0,0 +1,41 @@
package crypto
import (
"crypto/rand"
"errors"
"golang.org/x/crypto/curve25519"
)
// KeyExchange manages the exchange of keys
type curve25519KEX struct {
secret [32]byte
public [32]byte
}
var _ KeyExchange = &curve25519KEX{}
// NewCurve25519KEX creates a new KeyExchange using Curve25519, see https://cr.yp.to/ecdh.html
func NewCurve25519KEX() (KeyExchange, error) {
c := &curve25519KEX{}
if _, err := rand.Read(c.secret[:]); err != nil {
return nil, errors.New("Curve25519: could not create private key")
}
curve25519.ScalarBaseMult(&c.public, &c.secret)
return c, nil
}
func (c *curve25519KEX) PublicKey() []byte {
return c.public[:]
}
func (c *curve25519KEX) CalculateSharedKey(otherPublic []byte) ([]byte, error) {
if len(otherPublic) != 32 {
return nil, errors.New("Curve25519: expected public key of 32 byte")
}
var res [32]byte
var otherPublicArray [32]byte
copy(otherPublicArray[:], otherPublic)
curve25519.ScalarMult(&res, &c.secret, &otherPublicArray)
return res[:], nil
}

View File

@@ -0,0 +1,56 @@
package crypto
import (
"crypto"
"crypto/hmac"
"encoding/binary"
)
// copied from https://github.com/cloudflare/tls-tris/blob/master/hkdf.go
func hkdfExtract(hash crypto.Hash, secret, salt []byte) []byte {
if salt == nil {
salt = make([]byte, hash.Size())
}
if secret == nil {
secret = make([]byte, hash.Size())
}
extractor := hmac.New(hash.New, salt)
extractor.Write(secret)
return extractor.Sum(nil)
}
// copied from https://github.com/cloudflare/tls-tris/blob/master/hkdf.go
func hkdfExpand(hash crypto.Hash, prk, info []byte, l int) []byte {
var (
expander = hmac.New(hash.New, prk)
res = make([]byte, l)
counter = byte(1)
prev []byte
)
if l > 255*expander.Size() {
panic("hkdf: requested too much output")
}
p := res
for len(p) > 0 {
expander.Reset()
expander.Write(prev)
expander.Write(info)
expander.Write([]byte{counter})
prev = expander.Sum(prev[:0])
counter++
n := copy(p, prev)
p = p[n:]
}
return res
}
func qhkdfExpand(secret []byte, label string, length int) []byte {
qlabel := make([]byte, 2+1+5+len(label))
binary.BigEndian.PutUint16(qlabel[0:2], uint16(length))
qlabel[2] = uint8(5 + len(label))
copy(qlabel[3:], []byte("QUIC "+label))
return hkdfExpand(crypto.SHA256, secret, qlabel, length)
}

View File

@@ -0,0 +1,49 @@
package crypto
import (
"github.com/bifurcation/mint"
"github.com/lucas-clemente/quic-go/internal/protocol"
)
const (
clientExporterLabel = "EXPORTER-QUIC client 1rtt"
serverExporterLabel = "EXPORTER-QUIC server 1rtt"
)
// A TLSExporter gets the negotiated ciphersuite and computes exporter
type TLSExporter interface {
ConnectionState() mint.ConnectionState
ComputeExporter(label string, context []byte, keyLength int) ([]byte, error)
}
// DeriveAESKeys derives the AES keys and creates a matching AES-GCM AEAD instance
func DeriveAESKeys(tls TLSExporter, pers protocol.Perspective) (AEAD, error) {
var myLabel, otherLabel string
if pers == protocol.PerspectiveClient {
myLabel = clientExporterLabel
otherLabel = serverExporterLabel
} else {
myLabel = serverExporterLabel
otherLabel = clientExporterLabel
}
myKey, myIV, err := computeKeyAndIV(tls, myLabel)
if err != nil {
return nil, err
}
otherKey, otherIV, err := computeKeyAndIV(tls, otherLabel)
if err != nil {
return nil, err
}
return NewAEADAESGCM(otherKey, myKey, otherIV, myIV)
}
func computeKeyAndIV(tls TLSExporter, label string) (key, iv []byte, err error) {
cs := tls.ConnectionState().CipherSuite
secret, err := tls.ComputeExporter(label, nil, cs.Hash.Size())
if err != nil {
return nil, nil, err
}
key = qhkdfExpand(secret, "key", cs.KeyLen)
iv = qhkdfExpand(secret, "iv", cs.IvLen)
return key, iv, nil
}

View File

@@ -0,0 +1,100 @@
package crypto
import (
"bytes"
"crypto/sha256"
"io"
"github.com/lucas-clemente/quic-go/internal/protocol"
"golang.org/x/crypto/hkdf"
)
// DeriveKeysChacha20 derives the client and server keys and creates a matching chacha20poly1305 AEAD instance
// func DeriveKeysChacha20(version protocol.VersionNumber, forwardSecure bool, sharedSecret, nonces []byte, connID protocol.ConnectionID, chlo []byte, scfg []byte, cert []byte, divNonce []byte) (AEAD, error) {
// otherKey, myKey, otherIV, myIV, err := deriveKeys(version, forwardSecure, sharedSecret, nonces, connID, chlo, scfg, cert, divNonce, 32)
// if err != nil {
// return nil, err
// }
// return NewAEADChacha20Poly1305(otherKey, myKey, otherIV, myIV)
// }
// DeriveQuicCryptoAESKeys derives the client and server keys and creates a matching AES-GCM AEAD instance
func DeriveQuicCryptoAESKeys(forwardSecure bool, sharedSecret, nonces []byte, connID protocol.ConnectionID, chlo []byte, scfg []byte, cert []byte, divNonce []byte, pers protocol.Perspective) (AEAD, error) {
var swap bool
if pers == protocol.PerspectiveClient {
swap = true
}
otherKey, myKey, otherIV, myIV, err := deriveKeys(forwardSecure, sharedSecret, nonces, connID, chlo, scfg, cert, divNonce, 16, swap)
if err != nil {
return nil, err
}
return NewAEADAESGCM12(otherKey, myKey, otherIV, myIV)
}
// deriveKeys derives the keys and the IVs
// swap should be set true if generating the values for the client, and false for the server
func deriveKeys(forwardSecure bool, sharedSecret, nonces []byte, connID protocol.ConnectionID, chlo, scfg, cert, divNonce []byte, keyLen int, swap bool) ([]byte, []byte, []byte, []byte, error) {
var info bytes.Buffer
if forwardSecure {
info.Write([]byte("QUIC forward secure key expansion\x00"))
} else {
info.Write([]byte("QUIC key expansion\x00"))
}
info.Write(connID)
info.Write(chlo)
info.Write(scfg)
info.Write(cert)
r := hkdf.New(sha256.New, sharedSecret, nonces, info.Bytes())
s := make([]byte, 2*keyLen+2*4)
if _, err := io.ReadFull(r, s); err != nil {
return nil, nil, nil, nil, err
}
key1 := s[:keyLen]
key2 := s[keyLen : 2*keyLen]
iv1 := s[2*keyLen : 2*keyLen+4]
iv2 := s[2*keyLen+4:]
var otherKey, myKey []byte
var otherIV, myIV []byte
if !forwardSecure {
if err := diversify(key2, iv2, divNonce); err != nil {
return nil, nil, nil, nil, err
}
}
if swap {
otherKey = key2
myKey = key1
otherIV = iv2
myIV = iv1
} else {
otherKey = key1
myKey = key2
otherIV = iv1
myIV = iv2
}
return otherKey, myKey, otherIV, myIV, nil
}
func diversify(key, iv, divNonce []byte) error {
secret := make([]byte, len(key)+len(iv))
copy(secret, key)
copy(secret[len(key):], iv)
r := hkdf.New(sha256.New, secret, divNonce, []byte("QUIC key diversification"))
if _, err := io.ReadFull(r, key); err != nil {
return err
}
if _, err := io.ReadFull(r, iv); err != nil {
return err
}
return nil
}

View File

@@ -0,0 +1,7 @@
package crypto
// KeyExchange manages the exchange of keys
type KeyExchange interface {
PublicKey() []byte
CalculateSharedKey(otherPublic []byte) ([]byte, error)
}

View File

@@ -0,0 +1,11 @@
package crypto
import "github.com/lucas-clemente/quic-go/internal/protocol"
// NewNullAEAD creates a NullAEAD
func NewNullAEAD(p protocol.Perspective, connID protocol.ConnectionID, v protocol.VersionNumber) (AEAD, error) {
if v.UsesTLS() {
return newNullAEADAESGCM(connID, p)
}
return &nullAEADFNV128a{perspective: p}, nil
}

View File

@@ -0,0 +1,40 @@
package crypto
import (
"crypto"
"github.com/lucas-clemente/quic-go/internal/protocol"
)
var quicVersion1Salt = []byte{0x9c, 0x10, 0x8f, 0x98, 0x52, 0x0a, 0x5c, 0x5c, 0x32, 0x96, 0x8e, 0x95, 0x0e, 0x8a, 0x2c, 0x5f, 0xe0, 0x6d, 0x6c, 0x38}
func newNullAEADAESGCM(connectionID protocol.ConnectionID, pers protocol.Perspective) (AEAD, error) {
clientSecret, serverSecret := computeSecrets(connectionID)
var mySecret, otherSecret []byte
if pers == protocol.PerspectiveClient {
mySecret = clientSecret
otherSecret = serverSecret
} else {
mySecret = serverSecret
otherSecret = clientSecret
}
myKey, myIV := computeNullAEADKeyAndIV(mySecret)
otherKey, otherIV := computeNullAEADKeyAndIV(otherSecret)
return NewAEADAESGCM(otherKey, myKey, otherIV, myIV)
}
func computeSecrets(connID protocol.ConnectionID) (clientSecret, serverSecret []byte) {
handshakeSecret := hkdfExtract(crypto.SHA256, connID, quicVersion1Salt)
clientSecret = qhkdfExpand(handshakeSecret, "client hs", crypto.SHA256.Size())
serverSecret = qhkdfExpand(handshakeSecret, "server hs", crypto.SHA256.Size())
return
}
func computeNullAEADKeyAndIV(secret []byte) (key, iv []byte) {
key = qhkdfExpand(secret, "key", 16)
iv = qhkdfExpand(secret, "iv", 12)
return
}

View File

@@ -0,0 +1,79 @@
package crypto
import (
"bytes"
"errors"
"fmt"
"hash/fnv"
"github.com/lucas-clemente/quic-go/internal/protocol"
)
// nullAEAD handles not-yet encrypted packets
type nullAEADFNV128a struct {
perspective protocol.Perspective
}
var _ AEAD = &nullAEADFNV128a{}
// Open and verify the ciphertext
func (n *nullAEADFNV128a) Open(dst, src []byte, packetNumber protocol.PacketNumber, associatedData []byte) ([]byte, error) {
if len(src) < 12 {
return nil, errors.New("NullAEAD: ciphertext cannot be less than 12 bytes long")
}
hash := fnv.New128a()
hash.Write(associatedData)
hash.Write(src[12:])
if n.perspective == protocol.PerspectiveServer {
hash.Write([]byte("Client"))
} else {
hash.Write([]byte("Server"))
}
sum := make([]byte, 0, 16)
sum = hash.Sum(sum)
// The tag is written in little endian, so we need to reverse the slice.
reverse(sum)
if !bytes.Equal(sum[:12], src[:12]) {
return nil, fmt.Errorf("NullAEAD: failed to authenticate received data (%#v vs %#v)", sum[:12], src[:12])
}
return src[12:], nil
}
// Seal writes hash and ciphertext to the buffer
func (n *nullAEADFNV128a) Seal(dst, src []byte, packetNumber protocol.PacketNumber, associatedData []byte) []byte {
if cap(dst) < 12+len(src) {
dst = make([]byte, 12+len(src))
} else {
dst = dst[:12+len(src)]
}
hash := fnv.New128a()
hash.Write(associatedData)
hash.Write(src)
if n.perspective == protocol.PerspectiveServer {
hash.Write([]byte("Server"))
} else {
hash.Write([]byte("Client"))
}
sum := make([]byte, 0, 16)
sum = hash.Sum(sum)
// The tag is written in little endian, so we need to reverse the slice.
reverse(sum)
copy(dst[12:], src)
copy(dst, sum[:12])
return dst
}
func (n *nullAEADFNV128a) Overhead() int {
return 12
}
func reverse(a []byte) {
for left, right := 0, len(a)-1; left < right; left, right = left+1, right-1 {
a[left], a[right] = a[right], a[left]
}
}

View File

@@ -0,0 +1,66 @@
package crypto
import (
"crypto"
"crypto/ecdsa"
"crypto/rand"
"crypto/rsa"
"crypto/sha256"
"crypto/tls"
"crypto/x509"
"encoding/asn1"
"errors"
"math/big"
)
type ecdsaSignature struct {
R, S *big.Int
}
// signServerProof signs CHLO and server config for use in the server proof
func signServerProof(cert *tls.Certificate, chlo []byte, serverConfigData []byte) ([]byte, error) {
hash := sha256.New()
hash.Write([]byte("QUIC CHLO and server config signature\x00"))
chloHash := sha256.Sum256(chlo)
hash.Write([]byte{32, 0, 0, 0})
hash.Write(chloHash[:])
hash.Write(serverConfigData)
key, ok := cert.PrivateKey.(crypto.Signer)
if !ok {
return nil, errors.New("expected PrivateKey to implement crypto.Signer")
}
opts := crypto.SignerOpts(crypto.SHA256)
if _, ok = key.(*rsa.PrivateKey); ok {
opts = &rsa.PSSOptions{SaltLength: 32, Hash: crypto.SHA256}
}
return key.Sign(rand.Reader, hash.Sum(nil), opts)
}
// verifyServerProof verifies the server proof signature
func verifyServerProof(proof []byte, cert *x509.Certificate, chlo []byte, serverConfigData []byte) bool {
hash := sha256.New()
hash.Write([]byte("QUIC CHLO and server config signature\x00"))
chloHash := sha256.Sum256(chlo)
hash.Write([]byte{32, 0, 0, 0})
hash.Write(chloHash[:])
hash.Write(serverConfigData)
// RSA
if cert.PublicKeyAlgorithm == x509.RSA {
opts := &rsa.PSSOptions{SaltLength: 32, Hash: crypto.SHA256}
err := rsa.VerifyPSS(cert.PublicKey.(*rsa.PublicKey), crypto.SHA256, hash.Sum(nil), proof, opts)
return err == nil
}
// ECDSA
signature := &ecdsaSignature{}
rest, err := asn1.Unmarshal(proof, signature)
if err != nil || len(rest) != 0 {
return false
}
return ecdsa.Verify(cert.PublicKey.(*ecdsa.PublicKey), hash.Sum(nil), signature.R, signature.S)
}

View File

@@ -0,0 +1,122 @@
package flowcontrol
import (
"sync"
"time"
"github.com/lucas-clemente/quic-go/internal/congestion"
"github.com/lucas-clemente/quic-go/internal/protocol"
"github.com/lucas-clemente/quic-go/internal/utils"
)
type baseFlowController struct {
// for sending data
bytesSent protocol.ByteCount
sendWindow protocol.ByteCount
lastBlockedAt protocol.ByteCount
// for receiving data
mutex sync.RWMutex
bytesRead protocol.ByteCount
highestReceived protocol.ByteCount
receiveWindow protocol.ByteCount
receiveWindowSize protocol.ByteCount
maxReceiveWindowSize protocol.ByteCount
epochStartTime time.Time
epochStartOffset protocol.ByteCount
rttStats *congestion.RTTStats
logger utils.Logger
}
// IsNewlyBlocked says if it is newly blocked by flow control.
// For every offset, it only returns true once.
// If it is blocked, the offset is returned.
func (c *baseFlowController) IsNewlyBlocked() (bool, protocol.ByteCount) {
if c.sendWindowSize() != 0 || c.sendWindow == c.lastBlockedAt {
return false, 0
}
c.lastBlockedAt = c.sendWindow
return true, c.sendWindow
}
func (c *baseFlowController) AddBytesSent(n protocol.ByteCount) {
c.bytesSent += n
}
// UpdateSendWindow should be called after receiving a WindowUpdateFrame
// it returns true if the window was actually updated
func (c *baseFlowController) UpdateSendWindow(offset protocol.ByteCount) {
if offset > c.sendWindow {
c.sendWindow = offset
}
}
func (c *baseFlowController) sendWindowSize() protocol.ByteCount {
// this only happens during connection establishment, when data is sent before we receive the peer's transport parameters
if c.bytesSent > c.sendWindow {
return 0
}
return c.sendWindow - c.bytesSent
}
func (c *baseFlowController) AddBytesRead(n protocol.ByteCount) {
c.mutex.Lock()
defer c.mutex.Unlock()
// pretend we sent a WindowUpdate when reading the first byte
// this way auto-tuning of the window size already works for the first WindowUpdate
if c.bytesRead == 0 {
c.startNewAutoTuningEpoch()
}
c.bytesRead += n
}
func (c *baseFlowController) hasWindowUpdate() bool {
bytesRemaining := c.receiveWindow - c.bytesRead
// update the window when more than the threshold was consumed
return bytesRemaining <= protocol.ByteCount((float64(c.receiveWindowSize) * float64((1 - protocol.WindowUpdateThreshold))))
}
// getWindowUpdate updates the receive window, if necessary
// it returns the new offset
func (c *baseFlowController) getWindowUpdate() protocol.ByteCount {
if !c.hasWindowUpdate() {
return 0
}
c.maybeAdjustWindowSize()
c.receiveWindow = c.bytesRead + c.receiveWindowSize
return c.receiveWindow
}
// maybeAdjustWindowSize increases the receiveWindowSize if we're sending updates too often.
// For details about auto-tuning, see https://docs.google.com/document/d/1SExkMmGiz8VYzV3s9E35JQlJ73vhzCekKkDi85F1qCE/edit?usp=sharing.
func (c *baseFlowController) maybeAdjustWindowSize() {
bytesReadInEpoch := c.bytesRead - c.epochStartOffset
// don't do anything if less than half the window has been consumed
if bytesReadInEpoch <= c.receiveWindowSize/2 {
return
}
rtt := c.rttStats.SmoothedRTT()
if rtt == 0 {
return
}
fraction := float64(bytesReadInEpoch) / float64(c.receiveWindowSize)
if time.Since(c.epochStartTime) < time.Duration(4*fraction*float64(rtt)) {
// window is consumed too fast, try to increase the window size
c.receiveWindowSize = utils.MinByteCount(2*c.receiveWindowSize, c.maxReceiveWindowSize)
}
c.startNewAutoTuningEpoch()
}
func (c *baseFlowController) startNewAutoTuningEpoch() {
c.epochStartTime = time.Now()
c.epochStartOffset = c.bytesRead
}
func (c *baseFlowController) checkFlowControlViolation() bool {
return c.highestReceived > c.receiveWindow
}

View File

@@ -0,0 +1,87 @@
package flowcontrol
import (
"fmt"
"github.com/lucas-clemente/quic-go/internal/congestion"
"github.com/lucas-clemente/quic-go/internal/protocol"
"github.com/lucas-clemente/quic-go/internal/utils"
"github.com/lucas-clemente/quic-go/qerr"
)
type connectionFlowController struct {
baseFlowController
queueWindowUpdate func()
}
var _ ConnectionFlowController = &connectionFlowController{}
// NewConnectionFlowController gets a new flow controller for the connection
// It is created before we receive the peer's transport paramenters, thus it starts with a sendWindow of 0.
func NewConnectionFlowController(
receiveWindow protocol.ByteCount,
maxReceiveWindow protocol.ByteCount,
queueWindowUpdate func(),
rttStats *congestion.RTTStats,
logger utils.Logger,
) ConnectionFlowController {
return &connectionFlowController{
baseFlowController: baseFlowController{
rttStats: rttStats,
receiveWindow: receiveWindow,
receiveWindowSize: receiveWindow,
maxReceiveWindowSize: maxReceiveWindow,
logger: logger,
},
queueWindowUpdate: queueWindowUpdate,
}
}
func (c *connectionFlowController) SendWindowSize() protocol.ByteCount {
return c.baseFlowController.sendWindowSize()
}
// IncrementHighestReceived adds an increment to the highestReceived value
func (c *connectionFlowController) IncrementHighestReceived(increment protocol.ByteCount) error {
c.mutex.Lock()
defer c.mutex.Unlock()
c.highestReceived += increment
if c.checkFlowControlViolation() {
return qerr.Error(qerr.FlowControlReceivedTooMuchData, fmt.Sprintf("Received %d bytes for the connection, allowed %d bytes", c.highestReceived, c.receiveWindow))
}
return nil
}
func (c *connectionFlowController) MaybeQueueWindowUpdate() {
c.mutex.Lock()
hasWindowUpdate := c.hasWindowUpdate()
c.mutex.Unlock()
if hasWindowUpdate {
c.queueWindowUpdate()
}
}
func (c *connectionFlowController) GetWindowUpdate() protocol.ByteCount {
c.mutex.Lock()
oldWindowSize := c.receiveWindowSize
offset := c.baseFlowController.getWindowUpdate()
if oldWindowSize < c.receiveWindowSize {
c.logger.Debugf("Increasing receive flow control window for the connection to %d kB", c.receiveWindowSize/(1<<10))
}
c.mutex.Unlock()
return offset
}
// EnsureMinimumWindowSize sets a minimum window size
// it should make sure that the connection-level window is increased when a stream-level window grows
func (c *connectionFlowController) EnsureMinimumWindowSize(inc protocol.ByteCount) {
c.mutex.Lock()
if inc > c.receiveWindowSize {
c.logger.Debugf("Increasing receive flow control window for the connection to %d kB, in response to stream flow control window increase", c.receiveWindowSize/(1<<10))
c.receiveWindowSize = utils.MinByteCount(inc, c.maxReceiveWindowSize)
c.startNewAutoTuningEpoch()
}
c.mutex.Unlock()
}

View File

@@ -0,0 +1,38 @@
package flowcontrol
import "github.com/lucas-clemente/quic-go/internal/protocol"
type flowController interface {
// for sending
SendWindowSize() protocol.ByteCount
UpdateSendWindow(protocol.ByteCount)
AddBytesSent(protocol.ByteCount)
// for receiving
AddBytesRead(protocol.ByteCount)
GetWindowUpdate() protocol.ByteCount // returns 0 if no update is necessary
MaybeQueueWindowUpdate() // queues a window update, if necessary
IsNewlyBlocked() (bool, protocol.ByteCount)
}
// A StreamFlowController is a flow controller for a QUIC stream.
type StreamFlowController interface {
flowController
// for receiving
// UpdateHighestReceived should be called when a new highest offset is received
// final has to be to true if this is the final offset of the stream, as contained in a STREAM frame with FIN bit, and the RST_STREAM frame
UpdateHighestReceived(offset protocol.ByteCount, final bool) error
}
// The ConnectionFlowController is the flow controller for the connection.
type ConnectionFlowController interface {
flowController
}
type connectionFlowControllerI interface {
ConnectionFlowController
// The following two methods are not supposed to be called from outside this packet, but are needed internally
// for sending
EnsureMinimumWindowSize(protocol.ByteCount)
// for receiving
IncrementHighestReceived(protocol.ByteCount) error
}

View File

@@ -0,0 +1,149 @@
package flowcontrol
import (
"fmt"
"github.com/lucas-clemente/quic-go/internal/congestion"
"github.com/lucas-clemente/quic-go/internal/protocol"
"github.com/lucas-clemente/quic-go/internal/utils"
"github.com/lucas-clemente/quic-go/qerr"
)
type streamFlowController struct {
baseFlowController
streamID protocol.StreamID
queueWindowUpdate func()
connection connectionFlowControllerI
contributesToConnection bool // does the stream contribute to connection level flow control
receivedFinalOffset bool
}
var _ StreamFlowController = &streamFlowController{}
// NewStreamFlowController gets a new flow controller for a stream
func NewStreamFlowController(
streamID protocol.StreamID,
contributesToConnection bool,
cfc ConnectionFlowController,
receiveWindow protocol.ByteCount,
maxReceiveWindow protocol.ByteCount,
initialSendWindow protocol.ByteCount,
queueWindowUpdate func(protocol.StreamID),
rttStats *congestion.RTTStats,
logger utils.Logger,
) StreamFlowController {
return &streamFlowController{
streamID: streamID,
contributesToConnection: contributesToConnection,
connection: cfc.(connectionFlowControllerI),
queueWindowUpdate: func() { queueWindowUpdate(streamID) },
baseFlowController: baseFlowController{
rttStats: rttStats,
receiveWindow: receiveWindow,
receiveWindowSize: receiveWindow,
maxReceiveWindowSize: maxReceiveWindow,
sendWindow: initialSendWindow,
logger: logger,
},
}
}
// UpdateHighestReceived updates the highestReceived value, if the byteOffset is higher
// it returns an ErrReceivedSmallerByteOffset if the received byteOffset is smaller than any byteOffset received before
func (c *streamFlowController) UpdateHighestReceived(byteOffset protocol.ByteCount, final bool) error {
c.mutex.Lock()
defer c.mutex.Unlock()
// when receiving a final offset, check that this final offset is consistent with a final offset we might have received earlier
if final && c.receivedFinalOffset && byteOffset != c.highestReceived {
return qerr.Error(qerr.StreamDataAfterTermination, fmt.Sprintf("Received inconsistent final offset for stream %d (old: %d, new: %d bytes)", c.streamID, c.highestReceived, byteOffset))
}
// if we already received a final offset, check that the offset in the STREAM frames is below the final offset
if c.receivedFinalOffset && byteOffset > c.highestReceived {
return qerr.StreamDataAfterTermination
}
if final {
c.receivedFinalOffset = true
}
if byteOffset == c.highestReceived {
return nil
}
if byteOffset <= c.highestReceived {
// a STREAM_FRAME with a higher offset was received before.
if final {
// If the current byteOffset is smaller than the offset in that STREAM_FRAME, this STREAM_FRAME contained data after the end of the stream
return qerr.StreamDataAfterTermination
}
// this is a reordered STREAM_FRAME
return nil
}
increment := byteOffset - c.highestReceived
c.highestReceived = byteOffset
if c.checkFlowControlViolation() {
return qerr.Error(qerr.FlowControlReceivedTooMuchData, fmt.Sprintf("Received %d bytes on stream %d, allowed %d bytes", byteOffset, c.streamID, c.receiveWindow))
}
if c.contributesToConnection {
return c.connection.IncrementHighestReceived(increment)
}
return nil
}
func (c *streamFlowController) AddBytesRead(n protocol.ByteCount) {
c.baseFlowController.AddBytesRead(n)
if c.contributesToConnection {
c.connection.AddBytesRead(n)
}
}
func (c *streamFlowController) AddBytesSent(n protocol.ByteCount) {
c.baseFlowController.AddBytesSent(n)
if c.contributesToConnection {
c.connection.AddBytesSent(n)
}
}
func (c *streamFlowController) SendWindowSize() protocol.ByteCount {
window := c.baseFlowController.sendWindowSize()
if c.contributesToConnection {
window = utils.MinByteCount(window, c.connection.SendWindowSize())
}
return window
}
func (c *streamFlowController) MaybeQueueWindowUpdate() {
c.mutex.Lock()
hasWindowUpdate := !c.receivedFinalOffset && c.hasWindowUpdate()
c.mutex.Unlock()
if hasWindowUpdate {
c.queueWindowUpdate()
}
if c.contributesToConnection {
c.connection.MaybeQueueWindowUpdate()
}
}
func (c *streamFlowController) GetWindowUpdate() protocol.ByteCount {
// don't use defer for unlocking the mutex here, GetWindowUpdate() is called frequently and defer shows up in the profiler
c.mutex.Lock()
// if we already received the final offset for this stream, the peer won't need any additional flow control credit
if c.receivedFinalOffset {
c.mutex.Unlock()
return 0
}
oldWindowSize := c.receiveWindowSize
offset := c.baseFlowController.getWindowUpdate()
if c.receiveWindowSize > oldWindowSize { // auto-tuning enlarged the window size
c.logger.Debugf("Increasing receive flow control window for stream %d to %d kB", c.streamID, c.receiveWindowSize/(1<<10))
if c.contributesToConnection {
c.connection.EnsureMinimumWindowSize(protocol.ByteCount(float64(c.receiveWindowSize) * protocol.ConnectionFlowControlMultiplier))
}
}
c.mutex.Unlock()
return offset
}

View File

@@ -0,0 +1,99 @@
package handshake
import (
"encoding/asn1"
"fmt"
"net"
"time"
)
const (
cookiePrefixIP byte = iota
cookiePrefixString
)
// A Cookie is derived from the client address and can be used to verify the ownership of this address.
type Cookie struct {
RemoteAddr string
// The time that the STK was issued (resolution 1 second)
SentTime time.Time
}
// token is the struct that is used for ASN1 serialization and deserialization
type token struct {
Data []byte
Timestamp int64
}
// A CookieGenerator generates Cookies
type CookieGenerator struct {
cookieProtector cookieProtector
}
// NewCookieGenerator initializes a new CookieGenerator
func NewCookieGenerator() (*CookieGenerator, error) {
cookieProtector, err := newCookieProtector()
if err != nil {
return nil, err
}
return &CookieGenerator{
cookieProtector: cookieProtector,
}, nil
}
// NewToken generates a new Cookie for a given source address
func (g *CookieGenerator) NewToken(raddr net.Addr) ([]byte, error) {
data, err := asn1.Marshal(token{
Data: encodeRemoteAddr(raddr),
Timestamp: time.Now().Unix(),
})
if err != nil {
return nil, err
}
return g.cookieProtector.NewToken(data)
}
// DecodeToken decodes a Cookie
func (g *CookieGenerator) DecodeToken(encrypted []byte) (*Cookie, error) {
// if the client didn't send any Cookie, DecodeToken will be called with a nil-slice
if len(encrypted) == 0 {
return nil, nil
}
data, err := g.cookieProtector.DecodeToken(encrypted)
if err != nil {
return nil, err
}
t := &token{}
rest, err := asn1.Unmarshal(data, t)
if err != nil {
return nil, err
}
if len(rest) != 0 {
return nil, fmt.Errorf("rest when unpacking token: %d", len(rest))
}
return &Cookie{
RemoteAddr: decodeRemoteAddr(t.Data),
SentTime: time.Unix(t.Timestamp, 0),
}, nil
}
// encodeRemoteAddr encodes a remote address such that it can be saved in the Cookie
func encodeRemoteAddr(remoteAddr net.Addr) []byte {
if udpAddr, ok := remoteAddr.(*net.UDPAddr); ok {
return append([]byte{cookiePrefixIP}, udpAddr.IP...)
}
return append([]byte{cookiePrefixString}, []byte(remoteAddr.String())...)
}
// decodeRemoteAddr decodes the remote address saved in the Cookie
func decodeRemoteAddr(data []byte) string {
// data will never be empty for a Cookie that we generated. Check it to be on the safe side
if len(data) == 0 {
return ""
}
if data[0] == cookiePrefixIP {
return net.IP(data[1:]).String()
}
return string(data[1:])
}

View File

@@ -0,0 +1,86 @@
package handshake
import (
"crypto/aes"
"crypto/cipher"
"crypto/rand"
"crypto/sha256"
"fmt"
"io"
"golang.org/x/crypto/hkdf"
)
// CookieProtector is used to create and verify a cookie
type cookieProtector interface {
// NewToken creates a new token
NewToken([]byte) ([]byte, error)
// DecodeToken decodes a token
DecodeToken([]byte) ([]byte, error)
}
const (
cookieSecretSize = 32
cookieNonceSize = 32
)
// cookieProtector is used to create and verify a cookie
type cookieProtectorImpl struct {
secret []byte
}
// newCookieProtector creates a source for source address tokens
func newCookieProtector() (cookieProtector, error) {
secret := make([]byte, cookieSecretSize)
if _, err := rand.Read(secret); err != nil {
return nil, err
}
return &cookieProtectorImpl{secret: secret}, nil
}
// NewToken encodes data into a new token.
func (s *cookieProtectorImpl) NewToken(data []byte) ([]byte, error) {
nonce := make([]byte, cookieNonceSize)
if _, err := rand.Read(nonce); err != nil {
return nil, err
}
aead, aeadNonce, err := s.createAEAD(nonce)
if err != nil {
return nil, err
}
return append(nonce, aead.Seal(nil, aeadNonce, data, nil)...), nil
}
// DecodeToken decodes a token.
func (s *cookieProtectorImpl) DecodeToken(p []byte) ([]byte, error) {
if len(p) < cookieNonceSize {
return nil, fmt.Errorf("Token too short: %d", len(p))
}
nonce := p[:cookieNonceSize]
aead, aeadNonce, err := s.createAEAD(nonce)
if err != nil {
return nil, err
}
return aead.Open(nil, aeadNonce, p[cookieNonceSize:], nil)
}
func (s *cookieProtectorImpl) createAEAD(nonce []byte) (cipher.AEAD, []byte, error) {
h := hkdf.New(sha256.New, s.secret, nonce, []byte("quic-go cookie source"))
key := make([]byte, 32) // use a 32 byte key, in order to select AES-256
if _, err := io.ReadFull(h, key); err != nil {
return nil, nil, err
}
aeadNonce := make([]byte, 12)
if _, err := io.ReadFull(h, aeadNonce); err != nil {
return nil, nil, err
}
c, err := aes.NewCipher(key)
if err != nil {
return nil, nil, err
}
aead, err := cipher.NewGCM(c)
if err != nil {
return nil, nil, err
}
return aead, aeadNonce, nil
}

View File

@@ -0,0 +1,542 @@
package handshake
import (
"bytes"
"crypto/rand"
"crypto/tls"
"encoding/binary"
"errors"
"fmt"
"io"
"sync"
"time"
"github.com/lucas-clemente/quic-go/internal/crypto"
"github.com/lucas-clemente/quic-go/internal/protocol"
"github.com/lucas-clemente/quic-go/internal/utils"
"github.com/lucas-clemente/quic-go/qerr"
)
type cryptoSetupClient struct {
mutex sync.RWMutex
hostname string
connID protocol.ConnectionID
version protocol.VersionNumber
initialVersion protocol.VersionNumber
negotiatedVersions []protocol.VersionNumber
cryptoStream io.ReadWriter
serverConfig *serverConfigClient
stk []byte
sno []byte
nonc []byte
proof []byte
chloForSignature []byte
lastSentCHLO []byte
certManager crypto.CertManager
divNonceChan chan struct{}
diversificationNonce []byte
clientHelloCounter int
serverVerified bool // has the certificate chain and the proof already been verified
keyDerivation QuicCryptoKeyDerivationFunction
receivedSecurePacket bool
nullAEAD crypto.AEAD
secureAEAD crypto.AEAD
forwardSecureAEAD crypto.AEAD
paramsChan chan<- TransportParameters
handshakeEvent chan<- struct{}
params *TransportParameters
logger utils.Logger
}
var _ CryptoSetup = &cryptoSetupClient{}
var (
errNoObitForClientNonce = errors.New("CryptoSetup BUG: No OBIT for client nonce available")
errClientNonceAlreadyExists = errors.New("CryptoSetup BUG: A client nonce was already generated")
errConflictingDiversificationNonces = errors.New("Received two different diversification nonces")
)
// NewCryptoSetupClient creates a new CryptoSetup instance for a client
func NewCryptoSetupClient(
cryptoStream io.ReadWriter,
connID protocol.ConnectionID,
version protocol.VersionNumber,
tlsConf *tls.Config,
params *TransportParameters,
paramsChan chan<- TransportParameters,
handshakeEvent chan<- struct{},
initialVersion protocol.VersionNumber,
negotiatedVersions []protocol.VersionNumber,
logger utils.Logger,
) (CryptoSetup, error) {
nullAEAD, err := crypto.NewNullAEAD(protocol.PerspectiveClient, connID, version)
if err != nil {
return nil, err
}
divNonceChan := make(chan struct{})
cs := &cryptoSetupClient{
cryptoStream: cryptoStream,
hostname: tlsConf.ServerName,
connID: connID,
version: version,
certManager: crypto.NewCertManager(tlsConf),
params: params,
keyDerivation: crypto.DeriveQuicCryptoAESKeys,
nullAEAD: nullAEAD,
paramsChan: paramsChan,
handshakeEvent: handshakeEvent,
initialVersion: initialVersion,
// The server might have sent greased versions in the Version Negotiation packet.
// We need strip those from the list, since they won't be included in the handshake tag.
negotiatedVersions: protocol.StripGreasedVersions(negotiatedVersions),
divNonceChan: divNonceChan,
logger: logger,
}
return cs, nil
}
func (h *cryptoSetupClient) HandleCryptoStream() error {
messageChan := make(chan HandshakeMessage)
errorChan := make(chan error, 1)
go func() {
for {
message, err := ParseHandshakeMessage(h.cryptoStream)
if err != nil {
errorChan <- qerr.Error(qerr.HandshakeFailed, err.Error())
return
}
messageChan <- message
}
}()
for {
if err := h.maybeUpgradeCrypto(); err != nil {
return err
}
h.mutex.RLock()
sendCHLO := h.secureAEAD == nil
h.mutex.RUnlock()
if sendCHLO {
if err := h.sendCHLO(); err != nil {
return err
}
}
var message HandshakeMessage
select {
case <-h.divNonceChan:
// there's no message to process, but we should try upgrading the crypto again
continue
case message = <-messageChan:
case err := <-errorChan:
return err
}
h.logger.Debugf("Got %s", message)
switch message.Tag {
case TagREJ:
if err := h.handleREJMessage(message.Data); err != nil {
return err
}
case TagSHLO:
params, err := h.handleSHLOMessage(message.Data)
if err != nil {
return err
}
// blocks until the session has received the parameters
h.paramsChan <- *params
h.handshakeEvent <- struct{}{}
close(h.handshakeEvent)
default:
return qerr.InvalidCryptoMessageType
}
}
}
func (h *cryptoSetupClient) handleREJMessage(cryptoData map[Tag][]byte) error {
var err error
if stk, ok := cryptoData[TagSTK]; ok {
h.stk = stk
}
if sno, ok := cryptoData[TagSNO]; ok {
h.sno = sno
}
// TODO: what happens if the server sends a different server config in two packets?
if scfg, ok := cryptoData[TagSCFG]; ok {
h.serverConfig, err = parseServerConfig(scfg)
if err != nil {
return err
}
if h.serverConfig.IsExpired() {
return qerr.CryptoServerConfigExpired
}
// now that we have a server config, we can use its OBIT value to generate a client nonce
if len(h.nonc) == 0 {
err = h.generateClientNonce()
if err != nil {
return err
}
}
}
if proof, ok := cryptoData[TagPROF]; ok {
h.proof = proof
h.chloForSignature = h.lastSentCHLO
}
if crt, ok := cryptoData[TagCERT]; ok {
err := h.certManager.SetData(crt)
if err != nil {
return qerr.Error(qerr.InvalidCryptoMessageParameter, "Certificate data invalid")
}
err = h.certManager.Verify(h.hostname)
if err != nil {
h.logger.Infof("Certificate validation failed: %s", err.Error())
return qerr.ProofInvalid
}
}
if h.serverConfig != nil && len(h.proof) != 0 && h.certManager.GetLeafCert() != nil {
validProof := h.certManager.VerifyServerProof(h.proof, h.chloForSignature, h.serverConfig.Get())
if !validProof {
h.logger.Infof("Server proof verification failed")
return qerr.ProofInvalid
}
h.serverVerified = true
}
return nil
}
func (h *cryptoSetupClient) handleSHLOMessage(cryptoData map[Tag][]byte) (*TransportParameters, error) {
h.mutex.Lock()
defer h.mutex.Unlock()
if !h.receivedSecurePacket {
return nil, qerr.Error(qerr.CryptoEncryptionLevelIncorrect, "unencrypted SHLO message")
}
if sno, ok := cryptoData[TagSNO]; ok {
h.sno = sno
}
serverPubs, ok := cryptoData[TagPUBS]
if !ok {
return nil, qerr.Error(qerr.CryptoMessageParameterNotFound, "PUBS")
}
verTag, ok := cryptoData[TagVER]
if !ok {
return nil, qerr.Error(qerr.InvalidCryptoMessageParameter, "server hello missing version list")
}
if !h.validateVersionList(verTag) {
return nil, qerr.Error(qerr.VersionNegotiationMismatch, "Downgrade attack detected")
}
nonce := append(h.nonc, h.sno...)
ephermalSharedSecret, err := h.serverConfig.kex.CalculateSharedKey(serverPubs)
if err != nil {
return nil, err
}
leafCert := h.certManager.GetLeafCert()
h.forwardSecureAEAD, err = h.keyDerivation(
true,
ephermalSharedSecret,
nonce,
h.connID,
h.lastSentCHLO,
h.serverConfig.Get(),
leafCert,
nil,
protocol.PerspectiveClient,
)
if err != nil {
return nil, err
}
h.logger.Debugf("Creating AEAD for forward-secure encryption. Stopping to accept all lower encryption levels.")
params, err := readHelloMap(cryptoData)
if err != nil {
return nil, qerr.InvalidCryptoMessageParameter
}
return params, nil
}
func (h *cryptoSetupClient) validateVersionList(verTags []byte) bool {
numNegotiatedVersions := len(h.negotiatedVersions)
if numNegotiatedVersions == 0 {
return true
}
if len(verTags)%4 != 0 || len(verTags)/4 != numNegotiatedVersions {
return false
}
b := bytes.NewReader(verTags)
for i := 0; i < numNegotiatedVersions; i++ {
v, err := utils.BigEndian.ReadUint32(b)
if err != nil { // should never occur, since the length was already checked
return false
}
if protocol.VersionNumber(v) != h.negotiatedVersions[i] {
return false
}
}
return true
}
func (h *cryptoSetupClient) Open(dst, src []byte, packetNumber protocol.PacketNumber, associatedData []byte) ([]byte, protocol.EncryptionLevel, error) {
h.mutex.RLock()
defer h.mutex.RUnlock()
if h.forwardSecureAEAD != nil {
data, err := h.forwardSecureAEAD.Open(dst, src, packetNumber, associatedData)
if err == nil {
return data, protocol.EncryptionForwardSecure, nil
}
return nil, protocol.EncryptionUnspecified, err
}
if h.secureAEAD != nil {
data, err := h.secureAEAD.Open(dst, src, packetNumber, associatedData)
if err == nil {
h.logger.Debugf("Received first secure packet. Stopping to accept unencrypted packets.")
h.receivedSecurePacket = true
return data, protocol.EncryptionSecure, nil
}
if h.receivedSecurePacket {
return nil, protocol.EncryptionUnspecified, err
}
}
res, err := h.nullAEAD.Open(dst, src, packetNumber, associatedData)
if err != nil {
return nil, protocol.EncryptionUnspecified, err
}
return res, protocol.EncryptionUnencrypted, nil
}
func (h *cryptoSetupClient) GetSealer() (protocol.EncryptionLevel, Sealer) {
h.mutex.RLock()
defer h.mutex.RUnlock()
if h.forwardSecureAEAD != nil {
return protocol.EncryptionForwardSecure, h.forwardSecureAEAD
} else if h.secureAEAD != nil {
return protocol.EncryptionSecure, h.secureAEAD
} else {
return protocol.EncryptionUnencrypted, h.nullAEAD
}
}
func (h *cryptoSetupClient) GetSealerForCryptoStream() (protocol.EncryptionLevel, Sealer) {
return protocol.EncryptionUnencrypted, h.nullAEAD
}
func (h *cryptoSetupClient) GetSealerWithEncryptionLevel(encLevel protocol.EncryptionLevel) (Sealer, error) {
h.mutex.RLock()
defer h.mutex.RUnlock()
switch encLevel {
case protocol.EncryptionUnencrypted:
return h.nullAEAD, nil
case protocol.EncryptionSecure:
if h.secureAEAD == nil {
return nil, errors.New("CryptoSetupClient: no secureAEAD")
}
return h.secureAEAD, nil
case protocol.EncryptionForwardSecure:
if h.forwardSecureAEAD == nil {
return nil, errors.New("CryptoSetupClient: no forwardSecureAEAD")
}
return h.forwardSecureAEAD, nil
}
return nil, errors.New("CryptoSetupClient: no encryption level specified")
}
func (h *cryptoSetupClient) ConnectionState() ConnectionState {
h.mutex.Lock()
defer h.mutex.Unlock()
return ConnectionState{
HandshakeComplete: h.forwardSecureAEAD != nil,
PeerCertificates: h.certManager.GetChain(),
}
}
func (h *cryptoSetupClient) SetDiversificationNonce(divNonce []byte) error {
h.mutex.Lock()
if len(h.diversificationNonce) > 0 {
defer h.mutex.Unlock()
if !bytes.Equal(h.diversificationNonce, divNonce) {
return errConflictingDiversificationNonces
}
return nil
}
h.diversificationNonce = divNonce
h.mutex.Unlock()
h.divNonceChan <- struct{}{}
return nil
}
func (h *cryptoSetupClient) sendCHLO() error {
h.clientHelloCounter++
if h.clientHelloCounter > protocol.MaxClientHellos {
return qerr.Error(qerr.CryptoTooManyRejects, fmt.Sprintf("More than %d rejects", protocol.MaxClientHellos))
}
b := &bytes.Buffer{}
tags, err := h.getTags()
if err != nil {
return err
}
h.addPadding(tags)
message := HandshakeMessage{
Tag: TagCHLO,
Data: tags,
}
h.logger.Debugf("Sending %s", message)
message.Write(b)
_, err = h.cryptoStream.Write(b.Bytes())
if err != nil {
return err
}
h.lastSentCHLO = b.Bytes()
return nil
}
func (h *cryptoSetupClient) getTags() (map[Tag][]byte, error) {
tags := h.params.getHelloMap()
tags[TagSNI] = []byte(h.hostname)
tags[TagPDMD] = []byte("X509")
ccs := h.certManager.GetCommonCertificateHashes()
if len(ccs) > 0 {
tags[TagCCS] = ccs
}
versionTag := make([]byte, 4)
binary.BigEndian.PutUint32(versionTag, uint32(h.initialVersion))
tags[TagVER] = versionTag
if len(h.stk) > 0 {
tags[TagSTK] = h.stk
}
if len(h.sno) > 0 {
tags[TagSNO] = h.sno
}
if h.serverConfig != nil {
tags[TagSCID] = h.serverConfig.ID
leafCert := h.certManager.GetLeafCert()
if leafCert != nil {
certHash, _ := h.certManager.GetLeafCertHash()
xlct := make([]byte, 8)
binary.LittleEndian.PutUint64(xlct, certHash)
tags[TagNONC] = h.nonc
tags[TagXLCT] = xlct
tags[TagKEXS] = []byte("C255")
tags[TagAEAD] = []byte("AESG")
tags[TagPUBS] = h.serverConfig.kex.PublicKey() // TODO: check if 3 bytes need to be prepended
}
}
return tags, nil
}
// add a TagPAD to a tagMap, such that the total size will be bigger than the ClientHelloMinimumSize
func (h *cryptoSetupClient) addPadding(tags map[Tag][]byte) {
var size int
for _, tag := range tags {
size += 8 + len(tag) // 4 bytes for the tag + 4 bytes for the offset + the length of the data
}
paddingSize := protocol.MinClientHelloSize - size
if paddingSize > 0 {
tags[TagPAD] = bytes.Repeat([]byte{0}, paddingSize)
}
}
func (h *cryptoSetupClient) maybeUpgradeCrypto() error {
if !h.serverVerified {
return nil
}
h.mutex.Lock()
defer h.mutex.Unlock()
leafCert := h.certManager.GetLeafCert()
if h.secureAEAD == nil && (h.serverConfig != nil && len(h.serverConfig.sharedSecret) > 0 && len(h.nonc) > 0 && len(leafCert) > 0 && len(h.diversificationNonce) > 0 && len(h.lastSentCHLO) > 0) {
var err error
var nonce []byte
if h.sno == nil {
nonce = h.nonc
} else {
nonce = append(h.nonc, h.sno...)
}
h.secureAEAD, err = h.keyDerivation(
false,
h.serverConfig.sharedSecret,
nonce,
h.connID,
h.lastSentCHLO,
h.serverConfig.Get(),
leafCert,
h.diversificationNonce,
protocol.PerspectiveClient,
)
if err != nil {
return err
}
h.logger.Debugf("Creating AEAD for secure encryption.")
h.handshakeEvent <- struct{}{}
}
return nil
}
func (h *cryptoSetupClient) generateClientNonce() error {
if len(h.nonc) > 0 {
return errClientNonceAlreadyExists
}
nonc := make([]byte, 32)
binary.BigEndian.PutUint32(nonc, uint32(time.Now().Unix()))
if len(h.serverConfig.obit) != 8 {
return errNoObitForClientNonce
}
copy(nonc[4:12], h.serverConfig.obit)
_, err := rand.Read(nonc[12:])
if err != nil {
return err
}
h.nonc = nonc
return nil
}

View File

@@ -0,0 +1,467 @@
package handshake
import (
"bytes"
"crypto/rand"
"encoding/binary"
"errors"
"io"
"net"
"sync"
"github.com/lucas-clemente/quic-go/internal/crypto"
"github.com/lucas-clemente/quic-go/internal/protocol"
"github.com/lucas-clemente/quic-go/internal/utils"
"github.com/lucas-clemente/quic-go/qerr"
)
// QuicCryptoKeyDerivationFunction is used for key derivation
type QuicCryptoKeyDerivationFunction func(forwardSecure bool, sharedSecret, nonces []byte, connID protocol.ConnectionID, chlo []byte, scfg []byte, cert []byte, divNonce []byte, pers protocol.Perspective) (crypto.AEAD, error)
// KeyExchangeFunction is used to make a new KEX
type KeyExchangeFunction func() (crypto.KeyExchange, error)
// The CryptoSetupServer handles all things crypto for the Session
type cryptoSetupServer struct {
mutex sync.RWMutex
connID protocol.ConnectionID
remoteAddr net.Addr
scfg *ServerConfig
diversificationNonce []byte
version protocol.VersionNumber
supportedVersions []protocol.VersionNumber
acceptSTKCallback func(net.Addr, *Cookie) bool
nullAEAD crypto.AEAD
secureAEAD crypto.AEAD
forwardSecureAEAD crypto.AEAD
receivedForwardSecurePacket bool
receivedSecurePacket bool
sentSHLO chan struct{} // this channel is closed as soon as the SHLO has been written
receivedParams bool
paramsChan chan<- TransportParameters
handshakeEvent chan<- struct{}
keyDerivation QuicCryptoKeyDerivationFunction
keyExchange KeyExchangeFunction
cryptoStream io.ReadWriter
params *TransportParameters
sni string // need to fill out the ConnectionState
logger utils.Logger
}
var _ CryptoSetup = &cryptoSetupServer{}
// ErrNSTPExperiment is returned when the client sends the NSTP tag in the CHLO.
// This is an experiment implemented by Chrome in QUIC 38, which we don't support at this point.
var ErrNSTPExperiment = qerr.Error(qerr.InvalidCryptoMessageParameter, "NSTP experiment. Unsupported")
// NewCryptoSetup creates a new CryptoSetup instance for a server
func NewCryptoSetup(
cryptoStream io.ReadWriter,
connID protocol.ConnectionID,
remoteAddr net.Addr,
version protocol.VersionNumber,
divNonce []byte,
scfg *ServerConfig,
params *TransportParameters,
supportedVersions []protocol.VersionNumber,
acceptSTK func(net.Addr, *Cookie) bool,
paramsChan chan<- TransportParameters,
handshakeEvent chan<- struct{},
logger utils.Logger,
) (CryptoSetup, error) {
nullAEAD, err := crypto.NewNullAEAD(protocol.PerspectiveServer, connID, version)
if err != nil {
return nil, err
}
return &cryptoSetupServer{
cryptoStream: cryptoStream,
connID: connID,
remoteAddr: remoteAddr,
version: version,
supportedVersions: supportedVersions,
diversificationNonce: divNonce,
scfg: scfg,
keyDerivation: crypto.DeriveQuicCryptoAESKeys,
keyExchange: getEphermalKEX,
nullAEAD: nullAEAD,
params: params,
acceptSTKCallback: acceptSTK,
sentSHLO: make(chan struct{}),
paramsChan: paramsChan,
handshakeEvent: handshakeEvent,
logger: logger,
}, nil
}
// HandleCryptoStream reads and writes messages on the crypto stream
func (h *cryptoSetupServer) HandleCryptoStream() error {
for {
var chloData bytes.Buffer
message, err := ParseHandshakeMessage(io.TeeReader(h.cryptoStream, &chloData))
if err != nil {
return qerr.HandshakeFailed
}
if message.Tag != TagCHLO {
return qerr.InvalidCryptoMessageType
}
h.logger.Debugf("Got %s", message)
done, err := h.handleMessage(chloData.Bytes(), message.Data)
if err != nil {
return err
}
if done {
return nil
}
}
}
func (h *cryptoSetupServer) handleMessage(chloData []byte, cryptoData map[Tag][]byte) (bool, error) {
if _, isNSTPExperiment := cryptoData[TagNSTP]; isNSTPExperiment {
return false, ErrNSTPExperiment
}
sniSlice, ok := cryptoData[TagSNI]
if !ok {
return false, qerr.Error(qerr.CryptoMessageParameterNotFound, "SNI required")
}
sni := string(sniSlice)
if sni == "" {
return false, qerr.Error(qerr.CryptoMessageParameterNotFound, "SNI required")
}
h.sni = sni
// prevent version downgrade attacks
// see https://groups.google.com/a/chromium.org/forum/#!topic/proto-quic/N-de9j63tCk for a discussion and examples
verSlice, ok := cryptoData[TagVER]
if !ok {
return false, qerr.Error(qerr.InvalidCryptoMessageParameter, "client hello missing version tag")
}
if len(verSlice) != 4 {
return false, qerr.Error(qerr.InvalidCryptoMessageParameter, "incorrect version tag")
}
ver := protocol.VersionNumber(binary.BigEndian.Uint32(verSlice))
// If the client's preferred version is not the version we are currently speaking, then the client went through a version negotiation. In this case, we need to make sure that we actually do not support this version and that it wasn't a downgrade attack.
if ver != h.version && protocol.IsSupportedVersion(h.supportedVersions, ver) {
return false, qerr.Error(qerr.VersionNegotiationMismatch, "Downgrade attack detected")
}
var reply []byte
var err error
certUncompressed, err := h.scfg.certChain.GetLeafCert(sni)
if err != nil {
return false, err
}
params, err := readHelloMap(cryptoData)
if err != nil {
return false, err
}
// blocks until the session has received the parameters
if !h.receivedParams {
h.receivedParams = true
h.paramsChan <- *params
}
if !h.isInchoateCHLO(cryptoData, certUncompressed) {
// We have a CHLO with a proper server config ID, do a 0-RTT handshake
reply, err = h.handleCHLO(sni, chloData, cryptoData)
if err != nil {
return false, err
}
if _, err := h.cryptoStream.Write(reply); err != nil {
return false, err
}
h.handshakeEvent <- struct{}{}
close(h.sentSHLO)
return true, nil
}
// We have an inchoate or non-matching CHLO, we now send a rejection
reply, err = h.handleInchoateCHLO(sni, chloData, cryptoData)
if err != nil {
return false, err
}
_, err = h.cryptoStream.Write(reply)
return false, err
}
// Open a message
func (h *cryptoSetupServer) Open(dst, src []byte, packetNumber protocol.PacketNumber, associatedData []byte) ([]byte, protocol.EncryptionLevel, error) {
h.mutex.RLock()
defer h.mutex.RUnlock()
if h.forwardSecureAEAD != nil {
res, err := h.forwardSecureAEAD.Open(dst, src, packetNumber, associatedData)
if err == nil {
if !h.receivedForwardSecurePacket { // this is the first forward secure packet we receive from the client
h.logger.Debugf("Received first forward-secure packet. Stopping to accept all lower encryption levels.")
h.receivedForwardSecurePacket = true
// wait for the send on the handshakeEvent chan
<-h.sentSHLO
close(h.handshakeEvent)
}
return res, protocol.EncryptionForwardSecure, nil
}
if h.receivedForwardSecurePacket {
return nil, protocol.EncryptionUnspecified, err
}
}
if h.secureAEAD != nil {
res, err := h.secureAEAD.Open(dst, src, packetNumber, associatedData)
if err == nil {
h.logger.Debugf("Received first secure packet. Stopping to accept unencrypted packets.")
h.receivedSecurePacket = true
return res, protocol.EncryptionSecure, nil
}
if h.receivedSecurePacket {
return nil, protocol.EncryptionUnspecified, err
}
}
res, err := h.nullAEAD.Open(dst, src, packetNumber, associatedData)
if err != nil {
return res, protocol.EncryptionUnspecified, err
}
return res, protocol.EncryptionUnencrypted, err
}
func (h *cryptoSetupServer) GetSealer() (protocol.EncryptionLevel, Sealer) {
h.mutex.RLock()
defer h.mutex.RUnlock()
if h.forwardSecureAEAD != nil {
return protocol.EncryptionForwardSecure, h.forwardSecureAEAD
}
return protocol.EncryptionUnencrypted, h.nullAEAD
}
func (h *cryptoSetupServer) GetSealerForCryptoStream() (protocol.EncryptionLevel, Sealer) {
h.mutex.RLock()
defer h.mutex.RUnlock()
if h.secureAEAD != nil {
return protocol.EncryptionSecure, h.secureAEAD
}
return protocol.EncryptionUnencrypted, h.nullAEAD
}
func (h *cryptoSetupServer) GetSealerWithEncryptionLevel(encLevel protocol.EncryptionLevel) (Sealer, error) {
h.mutex.RLock()
defer h.mutex.RUnlock()
switch encLevel {
case protocol.EncryptionUnencrypted:
return h.nullAEAD, nil
case protocol.EncryptionSecure:
if h.secureAEAD == nil {
return nil, errors.New("CryptoSetupServer: no secureAEAD")
}
return h.secureAEAD, nil
case protocol.EncryptionForwardSecure:
if h.forwardSecureAEAD == nil {
return nil, errors.New("CryptoSetupServer: no forwardSecureAEAD")
}
return h.forwardSecureAEAD, nil
}
return nil, errors.New("CryptoSetupServer: no encryption level specified")
}
func (h *cryptoSetupServer) isInchoateCHLO(cryptoData map[Tag][]byte, cert []byte) bool {
if _, ok := cryptoData[TagPUBS]; !ok {
return true
}
scid, ok := cryptoData[TagSCID]
if !ok || !bytes.Equal(h.scfg.ID, scid) {
return true
}
xlctTag, ok := cryptoData[TagXLCT]
if !ok || len(xlctTag) != 8 {
return true
}
xlct := binary.LittleEndian.Uint64(xlctTag)
if crypto.HashCert(cert) != xlct {
return true
}
return !h.acceptSTK(cryptoData[TagSTK])
}
func (h *cryptoSetupServer) acceptSTK(token []byte) bool {
stk, err := h.scfg.cookieGenerator.DecodeToken(token)
if err != nil {
h.logger.Debugf("STK invalid: %s", err.Error())
return false
}
return h.acceptSTKCallback(h.remoteAddr, stk)
}
func (h *cryptoSetupServer) handleInchoateCHLO(sni string, chlo []byte, cryptoData map[Tag][]byte) ([]byte, error) {
token, err := h.scfg.cookieGenerator.NewToken(h.remoteAddr)
if err != nil {
return nil, err
}
replyMap := map[Tag][]byte{
TagSCFG: h.scfg.Get(),
TagSTK: token,
TagSVID: []byte("quic-go"),
}
if h.acceptSTK(cryptoData[TagSTK]) {
proof, err := h.scfg.Sign(sni, chlo)
if err != nil {
return nil, err
}
commonSetHashes := cryptoData[TagCCS]
cachedCertsHashes := cryptoData[TagCCRT]
certCompressed, err := h.scfg.GetCertsCompressed(sni, commonSetHashes, cachedCertsHashes)
if err != nil {
return nil, err
}
// Token was valid, send more details
replyMap[TagPROF] = proof
replyMap[TagCERT] = certCompressed
}
message := HandshakeMessage{
Tag: TagREJ,
Data: replyMap,
}
var serverReply bytes.Buffer
message.Write(&serverReply)
h.logger.Debugf("Sending %s", message)
return serverReply.Bytes(), nil
}
func (h *cryptoSetupServer) handleCHLO(sni string, data []byte, cryptoData map[Tag][]byte) ([]byte, error) {
// We have a CHLO matching our server config, we can continue with the 0-RTT handshake
sharedSecret, err := h.scfg.kex.CalculateSharedKey(cryptoData[TagPUBS])
if err != nil {
return nil, err
}
h.mutex.Lock()
defer h.mutex.Unlock()
certUncompressed, err := h.scfg.certChain.GetLeafCert(sni)
if err != nil {
return nil, err
}
serverNonce := make([]byte, 32)
if _, err = rand.Read(serverNonce); err != nil {
return nil, err
}
clientNonce := cryptoData[TagNONC]
err = h.validateClientNonce(clientNonce)
if err != nil {
return nil, err
}
aead := cryptoData[TagAEAD]
if !bytes.Equal(aead, []byte("AESG")) {
return nil, qerr.Error(qerr.CryptoNoSupport, "Unsupported AEAD or KEXS")
}
kexs := cryptoData[TagKEXS]
if !bytes.Equal(kexs, []byte("C255")) {
return nil, qerr.Error(qerr.CryptoNoSupport, "Unsupported AEAD or KEXS")
}
h.secureAEAD, err = h.keyDerivation(
false,
sharedSecret,
clientNonce,
h.connID,
data,
h.scfg.Get(),
certUncompressed,
h.diversificationNonce,
protocol.PerspectiveServer,
)
if err != nil {
return nil, err
}
h.logger.Debugf("Creating AEAD for secure encryption.")
h.handshakeEvent <- struct{}{}
// Generate a new curve instance to derive the forward secure key
var fsNonce bytes.Buffer
fsNonce.Write(clientNonce)
fsNonce.Write(serverNonce)
ephermalKex, err := h.keyExchange()
if err != nil {
return nil, err
}
ephermalSharedSecret, err := ephermalKex.CalculateSharedKey(cryptoData[TagPUBS])
if err != nil {
return nil, err
}
h.forwardSecureAEAD, err = h.keyDerivation(
true,
ephermalSharedSecret,
fsNonce.Bytes(),
h.connID,
data,
h.scfg.Get(),
certUncompressed,
nil,
protocol.PerspectiveServer,
)
if err != nil {
return nil, err
}
h.logger.Debugf("Creating AEAD for forward-secure encryption.")
replyMap := h.params.getHelloMap()
// add crypto parameters
verTag := &bytes.Buffer{}
for _, v := range h.supportedVersions {
utils.BigEndian.WriteUint32(verTag, uint32(v))
}
replyMap[TagPUBS] = ephermalKex.PublicKey()
replyMap[TagSNO] = serverNonce
replyMap[TagVER] = verTag.Bytes()
// note that the SHLO *has* to fit into one packet
message := HandshakeMessage{
Tag: TagSHLO,
Data: replyMap,
}
var reply bytes.Buffer
message.Write(&reply)
h.logger.Debugf("Sending %s", message)
return reply.Bytes(), nil
}
func (h *cryptoSetupServer) ConnectionState() ConnectionState {
h.mutex.Lock()
defer h.mutex.Unlock()
return ConnectionState{
ServerName: h.sni,
HandshakeComplete: h.receivedForwardSecurePacket,
}
}
func (h *cryptoSetupServer) validateClientNonce(nonce []byte) error {
if len(nonce) != 32 {
return qerr.Error(qerr.InvalidCryptoMessageParameter, "invalid client nonce length")
}
if !bytes.Equal(nonce[4:12], h.scfg.obit) {
return qerr.Error(qerr.InvalidCryptoMessageParameter, "OBIT not matching")
}
return nil
}

View File

@@ -0,0 +1,163 @@
package handshake
import (
"errors"
"fmt"
"io"
"sync"
"github.com/bifurcation/mint"
"github.com/lucas-clemente/quic-go/internal/crypto"
"github.com/lucas-clemente/quic-go/internal/protocol"
)
// KeyDerivationFunction is used for key derivation
type KeyDerivationFunction func(crypto.TLSExporter, protocol.Perspective) (crypto.AEAD, error)
type cryptoSetupTLS struct {
mutex sync.RWMutex
perspective protocol.Perspective
keyDerivation KeyDerivationFunction
nullAEAD crypto.AEAD
aead crypto.AEAD
tls mintTLS
conn *cryptoStreamConn
handshakeEvent chan<- struct{}
}
var _ CryptoSetupTLS = &cryptoSetupTLS{}
// NewCryptoSetupTLSServer creates a new TLS CryptoSetup instance for a server
func NewCryptoSetupTLSServer(
cryptoStream io.ReadWriter,
connID protocol.ConnectionID,
config *mint.Config,
handshakeEvent chan<- struct{},
version protocol.VersionNumber,
) (CryptoSetupTLS, error) {
nullAEAD, err := crypto.NewNullAEAD(protocol.PerspectiveServer, connID, version)
if err != nil {
return nil, err
}
conn := newCryptoStreamConn(cryptoStream)
tls := mint.Server(conn, config)
return &cryptoSetupTLS{
tls: tls,
conn: conn,
nullAEAD: nullAEAD,
perspective: protocol.PerspectiveServer,
keyDerivation: crypto.DeriveAESKeys,
handshakeEvent: handshakeEvent,
}, nil
}
// NewCryptoSetupTLSClient creates a new TLS CryptoSetup instance for a client
func NewCryptoSetupTLSClient(
cryptoStream io.ReadWriter,
connID protocol.ConnectionID,
config *mint.Config,
handshakeEvent chan<- struct{},
version protocol.VersionNumber,
) (CryptoSetupTLS, error) {
nullAEAD, err := crypto.NewNullAEAD(protocol.PerspectiveClient, connID, version)
if err != nil {
return nil, err
}
conn := newCryptoStreamConn(cryptoStream)
tls := mint.Client(conn, config)
return &cryptoSetupTLS{
tls: tls,
conn: conn,
perspective: protocol.PerspectiveClient,
nullAEAD: nullAEAD,
keyDerivation: crypto.DeriveAESKeys,
handshakeEvent: handshakeEvent,
}, nil
}
func (h *cryptoSetupTLS) HandleCryptoStream() error {
for {
if alert := h.tls.Handshake(); alert != mint.AlertNoAlert {
return fmt.Errorf("TLS handshake error: %s (Alert %d)", alert.String(), alert)
}
state := h.tls.ConnectionState().HandshakeState
if err := h.conn.Flush(); err != nil {
return err
}
if state == mint.StateClientConnected || state == mint.StateServerConnected {
break
}
}
aead, err := h.keyDerivation(h.tls, h.perspective)
if err != nil {
return err
}
h.mutex.Lock()
h.aead = aead
h.mutex.Unlock()
h.handshakeEvent <- struct{}{}
close(h.handshakeEvent)
return nil
}
func (h *cryptoSetupTLS) OpenHandshake(dst, src []byte, packetNumber protocol.PacketNumber, associatedData []byte) ([]byte, error) {
return h.nullAEAD.Open(dst, src, packetNumber, associatedData)
}
func (h *cryptoSetupTLS) Open1RTT(dst, src []byte, packetNumber protocol.PacketNumber, associatedData []byte) ([]byte, error) {
h.mutex.RLock()
defer h.mutex.RUnlock()
if h.aead == nil {
return nil, errors.New("no 1-RTT sealer")
}
return h.aead.Open(dst, src, packetNumber, associatedData)
}
func (h *cryptoSetupTLS) GetSealer() (protocol.EncryptionLevel, Sealer) {
h.mutex.RLock()
defer h.mutex.RUnlock()
if h.aead != nil {
return protocol.EncryptionForwardSecure, h.aead
}
return protocol.EncryptionUnencrypted, h.nullAEAD
}
func (h *cryptoSetupTLS) GetSealerWithEncryptionLevel(encLevel protocol.EncryptionLevel) (Sealer, error) {
errNoSealer := fmt.Errorf("CryptoSetup: no sealer with encryption level %s", encLevel.String())
h.mutex.RLock()
defer h.mutex.RUnlock()
switch encLevel {
case protocol.EncryptionUnencrypted:
return h.nullAEAD, nil
case protocol.EncryptionForwardSecure:
if h.aead == nil {
return nil, errNoSealer
}
return h.aead, nil
default:
return nil, errNoSealer
}
}
func (h *cryptoSetupTLS) GetSealerForCryptoStream() (protocol.EncryptionLevel, Sealer) {
return protocol.EncryptionUnencrypted, h.nullAEAD
}
func (h *cryptoSetupTLS) ConnectionState() ConnectionState {
h.mutex.Lock()
defer h.mutex.Unlock()
mintConnState := h.tls.ConnectionState()
return ConnectionState{
// TODO: set the ServerName, once mint exports it
HandshakeComplete: h.aead != nil,
PeerCertificates: mintConnState.PeerCertificates,
}
}

View File

@@ -0,0 +1,69 @@
package handshake
import (
"bytes"
"io"
"net"
"time"
)
type cryptoStreamConn struct {
buffer *bytes.Buffer
stream io.ReadWriter
}
var _ net.Conn = &cryptoStreamConn{}
func newCryptoStreamConn(stream io.ReadWriter) *cryptoStreamConn {
return &cryptoStreamConn{
stream: stream,
buffer: &bytes.Buffer{},
}
}
func (c *cryptoStreamConn) Read(b []byte) (int, error) {
return c.stream.Read(b)
}
func (c *cryptoStreamConn) Write(p []byte) (int, error) {
return c.buffer.Write(p)
}
func (c *cryptoStreamConn) Flush() error {
if c.buffer.Len() == 0 {
return nil
}
_, err := c.stream.Write(c.buffer.Bytes())
c.buffer.Reset()
return err
}
// Close is not implemented
func (c *cryptoStreamConn) Close() error {
return nil
}
// LocalAddr is not implemented
func (c *cryptoStreamConn) LocalAddr() net.Addr {
return nil
}
// RemoteAddr is not implemented
func (c *cryptoStreamConn) RemoteAddr() net.Addr {
return nil
}
// SetReadDeadline is not implemented
func (c *cryptoStreamConn) SetReadDeadline(time.Time) error {
return nil
}
// SetWriteDeadline is not implemented
func (c *cryptoStreamConn) SetWriteDeadline(time.Time) error {
return nil
}
// SetDeadline is not implemented
func (c *cryptoStreamConn) SetDeadline(time.Time) error {
return nil
}

View File

@@ -0,0 +1,48 @@
package handshake
import (
"sync"
"time"
"github.com/lucas-clemente/quic-go/internal/crypto"
"github.com/lucas-clemente/quic-go/internal/protocol"
)
var (
kexLifetime = protocol.EphermalKeyLifetime
kexCurrent crypto.KeyExchange
kexCurrentTime time.Time
kexMutex sync.RWMutex
)
// getEphermalKEX returns the currently active KEX, which changes every protocol.EphermalKeyLifetime
// See the explanation from the QUIC crypto doc:
//
// A single connection is the usual scope for forward security, but the security
// difference between an ephemeral key used for a single connection, and one
// used for all connections for 60 seconds is negligible. Thus we can amortise
// the Diffie-Hellman key generation at the server over all the connections in a
// small time span.
func getEphermalKEX() (crypto.KeyExchange, error) {
kexMutex.RLock()
res := kexCurrent
t := kexCurrentTime
kexMutex.RUnlock()
if res != nil && time.Since(t) < kexLifetime {
return res, nil
}
kexMutex.Lock()
defer kexMutex.Unlock()
// Check if still unfulfilled
if kexCurrent == nil || time.Since(kexCurrentTime) >= kexLifetime {
kex, err := crypto.NewCurve25519KEX()
if err != nil {
return nil, err
}
kexCurrent = kex
kexCurrentTime = time.Now()
return kexCurrent, nil
}
return kexCurrent, nil
}

View File

@@ -0,0 +1,137 @@
package handshake
import (
"bytes"
"encoding/binary"
"fmt"
"io"
"sort"
"github.com/lucas-clemente/quic-go/internal/protocol"
"github.com/lucas-clemente/quic-go/internal/utils"
"github.com/lucas-clemente/quic-go/qerr"
)
// A HandshakeMessage is a handshake message
type HandshakeMessage struct {
Tag Tag
Data map[Tag][]byte
}
var _ fmt.Stringer = &HandshakeMessage{}
// ParseHandshakeMessage reads a crypto message
func ParseHandshakeMessage(r io.Reader) (HandshakeMessage, error) {
slice4 := make([]byte, 4)
if _, err := io.ReadFull(r, slice4); err != nil {
return HandshakeMessage{}, err
}
messageTag := Tag(binary.LittleEndian.Uint32(slice4))
if _, err := io.ReadFull(r, slice4); err != nil {
return HandshakeMessage{}, err
}
nPairs := binary.LittleEndian.Uint32(slice4)
if nPairs > protocol.CryptoMaxParams {
return HandshakeMessage{}, qerr.CryptoTooManyEntries
}
index := make([]byte, nPairs*8)
if _, err := io.ReadFull(r, index); err != nil {
return HandshakeMessage{}, err
}
resultMap := map[Tag][]byte{}
var dataStart uint32
for indexPos := 0; indexPos < int(nPairs)*8; indexPos += 8 {
tag := Tag(binary.LittleEndian.Uint32(index[indexPos : indexPos+4]))
dataEnd := binary.LittleEndian.Uint32(index[indexPos+4 : indexPos+8])
dataLen := dataEnd - dataStart
if dataLen > protocol.CryptoParameterMaxLength {
return HandshakeMessage{}, qerr.Error(qerr.CryptoInvalidValueLength, "value too long")
}
data := make([]byte, dataLen)
if _, err := io.ReadFull(r, data); err != nil {
return HandshakeMessage{}, err
}
resultMap[tag] = data
dataStart = dataEnd
}
return HandshakeMessage{
Tag: messageTag,
Data: resultMap}, nil
}
// Write writes a crypto message
func (h HandshakeMessage) Write(b *bytes.Buffer) {
data := h.Data
utils.LittleEndian.WriteUint32(b, uint32(h.Tag))
utils.LittleEndian.WriteUint16(b, uint16(len(data)))
utils.LittleEndian.WriteUint16(b, 0)
// Save current position in the buffer, so that we can update the index in-place later
indexStart := b.Len()
indexData := make([]byte, 8*len(data))
b.Write(indexData) // Will be updated later
offset := uint32(0)
for i, t := range h.getTagsSorted() {
v := data[t]
b.Write(v)
offset += uint32(len(v))
binary.LittleEndian.PutUint32(indexData[i*8:], uint32(t))
binary.LittleEndian.PutUint32(indexData[i*8+4:], offset)
}
// Now we write the index data for real
copy(b.Bytes()[indexStart:], indexData)
}
func (h *HandshakeMessage) getTagsSorted() []Tag {
tags := make([]Tag, len(h.Data))
i := 0
for t := range h.Data {
tags[i] = t
i++
}
sort.Slice(tags, func(i, j int) bool {
return tags[i] < tags[j]
})
return tags
}
func (h HandshakeMessage) String() string {
var pad string
res := tagToString(h.Tag) + ":\n"
for _, tag := range h.getTagsSorted() {
if tag == TagPAD {
pad = fmt.Sprintf("\t%s: (%d bytes)\n", tagToString(tag), len(h.Data[tag]))
} else {
res += fmt.Sprintf("\t%s: %#v\n", tagToString(tag), string(h.Data[tag]))
}
}
if len(pad) > 0 {
res += pad
}
return res
}
func tagToString(tag Tag) string {
b := make([]byte, 4)
binary.LittleEndian.PutUint32(b, uint32(tag))
for i := range b {
if b[i] == 0 {
b[i] = ' '
}
}
return string(b)
}

View File

@@ -0,0 +1,61 @@
package handshake
import (
"crypto/x509"
"github.com/bifurcation/mint"
"github.com/lucas-clemente/quic-go/internal/crypto"
"github.com/lucas-clemente/quic-go/internal/protocol"
)
// Sealer seals a packet
type Sealer interface {
Seal(dst, src []byte, packetNumber protocol.PacketNumber, associatedData []byte) []byte
Overhead() int
}
// mintTLS combines some methods needed to interact with mint.
type mintTLS interface {
crypto.TLSExporter
Handshake() mint.Alert
}
// A TLSExtensionHandler sends and received the QUIC TLS extension.
// It provides the parameters sent by the peer on a channel.
type TLSExtensionHandler interface {
Send(mint.HandshakeType, *mint.ExtensionList) error
Receive(mint.HandshakeType, *mint.ExtensionList) error
GetPeerParams() <-chan TransportParameters
}
type baseCryptoSetup interface {
HandleCryptoStream() error
ConnectionState() ConnectionState
GetSealer() (protocol.EncryptionLevel, Sealer)
GetSealerWithEncryptionLevel(protocol.EncryptionLevel) (Sealer, error)
GetSealerForCryptoStream() (protocol.EncryptionLevel, Sealer)
}
// CryptoSetup is the crypto setup used by gQUIC
type CryptoSetup interface {
baseCryptoSetup
Open(dst, src []byte, packetNumber protocol.PacketNumber, associatedData []byte) ([]byte, protocol.EncryptionLevel, error)
}
// CryptoSetupTLS is the crypto setup used by IETF QUIC
type CryptoSetupTLS interface {
baseCryptoSetup
OpenHandshake(dst, src []byte, packetNumber protocol.PacketNumber, associatedData []byte) ([]byte, error)
Open1RTT(dst, src []byte, packetNumber protocol.PacketNumber, associatedData []byte) ([]byte, error)
}
// ConnectionState records basic details about the QUIC connection.
// Warning: This API should not be considered stable and might change soon.
type ConnectionState struct {
HandshakeComplete bool // handshake is complete
ServerName string // server name requested by client, if any (server side only)
PeerCertificates []*x509.Certificate // certificate chain presented by remote peer
}

View File

@@ -0,0 +1,3 @@
package handshake
//go:generate sh -c "../mockgen_internal.sh handshake mock_mint_tls_test.go github.com/lucas-clemente/quic-go/internal/handshake mintTLS"

View File

@@ -0,0 +1,73 @@
package handshake
import (
"bytes"
"crypto/rand"
"github.com/lucas-clemente/quic-go/internal/crypto"
)
// ServerConfig is a server config
type ServerConfig struct {
kex crypto.KeyExchange
certChain crypto.CertChain
ID []byte
obit []byte
cookieGenerator *CookieGenerator
}
// NewServerConfig creates a new server config
func NewServerConfig(kex crypto.KeyExchange, certChain crypto.CertChain) (*ServerConfig, error) {
id := make([]byte, 16)
_, err := rand.Read(id)
if err != nil {
return nil, err
}
obit := make([]byte, 8)
if _, err = rand.Read(obit); err != nil {
return nil, err
}
cookieGenerator, err := NewCookieGenerator()
if err != nil {
return nil, err
}
return &ServerConfig{
kex: kex,
certChain: certChain,
ID: id,
obit: obit,
cookieGenerator: cookieGenerator,
}, nil
}
// Get the server config binary representation
func (s *ServerConfig) Get() []byte {
var serverConfig bytes.Buffer
msg := HandshakeMessage{
Tag: TagSCFG,
Data: map[Tag][]byte{
TagSCID: s.ID,
TagKEXS: []byte("C255"),
TagAEAD: []byte("AESG"),
TagPUBS: append([]byte{0x20, 0x00, 0x00}, s.kex.PublicKey()...),
TagOBIT: s.obit,
TagEXPY: {0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff},
},
}
msg.Write(&serverConfig)
return serverConfig.Bytes()
}
// Sign the server config and CHLO with the server's keyData
func (s *ServerConfig) Sign(sni string, chlo []byte) ([]byte, error) {
return s.certChain.SignServerProof(sni, chlo, s.Get())
}
// GetCertsCompressed returns the certificate data
func (s *ServerConfig) GetCertsCompressed(sni string, commonSetHashes, compressedHashes []byte) ([]byte, error) {
return s.certChain.GetCertsCompressed(sni, commonSetHashes, compressedHashes)
}

View File

@@ -0,0 +1,184 @@
package handshake
import (
"bytes"
"encoding/binary"
"errors"
"math"
"time"
"github.com/lucas-clemente/quic-go/internal/crypto"
"github.com/lucas-clemente/quic-go/internal/utils"
"github.com/lucas-clemente/quic-go/qerr"
)
type serverConfigClient struct {
raw []byte
ID []byte
obit []byte
expiry time.Time
kex crypto.KeyExchange
sharedSecret []byte
}
var (
errMessageNotServerConfig = errors.New("ServerConfig must have TagSCFG")
)
// parseServerConfig parses a server config
func parseServerConfig(data []byte) (*serverConfigClient, error) {
message, err := ParseHandshakeMessage(bytes.NewReader(data))
if err != nil {
return nil, err
}
if message.Tag != TagSCFG {
return nil, errMessageNotServerConfig
}
scfg := &serverConfigClient{raw: data}
err = scfg.parseValues(message.Data)
if err != nil {
return nil, err
}
return scfg, nil
}
func (s *serverConfigClient) parseValues(tagMap map[Tag][]byte) error {
// SCID
scfgID, ok := tagMap[TagSCID]
if !ok {
return qerr.Error(qerr.CryptoMessageParameterNotFound, "SCID")
}
if len(scfgID) != 16 {
return qerr.Error(qerr.CryptoInvalidValueLength, "SCID")
}
s.ID = scfgID
// KEXS
// TODO: setup Key Exchange
kexs, ok := tagMap[TagKEXS]
if !ok {
return qerr.Error(qerr.CryptoMessageParameterNotFound, "KEXS")
}
if len(kexs)%4 != 0 {
return qerr.Error(qerr.CryptoInvalidValueLength, "KEXS")
}
c255Foundat := -1
for i := 0; i < len(kexs)/4; i++ {
if bytes.Equal(kexs[4*i:4*i+4], []byte("C255")) {
c255Foundat = i
break
}
}
if c255Foundat < 0 {
return qerr.Error(qerr.CryptoNoSupport, "KEXS: Could not find C255, other key exchanges are not supported")
}
// AEAD
aead, ok := tagMap[TagAEAD]
if !ok {
return qerr.Error(qerr.CryptoMessageParameterNotFound, "AEAD")
}
if len(aead)%4 != 0 {
return qerr.Error(qerr.CryptoInvalidValueLength, "AEAD")
}
var aesgFound bool
for i := 0; i < len(aead)/4; i++ {
if bytes.Equal(aead[4*i:4*i+4], []byte("AESG")) {
aesgFound = true
break
}
}
if !aesgFound {
return qerr.Error(qerr.CryptoNoSupport, "AEAD")
}
// PUBS
pubs, ok := tagMap[TagPUBS]
if !ok {
return qerr.Error(qerr.CryptoMessageParameterNotFound, "PUBS")
}
var pubsKexs []struct {
Length uint32
Value []byte
}
var lastLen uint32
for i := 0; i < len(pubs)-3; i += int(lastLen) + 3 {
// the PUBS value is always prepended by 3 byte little endian length field
err := binary.Read(bytes.NewReader([]byte{pubs[i], pubs[i+1], pubs[i+2], 0x00}), binary.LittleEndian, &lastLen)
if err != nil {
return qerr.Error(qerr.CryptoInvalidValueLength, "PUBS not decodable")
}
if lastLen == 0 {
return qerr.Error(qerr.CryptoInvalidValueLength, "PUBS")
}
if i+3+int(lastLen) > len(pubs) {
return qerr.Error(qerr.CryptoInvalidValueLength, "PUBS")
}
pubsKexs = append(pubsKexs, struct {
Length uint32
Value []byte
}{lastLen, pubs[i+3 : i+3+int(lastLen)]})
}
if c255Foundat >= len(pubsKexs) {
return qerr.Error(qerr.CryptoMessageParameterNotFound, "KEXS not in PUBS")
}
if pubsKexs[c255Foundat].Length != 32 {
return qerr.Error(qerr.CryptoInvalidValueLength, "PUBS")
}
var err error
s.kex, err = crypto.NewCurve25519KEX()
if err != nil {
return err
}
s.sharedSecret, err = s.kex.CalculateSharedKey(pubsKexs[c255Foundat].Value)
if err != nil {
return err
}
// OBIT
obit, ok := tagMap[TagOBIT]
if !ok {
return qerr.Error(qerr.CryptoMessageParameterNotFound, "OBIT")
}
if len(obit) != 8 {
return qerr.Error(qerr.CryptoInvalidValueLength, "OBIT")
}
s.obit = obit
// EXPY
expy, ok := tagMap[TagEXPY]
if !ok {
return qerr.Error(qerr.CryptoMessageParameterNotFound, "EXPY")
}
if len(expy) != 8 {
return qerr.Error(qerr.CryptoInvalidValueLength, "EXPY")
}
// make sure that the value doesn't overflow an int64
// furthermore, values close to MaxInt64 are not a valid input to time.Unix, thus set MaxInt64/2 as the maximum value here
expyTimestamp := utils.MinUint64(binary.LittleEndian.Uint64(expy), math.MaxInt64/2)
s.expiry = time.Unix(int64(expyTimestamp), 0)
// TODO: implement VER
return nil
}
func (s *serverConfigClient) IsExpired() bool {
return s.expiry.Before(time.Now())
}
func (s *serverConfigClient) Get() []byte {
return s.raw
}

View File

@@ -0,0 +1,93 @@
package handshake
// A Tag in the QUIC crypto
type Tag uint32
const (
// TagCHLO is a client hello
TagCHLO Tag = 'C' + 'H'<<8 + 'L'<<16 + 'O'<<24
// TagREJ is a server hello rejection
TagREJ Tag = 'R' + 'E'<<8 + 'J'<<16
// TagSCFG is a server config
TagSCFG Tag = 'S' + 'C'<<8 + 'F'<<16 + 'G'<<24
// TagPAD is padding
TagPAD Tag = 'P' + 'A'<<8 + 'D'<<16
// TagSNI is the server name indication
TagSNI Tag = 'S' + 'N'<<8 + 'I'<<16
// TagVER is the QUIC version
TagVER Tag = 'V' + 'E'<<8 + 'R'<<16
// TagCCS are the hashes of the common certificate sets
TagCCS Tag = 'C' + 'C'<<8 + 'S'<<16
// TagCCRT are the hashes of the cached certificates
TagCCRT Tag = 'C' + 'C'<<8 + 'R'<<16 + 'T'<<24
// TagMSPC is max streams per connection
TagMSPC Tag = 'M' + 'S'<<8 + 'P'<<16 + 'C'<<24
// TagMIDS is max incoming dyanamic streams
TagMIDS Tag = 'M' + 'I'<<8 + 'D'<<16 + 'S'<<24
// TagUAID is the user agent ID
TagUAID Tag = 'U' + 'A'<<8 + 'I'<<16 + 'D'<<24
// TagSVID is the server ID (unofficial tag by us :)
TagSVID Tag = 'S' + 'V'<<8 + 'I'<<16 + 'D'<<24
// TagTCID is truncation of the connection ID
TagTCID Tag = 'T' + 'C'<<8 + 'I'<<16 + 'D'<<24
// TagPDMD is the proof demand
TagPDMD Tag = 'P' + 'D'<<8 + 'M'<<16 + 'D'<<24
// TagSRBF is the socket receive buffer
TagSRBF Tag = 'S' + 'R'<<8 + 'B'<<16 + 'F'<<24
// TagICSL is the idle connection state lifetime
TagICSL Tag = 'I' + 'C'<<8 + 'S'<<16 + 'L'<<24
// TagNONP is the client proof nonce
TagNONP Tag = 'N' + 'O'<<8 + 'N'<<16 + 'P'<<24
// TagSCLS is the silently close timeout
TagSCLS Tag = 'S' + 'C'<<8 + 'L'<<16 + 'S'<<24
// TagCSCT is the signed cert timestamp (RFC6962) of leaf cert
TagCSCT Tag = 'C' + 'S'<<8 + 'C'<<16 + 'T'<<24
// TagCOPT are the connection options
TagCOPT Tag = 'C' + 'O'<<8 + 'P'<<16 + 'T'<<24
// TagCFCW is the initial session/connection flow control receive window
TagCFCW Tag = 'C' + 'F'<<8 + 'C'<<16 + 'W'<<24
// TagSFCW is the initial stream flow control receive window.
TagSFCW Tag = 'S' + 'F'<<8 + 'C'<<16 + 'W'<<24
// TagNSTP is the no STOP_WAITING experiment
// currently unsupported by quic-go
TagNSTP Tag = 'N' + 'S'<<8 + 'T'<<16 + 'P'<<24
// TagSTK is the source-address token
TagSTK Tag = 'S' + 'T'<<8 + 'K'<<16
// TagSNO is the server nonce
TagSNO Tag = 'S' + 'N'<<8 + 'O'<<16
// TagPROF is the server proof
TagPROF Tag = 'P' + 'R'<<8 + 'O'<<16 + 'F'<<24
// TagNONC is the client nonce
TagNONC Tag = 'N' + 'O'<<8 + 'N'<<16 + 'C'<<24
// TagXLCT is the expected leaf certificate
TagXLCT Tag = 'X' + 'L'<<8 + 'C'<<16 + 'T'<<24
// TagSCID is the server config ID
TagSCID Tag = 'S' + 'C'<<8 + 'I'<<16 + 'D'<<24
// TagKEXS is the list of key exchange algos
TagKEXS Tag = 'K' + 'E'<<8 + 'X'<<16 + 'S'<<24
// TagAEAD is the list of AEAD algos
TagAEAD Tag = 'A' + 'E'<<8 + 'A'<<16 + 'D'<<24
// TagPUBS is the public value for the KEX
TagPUBS Tag = 'P' + 'U'<<8 + 'B'<<16 + 'S'<<24
// TagOBIT is the client orbit
TagOBIT Tag = 'O' + 'B'<<8 + 'I'<<16 + 'T'<<24
// TagEXPY is the server config expiry
TagEXPY Tag = 'E' + 'X'<<8 + 'P'<<16 + 'Y'<<24
// TagCERT is the CERT data
TagCERT Tag = 0xff545243
// TagSHLO is the server hello
TagSHLO Tag = 'S' + 'H'<<8 + 'L'<<16 + 'O'<<24
// TagPRST is the public reset tag
TagPRST Tag = 'P' + 'R'<<8 + 'S'<<16 + 'T'<<24
// TagRSEQ is the public reset rejected packet number
TagRSEQ Tag = 'R' + 'S'<<8 + 'E'<<16 + 'Q'<<24
// TagRNON is the public reset nonce
TagRNON Tag = 'R' + 'N'<<8 + 'O'<<16 + 'N'<<24
)

View File

@@ -0,0 +1,123 @@
package handshake
import (
"bytes"
"encoding/binary"
"errors"
"fmt"
"github.com/bifurcation/mint"
"github.com/lucas-clemente/quic-go/internal/protocol"
"github.com/lucas-clemente/quic-go/internal/utils"
)
type transportParameterID uint16
const quicTLSExtensionType = 0xff5
const (
initialMaxStreamDataParameterID transportParameterID = 0x0
initialMaxDataParameterID transportParameterID = 0x1
initialMaxBidiStreamsParameterID transportParameterID = 0x2
idleTimeoutParameterID transportParameterID = 0x3
maxPacketSizeParameterID transportParameterID = 0x5
statelessResetTokenParameterID transportParameterID = 0x6
initialMaxUniStreamsParameterID transportParameterID = 0x8
disableMigrationParameterID transportParameterID = 0x9
)
type clientHelloTransportParameters struct {
InitialVersion protocol.VersionNumber
Parameters TransportParameters
}
func (p *clientHelloTransportParameters) Marshal() []byte {
const lenOffset = 4
b := &bytes.Buffer{}
utils.BigEndian.WriteUint32(b, uint32(p.InitialVersion))
b.Write([]byte{0, 0}) // length. Will be replaced later
p.Parameters.marshal(b)
data := b.Bytes()
binary.BigEndian.PutUint16(data[lenOffset:lenOffset+2], uint16(len(data)-lenOffset-2))
return data
}
func (p *clientHelloTransportParameters) Unmarshal(data []byte) error {
if len(data) < 6 {
return errors.New("transport parameter data too short")
}
p.InitialVersion = protocol.VersionNumber(binary.BigEndian.Uint32(data[:4]))
paramsLen := int(binary.BigEndian.Uint16(data[4:6]))
data = data[6:]
if len(data) != paramsLen {
return fmt.Errorf("expected transport parameters to be %d bytes long, have %d", paramsLen, len(data))
}
return p.Parameters.unmarshal(data)
}
type encryptedExtensionsTransportParameters struct {
NegotiatedVersion protocol.VersionNumber
SupportedVersions []protocol.VersionNumber
Parameters TransportParameters
}
func (p *encryptedExtensionsTransportParameters) Marshal() []byte {
b := &bytes.Buffer{}
utils.BigEndian.WriteUint32(b, uint32(p.NegotiatedVersion))
b.WriteByte(uint8(4 * len(p.SupportedVersions)))
for _, v := range p.SupportedVersions {
utils.BigEndian.WriteUint32(b, uint32(v))
}
lenOffset := b.Len()
b.Write([]byte{0, 0}) // length. Will be replaced later
p.Parameters.marshal(b)
data := b.Bytes()
binary.BigEndian.PutUint16(data[lenOffset:lenOffset+2], uint16(len(data)-lenOffset-2))
return data
}
func (p *encryptedExtensionsTransportParameters) Unmarshal(data []byte) error {
if len(data) < 5 {
return errors.New("transport parameter data too short")
}
p.NegotiatedVersion = protocol.VersionNumber(binary.BigEndian.Uint32(data[:4]))
numVersions := int(data[4])
if numVersions%4 != 0 {
return fmt.Errorf("invalid length for version list: %d", numVersions)
}
numVersions /= 4
data = data[5:]
if len(data) < 4*numVersions+2 /*length field for the parameter list */ {
return errors.New("transport parameter data too short")
}
p.SupportedVersions = make([]protocol.VersionNumber, numVersions)
for i := 0; i < numVersions; i++ {
p.SupportedVersions[i] = protocol.VersionNumber(binary.BigEndian.Uint32(data[:4]))
data = data[4:]
}
paramsLen := int(binary.BigEndian.Uint16(data[:2]))
data = data[2:]
if len(data) != paramsLen {
return fmt.Errorf("expected transport parameters to be %d bytes long, have %d", paramsLen, len(data))
}
return p.Parameters.unmarshal(data)
}
type tlsExtensionBody struct {
data []byte
}
var _ mint.ExtensionBody = &tlsExtensionBody{}
func (e *tlsExtensionBody) Type() mint.ExtensionType {
return quicTLSExtensionType
}
func (e *tlsExtensionBody) Marshal() ([]byte, error) {
return e.data, nil
}
func (e *tlsExtensionBody) Unmarshal(data []byte) (int, error) {
e.data = data
return len(data), nil
}

View File

@@ -0,0 +1,112 @@
package handshake
import (
"errors"
"fmt"
"github.com/lucas-clemente/quic-go/qerr"
"github.com/bifurcation/mint"
"github.com/lucas-clemente/quic-go/internal/protocol"
"github.com/lucas-clemente/quic-go/internal/utils"
)
type extensionHandlerClient struct {
ourParams *TransportParameters
paramsChan chan TransportParameters
initialVersion protocol.VersionNumber
supportedVersions []protocol.VersionNumber
version protocol.VersionNumber
logger utils.Logger
}
var _ mint.AppExtensionHandler = &extensionHandlerClient{}
var _ TLSExtensionHandler = &extensionHandlerClient{}
// NewExtensionHandlerClient creates a new extension handler for the client.
func NewExtensionHandlerClient(
params *TransportParameters,
initialVersion protocol.VersionNumber,
supportedVersions []protocol.VersionNumber,
version protocol.VersionNumber,
logger utils.Logger,
) TLSExtensionHandler {
// The client reads the transport parameters from the Encrypted Extensions message.
// The paramsChan is used in the session's run loop's select statement.
// We have to use an unbuffered channel here to make sure that the session actually processes the transport parameters immediately.
paramsChan := make(chan TransportParameters)
return &extensionHandlerClient{
ourParams: params,
paramsChan: paramsChan,
initialVersion: initialVersion,
supportedVersions: supportedVersions,
version: version,
logger: logger,
}
}
func (h *extensionHandlerClient) Send(hType mint.HandshakeType, el *mint.ExtensionList) error {
if hType != mint.HandshakeTypeClientHello {
return nil
}
h.logger.Debugf("Sending Transport Parameters: %s", h.ourParams)
chtp := &clientHelloTransportParameters{
InitialVersion: h.initialVersion,
Parameters: *h.ourParams,
}
return el.Add(&tlsExtensionBody{data: chtp.Marshal()})
}
func (h *extensionHandlerClient) Receive(hType mint.HandshakeType, el *mint.ExtensionList) error {
ext := &tlsExtensionBody{}
found, err := el.Find(ext)
if err != nil {
return err
}
if hType != mint.HandshakeTypeEncryptedExtensions {
if found {
return fmt.Errorf("Unexpected QUIC extension in handshake message %d", hType)
}
return nil
}
// hType == mint.HandshakeTypeEncryptedExtensions
if !found {
return errors.New("EncryptedExtensions message didn't contain a QUIC extension")
}
eetp := &encryptedExtensionsTransportParameters{}
if err := eetp.Unmarshal(ext.data); err != nil {
return err
}
// check that the negotiated_version is the current version
if eetp.NegotiatedVersion != h.version {
return qerr.Error(qerr.VersionNegotiationMismatch, "current version doesn't match negotiated_version")
}
// check that the current version is included in the supported versions
if !protocol.IsSupportedVersion(eetp.SupportedVersions, h.version) {
return qerr.Error(qerr.VersionNegotiationMismatch, "current version not included in the supported versions")
}
// if version negotiation was performed, check that we would have selected the current version based on the supported versions sent by the server
if h.version != h.initialVersion {
negotiatedVersion, ok := protocol.ChooseSupportedVersion(h.supportedVersions, eetp.SupportedVersions)
if !ok || h.version != negotiatedVersion {
return qerr.Error(qerr.VersionNegotiationMismatch, "would have picked a different version")
}
}
// check that the server sent a stateless reset token
if len(eetp.Parameters.StatelessResetToken) == 0 {
return errors.New("server didn't sent stateless_reset_token")
}
h.logger.Debugf("Received Transport Parameters: %s", &eetp.Parameters)
h.paramsChan <- eetp.Parameters
return nil
}
func (h *extensionHandlerClient) GetPeerParams() <-chan TransportParameters {
return h.paramsChan
}

View File

@@ -0,0 +1,100 @@
package handshake
import (
"errors"
"fmt"
"github.com/lucas-clemente/quic-go/qerr"
"github.com/bifurcation/mint"
"github.com/lucas-clemente/quic-go/internal/protocol"
"github.com/lucas-clemente/quic-go/internal/utils"
)
type extensionHandlerServer struct {
ourParams *TransportParameters
paramsChan chan TransportParameters
version protocol.VersionNumber
supportedVersions []protocol.VersionNumber
logger utils.Logger
}
var _ mint.AppExtensionHandler = &extensionHandlerServer{}
var _ TLSExtensionHandler = &extensionHandlerServer{}
// NewExtensionHandlerServer creates a new extension handler for the server
func NewExtensionHandlerServer(
params *TransportParameters,
supportedVersions []protocol.VersionNumber,
version protocol.VersionNumber,
logger utils.Logger,
) TLSExtensionHandler {
// Processing the ClientHello is performed statelessly (and from a single go-routine).
// Therefore, we have to use a buffered chan to pass the transport parameters to that go routine.
paramsChan := make(chan TransportParameters, 1)
return &extensionHandlerServer{
ourParams: params,
paramsChan: paramsChan,
supportedVersions: supportedVersions,
version: version,
logger: logger,
}
}
func (h *extensionHandlerServer) Send(hType mint.HandshakeType, el *mint.ExtensionList) error {
if hType != mint.HandshakeTypeEncryptedExtensions {
return nil
}
h.logger.Debugf("Sending Transport Parameters: %s", h.ourParams)
eetp := &encryptedExtensionsTransportParameters{
NegotiatedVersion: h.version,
SupportedVersions: protocol.GetGreasedVersions(h.supportedVersions),
Parameters: *h.ourParams,
}
return el.Add(&tlsExtensionBody{data: eetp.Marshal()})
}
func (h *extensionHandlerServer) Receive(hType mint.HandshakeType, el *mint.ExtensionList) error {
ext := &tlsExtensionBody{}
found, err := el.Find(ext)
if err != nil {
return err
}
if hType != mint.HandshakeTypeClientHello {
if found {
return fmt.Errorf("Unexpected QUIC extension in handshake message %d", hType)
}
return nil
}
if !found {
return errors.New("ClientHello didn't contain a QUIC extension")
}
chtp := &clientHelloTransportParameters{}
if err := chtp.Unmarshal(ext.data); err != nil {
return err
}
// perform the stateless version negotiation validation:
// make sure that we would have sent a Version Negotiation Packet if the client offered the initial version
// this is the case if and only if the initial version is not contained in the supported versions
if chtp.InitialVersion != h.version && protocol.IsSupportedVersion(h.supportedVersions, chtp.InitialVersion) {
return qerr.Error(qerr.VersionNegotiationMismatch, "Client should have used the initial version")
}
// check that the client didn't send a stateless reset token
if len(chtp.Parameters.StatelessResetToken) != 0 {
// TODO: return the correct error type
return errors.New("client sent a stateless reset token")
}
h.logger.Debugf("Received Transport Parameters: %s", &chtp.Parameters)
h.paramsChan <- chtp.Parameters
return nil
}
func (h *extensionHandlerServer) GetPeerParams() <-chan TransportParameters {
return h.paramsChan
}

View File

@@ -0,0 +1,215 @@
package handshake
import (
"bytes"
"encoding/binary"
"fmt"
"sort"
"time"
"github.com/lucas-clemente/quic-go/internal/protocol"
"github.com/lucas-clemente/quic-go/internal/utils"
"github.com/lucas-clemente/quic-go/qerr"
)
// errMalformedTag is returned when the tag value cannot be read
var errMalformedTag = qerr.Error(qerr.InvalidCryptoMessageParameter, "malformed Tag value")
// TransportParameters are parameters sent to the peer during the handshake
type TransportParameters struct {
StreamFlowControlWindow protocol.ByteCount
ConnectionFlowControlWindow protocol.ByteCount
MaxPacketSize protocol.ByteCount
MaxUniStreams uint16 // only used for IETF QUIC
MaxBidiStreams uint16 // only used for IETF QUIC
MaxStreams uint32 // only used for gQUIC
OmitConnectionID bool // only used for gQUIC
IdleTimeout time.Duration
DisableMigration bool // only used for IETF QUIC
StatelessResetToken []byte // only used for IETF QUIC
}
// readHelloMap reads the transport parameters from the tags sent in a gQUIC handshake message
func readHelloMap(tags map[Tag][]byte) (*TransportParameters, error) {
params := &TransportParameters{}
if value, ok := tags[TagTCID]; ok {
v, err := utils.LittleEndian.ReadUint32(bytes.NewBuffer(value))
if err != nil {
return nil, errMalformedTag
}
params.OmitConnectionID = (v == 0)
}
if value, ok := tags[TagMIDS]; ok {
v, err := utils.LittleEndian.ReadUint32(bytes.NewBuffer(value))
if err != nil {
return nil, errMalformedTag
}
params.MaxStreams = v
}
if value, ok := tags[TagICSL]; ok {
v, err := utils.LittleEndian.ReadUint32(bytes.NewBuffer(value))
if err != nil {
return nil, errMalformedTag
}
params.IdleTimeout = utils.MaxDuration(protocol.MinRemoteIdleTimeout, time.Duration(v)*time.Second)
}
if value, ok := tags[TagSFCW]; ok {
v, err := utils.LittleEndian.ReadUint32(bytes.NewBuffer(value))
if err != nil {
return nil, errMalformedTag
}
params.StreamFlowControlWindow = protocol.ByteCount(v)
}
if value, ok := tags[TagCFCW]; ok {
v, err := utils.LittleEndian.ReadUint32(bytes.NewBuffer(value))
if err != nil {
return nil, errMalformedTag
}
params.ConnectionFlowControlWindow = protocol.ByteCount(v)
}
return params, nil
}
// GetHelloMap gets all parameters needed for the Hello message in the gQUIC handshake.
func (p *TransportParameters) getHelloMap() map[Tag][]byte {
sfcw := bytes.NewBuffer([]byte{})
utils.LittleEndian.WriteUint32(sfcw, uint32(p.StreamFlowControlWindow))
cfcw := bytes.NewBuffer([]byte{})
utils.LittleEndian.WriteUint32(cfcw, uint32(p.ConnectionFlowControlWindow))
mids := bytes.NewBuffer([]byte{})
utils.LittleEndian.WriteUint32(mids, p.MaxStreams)
icsl := bytes.NewBuffer([]byte{})
utils.LittleEndian.WriteUint32(icsl, uint32(p.IdleTimeout/time.Second))
tags := map[Tag][]byte{
TagICSL: icsl.Bytes(),
TagMIDS: mids.Bytes(),
TagCFCW: cfcw.Bytes(),
TagSFCW: sfcw.Bytes(),
}
if p.OmitConnectionID {
tags[TagTCID] = []byte{0, 0, 0, 0}
}
return tags
}
func (p *TransportParameters) unmarshal(data []byte) error {
// needed to check that every parameter is only sent at most once
var parameterIDs []transportParameterID
for len(data) >= 4 {
paramID := transportParameterID(binary.BigEndian.Uint16(data[:2]))
paramLen := int(binary.BigEndian.Uint16(data[2:4]))
data = data[4:]
if len(data) < paramLen {
return fmt.Errorf("remaining length (%d) smaller than parameter length (%d)", len(data), paramLen)
}
parameterIDs = append(parameterIDs, paramID)
switch paramID {
case initialMaxStreamDataParameterID:
if paramLen != 4 {
return fmt.Errorf("wrong length for initial_max_stream_data: %d (expected 4)", paramLen)
}
p.StreamFlowControlWindow = protocol.ByteCount(binary.BigEndian.Uint32(data[:4]))
case initialMaxDataParameterID:
if paramLen != 4 {
return fmt.Errorf("wrong length for initial_max_data: %d (expected 4)", paramLen)
}
p.ConnectionFlowControlWindow = protocol.ByteCount(binary.BigEndian.Uint32(data[:4]))
case initialMaxBidiStreamsParameterID:
if paramLen != 2 {
return fmt.Errorf("wrong length for initial_max_stream_id_bidi: %d (expected 2)", paramLen)
}
p.MaxBidiStreams = binary.BigEndian.Uint16(data[:2])
case initialMaxUniStreamsParameterID:
if paramLen != 2 {
return fmt.Errorf("wrong length for initial_max_stream_id_uni: %d (expected 2)", paramLen)
}
p.MaxUniStreams = binary.BigEndian.Uint16(data[:2])
case idleTimeoutParameterID:
if paramLen != 2 {
return fmt.Errorf("wrong length for idle_timeout: %d (expected 2)", paramLen)
}
p.IdleTimeout = utils.MaxDuration(protocol.MinRemoteIdleTimeout, time.Duration(binary.BigEndian.Uint16(data[:2]))*time.Second)
case maxPacketSizeParameterID:
if paramLen != 2 {
return fmt.Errorf("wrong length for max_packet_size: %d (expected 2)", paramLen)
}
maxPacketSize := protocol.ByteCount(binary.BigEndian.Uint16(data[:2]))
if maxPacketSize < 1200 {
return fmt.Errorf("invalid value for max_packet_size: %d (minimum 1200)", maxPacketSize)
}
p.MaxPacketSize = maxPacketSize
case disableMigrationParameterID:
if paramLen != 0 {
return fmt.Errorf("wrong length for disable_migration: %d (expected empty)", paramLen)
}
p.DisableMigration = true
case statelessResetTokenParameterID:
if paramLen != 16 {
return fmt.Errorf("wrong length for stateless_reset_token: %d (expected 16)", paramLen)
}
p.StatelessResetToken = data[:16]
}
data = data[paramLen:]
}
// check that every transport parameter was sent at most once
sort.Slice(parameterIDs, func(i, j int) bool { return parameterIDs[i] < parameterIDs[j] })
for i := 0; i < len(parameterIDs)-1; i++ {
if parameterIDs[i] == parameterIDs[i+1] {
return fmt.Errorf("received duplicate transport parameter %#x", parameterIDs[i])
}
}
if len(data) != 0 {
return fmt.Errorf("should have read all data. Still have %d bytes", len(data))
}
return nil
}
func (p *TransportParameters) marshal(b *bytes.Buffer) {
// initial_max_stream_data
utils.BigEndian.WriteUint16(b, uint16(initialMaxStreamDataParameterID))
utils.BigEndian.WriteUint16(b, 4)
utils.BigEndian.WriteUint32(b, uint32(p.StreamFlowControlWindow))
// initial_max_data
utils.BigEndian.WriteUint16(b, uint16(initialMaxDataParameterID))
utils.BigEndian.WriteUint16(b, 4)
utils.BigEndian.WriteUint32(b, uint32(p.ConnectionFlowControlWindow))
// initial_max_bidi_streams
utils.BigEndian.WriteUint16(b, uint16(initialMaxBidiStreamsParameterID))
utils.BigEndian.WriteUint16(b, 2)
utils.BigEndian.WriteUint16(b, p.MaxBidiStreams)
// initial_max_uni_streams
utils.BigEndian.WriteUint16(b, uint16(initialMaxUniStreamsParameterID))
utils.BigEndian.WriteUint16(b, 2)
utils.BigEndian.WriteUint16(b, p.MaxUniStreams)
// idle_timeout
utils.BigEndian.WriteUint16(b, uint16(idleTimeoutParameterID))
utils.BigEndian.WriteUint16(b, 2)
utils.BigEndian.WriteUint16(b, uint16(p.IdleTimeout/time.Second))
// max_packet_size
utils.BigEndian.WriteUint16(b, uint16(maxPacketSizeParameterID))
utils.BigEndian.WriteUint16(b, 2)
utils.BigEndian.WriteUint16(b, uint16(protocol.MaxReceivePacketSize))
// disable_migration
if p.DisableMigration {
utils.BigEndian.WriteUint16(b, uint16(disableMigrationParameterID))
utils.BigEndian.WriteUint16(b, 0)
}
if len(p.StatelessResetToken) > 0 {
utils.BigEndian.WriteUint16(b, uint16(statelessResetTokenParameterID))
utils.BigEndian.WriteUint16(b, uint16(len(p.StatelessResetToken))) // should always be 16 bytes
b.Write(p.StatelessResetToken)
}
}
// String returns a string representation, intended for logging.
// It should only used for IETF QUIC.
func (p *TransportParameters) String() string {
return fmt.Sprintf("&handshake.TransportParameters{StreamFlowControlWindow: %#x, ConnectionFlowControlWindow: %#x, MaxBidiStreams: %d, MaxUniStreams: %d, IdleTimeout: %s}", p.StreamFlowControlWindow, p.ConnectionFlowControlWindow, p.MaxBidiStreams, p.MaxUniStreams, p.IdleTimeout)
}

View File

@@ -0,0 +1,69 @@
package protocol
import (
"bytes"
"crypto/rand"
"fmt"
"io"
)
// A ConnectionID in QUIC
type ConnectionID []byte
const maxConnectionIDLen = 18
// GenerateConnectionID generates a connection ID using cryptographic random
func GenerateConnectionID(len int) (ConnectionID, error) {
b := make([]byte, len)
if _, err := rand.Read(b); err != nil {
return nil, err
}
return ConnectionID(b), nil
}
// GenerateConnectionIDForInitial generates a connection ID for the Initial packet.
// It uses a length randomly chosen between 8 and 18 bytes.
func GenerateConnectionIDForInitial() (ConnectionID, error) {
r := make([]byte, 1)
if _, err := rand.Read(r); err != nil {
return nil, err
}
len := MinConnectionIDLenInitial + int(r[0])%(maxConnectionIDLen-MinConnectionIDLenInitial+1)
return GenerateConnectionID(len)
}
// ReadConnectionID reads a connection ID of length len from the given io.Reader.
// It returns io.EOF if there are not enough bytes to read.
func ReadConnectionID(r io.Reader, len int) (ConnectionID, error) {
if len == 0 {
return nil, nil
}
c := make(ConnectionID, len)
_, err := io.ReadFull(r, c)
if err == io.ErrUnexpectedEOF {
return nil, io.EOF
}
return c, err
}
// Equal says if two connection IDs are equal
func (c ConnectionID) Equal(other ConnectionID) bool {
return bytes.Equal(c, other)
}
// Len returns the length of the connection ID in bytes
func (c ConnectionID) Len() int {
return len(c)
}
// Bytes returns the byte representation
func (c ConnectionID) Bytes() []byte {
return []byte(c)
}
func (c ConnectionID) String() string {
if c.Len() == 0 {
return "(empty)"
}
return fmt.Sprintf("%#x", c.Bytes())
}

View File

@@ -0,0 +1,28 @@
package protocol
// EncryptionLevel is the encryption level
// Default value is Unencrypted
type EncryptionLevel int
const (
// EncryptionUnspecified is a not specified encryption level
EncryptionUnspecified EncryptionLevel = iota
// EncryptionUnencrypted is not encrypted
EncryptionUnencrypted
// EncryptionSecure is encrypted, but not forward secure
EncryptionSecure
// EncryptionForwardSecure is forward secure
EncryptionForwardSecure
)
func (e EncryptionLevel) String() string {
switch e {
case EncryptionUnencrypted:
return "unencrypted"
case EncryptionSecure:
return "encrypted (not forward-secure)"
case EncryptionForwardSecure:
return "forward-secure"
}
return "unknown"
}

View File

@@ -0,0 +1,70 @@
package protocol
// InferPacketNumber calculates the packet number based on the received packet number, its length and the last seen packet number
func InferPacketNumber(
packetNumberLength PacketNumberLen,
lastPacketNumber PacketNumber,
wirePacketNumber PacketNumber,
version VersionNumber,
) PacketNumber {
var epochDelta PacketNumber
if version.UsesVarintPacketNumbers() {
switch packetNumberLength {
case PacketNumberLen1:
epochDelta = PacketNumber(1) << 7
case PacketNumberLen2:
epochDelta = PacketNumber(1) << 14
case PacketNumberLen4:
epochDelta = PacketNumber(1) << 30
}
} else {
epochDelta = PacketNumber(1) << (uint8(packetNumberLength) * 8)
}
epoch := lastPacketNumber & ^(epochDelta - 1)
prevEpochBegin := epoch - epochDelta
nextEpochBegin := epoch + epochDelta
return closestTo(
lastPacketNumber+1,
epoch+wirePacketNumber,
closestTo(lastPacketNumber+1, prevEpochBegin+wirePacketNumber, nextEpochBegin+wirePacketNumber),
)
}
func closestTo(target, a, b PacketNumber) PacketNumber {
if delta(target, a) < delta(target, b) {
return a
}
return b
}
func delta(a, b PacketNumber) PacketNumber {
if a < b {
return b - a
}
return a - b
}
// GetPacketNumberLengthForHeader gets the length of the packet number for the public header
// it never chooses a PacketNumberLen of 1 byte, since this is too short under certain circumstances
func GetPacketNumberLengthForHeader(packetNumber, leastUnacked PacketNumber, version VersionNumber) PacketNumberLen {
diff := uint64(packetNumber - leastUnacked)
if version.UsesVarintPacketNumbers() && diff < (1<<(14-1)) ||
!version.UsesVarintPacketNumbers() && diff < (1<<(16-1)) {
return PacketNumberLen2
}
return PacketNumberLen4
}
// GetPacketNumberLength gets the minimum length needed to fully represent the packet number
func GetPacketNumberLength(packetNumber PacketNumber) PacketNumberLen {
if packetNumber < (1 << (uint8(PacketNumberLen1) * 8)) {
return PacketNumberLen1
}
if packetNumber < (1 << (uint8(PacketNumberLen2) * 8)) {
return PacketNumberLen2
}
if packetNumber < (1 << (uint8(PacketNumberLen4) * 8)) {
return PacketNumberLen4
}
return PacketNumberLen6
}

View File

@@ -0,0 +1,26 @@
package protocol
// Perspective determines if we're acting as a server or a client
type Perspective int
// the perspectives
const (
PerspectiveServer Perspective = 1
PerspectiveClient Perspective = 2
)
// Opposite returns the perspective of the peer
func (p Perspective) Opposite() Perspective {
return 3 - p
}
func (p Perspective) String() string {
switch p {
case PerspectiveServer:
return "Server"
case PerspectiveClient:
return "Client"
default:
return "invalid perspective"
}
}

View File

@@ -0,0 +1,90 @@
package protocol
import (
"fmt"
)
// A PacketNumber in QUIC
type PacketNumber uint64
// PacketNumberLen is the length of the packet number in bytes
type PacketNumberLen uint8
const (
// PacketNumberLenInvalid is the default value and not a valid length for a packet number
PacketNumberLenInvalid PacketNumberLen = 0
// PacketNumberLen1 is a packet number length of 1 byte
PacketNumberLen1 PacketNumberLen = 1
// PacketNumberLen2 is a packet number length of 2 bytes
PacketNumberLen2 PacketNumberLen = 2
// PacketNumberLen4 is a packet number length of 4 bytes
PacketNumberLen4 PacketNumberLen = 4
// PacketNumberLen6 is a packet number length of 6 bytes
PacketNumberLen6 PacketNumberLen = 6
)
// The PacketType is the Long Header Type (only used for the IETF draft header format)
type PacketType uint8
const (
// PacketTypeInitial is the packet type of an Initial packet
PacketTypeInitial PacketType = 0x7f
// PacketTypeRetry is the packet type of a Retry packet
PacketTypeRetry PacketType = 0x7e
// PacketTypeHandshake is the packet type of a Handshake packet
PacketTypeHandshake PacketType = 0x7d
// PacketType0RTT is the packet type of a 0-RTT packet
PacketType0RTT PacketType = 0x7c
)
func (t PacketType) String() string {
switch t {
case PacketTypeInitial:
return "Initial"
case PacketTypeRetry:
return "Retry"
case PacketTypeHandshake:
return "Handshake"
case PacketType0RTT:
return "0-RTT Protected"
default:
return fmt.Sprintf("unknown packet type: %d", t)
}
}
// A ByteCount in QUIC
type ByteCount uint64
// MaxByteCount is the maximum value of a ByteCount
const MaxByteCount = ByteCount(1<<62 - 1)
// An ApplicationErrorCode is an application-defined error code.
type ApplicationErrorCode uint16
// MaxReceivePacketSize maximum packet size of any QUIC packet, based on
// ethernet's max size, minus the IP and UDP headers. IPv6 has a 40 byte header,
// UDP adds an additional 8 bytes. This is a total overhead of 48 bytes.
// Ethernet's max packet size is 1500 bytes, 1500 - 48 = 1452.
const MaxReceivePacketSize ByteCount = 1452
// DefaultTCPMSS is the default maximum packet size used in the Linux TCP implementation.
// Used in QUIC for congestion window computations in bytes.
const DefaultTCPMSS ByteCount = 1460
// MinClientHelloSize is the minimum size the server expects an inchoate CHLO to have (in gQUIC)
const MinClientHelloSize = 1024
// MinInitialPacketSize is the minimum size an Initial packet (in IETF QUIC) is required to have.
const MinInitialPacketSize = 1200
// MaxClientHellos is the maximum number of times we'll send a client hello
// The value 3 accounts for:
// * one failure due to an incorrect or missing source-address token
// * one failure due the server's certificate chain being unavailable and the server being unwilling to send it without a valid source-address token
const MaxClientHellos = 3
// ConnectionIDLenGQUIC is the length of the source Connection ID used on gQUIC QUIC packets.
const ConnectionIDLenGQUIC = 8
// MinConnectionIDLenInitial is the minimum length of the destination connection ID on an Initial packet.
const MinConnectionIDLenInitial = 8

View File

@@ -0,0 +1,151 @@
package protocol
import "time"
// MaxPacketSizeIPv4 is the maximum packet size that we use for sending IPv4 packets.
const MaxPacketSizeIPv4 = 1252
// MaxPacketSizeIPv6 is the maximum packet size that we use for sending IPv6 packets.
const MaxPacketSizeIPv6 = 1232
// NonForwardSecurePacketSizeReduction is the number of bytes a non forward-secure packet has to be smaller than a forward-secure packet
// This makes sure that those packets can always be retransmitted without splitting the contained StreamFrames
const NonForwardSecurePacketSizeReduction = 50
const defaultMaxCongestionWindowPackets = 1000
// DefaultMaxCongestionWindow is the default for the max congestion window
const DefaultMaxCongestionWindow ByteCount = defaultMaxCongestionWindowPackets * DefaultTCPMSS
// InitialCongestionWindow is the initial congestion window in QUIC packets
const InitialCongestionWindow ByteCount = 32 * DefaultTCPMSS
// MaxUndecryptablePackets limits the number of undecryptable packets that a
// session queues for later until it sends a public reset.
const MaxUndecryptablePackets = 10
// PublicResetTimeout is the time to wait before sending a Public Reset when receiving too many undecryptable packets during the handshake
// This timeout allows the Go scheduler to switch to the Go rountine that reads the crypto stream and to escalate the crypto
const PublicResetTimeout = 500 * time.Millisecond
// ReceiveStreamFlowControlWindow is the stream-level flow control window for receiving data
// This is the value that Google servers are using
const ReceiveStreamFlowControlWindow = (1 << 10) * 32 // 32 kB
// ReceiveConnectionFlowControlWindow is the connection-level flow control window for receiving data
// This is the value that Google servers are using
const ReceiveConnectionFlowControlWindow = (1 << 10) * 48 // 48 kB
// DefaultMaxReceiveStreamFlowControlWindowServer is the default maximum stream-level flow control window for receiving data, for the server
// This is the value that Google servers are using
const DefaultMaxReceiveStreamFlowControlWindowServer = 1 * (1 << 20) // 1 MB
// DefaultMaxReceiveConnectionFlowControlWindowServer is the default connection-level flow control window for receiving data, for the server
// This is the value that Google servers are using
const DefaultMaxReceiveConnectionFlowControlWindowServer = 1.5 * (1 << 20) // 1.5 MB
// DefaultMaxReceiveStreamFlowControlWindowClient is the default maximum stream-level flow control window for receiving data, for the client
// This is the value that Chromium is using
const DefaultMaxReceiveStreamFlowControlWindowClient = 6 * (1 << 20) // 6 MB
// DefaultMaxReceiveConnectionFlowControlWindowClient is the default connection-level flow control window for receiving data, for the client
// This is the value that Google servers are using
const DefaultMaxReceiveConnectionFlowControlWindowClient = 15 * (1 << 20) // 15 MB
// ConnectionFlowControlMultiplier determines how much larger the connection flow control windows needs to be relative to any stream's flow control window
// This is the value that Chromium is using
const ConnectionFlowControlMultiplier = 1.5
// WindowUpdateThreshold is the fraction of the receive window that has to be consumed before an higher offset is advertised to the client
const WindowUpdateThreshold = 0.25
// DefaultMaxIncomingStreams is the maximum number of streams that a peer may open
const DefaultMaxIncomingStreams = 100
// DefaultMaxIncomingUniStreams is the maximum number of unidirectional streams that a peer may open
const DefaultMaxIncomingUniStreams = 100
// MaxStreamsMultiplier is the slack the client is allowed for the maximum number of streams per connection, needed e.g. when packets are out of order or dropped. The minimum of this procentual increase and the absolute increment specified by MaxStreamsMinimumIncrement is used.
const MaxStreamsMultiplier = 1.1
// MaxStreamsMinimumIncrement is the slack the client is allowed for the maximum number of streams per connection, needed e.g. when packets are out of order or dropped. The minimum of this absolute increment and the procentual increase specified by MaxStreamsMultiplier is used.
const MaxStreamsMinimumIncrement = 10
// MaxSessionUnprocessedPackets is the max number of packets stored in each session that are not yet processed.
const MaxSessionUnprocessedPackets = defaultMaxCongestionWindowPackets
// SkipPacketAveragePeriodLength is the average period length in which one packet number is skipped to prevent an Optimistic ACK attack
const SkipPacketAveragePeriodLength PacketNumber = 500
// MaxTrackedSkippedPackets is the maximum number of skipped packet numbers the SentPacketHandler keep track of for Optimistic ACK attack mitigation
const MaxTrackedSkippedPackets = 10
// CookieExpiryTime is the valid time of a cookie
const CookieExpiryTime = 24 * time.Hour
// MaxOutstandingSentPackets is maximum number of packets saved for retransmission.
// When reached, it imposes a soft limit on sending new packets:
// Sending ACKs and retransmission is still allowed, but now new regular packets can be sent.
const MaxOutstandingSentPackets = 2 * defaultMaxCongestionWindowPackets
// MaxTrackedSentPackets is maximum number of sent packets saved for retransmission.
// When reached, no more packets will be sent.
// This value *must* be larger than MaxOutstandingSentPackets.
const MaxTrackedSentPackets = MaxOutstandingSentPackets * 5 / 4
// MaxTrackedReceivedAckRanges is the maximum number of ACK ranges tracked
const MaxTrackedReceivedAckRanges = defaultMaxCongestionWindowPackets
// MaxNonRetransmittableAcks is the maximum number of packets containing an ACK, but no retransmittable frames, that we send in a row
const MaxNonRetransmittableAcks = 19
// MaxStreamFrameSorterGaps is the maximum number of gaps between received StreamFrames
// prevents DoS attacks against the streamFrameSorter
const MaxStreamFrameSorterGaps = 1000
// CryptoMaxParams is the upper limit for the number of parameters in a crypto message.
// Value taken from Chrome.
const CryptoMaxParams = 128
// CryptoParameterMaxLength is the upper limit for the length of a parameter in a crypto message.
const CryptoParameterMaxLength = 4000
// EphermalKeyLifetime is the lifetime of the ephermal key during the handshake, see handshake.getEphermalKEX.
const EphermalKeyLifetime = time.Minute
// MinRemoteIdleTimeout is the minimum value that we accept for the remote idle timeout
const MinRemoteIdleTimeout = 5 * time.Second
// DefaultIdleTimeout is the default idle timeout
const DefaultIdleTimeout = 30 * time.Second
// DefaultHandshakeTimeout is the default timeout for a connection until the crypto handshake succeeds.
const DefaultHandshakeTimeout = 10 * time.Second
// ClosedSessionDeleteTimeout the server ignores packets arriving on a connection that is already closed
// after this time all information about the old connection will be deleted
const ClosedSessionDeleteTimeout = time.Minute
// NumCachedCertificates is the number of cached compressed certificate chains, each taking ~1K space
const NumCachedCertificates = 128
// MinStreamFrameSize is the minimum size that has to be left in a packet, so that we add another STREAM frame.
// This avoids splitting up STREAM frames into small pieces, which has 2 advantages:
// 1. it reduces the framing overhead
// 2. it reduces the head-of-line blocking, when a packet is lost
const MinStreamFrameSize ByteCount = 128
// MaxAckFrameSize is the maximum size for an (IETF QUIC) ACK frame that we write
// Due to the varint encoding, ACK frames can grow (almost) indefinitely large.
// The MaxAckFrameSize should be large enough to encode many ACK range,
// but must ensure that a maximum size ACK frame fits into one packet.
const MaxAckFrameSize ByteCount = 1000
// MinPacingDelay is the minimum duration that is used for packet pacing
// If the packet packing frequency is higher, multiple packets might be sent at once.
// Example: For a packet pacing delay of 20 microseconds, we would send 5 packets at once, wait for 100 microseconds, and so forth.
const MinPacingDelay time.Duration = 100 * time.Microsecond
// DefaultConnectionIDLength is the connection ID length that is used for multiplexed connections
// if no other value is configured.
const DefaultConnectionIDLength = 4

View File

@@ -0,0 +1,36 @@
package protocol
// A StreamID in QUIC
type StreamID uint64
// MaxBidiStreamID is the highest stream ID that the peer is allowed to open,
// when it is allowed to open numStreams bidirectional streams.
// It is only valid for IETF QUIC.
func MaxBidiStreamID(numStreams int, pers Perspective) StreamID {
if numStreams == 0 {
return 0
}
var first StreamID
if pers == PerspectiveClient {
first = 1
} else {
first = 4
}
return first + 4*StreamID(numStreams-1)
}
// MaxUniStreamID is the highest stream ID that the peer is allowed to open,
// when it is allowed to open numStreams unidirectional streams.
// It is only valid for IETF QUIC.
func MaxUniStreamID(numStreams int, pers Perspective) StreamID {
if numStreams == 0 {
return 0
}
var first StreamID
if pers == PerspectiveClient {
first = 3
} else {
first = 2
}
return first + 4*StreamID(numStreams-1)
}

View File

@@ -0,0 +1,181 @@
package protocol
import (
"crypto/rand"
"encoding/binary"
"fmt"
"math"
)
// VersionNumber is a version number as int
type VersionNumber uint32
// gQUIC version range as defined in the wiki: https://github.com/quicwg/base-drafts/wiki/QUIC-Versions
const (
gquicVersion0 = 0x51303030
maxGquicVersion = 0x51303439
)
// The version numbers, making grepping easier
const (
Version39 VersionNumber = gquicVersion0 + 3*0x100 + 0x9
Version43 VersionNumber = gquicVersion0 + 4*0x100 + 0x3
Version44 VersionNumber = gquicVersion0 + 4*0x100 + 0x4
VersionTLS VersionNumber = 101
VersionWhatever VersionNumber = 0 // for when the version doesn't matter
VersionUnknown VersionNumber = math.MaxUint32
)
// SupportedVersions lists the versions that the server supports
// must be in sorted descending order
var SupportedVersions = []VersionNumber{
Version44,
Version43,
Version39,
}
// IsValidVersion says if the version is known to quic-go
func IsValidVersion(v VersionNumber) bool {
return v == VersionTLS || IsSupportedVersion(SupportedVersions, v)
}
// UsesTLS says if this QUIC version uses TLS 1.3 for the handshake
func (vn VersionNumber) UsesTLS() bool {
return !vn.isGQUIC()
}
func (vn VersionNumber) String() string {
switch vn {
case VersionWhatever:
return "whatever"
case VersionUnknown:
return "unknown"
case VersionTLS:
return "TLS dev version (WIP)"
default:
if vn.isGQUIC() {
return fmt.Sprintf("gQUIC %d", vn.toGQUICVersion())
}
return fmt.Sprintf("%#x", uint32(vn))
}
}
// ToAltSvc returns the representation of the version for the H2 Alt-Svc parameters
func (vn VersionNumber) ToAltSvc() string {
if vn.isGQUIC() {
return fmt.Sprintf("%d", vn.toGQUICVersion())
}
return fmt.Sprintf("%d", vn)
}
// CryptoStreamID gets the Stream ID of the crypto stream
func (vn VersionNumber) CryptoStreamID() StreamID {
if vn.isGQUIC() {
return 1
}
return 0
}
// UsesIETFFrameFormat tells if this version uses the IETF frame format
func (vn VersionNumber) UsesIETFFrameFormat() bool {
return !vn.isGQUIC()
}
// UsesIETFHeaderFormat tells if this version uses the IETF header format
func (vn VersionNumber) UsesIETFHeaderFormat() bool {
return !vn.isGQUIC() || vn >= Version44
}
// UsesLengthInHeader tells if this version uses the Length field in the IETF header
func (vn VersionNumber) UsesLengthInHeader() bool {
return !vn.isGQUIC()
}
// UsesTokenInHeader tells if this version uses the Token field in the IETF header
func (vn VersionNumber) UsesTokenInHeader() bool {
return !vn.isGQUIC()
}
// UsesStopWaitingFrames tells if this version uses STOP_WAITING frames
func (vn VersionNumber) UsesStopWaitingFrames() bool {
return vn.isGQUIC() && vn <= Version43
}
// UsesVarintPacketNumbers tells if this version uses 7/14/30 bit packet numbers
func (vn VersionNumber) UsesVarintPacketNumbers() bool {
return !vn.isGQUIC()
}
// StreamContributesToConnectionFlowControl says if a stream contributes to connection-level flow control
func (vn VersionNumber) StreamContributesToConnectionFlowControl(id StreamID) bool {
if id == vn.CryptoStreamID() {
return false
}
if vn.isGQUIC() && id == 3 {
return false
}
return true
}
func (vn VersionNumber) isGQUIC() bool {
return vn > gquicVersion0 && vn <= maxGquicVersion
}
func (vn VersionNumber) toGQUICVersion() int {
return int(10*(vn-gquicVersion0)/0x100) + int(vn%0x10)
}
// IsSupportedVersion returns true if the server supports this version
func IsSupportedVersion(supported []VersionNumber, v VersionNumber) bool {
for _, t := range supported {
if t == v {
return true
}
}
return false
}
// ChooseSupportedVersion finds the best version in the overlap of ours and theirs
// ours is a slice of versions that we support, sorted by our preference (descending)
// theirs is a slice of versions offered by the peer. The order does not matter.
// The bool returned indicates if a matching version was found.
func ChooseSupportedVersion(ours, theirs []VersionNumber) (VersionNumber, bool) {
for _, ourVer := range ours {
for _, theirVer := range theirs {
if ourVer == theirVer {
return ourVer, true
}
}
}
return 0, false
}
// generateReservedVersion generates a reserved version number (v & 0x0f0f0f0f == 0x0a0a0a0a)
func generateReservedVersion() VersionNumber {
b := make([]byte, 4)
_, _ = rand.Read(b) // ignore the error here. Failure to read random data doesn't break anything
return VersionNumber((binary.BigEndian.Uint32(b) | 0x0a0a0a0a) & 0xfafafafa)
}
// GetGreasedVersions adds one reserved version number to a slice of version numbers, at a random position
func GetGreasedVersions(supported []VersionNumber) []VersionNumber {
b := make([]byte, 1)
_, _ = rand.Read(b) // ignore the error here. Failure to read random data doesn't break anything
randPos := int(b[0]) % (len(supported) + 1)
greased := make([]VersionNumber, len(supported)+1)
copy(greased, supported[:randPos])
greased[randPos] = generateReservedVersion()
copy(greased[randPos+1:], supported[randPos:])
return greased
}
// StripGreasedVersions strips all greased versions from a slice of versions
func StripGreasedVersions(versions []VersionNumber) []VersionNumber {
realVersions := make([]VersionNumber, 0, len(versions))
for _, v := range versions {
if v&0x0f0f0f0f != 0x0a0a0a0a {
realVersions = append(realVersions, v)
}
}
return realVersions
}

View File

@@ -0,0 +1,22 @@
package utils
import "sync/atomic"
// An AtomicBool is an atomic bool
type AtomicBool struct {
v int32
}
// Set sets the value
func (a *AtomicBool) Set(value bool) {
var n int32
if value {
n = 1
}
atomic.StoreInt32(&a.v, n)
}
// Get gets the value
func (a *AtomicBool) Get() bool {
return atomic.LoadInt32(&a.v) != 0
}

View File

@@ -0,0 +1,217 @@
// This file was automatically generated by genny.
// Any changes will be lost if this file is regenerated.
// see https://github.com/cheekybits/genny
package utils
// Linked list implementation from the Go standard library.
// ByteIntervalElement is an element of a linked list.
type ByteIntervalElement struct {
// Next and previous pointers in the doubly-linked list of elements.
// To simplify the implementation, internally a list l is implemented
// as a ring, such that &l.root is both the next element of the last
// list element (l.Back()) and the previous element of the first list
// element (l.Front()).
next, prev *ByteIntervalElement
// The list to which this element belongs.
list *ByteIntervalList
// The value stored with this element.
Value ByteInterval
}
// Next returns the next list element or nil.
func (e *ByteIntervalElement) Next() *ByteIntervalElement {
if p := e.next; e.list != nil && p != &e.list.root {
return p
}
return nil
}
// Prev returns the previous list element or nil.
func (e *ByteIntervalElement) Prev() *ByteIntervalElement {
if p := e.prev; e.list != nil && p != &e.list.root {
return p
}
return nil
}
// ByteIntervalList is a linked list of ByteIntervals.
type ByteIntervalList struct {
root ByteIntervalElement // sentinel list element, only &root, root.prev, and root.next are used
len int // current list length excluding (this) sentinel element
}
// Init initializes or clears list l.
func (l *ByteIntervalList) Init() *ByteIntervalList {
l.root.next = &l.root
l.root.prev = &l.root
l.len = 0
return l
}
// NewByteIntervalList returns an initialized list.
func NewByteIntervalList() *ByteIntervalList { return new(ByteIntervalList).Init() }
// Len returns the number of elements of list l.
// The complexity is O(1).
func (l *ByteIntervalList) Len() int { return l.len }
// Front returns the first element of list l or nil if the list is empty.
func (l *ByteIntervalList) Front() *ByteIntervalElement {
if l.len == 0 {
return nil
}
return l.root.next
}
// Back returns the last element of list l or nil if the list is empty.
func (l *ByteIntervalList) Back() *ByteIntervalElement {
if l.len == 0 {
return nil
}
return l.root.prev
}
// lazyInit lazily initializes a zero List value.
func (l *ByteIntervalList) lazyInit() {
if l.root.next == nil {
l.Init()
}
}
// insert inserts e after at, increments l.len, and returns e.
func (l *ByteIntervalList) insert(e, at *ByteIntervalElement) *ByteIntervalElement {
n := at.next
at.next = e
e.prev = at
e.next = n
n.prev = e
e.list = l
l.len++
return e
}
// insertValue is a convenience wrapper for insert(&Element{Value: v}, at).
func (l *ByteIntervalList) insertValue(v ByteInterval, at *ByteIntervalElement) *ByteIntervalElement {
return l.insert(&ByteIntervalElement{Value: v}, at)
}
// remove removes e from its list, decrements l.len, and returns e.
func (l *ByteIntervalList) remove(e *ByteIntervalElement) *ByteIntervalElement {
e.prev.next = e.next
e.next.prev = e.prev
e.next = nil // avoid memory leaks
e.prev = nil // avoid memory leaks
e.list = nil
l.len--
return e
}
// Remove removes e from l if e is an element of list l.
// It returns the element value e.Value.
// The element must not be nil.
func (l *ByteIntervalList) Remove(e *ByteIntervalElement) ByteInterval {
if e.list == l {
// if e.list == l, l must have been initialized when e was inserted
// in l or l == nil (e is a zero Element) and l.remove will crash
l.remove(e)
}
return e.Value
}
// PushFront inserts a new element e with value v at the front of list l and returns e.
func (l *ByteIntervalList) PushFront(v ByteInterval) *ByteIntervalElement {
l.lazyInit()
return l.insertValue(v, &l.root)
}
// PushBack inserts a new element e with value v at the back of list l and returns e.
func (l *ByteIntervalList) PushBack(v ByteInterval) *ByteIntervalElement {
l.lazyInit()
return l.insertValue(v, l.root.prev)
}
// InsertBefore inserts a new element e with value v immediately before mark and returns e.
// If mark is not an element of l, the list is not modified.
// The mark must not be nil.
func (l *ByteIntervalList) InsertBefore(v ByteInterval, mark *ByteIntervalElement) *ByteIntervalElement {
if mark.list != l {
return nil
}
// see comment in List.Remove about initialization of l
return l.insertValue(v, mark.prev)
}
// InsertAfter inserts a new element e with value v immediately after mark and returns e.
// If mark is not an element of l, the list is not modified.
// The mark must not be nil.
func (l *ByteIntervalList) InsertAfter(v ByteInterval, mark *ByteIntervalElement) *ByteIntervalElement {
if mark.list != l {
return nil
}
// see comment in List.Remove about initialization of l
return l.insertValue(v, mark)
}
// MoveToFront moves element e to the front of list l.
// If e is not an element of l, the list is not modified.
// The element must not be nil.
func (l *ByteIntervalList) MoveToFront(e *ByteIntervalElement) {
if e.list != l || l.root.next == e {
return
}
// see comment in List.Remove about initialization of l
l.insert(l.remove(e), &l.root)
}
// MoveToBack moves element e to the back of list l.
// If e is not an element of l, the list is not modified.
// The element must not be nil.
func (l *ByteIntervalList) MoveToBack(e *ByteIntervalElement) {
if e.list != l || l.root.prev == e {
return
}
// see comment in List.Remove about initialization of l
l.insert(l.remove(e), l.root.prev)
}
// MoveBefore moves element e to its new position before mark.
// If e or mark is not an element of l, or e == mark, the list is not modified.
// The element and mark must not be nil.
func (l *ByteIntervalList) MoveBefore(e, mark *ByteIntervalElement) {
if e.list != l || e == mark || mark.list != l {
return
}
l.insert(l.remove(e), mark.prev)
}
// MoveAfter moves element e to its new position after mark.
// If e or mark is not an element of l, or e == mark, the list is not modified.
// The element and mark must not be nil.
func (l *ByteIntervalList) MoveAfter(e, mark *ByteIntervalElement) {
if e.list != l || e == mark || mark.list != l {
return
}
l.insert(l.remove(e), mark)
}
// PushBackList inserts a copy of an other list at the back of list l.
// The lists l and other may be the same. They must not be nil.
func (l *ByteIntervalList) PushBackList(other *ByteIntervalList) {
l.lazyInit()
for i, e := other.Len(), other.Front(); i > 0; i, e = i-1, e.Next() {
l.insertValue(e.Value, l.root.prev)
}
}
// PushFrontList inserts a copy of an other list at the front of list l.
// The lists l and other may be the same. They must not be nil.
func (l *ByteIntervalList) PushFrontList(other *ByteIntervalList) {
l.lazyInit()
for i, e := other.Len(), other.Back(); i > 0; i, e = i-1, e.Prev() {
l.insertValue(e.Value, &l.root)
}
}

View File

@@ -0,0 +1,25 @@
package utils
import (
"bytes"
"io"
)
// A ByteOrder specifies how to convert byte sequences into 16-, 32-, or 64-bit unsigned integers.
type ByteOrder interface {
ReadUintN(b io.ByteReader, length uint8) (uint64, error)
ReadUint64(io.ByteReader) (uint64, error)
ReadUint32(io.ByteReader) (uint32, error)
ReadUint16(io.ByteReader) (uint16, error)
WriteUint64(*bytes.Buffer, uint64)
WriteUint56(*bytes.Buffer, uint64)
WriteUint48(*bytes.Buffer, uint64)
WriteUint40(*bytes.Buffer, uint64)
WriteUint32(*bytes.Buffer, uint32)
WriteUint24(*bytes.Buffer, uint32)
WriteUint16(*bytes.Buffer, uint16)
ReadUfloat16(io.ByteReader) (uint64, error)
WriteUfloat16(*bytes.Buffer, uint64)
}

View File

@@ -0,0 +1,157 @@
package utils
import (
"bytes"
"fmt"
"io"
)
// BigEndian is the big-endian implementation of ByteOrder.
var BigEndian ByteOrder = bigEndian{}
type bigEndian struct{}
var _ ByteOrder = &bigEndian{}
// ReadUintN reads N bytes
func (bigEndian) ReadUintN(b io.ByteReader, length uint8) (uint64, error) {
var res uint64
for i := uint8(0); i < length; i++ {
bt, err := b.ReadByte()
if err != nil {
return 0, err
}
res ^= uint64(bt) << ((length - 1 - i) * 8)
}
return res, nil
}
// ReadUint64 reads a uint64
func (bigEndian) ReadUint64(b io.ByteReader) (uint64, error) {
var b1, b2, b3, b4, b5, b6, b7, b8 uint8
var err error
if b8, err = b.ReadByte(); err != nil {
return 0, err
}
if b7, err = b.ReadByte(); err != nil {
return 0, err
}
if b6, err = b.ReadByte(); err != nil {
return 0, err
}
if b5, err = b.ReadByte(); err != nil {
return 0, err
}
if b4, err = b.ReadByte(); err != nil {
return 0, err
}
if b3, err = b.ReadByte(); err != nil {
return 0, err
}
if b2, err = b.ReadByte(); err != nil {
return 0, err
}
if b1, err = b.ReadByte(); err != nil {
return 0, err
}
return uint64(b1) + uint64(b2)<<8 + uint64(b3)<<16 + uint64(b4)<<24 + uint64(b5)<<32 + uint64(b6)<<40 + uint64(b7)<<48 + uint64(b8)<<56, nil
}
// ReadUint32 reads a uint32
func (bigEndian) ReadUint32(b io.ByteReader) (uint32, error) {
var b1, b2, b3, b4 uint8
var err error
if b4, err = b.ReadByte(); err != nil {
return 0, err
}
if b3, err = b.ReadByte(); err != nil {
return 0, err
}
if b2, err = b.ReadByte(); err != nil {
return 0, err
}
if b1, err = b.ReadByte(); err != nil {
return 0, err
}
return uint32(b1) + uint32(b2)<<8 + uint32(b3)<<16 + uint32(b4)<<24, nil
}
// ReadUint16 reads a uint16
func (bigEndian) ReadUint16(b io.ByteReader) (uint16, error) {
var b1, b2 uint8
var err error
if b2, err = b.ReadByte(); err != nil {
return 0, err
}
if b1, err = b.ReadByte(); err != nil {
return 0, err
}
return uint16(b1) + uint16(b2)<<8, nil
}
// WriteUint64 writes a uint64
func (bigEndian) WriteUint64(b *bytes.Buffer, i uint64) {
b.Write([]byte{
uint8(i >> 56), uint8(i >> 48), uint8(i >> 40), uint8(i >> 32),
uint8(i >> 24), uint8(i >> 16), uint8(i >> 8), uint8(i),
})
}
// WriteUint56 writes 56 bit of a uint64
func (bigEndian) WriteUint56(b *bytes.Buffer, i uint64) {
if i >= (1 << 56) {
panic(fmt.Sprintf("%#x doesn't fit into 56 bits", i))
}
b.Write([]byte{
uint8(i >> 48), uint8(i >> 40), uint8(i >> 32),
uint8(i >> 24), uint8(i >> 16), uint8(i >> 8), uint8(i),
})
}
// WriteUint48 writes 48 bit of a uint64
func (bigEndian) WriteUint48(b *bytes.Buffer, i uint64) {
if i >= (1 << 48) {
panic(fmt.Sprintf("%#x doesn't fit into 48 bits", i))
}
b.Write([]byte{
uint8(i >> 40), uint8(i >> 32),
uint8(i >> 24), uint8(i >> 16), uint8(i >> 8), uint8(i),
})
}
// WriteUint40 writes 40 bit of a uint64
func (bigEndian) WriteUint40(b *bytes.Buffer, i uint64) {
if i >= (1 << 40) {
panic(fmt.Sprintf("%#x doesn't fit into 40 bits", i))
}
b.Write([]byte{
uint8(i >> 32),
uint8(i >> 24), uint8(i >> 16), uint8(i >> 8), uint8(i),
})
}
// WriteUint32 writes a uint32
func (bigEndian) WriteUint32(b *bytes.Buffer, i uint32) {
b.Write([]byte{uint8(i >> 24), uint8(i >> 16), uint8(i >> 8), uint8(i)})
}
// WriteUint24 writes 24 bit of a uint32
func (bigEndian) WriteUint24(b *bytes.Buffer, i uint32) {
if i >= (1 << 24) {
panic(fmt.Sprintf("%#x doesn't fit into 24 bits", i))
}
b.Write([]byte{uint8(i >> 16), uint8(i >> 8), uint8(i)})
}
// WriteUint16 writes a uint16
func (bigEndian) WriteUint16(b *bytes.Buffer, i uint16) {
b.Write([]byte{uint8(i >> 8), uint8(i)})
}
func (l bigEndian) ReadUfloat16(b io.ByteReader) (uint64, error) {
return readUfloat16(b, l)
}
func (l bigEndian) WriteUfloat16(b *bytes.Buffer, val uint64) {
writeUfloat16(b, l, val)
}

View File

@@ -0,0 +1,157 @@
package utils
import (
"bytes"
"fmt"
"io"
)
// LittleEndian is the little-endian implementation of ByteOrder.
var LittleEndian ByteOrder = littleEndian{}
type littleEndian struct{}
var _ ByteOrder = &littleEndian{}
// ReadUintN reads N bytes
func (littleEndian) ReadUintN(b io.ByteReader, length uint8) (uint64, error) {
var res uint64
for i := uint8(0); i < length; i++ {
bt, err := b.ReadByte()
if err != nil {
return 0, err
}
res ^= uint64(bt) << (i * 8)
}
return res, nil
}
// ReadUint64 reads a uint64
func (littleEndian) ReadUint64(b io.ByteReader) (uint64, error) {
var b1, b2, b3, b4, b5, b6, b7, b8 uint8
var err error
if b1, err = b.ReadByte(); err != nil {
return 0, err
}
if b2, err = b.ReadByte(); err != nil {
return 0, err
}
if b3, err = b.ReadByte(); err != nil {
return 0, err
}
if b4, err = b.ReadByte(); err != nil {
return 0, err
}
if b5, err = b.ReadByte(); err != nil {
return 0, err
}
if b6, err = b.ReadByte(); err != nil {
return 0, err
}
if b7, err = b.ReadByte(); err != nil {
return 0, err
}
if b8, err = b.ReadByte(); err != nil {
return 0, err
}
return uint64(b1) + uint64(b2)<<8 + uint64(b3)<<16 + uint64(b4)<<24 + uint64(b5)<<32 + uint64(b6)<<40 + uint64(b7)<<48 + uint64(b8)<<56, nil
}
// ReadUint32 reads a uint32
func (littleEndian) ReadUint32(b io.ByteReader) (uint32, error) {
var b1, b2, b3, b4 uint8
var err error
if b1, err = b.ReadByte(); err != nil {
return 0, err
}
if b2, err = b.ReadByte(); err != nil {
return 0, err
}
if b3, err = b.ReadByte(); err != nil {
return 0, err
}
if b4, err = b.ReadByte(); err != nil {
return 0, err
}
return uint32(b1) + uint32(b2)<<8 + uint32(b3)<<16 + uint32(b4)<<24, nil
}
// ReadUint16 reads a uint16
func (littleEndian) ReadUint16(b io.ByteReader) (uint16, error) {
var b1, b2 uint8
var err error
if b1, err = b.ReadByte(); err != nil {
return 0, err
}
if b2, err = b.ReadByte(); err != nil {
return 0, err
}
return uint16(b1) + uint16(b2)<<8, nil
}
// WriteUint64 writes a uint64
func (littleEndian) WriteUint64(b *bytes.Buffer, i uint64) {
b.Write([]byte{
uint8(i), uint8(i >> 8), uint8(i >> 16), uint8(i >> 24),
uint8(i >> 32), uint8(i >> 40), uint8(i >> 48), uint8(i >> 56),
})
}
// WriteUint56 writes 56 bit of a uint64
func (littleEndian) WriteUint56(b *bytes.Buffer, i uint64) {
if i >= (1 << 56) {
panic(fmt.Sprintf("%#x doesn't fit into 56 bits", i))
}
b.Write([]byte{
uint8(i), uint8(i >> 8), uint8(i >> 16), uint8(i >> 24),
uint8(i >> 32), uint8(i >> 40), uint8(i >> 48),
})
}
// WriteUint48 writes 48 bit of a uint64
func (littleEndian) WriteUint48(b *bytes.Buffer, i uint64) {
if i >= (1 << 48) {
panic(fmt.Sprintf("%#x doesn't fit into 48 bits", i))
}
b.Write([]byte{
uint8(i), uint8(i >> 8), uint8(i >> 16), uint8(i >> 24),
uint8(i >> 32), uint8(i >> 40),
})
}
// WriteUint40 writes 40 bit of a uint64
func (littleEndian) WriteUint40(b *bytes.Buffer, i uint64) {
if i >= (1 << 40) {
panic(fmt.Sprintf("%#x doesn't fit into 40 bits", i))
}
b.Write([]byte{
uint8(i), uint8(i >> 8), uint8(i >> 16),
uint8(i >> 24), uint8(i >> 32),
})
}
// WriteUint32 writes a uint32
func (littleEndian) WriteUint32(b *bytes.Buffer, i uint32) {
b.Write([]byte{uint8(i), uint8(i >> 8), uint8(i >> 16), uint8(i >> 24)})
}
// WriteUint24 writes 24 bit of a uint32
func (littleEndian) WriteUint24(b *bytes.Buffer, i uint32) {
if i >= (1 << 24) {
panic(fmt.Sprintf("%#x doesn't fit into 24 bits", i))
}
b.Write([]byte{uint8(i), uint8(i >> 8), uint8(i >> 16)})
}
// WriteUint16 writes a uint16
func (littleEndian) WriteUint16(b *bytes.Buffer, i uint16) {
b.Write([]byte{uint8(i), uint8(i >> 8)})
}
func (l littleEndian) ReadUfloat16(b io.ByteReader) (uint64, error) {
return readUfloat16(b, l)
}
func (l littleEndian) WriteUfloat16(b *bytes.Buffer, val uint64) {
writeUfloat16(b, l, val)
}

View File

@@ -0,0 +1,86 @@
package utils
import (
"bytes"
"io"
"math"
)
// We define an unsigned 16-bit floating point value, inspired by IEEE floats
// (http://en.wikipedia.org/wiki/Half_precision_floating-point_format),
// with 5-bit exponent (bias 1), 11-bit mantissa (effective 12 with hidden
// bit) and denormals, but without signs, transfinites or fractions. Wire format
// 16 bits (little-endian byte order) are split into exponent (high 5) and
// mantissa (low 11) and decoded as:
// uint64_t value;
// if (exponent == 0) value = mantissa;
// else value = (mantissa | 1 << 11) << (exponent - 1)
const uFloat16ExponentBits = 5
const uFloat16MaxExponent = (1 << uFloat16ExponentBits) - 2 // 30
const uFloat16MantissaBits = 16 - uFloat16ExponentBits // 11
const uFloat16MantissaEffectiveBits = uFloat16MantissaBits + 1 // 12
const uFloat16MaxValue = ((uint64(1) << uFloat16MantissaEffectiveBits) - 1) << uFloat16MaxExponent // 0x3FFC0000000
// readUfloat16 reads a float in the QUIC-float16 format and returns its uint64 representation
func readUfloat16(b io.ByteReader, byteOrder ByteOrder) (uint64, error) {
val, err := byteOrder.ReadUint16(b)
if err != nil {
return 0, err
}
res := uint64(val)
if res < (1 << uFloat16MantissaEffectiveBits) {
// Fast path: either the value is denormalized (no hidden bit), or
// normalized (hidden bit set, exponent offset by one) with exponent zero.
// Zero exponent offset by one sets the bit exactly where the hidden bit is.
// So in both cases the value encodes itself.
return res, nil
}
exponent := val >> uFloat16MantissaBits // No sign extend on uint!
// After the fast pass, the exponent is at least one (offset by one).
// Un-offset the exponent.
exponent--
// Here we need to clear the exponent and set the hidden bit. We have already
// decremented the exponent, so when we subtract it, it leaves behind the
// hidden bit.
res -= uint64(exponent) << uFloat16MantissaBits
res <<= exponent
return res, nil
}
// writeUfloat16 writes a float in the QUIC-float16 format from its uint64 representation
func writeUfloat16(b *bytes.Buffer, byteOrder ByteOrder, value uint64) {
var result uint16
if value < (uint64(1) << uFloat16MantissaEffectiveBits) {
// Fast path: either the value is denormalized, or has exponent zero.
// Both cases are represented by the value itself.
result = uint16(value)
} else if value >= uFloat16MaxValue {
// Value is out of range; clamp it to the maximum representable.
result = math.MaxUint16
} else {
// The highest bit is between position 13 and 42 (zero-based), which
// corresponds to exponent 1-30. In the output, mantissa is from 0 to 10,
// hidden bit is 11 and exponent is 11 to 15. Shift the highest bit to 11
// and count the shifts.
exponent := uint16(0)
for offset := uint16(16); offset > 0; offset /= 2 {
// Right-shift the value until the highest bit is in position 11.
// For offset of 16, 8, 4, 2 and 1 (binary search over 1-30),
// shift if the bit is at or above 11 + offset.
if value >= (uint64(1) << (uFloat16MantissaBits + offset)) {
exponent += offset
value >>= offset
}
}
// Hidden bit (position 11) is set. We should remove it and increment the
// exponent. Equivalently, we just add it to the exponent.
// This hides the bit.
result = (uint16(value) + (exponent << uFloat16MantissaBits))
}
byteOrder.WriteUint16(b, result)
}

Some files were not shown because too many files have changed in this diff Show More