TUN-8456: Update quic-go to 0.45 and collect mtu and congestion control metrics

This commit is contained in:
chungthuang
2024-06-07 10:24:19 -05:00
committed by Chung-Ting Huang
parent cb6e5999e1
commit 0b62d45738
241 changed files with 27423 additions and 19798 deletions

View File

@@ -1,21 +1,29 @@
run:
skip-files:
- internal/handshake/cipher_suite.go
linters-settings:
misspell:
ignore-words:
- ect
depguard:
rules:
quicvarint:
list-mode: strict
files:
- "**/github.com/quic-go/quic-go/quicvarint/*"
- "!$test"
allow:
- $gostd
linters:
disable-all: true
enable:
- asciicheck
- depguard
- exhaustive
- exportloopref
- goimports
- gofmt # redundant, since gofmt *should* be a no-op after gofumpt
- gofumpt
- gosimple
- govet
- ineffassign
- misspell
- prealloc
@@ -24,10 +32,14 @@ linters:
- unconvert
- unparam
- unused
- vet
issues:
exclude-files:
- internal/handshake/cipher_suite.go
exclude-rules:
- path: internal/qtls
linters:
- depguard
- path: _test\.go
linters:
- exhaustive

View File

@@ -2,11 +2,12 @@
<img src="docs/quic.png" width=303 height=124>
[![Documentation](https://img.shields.io/badge/docs-quic--go.net-red?style=flat)](https://quic-go.net/docs/)
[![PkgGoDev](https://pkg.go.dev/badge/github.com/quic-go/quic-go)](https://pkg.go.dev/github.com/quic-go/quic-go)
[![Code Coverage](https://img.shields.io/codecov/c/github/quic-go/quic-go/master.svg?style=flat-square)](https://codecov.io/gh/quic-go/quic-go/)
[![Fuzzing Status](https://oss-fuzz-build-logs.storage.googleapis.com/badges/quic-go.svg)](https://bugs.chromium.org/p/oss-fuzz/issues/list?sort=-opened&can=1&q=proj:quic-go)
quic-go is an implementation of the QUIC protocol ([RFC 9000](https://datatracker.ietf.org/doc/html/rfc9000), [RFC 9001](https://datatracker.ietf.org/doc/html/rfc9001), [RFC 9002](https://datatracker.ietf.org/doc/html/rfc9002)) in Go. It has support for HTTP/3 ([RFC 9114](https://datatracker.ietf.org/doc/html/rfc9114)), including QPACK ([RFC 9204](https://datatracker.ietf.org/doc/html/rfc9204)).
quic-go is an implementation of the QUIC protocol ([RFC 9000](https://datatracker.ietf.org/doc/html/rfc9000), [RFC 9001](https://datatracker.ietf.org/doc/html/rfc9001), [RFC 9002](https://datatracker.ietf.org/doc/html/rfc9002)) in Go. It has support for HTTP/3 ([RFC 9114](https://datatracker.ietf.org/doc/html/rfc9114)), including QPACK ([RFC 9204](https://datatracker.ietf.org/doc/html/rfc9204)) and HTTP Datagrams ([RFC 9297](https://datatracker.ietf.org/doc/html/rfc9297)).
In addition to these base RFCs, it also implements the following RFCs:
* Unreliable Datagram Extension ([RFC 9221](https://datatracker.ietf.org/doc/html/rfc9221))
@@ -16,207 +17,7 @@ In addition to these base RFCs, it also implements the following RFCs:
Support for WebTransport over HTTP/3 ([draft-ietf-webtrans-http3](https://datatracker.ietf.org/doc/draft-ietf-webtrans-http3/)) is implemented in [webtransport-go](https://github.com/quic-go/webtransport-go).
## Using QUIC
### Running a Server
The central entry point is the `quic.Transport`. A transport manages QUIC connections running on a single UDP socket. Since QUIC uses Connection IDs, it can demultiplex a listener (accepting incoming connections) and an arbitrary number of outgoing QUIC connections on the same UDP socket.
```go
udpConn, err := net.ListenUDP("udp4", &net.UDPAddr{Port: 1234})
// ... error handling
tr := quic.Transport{
Conn: udpConn,
}
ln, err := tr.Listen(tlsConf, quicConf)
// ... error handling
go func() {
for {
conn, err := ln.Accept()
// ... error handling
// handle the connection, usually in a new Go routine
}
}()
```
The listener `ln` can now be used to accept incoming QUIC connections by (repeatedly) calling the `Accept` method (see below for more information on the `quic.Connection`).
As a shortcut, `quic.Listen` and `quic.ListenAddr` can be used without explicitly initializing a `quic.Transport`:
```
ln, err := quic.Listen(udpConn, tlsConf, quicConf)
```
When using the shortcut, it's not possible to reuse the same UDP socket for outgoing connections.
### Running a Client
As mentioned above, multiple outgoing connections can share a single UDP socket, since QUIC uses Connection IDs to demultiplex connections.
```go
ctx, cancel := context.WithTimeout(context.Background(), 3*time.Second) // 3s handshake timeout
defer cancel()
conn, err := tr.Dial(ctx, <server address>, <tls.Config>, <quic.Config>)
// ... error handling
```
As a shortcut, `quic.Dial` and `quic.DialAddr` can be used without explictly initializing a `quic.Transport`:
```go
ctx, cancel := context.WithTimeout(context.Background(), 3*time.Second) // 3s handshake timeout
defer cancel()
conn, err := quic.Dial(ctx, conn, <server address>, <tls.Config>, <quic.Config>)
```
Just as we saw before when used a similar shortcut to run a server, it's also not possible to reuse the same UDP socket for other outgoing connections, or to listen for incoming connections.
### Using a QUIC Connection
#### Accepting Streams
QUIC is a stream-multiplexed transport. A `quic.Connection` fundamentally differs from the `net.Conn` and the `net.PacketConn` interface defined in the standard library. Data is sent and received on (unidirectional and bidirectional) streams (and, if supported, in [datagrams](#quic-datagrams)), not on the connection itself. The stream state machine is described in detail in [Section 3 of RFC 9000](https://datatracker.ietf.org/doc/html/rfc9000#section-3).
Note: A unidirectional stream is a stream that the initiator can only write to (`quic.SendStream`), and the receiver can only read from (`quic.ReceiveStream`). A bidirectional stream (`quic.Stream`) allows reading from and writing to for both sides.
On the receiver side, streams are accepted using the `AcceptStream` (for bidirectional) and `AcceptUniStream` functions. For most user cases, it makes sense to call these functions in a loop:
```go
for {
str, err := conn.AcceptStream(context.Background()) // for bidirectional streams
// ... error handling
// handle the stream, usually in a new Go routine
}
```
These functions return an error when the underlying QUIC connection is closed.
#### Opening Streams
There are two slightly different ways to open streams, one synchronous and one (potentially) asynchronous. This API is necessary since the receiver grants us a certain number of streams that we're allowed to open. It may grant us additional streams later on (typically when existing streams are closed), but it means that at the time we want to open a new stream, we might not be able to do so.
Using the synchronous method `OpenStreamSync` for bidirectional streams, and `OpenUniStreamSync` for unidirectional streams, an application can block until the peer allows opening additional streams. In case that we're allowed to open a new stream, these methods return right away:
```go
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
defer cancel()
str, err := conn.OpenStreamSync(ctx) // wait up to 5s to open a new bidirectional stream
```
The asynchronous version never blocks. If it's currently not possible to open a new stream, it returns a `net.Error` timeout error:
```go
str, err := conn.OpenStream()
if nerr, ok := err.(net.Error); ok && nerr.Timeout() {
// It's currently not possible to open another stream,
// but it might be possible later, once the peer allowed us to do so.
}
```
These functions return an error when the underlying QUIC connection is closed.
#### Using Streams
Using QUIC streams is pretty straightforward. The `quic.ReceiveStream` implements the `io.Reader` interface, and the `quic.SendStream` implements the `io.Writer` interface. A bidirectional stream (`quic.Stream`) implements both these interfaces. Conceptually, a bidirectional stream can be thought of as the composition of two unidirectional streams in opposite directions.
Calling `Close` on a `quic.SendStream` or a `quic.Stream` closes the send side of the stream. On the receiver side, this will be surfaced as an `io.EOF` returned from the `io.Reader` once all data has been consumed. Note that for bidirectional streams, `Close` _only_ closes the send side of the stream. It is still possible to read from the stream until the peer closes or resets the stream.
In case the application wishes to abort sending on a `quic.SendStream` or a `quic.Stream` , it can reset the send side by calling `CancelWrite` with an application-defined error code (an unsigned 62-bit number). On the receiver side, this surfaced as a `quic.StreamError` containing that error code on the `io.Reader`. Note that for bidirectional streams, `CancelWrite` _only_ resets the send side of the stream. It is still possible to read from the stream until the peer closes or resets the stream.
Conversely, in case the application wishes to abort receiving from a `quic.ReceiveStream` or a `quic.Stream`, it can ask the sender to abort data transmission by calling `CancelRead` with an application-defined error code (an unsigned 62-bit number). On the receiver side, this surfaced as a `quic.StreamError` containing that error code on the `io.Writer`. Note that for bidirectional streams, `CancelWrite` _only_ resets the receive side of the stream. It is still possible to write to the stream.
A bidirectional stream is only closed once both the read and the write side of the stream have been either closed or reset. Only then the peer is granted a new stream according to the maximum number of concurrent streams configured via `quic.Config.MaxIncomingStreams`.
### Configuring QUIC
The `quic.Config` struct passed to both the listen and dial calls (see above) contains a wide range of configuration options for QUIC connections, incl. the ability to fine-tune flow control limits, the number of streams that the peer is allowed to open concurrently, keep-alives, idle timeouts, and many more. Please refer to the documentation for the `quic.Config` for details.
The `quic.Transport` contains a few configuration options that don't apply to any single QUIC connection, but to all connections handled by that transport. It is highly recommend to set the `StatelessResetToken`, which allows endpoints to quickly recover from crashes / reboots of our node (see [Section 10.3 of RFC 9000](https://datatracker.ietf.org/doc/html/rfc9000#section-10.3)).
### Closing a Connection
#### When the remote Peer closes the Connection
In case the peer closes the QUIC connection, all calls to open streams, accept streams, as well as all methods on streams immediately return an error. Additionally, it is set as cancellation cause of the connection context. Users can use errors assertions to find out what exactly went wrong:
* `quic.VersionNegotiationError`: Happens during the handshake, if there is no overlap between our and the remote's supported QUIC versions.
* `quic.HandshakeTimeoutError`: Happens if the QUIC handshake doesn't complete within the time specified in `quic.Config.HandshakeTimeout`.
* `quic.IdleTimeoutError`: Happens after completion of the handshake if the connection is idle for longer than the minimum of both peers idle timeouts (as configured by `quic.Config.IdleTimeout`). The connection is considered idle when no stream data (and datagrams, if applicable) are exchanged for that period. The QUIC connection can be instructed to regularly send a packet to prevent a connection from going idle by setting `quic.Config.KeepAlive`. However, this is no guarantee that the peer doesn't suddenly go away (e.g. by abruptly shutting down the node or by crashing), or by a NAT binding expiring, in which case this error might still occur.
* `quic.StatelessResetError`: Happens when the remote peer lost the state required to decrypt the packet. This requires the `quic.Transport.StatelessResetToken` to be configured by the peer.
* `quic.TransportError`: Happens if when the QUIC protocol is violated. Unless the error code is `APPLICATION_ERROR`, this will not happen unless one of the QUIC stacks involved is misbehaving. Please open an issue if you encounter this error.
* `quic.ApplicationError`: Happens when the remote decides to close the connection, see below.
#### Initiated by the Application
A `quic.Connection` can be closed using `CloseWithError`:
```go
conn.CloseWithError(0x42, "error 0x42 occurred")
```
Applications can transmit both an error code (an unsigned 62-bit number) as well as a UTF-8 encoded human-readable reason. The error code allows the receiver to learn why the connection was closed, and the reason can be useful for debugging purposes.
On the receiver side, this is surfaced as a `quic.ApplicationError`.
### QUIC Datagrams
Unreliable datagrams are a QUIC extension ([RFC 9221](https://datatracker.ietf.org/doc/html/rfc9221)) that is negotiated during the handshake. Support can be enabled by setting the `quic.Config.EnableDatagram` flag. Note that this doesn't guarantee that the peer also supports datagrams. Whether or not the feature negotiation succeeded can be learned from the `quic.ConnectionState.SupportsDatagrams` obtained from `quic.Connection.ConnectionState()`.
QUIC DATAGRAMs are a new QUIC frame type sent in QUIC 1-RTT packets (i.e. after completion of the handshake). Therefore, they're end-to-end encrypted and congestion-controlled. However, if a DATAGRAM frame is deemed lost by QUIC's loss detection mechanism, they are not retransmitted.
Datagrams are sent using the `SendDatagram` method on the `quic.Connection`:
```go
conn.SendDatagram([]byte("foobar"))
```
And received using `ReceiveDatagram`:
```go
msg, err := conn.ReceiveDatagram()
```
Note that this code path is currently not optimized. It works for datagrams that are sent occasionally, but it doesn't achieve the same throughput as writing data on a stream. Please get in touch on issue #3766 if your use case relies on high datagram throughput, or if you'd like to help fix this issue. There are also some restrictions regarding the maximum message size (see #3599).
### QUIC Event Logging using qlog
quic-go logs a wide range of events defined in [draft-ietf-quic-qlog-quic-events](https://datatracker.ietf.org/doc/draft-ietf-quic-qlog-quic-events/), providing comprehensive insights in the internals of a QUIC connection.
qlog files can be processed by a number of 3rd-party tools. [qviz](https://qvis.quictools.info/) has proven very useful for debugging all kinds of QUIC connection failures.
qlog can be activated by setting the `Tracer` callback on the `Config`. It is called as soon as quic-go decides to start the QUIC handshake on a new connection.
`qlog.DefaultTracer` provides a tracer implementation which writes qlog files to a directory specified by the `QLOGDIR` environment variable, if set.
The default qlog tracer can be used like this:
```go
quic.Config{
Tracer: qlog.DefaultTracer,
}
```
This example creates a new qlog file under `<QLOGDIR>/<Original Destination Connection ID>_<Vantage Point>.qlog`, e.g. `qlogs/2e0407da_client.qlog`.
For custom qlog behavior, `qlog.NewConnectionTracer` can be used.
## Using HTTP/3
### As a server
See the [example server](example/main.go). Starting a QUIC server is very similar to the standard library http package in Go:
```go
http.Handle("/", http.FileServer(http.Dir(wwwDir)))
http3.ListenAndServeQUIC("localhost:4242", "/path/to/cert/chain.pem", "/path/to/privkey.pem", nil)
```
### As a client
See the [example client](example/client/main.go). Use a `http3.RoundTripper` as a `Transport` in a `http.Client`.
```go
http.Client{
Transport: &http3.RoundTripper{},
}
```
Detailed documentation can be found on [quic-go.net](https://quic-go.net/docs/).
## Projects using quic-go

View File

@@ -35,7 +35,7 @@ type client struct {
conn quicConn
tracer *logging.ConnectionTracer
tracingID uint64
tracingID ConnectionTracingID
logger utils.Logger
}
@@ -191,6 +191,7 @@ func (c *client) dial(ctx context.Context) error {
c.logger.Infof("Starting new connection to %s (%s -> %s), source connection ID %s, destination connection ID %s, version %s", c.tlsConf.ServerName, c.sendConn.LocalAddr(), c.sendConn.RemoteAddr(), c.srcConnID, c.destConnID, c.version)
c.conn = newClientConnection(
context.WithValue(context.WithoutCancel(ctx), ConnectionTracingKey, c.tracingID),
c.sendConn,
c.packetHandlers,
c.destConnID,
@@ -202,7 +203,6 @@ func (c *client) dial(ctx context.Context) error {
c.use0RTT,
c.hasNegotiatedVersion,
c.tracer,
c.tracingID,
c.logger,
c.version,
)

View File

@@ -39,6 +39,12 @@ func validateConfig(config *Config) error {
if config.MaxConnectionReceiveWindow > quicvarint.Max {
config.MaxConnectionReceiveWindow = quicvarint.Max
}
if config.InitialPacketSize > 0 && config.InitialPacketSize < protocol.MinInitialPacketSize {
config.InitialPacketSize = protocol.MinInitialPacketSize
}
if config.InitialPacketSize > protocol.MaxPacketBufferSize {
config.InitialPacketSize = protocol.MaxPacketBufferSize
}
// check that all QUIC versions are actually supported
for _, v := range config.Versions {
if !protocol.IsValidVersion(v) {
@@ -94,6 +100,10 @@ func populateConfig(config *Config) *Config {
} else if maxIncomingUniStreams < 0 {
maxIncomingUniStreams = 0
}
initialPacketSize := config.InitialPacketSize
if initialPacketSize == 0 {
initialPacketSize = protocol.InitialPacketSize
}
return &Config{
GetConfigForClient: config.GetConfigForClient,
@@ -110,6 +120,7 @@ func populateConfig(config *Config) *Config {
MaxIncomingUniStreams: maxIncomingUniStreams,
TokenStore: config.TokenStore,
EnableDatagrams: config.EnableDatagrams,
InitialPacketSize: initialPacketSize,
DisablePathMTUDiscovery: config.DisablePathMTUDiscovery,
Allow0RTT: config.Allow0RTT,
Tracer: config.Tracer,

View File

@@ -52,7 +52,7 @@ type streamManager interface {
}
type cryptoStreamHandler interface {
StartHandshake() error
StartHandshake(context.Context) error
ChangeConnectionID(protocol.ConnectionID)
SetLargest1RTTAcked(protocol.PacketNumber) error
SetHandshakeConfirmed()
@@ -113,8 +113,8 @@ func (e *errCloseForRecreating) Error() string {
return "closing connection in order to recreate it"
}
var connTracingID uint64 // to be accessed atomically
func nextConnTracingID() uint64 { return atomic.AddUint64(&connTracingID, 1) }
var connTracingID atomic.Uint64 // to be accessed atomically
func nextConnTracingID() ConnectionTracingID { return ConnectionTracingID(connTracingID.Add(1)) }
// A Connection is a QUIC connection
type connection struct {
@@ -153,7 +153,9 @@ type connection struct {
unpacker unpacker
frameParser wire.FrameParser
packer packer
mtuDiscoverer mtuDiscoverer // initialized when the handshake completes
mtuDiscoverer mtuDiscoverer // initialized when the transport parameters are received
maxPayloadSizeEstimate atomic.Uint32
initialStream cryptoStream
handshakeStream cryptoStream
@@ -167,10 +169,9 @@ type connection struct {
// closeChan is used to notify the run loop that it should terminate
closeChan chan closeError
ctx context.Context
ctxCancel context.CancelCauseFunc
handshakeCtx context.Context
handshakeCtxCancel context.CancelFunc
ctx context.Context
ctxCancel context.CancelCauseFunc
handshakeCompleteChan chan struct{}
undecryptablePackets []receivedPacket // undecryptable packets, waiting for a change in encryption level
undecryptablePacketsToProcess []receivedPacket
@@ -220,6 +221,8 @@ var (
)
var newConnection = func(
ctx context.Context,
ctxCancel context.CancelCauseFunc,
conn sendConn,
runner connRunner,
origDestConnID protocol.ConnectionID,
@@ -234,11 +237,12 @@ var newConnection = func(
tokenGenerator *handshake.TokenGenerator,
clientAddressValidated bool,
tracer *logging.ConnectionTracer,
tracingID uint64,
logger utils.Logger,
v protocol.Version,
) quicConn {
s := &connection{
ctx: ctx,
ctxCancel: ctxCancel,
conn: conn,
config: conf,
handshakeDestConnID: destConnID,
@@ -273,10 +277,9 @@ var newConnection = func(
connIDGenerator,
)
s.preSetup()
s.ctx, s.ctxCancel = context.WithCancelCause(context.WithValue(context.Background(), ConnectionTracingKey, tracingID))
s.sentPacketHandler, s.receivedPacketHandler = ackhandler.NewAckHandler(
0,
getMaxPacketSize(s.conn.RemoteAddr()),
protocol.ByteCount(s.config.InitialPacketSize),
s.rttStats,
clientAddressValidated,
s.conn.capabilities().ECN,
@@ -284,7 +287,7 @@ var newConnection = func(
s.tracer,
s.logger,
)
s.mtuDiscoverer = newMTUDiscoverer(s.rttStats, getMaxPacketSize(s.conn.RemoteAddr()), s.sentPacketHandler.SetMaxDatagramSize)
s.maxPayloadSizeEstimate.Store(uint32(estimateMaxPayloadSize(protocol.ByteCount(s.config.InitialPacketSize))))
params := &wire.TransportParameters{
InitialMaxStreamDataBidiLocal: protocol.ByteCount(s.config.InitialStreamReceiveWindow),
InitialMaxStreamDataBidiRemote: protocol.ByteCount(s.config.InitialStreamReceiveWindow),
@@ -295,6 +298,7 @@ var newConnection = func(
MaxUniStreamNum: protocol.StreamNum(s.config.MaxIncomingUniStreams),
MaxAckDelay: protocol.MaxAckDelayInclGranularity,
AckDelayExponent: protocol.AckDelayExponent,
MaxUDPPayloadSize: protocol.MaxPacketBufferSize,
DisableActiveMigration: true,
StatelessResetToken: &statelessResetToken,
OriginalDestinationConnectionID: origDestConnID,
@@ -336,6 +340,7 @@ var newConnection = func(
// declare this as a variable, such that we can it mock it in the tests
var newClientConnection = func(
ctx context.Context,
conn sendConn,
runner connRunner,
destConnID protocol.ConnectionID,
@@ -347,7 +352,6 @@ var newClientConnection = func(
enable0RTT bool,
hasNegotiatedVersion bool,
tracer *logging.ConnectionTracer,
tracingID uint64,
logger utils.Logger,
v protocol.Version,
) quicConn {
@@ -381,11 +385,11 @@ var newClientConnection = func(
s.queueControlFrame,
connIDGenerator,
)
s.ctx, s.ctxCancel = context.WithCancelCause(ctx)
s.preSetup()
s.ctx, s.ctxCancel = context.WithCancelCause(context.WithValue(context.Background(), ConnectionTracingKey, tracingID))
s.sentPacketHandler, s.receivedPacketHandler = ackhandler.NewAckHandler(
initialPacketNumber,
getMaxPacketSize(s.conn.RemoteAddr()),
protocol.ByteCount(s.config.InitialPacketSize),
s.rttStats,
false, // has no effect
s.conn.capabilities().ECN,
@@ -393,7 +397,7 @@ var newClientConnection = func(
s.tracer,
s.logger,
)
s.mtuDiscoverer = newMTUDiscoverer(s.rttStats, getMaxPacketSize(s.conn.RemoteAddr()), s.sentPacketHandler.SetMaxDatagramSize)
s.maxPayloadSizeEstimate.Store(uint32(estimateMaxPayloadSize(protocol.ByteCount(s.config.InitialPacketSize))))
oneRTTStream := newCryptoStream()
params := &wire.TransportParameters{
InitialMaxStreamDataBidiRemote: protocol.ByteCount(s.config.InitialStreamReceiveWindow),
@@ -404,6 +408,7 @@ var newClientConnection = func(
MaxBidiStreamNum: protocol.StreamNum(s.config.MaxIncomingStreams),
MaxUniStreamNum: protocol.StreamNum(s.config.MaxIncomingUniStreams),
MaxAckDelay: protocol.MaxAckDelayInclGranularity,
MaxUDPPayloadSize: protocol.MaxPacketBufferSize,
AckDelayExponent: protocol.AckDelayExponent,
DisableActiveMigration: true,
// For interoperability with quic-go versions before May 2023, this value must be set to a value
@@ -471,6 +476,7 @@ func (s *connection) preSetup() {
)
s.earlyConnReadyChan = make(chan struct{})
s.streamsMap = newStreamsMap(
s.ctx,
s,
s.newFlowController,
uint64(s.config.MaxIncomingStreams),
@@ -481,7 +487,7 @@ func (s *connection) preSetup() {
s.receivedPackets = make(chan receivedPacket, protocol.MaxConnUnprocessedPackets)
s.closeChan = make(chan closeError, 1)
s.sendingScheduled = make(chan struct{}, 1)
s.handshakeCtx, s.handshakeCtxCancel = context.WithCancel(context.Background())
s.handshakeCompleteChan = make(chan struct{})
now := time.Now()
s.lastPacketReceivedTime = now
@@ -495,13 +501,11 @@ func (s *connection) preSetup() {
// run the connection main loop
func (s *connection) run() error {
var closeErr closeError
defer func() {
s.ctxCancel(closeErr.err)
}()
defer func() { s.ctxCancel(closeErr.err) }()
s.timer = *newTimer()
if err := s.cryptoStreamHandler.StartHandshake(); err != nil {
if err := s.cryptoStreamHandler.StartHandshake(s.ctx); err != nil {
return err
}
if err := s.handleHandshakeEvents(); err != nil {
@@ -662,7 +666,7 @@ func (s *connection) earlyConnReady() <-chan struct{} {
}
func (s *connection) HandshakeComplete() <-chan struct{} {
return s.handshakeCtx.Done()
return s.handshakeCompleteChan
}
func (s *connection) Context() context.Context {
@@ -727,7 +731,7 @@ func (s *connection) idleTimeoutStartTime() time.Time {
}
func (s *connection) handleHandshakeComplete() error {
defer s.handshakeCtxCancel()
defer close(s.handshakeCompleteChan)
// Once the handshake completes, we have derived 1-RTT keys.
// There's no point in queueing undecryptable packets for later decryption anymore.
s.undecryptablePackets = nil
@@ -780,11 +784,7 @@ func (s *connection) handleHandshakeConfirmed() error {
s.cryptoStreamHandler.SetHandshakeConfirmed()
if !s.config.DisablePathMTUDiscovery && s.conn.capabilities().DF {
maxPacketSize := s.peerParams.MaxUDPPayloadSize
if maxPacketSize == 0 {
maxPacketSize = protocol.MaxByteCount
}
s.mtuDiscoverer.Start(min(maxPacketSize, protocol.MaxPacketBufferSize))
s.mtuDiscoverer.Start()
}
return nil
}
@@ -1773,6 +1773,17 @@ func (s *connection) applyTransportParameters() {
// Retire the connection ID.
s.connIDManager.AddFromPreferredAddress(params.PreferredAddress.ConnectionID, params.PreferredAddress.StatelessResetToken)
}
maxPacketSize := protocol.ByteCount(protocol.MaxPacketBufferSize)
if params.MaxUDPPayloadSize > 0 && params.MaxUDPPayloadSize < maxPacketSize {
maxPacketSize = params.MaxUDPPayloadSize
}
s.mtuDiscoverer = newMTUDiscoverer(
s.rttStats,
protocol.ByteCount(s.config.InitialPacketSize),
maxPacketSize,
s.onMTUIncreased,
s.tracer,
)
}
func (s *connection) triggerSending(now time.Time) error {
@@ -1861,7 +1872,7 @@ func (s *connection) sendPackets(now time.Time) error {
}
if !s.handshakeConfirmed {
packet, err := s.packer.PackCoalescedPacket(false, s.mtuDiscoverer.CurrentSize(), s.version)
packet, err := s.packer.PackCoalescedPacket(false, s.maxPacketSize(), s.version)
if err != nil || packet == nil {
return err
}
@@ -1888,7 +1899,7 @@ func (s *connection) sendPacketsWithoutGSO(now time.Time) error {
for {
buf := getPacketBuffer()
ecn := s.sentPacketHandler.ECNMode(true)
if _, err := s.appendOneShortHeaderPacket(buf, s.mtuDiscoverer.CurrentSize(), ecn, now); err != nil {
if _, err := s.appendOneShortHeaderPacket(buf, s.maxPacketSize(), ecn, now); err != nil {
if err == errNothingToPack {
buf.Release()
return nil
@@ -1919,7 +1930,7 @@ func (s *connection) sendPacketsWithoutGSO(now time.Time) error {
func (s *connection) sendPacketsWithGSO(now time.Time) error {
buf := getLargePacketBuffer()
maxSize := s.mtuDiscoverer.CurrentSize()
maxSize := s.maxPacketSize()
ecn := s.sentPacketHandler.ECNMode(true)
for {
@@ -1988,7 +1999,7 @@ func (s *connection) resetPacingDeadline() {
func (s *connection) maybeSendAckOnlyPacket(now time.Time) error {
if !s.handshakeConfirmed {
ecn := s.sentPacketHandler.ECNMode(false)
packet, err := s.packer.PackCoalescedPacket(true, s.mtuDiscoverer.CurrentSize(), s.version)
packet, err := s.packer.PackCoalescedPacket(true, s.maxPacketSize(), s.version)
if err != nil {
return err
}
@@ -1999,7 +2010,7 @@ func (s *connection) maybeSendAckOnlyPacket(now time.Time) error {
}
ecn := s.sentPacketHandler.ECNMode(true)
p, buf, err := s.packer.PackAckOnlyPacket(s.mtuDiscoverer.CurrentSize(), s.version)
p, buf, err := s.packer.PackAckOnlyPacket(s.maxPacketSize(), s.version)
if err != nil {
if err == errNothingToPack {
return nil
@@ -2021,7 +2032,7 @@ func (s *connection) sendProbePacket(encLevel protocol.EncryptionLevel, now time
break
}
var err error
packet, err = s.packer.MaybePackProbePacket(encLevel, s.mtuDiscoverer.CurrentSize(), s.version)
packet, err = s.packer.MaybePackProbePacket(encLevel, s.maxPacketSize(), s.version)
if err != nil {
return err
}
@@ -2032,7 +2043,7 @@ func (s *connection) sendProbePacket(encLevel protocol.EncryptionLevel, now time
if packet == nil {
s.retransmissionQueue.AddPing(encLevel)
var err error
packet, err = s.packer.MaybePackProbePacket(encLevel, s.mtuDiscoverer.CurrentSize(), s.version)
packet, err = s.packer.MaybePackProbePacket(encLevel, s.maxPacketSize(), s.version)
if err != nil {
return err
}
@@ -2111,14 +2122,14 @@ func (s *connection) sendConnectionClose(e error) ([]byte, error) {
var transportErr *qerr.TransportError
var applicationErr *qerr.ApplicationError
if errors.As(e, &transportErr) {
packet, err = s.packer.PackConnectionClose(transportErr, s.mtuDiscoverer.CurrentSize(), s.version)
packet, err = s.packer.PackConnectionClose(transportErr, s.maxPacketSize(), s.version)
} else if errors.As(e, &applicationErr) {
packet, err = s.packer.PackApplicationClose(applicationErr, s.mtuDiscoverer.CurrentSize(), s.version)
packet, err = s.packer.PackApplicationClose(applicationErr, s.maxPacketSize(), s.version)
} else {
packet, err = s.packer.PackConnectionClose(&qerr.TransportError{
ErrorCode: qerr.InternalError,
ErrorMessage: fmt.Sprintf("connection BUG: unspecified error type (msg: %s)", e.Error()),
}, s.mtuDiscoverer.CurrentSize(), s.version)
}, s.maxPacketSize(), s.version)
}
if err != nil {
return nil, err
@@ -2128,6 +2139,24 @@ func (s *connection) sendConnectionClose(e error) ([]byte, error) {
return packet.buffer.Data, s.conn.Write(packet.buffer.Data, 0, ecn)
}
func (s *connection) maxPacketSize() protocol.ByteCount {
if s.mtuDiscoverer == nil {
// Use the configured packet size on the client side.
// If the server sends a max_udp_payload_size that's smaller than this size, we can ignore this:
// Apparently the server still processed the (fully padded) Initial packet anyway.
if s.perspective == protocol.PerspectiveClient {
return protocol.ByteCount(s.config.InitialPacketSize)
}
// On the server side, there's no downside to using 1200 bytes until we received the client's transport
// parameters:
// * If the first packet didn't contain the entire ClientHello, all we can do is ACK that packet. We don't
// need a lot of bytes for that.
// * If it did, we will have processed the transport parameters and initialized the MTU discoverer.
return protocol.MinInitialPacketSize
}
return s.mtuDiscoverer.CurrentSize()
}
func (s *connection) logLongHeaderPacket(p *longHeaderPacket, ecn protocol.ECN) {
// quic-go logging
if s.logger.Debug() {
@@ -2351,16 +2380,25 @@ func (s *connection) onStreamCompleted(id protocol.StreamID) {
}
}
func (s *connection) onMTUIncreased(mtu protocol.ByteCount) {
s.maxPayloadSizeEstimate.Store(uint32(estimateMaxPayloadSize(mtu)))
s.sentPacketHandler.SetMaxDatagramSize(mtu)
}
func (s *connection) SendDatagram(p []byte) error {
if !s.supportsDatagrams() {
return errors.New("datagram support disabled")
}
f := &wire.DatagramFrame{DataLenPresent: true}
if protocol.ByteCount(len(p)) > f.MaxDataLen(s.peerParams.MaxDatagramFrameSize, s.version) {
return &DatagramTooLargeError{
PeerMaxDatagramFrameSize: int64(s.peerParams.MaxDatagramFrameSize),
}
// The payload size estimate is conservative.
// Under many circumstances we could send a few more bytes.
maxDataLen := min(
f.MaxDataLen(s.peerParams.MaxDatagramFrameSize, s.version),
protocol.ByteCount(s.maxPayloadSizeEstimate.Load()),
)
if protocol.ByteCount(len(p)) > maxDataLen {
return &DatagramTooLargeError{MaxDatagramPayloadSize: int64(maxDataLen)}
}
f.Data = make([]byte, len(p))
copy(f.Data, p)
@@ -2386,8 +2424,22 @@ func (s *connection) GetVersion() protocol.Version {
return s.version
}
func (s *connection) NextConnection() Connection {
<-s.HandshakeComplete()
s.streamsMap.UseResetMaps()
return s
func (s *connection) NextConnection(ctx context.Context) (Connection, error) {
// The handshake might fail after the server rejected 0-RTT.
// This could happen if the Finished message is malformed or never received.
select {
case <-ctx.Done():
return nil, context.Cause(ctx)
case <-s.Context().Done():
case <-s.HandshakeComplete():
s.streamsMap.UseResetMaps()
}
return s, nil
}
// estimateMaxPayloadSize estimates the maximum payload size for short header packets.
// It is not very sophisticated: it just subtracts the size of header (assuming the maximum
// connection ID length), and the size of the encryption tag.
func estimateMaxPayloadSize(mtu protocol.ByteCount) protocol.ByteCount {
return mtu - 1 /* type byte */ - 20 /* maximum connection ID length */ - 16 /* tag size */
}

View File

@@ -64,7 +64,7 @@ func (e *StreamError) Error() string {
// DatagramTooLargeError is returned from Connection.SendDatagram if the payload is too large to be sent.
type DatagramTooLargeError struct {
PeerMaxDatagramFrameSize int64
MaxDatagramPayloadSize int64
}
func (e *DatagramTooLargeError) Is(target error) bool {

View File

@@ -157,7 +157,7 @@ func (f *framerI) AppendStreamFrames(frames []ackhandler.StreamFrame, maxLen pro
// For the last STREAM frame, we'll remove the DataLen field later.
// Therefore, we can pretend to have more bytes available when popping
// the STREAM frame (which will always have the DataLen set).
remainingLen += quicvarint.Len(uint64(remainingLen))
remainingLen += protocol.ByteCount(quicvarint.Len(uint64(remainingLen)))
frame, ok, hasMoreData := str.popStreamFrame(remainingLen, v)
if hasMoreData { // put the stream back in the queue (at the end)
f.streamQueue.PushBack(id)

View File

@@ -57,8 +57,13 @@ var Err0RTTRejected = errors.New("0-RTT rejected")
// ConnectionTracingKey can be used to associate a ConnectionTracer with a Connection.
// It is set on the Connection.Context() context,
// as well as on the context passed to logging.Tracer.NewConnectionTracer.
// Deprecated: Applications can set their own tracing key using Transport.ConnContext.
var ConnectionTracingKey = connTracingCtxKey{}
// ConnectionTracingID is the type of the context value saved under the ConnectionTracingKey.
// Deprecated: Applications can set their own tracing key using Transport.ConnContext.
type ConnectionTracingID uint64
type connTracingCtxKey struct{}
// QUICVersionContextKey can be used to find out the QUIC version of a TLS handshake from the
@@ -121,7 +126,9 @@ type SendStream interface {
// CancelWrite aborts sending on this stream.
// Data already written, but not yet delivered to the peer is not guaranteed to be delivered reliably.
// Write will unblock immediately, and future calls to Write will fail.
// When called multiple times or after closing the stream it is a no-op.
// When called multiple times it is a no-op.
// When called after Close, it aborts delivery. Note that there is no guarantee if
// the peer will receive the FIN or the reset first.
CancelWrite(StreamErrorCode)
// The Context is canceled as soon as the write-side of the stream is closed.
// This happens when Close() or CancelWrite() is called, or when the peer
@@ -217,7 +224,7 @@ type EarlyConnection interface {
// however the client's identity is only verified once the handshake completes.
HandshakeComplete() <-chan struct{}
NextConnection() Connection
NextConnection(context.Context) (Connection, error)
}
// StatelessResetKey is a key used to derive stateless reset tokens.
@@ -320,10 +327,15 @@ type Config struct {
// If set to 0, then no keep alive is sent. Otherwise, the keep alive is sent on that period (or at most
// every half of MaxIdleTimeout, whichever is smaller).
KeepAlivePeriod time.Duration
// InitialPacketSize is the initial size of packets sent.
// It is usually not necessary to manually set this value,
// since Path MTU discovery very quickly finds the path's MTU.
// If set too high, the path might not support packets that large, leading to a timeout of the QUIC handshake.
// Values below 1200 are invalid.
InitialPacketSize uint16
// DisablePathMTUDiscovery disables Path MTU Discovery (RFC 8899).
// This allows the sending of QUIC packets that fully utilize the available MTU of the path.
// Path MTU discovery is only available on systems that allow setting of the Don't Fragment (DF) bit.
// If unavailable or disabled, packets will be at most 1252 (IPv4) / 1232 (IPv6) bytes in size.
DisablePathMTUDiscovery bool
// Allow0RTT allows the application to decide if a 0-RTT connection attempt should be accepted.
// Only valid for the server.

View File

@@ -17,11 +17,11 @@ import (
// 1024*1024^3 (first 1024 is from 0.100^3)
// where 0.100 is 100 ms which is the scaling round trip time.
const (
cubeScale = 40
cubeCongestionWindowScale = 410
cubeFactor protocol.ByteCount = 1 << cubeScale / cubeCongestionWindowScale / maxDatagramSize
cubeScale = 40
cubeCongestionWindowScale = 410
cubeFactor = 1 << cubeScale / cubeCongestionWindowScale / maxDatagramSize
// TODO: when re-enabling cubic, make sure to use the actual packet size here
maxDatagramSize = protocol.ByteCount(protocol.InitialPacketSizeIPv4)
maxDatagramSize = protocol.ByteCount(protocol.InitialPacketSize)
)
const defaultNumConnections = 1

View File

@@ -12,7 +12,7 @@ import (
const (
// maxDatagramSize is the default maximum packet size used in the Linux TCP implementation.
// Used in QUIC for congestion window computations in bytes.
initialMaxDatagramSize = protocol.ByteCount(protocol.InitialPacketSizeIPv4)
initialMaxDatagramSize = protocol.ByteCount(protocol.InitialPacketSize)
maxBurstPackets = 3
renoBeta = 0.7 // Reno backoff factor.
minCongestionWindowPackets = 2

View File

@@ -111,6 +111,7 @@ func (c *streamFlowController) AddBytesRead(n protocol.ByteCount) {
func (c *streamFlowController) Abandon() {
c.mutex.Lock()
unread := c.highestReceived - c.bytesRead
c.bytesRead = c.highestReceived
c.mutex.Unlock()
if unread > 0 {
c.connection.AddBytesRead(unread)

View File

@@ -1,7 +1,6 @@
package handshake
import (
"bytes"
"context"
"crypto/tls"
"errors"
@@ -124,44 +123,12 @@ func NewCryptoSetupServer(
)
cs.allow0RTT = allow0RTT
quicConf := &tls.QUICConfig{TLSConfig: tlsConf}
qtls.SetupConfigForServer(quicConf, cs.allow0RTT, cs.getDataForSessionTicket, cs.handleSessionTicket)
addConnToClientHelloInfo(quicConf.TLSConfig, localAddr, remoteAddr)
cs.tlsConf = quicConf.TLSConfig
cs.conn = tls.QUICServer(quicConf)
tlsConf = qtls.SetupConfigForServer(tlsConf, localAddr, remoteAddr, cs.getDataForSessionTicket, cs.handleSessionTicket)
cs.tlsConf = tlsConf
cs.conn = tls.QUICServer(&tls.QUICConfig{TLSConfig: tlsConf})
return cs
}
// The tls.Config contains two callbacks that pass in a tls.ClientHelloInfo.
// Since crypto/tls doesn't do it, we need to make sure to set the Conn field with a fake net.Conn
// that allows the caller to get the local and the remote address.
func addConnToClientHelloInfo(conf *tls.Config, localAddr, remoteAddr net.Addr) {
if conf.GetConfigForClient != nil {
gcfc := conf.GetConfigForClient
conf.GetConfigForClient = func(info *tls.ClientHelloInfo) (*tls.Config, error) {
info.Conn = &conn{localAddr: localAddr, remoteAddr: remoteAddr}
c, err := gcfc(info)
if c != nil {
c = c.Clone()
// This won't be necessary anymore once https://github.com/golang/go/issues/63722 is accepted.
c.MinVersion = tls.VersionTLS13
// We're returning a tls.Config here, so we need to apply this recursively.
addConnToClientHelloInfo(c, localAddr, remoteAddr)
}
return c, err
}
}
if conf.GetCertificate != nil {
gc := conf.GetCertificate
conf.GetCertificate = func(info *tls.ClientHelloInfo) (*tls.Certificate, error) {
info.Conn = &conn{localAddr: localAddr, remoteAddr: remoteAddr}
return gc(info)
}
}
}
func newCryptoSetup(
connID protocol.ConnectionID,
tp *wire.TransportParameters,
@@ -204,8 +171,8 @@ func (h *cryptoSetup) SetLargest1RTTAcked(pn protocol.PacketNumber) error {
return h.aead.SetLargestAcked(pn)
}
func (h *cryptoSetup) StartHandshake() error {
err := h.conn.Start(context.WithValue(context.Background(), QUICVersionContextKey, h.version))
func (h *cryptoSetup) StartHandshake(ctx context.Context) error {
err := h.conn.Start(context.WithValue(ctx, QUICVersionContextKey, h.version))
if err != nil {
return wrapError(err)
}
@@ -338,25 +305,26 @@ func (h *cryptoSetup) handleDataFromSessionState(data []byte, earlyData bool) (a
return false
}
func decodeDataFromSessionState(data []byte, earlyData bool) (time.Duration, *wire.TransportParameters, error) {
r := bytes.NewReader(data)
ver, err := quicvarint.Read(r)
func decodeDataFromSessionState(b []byte, earlyData bool) (time.Duration, *wire.TransportParameters, error) {
ver, l, err := quicvarint.Parse(b)
if err != nil {
return 0, nil, err
}
b = b[l:]
if ver != clientSessionStateRevision {
return 0, nil, fmt.Errorf("mismatching version. Got %d, expected %d", ver, clientSessionStateRevision)
}
rttEncoded, err := quicvarint.Read(r)
rttEncoded, l, err := quicvarint.Parse(b)
if err != nil {
return 0, nil, err
}
b = b[l:]
rtt := time.Duration(rttEncoded) * time.Microsecond
if !earlyData {
return rtt, nil, nil
}
var tp wire.TransportParameters
if err := tp.UnmarshalFromSessionTicket(r); err != nil {
if err := tp.UnmarshalFromSessionTicket(b); err != nil {
return 0, nil, err
}
return rtt, &tp, nil
@@ -376,9 +344,7 @@ func (h *cryptoSetup) getDataForSessionTicket() []byte {
// Due to limitations in crypto/tls, it's only possible to generate a single session ticket per connection.
// It is only valid for the server.
func (h *cryptoSetup) GetSessionTicket() ([]byte, error) {
if err := h.conn.SendSessionTicket(tls.QUICSessionTicketOptions{
EarlyData: h.allow0RTT,
}); err != nil {
if err := h.conn.SendSessionTicket(tls.QUICSessionTicketOptions{EarlyData: h.allow0RTT}); err != nil {
// Session tickets might be disabled by tls.Config.SessionTicketsDisabled.
// We can't check h.tlsConfig here, since the actual config might have been obtained from
// the GetConfigForClient callback.

View File

@@ -1,6 +1,7 @@
package handshake
import (
"context"
"crypto/tls"
"errors"
"io"
@@ -91,7 +92,7 @@ type Event struct {
// CryptoSetup handles the handshake and protecting / unprotecting packets
type CryptoSetup interface {
StartHandshake() error
StartHandshake(context.Context) error
io.Closer
ChangeConnectionID(protocol.ConnectionID)
GetSessionTicket() ([]byte, error)

View File

@@ -1,7 +1,6 @@
package handshake
import (
"bytes"
"errors"
"fmt"
"time"
@@ -28,25 +27,26 @@ func (t *sessionTicket) Marshal() []byte {
}
func (t *sessionTicket) Unmarshal(b []byte, using0RTT bool) error {
r := bytes.NewReader(b)
rev, err := quicvarint.Read(r)
rev, l, err := quicvarint.Parse(b)
if err != nil {
return errors.New("failed to read session ticket revision")
}
b = b[l:]
if rev != sessionTicketRevision {
return fmt.Errorf("unknown session ticket revision: %d", rev)
}
rtt, err := quicvarint.Read(r)
rtt, l, err := quicvarint.Parse(b)
if err != nil {
return errors.New("failed to read RTT")
}
b = b[l:]
if using0RTT {
var tp wire.TransportParameters
if err := tp.UnmarshalFromSessionTicket(r); err != nil {
if err := tp.UnmarshalFromSessionTicket(b); err != nil {
return fmt.Errorf("unmarshaling transport parameters from session ticket failed: %s", err.Error())
}
t.Parameters = &tp
} else if r.Len() > 0 {
} else if len(b) > 0 {
return fmt.Errorf("the session ticket has more bytes than expected")
}
t.RTT = time.Duration(rtt) * time.Microsecond

View File

@@ -3,16 +3,13 @@ package protocol
import "time"
// DesiredReceiveBufferSize is the kernel UDP receive buffer size that we'd like to use.
const DesiredReceiveBufferSize = (1 << 20) * 2 // 2 MB
const DesiredReceiveBufferSize = (1 << 20) * 7 // 7 MB
// DesiredSendBufferSize is the kernel UDP send buffer size that we'd like to use.
const DesiredSendBufferSize = (1 << 20) * 2 // 2 MB
const DesiredSendBufferSize = (1 << 20) * 7 // 7 MB
// InitialPacketSizeIPv4 is the maximum packet size that we use for sending IPv4 packets.
const InitialPacketSizeIPv4 = 1252
// InitialPacketSizeIPv6 is the maximum packet size that we use for sending IPv6 packets.
const InitialPacketSizeIPv6 = 1232
// InitialPacketSize is the initial (before Path MTU discovery) maximum packet size used.
const InitialPacketSize = 1280
// MaxCongestionWindowPackets is the maximum congestion window in packet.
const MaxCongestionWindowPackets = 10000

View File

@@ -1,4 +1,4 @@
package handshake
package qtls
import (
"net"

View File

@@ -4,20 +4,23 @@ import (
"bytes"
"crypto/tls"
"fmt"
"net"
"github.com/quic-go/quic-go/internal/protocol"
)
func SetupConfigForServer(qconf *tls.QUICConfig, _ bool, getData func() []byte, handleSessionTicket func([]byte, bool) bool) {
conf := qconf.TLSConfig
func SetupConfigForServer(
conf *tls.Config,
localAddr, remoteAddr net.Addr,
getData func() []byte,
handleSessionTicket func([]byte, bool) bool,
) *tls.Config {
// Workaround for https://github.com/golang/go/issues/60506.
// This initializes the session tickets _before_ cloning the config.
_, _ = conf.DecryptTicket(nil, tls.ConnectionState{})
conf = conf.Clone()
conf.MinVersion = tls.VersionTLS13
qconf.TLSConfig = conf
// add callbacks to save transport parameters into the session ticket
origWrapSession := conf.WrapSession
@@ -58,6 +61,29 @@ func SetupConfigForServer(qconf *tls.QUICConfig, _ bool, getData func() []byte,
return state, nil
}
// The tls.Config contains two callbacks that pass in a tls.ClientHelloInfo.
// Since crypto/tls doesn't do it, we need to make sure to set the Conn field with a fake net.Conn
// that allows the caller to get the local and the remote address.
if conf.GetConfigForClient != nil {
gcfc := conf.GetConfigForClient
conf.GetConfigForClient = func(info *tls.ClientHelloInfo) (*tls.Config, error) {
info.Conn = &conn{localAddr: localAddr, remoteAddr: remoteAddr}
c, err := gcfc(info)
if c != nil {
// We're returning a tls.Config here, so we need to apply this recursively.
c = SetupConfigForServer(c, localAddr, remoteAddr, getData, handleSessionTicket)
}
return c, err
}
}
if conf.GetCertificate != nil {
gc := conf.GetCertificate
conf.GetCertificate = func(info *tls.ClientHelloInfo) (*tls.Certificate, error) {
info.Conn = &conn{localAddr: localAddr, remoteAddr: remoteAddr}
return gc(info)
}
}
return conf
}
func SetupConfigForClient(

View File

@@ -1,7 +1,6 @@
package wire
import (
"bytes"
"errors"
"sort"
"time"
@@ -22,18 +21,21 @@ type AckFrame struct {
}
// parseAckFrame reads an ACK frame
func parseAckFrame(frame *AckFrame, r *bytes.Reader, typ uint64, ackDelayExponent uint8, _ protocol.Version) error {
func parseAckFrame(frame *AckFrame, b []byte, typ uint64, ackDelayExponent uint8, _ protocol.Version) (int, error) {
startLen := len(b)
ecn := typ == ackECNFrameType
la, err := quicvarint.Read(r)
la, l, err := quicvarint.Parse(b)
if err != nil {
return err
return 0, replaceUnexpectedEOF(err)
}
b = b[l:]
largestAcked := protocol.PacketNumber(la)
delay, err := quicvarint.Read(r)
delay, l, err := quicvarint.Parse(b)
if err != nil {
return err
return 0, replaceUnexpectedEOF(err)
}
b = b[l:]
delayTime := time.Duration(delay*1<<ackDelayExponent) * time.Microsecond
if delayTime < 0 {
@@ -42,71 +44,78 @@ func parseAckFrame(frame *AckFrame, r *bytes.Reader, typ uint64, ackDelayExponen
}
frame.DelayTime = delayTime
numBlocks, err := quicvarint.Read(r)
numBlocks, l, err := quicvarint.Parse(b)
if err != nil {
return err
return 0, replaceUnexpectedEOF(err)
}
b = b[l:]
// read the first ACK range
ab, err := quicvarint.Read(r)
ab, l, err := quicvarint.Parse(b)
if err != nil {
return err
return 0, replaceUnexpectedEOF(err)
}
b = b[l:]
ackBlock := protocol.PacketNumber(ab)
if ackBlock > largestAcked {
return errors.New("invalid first ACK range")
return 0, errors.New("invalid first ACK range")
}
smallest := largestAcked - ackBlock
frame.AckRanges = append(frame.AckRanges, AckRange{Smallest: smallest, Largest: largestAcked})
// read all the other ACK ranges
for i := uint64(0); i < numBlocks; i++ {
g, err := quicvarint.Read(r)
g, l, err := quicvarint.Parse(b)
if err != nil {
return err
return 0, replaceUnexpectedEOF(err)
}
b = b[l:]
gap := protocol.PacketNumber(g)
if smallest < gap+2 {
return errInvalidAckRanges
return 0, errInvalidAckRanges
}
largest := smallest - gap - 2
ab, err := quicvarint.Read(r)
ab, l, err := quicvarint.Parse(b)
if err != nil {
return err
return 0, replaceUnexpectedEOF(err)
}
b = b[l:]
ackBlock := protocol.PacketNumber(ab)
if ackBlock > largest {
return errInvalidAckRanges
return 0, errInvalidAckRanges
}
smallest = largest - ackBlock
frame.AckRanges = append(frame.AckRanges, AckRange{Smallest: smallest, Largest: largest})
}
if !frame.validateAckRanges() {
return errInvalidAckRanges
return 0, errInvalidAckRanges
}
if ecn {
ect0, err := quicvarint.Read(r)
ect0, l, err := quicvarint.Parse(b)
if err != nil {
return err
return 0, replaceUnexpectedEOF(err)
}
b = b[l:]
frame.ECT0 = ect0
ect1, err := quicvarint.Read(r)
ect1, l, err := quicvarint.Parse(b)
if err != nil {
return err
return 0, replaceUnexpectedEOF(err)
}
b = b[l:]
frame.ECT1 = ect1
ecnce, err := quicvarint.Read(r)
ecnce, l, err := quicvarint.Parse(b)
if err != nil {
return err
return 0, replaceUnexpectedEOF(err)
}
b = b[l:]
frame.ECNCE = ecnce
}
return nil
return startLen - len(b), nil
}
// Append appends an ACK frame.
@@ -163,7 +172,7 @@ func (f *AckFrame) Length(_ protocol.Version) protocol.ByteCount {
length += quicvarint.Len(f.ECT1)
length += quicvarint.Len(f.ECNCE)
}
return length
return protocol.ByteCount(length)
}
// gets the number of ACK ranges that can be encoded
@@ -174,7 +183,7 @@ func (f *AckFrame) numEncodableAckRanges() int {
for i := 1; i < len(f.AckRanges); i++ {
gap, len := f.encodeAckRange(i)
rangeLen := quicvarint.Len(gap) + quicvarint.Len(len)
if length+rangeLen > protocol.MaxAckFrameSize {
if protocol.ByteCount(length+rangeLen) > protocol.MaxAckFrameSize {
// Writing range i would exceed the MaxAckFrameSize.
// So encode one range less than that.
return i - 1

View File

@@ -1,7 +1,6 @@
package wire
import (
"bytes"
"io"
"github.com/quic-go/quic-go/internal/protocol"
@@ -16,47 +15,45 @@ type ConnectionCloseFrame struct {
ReasonPhrase string
}
func parseConnectionCloseFrame(r *bytes.Reader, typ uint64, _ protocol.Version) (*ConnectionCloseFrame, error) {
func parseConnectionCloseFrame(b []byte, typ uint64, _ protocol.Version) (*ConnectionCloseFrame, int, error) {
startLen := len(b)
f := &ConnectionCloseFrame{IsApplicationError: typ == applicationCloseFrameType}
ec, err := quicvarint.Read(r)
ec, l, err := quicvarint.Parse(b)
if err != nil {
return nil, err
return nil, 0, replaceUnexpectedEOF(err)
}
b = b[l:]
f.ErrorCode = ec
// read the Frame Type, if this is not an application error
if !f.IsApplicationError {
ft, err := quicvarint.Read(r)
ft, l, err := quicvarint.Parse(b)
if err != nil {
return nil, err
return nil, 0, replaceUnexpectedEOF(err)
}
b = b[l:]
f.FrameType = ft
}
var reasonPhraseLen uint64
reasonPhraseLen, err = quicvarint.Read(r)
reasonPhraseLen, l, err = quicvarint.Parse(b)
if err != nil {
return nil, err
return nil, 0, replaceUnexpectedEOF(err)
}
// shortcut to prevent the unnecessary allocation of dataLen bytes
// if the dataLen is larger than the remaining length of the packet
// reading the whole reason phrase would result in EOF when attempting to READ
if int(reasonPhraseLen) > r.Len() {
return nil, io.EOF
b = b[l:]
if int(reasonPhraseLen) > len(b) {
return nil, 0, io.EOF
}
reasonPhrase := make([]byte, reasonPhraseLen)
if _, err := io.ReadFull(r, reasonPhrase); err != nil {
// this should never happen, since we already checked the reasonPhraseLen earlier
return nil, err
}
copy(reasonPhrase, b)
f.ReasonPhrase = string(reasonPhrase)
return f, nil
return f, startLen - len(b) + int(reasonPhraseLen), nil
}
// Length of a written frame
func (f *ConnectionCloseFrame) Length(protocol.Version) protocol.ByteCount {
length := 1 + quicvarint.Len(f.ErrorCode) + quicvarint.Len(uint64(len(f.ReasonPhrase))) + protocol.ByteCount(len(f.ReasonPhrase))
length := 1 + protocol.ByteCount(quicvarint.Len(f.ErrorCode)+quicvarint.Len(uint64(len(f.ReasonPhrase)))) + protocol.ByteCount(len(f.ReasonPhrase))
if !f.IsApplicationError {
length += quicvarint.Len(f.FrameType) // for the frame type
length += protocol.ByteCount(quicvarint.Len(f.FrameType)) // for the frame type
}
return length
}

View File

@@ -1,7 +1,6 @@
package wire
import (
"bytes"
"io"
"github.com/quic-go/quic-go/internal/protocol"
@@ -14,28 +13,28 @@ type CryptoFrame struct {
Data []byte
}
func parseCryptoFrame(r *bytes.Reader, _ protocol.Version) (*CryptoFrame, error) {
func parseCryptoFrame(b []byte, _ protocol.Version) (*CryptoFrame, int, error) {
startLen := len(b)
frame := &CryptoFrame{}
offset, err := quicvarint.Read(r)
offset, l, err := quicvarint.Parse(b)
if err != nil {
return nil, err
return nil, 0, replaceUnexpectedEOF(err)
}
b = b[l:]
frame.Offset = protocol.ByteCount(offset)
dataLen, err := quicvarint.Read(r)
dataLen, l, err := quicvarint.Parse(b)
if err != nil {
return nil, err
return nil, 0, replaceUnexpectedEOF(err)
}
if dataLen > uint64(r.Len()) {
return nil, io.EOF
b = b[l:]
if dataLen > uint64(len(b)) {
return nil, 0, io.EOF
}
if dataLen != 0 {
frame.Data = make([]byte, dataLen)
if _, err := io.ReadFull(r, frame.Data); err != nil {
// this should never happen, since we already checked the dataLen earlier
return nil, err
}
copy(frame.Data, b)
}
return frame, nil
return frame, startLen - len(b) + int(dataLen), nil
}
func (f *CryptoFrame) Append(b []byte, _ protocol.Version) ([]byte, error) {
@@ -48,14 +47,14 @@ func (f *CryptoFrame) Append(b []byte, _ protocol.Version) ([]byte, error) {
// Length of a written frame
func (f *CryptoFrame) Length(_ protocol.Version) protocol.ByteCount {
return 1 + quicvarint.Len(uint64(f.Offset)) + quicvarint.Len(uint64(len(f.Data))) + protocol.ByteCount(len(f.Data))
return protocol.ByteCount(1 + quicvarint.Len(uint64(f.Offset)) + quicvarint.Len(uint64(len(f.Data))) + len(f.Data))
}
// MaxDataLen returns the maximum data length
func (f *CryptoFrame) MaxDataLen(maxSize protocol.ByteCount) protocol.ByteCount {
// pretend that the data size will be 1 bytes
// if it turns out that varint encoding the length will consume 2 bytes, we need to adjust the data length afterwards
headerLen := 1 + quicvarint.Len(uint64(f.Offset)) + 1
headerLen := protocol.ByteCount(1 + quicvarint.Len(uint64(f.Offset)) + 1)
if headerLen > maxSize {
return 0
}

View File

@@ -1,8 +1,6 @@
package wire
import (
"bytes"
"github.com/quic-go/quic-go/internal/protocol"
"github.com/quic-go/quic-go/quicvarint"
)
@@ -12,12 +10,12 @@ type DataBlockedFrame struct {
MaximumData protocol.ByteCount
}
func parseDataBlockedFrame(r *bytes.Reader, _ protocol.Version) (*DataBlockedFrame, error) {
offset, err := quicvarint.Read(r)
func parseDataBlockedFrame(b []byte, _ protocol.Version) (*DataBlockedFrame, int, error) {
offset, l, err := quicvarint.Parse(b)
if err != nil {
return nil, err
return nil, 0, replaceUnexpectedEOF(err)
}
return &DataBlockedFrame{MaximumData: protocol.ByteCount(offset)}, nil
return &DataBlockedFrame{MaximumData: protocol.ByteCount(offset)}, l, nil
}
func (f *DataBlockedFrame) Append(b []byte, version protocol.Version) ([]byte, error) {
@@ -27,5 +25,5 @@ func (f *DataBlockedFrame) Append(b []byte, version protocol.Version) ([]byte, e
// Length of a written frame
func (f *DataBlockedFrame) Length(version protocol.Version) protocol.ByteCount {
return 1 + quicvarint.Len(uint64(f.MaximumData))
return 1 + protocol.ByteCount(quicvarint.Len(uint64(f.MaximumData)))
}

View File

@@ -1,7 +1,6 @@
package wire
import (
"bytes"
"io"
"github.com/quic-go/quic-go/internal/protocol"
@@ -20,29 +19,29 @@ type DatagramFrame struct {
Data []byte
}
func parseDatagramFrame(r *bytes.Reader, typ uint64, _ protocol.Version) (*DatagramFrame, error) {
func parseDatagramFrame(b []byte, typ uint64, _ protocol.Version) (*DatagramFrame, int, error) {
startLen := len(b)
f := &DatagramFrame{}
f.DataLenPresent = typ&0x1 > 0
var length uint64
if f.DataLenPresent {
var err error
len, err := quicvarint.Read(r)
var l int
length, l, err = quicvarint.Parse(b)
if err != nil {
return nil, err
return nil, 0, replaceUnexpectedEOF(err)
}
if len > uint64(r.Len()) {
return nil, io.EOF
b = b[l:]
if length > uint64(len(b)) {
return nil, 0, io.EOF
}
length = len
} else {
length = uint64(r.Len())
length = uint64(len(b))
}
f.Data = make([]byte, length)
if _, err := io.ReadFull(r, f.Data); err != nil {
return nil, err
}
return f, nil
copy(f.Data, b)
return f, startLen - len(b) + int(length), nil
}
func (f *DatagramFrame) Append(b []byte, _ protocol.Version) ([]byte, error) {
@@ -80,7 +79,7 @@ func (f *DatagramFrame) MaxDataLen(maxSize protocol.ByteCount, version protocol.
func (f *DatagramFrame) Length(_ protocol.Version) protocol.ByteCount {
length := 1 + protocol.ByteCount(len(f.Data))
if f.DataLenPresent {
length += quicvarint.Len(uint64(len(f.Data)))
length += protocol.ByteCount(quicvarint.Len(uint64(len(f.Data))))
}
return length
}

View File

@@ -165,7 +165,7 @@ func (h *ExtendedHeader) ParsedLen() protocol.ByteCount {
func (h *ExtendedHeader) GetLength(_ protocol.Version) protocol.ByteCount {
length := 1 /* type byte */ + 4 /* version */ + 1 /* dest conn ID len */ + protocol.ByteCount(h.DestConnectionID.Len()) + 1 /* src conn ID len */ + protocol.ByteCount(h.SrcConnectionID.Len()) + protocol.ByteCount(h.PacketNumberLen) + 2 /* length */
if h.Type == protocol.PacketTypeInitial {
length += quicvarint.Len(uint64(len(h.Token))) + protocol.ByteCount(len(h.Token))
length += protocol.ByteCount(quicvarint.Len(uint64(len(h.Token))) + len(h.Token))
}
return length
}

View File

@@ -1,9 +1,9 @@
package wire
import (
"bytes"
"errors"
"fmt"
"io"
"reflect"
"github.com/quic-go/quic-go/internal/protocol"
@@ -38,8 +38,6 @@ const (
// The FrameParser parses QUIC frames, one by one.
type FrameParser struct {
r bytes.Reader // cached bytes.Reader, so we don't have to repeatedly allocate them
ackDelayExponent uint8
supportsDatagrams bool
@@ -51,7 +49,6 @@ type FrameParser struct {
// NewFrameParser creates a new frame parser.
func NewFrameParser(supportsDatagrams bool) *FrameParser {
return &FrameParser{
r: *bytes.NewReader(nil),
supportsDatagrams: supportsDatagrams,
ackFrame: &AckFrame{},
}
@@ -60,45 +57,46 @@ func NewFrameParser(supportsDatagrams bool) *FrameParser {
// ParseNext parses the next frame.
// It skips PADDING frames.
func (p *FrameParser) ParseNext(data []byte, encLevel protocol.EncryptionLevel, v protocol.Version) (int, Frame, error) {
startLen := len(data)
p.r.Reset(data)
frame, err := p.parseNext(&p.r, encLevel, v)
n := startLen - p.r.Len()
p.r.Reset(nil)
return n, frame, err
frame, l, err := p.parseNext(data, encLevel, v)
return l, frame, err
}
func (p *FrameParser) parseNext(r *bytes.Reader, encLevel protocol.EncryptionLevel, v protocol.Version) (Frame, error) {
for r.Len() != 0 {
typ, err := quicvarint.Read(r)
func (p *FrameParser) parseNext(b []byte, encLevel protocol.EncryptionLevel, v protocol.Version) (Frame, int, error) {
var parsed int
for len(b) != 0 {
typ, l, err := quicvarint.Parse(b)
parsed += l
if err != nil {
return nil, &qerr.TransportError{
return nil, parsed, &qerr.TransportError{
ErrorCode: qerr.FrameEncodingError,
ErrorMessage: err.Error(),
}
}
b = b[l:]
if typ == 0x0 { // skip PADDING frames
continue
}
f, err := p.parseFrame(r, typ, encLevel, v)
f, l, err := p.parseFrame(b, typ, encLevel, v)
parsed += l
if err != nil {
return nil, &qerr.TransportError{
return nil, parsed, &qerr.TransportError{
FrameType: typ,
ErrorCode: qerr.FrameEncodingError,
ErrorMessage: err.Error(),
}
}
return f, nil
return f, parsed, nil
}
return nil, nil
return nil, parsed, nil
}
func (p *FrameParser) parseFrame(r *bytes.Reader, typ uint64, encLevel protocol.EncryptionLevel, v protocol.Version) (Frame, error) {
func (p *FrameParser) parseFrame(b []byte, typ uint64, encLevel protocol.EncryptionLevel, v protocol.Version) (Frame, int, error) {
var frame Frame
var err error
var l int
if typ&0xf8 == 0x8 {
frame, err = parseStreamFrame(r, typ, v)
frame, l, err = parseStreamFrame(b, typ, v)
} else {
switch typ {
case pingFrameType:
@@ -109,43 +107,43 @@ func (p *FrameParser) parseFrame(r *bytes.Reader, typ uint64, encLevel protocol.
ackDelayExponent = protocol.DefaultAckDelayExponent
}
p.ackFrame.Reset()
err = parseAckFrame(p.ackFrame, r, typ, ackDelayExponent, v)
l, err = parseAckFrame(p.ackFrame, b, typ, ackDelayExponent, v)
frame = p.ackFrame
case resetStreamFrameType:
frame, err = parseResetStreamFrame(r, v)
frame, l, err = parseResetStreamFrame(b, v)
case stopSendingFrameType:
frame, err = parseStopSendingFrame(r, v)
frame, l, err = parseStopSendingFrame(b, v)
case cryptoFrameType:
frame, err = parseCryptoFrame(r, v)
frame, l, err = parseCryptoFrame(b, v)
case newTokenFrameType:
frame, err = parseNewTokenFrame(r, v)
frame, l, err = parseNewTokenFrame(b, v)
case maxDataFrameType:
frame, err = parseMaxDataFrame(r, v)
frame, l, err = parseMaxDataFrame(b, v)
case maxStreamDataFrameType:
frame, err = parseMaxStreamDataFrame(r, v)
frame, l, err = parseMaxStreamDataFrame(b, v)
case bidiMaxStreamsFrameType, uniMaxStreamsFrameType:
frame, err = parseMaxStreamsFrame(r, typ, v)
frame, l, err = parseMaxStreamsFrame(b, typ, v)
case dataBlockedFrameType:
frame, err = parseDataBlockedFrame(r, v)
frame, l, err = parseDataBlockedFrame(b, v)
case streamDataBlockedFrameType:
frame, err = parseStreamDataBlockedFrame(r, v)
frame, l, err = parseStreamDataBlockedFrame(b, v)
case bidiStreamBlockedFrameType, uniStreamBlockedFrameType:
frame, err = parseStreamsBlockedFrame(r, typ, v)
frame, l, err = parseStreamsBlockedFrame(b, typ, v)
case newConnectionIDFrameType:
frame, err = parseNewConnectionIDFrame(r, v)
frame, l, err = parseNewConnectionIDFrame(b, v)
case retireConnectionIDFrameType:
frame, err = parseRetireConnectionIDFrame(r, v)
frame, l, err = parseRetireConnectionIDFrame(b, v)
case pathChallengeFrameType:
frame, err = parsePathChallengeFrame(r, v)
frame, l, err = parsePathChallengeFrame(b, v)
case pathResponseFrameType:
frame, err = parsePathResponseFrame(r, v)
frame, l, err = parsePathResponseFrame(b, v)
case connectionCloseFrameType, applicationCloseFrameType:
frame, err = parseConnectionCloseFrame(r, typ, v)
frame, l, err = parseConnectionCloseFrame(b, typ, v)
case handshakeDoneFrameType:
frame = &HandshakeDoneFrame{}
case 0x30, 0x31:
if p.supportsDatagrams {
frame, err = parseDatagramFrame(r, typ, v)
frame, l, err = parseDatagramFrame(b, typ, v)
break
}
fallthrough
@@ -154,12 +152,12 @@ func (p *FrameParser) parseFrame(r *bytes.Reader, typ uint64, encLevel protocol.
}
}
if err != nil {
return nil, err
return nil, 0, err
}
if !p.isAllowedAtEncLevel(frame, encLevel) {
return nil, fmt.Errorf("%s not allowed at encryption level %s", reflect.TypeOf(frame).Elem().Name(), encLevel)
return nil, l, fmt.Errorf("%s not allowed at encryption level %s", reflect.TypeOf(frame).Elem().Name(), encLevel)
}
return frame, nil
return frame, l, nil
}
func (p *FrameParser) isAllowedAtEncLevel(f Frame, encLevel protocol.EncryptionLevel) bool {
@@ -190,3 +188,10 @@ func (p *FrameParser) isAllowedAtEncLevel(f Frame, encLevel protocol.EncryptionL
func (p *FrameParser) SetAckDelayExponent(exp uint8) {
p.ackDelayExponent = exp
}
func replaceUnexpectedEOF(e error) error {
if e == io.ErrUnexpectedEOF {
return io.EOF
}
return e
}

View File

@@ -8,7 +8,6 @@ import (
"io"
"github.com/quic-go/quic-go/internal/protocol"
"github.com/quic-go/quic-go/internal/utils"
"github.com/quic-go/quic-go/quicvarint"
)
@@ -139,18 +138,18 @@ type Header struct {
parsedLen protocol.ByteCount // how many bytes were read while parsing this header
}
// ParsePacket parses a packet.
// If the packet has a long header, the packet is cut according to the length field.
// If we understand the version, the packet is header up unto the packet number.
// ParsePacket parses a long header packet.
// The packet is cut according to the length field.
// If we understand the version, the packet is parsed up unto the packet number.
// Otherwise, only the invariant part of the header is parsed.
func ParsePacket(data []byte) (*Header, []byte, []byte, error) {
if len(data) == 0 || !IsLongHeaderPacket(data[0]) {
return nil, nil, nil, errors.New("not a long header packet")
}
hdr, err := parseHeader(bytes.NewReader(data))
hdr, err := parseHeader(data)
if err != nil {
if err == ErrUnsupportedVersion {
return hdr, nil, nil, ErrUnsupportedVersion
if errors.Is(err, ErrUnsupportedVersion) {
return hdr, nil, nil, err
}
return nil, nil, nil, err
}
@@ -161,55 +160,55 @@ func ParsePacket(data []byte) (*Header, []byte, []byte, error) {
return hdr, data[:packetLen], data[packetLen:], nil
}
// ParseHeader parses the header.
// For short header packets: up to the packet number.
// For long header packets:
// ParseHeader parses the header:
// * if we understand the version: up to the packet number
// * if not, only the invariant part of the header
func parseHeader(b *bytes.Reader) (*Header, error) {
startLen := b.Len()
typeByte, err := b.ReadByte()
if err != nil {
return nil, err
func parseHeader(b []byte) (*Header, error) {
if len(b) == 0 {
return nil, io.EOF
}
typeByte := b[0]
h := &Header{typeByte: typeByte}
err = h.parseLongHeader(b)
h.parsedLen = protocol.ByteCount(startLen - b.Len())
l, err := h.parseLongHeader(b[1:])
h.parsedLen = protocol.ByteCount(l) + 1
return h, err
}
func (h *Header) parseLongHeader(b *bytes.Reader) error {
v, err := utils.BigEndian.ReadUint32(b)
if err != nil {
return err
func (h *Header) parseLongHeader(b []byte) (int, error) {
startLen := len(b)
if len(b) < 5 {
return 0, io.EOF
}
h.Version = protocol.Version(v)
h.Version = protocol.Version(binary.BigEndian.Uint32(b[:4]))
if h.Version != 0 && h.typeByte&0x40 == 0 {
return errors.New("not a QUIC packet")
return startLen - len(b), errors.New("not a QUIC packet")
}
destConnIDLen, err := b.ReadByte()
if err != nil {
return err
destConnIDLen := int(b[4])
if destConnIDLen > protocol.MaxConnIDLen {
return startLen - len(b), protocol.ErrInvalidConnectionIDLen
}
h.DestConnectionID, err = protocol.ReadConnectionID(b, int(destConnIDLen))
if err != nil {
return err
b = b[5:]
if len(b) < destConnIDLen+1 {
return startLen - len(b), io.EOF
}
srcConnIDLen, err := b.ReadByte()
if err != nil {
return err
h.DestConnectionID = protocol.ParseConnectionID(b[:destConnIDLen])
srcConnIDLen := int(b[destConnIDLen])
if srcConnIDLen > protocol.MaxConnIDLen {
return startLen - len(b), protocol.ErrInvalidConnectionIDLen
}
h.SrcConnectionID, err = protocol.ReadConnectionID(b, int(srcConnIDLen))
if err != nil {
return err
b = b[destConnIDLen+1:]
if len(b) < srcConnIDLen {
return startLen - len(b), io.EOF
}
h.SrcConnectionID = protocol.ParseConnectionID(b[:srcConnIDLen])
b = b[srcConnIDLen:]
if h.Version == 0 { // version negotiation packet
return nil
return startLen - len(b), nil
}
// If we don't understand the version, we have no idea how to interpret the rest of the bytes
if !protocol.IsSupportedVersion(protocol.SupportedVersions, h.Version) {
return ErrUnsupportedVersion
return startLen - len(b), ErrUnsupportedVersion
}
if h.Version == protocol.Version2 {
@@ -237,38 +236,35 @@ func (h *Header) parseLongHeader(b *bytes.Reader) error {
}
if h.Type == protocol.PacketTypeRetry {
tokenLen := b.Len() - 16
tokenLen := len(b) - 16
if tokenLen <= 0 {
return io.EOF
return startLen - len(b), io.EOF
}
h.Token = make([]byte, tokenLen)
if _, err := io.ReadFull(b, h.Token); err != nil {
return err
}
_, err := b.Seek(16, io.SeekCurrent)
return err
copy(h.Token, b[:tokenLen])
return startLen - len(b) + tokenLen + 16, nil
}
if h.Type == protocol.PacketTypeInitial {
tokenLen, err := quicvarint.Read(b)
tokenLen, n, err := quicvarint.Parse(b)
if err != nil {
return err
return startLen - len(b), err
}
if tokenLen > uint64(b.Len()) {
return io.EOF
b = b[n:]
if tokenLen > uint64(len(b)) {
return startLen - len(b), io.EOF
}
h.Token = make([]byte, tokenLen)
if _, err := io.ReadFull(b, h.Token); err != nil {
return err
}
copy(h.Token, b[:tokenLen])
b = b[tokenLen:]
}
pl, err := quicvarint.Read(b)
pl, n, err := quicvarint.Parse(b)
if err != nil {
return err
return 0, err
}
h.Length = protocol.ByteCount(pl)
return nil
return startLen - len(b) + n, nil
}
// ParsedLen returns the number of bytes that were consumed when parsing the header

View File

@@ -1,8 +1,6 @@
package wire
import (
"bytes"
"github.com/quic-go/quic-go/internal/protocol"
"github.com/quic-go/quic-go/quicvarint"
)
@@ -13,14 +11,14 @@ type MaxDataFrame struct {
}
// parseMaxDataFrame parses a MAX_DATA frame
func parseMaxDataFrame(r *bytes.Reader, _ protocol.Version) (*MaxDataFrame, error) {
func parseMaxDataFrame(b []byte, _ protocol.Version) (*MaxDataFrame, int, error) {
frame := &MaxDataFrame{}
byteOffset, err := quicvarint.Read(r)
byteOffset, l, err := quicvarint.Parse(b)
if err != nil {
return nil, err
return nil, 0, replaceUnexpectedEOF(err)
}
frame.MaximumData = protocol.ByteCount(byteOffset)
return frame, nil
return frame, l, nil
}
func (f *MaxDataFrame) Append(b []byte, _ protocol.Version) ([]byte, error) {
@@ -31,5 +29,5 @@ func (f *MaxDataFrame) Append(b []byte, _ protocol.Version) ([]byte, error) {
// Length of a written frame
func (f *MaxDataFrame) Length(_ protocol.Version) protocol.ByteCount {
return 1 + quicvarint.Len(uint64(f.MaximumData))
return 1 + protocol.ByteCount(quicvarint.Len(uint64(f.MaximumData)))
}

View File

@@ -1,8 +1,6 @@
package wire
import (
"bytes"
"github.com/quic-go/quic-go/internal/protocol"
"github.com/quic-go/quic-go/quicvarint"
)
@@ -13,23 +11,26 @@ type MaxStreamDataFrame struct {
MaximumStreamData protocol.ByteCount
}
func parseMaxStreamDataFrame(r *bytes.Reader, _ protocol.Version) (*MaxStreamDataFrame, error) {
sid, err := quicvarint.Read(r)
func parseMaxStreamDataFrame(b []byte, _ protocol.Version) (*MaxStreamDataFrame, int, error) {
startLen := len(b)
sid, l, err := quicvarint.Parse(b)
if err != nil {
return nil, err
return nil, 0, replaceUnexpectedEOF(err)
}
offset, err := quicvarint.Read(r)
b = b[l:]
offset, l, err := quicvarint.Parse(b)
if err != nil {
return nil, err
return nil, 0, replaceUnexpectedEOF(err)
}
b = b[l:]
return &MaxStreamDataFrame{
StreamID: protocol.StreamID(sid),
MaximumStreamData: protocol.ByteCount(offset),
}, nil
}, startLen - len(b), nil
}
func (f *MaxStreamDataFrame) Append(b []byte, version protocol.Version) ([]byte, error) {
func (f *MaxStreamDataFrame) Append(b []byte, _ protocol.Version) ([]byte, error) {
b = append(b, maxStreamDataFrameType)
b = quicvarint.Append(b, uint64(f.StreamID))
b = quicvarint.Append(b, uint64(f.MaximumStreamData))
@@ -37,6 +38,6 @@ func (f *MaxStreamDataFrame) Append(b []byte, version protocol.Version) ([]byte,
}
// Length of a written frame
func (f *MaxStreamDataFrame) Length(version protocol.Version) protocol.ByteCount {
return 1 + quicvarint.Len(uint64(f.StreamID)) + quicvarint.Len(uint64(f.MaximumStreamData))
func (f *MaxStreamDataFrame) Length(protocol.Version) protocol.ByteCount {
return 1 + protocol.ByteCount(quicvarint.Len(uint64(f.StreamID))+quicvarint.Len(uint64(f.MaximumStreamData)))
}

View File

@@ -1,7 +1,6 @@
package wire
import (
"bytes"
"fmt"
"github.com/quic-go/quic-go/internal/protocol"
@@ -14,7 +13,7 @@ type MaxStreamsFrame struct {
MaxStreamNum protocol.StreamNum
}
func parseMaxStreamsFrame(r *bytes.Reader, typ uint64, _ protocol.Version) (*MaxStreamsFrame, error) {
func parseMaxStreamsFrame(b []byte, typ uint64, _ protocol.Version) (*MaxStreamsFrame, int, error) {
f := &MaxStreamsFrame{}
switch typ {
case bidiMaxStreamsFrameType:
@@ -22,15 +21,15 @@ func parseMaxStreamsFrame(r *bytes.Reader, typ uint64, _ protocol.Version) (*Max
case uniMaxStreamsFrameType:
f.Type = protocol.StreamTypeUni
}
streamID, err := quicvarint.Read(r)
streamID, l, err := quicvarint.Parse(b)
if err != nil {
return nil, err
return nil, 0, replaceUnexpectedEOF(err)
}
f.MaxStreamNum = protocol.StreamNum(streamID)
if f.MaxStreamNum > protocol.MaxStreamCount {
return nil, fmt.Errorf("%d exceeds the maximum stream count", f.MaxStreamNum)
return nil, 0, fmt.Errorf("%d exceeds the maximum stream count", f.MaxStreamNum)
}
return f, nil
return f, l, nil
}
func (f *MaxStreamsFrame) Append(b []byte, _ protocol.Version) ([]byte, error) {
@@ -46,5 +45,5 @@ func (f *MaxStreamsFrame) Append(b []byte, _ protocol.Version) ([]byte, error) {
// Length of a written frame
func (f *MaxStreamsFrame) Length(protocol.Version) protocol.ByteCount {
return 1 + quicvarint.Len(uint64(f.MaxStreamNum))
return 1 + protocol.ByteCount(quicvarint.Len(uint64(f.MaxStreamNum)))
}

View File

@@ -1,7 +1,6 @@
package wire
import (
"bytes"
"errors"
"fmt"
"io"
@@ -18,43 +17,47 @@ type NewConnectionIDFrame struct {
StatelessResetToken protocol.StatelessResetToken
}
func parseNewConnectionIDFrame(r *bytes.Reader, _ protocol.Version) (*NewConnectionIDFrame, error) {
seq, err := quicvarint.Read(r)
func parseNewConnectionIDFrame(b []byte, _ protocol.Version) (*NewConnectionIDFrame, int, error) {
startLen := len(b)
seq, l, err := quicvarint.Parse(b)
if err != nil {
return nil, err
return nil, 0, replaceUnexpectedEOF(err)
}
ret, err := quicvarint.Read(r)
b = b[l:]
ret, l, err := quicvarint.Parse(b)
if err != nil {
return nil, err
return nil, 0, replaceUnexpectedEOF(err)
}
b = b[l:]
if ret > seq {
//nolint:stylecheck
return nil, fmt.Errorf("Retire Prior To value (%d) larger than Sequence Number (%d)", ret, seq)
return nil, 0, fmt.Errorf("Retire Prior To value (%d) larger than Sequence Number (%d)", ret, seq)
}
connIDLen, err := r.ReadByte()
if err != nil {
return nil, err
if len(b) == 0 {
return nil, 0, io.EOF
}
connIDLen := int(b[0])
b = b[1:]
if connIDLen == 0 {
return nil, errors.New("invalid zero-length connection ID")
return nil, 0, errors.New("invalid zero-length connection ID")
}
connID, err := protocol.ReadConnectionID(r, int(connIDLen))
if err != nil {
return nil, err
if connIDLen > protocol.MaxConnIDLen {
return nil, 0, protocol.ErrInvalidConnectionIDLen
}
if len(b) < connIDLen {
return nil, 0, io.EOF
}
frame := &NewConnectionIDFrame{
SequenceNumber: seq,
RetirePriorTo: ret,
ConnectionID: connID,
ConnectionID: protocol.ParseConnectionID(b[:connIDLen]),
}
if _, err := io.ReadFull(r, frame.StatelessResetToken[:]); err != nil {
if err == io.ErrUnexpectedEOF {
return nil, io.EOF
}
return nil, err
b = b[connIDLen:]
if len(b) < len(frame.StatelessResetToken) {
return nil, 0, io.EOF
}
return frame, nil
copy(frame.StatelessResetToken[:], b)
return frame, startLen - len(b) + len(frame.StatelessResetToken), nil
}
func (f *NewConnectionIDFrame) Append(b []byte, _ protocol.Version) ([]byte, error) {
@@ -73,5 +76,5 @@ func (f *NewConnectionIDFrame) Append(b []byte, _ protocol.Version) ([]byte, err
// Length of a written frame
func (f *NewConnectionIDFrame) Length(protocol.Version) protocol.ByteCount {
return 1 + quicvarint.Len(f.SequenceNumber) + quicvarint.Len(f.RetirePriorTo) + 1 /* connection ID length */ + protocol.ByteCount(f.ConnectionID.Len()) + 16
return 1 + protocol.ByteCount(quicvarint.Len(f.SequenceNumber)+quicvarint.Len(f.RetirePriorTo)+1 /* connection ID length */ +f.ConnectionID.Len()) + 16
}

View File

@@ -1,7 +1,6 @@
package wire
import (
"bytes"
"errors"
"io"
@@ -14,22 +13,21 @@ type NewTokenFrame struct {
Token []byte
}
func parseNewTokenFrame(r *bytes.Reader, _ protocol.Version) (*NewTokenFrame, error) {
tokenLen, err := quicvarint.Read(r)
func parseNewTokenFrame(b []byte, _ protocol.Version) (*NewTokenFrame, int, error) {
tokenLen, l, err := quicvarint.Parse(b)
if err != nil {
return nil, err
}
if uint64(r.Len()) < tokenLen {
return nil, io.EOF
return nil, 0, replaceUnexpectedEOF(err)
}
b = b[l:]
if tokenLen == 0 {
return nil, errors.New("token must not be empty")
return nil, 0, errors.New("token must not be empty")
}
if uint64(len(b)) < tokenLen {
return nil, 0, io.EOF
}
token := make([]byte, int(tokenLen))
if _, err := io.ReadFull(r, token); err != nil {
return nil, err
}
return &NewTokenFrame{Token: token}, nil
copy(token, b)
return &NewTokenFrame{Token: token}, l + int(tokenLen), nil
}
func (f *NewTokenFrame) Append(b []byte, _ protocol.Version) ([]byte, error) {
@@ -41,5 +39,5 @@ func (f *NewTokenFrame) Append(b []byte, _ protocol.Version) ([]byte, error) {
// Length of a written frame
func (f *NewTokenFrame) Length(protocol.Version) protocol.ByteCount {
return 1 + quicvarint.Len(uint64(len(f.Token))) + protocol.ByteCount(len(f.Token))
return 1 + protocol.ByteCount(quicvarint.Len(uint64(len(f.Token)))+len(f.Token))
}

View File

@@ -1,7 +1,6 @@
package wire
import (
"bytes"
"io"
"github.com/quic-go/quic-go/internal/protocol"
@@ -12,15 +11,13 @@ type PathChallengeFrame struct {
Data [8]byte
}
func parsePathChallengeFrame(r *bytes.Reader, _ protocol.Version) (*PathChallengeFrame, error) {
frame := &PathChallengeFrame{}
if _, err := io.ReadFull(r, frame.Data[:]); err != nil {
if err == io.ErrUnexpectedEOF {
return nil, io.EOF
}
return nil, err
func parsePathChallengeFrame(b []byte, _ protocol.Version) (*PathChallengeFrame, int, error) {
f := &PathChallengeFrame{}
if len(b) < 8 {
return nil, 0, io.EOF
}
return frame, nil
copy(f.Data[:], b)
return f, 8, nil
}
func (f *PathChallengeFrame) Append(b []byte, _ protocol.Version) ([]byte, error) {

View File

@@ -1,7 +1,6 @@
package wire
import (
"bytes"
"io"
"github.com/quic-go/quic-go/internal/protocol"
@@ -12,15 +11,13 @@ type PathResponseFrame struct {
Data [8]byte
}
func parsePathResponseFrame(r *bytes.Reader, _ protocol.Version) (*PathResponseFrame, error) {
frame := &PathResponseFrame{}
if _, err := io.ReadFull(r, frame.Data[:]); err != nil {
if err == io.ErrUnexpectedEOF {
return nil, io.EOF
}
return nil, err
func parsePathResponseFrame(b []byte, _ protocol.Version) (*PathResponseFrame, int, error) {
f := &PathResponseFrame{}
if len(b) < 8 {
return nil, 0, io.EOF
}
return frame, nil
copy(f.Data[:], b)
return f, 8, nil
}
func (f *PathResponseFrame) Append(b []byte, _ protocol.Version) ([]byte, error) {

View File

@@ -1,8 +1,6 @@
package wire
import (
"bytes"
"github.com/quic-go/quic-go/internal/protocol"
"github.com/quic-go/quic-go/internal/qerr"
"github.com/quic-go/quic-go/quicvarint"
@@ -15,21 +13,24 @@ type ResetStreamFrame struct {
FinalSize protocol.ByteCount
}
func parseResetStreamFrame(r *bytes.Reader, _ protocol.Version) (*ResetStreamFrame, error) {
func parseResetStreamFrame(b []byte, _ protocol.Version) (*ResetStreamFrame, int, error) {
startLen := len(b)
var streamID protocol.StreamID
var byteOffset protocol.ByteCount
sid, err := quicvarint.Read(r)
sid, l, err := quicvarint.Parse(b)
if err != nil {
return nil, err
return nil, 0, replaceUnexpectedEOF(err)
}
b = b[l:]
streamID = protocol.StreamID(sid)
errorCode, err := quicvarint.Read(r)
errorCode, l, err := quicvarint.Parse(b)
if err != nil {
return nil, err
return nil, 0, replaceUnexpectedEOF(err)
}
bo, err := quicvarint.Read(r)
b = b[l:]
bo, l, err := quicvarint.Parse(b)
if err != nil {
return nil, err
return nil, 0, replaceUnexpectedEOF(err)
}
byteOffset = protocol.ByteCount(bo)
@@ -37,7 +38,7 @@ func parseResetStreamFrame(r *bytes.Reader, _ protocol.Version) (*ResetStreamFra
StreamID: streamID,
ErrorCode: qerr.StreamErrorCode(errorCode),
FinalSize: byteOffset,
}, nil
}, startLen - len(b) + l, nil
}
func (f *ResetStreamFrame) Append(b []byte, _ protocol.Version) ([]byte, error) {
@@ -49,6 +50,6 @@ func (f *ResetStreamFrame) Append(b []byte, _ protocol.Version) ([]byte, error)
}
// Length of a written frame
func (f *ResetStreamFrame) Length(version protocol.Version) protocol.ByteCount {
return 1 + quicvarint.Len(uint64(f.StreamID)) + quicvarint.Len(uint64(f.ErrorCode)) + quicvarint.Len(uint64(f.FinalSize))
func (f *ResetStreamFrame) Length(protocol.Version) protocol.ByteCount {
return 1 + protocol.ByteCount(quicvarint.Len(uint64(f.StreamID))+quicvarint.Len(uint64(f.ErrorCode))+quicvarint.Len(uint64(f.FinalSize)))
}

View File

@@ -1,8 +1,6 @@
package wire
import (
"bytes"
"github.com/quic-go/quic-go/internal/protocol"
"github.com/quic-go/quic-go/quicvarint"
)
@@ -12,12 +10,12 @@ type RetireConnectionIDFrame struct {
SequenceNumber uint64
}
func parseRetireConnectionIDFrame(r *bytes.Reader, _ protocol.Version) (*RetireConnectionIDFrame, error) {
seq, err := quicvarint.Read(r)
func parseRetireConnectionIDFrame(b []byte, _ protocol.Version) (*RetireConnectionIDFrame, int, error) {
seq, l, err := quicvarint.Parse(b)
if err != nil {
return nil, err
return nil, 0, replaceUnexpectedEOF(err)
}
return &RetireConnectionIDFrame{SequenceNumber: seq}, nil
return &RetireConnectionIDFrame{SequenceNumber: seq}, l, nil
}
func (f *RetireConnectionIDFrame) Append(b []byte, _ protocol.Version) ([]byte, error) {
@@ -28,5 +26,5 @@ func (f *RetireConnectionIDFrame) Append(b []byte, _ protocol.Version) ([]byte,
// Length of a written frame
func (f *RetireConnectionIDFrame) Length(protocol.Version) protocol.ByteCount {
return 1 + quicvarint.Len(f.SequenceNumber)
return 1 + protocol.ByteCount(quicvarint.Len(f.SequenceNumber))
}

View File

@@ -1,8 +1,6 @@
package wire
import (
"bytes"
"github.com/quic-go/quic-go/internal/protocol"
"github.com/quic-go/quic-go/internal/qerr"
"github.com/quic-go/quic-go/quicvarint"
@@ -15,25 +13,28 @@ type StopSendingFrame struct {
}
// parseStopSendingFrame parses a STOP_SENDING frame
func parseStopSendingFrame(r *bytes.Reader, _ protocol.Version) (*StopSendingFrame, error) {
streamID, err := quicvarint.Read(r)
func parseStopSendingFrame(b []byte, _ protocol.Version) (*StopSendingFrame, int, error) {
startLen := len(b)
streamID, l, err := quicvarint.Parse(b)
if err != nil {
return nil, err
return nil, 0, replaceUnexpectedEOF(err)
}
errorCode, err := quicvarint.Read(r)
b = b[l:]
errorCode, l, err := quicvarint.Parse(b)
if err != nil {
return nil, err
return nil, 0, replaceUnexpectedEOF(err)
}
b = b[l:]
return &StopSendingFrame{
StreamID: protocol.StreamID(streamID),
ErrorCode: qerr.StreamErrorCode(errorCode),
}, nil
}, startLen - len(b), nil
}
// Length of a written frame
func (f *StopSendingFrame) Length(_ protocol.Version) protocol.ByteCount {
return 1 + quicvarint.Len(uint64(f.StreamID)) + quicvarint.Len(uint64(f.ErrorCode))
return 1 + protocol.ByteCount(quicvarint.Len(uint64(f.StreamID))+quicvarint.Len(uint64(f.ErrorCode)))
}
func (f *StopSendingFrame) Append(b []byte, _ protocol.Version) ([]byte, error) {

View File

@@ -1,8 +1,6 @@
package wire
import (
"bytes"
"github.com/quic-go/quic-go/internal/protocol"
"github.com/quic-go/quic-go/quicvarint"
)
@@ -13,20 +11,22 @@ type StreamDataBlockedFrame struct {
MaximumStreamData protocol.ByteCount
}
func parseStreamDataBlockedFrame(r *bytes.Reader, _ protocol.Version) (*StreamDataBlockedFrame, error) {
sid, err := quicvarint.Read(r)
func parseStreamDataBlockedFrame(b []byte, _ protocol.Version) (*StreamDataBlockedFrame, int, error) {
startLen := len(b)
sid, l, err := quicvarint.Parse(b)
if err != nil {
return nil, err
return nil, 0, replaceUnexpectedEOF(err)
}
offset, err := quicvarint.Read(r)
b = b[l:]
offset, l, err := quicvarint.Parse(b)
if err != nil {
return nil, err
return nil, 0, replaceUnexpectedEOF(err)
}
return &StreamDataBlockedFrame{
StreamID: protocol.StreamID(sid),
MaximumStreamData: protocol.ByteCount(offset),
}, nil
}, startLen - len(b) + l, nil
}
func (f *StreamDataBlockedFrame) Append(b []byte, _ protocol.Version) ([]byte, error) {
@@ -37,6 +37,6 @@ func (f *StreamDataBlockedFrame) Append(b []byte, _ protocol.Version) ([]byte, e
}
// Length of a written frame
func (f *StreamDataBlockedFrame) Length(version protocol.Version) protocol.ByteCount {
return 1 + quicvarint.Len(uint64(f.StreamID)) + quicvarint.Len(uint64(f.MaximumStreamData))
func (f *StreamDataBlockedFrame) Length(protocol.Version) protocol.ByteCount {
return 1 + protocol.ByteCount(quicvarint.Len(uint64(f.StreamID))+quicvarint.Len(uint64(f.MaximumStreamData)))
}

View File

@@ -1,7 +1,6 @@
package wire
import (
"bytes"
"errors"
"io"
@@ -20,33 +19,41 @@ type StreamFrame struct {
fromPool bool
}
func parseStreamFrame(r *bytes.Reader, typ uint64, _ protocol.Version) (*StreamFrame, error) {
func parseStreamFrame(b []byte, typ uint64, _ protocol.Version) (*StreamFrame, int, error) {
startLen := len(b)
hasOffset := typ&0b100 > 0
fin := typ&0b1 > 0
hasDataLen := typ&0b10 > 0
streamID, err := quicvarint.Read(r)
streamID, l, err := quicvarint.Parse(b)
if err != nil {
return nil, err
return nil, 0, replaceUnexpectedEOF(err)
}
b = b[l:]
var offset uint64
if hasOffset {
offset, err = quicvarint.Read(r)
offset, l, err = quicvarint.Parse(b)
if err != nil {
return nil, err
return nil, 0, replaceUnexpectedEOF(err)
}
b = b[l:]
}
var dataLen uint64
if hasDataLen {
var err error
dataLen, err = quicvarint.Read(r)
var l int
dataLen, l, err = quicvarint.Parse(b)
if err != nil {
return nil, err
return nil, 0, replaceUnexpectedEOF(err)
}
b = b[l:]
if dataLen > uint64(len(b)) {
return nil, 0, io.EOF
}
} else {
// The rest of the packet is data
dataLen = uint64(r.Len())
dataLen = uint64(len(b))
}
var frame *StreamFrame
@@ -57,7 +64,7 @@ func parseStreamFrame(r *bytes.Reader, typ uint64, _ protocol.Version) (*StreamF
// The STREAM frame can't be larger than the StreamFrame we obtained from the buffer,
// since those StreamFrames have a buffer length of the maximum packet size.
if dataLen > uint64(cap(frame.Data)) {
return nil, io.EOF
return nil, 0, io.EOF
}
frame.Data = frame.Data[:dataLen]
}
@@ -68,17 +75,14 @@ func parseStreamFrame(r *bytes.Reader, typ uint64, _ protocol.Version) (*StreamF
frame.DataLenPresent = hasDataLen
if dataLen != 0 {
if _, err := io.ReadFull(r, frame.Data); err != nil {
return nil, err
}
copy(frame.Data, b)
}
if frame.Offset+frame.DataLen() > protocol.MaxByteCount {
return nil, errors.New("stream data overflows maximum offset")
return nil, 0, errors.New("stream data overflows maximum offset")
}
return frame, nil
return frame, startLen - len(b) + int(dataLen), nil
}
// Write writes a STREAM frame
func (f *StreamFrame) Append(b []byte, _ protocol.Version) ([]byte, error) {
if len(f.Data) == 0 && !f.Fin {
return nil, errors.New("StreamFrame: attempting to write empty frame without FIN")
@@ -108,7 +112,7 @@ func (f *StreamFrame) Append(b []byte, _ protocol.Version) ([]byte, error) {
}
// Length returns the total length of the STREAM frame
func (f *StreamFrame) Length(version protocol.Version) protocol.ByteCount {
func (f *StreamFrame) Length(protocol.Version) protocol.ByteCount {
length := 1 + quicvarint.Len(uint64(f.StreamID))
if f.Offset != 0 {
length += quicvarint.Len(uint64(f.Offset))
@@ -116,7 +120,7 @@ func (f *StreamFrame) Length(version protocol.Version) protocol.ByteCount {
if f.DataLenPresent {
length += quicvarint.Len(uint64(f.DataLen()))
}
return length + f.DataLen()
return protocol.ByteCount(length) + f.DataLen()
}
// DataLen gives the length of data in bytes
@@ -126,14 +130,14 @@ func (f *StreamFrame) DataLen() protocol.ByteCount {
// MaxDataLen returns the maximum data length
// If 0 is returned, writing will fail (a STREAM frame must contain at least 1 byte of data).
func (f *StreamFrame) MaxDataLen(maxSize protocol.ByteCount, version protocol.Version) protocol.ByteCount {
headerLen := 1 + quicvarint.Len(uint64(f.StreamID))
func (f *StreamFrame) MaxDataLen(maxSize protocol.ByteCount, _ protocol.Version) protocol.ByteCount {
headerLen := 1 + protocol.ByteCount(quicvarint.Len(uint64(f.StreamID)))
if f.Offset != 0 {
headerLen += quicvarint.Len(uint64(f.Offset))
headerLen += protocol.ByteCount(quicvarint.Len(uint64(f.Offset)))
}
if f.DataLenPresent {
// pretend that the data size will be 1 bytes
// if it turns out that varint encoding the length will consume 2 bytes, we need to adjust the data length afterwards
// Pretend that the data size will be 1 byte.
// If it turns out that varint encoding the length will consume 2 bytes, we need to adjust the data length afterward
headerLen++
}
if headerLen > maxSize {

View File

@@ -1,7 +1,6 @@
package wire
import (
"bytes"
"fmt"
"github.com/quic-go/quic-go/internal/protocol"
@@ -14,7 +13,7 @@ type StreamsBlockedFrame struct {
StreamLimit protocol.StreamNum
}
func parseStreamsBlockedFrame(r *bytes.Reader, typ uint64, _ protocol.Version) (*StreamsBlockedFrame, error) {
func parseStreamsBlockedFrame(b []byte, typ uint64, _ protocol.Version) (*StreamsBlockedFrame, int, error) {
f := &StreamsBlockedFrame{}
switch typ {
case bidiStreamBlockedFrameType:
@@ -22,15 +21,15 @@ func parseStreamsBlockedFrame(r *bytes.Reader, typ uint64, _ protocol.Version) (
case uniStreamBlockedFrameType:
f.Type = protocol.StreamTypeUni
}
streamLimit, err := quicvarint.Read(r)
streamLimit, l, err := quicvarint.Parse(b)
if err != nil {
return nil, err
return nil, 0, replaceUnexpectedEOF(err)
}
f.StreamLimit = protocol.StreamNum(streamLimit)
if f.StreamLimit > protocol.MaxStreamCount {
return nil, fmt.Errorf("%d exceeds the maximum stream count", f.StreamLimit)
return nil, 0, fmt.Errorf("%d exceeds the maximum stream count", f.StreamLimit)
}
return f, nil
return f, l, nil
}
func (f *StreamsBlockedFrame) Append(b []byte, _ protocol.Version) ([]byte, error) {
@@ -46,5 +45,5 @@ func (f *StreamsBlockedFrame) Append(b []byte, _ protocol.Version) ([]byte, erro
// Length of a written frame
func (f *StreamsBlockedFrame) Length(_ protocol.Version) protocol.ByteCount {
return 1 + quicvarint.Len(uint64(f.StreamLimit))
return 1 + protocol.ByteCount(quicvarint.Len(uint64(f.StreamLimit)))
}

View File

@@ -1,19 +1,17 @@
package wire
import (
"bytes"
"crypto/rand"
"encoding/binary"
"errors"
"fmt"
"io"
"net/netip"
"sort"
"slices"
"time"
"github.com/quic-go/quic-go/internal/protocol"
"github.com/quic-go/quic-go/internal/qerr"
"github.com/quic-go/quic-go/internal/utils"
"github.com/quic-go/quic-go/quicvarint"
)
@@ -89,7 +87,7 @@ type TransportParameters struct {
// Unmarshal the transport parameters
func (p *TransportParameters) Unmarshal(data []byte, sentBy protocol.Perspective) error {
if err := p.unmarshal(bytes.NewReader(data), sentBy, false); err != nil {
if err := p.unmarshal(data, sentBy, false); err != nil {
return &qerr.TransportError{
ErrorCode: qerr.TransportParameterError,
ErrorMessage: err.Error(),
@@ -98,9 +96,9 @@ func (p *TransportParameters) Unmarshal(data []byte, sentBy protocol.Perspective
return nil
}
func (p *TransportParameters) unmarshal(r *bytes.Reader, sentBy protocol.Perspective, fromSessionTicket bool) error {
func (p *TransportParameters) unmarshal(b []byte, sentBy protocol.Perspective, fromSessionTicket bool) error {
// needed to check that every parameter is only sent at most once
var parameterIDs []transportParameterID
parameterIDs := make([]transportParameterID, 0, 32)
var (
readOriginalDestinationConnectionID bool
@@ -112,18 +110,20 @@ func (p *TransportParameters) unmarshal(r *bytes.Reader, sentBy protocol.Perspec
p.MaxAckDelay = protocol.DefaultMaxAckDelay
p.MaxDatagramFrameSize = protocol.InvalidByteCount
for r.Len() > 0 {
paramIDInt, err := quicvarint.Read(r)
for len(b) > 0 {
paramIDInt, l, err := quicvarint.Parse(b)
if err != nil {
return err
}
paramID := transportParameterID(paramIDInt)
paramLen, err := quicvarint.Read(r)
b = b[l:]
paramLen, l, err := quicvarint.Parse(b)
if err != nil {
return err
}
if uint64(r.Len()) < paramLen {
return fmt.Errorf("remaining length (%d) smaller than parameter length (%d)", r.Len(), paramLen)
b = b[l:]
if uint64(len(b)) < paramLen {
return fmt.Errorf("remaining length (%d) smaller than parameter length (%d)", len(b), paramLen)
}
parameterIDs = append(parameterIDs, paramID)
switch paramID {
@@ -141,16 +141,18 @@ func (p *TransportParameters) unmarshal(r *bytes.Reader, sentBy protocol.Perspec
maxAckDelayParameterID,
maxDatagramFrameSizeParameterID,
ackDelayExponentParameterID:
if err := p.readNumericTransportParameter(r, paramID, int(paramLen)); err != nil {
if err := p.readNumericTransportParameter(b, paramID, int(paramLen)); err != nil {
return err
}
b = b[paramLen:]
case preferredAddressParameterID:
if sentBy == protocol.PerspectiveClient {
return errors.New("client sent a preferred_address")
}
if err := p.readPreferredAddress(r, int(paramLen)); err != nil {
if err := p.readPreferredAddress(b, int(paramLen)); err != nil {
return err
}
b = b[paramLen:]
case disableActiveMigrationParameterID:
if paramLen != 0 {
return fmt.Errorf("wrong length for disable_active_migration: %d (expected empty)", paramLen)
@@ -164,25 +166,41 @@ func (p *TransportParameters) unmarshal(r *bytes.Reader, sentBy protocol.Perspec
return fmt.Errorf("wrong length for stateless_reset_token: %d (expected 16)", paramLen)
}
var token protocol.StatelessResetToken
r.Read(token[:])
if len(b) < len(token) {
return io.EOF
}
copy(token[:], b)
b = b[len(token):]
p.StatelessResetToken = &token
case originalDestinationConnectionIDParameterID:
if sentBy == protocol.PerspectiveClient {
return errors.New("client sent an original_destination_connection_id")
}
p.OriginalDestinationConnectionID, _ = protocol.ReadConnectionID(r, int(paramLen))
if paramLen > protocol.MaxConnIDLen {
return protocol.ErrInvalidConnectionIDLen
}
p.OriginalDestinationConnectionID = protocol.ParseConnectionID(b[:paramLen])
b = b[paramLen:]
readOriginalDestinationConnectionID = true
case initialSourceConnectionIDParameterID:
p.InitialSourceConnectionID, _ = protocol.ReadConnectionID(r, int(paramLen))
if paramLen > protocol.MaxConnIDLen {
return protocol.ErrInvalidConnectionIDLen
}
p.InitialSourceConnectionID = protocol.ParseConnectionID(b[:paramLen])
b = b[paramLen:]
readInitialSourceConnectionID = true
case retrySourceConnectionIDParameterID:
if sentBy == protocol.PerspectiveClient {
return errors.New("client sent a retry_source_connection_id")
}
connID, _ := protocol.ReadConnectionID(r, int(paramLen))
if paramLen > protocol.MaxConnIDLen {
return protocol.ErrInvalidConnectionIDLen
}
connID := protocol.ParseConnectionID(b[:paramLen])
b = b[paramLen:]
p.RetrySourceConnectionID = &connID
default:
r.Seek(int64(paramLen), io.SeekCurrent)
b = b[paramLen:]
}
}
@@ -202,7 +220,12 @@ func (p *TransportParameters) unmarshal(r *bytes.Reader, sentBy protocol.Perspec
}
// check that every transport parameter was sent at most once
sort.Slice(parameterIDs, func(i, j int) bool { return parameterIDs[i] < parameterIDs[j] })
slices.SortFunc(parameterIDs, func(a, b transportParameterID) int {
if a < b {
return -1
}
return 1
})
for i := 0; i < len(parameterIDs)-1; i++ {
if parameterIDs[i] == parameterIDs[i+1] {
return fmt.Errorf("received duplicate transport parameter %#x", parameterIDs[i])
@@ -212,60 +235,47 @@ func (p *TransportParameters) unmarshal(r *bytes.Reader, sentBy protocol.Perspec
return nil
}
func (p *TransportParameters) readPreferredAddress(r *bytes.Reader, expectedLen int) error {
remainingLen := r.Len()
func (p *TransportParameters) readPreferredAddress(b []byte, expectedLen int) error {
remainingLen := len(b)
pa := &PreferredAddress{}
if len(b) < 4+2+16+2+1 {
return io.EOF
}
var ipv4 [4]byte
if _, err := io.ReadFull(r, ipv4[:]); err != nil {
return err
}
port, err := utils.BigEndian.ReadUint16(r)
if err != nil {
return err
}
pa.IPv4 = netip.AddrPortFrom(netip.AddrFrom4(ipv4), port)
copy(ipv4[:], b[:4])
port4 := binary.BigEndian.Uint16(b[4:])
b = b[4+2:]
pa.IPv4 = netip.AddrPortFrom(netip.AddrFrom4(ipv4), port4)
var ipv6 [16]byte
if _, err := io.ReadFull(r, ipv6[:]); err != nil {
return err
}
port, err = utils.BigEndian.ReadUint16(r)
if err != nil {
return err
}
pa.IPv6 = netip.AddrPortFrom(netip.AddrFrom16(ipv6), port)
connIDLen, err := r.ReadByte()
if err != nil {
return err
}
copy(ipv6[:], b[:16])
port6 := binary.BigEndian.Uint16(b[16:])
pa.IPv6 = netip.AddrPortFrom(netip.AddrFrom16(ipv6), port6)
b = b[16+2:]
connIDLen := int(b[0])
b = b[1:]
if connIDLen == 0 || connIDLen > protocol.MaxConnIDLen {
return fmt.Errorf("invalid connection ID length: %d", connIDLen)
}
connID, err := protocol.ReadConnectionID(r, int(connIDLen))
if err != nil {
return err
if len(b) < connIDLen+len(pa.StatelessResetToken) {
return io.EOF
}
pa.ConnectionID = connID
if _, err := io.ReadFull(r, pa.StatelessResetToken[:]); err != nil {
return err
}
if bytesRead := remainingLen - r.Len(); bytesRead != expectedLen {
pa.ConnectionID = protocol.ParseConnectionID(b[:connIDLen])
b = b[connIDLen:]
copy(pa.StatelessResetToken[:], b)
b = b[len(pa.StatelessResetToken):]
if bytesRead := remainingLen - len(b); bytesRead != expectedLen {
return fmt.Errorf("expected preferred_address to be %d long, read %d bytes", expectedLen, bytesRead)
}
p.PreferredAddress = pa
return nil
}
func (p *TransportParameters) readNumericTransportParameter(
r *bytes.Reader,
paramID transportParameterID,
expectedLen int,
) error {
remainingLen := r.Len()
val, err := quicvarint.Read(r)
func (p *TransportParameters) readNumericTransportParameter(b []byte, paramID transportParameterID, expectedLen int) error {
val, l, err := quicvarint.Parse(b)
if err != nil {
return fmt.Errorf("error while reading transport parameter %d: %s", paramID, err)
}
if remainingLen-r.Len() != expectedLen {
if l != expectedLen {
return fmt.Errorf("inconsistent transport parameter length for transport parameter %#x", paramID)
}
//nolint:exhaustive // This only covers the numeric transport parameters.
@@ -292,7 +302,7 @@ func (p *TransportParameters) readNumericTransportParameter(
p.MaxIdleTimeout = max(protocol.MinRemoteIdleTimeout, time.Duration(val)*time.Millisecond)
case maxUDPPayloadSizeParameterID:
if val < 1200 {
return fmt.Errorf("invalid value for max_packet_size: %d (minimum 1200)", val)
return fmt.Errorf("invalid value for max_udp_payload_size: %d (minimum 1200)", val)
}
p.MaxUDPPayloadSize = protocol.ByteCount(val)
case ackDelayExponentParameterID:
@@ -347,8 +357,10 @@ func (p *TransportParameters) Marshal(pers protocol.Perspective) []byte {
b = p.marshalVarintParam(b, initialMaxStreamsUniParameterID, uint64(p.MaxUniStreamNum))
// idle_timeout
b = p.marshalVarintParam(b, maxIdleTimeoutParameterID, uint64(p.MaxIdleTimeout/time.Millisecond))
// max_packet_size
b = p.marshalVarintParam(b, maxUDPPayloadSizeParameterID, uint64(protocol.MaxPacketBufferSize))
// max_udp_payload_size
if p.MaxUDPPayloadSize > 0 {
b = p.marshalVarintParam(b, maxUDPPayloadSizeParameterID, uint64(p.MaxUDPPayloadSize))
}
// max_ack_delay
// Only send it if is different from the default value.
if p.MaxAckDelay != protocol.DefaultMaxAckDelay {
@@ -457,15 +469,15 @@ func (p *TransportParameters) MarshalForSessionTicket(b []byte) []byte {
}
// UnmarshalFromSessionTicket unmarshals transport parameters from a session ticket.
func (p *TransportParameters) UnmarshalFromSessionTicket(r *bytes.Reader) error {
version, err := quicvarint.Read(r)
func (p *TransportParameters) UnmarshalFromSessionTicket(b []byte) error {
version, l, err := quicvarint.Parse(b)
if err != nil {
return err
}
if version != transportParameterMarshalingVersion {
return fmt.Errorf("unknown transport parameter marshaling version: %d", version)
}
return p.unmarshal(r, protocol.PerspectiveServer, true)
return p.unmarshal(b[l:], protocol.PerspectiveServer, true)
}
// ValidFor0RTT checks if the transport parameters match those saved in the session ticket.

View File

@@ -24,6 +24,7 @@ type ConnectionTracer struct {
UpdatedMetrics func(rttStats *RTTStats, cwnd, bytesInFlight ByteCount, packetsInFlight int)
AcknowledgedPacket func(EncryptionLevel, PacketNumber)
LostPacket func(EncryptionLevel, PacketNumber, PacketLossReason)
UpdatedMTU func(mtu ByteCount, done bool)
UpdatedCongestionState func(CongestionState)
UpdatedPTOCount func(value uint32)
UpdatedKeyFromTLS func(EncryptionLevel, Perspective)
@@ -168,6 +169,13 @@ func NewMultiplexedConnectionTracer(tracers ...*ConnectionTracer) *ConnectionTra
}
}
},
UpdatedMTU: func(mtu ByteCount, done bool) {
for _, t := range tracers {
if t.UpdatedMTU != nil {
t.UpdatedMTU(mtu, done)
}
}
},
UpdatedCongestionState: func(state CongestionState) {
for _, t := range tracers {
if t.UpdatedCongestionState != nil {

View File

@@ -1,19 +1,19 @@
package quic
import (
"net"
"time"
"github.com/quic-go/quic-go/internal/ackhandler"
"github.com/quic-go/quic-go/internal/protocol"
"github.com/quic-go/quic-go/internal/utils"
"github.com/quic-go/quic-go/internal/wire"
"github.com/quic-go/quic-go/logging"
)
type mtuDiscoverer interface {
// Start starts the MTU discovery process.
// It's unnecessary to call ShouldSendProbe before that.
Start(maxPacketSize protocol.ByteCount)
Start()
ShouldSendProbe(now time.Time) bool
CurrentSize() protocol.ByteCount
GetPing() (ping ackhandler.Frame, datagramSize protocol.ByteCount)
@@ -25,54 +25,129 @@ const (
maxMTUDiff = 20
// send a probe packet every mtuProbeDelay RTTs
mtuProbeDelay = 5
// Once maxLostMTUProbes MTU probe packets larger than a certain size are lost,
// MTU discovery won't probe for larger MTUs than this size.
// The algorithm used here is resilient to packet loss of (maxLostMTUProbes - 1) packets.
maxLostMTUProbes = 3
)
func getMaxPacketSize(addr net.Addr) protocol.ByteCount {
maxSize := protocol.ByteCount(protocol.MinInitialPacketSize)
// If this is not a UDP address, we don't know anything about the MTU.
// Use the minimum size of an Initial packet as the max packet size.
if udpAddr, ok := addr.(*net.UDPAddr); ok {
if utils.IsIPv4(udpAddr.IP) {
maxSize = protocol.InitialPacketSizeIPv4
} else {
maxSize = protocol.InitialPacketSizeIPv6
}
}
return maxSize
}
// The Path MTU is found by sending a larger packet every now and then.
// If the packet is acknowledged, we conclude that the path supports this larger packet size.
// If the packet is lost, this can mean one of two things:
// 1. The path doesn't support this larger packet size, or
// 2. The packet was lost due to packet loss, independent of its size.
// The algorithm used here is resilient to packet loss of (maxLostMTUProbes - 1) packets.
// For simplicty, the following example use maxLostMTUProbes = 2.
//
// Initialization:
// |------------------------------------------------------------------------------|
// min max
//
// The first MTU probe packet will have size (min+max)/2.
// Assume that this packet is acknowledged. We can now move the min marker,
// and continue the search in the resulting interval.
//
// If 1st probe packet acknowledged:
// |---------------------------------------|--------------------------------------|
// min max
//
// If 1st probe packet lost:
// |---------------------------------------|--------------------------------------|
// min lost[0] max
//
// We can't conclude that the path doesn't support this packet size, since the loss of the probe
// packet could have been unrelated to the packet size. A larger probe packet will be sent later on.
// After a loss, the next probe packet has size (min+lost[0])/2.
// Now assume this probe packet is acknowledged:
//
// 2nd probe packet acknowledged:
// |------------------|--------------------|--------------------------------------|
// min lost[0] max
//
// First of all, we conclude that the path supports at least this MTU. That's progress!
// Second, we probe a bit more aggressively with the next probe packet:
// After an acknowledgement, the next probe packet has size (min+max)/2.
// This means we'll send a packet larger than the first probe packet (which was lost).
//
// If 3rd probe packet acknowledged:
// |-------------------------------------------------|----------------------------|
// min max
//
// We can conclude that the loss of the 1st probe packet was not due to its size, and
// continue searching in a much smaller interval now.
//
// If 3rd probe packet lost:
// |------------------|--------------------|---------|----------------------------|
// min lost[0] max
//
// Since in our example numPTOProbes = 2, and we lost 2 packets smaller than max, we
// conclude that this packet size is not supported on the path, and reduce the maximum
// value of the search interval.
//
// MTU discovery concludes once the interval min and max has been narrowed down to maxMTUDiff.
type mtuFinder struct {
lastProbeTime time.Time
mtuIncreased func(protocol.ByteCount)
rttStats *utils.RTTStats
inFlight protocol.ByteCount // the size of the probe packet currently in flight. InvalidByteCount if none is in flight
current protocol.ByteCount
max protocol.ByteCount // the maximum value, as advertised by the peer (or our maximum size buffer)
min protocol.ByteCount
limit protocol.ByteCount
// on initialization, we treat the maximum size as the first "lost" packet
lost [maxLostMTUProbes]protocol.ByteCount
lastProbeWasLost bool
tracer *logging.ConnectionTracer
}
var _ mtuDiscoverer = &mtuFinder{}
func newMTUDiscoverer(rttStats *utils.RTTStats, start protocol.ByteCount, mtuIncreased func(protocol.ByteCount)) *mtuFinder {
return &mtuFinder{
func newMTUDiscoverer(
rttStats *utils.RTTStats,
start, max protocol.ByteCount,
mtuIncreased func(protocol.ByteCount),
tracer *logging.ConnectionTracer,
) *mtuFinder {
f := &mtuFinder{
inFlight: protocol.InvalidByteCount,
current: start,
min: start,
limit: max,
rttStats: rttStats,
mtuIncreased: mtuIncreased,
tracer: tracer,
}
for i := range f.lost {
if i == 0 {
f.lost[i] = max
continue
}
f.lost[i] = protocol.InvalidByteCount
}
return f
}
func (f *mtuFinder) done() bool {
return f.max-f.current <= maxMTUDiff+1
return f.max()-f.min <= maxMTUDiff+1
}
func (f *mtuFinder) Start(maxPacketSize protocol.ByteCount) {
func (f *mtuFinder) max() protocol.ByteCount {
for i, v := range f.lost {
if v == protocol.InvalidByteCount {
return f.lost[i-1]
}
}
return f.lost[len(f.lost)-1]
}
func (f *mtuFinder) Start() {
f.lastProbeTime = time.Now() // makes sure the first probe packet is not sent immediately
f.max = maxPacketSize
}
func (f *mtuFinder) ShouldSendProbe(now time.Time) bool {
if f.max == 0 || f.lastProbeTime.IsZero() {
if f.lastProbeTime.IsZero() {
return false
}
if f.inFlight != protocol.InvalidByteCount || f.done() {
@@ -82,20 +157,27 @@ func (f *mtuFinder) ShouldSendProbe(now time.Time) bool {
}
func (f *mtuFinder) GetPing() (ackhandler.Frame, protocol.ByteCount) {
size := (f.max + f.current) / 2
var size protocol.ByteCount
if f.lastProbeWasLost {
size = (f.min + f.lost[0]) / 2
} else {
size = (f.min + f.max()) / 2
}
f.lastProbeTime = time.Now()
f.inFlight = size
return ackhandler.Frame{
Frame: &wire.PingFrame{},
Handler: (*mtuFinderAckHandler)(f),
Handler: &mtuFinderAckHandler{f},
}, size
}
func (f *mtuFinder) CurrentSize() protocol.ByteCount {
return f.current
return f.min
}
type mtuFinderAckHandler mtuFinder
type mtuFinderAckHandler struct {
*mtuFinder
}
var _ ackhandler.FrameHandler = &mtuFinderAckHandler{}
@@ -105,7 +187,28 @@ func (h *mtuFinderAckHandler) OnAcked(wire.Frame) {
panic("OnAcked callback called although there's no MTU probe packet in flight")
}
h.inFlight = protocol.InvalidByteCount
h.current = size
h.min = size
h.lastProbeWasLost = false
// remove all values smaller than size from the lost array
var j int
for i, v := range h.lost {
if size < v {
j = i
break
}
}
if j > 0 {
for i := 0; i < len(h.lost); i++ {
if i+j < len(h.lost) {
h.lost[i] = h.lost[i+j]
} else {
h.lost[i] = protocol.InvalidByteCount
}
}
}
if h.tracer != nil && h.tracer.UpdatedMTU != nil {
h.tracer.UpdatedMTU(size, h.done())
}
h.mtuIncreased(size)
}
@@ -114,6 +217,13 @@ func (h *mtuFinderAckHandler) OnLost(wire.Frame) {
if size == protocol.InvalidByteCount {
panic("OnLost callback called although there's no MTU probe packet in flight")
}
h.max = size
h.lastProbeWasLost = true
h.inFlight = protocol.InvalidByteCount
for i, v := range h.lost {
if size < v {
copy(h.lost[i+1:], h.lost[i:])
h.lost[i] = size
break
}
}
}

View File

@@ -3,8 +3,6 @@ package quicvarint
import (
"fmt"
"io"
"github.com/quic-go/quic-go/internal/protocol"
)
// taken from the QUIC draft
@@ -28,16 +26,16 @@ func Read(r io.ByteReader) (uint64, error) {
return 0, err
}
// the first two bits of the first byte encode the length
len := 1 << ((firstByte & 0xc0) >> 6)
l := 1 << ((firstByte & 0xc0) >> 6)
b1 := firstByte & (0xff - 0xc0)
if len == 1 {
if l == 1 {
return uint64(b1), nil
}
b2, err := r.ReadByte()
if err != nil {
return 0, err
}
if len == 2 {
if l == 2 {
return uint64(b2) + uint64(b1)<<8, nil
}
b3, err := r.ReadByte()
@@ -48,7 +46,7 @@ func Read(r io.ByteReader) (uint64, error) {
if err != nil {
return 0, err
}
if len == 4 {
if l == 4 {
return uint64(b4) + uint64(b3)<<8 + uint64(b2)<<16 + uint64(b1)<<24, nil
}
b5, err := r.ReadByte()
@@ -70,6 +68,31 @@ func Read(r io.ByteReader) (uint64, error) {
return uint64(b8) + uint64(b7)<<8 + uint64(b6)<<16 + uint64(b5)<<24 + uint64(b4)<<32 + uint64(b3)<<40 + uint64(b2)<<48 + uint64(b1)<<56, nil
}
// Parse reads a number in the QUIC varint format.
// It returns the number of bytes consumed.
func Parse(b []byte) (uint64 /* value */, int /* bytes consumed */, error) {
if len(b) == 0 {
return 0, 0, io.EOF
}
firstByte := b[0]
// the first two bits of the first byte encode the length
l := 1 << ((firstByte & 0xc0) >> 6)
if len(b) < l {
return 0, 0, io.ErrUnexpectedEOF
}
b0 := firstByte & (0xff - 0xc0)
if l == 1 {
return uint64(b0), 1, nil
}
if l == 2 {
return uint64(b[1]) + uint64(b0)<<8, 2, nil
}
if l == 4 {
return uint64(b[3]) + uint64(b[2])<<8 + uint64(b[1])<<16 + uint64(b0)<<24, 4, nil
}
return uint64(b[7]) + uint64(b[6])<<8 + uint64(b[5])<<16 + uint64(b[4])<<24 + uint64(b[3])<<32 + uint64(b[2])<<40 + uint64(b[1])<<48 + uint64(b0)<<56, 8, nil
}
// Append appends i in the QUIC varint format.
func Append(b []byte, i uint64) []byte {
if i <= maxVarInt1 {
@@ -91,7 +114,7 @@ func Append(b []byte, i uint64) []byte {
}
// AppendWithLen append i in the QUIC varint format with the desired length.
func AppendWithLen(b []byte, i uint64, length protocol.ByteCount) []byte {
func AppendWithLen(b []byte, i uint64, length int) []byte {
if length != 1 && length != 2 && length != 4 && length != 8 {
panic("invalid varint length")
}
@@ -109,17 +132,17 @@ func AppendWithLen(b []byte, i uint64, length protocol.ByteCount) []byte {
} else if length == 8 {
b = append(b, 0b11000000)
}
for j := protocol.ByteCount(1); j < length-l; j++ {
for j := 1; j < length-l; j++ {
b = append(b, 0)
}
for j := protocol.ByteCount(0); j < l; j++ {
for j := 0; j < l; j++ {
b = append(b, uint8(i>>(8*(l-1-j))))
}
return b
}
// Len determines the number of bytes that will be needed to write the number i.
func Len(i uint64) protocol.ByteCount {
func Len(i uint64) int {
if i <= maxVarInt1 {
return 1
}

View File

@@ -37,10 +37,14 @@ type receiveStream struct {
readPosInFrame int
currentFrameIsLast bool // is the currentFrame the last frame on this stream
finRead bool // set once we read a frame with a Fin
// Set once we read the io.EOF or the cancellation error.
// Note that for local cancellations, this doesn't necessarily mean that we know the final offset yet.
errorRead bool
completed bool // set once we've called streamSender.onStreamCompleted
cancelledRemotely bool
cancelledLocally bool
cancelErr *StreamError
closeForShutdownErr error
cancelReadErr error
resetRemotelyErr *StreamError
readChan chan struct{}
readOnce chan struct{} // cap: 1, to protect against concurrent use of Read
@@ -83,7 +87,8 @@ func (s *receiveStream) Read(p []byte) (int, error) {
defer func() { <-s.readOnce }()
s.mutex.Lock()
completed, n, err := s.readImpl(p)
n, err := s.readImpl(p)
completed := s.isNewlyCompleted()
s.mutex.Unlock()
if completed {
@@ -92,18 +97,38 @@ func (s *receiveStream) Read(p []byte) (int, error) {
return n, err
}
func (s *receiveStream) readImpl(p []byte) (bool /*stream completed */, int, error) {
if s.finRead {
return false, 0, io.EOF
func (s *receiveStream) isNewlyCompleted() bool {
if s.completed {
return false
}
if s.cancelReadErr != nil {
return false, 0, s.cancelReadErr
// We need to know the final offset (either via FIN or RESET_STREAM) for flow control accounting.
if s.finalOffset == protocol.MaxByteCount {
return false
}
if s.resetRemotelyErr != nil {
return false, 0, s.resetRemotelyErr
// We're done with the stream if it was cancelled locally...
if s.cancelledLocally {
s.completed = true
return true
}
// ... or if the error (either io.EOF or the reset error) was read
if s.errorRead {
s.completed = true
return true
}
return false
}
func (s *receiveStream) readImpl(p []byte) (int, error) {
if s.currentFrameIsLast && s.currentFrame == nil {
s.errorRead = true
return 0, io.EOF
}
if s.cancelledRemotely || s.cancelledLocally {
s.errorRead = true
return 0, s.cancelErr
}
if s.closeForShutdownErr != nil {
return false, 0, s.closeForShutdownErr
return 0, s.closeForShutdownErr
}
var bytesRead int
@@ -113,25 +138,23 @@ func (s *receiveStream) readImpl(p []byte) (bool /*stream completed */, int, err
s.dequeueNextFrame()
}
if s.currentFrame == nil && bytesRead > 0 {
return false, bytesRead, s.closeForShutdownErr
return bytesRead, s.closeForShutdownErr
}
for {
// Stop waiting on errors
if s.closeForShutdownErr != nil {
return false, bytesRead, s.closeForShutdownErr
return bytesRead, s.closeForShutdownErr
}
if s.cancelReadErr != nil {
return false, bytesRead, s.cancelReadErr
}
if s.resetRemotelyErr != nil {
return false, bytesRead, s.resetRemotelyErr
if s.cancelledRemotely || s.cancelledLocally {
s.errorRead = true
return 0, s.cancelErr
}
deadline := s.deadline
if !deadline.IsZero() {
if !time.Now().Before(deadline) {
return false, bytesRead, errDeadline
return bytesRead, errDeadline
}
if deadlineTimer == nil {
deadlineTimer = utils.NewTimer()
@@ -161,10 +184,10 @@ func (s *receiveStream) readImpl(p []byte) (bool /*stream completed */, int, err
}
if bytesRead > len(p) {
return false, bytesRead, fmt.Errorf("BUG: bytesRead (%d) > len(p) (%d) in stream.Read", bytesRead, len(p))
return bytesRead, fmt.Errorf("BUG: bytesRead (%d) > len(p) (%d) in stream.Read", bytesRead, len(p))
}
if s.readPosInFrame > len(s.currentFrame) {
return false, bytesRead, fmt.Errorf("BUG: readPosInFrame (%d) > frame.DataLen (%d) in stream.Read", s.readPosInFrame, len(s.currentFrame))
return bytesRead, fmt.Errorf("BUG: readPosInFrame (%d) > frame.DataLen (%d) in stream.Read", s.readPosInFrame, len(s.currentFrame))
}
m := copy(p[bytesRead:], s.currentFrame[s.readPosInFrame:])
@@ -173,20 +196,20 @@ func (s *receiveStream) readImpl(p []byte) (bool /*stream completed */, int, err
// when a RESET_STREAM was received, the flow controller was already
// informed about the final byteOffset for this stream
if s.resetRemotelyErr == nil {
if !s.cancelledRemotely {
s.flowController.AddBytesRead(protocol.ByteCount(m))
}
if s.readPosInFrame >= len(s.currentFrame) && s.currentFrameIsLast {
s.finRead = true
s.currentFrame = nil
if s.currentFrameDone != nil {
s.currentFrameDone()
}
return true, bytesRead, io.EOF
s.errorRead = true
return bytesRead, io.EOF
}
}
return false, bytesRead, nil
return bytesRead, nil
}
func (s *receiveStream) dequeueNextFrame() {
@@ -202,7 +225,8 @@ func (s *receiveStream) dequeueNextFrame() {
func (s *receiveStream) CancelRead(errorCode StreamErrorCode) {
s.mutex.Lock()
completed := s.cancelReadImpl(errorCode)
s.cancelReadImpl(errorCode)
completed := s.isNewlyCompleted()
s.mutex.Unlock()
if completed {
@@ -211,23 +235,26 @@ func (s *receiveStream) CancelRead(errorCode StreamErrorCode) {
}
}
func (s *receiveStream) cancelReadImpl(errorCode qerr.StreamErrorCode) bool /* completed */ {
if s.finRead || s.cancelReadErr != nil || s.resetRemotelyErr != nil {
return false
func (s *receiveStream) cancelReadImpl(errorCode qerr.StreamErrorCode) {
if s.cancelledLocally { // duplicate call to CancelRead
return
}
s.cancelReadErr = &StreamError{StreamID: s.streamID, ErrorCode: errorCode, Remote: false}
s.cancelledLocally = true
if s.errorRead || s.cancelledRemotely {
return
}
s.cancelErr = &StreamError{StreamID: s.streamID, ErrorCode: errorCode, Remote: false}
s.signalRead()
s.sender.queueControlFrame(&wire.StopSendingFrame{
StreamID: s.streamID,
ErrorCode: errorCode,
})
// We're done with this stream if the final offset was already received.
return s.finalOffset != protocol.MaxByteCount
}
func (s *receiveStream) handleStreamFrame(frame *wire.StreamFrame) error {
s.mutex.Lock()
completed, err := s.handleStreamFrameImpl(frame)
err := s.handleStreamFrameImpl(frame)
completed := s.isNewlyCompleted()
s.mutex.Unlock()
if completed {
@@ -237,59 +264,58 @@ func (s *receiveStream) handleStreamFrame(frame *wire.StreamFrame) error {
return err
}
func (s *receiveStream) handleStreamFrameImpl(frame *wire.StreamFrame) (bool /* completed */, error) {
func (s *receiveStream) handleStreamFrameImpl(frame *wire.StreamFrame) error {
maxOffset := frame.Offset + frame.DataLen()
if err := s.flowController.UpdateHighestReceived(maxOffset, frame.Fin); err != nil {
return false, err
return err
}
var newlyRcvdFinalOffset bool
if frame.Fin {
newlyRcvdFinalOffset = s.finalOffset == protocol.MaxByteCount
s.finalOffset = maxOffset
}
if s.cancelReadErr != nil {
return newlyRcvdFinalOffset, nil
if s.cancelledLocally {
return nil
}
if err := s.frameQueue.Push(frame.Data, frame.Offset, frame.PutBack); err != nil {
return false, err
return err
}
s.signalRead()
return false, nil
return nil
}
func (s *receiveStream) handleResetStreamFrame(frame *wire.ResetStreamFrame) error {
s.mutex.Lock()
completed, err := s.handleResetStreamFrameImpl(frame)
err := s.handleResetStreamFrameImpl(frame)
completed := s.isNewlyCompleted()
s.mutex.Unlock()
if completed {
s.flowController.Abandon()
s.sender.onStreamCompleted(s.streamID)
}
return err
}
func (s *receiveStream) handleResetStreamFrameImpl(frame *wire.ResetStreamFrame) (bool /*completed */, error) {
func (s *receiveStream) handleResetStreamFrameImpl(frame *wire.ResetStreamFrame) error {
if s.closeForShutdownErr != nil {
return false, nil
return nil
}
if err := s.flowController.UpdateHighestReceived(frame.FinalSize, true); err != nil {
return false, err
return err
}
newlyRcvdFinalOffset := s.finalOffset == protocol.MaxByteCount
s.finalOffset = frame.FinalSize
// ignore duplicate RESET_STREAM frames for this stream (after checking their final offset)
if s.resetRemotelyErr != nil {
return false, nil
if s.cancelledRemotely {
return nil
}
s.resetRemotelyErr = &StreamError{
StreamID: s.streamID,
ErrorCode: frame.ErrorCode,
Remote: true,
s.flowController.Abandon()
// don't save the error if the RESET_STREAM frames was received after CancelRead was called
if s.cancelledLocally {
return nil
}
s.cancelledRemotely = true
s.cancelErr = &StreamError{StreamID: s.streamID, ErrorCode: frame.ErrorCode, Remote: true}
s.signalRead()
return newlyRcvdFinalOffset, nil
return nil
}
func (s *receiveStream) SetReadDeadline(t time.Time) error {

View File

@@ -42,7 +42,11 @@ type sendStream struct {
finishedWriting bool // set once Close() is called
finSent bool // set when a STREAM_FRAME with FIN bit has been sent
completed bool // set when this stream has been reported to the streamSender as completed
// Set when the application knows about the cancellation.
// This can happen because the application called CancelWrite,
// or because Write returned the error (for remote cancellations).
cancellationFlagged bool
completed bool // set when this stream has been reported to the streamSender as completed
dataForWriting []byte // during a Write() call, this slice is the part of p that still needs to be sent out
nextFrame *wire.StreamFrame
@@ -60,6 +64,7 @@ var (
)
func newSendStream(
ctx context.Context,
streamID protocol.StreamID,
sender streamSender,
flowController flowcontrol.StreamFlowController,
@@ -71,7 +76,7 @@ func newSendStream(
writeChan: make(chan struct{}, 1),
writeOnce: make(chan struct{}, 1), // cap: 1, to protect against concurrent use of Write
}
s.ctx, s.ctxCancel = context.WithCancelCause(context.Background())
s.ctx, s.ctxCancel = context.WithCancelCause(ctx)
return s
}
@@ -86,23 +91,32 @@ func (s *sendStream) Write(p []byte) (int, error) {
s.writeOnce <- struct{}{}
defer func() { <-s.writeOnce }()
isNewlyCompleted, n, err := s.write(p)
if isNewlyCompleted {
s.sender.onStreamCompleted(s.streamID)
}
return n, err
}
func (s *sendStream) write(p []byte) (bool /* is newly completed */, int, error) {
s.mutex.Lock()
defer s.mutex.Unlock()
if s.finishedWriting {
return 0, fmt.Errorf("write on closed stream %d", s.streamID)
return false, 0, fmt.Errorf("write on closed stream %d", s.streamID)
}
if s.cancelWriteErr != nil {
return 0, s.cancelWriteErr
s.cancellationFlagged = true
return s.isNewlyCompleted(), 0, s.cancelWriteErr
}
if s.closeForShutdownErr != nil {
return 0, s.closeForShutdownErr
return false, 0, s.closeForShutdownErr
}
if !s.deadline.IsZero() && !time.Now().Before(s.deadline) {
return 0, errDeadline
return false, 0, errDeadline
}
if len(p) == 0 {
return 0, nil
return false, 0, nil
}
s.dataForWriting = p
@@ -143,7 +157,7 @@ func (s *sendStream) Write(p []byte) (int, error) {
if !deadline.IsZero() {
if !time.Now().Before(deadline) {
s.dataForWriting = nil
return bytesWritten, errDeadline
return false, bytesWritten, errDeadline
}
if deadlineTimer == nil {
deadlineTimer = utils.NewTimer()
@@ -178,14 +192,15 @@ func (s *sendStream) Write(p []byte) (int, error) {
}
if bytesWritten == len(p) {
return bytesWritten, nil
return false, bytesWritten, nil
}
if s.closeForShutdownErr != nil {
return bytesWritten, s.closeForShutdownErr
return false, bytesWritten, s.closeForShutdownErr
} else if s.cancelWriteErr != nil {
return bytesWritten, s.cancelWriteErr
s.cancellationFlagged = true
return s.isNewlyCompleted(), bytesWritten, s.cancelWriteErr
}
return bytesWritten, nil
return false, bytesWritten, nil
}
func (s *sendStream) canBufferStreamFrame() bool {
@@ -348,8 +363,24 @@ func (s *sendStream) getDataForWriting(f *wire.StreamFrame, maxBytes protocol.By
}
func (s *sendStream) isNewlyCompleted() bool {
completed := (s.finSent || s.cancelWriteErr != nil) && s.numOutstandingFrames == 0 && len(s.retransmissionQueue) == 0
if completed && !s.completed {
if s.completed {
return false
}
// We need to keep the stream around until all frames have been sent and acknowledged.
if s.numOutstandingFrames > 0 || len(s.retransmissionQueue) > 0 {
return false
}
// The stream is completed if we sent the FIN.
if s.finSent {
s.completed = true
return true
}
// The stream is also completed if:
// 1. the application called CancelWrite, or
// 2. we received a STOP_SENDING, and
// * the application consumed the error via Write, or
// * the application called CLsoe
if s.cancelWriteErr != nil && (s.cancellationFlagged || s.finishedWriting) {
s.completed = true
return true
}
@@ -362,15 +393,23 @@ func (s *sendStream) Close() error {
s.mutex.Unlock()
return nil
}
if s.cancelWriteErr != nil {
s.mutex.Unlock()
return fmt.Errorf("close called for canceled stream %d", s.streamID)
}
s.ctxCancel(nil)
s.finishedWriting = true
cancelWriteErr := s.cancelWriteErr
if cancelWriteErr != nil {
s.cancellationFlagged = true
}
completed := s.isNewlyCompleted()
s.mutex.Unlock()
if completed {
s.sender.onStreamCompleted(s.streamID)
}
if cancelWriteErr != nil {
return fmt.Errorf("close called for canceled stream %d", s.streamID)
}
s.sender.onHasStreamData(s.streamID) // need to send the FIN, must be called without holding the mutex
s.ctxCancel(nil)
return nil
}
@@ -378,9 +417,11 @@ func (s *sendStream) CancelWrite(errorCode StreamErrorCode) {
s.cancelWriteImpl(errorCode, false)
}
// must be called after locking the mutex
func (s *sendStream) cancelWriteImpl(errorCode qerr.StreamErrorCode, remote bool) {
s.mutex.Lock()
if !remote {
s.cancellationFlagged = true
}
if s.cancelWriteErr != nil {
s.mutex.Unlock()
return
@@ -437,7 +478,6 @@ func (s *sendStream) SetWriteDeadline(t time.Time) error {
// The peer will NOT be informed about this: the stream is closed without sending a FIN or RST.
func (s *sendStream) closeForShutdown(err error) {
s.mutex.Lock()
s.ctxCancel(err)
s.closeForShutdownErr = err
s.mutex.Unlock()
s.signalWrite()

View File

@@ -76,8 +76,12 @@ type baseServer struct {
nextZeroRTTCleanup time.Time
zeroRTTQueues map[protocol.ConnectionID]*zeroRTTQueue // only initialized if acceptEarlyConns == true
connContext func(context.Context) context.Context
// set as a member, so they can be set in the tests
newConn func(
context.Context,
context.CancelCauseFunc,
sendConn,
connRunner,
protocol.ConnectionID, /* original dest connection ID */
@@ -92,7 +96,6 @@ type baseServer struct {
*handshake.TokenGenerator,
bool, /* client address validated by an address validation token */
*logging.ConnectionTracer,
uint64,
utils.Logger,
protocol.Version,
) quicConn
@@ -231,6 +234,7 @@ func newServer(
conn rawConn,
connHandler packetHandlerManager,
connIDGenerator ConnectionIDGenerator,
connContext func(context.Context) context.Context,
tlsConf *tls.Config,
config *Config,
tracer *logging.Tracer,
@@ -243,6 +247,7 @@ func newServer(
) *baseServer {
s := &baseServer{
conn: conn,
connContext: connContext,
tlsConf: tlsConf,
config: config,
tokenGenerator: handshake.NewTokenGenerator(tokenGeneratorKey),
@@ -631,7 +636,26 @@ func (s *baseServer) handleInitialImpl(p receivedPacket, hdr *wire.Header) error
}
var conn quicConn
tracingID := nextConnTracingID()
var cancel context.CancelCauseFunc
ctx, cancel1 := context.WithCancelCause(context.Background())
if s.connContext != nil {
ctx = s.connContext(ctx)
if ctx == nil {
panic("quic: ConnContext returned nil")
}
// There's no guarantee that the application returns a context
// that's derived from the context we passed into ConnContext.
// We need to make sure that both contexts are cancelled.
var cancel2 context.CancelCauseFunc
ctx, cancel2 = context.WithCancelCause(ctx)
cancel = func(cause error) {
cancel1(cause)
cancel2(cause)
}
} else {
cancel = cancel1
}
ctx = context.WithValue(ctx, ConnectionTracingKey, nextConnTracingID())
var tracer *logging.ConnectionTracer
if config.Tracer != nil {
// Use the same connection ID that is passed to the client's GetLogWriter callback.
@@ -639,7 +663,7 @@ func (s *baseServer) handleInitialImpl(p receivedPacket, hdr *wire.Header) error
if origDestConnID.Len() > 0 {
connID = origDestConnID
}
tracer = config.Tracer(context.WithValue(context.Background(), ConnectionTracingKey, tracingID), protocol.PerspectiveServer, connID)
tracer = config.Tracer(ctx, protocol.PerspectiveServer, connID)
}
connID, err := s.connIDGenerator.GenerateConnectionID()
if err != nil {
@@ -647,6 +671,8 @@ func (s *baseServer) handleInitialImpl(p receivedPacket, hdr *wire.Header) error
}
s.logger.Debugf("Changing connection ID to %s.", connID)
conn = s.newConn(
ctx,
cancel,
newSendConn(s.conn, p.remoteAddr, p.info, s.logger),
s.connHandler,
origDestConnID,
@@ -661,7 +687,6 @@ func (s *baseServer) handleInitialImpl(p receivedPacket, hdr *wire.Header) error
s.tokenGenerator,
clientAddrVerified,
tracer,
tracingID,
s.logger,
hdr.Version,
)

View File

@@ -1,6 +1,7 @@
package quic
import (
"context"
"net"
"os"
"sync"
@@ -85,7 +86,9 @@ type stream struct {
var _ Stream = &stream{}
// newStream creates a new Stream
func newStream(streamID protocol.StreamID,
func newStream(
ctx context.Context,
streamID protocol.StreamID,
sender streamSender,
flowController flowcontrol.StreamFlowController,
) *stream {
@@ -99,7 +102,7 @@ func newStream(streamID protocol.StreamID,
s.completedMutex.Unlock()
},
}
s.sendStream = *newSendStream(streamID, senderForSendStream, flowController)
s.sendStream = *newSendStream(ctx, streamID, senderForSendStream, flowController)
senderForReceiveStream := &uniStreamSender{
streamSender: sender,
onStreamCompletedImpl: func() {

View File

@@ -45,6 +45,7 @@ func (streamOpenErr) Timeout() bool { return false }
var errTooManyOpenStreams = errors.New("too many open streams")
type streamsMap struct {
ctx context.Context // not used for cancellations, but carries the values associated with the connection
perspective protocol.Perspective
maxIncomingBidiStreams uint64
@@ -64,6 +65,7 @@ type streamsMap struct {
var _ streamManager = &streamsMap{}
func newStreamsMap(
ctx context.Context,
sender streamSender,
newFlowController func(protocol.StreamID) flowcontrol.StreamFlowController,
maxIncomingBidiStreams uint64,
@@ -71,6 +73,7 @@ func newStreamsMap(
perspective protocol.Perspective,
) streamManager {
m := &streamsMap{
ctx: ctx,
perspective: perspective,
newFlowController: newFlowController,
maxIncomingBidiStreams: maxIncomingBidiStreams,
@@ -86,7 +89,7 @@ func (m *streamsMap) initMaps() {
protocol.StreamTypeBidi,
func(num protocol.StreamNum) streamI {
id := num.StreamID(protocol.StreamTypeBidi, m.perspective)
return newStream(id, m.sender, m.newFlowController(id))
return newStream(m.ctx, id, m.sender, m.newFlowController(id))
},
m.sender.queueControlFrame,
)
@@ -94,7 +97,7 @@ func (m *streamsMap) initMaps() {
protocol.StreamTypeBidi,
func(num protocol.StreamNum) streamI {
id := num.StreamID(protocol.StreamTypeBidi, m.perspective.Opposite())
return newStream(id, m.sender, m.newFlowController(id))
return newStream(m.ctx, id, m.sender, m.newFlowController(id))
},
m.maxIncomingBidiStreams,
m.sender.queueControlFrame,
@@ -103,7 +106,7 @@ func (m *streamsMap) initMaps() {
protocol.StreamTypeUni,
func(num protocol.StreamNum) sendStreamI {
id := num.StreamID(protocol.StreamTypeUni, m.perspective)
return newSendStream(id, m.sender, m.newFlowController(id))
return newSendStream(m.ctx, id, m.sender, m.newFlowController(id))
},
m.sender.queueControlFrame,
)

View File

@@ -33,4 +33,6 @@ func parseIPv4PktInfo(body []byte) (ip netip.Addr, ifIndex uint32, ok bool) {
return netip.AddrFrom4(*(*[4]byte)(body[8:12])), binary.LittleEndian.Uint32(body), true
}
func isGSOSupported(syscall.RawConn) bool { return false }
func isGSOEnabled(syscall.RawConn) bool { return false }
func isECNEnabled() bool { return !isECNDisabledUsingEnv() }

View File

@@ -28,4 +28,6 @@ func parseIPv4PktInfo(body []byte) (ip netip.Addr, _ uint32, ok bool) {
return netip.AddrFrom4(*(*[4]byte)(body)), 0, true
}
func isGSOSupported(syscall.RawConn) bool { return false }
func isGSOEnabled(syscall.RawConn) bool { return false }
func isECNEnabled() bool { return !isECNDisabledUsingEnv() }

View File

@@ -23,6 +23,12 @@ const ecnIPv4DataLen = 1
const batchSize = 8 // needs to smaller than MaxUint8 (otherwise the type of oobConn.readPos has to be changed)
var kernelVersionMajor int
func init() {
kernelVersionMajor, _ = kernelVersion()
}
func forceSetReceiveBuffer(c syscall.RawConn, bytes int) error {
var serr error
if err := c.Control(func(fd uintptr) {
@@ -55,9 +61,12 @@ func parseIPv4PktInfo(body []byte) (ip netip.Addr, ifIndex uint32, ok bool) {
return netip.AddrFrom4(*(*[4]byte)(body[8:12])), binary.LittleEndian.Uint32(body), true
}
// isGSOSupported tests if the kernel supports GSO.
// isGSOEnabled tests if the kernel supports GSO.
// Sending with GSO might still fail later on, if the interface doesn't support it (see isGSOError).
func isGSOSupported(conn syscall.RawConn) bool {
func isGSOEnabled(conn syscall.RawConn) bool {
if kernelVersionMajor < 5 {
return false
}
disabled, err := strconv.ParseBool(os.Getenv("QUIC_GO_DISABLE_GSO"))
if err == nil && disabled {
return false
@@ -108,3 +117,40 @@ func isPermissionError(err error) bool {
}
return false
}
func isECNEnabled() bool {
return kernelVersionMajor >= 5 && !isECNDisabledUsingEnv()
}
// kernelVersion returns major and minor kernel version numbers, parsed from
// the syscall.Uname's Release field, or 0, 0 if the version can't be obtained
// or parsed.
//
// copied from the standard library's internal/syscall/unix/kernel_version_linux.go
func kernelVersion() (major, minor int) {
var uname syscall.Utsname
if err := syscall.Uname(&uname); err != nil {
return
}
var (
values [2]int
value, vi int
)
for _, c := range uname.Release {
if '0' <= c && c <= '9' {
value = (value * 10) + int(c-'0')
} else {
// Note that we're assuming N.N.N here.
// If we see anything else, we are likely to mis-parse it.
values[vi] = value
vi++
if vi >= len(values) {
break
}
value = 0
}
}
return values[0], values[1]
}

View File

@@ -59,7 +59,7 @@ func inspectWriteBuffer(c syscall.RawConn) (int, error) {
return size, serr
}
func isECNDisabled() bool {
func isECNDisabledUsingEnv() bool {
disabled, err := strconv.ParseBool(os.Getenv("QUIC_GO_DISABLE_ECN"))
return err == nil && disabled
}
@@ -147,8 +147,8 @@ func newConn(c OOBCapablePacketConn, supportsDF bool) (*oobConn, error) {
readPos: batchSize,
cap: connCapabilities{
DF: supportsDF,
GSO: isGSOSupported(rawConn),
ECN: !isECNDisabled(),
GSO: isGSOEnabled(rawConn),
ECN: isECNEnabled(),
},
}
for i := 0; i < batchSize; i++ {
@@ -247,7 +247,7 @@ func (c *oobConn) WritePacket(b []byte, addr net.Addr, packetInfoOOB []byte, gso
}
if ecn != protocol.ECNUnsupported {
if !c.capabilities().ECN {
panic("tried to send a ECN-marked packet although ECN is disabled")
panic("tried to send an ECN-marked packet although ECN is disabled")
}
if remoteUDPAddr, ok := addr.(*net.UDPAddr); ok {
if remoteUDPAddr.IP.To4() != nil {

View File

@@ -89,6 +89,17 @@ type Transport struct {
// implementation of this callback (negating its return value).
VerifySourceAddress func(net.Addr) bool
// ConnContext is called when the server accepts a new connection.
// The context is closed when the connection is closed, or when the handshake fails for any reason.
// The context returned from the callback is used to derive every other context used during the
// lifetime of the connection:
// * the context passed to crypto/tls (and used on the tls.ClientHelloInfo)
// * the context used in Config.Tracer
// * the context returned from Connection.Context
// * the context returned from SendStream.Context
// It is not used for dialed connections.
ConnContext func(context.Context) context.Context
// A Tracer traces events that don't belong to a single QUIC connection.
// Tracer.Close is called when the transport is closed.
Tracer *logging.Tracer
@@ -168,6 +179,7 @@ func (t *Transport) createServer(tlsConf *tls.Config, conf *Config, allow0RTT bo
t.conn,
t.handlerMap,
t.connIDGenerator,
t.ConnContext,
tlsConf,
conf,
t.Tracer,