TUN-4597: Add a QUIC server skeleton

- Added a QUIC server to accept streams
- Unit test for this server also tests ALPN
- Temporary echo capability for HTTP ConnectionType
This commit is contained in:
Sudarsan Reddy
2021-08-03 10:04:02 +01:00
parent fd4000184c
commit ed024d0741
768 changed files with 84848 additions and 15639 deletions

View File

@@ -0,0 +1,20 @@
package ackhandler
import "github.com/lucas-clemente/quic-go/internal/wire"
// IsFrameAckEliciting returns true if the frame is ack-eliciting.
func IsFrameAckEliciting(f wire.Frame) bool {
_, isAck := f.(*wire.AckFrame)
_, isConnectionClose := f.(*wire.ConnectionCloseFrame)
return !isAck && !isConnectionClose
}
// HasAckElicitingFrames returns true if at least one frame is ack-eliciting.
func HasAckElicitingFrames(fs []Frame) bool {
for _, f := range fs {
if IsFrameAckEliciting(f.Frame) {
return true
}
}
return false
}

View File

@@ -0,0 +1,21 @@
package ackhandler
import (
"github.com/lucas-clemente/quic-go/internal/protocol"
"github.com/lucas-clemente/quic-go/internal/utils"
"github.com/lucas-clemente/quic-go/logging"
)
// NewAckHandler creates a new SentPacketHandler and a new ReceivedPacketHandler
func NewAckHandler(
initialPacketNumber protocol.PacketNumber,
initialMaxDatagramSize protocol.ByteCount,
rttStats *utils.RTTStats,
pers protocol.Perspective,
tracer logging.ConnectionTracer,
logger utils.Logger,
version protocol.VersionNumber,
) (SentPacketHandler, ReceivedPacketHandler) {
sph := newSentPacketHandler(initialPacketNumber, initialMaxDatagramSize, rttStats, pers, tracer, logger)
return sph, newReceivedPacketHandler(sph, rttStats, logger, version)
}

View File

@@ -0,0 +1,9 @@
package ackhandler
import "github.com/lucas-clemente/quic-go/internal/wire"
type Frame struct {
wire.Frame // nil if the frame has already been acknowledged in another packet
OnLost func(wire.Frame)
OnAcked func(wire.Frame)
}

View File

@@ -0,0 +1,3 @@
package ackhandler
//go:generate genny -pkg ackhandler -in ../utils/linkedlist/linkedlist.go -out packet_linkedlist.go gen Item=Packet

View File

@@ -0,0 +1,68 @@
package ackhandler
import (
"time"
"github.com/lucas-clemente/quic-go/internal/protocol"
"github.com/lucas-clemente/quic-go/internal/wire"
)
// A Packet is a packet
type Packet struct {
PacketNumber protocol.PacketNumber
Frames []Frame
LargestAcked protocol.PacketNumber // InvalidPacketNumber if the packet doesn't contain an ACK
Length protocol.ByteCount
EncryptionLevel protocol.EncryptionLevel
SendTime time.Time
IsPathMTUProbePacket bool // We don't report the loss of Path MTU probe packets to the congestion controller.
includedInBytesInFlight bool
declaredLost bool
skippedPacket bool
}
// SentPacketHandler handles ACKs received for outgoing packets
type SentPacketHandler interface {
// SentPacket may modify the packet
SentPacket(packet *Packet)
ReceivedAck(ackFrame *wire.AckFrame, encLevel protocol.EncryptionLevel, recvTime time.Time) (bool /* 1-RTT packet acked */, error)
ReceivedBytes(protocol.ByteCount)
DropPackets(protocol.EncryptionLevel)
ResetForRetry() error
SetHandshakeConfirmed()
// The SendMode determines if and what kind of packets can be sent.
SendMode() SendMode
// TimeUntilSend is the time when the next packet should be sent.
// It is used for pacing packets.
TimeUntilSend() time.Time
// HasPacingBudget says if the pacer allows sending of a (full size) packet at this moment.
HasPacingBudget() bool
SetMaxDatagramSize(count protocol.ByteCount)
// only to be called once the handshake is complete
QueueProbePacket(protocol.EncryptionLevel) bool /* was a packet queued */
PeekPacketNumber(protocol.EncryptionLevel) (protocol.PacketNumber, protocol.PacketNumberLen)
PopPacketNumber(protocol.EncryptionLevel) protocol.PacketNumber
GetLossDetectionTimeout() time.Time
OnLossDetectionTimeout() error
}
type sentPacketTracker interface {
GetLowestPacketNotConfirmedAcked() protocol.PacketNumber
ReceivedPacket(protocol.EncryptionLevel)
}
// ReceivedPacketHandler handles ACKs needed to send for incoming packets
type ReceivedPacketHandler interface {
IsPotentiallyDuplicate(protocol.PacketNumber, protocol.EncryptionLevel) bool
ReceivedPacket(pn protocol.PacketNumber, ecn protocol.ECN, encLevel protocol.EncryptionLevel, rcvTime time.Time, shouldInstigateAck bool) error
DropPackets(protocol.EncryptionLevel)
GetAlarmTimeout() time.Time
GetAckFrame(encLevel protocol.EncryptionLevel, onlyIfQueued bool) *wire.AckFrame
}

View File

@@ -0,0 +1,3 @@
package ackhandler
//go:generate sh -c "../../mockgen_private.sh ackhandler mock_sent_packet_tracker_test.go github.com/lucas-clemente/quic-go/internal/ackhandler sentPacketTracker"

View File

@@ -0,0 +1,217 @@
// This file was automatically generated by genny.
// Any changes will be lost if this file is regenerated.
// see https://github.com/cheekybits/genny
package ackhandler
// Linked list implementation from the Go standard library.
// PacketElement is an element of a linked list.
type PacketElement struct {
// Next and previous pointers in the doubly-linked list of elements.
// To simplify the implementation, internally a list l is implemented
// as a ring, such that &l.root is both the next element of the last
// list element (l.Back()) and the previous element of the first list
// element (l.Front()).
next, prev *PacketElement
// The list to which this element belongs.
list *PacketList
// The value stored with this element.
Value Packet
}
// Next returns the next list element or nil.
func (e *PacketElement) Next() *PacketElement {
if p := e.next; e.list != nil && p != &e.list.root {
return p
}
return nil
}
// Prev returns the previous list element or nil.
func (e *PacketElement) Prev() *PacketElement {
if p := e.prev; e.list != nil && p != &e.list.root {
return p
}
return nil
}
// PacketList is a linked list of Packets.
type PacketList struct {
root PacketElement // sentinel list element, only &root, root.prev, and root.next are used
len int // current list length excluding (this) sentinel element
}
// Init initializes or clears list l.
func (l *PacketList) Init() *PacketList {
l.root.next = &l.root
l.root.prev = &l.root
l.len = 0
return l
}
// NewPacketList returns an initialized list.
func NewPacketList() *PacketList { return new(PacketList).Init() }
// Len returns the number of elements of list l.
// The complexity is O(1).
func (l *PacketList) Len() int { return l.len }
// Front returns the first element of list l or nil if the list is empty.
func (l *PacketList) Front() *PacketElement {
if l.len == 0 {
return nil
}
return l.root.next
}
// Back returns the last element of list l or nil if the list is empty.
func (l *PacketList) Back() *PacketElement {
if l.len == 0 {
return nil
}
return l.root.prev
}
// lazyInit lazily initializes a zero List value.
func (l *PacketList) lazyInit() {
if l.root.next == nil {
l.Init()
}
}
// insert inserts e after at, increments l.len, and returns e.
func (l *PacketList) insert(e, at *PacketElement) *PacketElement {
n := at.next
at.next = e
e.prev = at
e.next = n
n.prev = e
e.list = l
l.len++
return e
}
// insertValue is a convenience wrapper for insert(&Element{Value: v}, at).
func (l *PacketList) insertValue(v Packet, at *PacketElement) *PacketElement {
return l.insert(&PacketElement{Value: v}, at)
}
// remove removes e from its list, decrements l.len, and returns e.
func (l *PacketList) remove(e *PacketElement) *PacketElement {
e.prev.next = e.next
e.next.prev = e.prev
e.next = nil // avoid memory leaks
e.prev = nil // avoid memory leaks
e.list = nil
l.len--
return e
}
// Remove removes e from l if e is an element of list l.
// It returns the element value e.Value.
// The element must not be nil.
func (l *PacketList) Remove(e *PacketElement) Packet {
if e.list == l {
// if e.list == l, l must have been initialized when e was inserted
// in l or l == nil (e is a zero Element) and l.remove will crash
l.remove(e)
}
return e.Value
}
// PushFront inserts a new element e with value v at the front of list l and returns e.
func (l *PacketList) PushFront(v Packet) *PacketElement {
l.lazyInit()
return l.insertValue(v, &l.root)
}
// PushBack inserts a new element e with value v at the back of list l and returns e.
func (l *PacketList) PushBack(v Packet) *PacketElement {
l.lazyInit()
return l.insertValue(v, l.root.prev)
}
// InsertBefore inserts a new element e with value v immediately before mark and returns e.
// If mark is not an element of l, the list is not modified.
// The mark must not be nil.
func (l *PacketList) InsertBefore(v Packet, mark *PacketElement) *PacketElement {
if mark.list != l {
return nil
}
// see comment in List.Remove about initialization of l
return l.insertValue(v, mark.prev)
}
// InsertAfter inserts a new element e with value v immediately after mark and returns e.
// If mark is not an element of l, the list is not modified.
// The mark must not be nil.
func (l *PacketList) InsertAfter(v Packet, mark *PacketElement) *PacketElement {
if mark.list != l {
return nil
}
// see comment in List.Remove about initialization of l
return l.insertValue(v, mark)
}
// MoveToFront moves element e to the front of list l.
// If e is not an element of l, the list is not modified.
// The element must not be nil.
func (l *PacketList) MoveToFront(e *PacketElement) {
if e.list != l || l.root.next == e {
return
}
// see comment in List.Remove about initialization of l
l.insert(l.remove(e), &l.root)
}
// MoveToBack moves element e to the back of list l.
// If e is not an element of l, the list is not modified.
// The element must not be nil.
func (l *PacketList) MoveToBack(e *PacketElement) {
if e.list != l || l.root.prev == e {
return
}
// see comment in List.Remove about initialization of l
l.insert(l.remove(e), l.root.prev)
}
// MoveBefore moves element e to its new position before mark.
// If e or mark is not an element of l, or e == mark, the list is not modified.
// The element and mark must not be nil.
func (l *PacketList) MoveBefore(e, mark *PacketElement) {
if e.list != l || e == mark || mark.list != l {
return
}
l.insert(l.remove(e), mark.prev)
}
// MoveAfter moves element e to its new position after mark.
// If e or mark is not an element of l, or e == mark, the list is not modified.
// The element and mark must not be nil.
func (l *PacketList) MoveAfter(e, mark *PacketElement) {
if e.list != l || e == mark || mark.list != l {
return
}
l.insert(l.remove(e), mark)
}
// PushBackList inserts a copy of an other list at the back of list l.
// The lists l and other may be the same. They must not be nil.
func (l *PacketList) PushBackList(other *PacketList) {
l.lazyInit()
for i, e := other.Len(), other.Front(); i > 0; i, e = i-1, e.Next() {
l.insertValue(e.Value, l.root.prev)
}
}
// PushFrontList inserts a copy of an other list at the front of list l.
// The lists l and other may be the same. They must not be nil.
func (l *PacketList) PushFrontList(other *PacketList) {
l.lazyInit()
for i, e := other.Len(), other.Back(); i > 0; i, e = i-1, e.Prev() {
l.insertValue(e.Value, &l.root)
}
}

View File

@@ -0,0 +1,76 @@
package ackhandler
import (
"github.com/lucas-clemente/quic-go/internal/protocol"
"github.com/lucas-clemente/quic-go/internal/utils"
)
type packetNumberGenerator interface {
Peek() protocol.PacketNumber
Pop() protocol.PacketNumber
}
type sequentialPacketNumberGenerator struct {
next protocol.PacketNumber
}
var _ packetNumberGenerator = &sequentialPacketNumberGenerator{}
func newSequentialPacketNumberGenerator(initial protocol.PacketNumber) packetNumberGenerator {
return &sequentialPacketNumberGenerator{next: initial}
}
func (p *sequentialPacketNumberGenerator) Peek() protocol.PacketNumber {
return p.next
}
func (p *sequentialPacketNumberGenerator) Pop() protocol.PacketNumber {
next := p.next
p.next++
return next
}
// The skippingPacketNumberGenerator generates the packet number for the next packet
// it randomly skips a packet number every averagePeriod packets (on average).
// It is guaranteed to never skip two consecutive packet numbers.
type skippingPacketNumberGenerator struct {
period protocol.PacketNumber
maxPeriod protocol.PacketNumber
next protocol.PacketNumber
nextToSkip protocol.PacketNumber
rng utils.Rand
}
var _ packetNumberGenerator = &skippingPacketNumberGenerator{}
func newSkippingPacketNumberGenerator(initial, initialPeriod, maxPeriod protocol.PacketNumber) packetNumberGenerator {
g := &skippingPacketNumberGenerator{
next: initial,
period: initialPeriod,
maxPeriod: maxPeriod,
}
g.generateNewSkip()
return g
}
func (p *skippingPacketNumberGenerator) Peek() protocol.PacketNumber {
return p.next
}
func (p *skippingPacketNumberGenerator) Pop() protocol.PacketNumber {
next := p.next
p.next++ // generate a new packet number for the next packet
if p.next == p.nextToSkip {
p.next++
p.generateNewSkip()
}
return next
}
func (p *skippingPacketNumberGenerator) generateNewSkip() {
// make sure that there are never two consecutive packet numbers that are skipped
p.nextToSkip = p.next + 2 + protocol.PacketNumber(p.rng.Int31n(int32(2*p.period)))
p.period = utils.MinPacketNumber(2*p.period, p.maxPeriod)
}

View File

@@ -0,0 +1,136 @@
package ackhandler
import (
"fmt"
"time"
"github.com/lucas-clemente/quic-go/internal/protocol"
"github.com/lucas-clemente/quic-go/internal/utils"
"github.com/lucas-clemente/quic-go/internal/wire"
)
type receivedPacketHandler struct {
sentPackets sentPacketTracker
initialPackets *receivedPacketTracker
handshakePackets *receivedPacketTracker
appDataPackets *receivedPacketTracker
lowest1RTTPacket protocol.PacketNumber
}
var _ ReceivedPacketHandler = &receivedPacketHandler{}
func newReceivedPacketHandler(
sentPackets sentPacketTracker,
rttStats *utils.RTTStats,
logger utils.Logger,
version protocol.VersionNumber,
) ReceivedPacketHandler {
return &receivedPacketHandler{
sentPackets: sentPackets,
initialPackets: newReceivedPacketTracker(rttStats, logger, version),
handshakePackets: newReceivedPacketTracker(rttStats, logger, version),
appDataPackets: newReceivedPacketTracker(rttStats, logger, version),
lowest1RTTPacket: protocol.InvalidPacketNumber,
}
}
func (h *receivedPacketHandler) ReceivedPacket(
pn protocol.PacketNumber,
ecn protocol.ECN,
encLevel protocol.EncryptionLevel,
rcvTime time.Time,
shouldInstigateAck bool,
) error {
h.sentPackets.ReceivedPacket(encLevel)
switch encLevel {
case protocol.EncryptionInitial:
h.initialPackets.ReceivedPacket(pn, ecn, rcvTime, shouldInstigateAck)
case protocol.EncryptionHandshake:
h.handshakePackets.ReceivedPacket(pn, ecn, rcvTime, shouldInstigateAck)
case protocol.Encryption0RTT:
if h.lowest1RTTPacket != protocol.InvalidPacketNumber && pn > h.lowest1RTTPacket {
return fmt.Errorf("received packet number %d on a 0-RTT packet after receiving %d on a 1-RTT packet", pn, h.lowest1RTTPacket)
}
h.appDataPackets.ReceivedPacket(pn, ecn, rcvTime, shouldInstigateAck)
case protocol.Encryption1RTT:
if h.lowest1RTTPacket == protocol.InvalidPacketNumber || pn < h.lowest1RTTPacket {
h.lowest1RTTPacket = pn
}
h.appDataPackets.IgnoreBelow(h.sentPackets.GetLowestPacketNotConfirmedAcked())
h.appDataPackets.ReceivedPacket(pn, ecn, rcvTime, shouldInstigateAck)
default:
panic(fmt.Sprintf("received packet with unknown encryption level: %s", encLevel))
}
return nil
}
func (h *receivedPacketHandler) DropPackets(encLevel protocol.EncryptionLevel) {
//nolint:exhaustive // 1-RTT packet number space is never dropped.
switch encLevel {
case protocol.EncryptionInitial:
h.initialPackets = nil
case protocol.EncryptionHandshake:
h.handshakePackets = nil
case protocol.Encryption0RTT:
// Nothing to do here.
// If we are rejecting 0-RTT, no 0-RTT packets will have been decrypted.
default:
panic(fmt.Sprintf("Cannot drop keys for encryption level %s", encLevel))
}
}
func (h *receivedPacketHandler) GetAlarmTimeout() time.Time {
var initialAlarm, handshakeAlarm time.Time
if h.initialPackets != nil {
initialAlarm = h.initialPackets.GetAlarmTimeout()
}
if h.handshakePackets != nil {
handshakeAlarm = h.handshakePackets.GetAlarmTimeout()
}
oneRTTAlarm := h.appDataPackets.GetAlarmTimeout()
return utils.MinNonZeroTime(utils.MinNonZeroTime(initialAlarm, handshakeAlarm), oneRTTAlarm)
}
func (h *receivedPacketHandler) GetAckFrame(encLevel protocol.EncryptionLevel, onlyIfQueued bool) *wire.AckFrame {
var ack *wire.AckFrame
//nolint:exhaustive // 0-RTT packets can't contain ACK frames.
switch encLevel {
case protocol.EncryptionInitial:
if h.initialPackets != nil {
ack = h.initialPackets.GetAckFrame(onlyIfQueued)
}
case protocol.EncryptionHandshake:
if h.handshakePackets != nil {
ack = h.handshakePackets.GetAckFrame(onlyIfQueued)
}
case protocol.Encryption1RTT:
// 0-RTT packets can't contain ACK frames
return h.appDataPackets.GetAckFrame(onlyIfQueued)
default:
return nil
}
// For Initial and Handshake ACKs, the delay time is ignored by the receiver.
// Set it to 0 in order to save bytes.
if ack != nil {
ack.DelayTime = 0
}
return ack
}
func (h *receivedPacketHandler) IsPotentiallyDuplicate(pn protocol.PacketNumber, encLevel protocol.EncryptionLevel) bool {
switch encLevel {
case protocol.EncryptionInitial:
if h.initialPackets != nil {
return h.initialPackets.IsPotentiallyDuplicate(pn)
}
case protocol.EncryptionHandshake:
if h.handshakePackets != nil {
return h.handshakePackets.IsPotentiallyDuplicate(pn)
}
case protocol.Encryption0RTT, protocol.Encryption1RTT:
return h.appDataPackets.IsPotentiallyDuplicate(pn)
}
panic("unexpected encryption level")
}

View File

@@ -0,0 +1,142 @@
package ackhandler
import (
"github.com/lucas-clemente/quic-go/internal/protocol"
"github.com/lucas-clemente/quic-go/internal/utils"
"github.com/lucas-clemente/quic-go/internal/wire"
)
// The receivedPacketHistory stores if a packet number has already been received.
// It generates ACK ranges which can be used to assemble an ACK frame.
// It does not store packet contents.
type receivedPacketHistory struct {
ranges *utils.PacketIntervalList
deletedBelow protocol.PacketNumber
}
func newReceivedPacketHistory() *receivedPacketHistory {
return &receivedPacketHistory{
ranges: utils.NewPacketIntervalList(),
}
}
// ReceivedPacket registers a packet with PacketNumber p and updates the ranges
func (h *receivedPacketHistory) ReceivedPacket(p protocol.PacketNumber) bool /* is a new packet (and not a duplicate / delayed packet) */ {
// ignore delayed packets, if we already deleted the range
if p < h.deletedBelow {
return false
}
isNew := h.addToRanges(p)
h.maybeDeleteOldRanges()
return isNew
}
func (h *receivedPacketHistory) addToRanges(p protocol.PacketNumber) bool /* is a new packet (and not a duplicate / delayed packet) */ {
if h.ranges.Len() == 0 {
h.ranges.PushBack(utils.PacketInterval{Start: p, End: p})
return true
}
for el := h.ranges.Back(); el != nil; el = el.Prev() {
// p already included in an existing range. Nothing to do here
if p >= el.Value.Start && p <= el.Value.End {
return false
}
if el.Value.End == p-1 { // extend a range at the end
el.Value.End = p
return true
}
if el.Value.Start == p+1 { // extend a range at the beginning
el.Value.Start = p
prev := el.Prev()
if prev != nil && prev.Value.End+1 == el.Value.Start { // merge two ranges
prev.Value.End = el.Value.End
h.ranges.Remove(el)
}
return true
}
// create a new range at the end
if p > el.Value.End {
h.ranges.InsertAfter(utils.PacketInterval{Start: p, End: p}, el)
return true
}
}
// create a new range at the beginning
h.ranges.InsertBefore(utils.PacketInterval{Start: p, End: p}, h.ranges.Front())
return true
}
// Delete old ranges, if we're tracking more than 500 of them.
// This is a DoS defense against a peer that sends us too many gaps.
func (h *receivedPacketHistory) maybeDeleteOldRanges() {
for h.ranges.Len() > protocol.MaxNumAckRanges {
h.ranges.Remove(h.ranges.Front())
}
}
// DeleteBelow deletes all entries below (but not including) p
func (h *receivedPacketHistory) DeleteBelow(p protocol.PacketNumber) {
if p < h.deletedBelow {
return
}
h.deletedBelow = p
nextEl := h.ranges.Front()
for el := h.ranges.Front(); nextEl != nil; el = nextEl {
nextEl = el.Next()
if el.Value.End < p { // delete a whole range
h.ranges.Remove(el)
} else if p > el.Value.Start && p <= el.Value.End {
el.Value.Start = p
return
} else { // no ranges affected. Nothing to do
return
}
}
}
// GetAckRanges gets a slice of all AckRanges that can be used in an AckFrame
func (h *receivedPacketHistory) GetAckRanges() []wire.AckRange {
if h.ranges.Len() == 0 {
return nil
}
ackRanges := make([]wire.AckRange, h.ranges.Len())
i := 0
for el := h.ranges.Back(); el != nil; el = el.Prev() {
ackRanges[i] = wire.AckRange{Smallest: el.Value.Start, Largest: el.Value.End}
i++
}
return ackRanges
}
func (h *receivedPacketHistory) GetHighestAckRange() wire.AckRange {
ackRange := wire.AckRange{}
if h.ranges.Len() > 0 {
r := h.ranges.Back().Value
ackRange.Smallest = r.Start
ackRange.Largest = r.End
}
return ackRange
}
func (h *receivedPacketHistory) IsPotentiallyDuplicate(p protocol.PacketNumber) bool {
if p < h.deletedBelow {
return true
}
for el := h.ranges.Back(); el != nil; el = el.Prev() {
if p > el.Value.End {
return false
}
if p <= el.Value.End && p >= el.Value.Start {
return true
}
}
return false
}

View File

@@ -0,0 +1,196 @@
package ackhandler
import (
"time"
"github.com/lucas-clemente/quic-go/internal/protocol"
"github.com/lucas-clemente/quic-go/internal/utils"
"github.com/lucas-clemente/quic-go/internal/wire"
)
// number of ack-eliciting packets received before sending an ack.
const packetsBeforeAck = 2
type receivedPacketTracker struct {
largestObserved protocol.PacketNumber
ignoreBelow protocol.PacketNumber
largestObservedReceivedTime time.Time
ect0, ect1, ecnce uint64
packetHistory *receivedPacketHistory
maxAckDelay time.Duration
rttStats *utils.RTTStats
hasNewAck bool // true as soon as we received an ack-eliciting new packet
ackQueued bool // true once we received more than 2 (or later in the connection 10) ack-eliciting packets
ackElicitingPacketsReceivedSinceLastAck int
ackAlarm time.Time
lastAck *wire.AckFrame
logger utils.Logger
version protocol.VersionNumber
}
func newReceivedPacketTracker(
rttStats *utils.RTTStats,
logger utils.Logger,
version protocol.VersionNumber,
) *receivedPacketTracker {
return &receivedPacketTracker{
packetHistory: newReceivedPacketHistory(),
maxAckDelay: protocol.MaxAckDelay,
rttStats: rttStats,
logger: logger,
version: version,
}
}
func (h *receivedPacketTracker) ReceivedPacket(packetNumber protocol.PacketNumber, ecn protocol.ECN, rcvTime time.Time, shouldInstigateAck bool) {
if packetNumber < h.ignoreBelow {
return
}
isMissing := h.isMissing(packetNumber)
if packetNumber >= h.largestObserved {
h.largestObserved = packetNumber
h.largestObservedReceivedTime = rcvTime
}
if isNew := h.packetHistory.ReceivedPacket(packetNumber); isNew && shouldInstigateAck {
h.hasNewAck = true
}
if shouldInstigateAck {
h.maybeQueueAck(packetNumber, rcvTime, isMissing)
}
switch ecn {
case protocol.ECNNon:
case protocol.ECT0:
h.ect0++
case protocol.ECT1:
h.ect1++
case protocol.ECNCE:
h.ecnce++
}
}
// IgnoreBelow sets a lower limit for acknowledging packets.
// Packets with packet numbers smaller than p will not be acked.
func (h *receivedPacketTracker) IgnoreBelow(p protocol.PacketNumber) {
if p <= h.ignoreBelow {
return
}
h.ignoreBelow = p
h.packetHistory.DeleteBelow(p)
if h.logger.Debug() {
h.logger.Debugf("\tIgnoring all packets below %d.", p)
}
}
// isMissing says if a packet was reported missing in the last ACK.
func (h *receivedPacketTracker) isMissing(p protocol.PacketNumber) bool {
if h.lastAck == nil || p < h.ignoreBelow {
return false
}
return p < h.lastAck.LargestAcked() && !h.lastAck.AcksPacket(p)
}
func (h *receivedPacketTracker) hasNewMissingPackets() bool {
if h.lastAck == nil {
return false
}
highestRange := h.packetHistory.GetHighestAckRange()
return highestRange.Smallest > h.lastAck.LargestAcked()+1 && highestRange.Len() == 1
}
// maybeQueueAck queues an ACK, if necessary.
func (h *receivedPacketTracker) maybeQueueAck(pn protocol.PacketNumber, rcvTime time.Time, wasMissing bool) {
// always acknowledge the first packet
if h.lastAck == nil {
if !h.ackQueued {
h.logger.Debugf("\tQueueing ACK because the first packet should be acknowledged.")
}
h.ackQueued = true
return
}
if h.ackQueued {
return
}
h.ackElicitingPacketsReceivedSinceLastAck++
// Send an ACK if this packet was reported missing in an ACK sent before.
// Ack decimation with reordering relies on the timer to send an ACK, but if
// missing packets we reported in the previous ack, send an ACK immediately.
if wasMissing {
if h.logger.Debug() {
h.logger.Debugf("\tQueueing ACK because packet %d was missing before.", pn)
}
h.ackQueued = true
}
// send an ACK every 2 ack-eliciting packets
if h.ackElicitingPacketsReceivedSinceLastAck >= packetsBeforeAck {
if h.logger.Debug() {
h.logger.Debugf("\tQueueing ACK because packet %d packets were received after the last ACK (using initial threshold: %d).", h.ackElicitingPacketsReceivedSinceLastAck, packetsBeforeAck)
}
h.ackQueued = true
} else if h.ackAlarm.IsZero() {
if h.logger.Debug() {
h.logger.Debugf("\tSetting ACK timer to max ack delay: %s", h.maxAckDelay)
}
h.ackAlarm = rcvTime.Add(h.maxAckDelay)
}
// Queue an ACK if there are new missing packets to report.
if h.hasNewMissingPackets() {
h.logger.Debugf("\tQueuing ACK because there's a new missing packet to report.")
h.ackQueued = true
}
if h.ackQueued {
// cancel the ack alarm
h.ackAlarm = time.Time{}
}
}
func (h *receivedPacketTracker) GetAckFrame(onlyIfQueued bool) *wire.AckFrame {
if !h.hasNewAck {
return nil
}
now := time.Now()
if onlyIfQueued {
if !h.ackQueued && (h.ackAlarm.IsZero() || h.ackAlarm.After(now)) {
return nil
}
if h.logger.Debug() && !h.ackQueued && !h.ackAlarm.IsZero() {
h.logger.Debugf("Sending ACK because the ACK timer expired.")
}
}
ack := &wire.AckFrame{
AckRanges: h.packetHistory.GetAckRanges(),
// Make sure that the DelayTime is always positive.
// This is not guaranteed on systems that don't have a monotonic clock.
DelayTime: utils.MaxDuration(0, now.Sub(h.largestObservedReceivedTime)),
ECT0: h.ect0,
ECT1: h.ect1,
ECNCE: h.ecnce,
}
h.lastAck = ack
h.ackAlarm = time.Time{}
h.ackQueued = false
h.hasNewAck = false
h.ackElicitingPacketsReceivedSinceLastAck = 0
return ack
}
func (h *receivedPacketTracker) GetAlarmTimeout() time.Time { return h.ackAlarm }
func (h *receivedPacketTracker) IsPotentiallyDuplicate(pn protocol.PacketNumber) bool {
return h.packetHistory.IsPotentiallyDuplicate(pn)
}

View File

@@ -0,0 +1,40 @@
package ackhandler
import "fmt"
// The SendMode says what kind of packets can be sent.
type SendMode uint8
const (
// SendNone means that no packets should be sent
SendNone SendMode = iota
// SendAck means an ACK-only packet should be sent
SendAck
// SendPTOInitial means that an Initial probe packet should be sent
SendPTOInitial
// SendPTOHandshake means that a Handshake probe packet should be sent
SendPTOHandshake
// SendPTOAppData means that an Application data probe packet should be sent
SendPTOAppData
// SendAny means that any packet should be sent
SendAny
)
func (s SendMode) String() string {
switch s {
case SendNone:
return "none"
case SendAck:
return "ack"
case SendPTOInitial:
return "pto (Initial)"
case SendPTOHandshake:
return "pto (Handshake)"
case SendPTOAppData:
return "pto (Application Data)"
case SendAny:
return "any"
default:
return fmt.Sprintf("invalid send mode: %d", s)
}
}

View File

@@ -0,0 +1,832 @@
package ackhandler
import (
"errors"
"fmt"
"time"
"github.com/lucas-clemente/quic-go/internal/congestion"
"github.com/lucas-clemente/quic-go/internal/protocol"
"github.com/lucas-clemente/quic-go/internal/qerr"
"github.com/lucas-clemente/quic-go/internal/utils"
"github.com/lucas-clemente/quic-go/internal/wire"
"github.com/lucas-clemente/quic-go/logging"
)
const (
// Maximum reordering in time space before time based loss detection considers a packet lost.
// Specified as an RTT multiplier.
timeThreshold = 9.0 / 8
// Maximum reordering in packets before packet threshold loss detection considers a packet lost.
packetThreshold = 3
// Before validating the client's address, the server won't send more than 3x bytes than it received.
amplificationFactor = 3
// We use Retry packets to derive an RTT estimate. Make sure we don't set the RTT to a super low value yet.
minRTTAfterRetry = 5 * time.Millisecond
)
type packetNumberSpace struct {
history *sentPacketHistory
pns packetNumberGenerator
lossTime time.Time
lastAckElicitingPacketTime time.Time
largestAcked protocol.PacketNumber
largestSent protocol.PacketNumber
}
func newPacketNumberSpace(initialPN protocol.PacketNumber, skipPNs bool, rttStats *utils.RTTStats) *packetNumberSpace {
var pns packetNumberGenerator
if skipPNs {
pns = newSkippingPacketNumberGenerator(initialPN, protocol.SkipPacketInitialPeriod, protocol.SkipPacketMaxPeriod)
} else {
pns = newSequentialPacketNumberGenerator(initialPN)
}
return &packetNumberSpace{
history: newSentPacketHistory(rttStats),
pns: pns,
largestSent: protocol.InvalidPacketNumber,
largestAcked: protocol.InvalidPacketNumber,
}
}
type sentPacketHandler struct {
initialPackets *packetNumberSpace
handshakePackets *packetNumberSpace
appDataPackets *packetNumberSpace
// Do we know that the peer completed address validation yet?
// Always true for the server.
peerCompletedAddressValidation bool
bytesReceived protocol.ByteCount
bytesSent protocol.ByteCount
// Have we validated the peer's address yet?
// Always true for the client.
peerAddressValidated bool
handshakeConfirmed bool
// lowestNotConfirmedAcked is the lowest packet number that we sent an ACK for, but haven't received confirmation, that this ACK actually arrived
// example: we send an ACK for packets 90-100 with packet number 20
// once we receive an ACK from the peer for packet 20, the lowestNotConfirmedAcked is 101
// Only applies to the application-data packet number space.
lowestNotConfirmedAcked protocol.PacketNumber
ackedPackets []*Packet // to avoid allocations in detectAndRemoveAckedPackets
bytesInFlight protocol.ByteCount
congestion congestion.SendAlgorithmWithDebugInfos
rttStats *utils.RTTStats
// The number of times a PTO has been sent without receiving an ack.
ptoCount uint32
ptoMode SendMode
// The number of PTO probe packets that should be sent.
// Only applies to the application-data packet number space.
numProbesToSend int
// The alarm timeout
alarm time.Time
perspective protocol.Perspective
tracer logging.ConnectionTracer
logger utils.Logger
}
var (
_ SentPacketHandler = &sentPacketHandler{}
_ sentPacketTracker = &sentPacketHandler{}
)
func newSentPacketHandler(
initialPN protocol.PacketNumber,
initialMaxDatagramSize protocol.ByteCount,
rttStats *utils.RTTStats,
pers protocol.Perspective,
tracer logging.ConnectionTracer,
logger utils.Logger,
) *sentPacketHandler {
congestion := congestion.NewCubicSender(
congestion.DefaultClock{},
rttStats,
initialMaxDatagramSize,
true, // use Reno
tracer,
)
return &sentPacketHandler{
peerCompletedAddressValidation: pers == protocol.PerspectiveServer,
peerAddressValidated: pers == protocol.PerspectiveClient,
initialPackets: newPacketNumberSpace(initialPN, false, rttStats),
handshakePackets: newPacketNumberSpace(0, false, rttStats),
appDataPackets: newPacketNumberSpace(0, true, rttStats),
rttStats: rttStats,
congestion: congestion,
perspective: pers,
tracer: tracer,
logger: logger,
}
}
func (h *sentPacketHandler) DropPackets(encLevel protocol.EncryptionLevel) {
if h.perspective == protocol.PerspectiveClient && encLevel == protocol.EncryptionInitial {
// This function is called when the crypto setup seals a Handshake packet.
// If this Handshake packet is coalesced behind an Initial packet, we would drop the Initial packet number space
// before SentPacket() was called for that Initial packet.
return
}
h.dropPackets(encLevel)
}
func (h *sentPacketHandler) removeFromBytesInFlight(p *Packet) {
if p.includedInBytesInFlight {
if p.Length > h.bytesInFlight {
panic("negative bytes_in_flight")
}
h.bytesInFlight -= p.Length
p.includedInBytesInFlight = false
}
}
func (h *sentPacketHandler) dropPackets(encLevel protocol.EncryptionLevel) {
// The server won't await address validation after the handshake is confirmed.
// This applies even if we didn't receive an ACK for a Handshake packet.
if h.perspective == protocol.PerspectiveClient && encLevel == protocol.EncryptionHandshake {
h.peerCompletedAddressValidation = true
}
// remove outstanding packets from bytes_in_flight
if encLevel == protocol.EncryptionInitial || encLevel == protocol.EncryptionHandshake {
pnSpace := h.getPacketNumberSpace(encLevel)
pnSpace.history.Iterate(func(p *Packet) (bool, error) {
h.removeFromBytesInFlight(p)
return true, nil
})
}
// drop the packet history
//nolint:exhaustive // Not every packet number space can be dropped.
switch encLevel {
case protocol.EncryptionInitial:
h.initialPackets = nil
case protocol.EncryptionHandshake:
h.handshakePackets = nil
case protocol.Encryption0RTT:
// This function is only called when 0-RTT is rejected,
// and not when the client drops 0-RTT keys when the handshake completes.
// When 0-RTT is rejected, all application data sent so far becomes invalid.
// Delete the packets from the history and remove them from bytes_in_flight.
h.appDataPackets.history.Iterate(func(p *Packet) (bool, error) {
if p.EncryptionLevel != protocol.Encryption0RTT {
return false, nil
}
h.removeFromBytesInFlight(p)
h.appDataPackets.history.Remove(p.PacketNumber)
return true, nil
})
default:
panic(fmt.Sprintf("Cannot drop keys for encryption level %s", encLevel))
}
if h.tracer != nil && h.ptoCount != 0 {
h.tracer.UpdatedPTOCount(0)
}
h.ptoCount = 0
h.numProbesToSend = 0
h.ptoMode = SendNone
h.setLossDetectionTimer()
}
func (h *sentPacketHandler) ReceivedBytes(n protocol.ByteCount) {
wasAmplificationLimit := h.isAmplificationLimited()
h.bytesReceived += n
if wasAmplificationLimit && !h.isAmplificationLimited() {
h.setLossDetectionTimer()
}
}
func (h *sentPacketHandler) ReceivedPacket(l protocol.EncryptionLevel) {
if h.perspective == protocol.PerspectiveServer && l == protocol.EncryptionHandshake && !h.peerAddressValidated {
h.peerAddressValidated = true
h.setLossDetectionTimer()
}
}
func (h *sentPacketHandler) packetsInFlight() int {
packetsInFlight := h.appDataPackets.history.Len()
if h.handshakePackets != nil {
packetsInFlight += h.handshakePackets.history.Len()
}
if h.initialPackets != nil {
packetsInFlight += h.initialPackets.history.Len()
}
return packetsInFlight
}
func (h *sentPacketHandler) SentPacket(packet *Packet) {
h.bytesSent += packet.Length
// For the client, drop the Initial packet number space when the first Handshake packet is sent.
if h.perspective == protocol.PerspectiveClient && packet.EncryptionLevel == protocol.EncryptionHandshake && h.initialPackets != nil {
h.dropPackets(protocol.EncryptionInitial)
}
isAckEliciting := h.sentPacketImpl(packet)
h.getPacketNumberSpace(packet.EncryptionLevel).history.SentPacket(packet, isAckEliciting)
if h.tracer != nil && isAckEliciting {
h.tracer.UpdatedMetrics(h.rttStats, h.congestion.GetCongestionWindow(), h.bytesInFlight, h.packetsInFlight())
}
if isAckEliciting || !h.peerCompletedAddressValidation {
h.setLossDetectionTimer()
}
}
func (h *sentPacketHandler) getPacketNumberSpace(encLevel protocol.EncryptionLevel) *packetNumberSpace {
switch encLevel {
case protocol.EncryptionInitial:
return h.initialPackets
case protocol.EncryptionHandshake:
return h.handshakePackets
case protocol.Encryption0RTT, protocol.Encryption1RTT:
return h.appDataPackets
default:
panic("invalid packet number space")
}
}
func (h *sentPacketHandler) sentPacketImpl(packet *Packet) bool /* is ack-eliciting */ {
pnSpace := h.getPacketNumberSpace(packet.EncryptionLevel)
if h.logger.Debug() && pnSpace.history.HasOutstandingPackets() {
for p := utils.MaxPacketNumber(0, pnSpace.largestSent+1); p < packet.PacketNumber; p++ {
h.logger.Debugf("Skipping packet number %d", p)
}
}
pnSpace.largestSent = packet.PacketNumber
isAckEliciting := len(packet.Frames) > 0
if isAckEliciting {
pnSpace.lastAckElicitingPacketTime = packet.SendTime
packet.includedInBytesInFlight = true
h.bytesInFlight += packet.Length
if h.numProbesToSend > 0 {
h.numProbesToSend--
}
}
h.congestion.OnPacketSent(packet.SendTime, h.bytesInFlight, packet.PacketNumber, packet.Length, isAckEliciting)
return isAckEliciting
}
func (h *sentPacketHandler) ReceivedAck(ack *wire.AckFrame, encLevel protocol.EncryptionLevel, rcvTime time.Time) (bool /* contained 1-RTT packet */, error) {
pnSpace := h.getPacketNumberSpace(encLevel)
largestAcked := ack.LargestAcked()
if largestAcked > pnSpace.largestSent {
return false, &qerr.TransportError{
ErrorCode: qerr.ProtocolViolation,
ErrorMessage: "received ACK for an unsent packet",
}
}
pnSpace.largestAcked = utils.MaxPacketNumber(pnSpace.largestAcked, largestAcked)
// Servers complete address validation when a protected packet is received.
if h.perspective == protocol.PerspectiveClient && !h.peerCompletedAddressValidation &&
(encLevel == protocol.EncryptionHandshake || encLevel == protocol.Encryption1RTT) {
h.peerCompletedAddressValidation = true
h.logger.Debugf("Peer doesn't await address validation any longer.")
// Make sure that the timer is reset, even if this ACK doesn't acknowledge any (ack-eliciting) packets.
h.setLossDetectionTimer()
}
priorInFlight := h.bytesInFlight
ackedPackets, err := h.detectAndRemoveAckedPackets(ack, encLevel)
if err != nil || len(ackedPackets) == 0 {
return false, err
}
// update the RTT, if the largest acked is newly acknowledged
if len(ackedPackets) > 0 {
if p := ackedPackets[len(ackedPackets)-1]; p.PacketNumber == ack.LargestAcked() {
// don't use the ack delay for Initial and Handshake packets
var ackDelay time.Duration
if encLevel == protocol.Encryption1RTT {
ackDelay = utils.MinDuration(ack.DelayTime, h.rttStats.MaxAckDelay())
}
h.rttStats.UpdateRTT(rcvTime.Sub(p.SendTime), ackDelay, rcvTime)
if h.logger.Debug() {
h.logger.Debugf("\tupdated RTT: %s (σ: %s)", h.rttStats.SmoothedRTT(), h.rttStats.MeanDeviation())
}
h.congestion.MaybeExitSlowStart()
}
}
if err := h.detectLostPackets(rcvTime, encLevel); err != nil {
return false, err
}
var acked1RTTPacket bool
for _, p := range ackedPackets {
if p.includedInBytesInFlight && !p.declaredLost {
h.congestion.OnPacketAcked(p.PacketNumber, p.Length, priorInFlight, rcvTime)
}
if p.EncryptionLevel == protocol.Encryption1RTT {
acked1RTTPacket = true
}
h.removeFromBytesInFlight(p)
}
// Reset the pto_count unless the client is unsure if the server has validated the client's address.
if h.peerCompletedAddressValidation {
if h.tracer != nil && h.ptoCount != 0 {
h.tracer.UpdatedPTOCount(0)
}
h.ptoCount = 0
}
h.numProbesToSend = 0
if h.tracer != nil {
h.tracer.UpdatedMetrics(h.rttStats, h.congestion.GetCongestionWindow(), h.bytesInFlight, h.packetsInFlight())
}
pnSpace.history.DeleteOldPackets(rcvTime)
h.setLossDetectionTimer()
return acked1RTTPacket, nil
}
func (h *sentPacketHandler) GetLowestPacketNotConfirmedAcked() protocol.PacketNumber {
return h.lowestNotConfirmedAcked
}
// Packets are returned in ascending packet number order.
func (h *sentPacketHandler) detectAndRemoveAckedPackets(ack *wire.AckFrame, encLevel protocol.EncryptionLevel) ([]*Packet, error) {
pnSpace := h.getPacketNumberSpace(encLevel)
h.ackedPackets = h.ackedPackets[:0]
ackRangeIndex := 0
lowestAcked := ack.LowestAcked()
largestAcked := ack.LargestAcked()
err := pnSpace.history.Iterate(func(p *Packet) (bool, error) {
// Ignore packets below the lowest acked
if p.PacketNumber < lowestAcked {
return true, nil
}
// Break after largest acked is reached
if p.PacketNumber > largestAcked {
return false, nil
}
if ack.HasMissingRanges() {
ackRange := ack.AckRanges[len(ack.AckRanges)-1-ackRangeIndex]
for p.PacketNumber > ackRange.Largest && ackRangeIndex < len(ack.AckRanges)-1 {
ackRangeIndex++
ackRange = ack.AckRanges[len(ack.AckRanges)-1-ackRangeIndex]
}
if p.PacketNumber < ackRange.Smallest { // packet not contained in ACK range
return true, nil
}
if p.PacketNumber > ackRange.Largest {
return false, fmt.Errorf("BUG: ackhandler would have acked wrong packet %d, while evaluating range %d -> %d", p.PacketNumber, ackRange.Smallest, ackRange.Largest)
}
}
if p.skippedPacket {
return false, &qerr.TransportError{
ErrorCode: qerr.ProtocolViolation,
ErrorMessage: fmt.Sprintf("received an ACK for skipped packet number: %d (%s)", p.PacketNumber, encLevel),
}
}
h.ackedPackets = append(h.ackedPackets, p)
return true, nil
})
if h.logger.Debug() && len(h.ackedPackets) > 0 {
pns := make([]protocol.PacketNumber, len(h.ackedPackets))
for i, p := range h.ackedPackets {
pns[i] = p.PacketNumber
}
h.logger.Debugf("\tnewly acked packets (%d): %d", len(pns), pns)
}
for _, p := range h.ackedPackets {
if p.LargestAcked != protocol.InvalidPacketNumber && encLevel == protocol.Encryption1RTT {
h.lowestNotConfirmedAcked = utils.MaxPacketNumber(h.lowestNotConfirmedAcked, p.LargestAcked+1)
}
for _, f := range p.Frames {
if f.OnAcked != nil {
f.OnAcked(f.Frame)
}
}
if err := pnSpace.history.Remove(p.PacketNumber); err != nil {
return nil, err
}
if h.tracer != nil {
h.tracer.AcknowledgedPacket(encLevel, p.PacketNumber)
}
}
return h.ackedPackets, err
}
func (h *sentPacketHandler) getLossTimeAndSpace() (time.Time, protocol.EncryptionLevel) {
var encLevel protocol.EncryptionLevel
var lossTime time.Time
if h.initialPackets != nil {
lossTime = h.initialPackets.lossTime
encLevel = protocol.EncryptionInitial
}
if h.handshakePackets != nil && (lossTime.IsZero() || (!h.handshakePackets.lossTime.IsZero() && h.handshakePackets.lossTime.Before(lossTime))) {
lossTime = h.handshakePackets.lossTime
encLevel = protocol.EncryptionHandshake
}
if lossTime.IsZero() || (!h.appDataPackets.lossTime.IsZero() && h.appDataPackets.lossTime.Before(lossTime)) {
lossTime = h.appDataPackets.lossTime
encLevel = protocol.Encryption1RTT
}
return lossTime, encLevel
}
// same logic as getLossTimeAndSpace, but for lastAckElicitingPacketTime instead of lossTime
func (h *sentPacketHandler) getPTOTimeAndSpace() (pto time.Time, encLevel protocol.EncryptionLevel, ok bool) {
// We only send application data probe packets once the handshake is confirmed,
// because before that, we don't have the keys to decrypt ACKs sent in 1-RTT packets.
if !h.handshakeConfirmed && !h.hasOutstandingCryptoPackets() {
if h.peerCompletedAddressValidation {
return
}
t := time.Now().Add(h.rttStats.PTO(false) << h.ptoCount)
if h.initialPackets != nil {
return t, protocol.EncryptionInitial, true
}
return t, protocol.EncryptionHandshake, true
}
if h.initialPackets != nil {
encLevel = protocol.EncryptionInitial
if t := h.initialPackets.lastAckElicitingPacketTime; !t.IsZero() {
pto = t.Add(h.rttStats.PTO(false) << h.ptoCount)
}
}
if h.handshakePackets != nil && !h.handshakePackets.lastAckElicitingPacketTime.IsZero() {
t := h.handshakePackets.lastAckElicitingPacketTime.Add(h.rttStats.PTO(false) << h.ptoCount)
if pto.IsZero() || (!t.IsZero() && t.Before(pto)) {
pto = t
encLevel = protocol.EncryptionHandshake
}
}
if h.handshakeConfirmed && !h.appDataPackets.lastAckElicitingPacketTime.IsZero() {
t := h.appDataPackets.lastAckElicitingPacketTime.Add(h.rttStats.PTO(true) << h.ptoCount)
if pto.IsZero() || (!t.IsZero() && t.Before(pto)) {
pto = t
encLevel = protocol.Encryption1RTT
}
}
return pto, encLevel, true
}
func (h *sentPacketHandler) hasOutstandingCryptoPackets() bool {
var hasInitial, hasHandshake bool
if h.initialPackets != nil {
hasInitial = h.initialPackets.history.HasOutstandingPackets()
}
if h.handshakePackets != nil {
hasHandshake = h.handshakePackets.history.HasOutstandingPackets()
}
return hasInitial || hasHandshake
}
func (h *sentPacketHandler) hasOutstandingPackets() bool {
return h.appDataPackets.history.HasOutstandingPackets() || h.hasOutstandingCryptoPackets()
}
func (h *sentPacketHandler) setLossDetectionTimer() {
oldAlarm := h.alarm // only needed in case tracing is enabled
lossTime, encLevel := h.getLossTimeAndSpace()
if !lossTime.IsZero() {
// Early retransmit timer or time loss detection.
h.alarm = lossTime
if h.tracer != nil && h.alarm != oldAlarm {
h.tracer.SetLossTimer(logging.TimerTypeACK, encLevel, h.alarm)
}
return
}
// Cancel the alarm if amplification limited.
if h.isAmplificationLimited() {
h.alarm = time.Time{}
if !oldAlarm.IsZero() {
h.logger.Debugf("Canceling loss detection timer. Amplification limited.")
if h.tracer != nil {
h.tracer.LossTimerCanceled()
}
}
return
}
// Cancel the alarm if no packets are outstanding
if !h.hasOutstandingPackets() && h.peerCompletedAddressValidation {
h.alarm = time.Time{}
if !oldAlarm.IsZero() {
h.logger.Debugf("Canceling loss detection timer. No packets in flight.")
if h.tracer != nil {
h.tracer.LossTimerCanceled()
}
}
return
}
// PTO alarm
ptoTime, encLevel, ok := h.getPTOTimeAndSpace()
if !ok {
return
}
h.alarm = ptoTime
if h.tracer != nil && h.alarm != oldAlarm {
h.tracer.SetLossTimer(logging.TimerTypePTO, encLevel, h.alarm)
}
}
func (h *sentPacketHandler) detectLostPackets(now time.Time, encLevel protocol.EncryptionLevel) error {
pnSpace := h.getPacketNumberSpace(encLevel)
pnSpace.lossTime = time.Time{}
maxRTT := float64(utils.MaxDuration(h.rttStats.LatestRTT(), h.rttStats.SmoothedRTT()))
lossDelay := time.Duration(timeThreshold * maxRTT)
// Minimum time of granularity before packets are deemed lost.
lossDelay = utils.MaxDuration(lossDelay, protocol.TimerGranularity)
// Packets sent before this time are deemed lost.
lostSendTime := now.Add(-lossDelay)
priorInFlight := h.bytesInFlight
return pnSpace.history.Iterate(func(p *Packet) (bool, error) {
if p.PacketNumber > pnSpace.largestAcked {
return false, nil
}
if p.declaredLost || p.skippedPacket {
return true, nil
}
var packetLost bool
if p.SendTime.Before(lostSendTime) {
packetLost = true
if h.logger.Debug() {
h.logger.Debugf("\tlost packet %d (time threshold)", p.PacketNumber)
}
if h.tracer != nil {
h.tracer.LostPacket(p.EncryptionLevel, p.PacketNumber, logging.PacketLossTimeThreshold)
}
} else if pnSpace.largestAcked >= p.PacketNumber+packetThreshold {
packetLost = true
if h.logger.Debug() {
h.logger.Debugf("\tlost packet %d (reordering threshold)", p.PacketNumber)
}
if h.tracer != nil {
h.tracer.LostPacket(p.EncryptionLevel, p.PacketNumber, logging.PacketLossReorderingThreshold)
}
} else if pnSpace.lossTime.IsZero() {
// Note: This conditional is only entered once per call
lossTime := p.SendTime.Add(lossDelay)
if h.logger.Debug() {
h.logger.Debugf("\tsetting loss timer for packet %d (%s) to %s (in %s)", p.PacketNumber, encLevel, lossDelay, lossTime)
}
pnSpace.lossTime = lossTime
}
if packetLost {
p.declaredLost = true
// the bytes in flight need to be reduced no matter if the frames in this packet will be retransmitted
h.removeFromBytesInFlight(p)
h.queueFramesForRetransmission(p)
if !p.IsPathMTUProbePacket {
h.congestion.OnPacketLost(p.PacketNumber, p.Length, priorInFlight)
}
}
return true, nil
})
}
func (h *sentPacketHandler) OnLossDetectionTimeout() error {
defer h.setLossDetectionTimer()
earliestLossTime, encLevel := h.getLossTimeAndSpace()
if !earliestLossTime.IsZero() {
if h.logger.Debug() {
h.logger.Debugf("Loss detection alarm fired in loss timer mode. Loss time: %s", earliestLossTime)
}
if h.tracer != nil {
h.tracer.LossTimerExpired(logging.TimerTypeACK, encLevel)
}
// Early retransmit or time loss detection
return h.detectLostPackets(time.Now(), encLevel)
}
// PTO
// When all outstanding are acknowledged, the alarm is canceled in
// setLossDetectionTimer. This doesn't reset the timer in the session though.
// When OnAlarm is called, we therefore need to make sure that there are
// actually packets outstanding.
if h.bytesInFlight == 0 && !h.peerCompletedAddressValidation {
h.ptoCount++
h.numProbesToSend++
if h.initialPackets != nil {
h.ptoMode = SendPTOInitial
} else if h.handshakePackets != nil {
h.ptoMode = SendPTOHandshake
} else {
return errors.New("sentPacketHandler BUG: PTO fired, but bytes_in_flight is 0 and Initial and Handshake already dropped")
}
return nil
}
_, encLevel, ok := h.getPTOTimeAndSpace()
if !ok {
return nil
}
if ps := h.getPacketNumberSpace(encLevel); !ps.history.HasOutstandingPackets() && !h.peerCompletedAddressValidation {
return nil
}
h.ptoCount++
if h.logger.Debug() {
h.logger.Debugf("Loss detection alarm for %s fired in PTO mode. PTO count: %d", encLevel, h.ptoCount)
}
if h.tracer != nil {
h.tracer.LossTimerExpired(logging.TimerTypePTO, encLevel)
h.tracer.UpdatedPTOCount(h.ptoCount)
}
h.numProbesToSend += 2
//nolint:exhaustive // We never arm a PTO timer for 0-RTT packets.
switch encLevel {
case protocol.EncryptionInitial:
h.ptoMode = SendPTOInitial
case protocol.EncryptionHandshake:
h.ptoMode = SendPTOHandshake
case protocol.Encryption1RTT:
// skip a packet number in order to elicit an immediate ACK
_ = h.PopPacketNumber(protocol.Encryption1RTT)
h.ptoMode = SendPTOAppData
default:
return fmt.Errorf("PTO timer in unexpected encryption level: %s", encLevel)
}
return nil
}
func (h *sentPacketHandler) GetLossDetectionTimeout() time.Time {
return h.alarm
}
func (h *sentPacketHandler) PeekPacketNumber(encLevel protocol.EncryptionLevel) (protocol.PacketNumber, protocol.PacketNumberLen) {
pnSpace := h.getPacketNumberSpace(encLevel)
var lowestUnacked protocol.PacketNumber
if p := pnSpace.history.FirstOutstanding(); p != nil {
lowestUnacked = p.PacketNumber
} else {
lowestUnacked = pnSpace.largestAcked + 1
}
pn := pnSpace.pns.Peek()
return pn, protocol.GetPacketNumberLengthForHeader(pn, lowestUnacked)
}
func (h *sentPacketHandler) PopPacketNumber(encLevel protocol.EncryptionLevel) protocol.PacketNumber {
return h.getPacketNumberSpace(encLevel).pns.Pop()
}
func (h *sentPacketHandler) SendMode() SendMode {
numTrackedPackets := h.appDataPackets.history.Len()
if h.initialPackets != nil {
numTrackedPackets += h.initialPackets.history.Len()
}
if h.handshakePackets != nil {
numTrackedPackets += h.handshakePackets.history.Len()
}
if h.isAmplificationLimited() {
h.logger.Debugf("Amplification window limited. Received %d bytes, already sent out %d bytes", h.bytesReceived, h.bytesSent)
return SendNone
}
// Don't send any packets if we're keeping track of the maximum number of packets.
// Note that since MaxOutstandingSentPackets is smaller than MaxTrackedSentPackets,
// we will stop sending out new data when reaching MaxOutstandingSentPackets,
// but still allow sending of retransmissions and ACKs.
if numTrackedPackets >= protocol.MaxTrackedSentPackets {
if h.logger.Debug() {
h.logger.Debugf("Limited by the number of tracked packets: tracking %d packets, maximum %d", numTrackedPackets, protocol.MaxTrackedSentPackets)
}
return SendNone
}
if h.numProbesToSend > 0 {
return h.ptoMode
}
// Only send ACKs if we're congestion limited.
if !h.congestion.CanSend(h.bytesInFlight) {
if h.logger.Debug() {
h.logger.Debugf("Congestion limited: bytes in flight %d, window %d", h.bytesInFlight, h.congestion.GetCongestionWindow())
}
return SendAck
}
if numTrackedPackets >= protocol.MaxOutstandingSentPackets {
if h.logger.Debug() {
h.logger.Debugf("Max outstanding limited: tracking %d packets, maximum: %d", numTrackedPackets, protocol.MaxOutstandingSentPackets)
}
return SendAck
}
return SendAny
}
func (h *sentPacketHandler) TimeUntilSend() time.Time {
return h.congestion.TimeUntilSend(h.bytesInFlight)
}
func (h *sentPacketHandler) HasPacingBudget() bool {
return h.congestion.HasPacingBudget()
}
func (h *sentPacketHandler) SetMaxDatagramSize(s protocol.ByteCount) {
h.congestion.SetMaxDatagramSize(s)
}
func (h *sentPacketHandler) isAmplificationLimited() bool {
if h.peerAddressValidated {
return false
}
return h.bytesSent >= amplificationFactor*h.bytesReceived
}
func (h *sentPacketHandler) QueueProbePacket(encLevel protocol.EncryptionLevel) bool {
pnSpace := h.getPacketNumberSpace(encLevel)
p := pnSpace.history.FirstOutstanding()
if p == nil {
return false
}
h.queueFramesForRetransmission(p)
// TODO: don't declare the packet lost here.
// Keep track of acknowledged frames instead.
h.removeFromBytesInFlight(p)
p.declaredLost = true
return true
}
func (h *sentPacketHandler) queueFramesForRetransmission(p *Packet) {
if len(p.Frames) == 0 {
panic("no frames")
}
for _, f := range p.Frames {
f.OnLost(f.Frame)
}
p.Frames = nil
}
func (h *sentPacketHandler) ResetForRetry() error {
h.bytesInFlight = 0
var firstPacketSendTime time.Time
h.initialPackets.history.Iterate(func(p *Packet) (bool, error) {
if firstPacketSendTime.IsZero() {
firstPacketSendTime = p.SendTime
}
if p.declaredLost || p.skippedPacket {
return true, nil
}
h.queueFramesForRetransmission(p)
return true, nil
})
// All application data packets sent at this point are 0-RTT packets.
// In the case of a Retry, we can assume that the server dropped all of them.
h.appDataPackets.history.Iterate(func(p *Packet) (bool, error) {
if !p.declaredLost && !p.skippedPacket {
h.queueFramesForRetransmission(p)
}
return true, nil
})
// Only use the Retry to estimate the RTT if we didn't send any retransmission for the Initial.
// Otherwise, we don't know which Initial the Retry was sent in response to.
if h.ptoCount == 0 {
// Don't set the RTT to a value lower than 5ms here.
now := time.Now()
h.rttStats.UpdateRTT(utils.MaxDuration(minRTTAfterRetry, now.Sub(firstPacketSendTime)), 0, now)
if h.logger.Debug() {
h.logger.Debugf("\tupdated RTT: %s (σ: %s)", h.rttStats.SmoothedRTT(), h.rttStats.MeanDeviation())
}
if h.tracer != nil {
h.tracer.UpdatedMetrics(h.rttStats, h.congestion.GetCongestionWindow(), h.bytesInFlight, h.packetsInFlight())
}
}
h.initialPackets = newPacketNumberSpace(h.initialPackets.pns.Pop(), false, h.rttStats)
h.appDataPackets = newPacketNumberSpace(h.appDataPackets.pns.Pop(), true, h.rttStats)
oldAlarm := h.alarm
h.alarm = time.Time{}
if h.tracer != nil {
h.tracer.UpdatedPTOCount(0)
if !oldAlarm.IsZero() {
h.tracer.LossTimerCanceled()
}
}
h.ptoCount = 0
return nil
}
func (h *sentPacketHandler) SetHandshakeConfirmed() {
h.handshakeConfirmed = true
// We don't send PTOs for application data packets before the handshake completes.
// Make sure the timer is armed now, if necessary.
h.setLossDetectionTimer()
}

View File

@@ -0,0 +1,108 @@
package ackhandler
import (
"fmt"
"time"
"github.com/lucas-clemente/quic-go/internal/protocol"
"github.com/lucas-clemente/quic-go/internal/utils"
)
type sentPacketHistory struct {
rttStats *utils.RTTStats
packetList *PacketList
packetMap map[protocol.PacketNumber]*PacketElement
highestSent protocol.PacketNumber
}
func newSentPacketHistory(rttStats *utils.RTTStats) *sentPacketHistory {
return &sentPacketHistory{
rttStats: rttStats,
packetList: NewPacketList(),
packetMap: make(map[protocol.PacketNumber]*PacketElement),
highestSent: protocol.InvalidPacketNumber,
}
}
func (h *sentPacketHistory) SentPacket(p *Packet, isAckEliciting bool) {
if p.PacketNumber <= h.highestSent {
panic("non-sequential packet number use")
}
// Skipped packet numbers.
for pn := h.highestSent + 1; pn < p.PacketNumber; pn++ {
el := h.packetList.PushBack(Packet{
PacketNumber: pn,
EncryptionLevel: p.EncryptionLevel,
SendTime: p.SendTime,
skippedPacket: true,
})
h.packetMap[pn] = el
}
h.highestSent = p.PacketNumber
if isAckEliciting {
el := h.packetList.PushBack(*p)
h.packetMap[p.PacketNumber] = el
}
}
// Iterate iterates through all packets.
func (h *sentPacketHistory) Iterate(cb func(*Packet) (cont bool, err error)) error {
cont := true
var next *PacketElement
for el := h.packetList.Front(); cont && el != nil; el = next {
var err error
next = el.Next()
cont, err = cb(&el.Value)
if err != nil {
return err
}
}
return nil
}
// FirstOutStanding returns the first outstanding packet.
func (h *sentPacketHistory) FirstOutstanding() *Packet {
for el := h.packetList.Front(); el != nil; el = el.Next() {
p := &el.Value
if !p.declaredLost && !p.skippedPacket && !p.IsPathMTUProbePacket {
return p
}
}
return nil
}
func (h *sentPacketHistory) Len() int {
return len(h.packetMap)
}
func (h *sentPacketHistory) Remove(p protocol.PacketNumber) error {
el, ok := h.packetMap[p]
if !ok {
return fmt.Errorf("packet %d not found in sent packet history", p)
}
h.packetList.Remove(el)
delete(h.packetMap, p)
return nil
}
func (h *sentPacketHistory) HasOutstandingPackets() bool {
return h.FirstOutstanding() != nil
}
func (h *sentPacketHistory) DeleteOldPackets(now time.Time) {
maxAge := 3 * h.rttStats.PTO(false)
var nextEl *PacketElement
for el := h.packetList.Front(); el != nil; el = nextEl {
nextEl = el.Next()
p := el.Value
if p.SendTime.After(now.Add(-maxAge)) {
break
}
if !p.skippedPacket && !p.declaredLost { // should only happen in the case of drastic RTT changes
continue
}
delete(h.packetMap, p.PacketNumber)
h.packetList.Remove(el)
}
}

View File

@@ -0,0 +1,25 @@
package congestion
import (
"math"
"time"
"github.com/lucas-clemente/quic-go/internal/protocol"
)
// Bandwidth of a connection
type Bandwidth uint64
const infBandwidth Bandwidth = math.MaxUint64
const (
// BitsPerSecond is 1 bit per second
BitsPerSecond Bandwidth = 1
// BytesPerSecond is 1 byte per second
BytesPerSecond = 8 * BitsPerSecond
)
// BandwidthFromDelta calculates the bandwidth from a number of bytes and a time delta
func BandwidthFromDelta(bytes protocol.ByteCount, delta time.Duration) Bandwidth {
return Bandwidth(bytes) * Bandwidth(time.Second) / Bandwidth(delta) * BytesPerSecond
}

View File

@@ -0,0 +1,18 @@
package congestion
import "time"
// A Clock returns the current time
type Clock interface {
Now() time.Time
}
// DefaultClock implements the Clock interface using the Go stdlib clock.
type DefaultClock struct{}
var _ Clock = DefaultClock{}
// Now gets the current time
func (DefaultClock) Now() time.Time {
return time.Now()
}

View File

@@ -0,0 +1,214 @@
package congestion
import (
"math"
"time"
"github.com/lucas-clemente/quic-go/internal/protocol"
"github.com/lucas-clemente/quic-go/internal/utils"
)
// This cubic implementation is based on the one found in Chromiums's QUIC
// implementation, in the files net/quic/congestion_control/cubic.{hh,cc}.
// Constants based on TCP defaults.
// The following constants are in 2^10 fractions of a second instead of ms to
// allow a 10 shift right to divide.
// 1024*1024^3 (first 1024 is from 0.100^3)
// where 0.100 is 100 ms which is the scaling round trip time.
const (
cubeScale = 40
cubeCongestionWindowScale = 410
cubeFactor protocol.ByteCount = 1 << cubeScale / cubeCongestionWindowScale / maxDatagramSize
// TODO: when re-enabling cubic, make sure to use the actual packet size here
maxDatagramSize = protocol.ByteCount(protocol.InitialPacketSizeIPv4)
)
const defaultNumConnections = 1
// Default Cubic backoff factor
const beta float32 = 0.7
// Additional backoff factor when loss occurs in the concave part of the Cubic
// curve. This additional backoff factor is expected to give up bandwidth to
// new concurrent flows and speed up convergence.
const betaLastMax float32 = 0.85
// Cubic implements the cubic algorithm from TCP
type Cubic struct {
clock Clock
// Number of connections to simulate.
numConnections int
// Time when this cycle started, after last loss event.
epoch time.Time
// Max congestion window used just before last loss event.
// Note: to improve fairness to other streams an additional back off is
// applied to this value if the new value is below our latest value.
lastMaxCongestionWindow protocol.ByteCount
// Number of acked bytes since the cycle started (epoch).
ackedBytesCount protocol.ByteCount
// TCP Reno equivalent congestion window in packets.
estimatedTCPcongestionWindow protocol.ByteCount
// Origin point of cubic function.
originPointCongestionWindow protocol.ByteCount
// Time to origin point of cubic function in 2^10 fractions of a second.
timeToOriginPoint uint32
// Last congestion window in packets computed by cubic function.
lastTargetCongestionWindow protocol.ByteCount
}
// NewCubic returns a new Cubic instance
func NewCubic(clock Clock) *Cubic {
c := &Cubic{
clock: clock,
numConnections: defaultNumConnections,
}
c.Reset()
return c
}
// Reset is called after a timeout to reset the cubic state
func (c *Cubic) Reset() {
c.epoch = time.Time{}
c.lastMaxCongestionWindow = 0
c.ackedBytesCount = 0
c.estimatedTCPcongestionWindow = 0
c.originPointCongestionWindow = 0
c.timeToOriginPoint = 0
c.lastTargetCongestionWindow = 0
}
func (c *Cubic) alpha() float32 {
// TCPFriendly alpha is described in Section 3.3 of the CUBIC paper. Note that
// beta here is a cwnd multiplier, and is equal to 1-beta from the paper.
// We derive the equivalent alpha for an N-connection emulation as:
b := c.beta()
return 3 * float32(c.numConnections) * float32(c.numConnections) * (1 - b) / (1 + b)
}
func (c *Cubic) beta() float32 {
// kNConnectionBeta is the backoff factor after loss for our N-connection
// emulation, which emulates the effective backoff of an ensemble of N
// TCP-Reno connections on a single loss event. The effective multiplier is
// computed as:
return (float32(c.numConnections) - 1 + beta) / float32(c.numConnections)
}
func (c *Cubic) betaLastMax() float32 {
// betaLastMax is the additional backoff factor after loss for our
// N-connection emulation, which emulates the additional backoff of
// an ensemble of N TCP-Reno connections on a single loss event. The
// effective multiplier is computed as:
return (float32(c.numConnections) - 1 + betaLastMax) / float32(c.numConnections)
}
// OnApplicationLimited is called on ack arrival when sender is unable to use
// the available congestion window. Resets Cubic state during quiescence.
func (c *Cubic) OnApplicationLimited() {
// When sender is not using the available congestion window, the window does
// not grow. But to be RTT-independent, Cubic assumes that the sender has been
// using the entire window during the time since the beginning of the current
// "epoch" (the end of the last loss recovery period). Since
// application-limited periods break this assumption, we reset the epoch when
// in such a period. This reset effectively freezes congestion window growth
// through application-limited periods and allows Cubic growth to continue
// when the entire window is being used.
c.epoch = time.Time{}
}
// CongestionWindowAfterPacketLoss computes a new congestion window to use after
// a loss event. Returns the new congestion window in packets. The new
// congestion window is a multiplicative decrease of our current window.
func (c *Cubic) CongestionWindowAfterPacketLoss(currentCongestionWindow protocol.ByteCount) protocol.ByteCount {
if currentCongestionWindow+maxDatagramSize < c.lastMaxCongestionWindow {
// We never reached the old max, so assume we are competing with another
// flow. Use our extra back off factor to allow the other flow to go up.
c.lastMaxCongestionWindow = protocol.ByteCount(c.betaLastMax() * float32(currentCongestionWindow))
} else {
c.lastMaxCongestionWindow = currentCongestionWindow
}
c.epoch = time.Time{} // Reset time.
return protocol.ByteCount(float32(currentCongestionWindow) * c.beta())
}
// CongestionWindowAfterAck computes a new congestion window to use after a received ACK.
// Returns the new congestion window in packets. The new congestion window
// follows a cubic function that depends on the time passed since last
// packet loss.
func (c *Cubic) CongestionWindowAfterAck(
ackedBytes protocol.ByteCount,
currentCongestionWindow protocol.ByteCount,
delayMin time.Duration,
eventTime time.Time,
) protocol.ByteCount {
c.ackedBytesCount += ackedBytes
if c.epoch.IsZero() {
// First ACK after a loss event.
c.epoch = eventTime // Start of epoch.
c.ackedBytesCount = ackedBytes // Reset count.
// Reset estimated_tcp_congestion_window_ to be in sync with cubic.
c.estimatedTCPcongestionWindow = currentCongestionWindow
if c.lastMaxCongestionWindow <= currentCongestionWindow {
c.timeToOriginPoint = 0
c.originPointCongestionWindow = currentCongestionWindow
} else {
c.timeToOriginPoint = uint32(math.Cbrt(float64(cubeFactor * (c.lastMaxCongestionWindow - currentCongestionWindow))))
c.originPointCongestionWindow = c.lastMaxCongestionWindow
}
}
// Change the time unit from microseconds to 2^10 fractions per second. Take
// the round trip time in account. This is done to allow us to use shift as a
// divide operator.
elapsedTime := int64(eventTime.Add(delayMin).Sub(c.epoch)/time.Microsecond) << 10 / (1000 * 1000)
// Right-shifts of negative, signed numbers have implementation-dependent
// behavior, so force the offset to be positive, as is done in the kernel.
offset := int64(c.timeToOriginPoint) - elapsedTime
if offset < 0 {
offset = -offset
}
deltaCongestionWindow := protocol.ByteCount(cubeCongestionWindowScale*offset*offset*offset) * maxDatagramSize >> cubeScale
var targetCongestionWindow protocol.ByteCount
if elapsedTime > int64(c.timeToOriginPoint) {
targetCongestionWindow = c.originPointCongestionWindow + deltaCongestionWindow
} else {
targetCongestionWindow = c.originPointCongestionWindow - deltaCongestionWindow
}
// Limit the CWND increase to half the acked bytes.
targetCongestionWindow = utils.MinByteCount(targetCongestionWindow, currentCongestionWindow+c.ackedBytesCount/2)
// Increase the window by approximately Alpha * 1 MSS of bytes every
// time we ack an estimated tcp window of bytes. For small
// congestion windows (less than 25), the formula below will
// increase slightly slower than linearly per estimated tcp window
// of bytes.
c.estimatedTCPcongestionWindow += protocol.ByteCount(float32(c.ackedBytesCount) * c.alpha() * float32(maxDatagramSize) / float32(c.estimatedTCPcongestionWindow))
c.ackedBytesCount = 0
// We have a new cubic congestion window.
c.lastTargetCongestionWindow = targetCongestionWindow
// Compute target congestion_window based on cubic target and estimated TCP
// congestion_window, use highest (fastest).
if targetCongestionWindow < c.estimatedTCPcongestionWindow {
targetCongestionWindow = c.estimatedTCPcongestionWindow
}
return targetCongestionWindow
}
// SetNumConnections sets the number of emulated connections
func (c *Cubic) SetNumConnections(n int) {
c.numConnections = n
}

View File

@@ -0,0 +1,316 @@
package congestion
import (
"fmt"
"time"
"github.com/lucas-clemente/quic-go/internal/protocol"
"github.com/lucas-clemente/quic-go/internal/utils"
"github.com/lucas-clemente/quic-go/logging"
)
const (
// maxDatagramSize is the default maximum packet size used in the Linux TCP implementation.
// Used in QUIC for congestion window computations in bytes.
initialMaxDatagramSize = protocol.ByteCount(protocol.InitialPacketSizeIPv4)
maxBurstPackets = 3
renoBeta = 0.7 // Reno backoff factor.
minCongestionWindowPackets = 2
initialCongestionWindow = 32
)
type cubicSender struct {
hybridSlowStart HybridSlowStart
rttStats *utils.RTTStats
cubic *Cubic
pacer *pacer
clock Clock
reno bool
// Track the largest packet that has been sent.
largestSentPacketNumber protocol.PacketNumber
// Track the largest packet that has been acked.
largestAckedPacketNumber protocol.PacketNumber
// Track the largest packet number outstanding when a CWND cutback occurs.
largestSentAtLastCutback protocol.PacketNumber
// Whether the last loss event caused us to exit slowstart.
// Used for stats collection of slowstartPacketsLost
lastCutbackExitedSlowstart bool
// Congestion window in packets.
congestionWindow protocol.ByteCount
// Slow start congestion window in bytes, aka ssthresh.
slowStartThreshold protocol.ByteCount
// ACK counter for the Reno implementation.
numAckedPackets uint64
initialCongestionWindow protocol.ByteCount
initialMaxCongestionWindow protocol.ByteCount
maxDatagramSize protocol.ByteCount
lastState logging.CongestionState
tracer logging.ConnectionTracer
}
var (
_ SendAlgorithm = &cubicSender{}
_ SendAlgorithmWithDebugInfos = &cubicSender{}
)
// NewCubicSender makes a new cubic sender
func NewCubicSender(
clock Clock,
rttStats *utils.RTTStats,
initialMaxDatagramSize protocol.ByteCount,
reno bool,
tracer logging.ConnectionTracer,
) *cubicSender {
return newCubicSender(
clock,
rttStats,
reno,
initialMaxDatagramSize,
initialCongestionWindow*initialMaxDatagramSize,
protocol.MaxCongestionWindowPackets*initialMaxDatagramSize,
tracer,
)
}
func newCubicSender(
clock Clock,
rttStats *utils.RTTStats,
reno bool,
initialMaxDatagramSize,
initialCongestionWindow,
initialMaxCongestionWindow protocol.ByteCount,
tracer logging.ConnectionTracer,
) *cubicSender {
c := &cubicSender{
rttStats: rttStats,
largestSentPacketNumber: protocol.InvalidPacketNumber,
largestAckedPacketNumber: protocol.InvalidPacketNumber,
largestSentAtLastCutback: protocol.InvalidPacketNumber,
initialCongestionWindow: initialCongestionWindow,
initialMaxCongestionWindow: initialMaxCongestionWindow,
congestionWindow: initialCongestionWindow,
slowStartThreshold: protocol.MaxByteCount,
cubic: NewCubic(clock),
clock: clock,
reno: reno,
tracer: tracer,
maxDatagramSize: initialMaxDatagramSize,
}
c.pacer = newPacer(c.BandwidthEstimate)
if c.tracer != nil {
c.lastState = logging.CongestionStateSlowStart
c.tracer.UpdatedCongestionState(logging.CongestionStateSlowStart)
}
return c
}
// TimeUntilSend returns when the next packet should be sent.
func (c *cubicSender) TimeUntilSend(_ protocol.ByteCount) time.Time {
return c.pacer.TimeUntilSend()
}
func (c *cubicSender) HasPacingBudget() bool {
return c.pacer.Budget(c.clock.Now()) >= c.maxDatagramSize
}
func (c *cubicSender) maxCongestionWindow() protocol.ByteCount {
return c.maxDatagramSize * protocol.MaxCongestionWindowPackets
}
func (c *cubicSender) minCongestionWindow() protocol.ByteCount {
return c.maxDatagramSize * minCongestionWindowPackets
}
func (c *cubicSender) OnPacketSent(
sentTime time.Time,
_ protocol.ByteCount,
packetNumber protocol.PacketNumber,
bytes protocol.ByteCount,
isRetransmittable bool,
) {
c.pacer.SentPacket(sentTime, bytes)
if !isRetransmittable {
return
}
c.largestSentPacketNumber = packetNumber
c.hybridSlowStart.OnPacketSent(packetNumber)
}
func (c *cubicSender) CanSend(bytesInFlight protocol.ByteCount) bool {
return bytesInFlight < c.GetCongestionWindow()
}
func (c *cubicSender) InRecovery() bool {
return c.largestAckedPacketNumber != protocol.InvalidPacketNumber && c.largestAckedPacketNumber <= c.largestSentAtLastCutback
}
func (c *cubicSender) InSlowStart() bool {
return c.GetCongestionWindow() < c.slowStartThreshold
}
func (c *cubicSender) GetCongestionWindow() protocol.ByteCount {
return c.congestionWindow
}
func (c *cubicSender) MaybeExitSlowStart() {
if c.InSlowStart() &&
c.hybridSlowStart.ShouldExitSlowStart(c.rttStats.LatestRTT(), c.rttStats.MinRTT(), c.GetCongestionWindow()/c.maxDatagramSize) {
// exit slow start
c.slowStartThreshold = c.congestionWindow
c.maybeTraceStateChange(logging.CongestionStateCongestionAvoidance)
}
}
func (c *cubicSender) OnPacketAcked(
ackedPacketNumber protocol.PacketNumber,
ackedBytes protocol.ByteCount,
priorInFlight protocol.ByteCount,
eventTime time.Time,
) {
c.largestAckedPacketNumber = utils.MaxPacketNumber(ackedPacketNumber, c.largestAckedPacketNumber)
if c.InRecovery() {
return
}
c.maybeIncreaseCwnd(ackedPacketNumber, ackedBytes, priorInFlight, eventTime)
if c.InSlowStart() {
c.hybridSlowStart.OnPacketAcked(ackedPacketNumber)
}
}
func (c *cubicSender) OnPacketLost(packetNumber protocol.PacketNumber, lostBytes, priorInFlight protocol.ByteCount) {
// TCP NewReno (RFC6582) says that once a loss occurs, any losses in packets
// already sent should be treated as a single loss event, since it's expected.
if packetNumber <= c.largestSentAtLastCutback {
return
}
c.lastCutbackExitedSlowstart = c.InSlowStart()
c.maybeTraceStateChange(logging.CongestionStateRecovery)
if c.reno {
c.congestionWindow = protocol.ByteCount(float64(c.congestionWindow) * renoBeta)
} else {
c.congestionWindow = c.cubic.CongestionWindowAfterPacketLoss(c.congestionWindow)
}
if minCwnd := c.minCongestionWindow(); c.congestionWindow < minCwnd {
c.congestionWindow = minCwnd
}
c.slowStartThreshold = c.congestionWindow
c.largestSentAtLastCutback = c.largestSentPacketNumber
// reset packet count from congestion avoidance mode. We start
// counting again when we're out of recovery.
c.numAckedPackets = 0
}
// Called when we receive an ack. Normal TCP tracks how many packets one ack
// represents, but quic has a separate ack for each packet.
func (c *cubicSender) maybeIncreaseCwnd(
_ protocol.PacketNumber,
ackedBytes protocol.ByteCount,
priorInFlight protocol.ByteCount,
eventTime time.Time,
) {
// Do not increase the congestion window unless the sender is close to using
// the current window.
if !c.isCwndLimited(priorInFlight) {
c.cubic.OnApplicationLimited()
c.maybeTraceStateChange(logging.CongestionStateApplicationLimited)
return
}
if c.congestionWindow >= c.maxCongestionWindow() {
return
}
if c.InSlowStart() {
// TCP slow start, exponential growth, increase by one for each ACK.
c.congestionWindow += c.maxDatagramSize
c.maybeTraceStateChange(logging.CongestionStateSlowStart)
return
}
// Congestion avoidance
c.maybeTraceStateChange(logging.CongestionStateCongestionAvoidance)
if c.reno {
// Classic Reno congestion avoidance.
c.numAckedPackets++
if c.numAckedPackets >= uint64(c.congestionWindow/c.maxDatagramSize) {
c.congestionWindow += c.maxDatagramSize
c.numAckedPackets = 0
}
} else {
c.congestionWindow = utils.MinByteCount(c.maxCongestionWindow(), c.cubic.CongestionWindowAfterAck(ackedBytes, c.congestionWindow, c.rttStats.MinRTT(), eventTime))
}
}
func (c *cubicSender) isCwndLimited(bytesInFlight protocol.ByteCount) bool {
congestionWindow := c.GetCongestionWindow()
if bytesInFlight >= congestionWindow {
return true
}
availableBytes := congestionWindow - bytesInFlight
slowStartLimited := c.InSlowStart() && bytesInFlight > congestionWindow/2
return slowStartLimited || availableBytes <= maxBurstPackets*c.maxDatagramSize
}
// BandwidthEstimate returns the current bandwidth estimate
func (c *cubicSender) BandwidthEstimate() Bandwidth {
srtt := c.rttStats.SmoothedRTT()
if srtt == 0 {
// If we haven't measured an rtt, the bandwidth estimate is unknown.
return infBandwidth
}
return BandwidthFromDelta(c.GetCongestionWindow(), srtt)
}
// OnRetransmissionTimeout is called on an retransmission timeout
func (c *cubicSender) OnRetransmissionTimeout(packetsRetransmitted bool) {
c.largestSentAtLastCutback = protocol.InvalidPacketNumber
if !packetsRetransmitted {
return
}
c.hybridSlowStart.Restart()
c.cubic.Reset()
c.slowStartThreshold = c.congestionWindow / 2
c.congestionWindow = c.minCongestionWindow()
}
// OnConnectionMigration is called when the connection is migrated (?)
func (c *cubicSender) OnConnectionMigration() {
c.hybridSlowStart.Restart()
c.largestSentPacketNumber = protocol.InvalidPacketNumber
c.largestAckedPacketNumber = protocol.InvalidPacketNumber
c.largestSentAtLastCutback = protocol.InvalidPacketNumber
c.lastCutbackExitedSlowstart = false
c.cubic.Reset()
c.numAckedPackets = 0
c.congestionWindow = c.initialCongestionWindow
c.slowStartThreshold = c.initialMaxCongestionWindow
}
func (c *cubicSender) maybeTraceStateChange(new logging.CongestionState) {
if c.tracer == nil || new == c.lastState {
return
}
c.tracer.UpdatedCongestionState(new)
c.lastState = new
}
func (c *cubicSender) SetMaxDatagramSize(s protocol.ByteCount) {
if s < c.maxDatagramSize {
panic(fmt.Sprintf("congestion BUG: decreased max datagram size from %d to %d", c.maxDatagramSize, s))
}
cwndIsMinCwnd := c.congestionWindow == c.minCongestionWindow()
c.maxDatagramSize = s
if cwndIsMinCwnd {
c.congestionWindow = c.minCongestionWindow()
}
c.pacer.SetMaxDatagramSize(s)
}

View File

@@ -0,0 +1,113 @@
package congestion
import (
"time"
"github.com/lucas-clemente/quic-go/internal/protocol"
"github.com/lucas-clemente/quic-go/internal/utils"
)
// Note(pwestin): the magic clamping numbers come from the original code in
// tcp_cubic.c.
const hybridStartLowWindow = protocol.ByteCount(16)
// Number of delay samples for detecting the increase of delay.
const hybridStartMinSamples = uint32(8)
// Exit slow start if the min rtt has increased by more than 1/8th.
const hybridStartDelayFactorExp = 3 // 2^3 = 8
// The original paper specifies 2 and 8ms, but those have changed over time.
const (
hybridStartDelayMinThresholdUs = int64(4000)
hybridStartDelayMaxThresholdUs = int64(16000)
)
// HybridSlowStart implements the TCP hybrid slow start algorithm
type HybridSlowStart struct {
endPacketNumber protocol.PacketNumber
lastSentPacketNumber protocol.PacketNumber
started bool
currentMinRTT time.Duration
rttSampleCount uint32
hystartFound bool
}
// StartReceiveRound is called for the start of each receive round (burst) in the slow start phase.
func (s *HybridSlowStart) StartReceiveRound(lastSent protocol.PacketNumber) {
s.endPacketNumber = lastSent
s.currentMinRTT = 0
s.rttSampleCount = 0
s.started = true
}
// IsEndOfRound returns true if this ack is the last packet number of our current slow start round.
func (s *HybridSlowStart) IsEndOfRound(ack protocol.PacketNumber) bool {
return s.endPacketNumber < ack
}
// ShouldExitSlowStart should be called on every new ack frame, since a new
// RTT measurement can be made then.
// rtt: the RTT for this ack packet.
// minRTT: is the lowest delay (RTT) we have seen during the session.
// congestionWindow: the congestion window in packets.
func (s *HybridSlowStart) ShouldExitSlowStart(latestRTT time.Duration, minRTT time.Duration, congestionWindow protocol.ByteCount) bool {
if !s.started {
// Time to start the hybrid slow start.
s.StartReceiveRound(s.lastSentPacketNumber)
}
if s.hystartFound {
return true
}
// Second detection parameter - delay increase detection.
// Compare the minimum delay (s.currentMinRTT) of the current
// burst of packets relative to the minimum delay during the session.
// Note: we only look at the first few(8) packets in each burst, since we
// only want to compare the lowest RTT of the burst relative to previous
// bursts.
s.rttSampleCount++
if s.rttSampleCount <= hybridStartMinSamples {
if s.currentMinRTT == 0 || s.currentMinRTT > latestRTT {
s.currentMinRTT = latestRTT
}
}
// We only need to check this once per round.
if s.rttSampleCount == hybridStartMinSamples {
// Divide minRTT by 8 to get a rtt increase threshold for exiting.
minRTTincreaseThresholdUs := int64(minRTT / time.Microsecond >> hybridStartDelayFactorExp)
// Ensure the rtt threshold is never less than 2ms or more than 16ms.
minRTTincreaseThresholdUs = utils.MinInt64(minRTTincreaseThresholdUs, hybridStartDelayMaxThresholdUs)
minRTTincreaseThreshold := time.Duration(utils.MaxInt64(minRTTincreaseThresholdUs, hybridStartDelayMinThresholdUs)) * time.Microsecond
if s.currentMinRTT > (minRTT + minRTTincreaseThreshold) {
s.hystartFound = true
}
}
// Exit from slow start if the cwnd is greater than 16 and
// increasing delay is found.
return congestionWindow >= hybridStartLowWindow && s.hystartFound
}
// OnPacketSent is called when a packet was sent
func (s *HybridSlowStart) OnPacketSent(packetNumber protocol.PacketNumber) {
s.lastSentPacketNumber = packetNumber
}
// OnPacketAcked gets invoked after ShouldExitSlowStart, so it's best to end
// the round when the final packet of the burst is received and start it on
// the next incoming ack.
func (s *HybridSlowStart) OnPacketAcked(ackedPacketNumber protocol.PacketNumber) {
if s.IsEndOfRound(ackedPacketNumber) {
s.started = false
}
}
// Started returns true if started
func (s *HybridSlowStart) Started() bool {
return s.started
}
// Restart the slow start phase
func (s *HybridSlowStart) Restart() {
s.started = false
s.hystartFound = false
}

View File

@@ -0,0 +1,28 @@
package congestion
import (
"time"
"github.com/lucas-clemente/quic-go/internal/protocol"
)
// A SendAlgorithm performs congestion control
type SendAlgorithm interface {
TimeUntilSend(bytesInFlight protocol.ByteCount) time.Time
HasPacingBudget() bool
OnPacketSent(sentTime time.Time, bytesInFlight protocol.ByteCount, packetNumber protocol.PacketNumber, bytes protocol.ByteCount, isRetransmittable bool)
CanSend(bytesInFlight protocol.ByteCount) bool
MaybeExitSlowStart()
OnPacketAcked(number protocol.PacketNumber, ackedBytes protocol.ByteCount, priorInFlight protocol.ByteCount, eventTime time.Time)
OnPacketLost(number protocol.PacketNumber, lostBytes protocol.ByteCount, priorInFlight protocol.ByteCount)
OnRetransmissionTimeout(packetsRetransmitted bool)
SetMaxDatagramSize(protocol.ByteCount)
}
// A SendAlgorithmWithDebugInfos is a SendAlgorithm that exposes some debug infos
type SendAlgorithmWithDebugInfos interface {
SendAlgorithm
InSlowStart() bool
InRecovery() bool
GetCongestionWindow() protocol.ByteCount
}

View File

@@ -0,0 +1,77 @@
package congestion
import (
"math"
"time"
"github.com/lucas-clemente/quic-go/internal/protocol"
"github.com/lucas-clemente/quic-go/internal/utils"
)
const maxBurstSizePackets = 10
// The pacer implements a token bucket pacing algorithm.
type pacer struct {
budgetAtLastSent protocol.ByteCount
maxDatagramSize protocol.ByteCount
lastSentTime time.Time
getAdjustedBandwidth func() uint64 // in bytes/s
}
func newPacer(getBandwidth func() Bandwidth) *pacer {
p := &pacer{
maxDatagramSize: initialMaxDatagramSize,
getAdjustedBandwidth: func() uint64 {
// Bandwidth is in bits/s. We need the value in bytes/s.
bw := uint64(getBandwidth() / BytesPerSecond)
// Use a slightly higher value than the actual measured bandwidth.
// RTT variations then won't result in under-utilization of the congestion window.
// Ultimately, this will result in sending packets as acknowledgments are received rather than when timers fire,
// provided the congestion window is fully utilized and acknowledgments arrive at regular intervals.
return bw * 5 / 4
},
}
p.budgetAtLastSent = p.maxBurstSize()
return p
}
func (p *pacer) SentPacket(sendTime time.Time, size protocol.ByteCount) {
budget := p.Budget(sendTime)
if size > budget {
p.budgetAtLastSent = 0
} else {
p.budgetAtLastSent = budget - size
}
p.lastSentTime = sendTime
}
func (p *pacer) Budget(now time.Time) protocol.ByteCount {
if p.lastSentTime.IsZero() {
return p.maxBurstSize()
}
budget := p.budgetAtLastSent + (protocol.ByteCount(p.getAdjustedBandwidth())*protocol.ByteCount(now.Sub(p.lastSentTime).Nanoseconds()))/1e9
return utils.MinByteCount(p.maxBurstSize(), budget)
}
func (p *pacer) maxBurstSize() protocol.ByteCount {
return utils.MaxByteCount(
protocol.ByteCount(uint64((protocol.MinPacingDelay+protocol.TimerGranularity).Nanoseconds())*p.getAdjustedBandwidth())/1e9,
maxBurstSizePackets*p.maxDatagramSize,
)
}
// TimeUntilSend returns when the next packet should be sent.
// It returns the zero value of time.Time if a packet can be sent immediately.
func (p *pacer) TimeUntilSend() time.Time {
if p.budgetAtLastSent >= p.maxDatagramSize {
return time.Time{}
}
return p.lastSentTime.Add(utils.MaxDuration(
protocol.MinPacingDelay,
time.Duration(math.Ceil(float64(p.maxDatagramSize-p.budgetAtLastSent)*1e9/float64(p.getAdjustedBandwidth())))*time.Nanosecond,
))
}
func (p *pacer) SetMaxDatagramSize(s protocol.ByteCount) {
p.maxDatagramSize = s
}

View File

@@ -0,0 +1,120 @@
package flowcontrol
import (
"sync"
"time"
"github.com/lucas-clemente/quic-go/internal/protocol"
"github.com/lucas-clemente/quic-go/internal/utils"
)
type baseFlowController struct {
// for sending data
bytesSent protocol.ByteCount
sendWindow protocol.ByteCount
lastBlockedAt protocol.ByteCount
// for receiving data
//nolint:structcheck // The mutex is used both by the stream and the connection flow controller
mutex sync.Mutex
bytesRead protocol.ByteCount
highestReceived protocol.ByteCount
receiveWindow protocol.ByteCount
receiveWindowSize protocol.ByteCount
maxReceiveWindowSize protocol.ByteCount
epochStartTime time.Time
epochStartOffset protocol.ByteCount
rttStats *utils.RTTStats
logger utils.Logger
}
// IsNewlyBlocked says if it is newly blocked by flow control.
// For every offset, it only returns true once.
// If it is blocked, the offset is returned.
func (c *baseFlowController) IsNewlyBlocked() (bool, protocol.ByteCount) {
if c.sendWindowSize() != 0 || c.sendWindow == c.lastBlockedAt {
return false, 0
}
c.lastBlockedAt = c.sendWindow
return true, c.sendWindow
}
func (c *baseFlowController) AddBytesSent(n protocol.ByteCount) {
c.bytesSent += n
}
// UpdateSendWindow is be called after receiving a MAX_{STREAM_}DATA frame.
func (c *baseFlowController) UpdateSendWindow(offset protocol.ByteCount) {
if offset > c.sendWindow {
c.sendWindow = offset
}
}
func (c *baseFlowController) sendWindowSize() protocol.ByteCount {
// this only happens during connection establishment, when data is sent before we receive the peer's transport parameters
if c.bytesSent > c.sendWindow {
return 0
}
return c.sendWindow - c.bytesSent
}
// needs to be called with locked mutex
func (c *baseFlowController) addBytesRead(n protocol.ByteCount) {
// pretend we sent a WindowUpdate when reading the first byte
// this way auto-tuning of the window size already works for the first WindowUpdate
if c.bytesRead == 0 {
c.startNewAutoTuningEpoch(time.Now())
}
c.bytesRead += n
}
func (c *baseFlowController) hasWindowUpdate() bool {
bytesRemaining := c.receiveWindow - c.bytesRead
// update the window when more than the threshold was consumed
return bytesRemaining <= protocol.ByteCount(float64(c.receiveWindowSize)*(1-protocol.WindowUpdateThreshold))
}
// getWindowUpdate updates the receive window, if necessary
// it returns the new offset
func (c *baseFlowController) getWindowUpdate() protocol.ByteCount {
if !c.hasWindowUpdate() {
return 0
}
c.maybeAdjustWindowSize()
c.receiveWindow = c.bytesRead + c.receiveWindowSize
return c.receiveWindow
}
// maybeAdjustWindowSize increases the receiveWindowSize if we're sending updates too often.
// For details about auto-tuning, see https://docs.google.com/document/d/1SExkMmGiz8VYzV3s9E35JQlJ73vhzCekKkDi85F1qCE/edit?usp=sharing.
func (c *baseFlowController) maybeAdjustWindowSize() {
bytesReadInEpoch := c.bytesRead - c.epochStartOffset
// don't do anything if less than half the window has been consumed
if bytesReadInEpoch <= c.receiveWindowSize/2 {
return
}
rtt := c.rttStats.SmoothedRTT()
if rtt == 0 {
return
}
fraction := float64(bytesReadInEpoch) / float64(c.receiveWindowSize)
now := time.Now()
if now.Sub(c.epochStartTime) < time.Duration(4*fraction*float64(rtt)) {
// window is consumed too fast, try to increase the window size
c.receiveWindowSize = utils.MinByteCount(2*c.receiveWindowSize, c.maxReceiveWindowSize)
}
c.startNewAutoTuningEpoch(now)
}
func (c *baseFlowController) startNewAutoTuningEpoch(now time.Time) {
c.epochStartTime = now
c.epochStartOffset = c.bytesRead
}
func (c *baseFlowController) checkFlowControlViolation() bool {
return c.highestReceived > c.receiveWindow
}

View File

@@ -0,0 +1,107 @@
package flowcontrol
import (
"errors"
"fmt"
"time"
"github.com/lucas-clemente/quic-go/internal/protocol"
"github.com/lucas-clemente/quic-go/internal/qerr"
"github.com/lucas-clemente/quic-go/internal/utils"
)
type connectionFlowController struct {
baseFlowController
queueWindowUpdate func()
}
var _ ConnectionFlowController = &connectionFlowController{}
// NewConnectionFlowController gets a new flow controller for the connection
// It is created before we receive the peer's transport paramenters, thus it starts with a sendWindow of 0.
func NewConnectionFlowController(
receiveWindow protocol.ByteCount,
maxReceiveWindow protocol.ByteCount,
queueWindowUpdate func(),
rttStats *utils.RTTStats,
logger utils.Logger,
) ConnectionFlowController {
return &connectionFlowController{
baseFlowController: baseFlowController{
rttStats: rttStats,
receiveWindow: receiveWindow,
receiveWindowSize: receiveWindow,
maxReceiveWindowSize: maxReceiveWindow,
logger: logger,
},
queueWindowUpdate: queueWindowUpdate,
}
}
func (c *connectionFlowController) SendWindowSize() protocol.ByteCount {
return c.baseFlowController.sendWindowSize()
}
// IncrementHighestReceived adds an increment to the highestReceived value
func (c *connectionFlowController) IncrementHighestReceived(increment protocol.ByteCount) error {
c.mutex.Lock()
defer c.mutex.Unlock()
c.highestReceived += increment
if c.checkFlowControlViolation() {
return &qerr.TransportError{
ErrorCode: qerr.FlowControlError,
ErrorMessage: fmt.Sprintf("received %d bytes for the connection, allowed %d bytes", c.highestReceived, c.receiveWindow),
}
}
return nil
}
func (c *connectionFlowController) AddBytesRead(n protocol.ByteCount) {
c.mutex.Lock()
c.baseFlowController.addBytesRead(n)
shouldQueueWindowUpdate := c.hasWindowUpdate()
c.mutex.Unlock()
if shouldQueueWindowUpdate {
c.queueWindowUpdate()
}
}
func (c *connectionFlowController) GetWindowUpdate() protocol.ByteCount {
c.mutex.Lock()
oldWindowSize := c.receiveWindowSize
offset := c.baseFlowController.getWindowUpdate()
if oldWindowSize < c.receiveWindowSize {
c.logger.Debugf("Increasing receive flow control window for the connection to %d kB", c.receiveWindowSize/(1<<10))
}
c.mutex.Unlock()
return offset
}
// EnsureMinimumWindowSize sets a minimum window size
// it should make sure that the connection-level window is increased when a stream-level window grows
func (c *connectionFlowController) EnsureMinimumWindowSize(inc protocol.ByteCount) {
c.mutex.Lock()
if inc > c.receiveWindowSize {
c.logger.Debugf("Increasing receive flow control window for the connection to %d kB, in response to stream flow control window increase", c.receiveWindowSize/(1<<10))
c.receiveWindowSize = utils.MinByteCount(inc, c.maxReceiveWindowSize)
c.startNewAutoTuningEpoch(time.Now())
}
c.mutex.Unlock()
}
// The flow controller is reset when 0-RTT is rejected.
// All stream data is invalidated, it's if we had never opened a stream and never sent any data.
// At that point, we only have sent stream data, but we didn't have the keys to open 1-RTT keys yet.
func (c *connectionFlowController) Reset() error {
c.mutex.Lock()
defer c.mutex.Unlock()
if c.bytesRead > 0 || c.highestReceived > 0 || !c.epochStartTime.IsZero() {
return errors.New("flow controller reset after reading data")
}
c.bytesSent = 0
c.lastBlockedAt = 0
return nil
}

View File

@@ -0,0 +1,42 @@
package flowcontrol
import "github.com/lucas-clemente/quic-go/internal/protocol"
type flowController interface {
// for sending
SendWindowSize() protocol.ByteCount
UpdateSendWindow(protocol.ByteCount)
AddBytesSent(protocol.ByteCount)
// for receiving
AddBytesRead(protocol.ByteCount)
GetWindowUpdate() protocol.ByteCount // returns 0 if no update is necessary
IsNewlyBlocked() (bool, protocol.ByteCount)
}
// A StreamFlowController is a flow controller for a QUIC stream.
type StreamFlowController interface {
flowController
// for receiving
// UpdateHighestReceived should be called when a new highest offset is received
// final has to be to true if this is the final offset of the stream,
// as contained in a STREAM frame with FIN bit, and the RESET_STREAM frame
UpdateHighestReceived(offset protocol.ByteCount, final bool) error
// Abandon should be called when reading from the stream is aborted early,
// and there won't be any further calls to AddBytesRead.
Abandon()
}
// The ConnectionFlowController is the flow controller for the connection.
type ConnectionFlowController interface {
flowController
Reset() error
}
type connectionFlowControllerI interface {
ConnectionFlowController
// The following two methods are not supposed to be called from outside this packet, but are needed internally
// for sending
EnsureMinimumWindowSize(protocol.ByteCount)
// for receiving
IncrementHighestReceived(protocol.ByteCount) error
}

View File

@@ -0,0 +1,146 @@
package flowcontrol
import (
"fmt"
"github.com/lucas-clemente/quic-go/internal/protocol"
"github.com/lucas-clemente/quic-go/internal/qerr"
"github.com/lucas-clemente/quic-go/internal/utils"
)
type streamFlowController struct {
baseFlowController
streamID protocol.StreamID
queueWindowUpdate func()
connection connectionFlowControllerI
receivedFinalOffset bool
}
var _ StreamFlowController = &streamFlowController{}
// NewStreamFlowController gets a new flow controller for a stream
func NewStreamFlowController(
streamID protocol.StreamID,
cfc ConnectionFlowController,
receiveWindow protocol.ByteCount,
maxReceiveWindow protocol.ByteCount,
initialSendWindow protocol.ByteCount,
queueWindowUpdate func(protocol.StreamID),
rttStats *utils.RTTStats,
logger utils.Logger,
) StreamFlowController {
return &streamFlowController{
streamID: streamID,
connection: cfc.(connectionFlowControllerI),
queueWindowUpdate: func() { queueWindowUpdate(streamID) },
baseFlowController: baseFlowController{
rttStats: rttStats,
receiveWindow: receiveWindow,
receiveWindowSize: receiveWindow,
maxReceiveWindowSize: maxReceiveWindow,
sendWindow: initialSendWindow,
logger: logger,
},
}
}
// UpdateHighestReceived updates the highestReceived value, if the offset is higher.
func (c *streamFlowController) UpdateHighestReceived(offset protocol.ByteCount, final bool) error {
// If the final offset for this stream is already known, check for consistency.
if c.receivedFinalOffset {
// If we receive another final offset, check that it's the same.
if final && offset != c.highestReceived {
return &qerr.TransportError{
ErrorCode: qerr.FinalSizeError,
ErrorMessage: fmt.Sprintf("received inconsistent final offset for stream %d (old: %d, new: %d bytes)", c.streamID, c.highestReceived, offset),
}
}
// Check that the offset is below the final offset.
if offset > c.highestReceived {
return &qerr.TransportError{
ErrorCode: qerr.FinalSizeError,
ErrorMessage: fmt.Sprintf("received offset %d for stream %d, but final offset was already received at %d", offset, c.streamID, c.highestReceived),
}
}
}
if final {
c.receivedFinalOffset = true
}
if offset == c.highestReceived {
return nil
}
// A higher offset was received before.
// This can happen due to reordering.
if offset <= c.highestReceived {
if final {
return &qerr.TransportError{
ErrorCode: qerr.FinalSizeError,
ErrorMessage: fmt.Sprintf("received final offset %d for stream %d, but already received offset %d before", offset, c.streamID, c.highestReceived),
}
}
return nil
}
increment := offset - c.highestReceived
c.highestReceived = offset
if c.checkFlowControlViolation() {
return &qerr.TransportError{
ErrorCode: qerr.FlowControlError,
ErrorMessage: fmt.Sprintf("received %d bytes on stream %d, allowed %d bytes", offset, c.streamID, c.receiveWindow),
}
}
return c.connection.IncrementHighestReceived(increment)
}
func (c *streamFlowController) AddBytesRead(n protocol.ByteCount) {
c.mutex.Lock()
c.baseFlowController.addBytesRead(n)
shouldQueueWindowUpdate := c.shouldQueueWindowUpdate()
c.mutex.Unlock()
if shouldQueueWindowUpdate {
c.queueWindowUpdate()
}
c.connection.AddBytesRead(n)
}
func (c *streamFlowController) Abandon() {
if unread := c.highestReceived - c.bytesRead; unread > 0 {
c.connection.AddBytesRead(unread)
}
}
func (c *streamFlowController) AddBytesSent(n protocol.ByteCount) {
c.baseFlowController.AddBytesSent(n)
c.connection.AddBytesSent(n)
}
func (c *streamFlowController) SendWindowSize() protocol.ByteCount {
return utils.MinByteCount(c.baseFlowController.sendWindowSize(), c.connection.SendWindowSize())
}
func (c *streamFlowController) shouldQueueWindowUpdate() bool {
return !c.receivedFinalOffset && c.hasWindowUpdate()
}
func (c *streamFlowController) GetWindowUpdate() protocol.ByteCount {
// If we already received the final offset for this stream, the peer won't need any additional flow control credit.
if c.receivedFinalOffset {
return 0
}
// Don't use defer for unlocking the mutex here, GetWindowUpdate() is called frequently and defer shows up in the profiler
c.mutex.Lock()
oldWindowSize := c.receiveWindowSize
offset := c.baseFlowController.getWindowUpdate()
if c.receiveWindowSize > oldWindowSize { // auto-tuning enlarged the window size
c.logger.Debugf("Increasing receive flow control window for stream %d to %d kB", c.streamID, c.receiveWindowSize/(1<<10))
c.connection.EnsureMinimumWindowSize(protocol.ByteCount(float64(c.receiveWindowSize) * protocol.ConnectionFlowControlMultiplier))
}
c.mutex.Unlock()
return offset
}

View File

@@ -0,0 +1,155 @@
package handshake
import (
"crypto/cipher"
"encoding/binary"
"github.com/lucas-clemente/quic-go/internal/protocol"
"github.com/lucas-clemente/quic-go/internal/qtls"
"github.com/lucas-clemente/quic-go/internal/utils"
)
func createAEAD(suite *qtls.CipherSuiteTLS13, trafficSecret []byte) cipher.AEAD {
key := hkdfExpandLabel(suite.Hash, trafficSecret, []byte{}, "quic key", suite.KeyLen)
iv := hkdfExpandLabel(suite.Hash, trafficSecret, []byte{}, "quic iv", suite.IVLen())
return suite.AEAD(key, iv)
}
type longHeaderSealer struct {
aead cipher.AEAD
headerProtector headerProtector
// use a single slice to avoid allocations
nonceBuf []byte
}
var _ LongHeaderSealer = &longHeaderSealer{}
func newLongHeaderSealer(aead cipher.AEAD, headerProtector headerProtector) LongHeaderSealer {
return &longHeaderSealer{
aead: aead,
headerProtector: headerProtector,
nonceBuf: make([]byte, aead.NonceSize()),
}
}
func (s *longHeaderSealer) Seal(dst, src []byte, pn protocol.PacketNumber, ad []byte) []byte {
binary.BigEndian.PutUint64(s.nonceBuf[len(s.nonceBuf)-8:], uint64(pn))
// The AEAD we're using here will be the qtls.aeadAESGCM13.
// It uses the nonce provided here and XOR it with the IV.
return s.aead.Seal(dst, s.nonceBuf, src, ad)
}
func (s *longHeaderSealer) EncryptHeader(sample []byte, firstByte *byte, pnBytes []byte) {
s.headerProtector.EncryptHeader(sample, firstByte, pnBytes)
}
func (s *longHeaderSealer) Overhead() int {
return s.aead.Overhead()
}
type longHeaderOpener struct {
aead cipher.AEAD
headerProtector headerProtector
highestRcvdPN protocol.PacketNumber // highest packet number received (which could be successfully unprotected)
// use a single slice to avoid allocations
nonceBuf []byte
}
var _ LongHeaderOpener = &longHeaderOpener{}
func newLongHeaderOpener(aead cipher.AEAD, headerProtector headerProtector) LongHeaderOpener {
return &longHeaderOpener{
aead: aead,
headerProtector: headerProtector,
nonceBuf: make([]byte, aead.NonceSize()),
}
}
func (o *longHeaderOpener) DecodePacketNumber(wirePN protocol.PacketNumber, wirePNLen protocol.PacketNumberLen) protocol.PacketNumber {
return protocol.DecodePacketNumber(wirePNLen, o.highestRcvdPN, wirePN)
}
func (o *longHeaderOpener) Open(dst, src []byte, pn protocol.PacketNumber, ad []byte) ([]byte, error) {
binary.BigEndian.PutUint64(o.nonceBuf[len(o.nonceBuf)-8:], uint64(pn))
// The AEAD we're using here will be the qtls.aeadAESGCM13.
// It uses the nonce provided here and XOR it with the IV.
dec, err := o.aead.Open(dst, o.nonceBuf, src, ad)
if err == nil {
o.highestRcvdPN = utils.MaxPacketNumber(o.highestRcvdPN, pn)
} else {
err = ErrDecryptionFailed
}
return dec, err
}
func (o *longHeaderOpener) DecryptHeader(sample []byte, firstByte *byte, pnBytes []byte) {
o.headerProtector.DecryptHeader(sample, firstByte, pnBytes)
}
type handshakeSealer struct {
LongHeaderSealer
dropInitialKeys func()
dropped bool
}
func newHandshakeSealer(
aead cipher.AEAD,
headerProtector headerProtector,
dropInitialKeys func(),
perspective protocol.Perspective,
) LongHeaderSealer {
sealer := newLongHeaderSealer(aead, headerProtector)
// The client drops Initial keys when sending the first Handshake packet.
if perspective == protocol.PerspectiveServer {
return sealer
}
return &handshakeSealer{
LongHeaderSealer: sealer,
dropInitialKeys: dropInitialKeys,
}
}
func (s *handshakeSealer) Seal(dst, src []byte, pn protocol.PacketNumber, ad []byte) []byte {
data := s.LongHeaderSealer.Seal(dst, src, pn, ad)
if !s.dropped {
s.dropInitialKeys()
s.dropped = true
}
return data
}
type handshakeOpener struct {
LongHeaderOpener
dropInitialKeys func()
dropped bool
}
func newHandshakeOpener(
aead cipher.AEAD,
headerProtector headerProtector,
dropInitialKeys func(),
perspective protocol.Perspective,
) LongHeaderOpener {
opener := newLongHeaderOpener(aead, headerProtector)
// The server drops Initial keys when first successfully processing a Handshake packet.
if perspective == protocol.PerspectiveClient {
return opener
}
return &handshakeOpener{
LongHeaderOpener: opener,
dropInitialKeys: dropInitialKeys,
}
}
func (o *handshakeOpener) Open(dst, src []byte, pn protocol.PacketNumber, ad []byte) ([]byte, error) {
dec, err := o.LongHeaderOpener.Open(dst, src, pn, ad)
if err == nil && !o.dropped {
o.dropInitialKeys()
o.dropped = true
}
return dec, err
}

View File

@@ -0,0 +1,800 @@
package handshake
import (
"bytes"
"crypto/tls"
"errors"
"fmt"
"io"
"net"
"sync"
"time"
"github.com/lucas-clemente/quic-go/internal/protocol"
"github.com/lucas-clemente/quic-go/internal/qerr"
"github.com/lucas-clemente/quic-go/internal/qtls"
"github.com/lucas-clemente/quic-go/internal/utils"
"github.com/lucas-clemente/quic-go/internal/wire"
"github.com/lucas-clemente/quic-go/logging"
"github.com/lucas-clemente/quic-go/quicvarint"
)
// TLS unexpected_message alert
const alertUnexpectedMessage uint8 = 10
type messageType uint8
// TLS handshake message types.
const (
typeClientHello messageType = 1
typeServerHello messageType = 2
typeNewSessionTicket messageType = 4
typeEncryptedExtensions messageType = 8
typeCertificate messageType = 11
typeCertificateRequest messageType = 13
typeCertificateVerify messageType = 15
typeFinished messageType = 20
)
func (m messageType) String() string {
switch m {
case typeClientHello:
return "ClientHello"
case typeServerHello:
return "ServerHello"
case typeNewSessionTicket:
return "NewSessionTicket"
case typeEncryptedExtensions:
return "EncryptedExtensions"
case typeCertificate:
return "Certificate"
case typeCertificateRequest:
return "CertificateRequest"
case typeCertificateVerify:
return "CertificateVerify"
case typeFinished:
return "Finished"
default:
return fmt.Sprintf("unknown message type: %d", m)
}
}
const clientSessionStateRevision = 3
type conn struct {
localAddr, remoteAddr net.Addr
version protocol.VersionNumber
}
var _ ConnWithVersion = &conn{}
func newConn(local, remote net.Addr, version protocol.VersionNumber) ConnWithVersion {
return &conn{
localAddr: local,
remoteAddr: remote,
version: version,
}
}
var _ net.Conn = &conn{}
func (c *conn) Read([]byte) (int, error) { return 0, nil }
func (c *conn) Write([]byte) (int, error) { return 0, nil }
func (c *conn) Close() error { return nil }
func (c *conn) RemoteAddr() net.Addr { return c.remoteAddr }
func (c *conn) LocalAddr() net.Addr { return c.localAddr }
func (c *conn) SetReadDeadline(time.Time) error { return nil }
func (c *conn) SetWriteDeadline(time.Time) error { return nil }
func (c *conn) SetDeadline(time.Time) error { return nil }
func (c *conn) GetQUICVersion() protocol.VersionNumber { return c.version }
type cryptoSetup struct {
tlsConf *tls.Config
extraConf *qtls.ExtraConfig
conn *qtls.Conn
version protocol.VersionNumber
messageChan chan []byte
isReadingHandshakeMessage chan struct{}
readFirstHandshakeMessage bool
ourParams *wire.TransportParameters
peerParams *wire.TransportParameters
paramsChan <-chan []byte
runner handshakeRunner
alertChan chan uint8
// handshakeDone is closed as soon as the go routine running qtls.Handshake() returns
handshakeDone chan struct{}
// is closed when Close() is called
closeChan chan struct{}
zeroRTTParameters *wire.TransportParameters
clientHelloWritten bool
clientHelloWrittenChan chan *wire.TransportParameters
rttStats *utils.RTTStats
tracer logging.ConnectionTracer
logger utils.Logger
perspective protocol.Perspective
mutex sync.Mutex // protects all members below
handshakeCompleteTime time.Time
readEncLevel protocol.EncryptionLevel
writeEncLevel protocol.EncryptionLevel
zeroRTTOpener LongHeaderOpener // only set for the server
zeroRTTSealer LongHeaderSealer // only set for the client
initialStream io.Writer
initialOpener LongHeaderOpener
initialSealer LongHeaderSealer
handshakeStream io.Writer
handshakeOpener LongHeaderOpener
handshakeSealer LongHeaderSealer
aead *updatableAEAD
has1RTTSealer bool
has1RTTOpener bool
}
var (
_ qtls.RecordLayer = &cryptoSetup{}
_ CryptoSetup = &cryptoSetup{}
)
// NewCryptoSetupClient creates a new crypto setup for the client
func NewCryptoSetupClient(
initialStream io.Writer,
handshakeStream io.Writer,
connID protocol.ConnectionID,
localAddr net.Addr,
remoteAddr net.Addr,
tp *wire.TransportParameters,
runner handshakeRunner,
tlsConf *tls.Config,
enable0RTT bool,
rttStats *utils.RTTStats,
tracer logging.ConnectionTracer,
logger utils.Logger,
version protocol.VersionNumber,
) (CryptoSetup, <-chan *wire.TransportParameters /* ClientHello written. Receive nil for non-0-RTT */) {
cs, clientHelloWritten := newCryptoSetup(
initialStream,
handshakeStream,
connID,
tp,
runner,
tlsConf,
enable0RTT,
rttStats,
tracer,
logger,
protocol.PerspectiveClient,
version,
)
cs.conn = qtls.Client(newConn(localAddr, remoteAddr, version), cs.tlsConf, cs.extraConf)
return cs, clientHelloWritten
}
// NewCryptoSetupServer creates a new crypto setup for the server
func NewCryptoSetupServer(
initialStream io.Writer,
handshakeStream io.Writer,
connID protocol.ConnectionID,
localAddr net.Addr,
remoteAddr net.Addr,
tp *wire.TransportParameters,
runner handshakeRunner,
tlsConf *tls.Config,
enable0RTT bool,
rttStats *utils.RTTStats,
tracer logging.ConnectionTracer,
logger utils.Logger,
version protocol.VersionNumber,
) CryptoSetup {
cs, _ := newCryptoSetup(
initialStream,
handshakeStream,
connID,
tp,
runner,
tlsConf,
enable0RTT,
rttStats,
tracer,
logger,
protocol.PerspectiveServer,
version,
)
cs.conn = qtls.Server(newConn(localAddr, remoteAddr, version), cs.tlsConf, cs.extraConf)
return cs
}
func newCryptoSetup(
initialStream io.Writer,
handshakeStream io.Writer,
connID protocol.ConnectionID,
tp *wire.TransportParameters,
runner handshakeRunner,
tlsConf *tls.Config,
enable0RTT bool,
rttStats *utils.RTTStats,
tracer logging.ConnectionTracer,
logger utils.Logger,
perspective protocol.Perspective,
version protocol.VersionNumber,
) (*cryptoSetup, <-chan *wire.TransportParameters /* ClientHello written. Receive nil for non-0-RTT */) {
initialSealer, initialOpener := NewInitialAEAD(connID, perspective, version)
if tracer != nil {
tracer.UpdatedKeyFromTLS(protocol.EncryptionInitial, protocol.PerspectiveClient)
tracer.UpdatedKeyFromTLS(protocol.EncryptionInitial, protocol.PerspectiveServer)
}
extHandler := newExtensionHandler(tp.Marshal(perspective), perspective, version)
cs := &cryptoSetup{
tlsConf: tlsConf,
initialStream: initialStream,
initialSealer: initialSealer,
initialOpener: initialOpener,
handshakeStream: handshakeStream,
aead: newUpdatableAEAD(rttStats, tracer, logger),
readEncLevel: protocol.EncryptionInitial,
writeEncLevel: protocol.EncryptionInitial,
runner: runner,
ourParams: tp,
paramsChan: extHandler.TransportParameters(),
rttStats: rttStats,
tracer: tracer,
logger: logger,
perspective: perspective,
handshakeDone: make(chan struct{}),
alertChan: make(chan uint8),
clientHelloWrittenChan: make(chan *wire.TransportParameters, 1),
messageChan: make(chan []byte, 100),
isReadingHandshakeMessage: make(chan struct{}),
closeChan: make(chan struct{}),
version: version,
}
var maxEarlyData uint32
if enable0RTT {
maxEarlyData = 0xffffffff
}
cs.extraConf = &qtls.ExtraConfig{
GetExtensions: extHandler.GetExtensions,
ReceivedExtensions: extHandler.ReceivedExtensions,
AlternativeRecordLayer: cs,
EnforceNextProtoSelection: true,
MaxEarlyData: maxEarlyData,
Accept0RTT: cs.accept0RTT,
Rejected0RTT: cs.rejected0RTT,
Enable0RTT: enable0RTT,
GetAppDataForSessionState: cs.marshalDataForSessionState,
SetAppDataFromSessionState: cs.handleDataFromSessionState,
}
return cs, cs.clientHelloWrittenChan
}
func (h *cryptoSetup) ChangeConnectionID(id protocol.ConnectionID) {
initialSealer, initialOpener := NewInitialAEAD(id, h.perspective, h.version)
h.initialSealer = initialSealer
h.initialOpener = initialOpener
if h.tracer != nil {
h.tracer.UpdatedKeyFromTLS(protocol.EncryptionInitial, protocol.PerspectiveClient)
h.tracer.UpdatedKeyFromTLS(protocol.EncryptionInitial, protocol.PerspectiveServer)
}
}
func (h *cryptoSetup) SetLargest1RTTAcked(pn protocol.PacketNumber) error {
return h.aead.SetLargestAcked(pn)
}
func (h *cryptoSetup) RunHandshake() {
// Handle errors that might occur when HandleData() is called.
handshakeComplete := make(chan struct{})
handshakeErrChan := make(chan error, 1)
go func() {
defer close(h.handshakeDone)
if err := h.conn.Handshake(); err != nil {
handshakeErrChan <- err
return
}
close(handshakeComplete)
}()
select {
case <-handshakeComplete: // return when the handshake is done
h.mutex.Lock()
h.handshakeCompleteTime = time.Now()
h.mutex.Unlock()
h.runner.OnHandshakeComplete()
case <-h.closeChan:
// wait until the Handshake() go routine has returned
<-h.handshakeDone
case alert := <-h.alertChan:
handshakeErr := <-handshakeErrChan
h.onError(alert, handshakeErr.Error())
}
}
func (h *cryptoSetup) onError(alert uint8, message string) {
h.runner.OnError(qerr.NewCryptoError(alert, message))
}
// Close closes the crypto setup.
// It aborts the handshake, if it is still running.
// It must only be called once.
func (h *cryptoSetup) Close() error {
close(h.closeChan)
// wait until qtls.Handshake() actually returned
<-h.handshakeDone
return nil
}
// handleMessage handles a TLS handshake message.
// It is called by the crypto streams when a new message is available.
// It returns if it is done with messages on the same encryption level.
func (h *cryptoSetup) HandleMessage(data []byte, encLevel protocol.EncryptionLevel) bool /* stream finished */ {
msgType := messageType(data[0])
h.logger.Debugf("Received %s message (%d bytes, encryption level: %s)", msgType, len(data), encLevel)
if err := h.checkEncryptionLevel(msgType, encLevel); err != nil {
h.onError(alertUnexpectedMessage, err.Error())
return false
}
h.messageChan <- data
if encLevel == protocol.Encryption1RTT {
h.handlePostHandshakeMessage()
return false
}
readLoop:
for {
select {
case data := <-h.paramsChan:
if data == nil {
h.onError(0x6d, "missing quic_transport_parameters extension")
} else {
h.handleTransportParameters(data)
}
case <-h.isReadingHandshakeMessage:
break readLoop
case <-h.handshakeDone:
break readLoop
case <-h.closeChan:
break readLoop
}
}
// We're done with the Initial encryption level after processing a ClientHello / ServerHello,
// but only if a handshake opener and sealer was created.
// Otherwise, a HelloRetryRequest was performed.
// We're done with the Handshake encryption level after processing the Finished message.
return ((msgType == typeClientHello || msgType == typeServerHello) && h.handshakeOpener != nil && h.handshakeSealer != nil) ||
msgType == typeFinished
}
func (h *cryptoSetup) checkEncryptionLevel(msgType messageType, encLevel protocol.EncryptionLevel) error {
var expected protocol.EncryptionLevel
switch msgType {
case typeClientHello,
typeServerHello:
expected = protocol.EncryptionInitial
case typeEncryptedExtensions,
typeCertificate,
typeCertificateRequest,
typeCertificateVerify,
typeFinished:
expected = protocol.EncryptionHandshake
case typeNewSessionTicket:
expected = protocol.Encryption1RTT
default:
return fmt.Errorf("unexpected handshake message: %d", msgType)
}
if encLevel != expected {
return fmt.Errorf("expected handshake message %s to have encryption level %s, has %s", msgType, expected, encLevel)
}
return nil
}
func (h *cryptoSetup) handleTransportParameters(data []byte) {
var tp wire.TransportParameters
if err := tp.Unmarshal(data, h.perspective.Opposite()); err != nil {
h.runner.OnError(&qerr.TransportError{
ErrorCode: qerr.TransportParameterError,
ErrorMessage: err.Error(),
})
}
h.peerParams = &tp
h.runner.OnReceivedParams(h.peerParams)
}
// must be called after receiving the transport parameters
func (h *cryptoSetup) marshalDataForSessionState() []byte {
buf := &bytes.Buffer{}
quicvarint.Write(buf, clientSessionStateRevision)
quicvarint.Write(buf, uint64(h.rttStats.SmoothedRTT().Microseconds()))
h.peerParams.MarshalForSessionTicket(buf)
return buf.Bytes()
}
func (h *cryptoSetup) handleDataFromSessionState(data []byte) {
tp, err := h.handleDataFromSessionStateImpl(data)
if err != nil {
h.logger.Debugf("Restoring of transport parameters from session ticket failed: %s", err.Error())
return
}
h.zeroRTTParameters = tp
}
func (h *cryptoSetup) handleDataFromSessionStateImpl(data []byte) (*wire.TransportParameters, error) {
r := bytes.NewReader(data)
ver, err := quicvarint.Read(r)
if err != nil {
return nil, err
}
if ver != clientSessionStateRevision {
return nil, fmt.Errorf("mismatching version. Got %d, expected %d", ver, clientSessionStateRevision)
}
rtt, err := quicvarint.Read(r)
if err != nil {
return nil, err
}
h.rttStats.SetInitialRTT(time.Duration(rtt) * time.Microsecond)
var tp wire.TransportParameters
if err := tp.UnmarshalFromSessionTicket(r); err != nil {
return nil, err
}
return &tp, nil
}
// only valid for the server
func (h *cryptoSetup) GetSessionTicket() ([]byte, error) {
var appData []byte
// Save transport parameters to the session ticket if we're allowing 0-RTT.
if h.extraConf.MaxEarlyData > 0 {
appData = (&sessionTicket{
Parameters: h.ourParams,
RTT: h.rttStats.SmoothedRTT(),
}).Marshal()
}
return h.conn.GetSessionTicket(appData)
}
// accept0RTT is called for the server when receiving the client's session ticket.
// It decides whether to accept 0-RTT.
func (h *cryptoSetup) accept0RTT(sessionTicketData []byte) bool {
var t sessionTicket
if err := t.Unmarshal(sessionTicketData); err != nil {
h.logger.Debugf("Unmarshalling transport parameters from session ticket failed: %s", err.Error())
return false
}
valid := h.ourParams.ValidFor0RTT(t.Parameters)
if valid {
h.logger.Debugf("Accepting 0-RTT. Restoring RTT from session ticket: %s", t.RTT)
h.rttStats.SetInitialRTT(t.RTT)
} else {
h.logger.Debugf("Transport parameters changed. Rejecting 0-RTT.")
}
return valid
}
// rejected0RTT is called for the client when the server rejects 0-RTT.
func (h *cryptoSetup) rejected0RTT() {
h.logger.Debugf("0-RTT was rejected. Dropping 0-RTT keys.")
h.mutex.Lock()
had0RTTKeys := h.zeroRTTSealer != nil
h.zeroRTTSealer = nil
h.mutex.Unlock()
if had0RTTKeys {
h.runner.DropKeys(protocol.Encryption0RTT)
}
}
func (h *cryptoSetup) handlePostHandshakeMessage() {
// make sure the handshake has already completed
<-h.handshakeDone
done := make(chan struct{})
defer close(done)
// h.alertChan is an unbuffered channel.
// If an error occurs during conn.HandlePostHandshakeMessage,
// it will be sent on this channel.
// Read it from a go-routine so that HandlePostHandshakeMessage doesn't deadlock.
alertChan := make(chan uint8, 1)
go func() {
<-h.isReadingHandshakeMessage
select {
case alert := <-h.alertChan:
alertChan <- alert
case <-done:
}
}()
if err := h.conn.HandlePostHandshakeMessage(); err != nil {
select {
case <-h.closeChan:
case alert := <-alertChan:
h.onError(alert, err.Error())
}
}
}
// ReadHandshakeMessage is called by TLS.
// It blocks until a new handshake message is available.
func (h *cryptoSetup) ReadHandshakeMessage() ([]byte, error) {
if !h.readFirstHandshakeMessage {
h.readFirstHandshakeMessage = true
} else {
select {
case h.isReadingHandshakeMessage <- struct{}{}:
case <-h.closeChan:
return nil, errors.New("error while handling the handshake message")
}
}
select {
case msg := <-h.messageChan:
return msg, nil
case <-h.closeChan:
return nil, errors.New("error while handling the handshake message")
}
}
func (h *cryptoSetup) SetReadKey(encLevel qtls.EncryptionLevel, suite *qtls.CipherSuiteTLS13, trafficSecret []byte) {
h.mutex.Lock()
switch encLevel {
case qtls.Encryption0RTT:
if h.perspective == protocol.PerspectiveClient {
panic("Received 0-RTT read key for the client")
}
h.zeroRTTOpener = newLongHeaderOpener(
createAEAD(suite, trafficSecret),
newHeaderProtector(suite, trafficSecret, true),
)
h.mutex.Unlock()
h.logger.Debugf("Installed 0-RTT Read keys (using %s)", tls.CipherSuiteName(suite.ID))
if h.tracer != nil {
h.tracer.UpdatedKeyFromTLS(protocol.Encryption0RTT, h.perspective.Opposite())
}
return
case qtls.EncryptionHandshake:
h.readEncLevel = protocol.EncryptionHandshake
h.handshakeOpener = newHandshakeOpener(
createAEAD(suite, trafficSecret),
newHeaderProtector(suite, trafficSecret, true),
h.dropInitialKeys,
h.perspective,
)
h.logger.Debugf("Installed Handshake Read keys (using %s)", tls.CipherSuiteName(suite.ID))
case qtls.EncryptionApplication:
h.readEncLevel = protocol.Encryption1RTT
h.aead.SetReadKey(suite, trafficSecret)
h.has1RTTOpener = true
h.logger.Debugf("Installed 1-RTT Read keys (using %s)", tls.CipherSuiteName(suite.ID))
default:
panic("unexpected read encryption level")
}
h.mutex.Unlock()
if h.tracer != nil {
h.tracer.UpdatedKeyFromTLS(h.readEncLevel, h.perspective.Opposite())
}
}
func (h *cryptoSetup) SetWriteKey(encLevel qtls.EncryptionLevel, suite *qtls.CipherSuiteTLS13, trafficSecret []byte) {
h.mutex.Lock()
switch encLevel {
case qtls.Encryption0RTT:
if h.perspective == protocol.PerspectiveServer {
panic("Received 0-RTT write key for the server")
}
h.zeroRTTSealer = newLongHeaderSealer(
createAEAD(suite, trafficSecret),
newHeaderProtector(suite, trafficSecret, true),
)
h.mutex.Unlock()
h.logger.Debugf("Installed 0-RTT Write keys (using %s)", tls.CipherSuiteName(suite.ID))
if h.tracer != nil {
h.tracer.UpdatedKeyFromTLS(protocol.Encryption0RTT, h.perspective)
}
return
case qtls.EncryptionHandshake:
h.writeEncLevel = protocol.EncryptionHandshake
h.handshakeSealer = newHandshakeSealer(
createAEAD(suite, trafficSecret),
newHeaderProtector(suite, trafficSecret, true),
h.dropInitialKeys,
h.perspective,
)
h.logger.Debugf("Installed Handshake Write keys (using %s)", tls.CipherSuiteName(suite.ID))
case qtls.EncryptionApplication:
h.writeEncLevel = protocol.Encryption1RTT
h.aead.SetWriteKey(suite, trafficSecret)
h.has1RTTSealer = true
h.logger.Debugf("Installed 1-RTT Write keys (using %s)", tls.CipherSuiteName(suite.ID))
if h.zeroRTTSealer != nil {
h.zeroRTTSealer = nil
h.logger.Debugf("Dropping 0-RTT keys.")
if h.tracer != nil {
h.tracer.DroppedEncryptionLevel(protocol.Encryption0RTT)
}
}
default:
panic("unexpected write encryption level")
}
h.mutex.Unlock()
if h.tracer != nil {
h.tracer.UpdatedKeyFromTLS(h.writeEncLevel, h.perspective)
}
}
// WriteRecord is called when TLS writes data
func (h *cryptoSetup) WriteRecord(p []byte) (int, error) {
h.mutex.Lock()
defer h.mutex.Unlock()
//nolint:exhaustive // LS records can only be written for Initial and Handshake.
switch h.writeEncLevel {
case protocol.EncryptionInitial:
// assume that the first WriteRecord call contains the ClientHello
n, err := h.initialStream.Write(p)
if !h.clientHelloWritten && h.perspective == protocol.PerspectiveClient {
h.clientHelloWritten = true
if h.zeroRTTSealer != nil && h.zeroRTTParameters != nil {
h.logger.Debugf("Doing 0-RTT.")
h.clientHelloWrittenChan <- h.zeroRTTParameters
} else {
h.logger.Debugf("Not doing 0-RTT.")
h.clientHelloWrittenChan <- nil
}
}
return n, err
case protocol.EncryptionHandshake:
return h.handshakeStream.Write(p)
default:
panic(fmt.Sprintf("unexpected write encryption level: %s", h.writeEncLevel))
}
}
func (h *cryptoSetup) SendAlert(alert uint8) {
select {
case h.alertChan <- alert:
case <-h.closeChan:
// no need to send an alert when we've already closed
}
}
// used a callback in the handshakeSealer and handshakeOpener
func (h *cryptoSetup) dropInitialKeys() {
h.mutex.Lock()
h.initialOpener = nil
h.initialSealer = nil
h.mutex.Unlock()
h.runner.DropKeys(protocol.EncryptionInitial)
h.logger.Debugf("Dropping Initial keys.")
}
func (h *cryptoSetup) SetHandshakeConfirmed() {
h.aead.SetHandshakeConfirmed()
// drop Handshake keys
var dropped bool
h.mutex.Lock()
if h.handshakeOpener != nil {
h.handshakeOpener = nil
h.handshakeSealer = nil
dropped = true
}
h.mutex.Unlock()
if dropped {
h.runner.DropKeys(protocol.EncryptionHandshake)
h.logger.Debugf("Dropping Handshake keys.")
}
}
func (h *cryptoSetup) GetInitialSealer() (LongHeaderSealer, error) {
h.mutex.Lock()
defer h.mutex.Unlock()
if h.initialSealer == nil {
return nil, ErrKeysDropped
}
return h.initialSealer, nil
}
func (h *cryptoSetup) Get0RTTSealer() (LongHeaderSealer, error) {
h.mutex.Lock()
defer h.mutex.Unlock()
if h.zeroRTTSealer == nil {
return nil, ErrKeysDropped
}
return h.zeroRTTSealer, nil
}
func (h *cryptoSetup) GetHandshakeSealer() (LongHeaderSealer, error) {
h.mutex.Lock()
defer h.mutex.Unlock()
if h.handshakeSealer == nil {
if h.initialSealer == nil {
return nil, ErrKeysDropped
}
return nil, ErrKeysNotYetAvailable
}
return h.handshakeSealer, nil
}
func (h *cryptoSetup) Get1RTTSealer() (ShortHeaderSealer, error) {
h.mutex.Lock()
defer h.mutex.Unlock()
if !h.has1RTTSealer {
return nil, ErrKeysNotYetAvailable
}
return h.aead, nil
}
func (h *cryptoSetup) GetInitialOpener() (LongHeaderOpener, error) {
h.mutex.Lock()
defer h.mutex.Unlock()
if h.initialOpener == nil {
return nil, ErrKeysDropped
}
return h.initialOpener, nil
}
func (h *cryptoSetup) Get0RTTOpener() (LongHeaderOpener, error) {
h.mutex.Lock()
defer h.mutex.Unlock()
if h.zeroRTTOpener == nil {
if h.initialOpener != nil {
return nil, ErrKeysNotYetAvailable
}
// if the initial opener is also not available, the keys were already dropped
return nil, ErrKeysDropped
}
return h.zeroRTTOpener, nil
}
func (h *cryptoSetup) GetHandshakeOpener() (LongHeaderOpener, error) {
h.mutex.Lock()
defer h.mutex.Unlock()
if h.handshakeOpener == nil {
if h.initialOpener != nil {
return nil, ErrKeysNotYetAvailable
}
// if the initial opener is also not available, the keys were already dropped
return nil, ErrKeysDropped
}
return h.handshakeOpener, nil
}
func (h *cryptoSetup) Get1RTTOpener() (ShortHeaderOpener, error) {
h.mutex.Lock()
defer h.mutex.Unlock()
if h.zeroRTTOpener != nil && time.Since(h.handshakeCompleteTime) > 3*h.rttStats.PTO(true) {
h.zeroRTTOpener = nil
h.logger.Debugf("Dropping 0-RTT keys.")
if h.tracer != nil {
h.tracer.DroppedEncryptionLevel(protocol.Encryption0RTT)
}
}
if !h.has1RTTOpener {
return nil, ErrKeysNotYetAvailable
}
return h.aead, nil
}
func (h *cryptoSetup) ConnectionState() ConnectionState {
return qtls.GetConnectionState(h.conn)
}

View File

@@ -0,0 +1,127 @@
package handshake
import (
"crypto/aes"
"crypto/cipher"
"crypto/tls"
"encoding/binary"
"fmt"
"golang.org/x/crypto/chacha20"
"github.com/lucas-clemente/quic-go/internal/qtls"
)
type headerProtector interface {
EncryptHeader(sample []byte, firstByte *byte, hdrBytes []byte)
DecryptHeader(sample []byte, firstByte *byte, hdrBytes []byte)
}
func newHeaderProtector(suite *qtls.CipherSuiteTLS13, trafficSecret []byte, isLongHeader bool) headerProtector {
switch suite.ID {
case tls.TLS_AES_128_GCM_SHA256, tls.TLS_AES_256_GCM_SHA384:
return newAESHeaderProtector(suite, trafficSecret, isLongHeader)
case tls.TLS_CHACHA20_POLY1305_SHA256:
return newChaChaHeaderProtector(suite, trafficSecret, isLongHeader)
default:
panic(fmt.Sprintf("Invalid cipher suite id: %d", suite.ID))
}
}
type aesHeaderProtector struct {
mask []byte
block cipher.Block
isLongHeader bool
}
var _ headerProtector = &aesHeaderProtector{}
func newAESHeaderProtector(suite *qtls.CipherSuiteTLS13, trafficSecret []byte, isLongHeader bool) headerProtector {
hpKey := hkdfExpandLabel(suite.Hash, trafficSecret, []byte{}, "quic hp", suite.KeyLen)
block, err := aes.NewCipher(hpKey)
if err != nil {
panic(fmt.Sprintf("error creating new AES cipher: %s", err))
}
return &aesHeaderProtector{
block: block,
mask: make([]byte, block.BlockSize()),
isLongHeader: isLongHeader,
}
}
func (p *aesHeaderProtector) DecryptHeader(sample []byte, firstByte *byte, hdrBytes []byte) {
p.apply(sample, firstByte, hdrBytes)
}
func (p *aesHeaderProtector) EncryptHeader(sample []byte, firstByte *byte, hdrBytes []byte) {
p.apply(sample, firstByte, hdrBytes)
}
func (p *aesHeaderProtector) apply(sample []byte, firstByte *byte, hdrBytes []byte) {
if len(sample) != len(p.mask) {
panic("invalid sample size")
}
p.block.Encrypt(p.mask, sample)
if p.isLongHeader {
*firstByte ^= p.mask[0] & 0xf
} else {
*firstByte ^= p.mask[0] & 0x1f
}
for i := range hdrBytes {
hdrBytes[i] ^= p.mask[i+1]
}
}
type chachaHeaderProtector struct {
mask [5]byte
key [32]byte
isLongHeader bool
}
var _ headerProtector = &chachaHeaderProtector{}
func newChaChaHeaderProtector(suite *qtls.CipherSuiteTLS13, trafficSecret []byte, isLongHeader bool) headerProtector {
hpKey := hkdfExpandLabel(suite.Hash, trafficSecret, []byte{}, "quic hp", suite.KeyLen)
p := &chachaHeaderProtector{
isLongHeader: isLongHeader,
}
copy(p.key[:], hpKey)
return p
}
func (p *chachaHeaderProtector) DecryptHeader(sample []byte, firstByte *byte, hdrBytes []byte) {
p.apply(sample, firstByte, hdrBytes)
}
func (p *chachaHeaderProtector) EncryptHeader(sample []byte, firstByte *byte, hdrBytes []byte) {
p.apply(sample, firstByte, hdrBytes)
}
func (p *chachaHeaderProtector) apply(sample []byte, firstByte *byte, hdrBytes []byte) {
if len(sample) != 16 {
panic("invalid sample size")
}
for i := 0; i < 5; i++ {
p.mask[i] = 0
}
cipher, err := chacha20.NewUnauthenticatedCipher(p.key[:], sample[4:])
if err != nil {
panic(err)
}
cipher.SetCounter(binary.LittleEndian.Uint32(sample[:4]))
cipher.XORKeyStream(p.mask[:], p.mask[:])
p.applyMask(firstByte, hdrBytes)
}
func (p *chachaHeaderProtector) applyMask(firstByte *byte, hdrBytes []byte) {
if p.isLongHeader {
*firstByte ^= p.mask[0] & 0xf
} else {
*firstByte ^= p.mask[0] & 0x1f
}
for i := range hdrBytes {
hdrBytes[i] ^= p.mask[i+1]
}
}

View File

@@ -0,0 +1,29 @@
package handshake
import (
"crypto"
"encoding/binary"
"golang.org/x/crypto/hkdf"
)
// hkdfExpandLabel HKDF expands a label.
// Since this implementation avoids using a cryptobyte.Builder, it is about 15% faster than the
// hkdfExpandLabel in the standard library.
func hkdfExpandLabel(hash crypto.Hash, secret, context []byte, label string, length int) []byte {
b := make([]byte, 3, 3+6+len(label)+1+len(context))
binary.BigEndian.PutUint16(b, uint16(length))
b[2] = uint8(6 + len(label))
b = append(b, []byte("tls13 ")...)
b = append(b, []byte(label)...)
b = b[:3+6+len(label)+1]
b[3+6+len(label)] = uint8(len(context))
b = append(b, context...)
out := make([]byte, length)
n, err := hkdf.Expand(hash.New, secret, b).Read(out)
if err != nil || n != length {
panic("quic: HKDF-Expand-Label invocation failed unexpectedly")
}
return out
}

View File

@@ -0,0 +1,64 @@
package handshake
import (
"crypto"
"crypto/tls"
"golang.org/x/crypto/hkdf"
"github.com/lucas-clemente/quic-go/internal/protocol"
"github.com/lucas-clemente/quic-go/internal/qtls"
)
var (
quicSaltOld = []byte{0xaf, 0xbf, 0xec, 0x28, 0x99, 0x93, 0xd2, 0x4c, 0x9e, 0x97, 0x86, 0xf1, 0x9c, 0x61, 0x11, 0xe0, 0x43, 0x90, 0xa8, 0x99}
quicSaltDraft34 = []byte{0x38, 0x76, 0x2c, 0xf7, 0xf5, 0x59, 0x34, 0xb3, 0x4d, 0x17, 0x9a, 0xe6, 0xa4, 0xc8, 0x0c, 0xad, 0xcc, 0xbb, 0x7f, 0x0a}
)
func getSalt(v protocol.VersionNumber) []byte {
if v == protocol.VersionDraft34 || v == protocol.Version1 {
return quicSaltDraft34
}
return quicSaltOld
}
var initialSuite = &qtls.CipherSuiteTLS13{
ID: tls.TLS_AES_128_GCM_SHA256,
KeyLen: 16,
AEAD: qtls.AEADAESGCMTLS13,
Hash: crypto.SHA256,
}
// NewInitialAEAD creates a new AEAD for Initial encryption / decryption.
func NewInitialAEAD(connID protocol.ConnectionID, pers protocol.Perspective, v protocol.VersionNumber) (LongHeaderSealer, LongHeaderOpener) {
clientSecret, serverSecret := computeSecrets(connID, v)
var mySecret, otherSecret []byte
if pers == protocol.PerspectiveClient {
mySecret = clientSecret
otherSecret = serverSecret
} else {
mySecret = serverSecret
otherSecret = clientSecret
}
myKey, myIV := computeInitialKeyAndIV(mySecret)
otherKey, otherIV := computeInitialKeyAndIV(otherSecret)
encrypter := qtls.AEADAESGCMTLS13(myKey, myIV)
decrypter := qtls.AEADAESGCMTLS13(otherKey, otherIV)
return newLongHeaderSealer(encrypter, newHeaderProtector(initialSuite, mySecret, true)),
newLongHeaderOpener(decrypter, newAESHeaderProtector(initialSuite, otherSecret, true))
}
func computeSecrets(connID protocol.ConnectionID, v protocol.VersionNumber) (clientSecret, serverSecret []byte) {
initialSecret := hkdf.Extract(crypto.SHA256.New, connID, getSalt(v))
clientSecret = hkdfExpandLabel(crypto.SHA256, initialSecret, []byte{}, "client in", crypto.SHA256.Size())
serverSecret = hkdfExpandLabel(crypto.SHA256, initialSecret, []byte{}, "server in", crypto.SHA256.Size())
return
}
func computeInitialKeyAndIV(secret []byte) (key, iv []byte) {
key = hkdfExpandLabel(crypto.SHA256, secret, []byte{}, "quic key", 16)
iv = hkdfExpandLabel(crypto.SHA256, secret, []byte{}, "quic iv", 12)
return
}

View File

@@ -0,0 +1,102 @@
package handshake
import (
"errors"
"io"
"net"
"time"
"github.com/lucas-clemente/quic-go/internal/protocol"
"github.com/lucas-clemente/quic-go/internal/qtls"
"github.com/lucas-clemente/quic-go/internal/wire"
)
var (
// ErrKeysNotYetAvailable is returned when an opener or a sealer is requested for an encryption level,
// but the corresponding opener has not yet been initialized
// This can happen when packets arrive out of order.
ErrKeysNotYetAvailable = errors.New("CryptoSetup: keys at this encryption level not yet available")
// ErrKeysDropped is returned when an opener or a sealer is requested for an encryption level,
// but the corresponding keys have already been dropped.
ErrKeysDropped = errors.New("CryptoSetup: keys were already dropped")
// ErrDecryptionFailed is returned when the AEAD fails to open the packet.
ErrDecryptionFailed = errors.New("decryption failed")
)
// ConnectionState contains information about the state of the connection.
type ConnectionState = qtls.ConnectionState
type headerDecryptor interface {
DecryptHeader(sample []byte, firstByte *byte, pnBytes []byte)
}
// LongHeaderOpener opens a long header packet
type LongHeaderOpener interface {
headerDecryptor
DecodePacketNumber(wirePN protocol.PacketNumber, wirePNLen protocol.PacketNumberLen) protocol.PacketNumber
Open(dst, src []byte, pn protocol.PacketNumber, associatedData []byte) ([]byte, error)
}
// ShortHeaderOpener opens a short header packet
type ShortHeaderOpener interface {
headerDecryptor
DecodePacketNumber(wirePN protocol.PacketNumber, wirePNLen protocol.PacketNumberLen) protocol.PacketNumber
Open(dst, src []byte, rcvTime time.Time, pn protocol.PacketNumber, kp protocol.KeyPhaseBit, associatedData []byte) ([]byte, error)
}
// LongHeaderSealer seals a long header packet
type LongHeaderSealer interface {
Seal(dst, src []byte, packetNumber protocol.PacketNumber, associatedData []byte) []byte
EncryptHeader(sample []byte, firstByte *byte, pnBytes []byte)
Overhead() int
}
// ShortHeaderSealer seals a short header packet
type ShortHeaderSealer interface {
LongHeaderSealer
KeyPhase() protocol.KeyPhaseBit
}
// A tlsExtensionHandler sends and received the QUIC TLS extension.
type tlsExtensionHandler interface {
GetExtensions(msgType uint8) []qtls.Extension
ReceivedExtensions(msgType uint8, exts []qtls.Extension)
TransportParameters() <-chan []byte
}
type handshakeRunner interface {
OnReceivedParams(*wire.TransportParameters)
OnHandshakeComplete()
OnError(error)
DropKeys(protocol.EncryptionLevel)
}
// CryptoSetup handles the handshake and protecting / unprotecting packets
type CryptoSetup interface {
RunHandshake()
io.Closer
ChangeConnectionID(protocol.ConnectionID)
GetSessionTicket() ([]byte, error)
HandleMessage([]byte, protocol.EncryptionLevel) bool
SetLargest1RTTAcked(protocol.PacketNumber) error
SetHandshakeConfirmed()
ConnectionState() ConnectionState
GetInitialOpener() (LongHeaderOpener, error)
GetHandshakeOpener() (LongHeaderOpener, error)
Get0RTTOpener() (LongHeaderOpener, error)
Get1RTTOpener() (ShortHeaderOpener, error)
GetInitialSealer() (LongHeaderSealer, error)
GetHandshakeSealer() (LongHeaderSealer, error)
Get0RTTSealer() (LongHeaderSealer, error)
Get1RTTSealer() (ShortHeaderSealer, error)
}
// ConnWithVersion is the connection used in the ClientHelloInfo.
// It can be used to determine the QUIC version in use.
type ConnWithVersion interface {
net.Conn
GetQUICVersion() protocol.VersionNumber
}

View File

@@ -0,0 +1,3 @@
package handshake
//go:generate sh -c "../../mockgen_private.sh handshake mock_handshake_runner_test.go github.com/lucas-clemente/quic-go/internal/handshake handshakeRunner"

View File

@@ -0,0 +1,62 @@
package handshake
import (
"bytes"
"crypto/aes"
"crypto/cipher"
"fmt"
"sync"
"github.com/lucas-clemente/quic-go/internal/protocol"
)
var (
oldRetryAEAD cipher.AEAD // used for QUIC draft versions up to 34
retryAEAD cipher.AEAD // used for QUIC draft-34
)
func init() {
oldRetryAEAD = initAEAD([16]byte{0xcc, 0xce, 0x18, 0x7e, 0xd0, 0x9a, 0x09, 0xd0, 0x57, 0x28, 0x15, 0x5a, 0x6c, 0xb9, 0x6b, 0xe1})
retryAEAD = initAEAD([16]byte{0xbe, 0x0c, 0x69, 0x0b, 0x9f, 0x66, 0x57, 0x5a, 0x1d, 0x76, 0x6b, 0x54, 0xe3, 0x68, 0xc8, 0x4e})
}
func initAEAD(key [16]byte) cipher.AEAD {
aes, err := aes.NewCipher(key[:])
if err != nil {
panic(err)
}
aead, err := cipher.NewGCM(aes)
if err != nil {
panic(err)
}
return aead
}
var (
retryBuf bytes.Buffer
retryMutex sync.Mutex
oldRetryNonce = [12]byte{0xe5, 0x49, 0x30, 0xf9, 0x7f, 0x21, 0x36, 0xf0, 0x53, 0x0a, 0x8c, 0x1c}
retryNonce = [12]byte{0x46, 0x15, 0x99, 0xd3, 0x5d, 0x63, 0x2b, 0xf2, 0x23, 0x98, 0x25, 0xbb}
)
// GetRetryIntegrityTag calculates the integrity tag on a Retry packet
func GetRetryIntegrityTag(retry []byte, origDestConnID protocol.ConnectionID, version protocol.VersionNumber) *[16]byte {
retryMutex.Lock()
retryBuf.WriteByte(uint8(origDestConnID.Len()))
retryBuf.Write(origDestConnID.Bytes())
retryBuf.Write(retry)
var tag [16]byte
var sealed []byte
if version != protocol.VersionDraft34 && version != protocol.Version1 {
sealed = oldRetryAEAD.Seal(tag[:0], oldRetryNonce[:], nil, retryBuf.Bytes())
} else {
sealed = retryAEAD.Seal(tag[:0], retryNonce[:], nil, retryBuf.Bytes())
}
if len(sealed) != 16 {
panic(fmt.Sprintf("unexpected Retry integrity tag length: %d", len(sealed)))
}
retryBuf.Reset()
retryMutex.Unlock()
return &tag
}

View File

@@ -0,0 +1,48 @@
package handshake
import (
"bytes"
"errors"
"fmt"
"time"
"github.com/lucas-clemente/quic-go/internal/wire"
"github.com/lucas-clemente/quic-go/quicvarint"
)
const sessionTicketRevision = 2
type sessionTicket struct {
Parameters *wire.TransportParameters
RTT time.Duration // to be encoded in mus
}
func (t *sessionTicket) Marshal() []byte {
b := &bytes.Buffer{}
quicvarint.Write(b, sessionTicketRevision)
quicvarint.Write(b, uint64(t.RTT.Microseconds()))
t.Parameters.MarshalForSessionTicket(b)
return b.Bytes()
}
func (t *sessionTicket) Unmarshal(b []byte) error {
r := bytes.NewReader(b)
rev, err := quicvarint.Read(r)
if err != nil {
return errors.New("failed to read session ticket revision")
}
if rev != sessionTicketRevision {
return fmt.Errorf("unknown session ticket revision: %d", rev)
}
rtt, err := quicvarint.Read(r)
if err != nil {
return errors.New("failed to read RTT")
}
var tp wire.TransportParameters
if err := tp.UnmarshalFromSessionTicket(r); err != nil {
return fmt.Errorf("unmarshaling transport parameters from session ticket failed: %s", err.Error())
}
t.Parameters = &tp
t.RTT = time.Duration(rtt) * time.Microsecond
return nil
}

View File

@@ -0,0 +1,68 @@
package handshake
import (
"github.com/lucas-clemente/quic-go/internal/protocol"
"github.com/lucas-clemente/quic-go/internal/qtls"
)
const (
quicTLSExtensionTypeOldDrafts = 0xffa5
quicTLSExtensionType = 0x39
)
type extensionHandler struct {
ourParams []byte
paramsChan chan []byte
extensionType uint16
perspective protocol.Perspective
}
var _ tlsExtensionHandler = &extensionHandler{}
// newExtensionHandler creates a new extension handler
func newExtensionHandler(params []byte, pers protocol.Perspective, v protocol.VersionNumber) tlsExtensionHandler {
et := uint16(quicTLSExtensionType)
if v != protocol.VersionDraft34 && v != protocol.Version1 {
et = quicTLSExtensionTypeOldDrafts
}
return &extensionHandler{
ourParams: params,
paramsChan: make(chan []byte),
perspective: pers,
extensionType: et,
}
}
func (h *extensionHandler) GetExtensions(msgType uint8) []qtls.Extension {
if (h.perspective == protocol.PerspectiveClient && messageType(msgType) != typeClientHello) ||
(h.perspective == protocol.PerspectiveServer && messageType(msgType) != typeEncryptedExtensions) {
return nil
}
return []qtls.Extension{{
Type: h.extensionType,
Data: h.ourParams,
}}
}
func (h *extensionHandler) ReceivedExtensions(msgType uint8, exts []qtls.Extension) {
if (h.perspective == protocol.PerspectiveClient && messageType(msgType) != typeEncryptedExtensions) ||
(h.perspective == protocol.PerspectiveServer && messageType(msgType) != typeClientHello) {
return
}
var data []byte
for _, ext := range exts {
if ext.Type == h.extensionType {
data = ext.Data
break
}
}
h.paramsChan <- data
}
func (h *extensionHandler) TransportParameters() <-chan []byte {
return h.paramsChan
}

View File

@@ -0,0 +1,134 @@
package handshake
import (
"encoding/asn1"
"fmt"
"io"
"net"
"time"
"github.com/lucas-clemente/quic-go/internal/protocol"
)
const (
tokenPrefixIP byte = iota
tokenPrefixString
)
// A Token is derived from the client address and can be used to verify the ownership of this address.
type Token struct {
IsRetryToken bool
RemoteAddr string
SentTime time.Time
// only set for retry tokens
OriginalDestConnectionID protocol.ConnectionID
RetrySrcConnectionID protocol.ConnectionID
}
// token is the struct that is used for ASN1 serialization and deserialization
type token struct {
IsRetryToken bool
RemoteAddr []byte
Timestamp int64
OriginalDestConnectionID []byte
RetrySrcConnectionID []byte
}
// A TokenGenerator generates tokens
type TokenGenerator struct {
tokenProtector tokenProtector
}
// NewTokenGenerator initializes a new TookenGenerator
func NewTokenGenerator(rand io.Reader) (*TokenGenerator, error) {
tokenProtector, err := newTokenProtector(rand)
if err != nil {
return nil, err
}
return &TokenGenerator{
tokenProtector: tokenProtector,
}, nil
}
// NewRetryToken generates a new token for a Retry for a given source address
func (g *TokenGenerator) NewRetryToken(
raddr net.Addr,
origDestConnID protocol.ConnectionID,
retrySrcConnID protocol.ConnectionID,
) ([]byte, error) {
data, err := asn1.Marshal(token{
IsRetryToken: true,
RemoteAddr: encodeRemoteAddr(raddr),
OriginalDestConnectionID: origDestConnID,
RetrySrcConnectionID: retrySrcConnID,
Timestamp: time.Now().UnixNano(),
})
if err != nil {
return nil, err
}
return g.tokenProtector.NewToken(data)
}
// NewToken generates a new token to be sent in a NEW_TOKEN frame
func (g *TokenGenerator) NewToken(raddr net.Addr) ([]byte, error) {
data, err := asn1.Marshal(token{
RemoteAddr: encodeRemoteAddr(raddr),
Timestamp: time.Now().UnixNano(),
})
if err != nil {
return nil, err
}
return g.tokenProtector.NewToken(data)
}
// DecodeToken decodes a token
func (g *TokenGenerator) DecodeToken(encrypted []byte) (*Token, error) {
// if the client didn't send any token, DecodeToken will be called with a nil-slice
if len(encrypted) == 0 {
return nil, nil
}
data, err := g.tokenProtector.DecodeToken(encrypted)
if err != nil {
return nil, err
}
t := &token{}
rest, err := asn1.Unmarshal(data, t)
if err != nil {
return nil, err
}
if len(rest) != 0 {
return nil, fmt.Errorf("rest when unpacking token: %d", len(rest))
}
token := &Token{
IsRetryToken: t.IsRetryToken,
RemoteAddr: decodeRemoteAddr(t.RemoteAddr),
SentTime: time.Unix(0, t.Timestamp),
}
if t.IsRetryToken {
token.OriginalDestConnectionID = protocol.ConnectionID(t.OriginalDestConnectionID)
token.RetrySrcConnectionID = protocol.ConnectionID(t.RetrySrcConnectionID)
}
return token, nil
}
// encodeRemoteAddr encodes a remote address such that it can be saved in the token
func encodeRemoteAddr(remoteAddr net.Addr) []byte {
if udpAddr, ok := remoteAddr.(*net.UDPAddr); ok {
return append([]byte{tokenPrefixIP}, udpAddr.IP...)
}
return append([]byte{tokenPrefixString}, []byte(remoteAddr.String())...)
}
// decodeRemoteAddr decodes the remote address saved in the token
func decodeRemoteAddr(data []byte) string {
// data will never be empty for a token that we generated.
// Check it to be on the safe side
if len(data) == 0 {
return ""
}
if data[0] == tokenPrefixIP {
return net.IP(data[1:]).String()
}
return string(data[1:])
}

View File

@@ -0,0 +1,89 @@
package handshake
import (
"crypto/aes"
"crypto/cipher"
"crypto/sha256"
"fmt"
"io"
"golang.org/x/crypto/hkdf"
)
// TokenProtector is used to create and verify a token
type tokenProtector interface {
// NewToken creates a new token
NewToken([]byte) ([]byte, error)
// DecodeToken decodes a token
DecodeToken([]byte) ([]byte, error)
}
const (
tokenSecretSize = 32
tokenNonceSize = 32
)
// tokenProtector is used to create and verify a token
type tokenProtectorImpl struct {
rand io.Reader
secret []byte
}
// newTokenProtector creates a source for source address tokens
func newTokenProtector(rand io.Reader) (tokenProtector, error) {
secret := make([]byte, tokenSecretSize)
if _, err := rand.Read(secret); err != nil {
return nil, err
}
return &tokenProtectorImpl{
rand: rand,
secret: secret,
}, nil
}
// NewToken encodes data into a new token.
func (s *tokenProtectorImpl) NewToken(data []byte) ([]byte, error) {
nonce := make([]byte, tokenNonceSize)
if _, err := s.rand.Read(nonce); err != nil {
return nil, err
}
aead, aeadNonce, err := s.createAEAD(nonce)
if err != nil {
return nil, err
}
return append(nonce, aead.Seal(nil, aeadNonce, data, nil)...), nil
}
// DecodeToken decodes a token.
func (s *tokenProtectorImpl) DecodeToken(p []byte) ([]byte, error) {
if len(p) < tokenNonceSize {
return nil, fmt.Errorf("token too short: %d", len(p))
}
nonce := p[:tokenNonceSize]
aead, aeadNonce, err := s.createAEAD(nonce)
if err != nil {
return nil, err
}
return aead.Open(nil, aeadNonce, p[tokenNonceSize:], nil)
}
func (s *tokenProtectorImpl) createAEAD(nonce []byte) (cipher.AEAD, []byte, error) {
h := hkdf.New(sha256.New, s.secret, nonce, []byte("quic-go token source"))
key := make([]byte, 32) // use a 32 byte key, in order to select AES-256
if _, err := io.ReadFull(h, key); err != nil {
return nil, nil, err
}
aeadNonce := make([]byte, 12)
if _, err := io.ReadFull(h, aeadNonce); err != nil {
return nil, nil, err
}
c, err := aes.NewCipher(key)
if err != nil {
return nil, nil, err
}
aead, err := cipher.NewGCM(c)
if err != nil {
return nil, nil, err
}
return aead, aeadNonce, nil
}

View File

@@ -0,0 +1,321 @@
package handshake
import (
"crypto"
"crypto/cipher"
"crypto/tls"
"encoding/binary"
"fmt"
"time"
"github.com/lucas-clemente/quic-go/internal/protocol"
"github.com/lucas-clemente/quic-go/internal/qerr"
"github.com/lucas-clemente/quic-go/internal/qtls"
"github.com/lucas-clemente/quic-go/internal/utils"
"github.com/lucas-clemente/quic-go/logging"
)
// KeyUpdateInterval is the maximum number of packets we send or receive before initiating a key update.
// It's a package-level variable to allow modifying it for testing purposes.
var KeyUpdateInterval uint64 = protocol.KeyUpdateInterval
type updatableAEAD struct {
suite *qtls.CipherSuiteTLS13
keyPhase protocol.KeyPhase
largestAcked protocol.PacketNumber
firstPacketNumber protocol.PacketNumber
handshakeConfirmed bool
keyUpdateInterval uint64
invalidPacketLimit uint64
invalidPacketCount uint64
// Time when the keys should be dropped. Keys are dropped on the next call to Open().
prevRcvAEADExpiry time.Time
prevRcvAEAD cipher.AEAD
firstRcvdWithCurrentKey protocol.PacketNumber
firstSentWithCurrentKey protocol.PacketNumber
highestRcvdPN protocol.PacketNumber // highest packet number received (which could be successfully unprotected)
numRcvdWithCurrentKey uint64
numSentWithCurrentKey uint64
rcvAEAD cipher.AEAD
sendAEAD cipher.AEAD
// caches cipher.AEAD.Overhead(). This speeds up calls to Overhead().
aeadOverhead int
nextRcvAEAD cipher.AEAD
nextSendAEAD cipher.AEAD
nextRcvTrafficSecret []byte
nextSendTrafficSecret []byte
headerDecrypter headerProtector
headerEncrypter headerProtector
rttStats *utils.RTTStats
tracer logging.ConnectionTracer
logger utils.Logger
// use a single slice to avoid allocations
nonceBuf []byte
}
var (
_ ShortHeaderOpener = &updatableAEAD{}
_ ShortHeaderSealer = &updatableAEAD{}
)
func newUpdatableAEAD(rttStats *utils.RTTStats, tracer logging.ConnectionTracer, logger utils.Logger) *updatableAEAD {
return &updatableAEAD{
firstPacketNumber: protocol.InvalidPacketNumber,
largestAcked: protocol.InvalidPacketNumber,
firstRcvdWithCurrentKey: protocol.InvalidPacketNumber,
firstSentWithCurrentKey: protocol.InvalidPacketNumber,
keyUpdateInterval: KeyUpdateInterval,
rttStats: rttStats,
tracer: tracer,
logger: logger,
}
}
func (a *updatableAEAD) rollKeys() {
if a.prevRcvAEAD != nil {
a.logger.Debugf("Dropping key phase %d ahead of scheduled time. Drop time was: %s", a.keyPhase-1, a.prevRcvAEADExpiry)
if a.tracer != nil {
a.tracer.DroppedKey(a.keyPhase - 1)
}
a.prevRcvAEADExpiry = time.Time{}
}
a.keyPhase++
a.firstRcvdWithCurrentKey = protocol.InvalidPacketNumber
a.firstSentWithCurrentKey = protocol.InvalidPacketNumber
a.numRcvdWithCurrentKey = 0
a.numSentWithCurrentKey = 0
a.prevRcvAEAD = a.rcvAEAD
a.rcvAEAD = a.nextRcvAEAD
a.sendAEAD = a.nextSendAEAD
a.nextRcvTrafficSecret = a.getNextTrafficSecret(a.suite.Hash, a.nextRcvTrafficSecret)
a.nextSendTrafficSecret = a.getNextTrafficSecret(a.suite.Hash, a.nextSendTrafficSecret)
a.nextRcvAEAD = createAEAD(a.suite, a.nextRcvTrafficSecret)
a.nextSendAEAD = createAEAD(a.suite, a.nextSendTrafficSecret)
}
func (a *updatableAEAD) startKeyDropTimer(now time.Time) {
d := 3 * a.rttStats.PTO(true)
a.logger.Debugf("Starting key drop timer to drop key phase %d (in %s)", a.keyPhase-1, d)
a.prevRcvAEADExpiry = now.Add(d)
}
func (a *updatableAEAD) getNextTrafficSecret(hash crypto.Hash, ts []byte) []byte {
return hkdfExpandLabel(hash, ts, []byte{}, "quic ku", hash.Size())
}
// For the client, this function is called before SetWriteKey.
// For the server, this function is called after SetWriteKey.
func (a *updatableAEAD) SetReadKey(suite *qtls.CipherSuiteTLS13, trafficSecret []byte) {
a.rcvAEAD = createAEAD(suite, trafficSecret)
a.headerDecrypter = newHeaderProtector(suite, trafficSecret, false)
if a.suite == nil {
a.setAEADParameters(a.rcvAEAD, suite)
}
a.nextRcvTrafficSecret = a.getNextTrafficSecret(suite.Hash, trafficSecret)
a.nextRcvAEAD = createAEAD(suite, a.nextRcvTrafficSecret)
}
// For the client, this function is called after SetReadKey.
// For the server, this function is called before SetWriteKey.
func (a *updatableAEAD) SetWriteKey(suite *qtls.CipherSuiteTLS13, trafficSecret []byte) {
a.sendAEAD = createAEAD(suite, trafficSecret)
a.headerEncrypter = newHeaderProtector(suite, trafficSecret, false)
if a.suite == nil {
a.setAEADParameters(a.sendAEAD, suite)
}
a.nextSendTrafficSecret = a.getNextTrafficSecret(suite.Hash, trafficSecret)
a.nextSendAEAD = createAEAD(suite, a.nextSendTrafficSecret)
}
func (a *updatableAEAD) setAEADParameters(aead cipher.AEAD, suite *qtls.CipherSuiteTLS13) {
a.nonceBuf = make([]byte, aead.NonceSize())
a.aeadOverhead = aead.Overhead()
a.suite = suite
switch suite.ID {
case tls.TLS_AES_128_GCM_SHA256, tls.TLS_AES_256_GCM_SHA384:
a.invalidPacketLimit = protocol.InvalidPacketLimitAES
case tls.TLS_CHACHA20_POLY1305_SHA256:
a.invalidPacketLimit = protocol.InvalidPacketLimitChaCha
default:
panic(fmt.Sprintf("unknown cipher suite %d", suite.ID))
}
}
func (a *updatableAEAD) DecodePacketNumber(wirePN protocol.PacketNumber, wirePNLen protocol.PacketNumberLen) protocol.PacketNumber {
return protocol.DecodePacketNumber(wirePNLen, a.highestRcvdPN, wirePN)
}
func (a *updatableAEAD) Open(dst, src []byte, rcvTime time.Time, pn protocol.PacketNumber, kp protocol.KeyPhaseBit, ad []byte) ([]byte, error) {
dec, err := a.open(dst, src, rcvTime, pn, kp, ad)
if err == ErrDecryptionFailed {
a.invalidPacketCount++
if a.invalidPacketCount >= a.invalidPacketLimit {
return nil, &qerr.TransportError{ErrorCode: qerr.AEADLimitReached}
}
}
if err == nil {
a.highestRcvdPN = utils.MaxPacketNumber(a.highestRcvdPN, pn)
}
return dec, err
}
func (a *updatableAEAD) open(dst, src []byte, rcvTime time.Time, pn protocol.PacketNumber, kp protocol.KeyPhaseBit, ad []byte) ([]byte, error) {
if a.prevRcvAEAD != nil && !a.prevRcvAEADExpiry.IsZero() && rcvTime.After(a.prevRcvAEADExpiry) {
a.prevRcvAEAD = nil
a.logger.Debugf("Dropping key phase %d", a.keyPhase-1)
a.prevRcvAEADExpiry = time.Time{}
if a.tracer != nil {
a.tracer.DroppedKey(a.keyPhase - 1)
}
}
binary.BigEndian.PutUint64(a.nonceBuf[len(a.nonceBuf)-8:], uint64(pn))
if kp != a.keyPhase.Bit() {
if a.keyPhase > 0 && a.firstRcvdWithCurrentKey == protocol.InvalidPacketNumber || pn < a.firstRcvdWithCurrentKey {
if a.prevRcvAEAD == nil {
return nil, ErrKeysDropped
}
// we updated the key, but the peer hasn't updated yet
dec, err := a.prevRcvAEAD.Open(dst, a.nonceBuf, src, ad)
if err != nil {
err = ErrDecryptionFailed
}
return dec, err
}
// try opening the packet with the next key phase
dec, err := a.nextRcvAEAD.Open(dst, a.nonceBuf, src, ad)
if err != nil {
return nil, ErrDecryptionFailed
}
// Opening succeeded. Check if the peer was allowed to update.
if a.keyPhase > 0 && a.firstSentWithCurrentKey == protocol.InvalidPacketNumber {
return nil, &qerr.TransportError{
ErrorCode: qerr.KeyUpdateError,
ErrorMessage: "keys updated too quickly",
}
}
a.rollKeys()
a.logger.Debugf("Peer updated keys to %d", a.keyPhase)
// The peer initiated this key update. It's safe to drop the keys for the previous generation now.
// Start a timer to drop the previous key generation.
a.startKeyDropTimer(rcvTime)
if a.tracer != nil {
a.tracer.UpdatedKey(a.keyPhase, true)
}
a.firstRcvdWithCurrentKey = pn
return dec, err
}
// The AEAD we're using here will be the qtls.aeadAESGCM13.
// It uses the nonce provided here and XOR it with the IV.
dec, err := a.rcvAEAD.Open(dst, a.nonceBuf, src, ad)
if err != nil {
return dec, ErrDecryptionFailed
}
a.numRcvdWithCurrentKey++
if a.firstRcvdWithCurrentKey == protocol.InvalidPacketNumber {
// We initiated the key updated, and now we received the first packet protected with the new key phase.
// Therefore, we are certain that the peer rolled its keys as well. Start a timer to drop the old keys.
if a.keyPhase > 0 {
a.logger.Debugf("Peer confirmed key update to phase %d", a.keyPhase)
a.startKeyDropTimer(rcvTime)
}
a.firstRcvdWithCurrentKey = pn
}
return dec, err
}
func (a *updatableAEAD) Seal(dst, src []byte, pn protocol.PacketNumber, ad []byte) []byte {
if a.firstSentWithCurrentKey == protocol.InvalidPacketNumber {
a.firstSentWithCurrentKey = pn
}
if a.firstPacketNumber == protocol.InvalidPacketNumber {
a.firstPacketNumber = pn
}
a.numSentWithCurrentKey++
binary.BigEndian.PutUint64(a.nonceBuf[len(a.nonceBuf)-8:], uint64(pn))
// The AEAD we're using here will be the qtls.aeadAESGCM13.
// It uses the nonce provided here and XOR it with the IV.
return a.sendAEAD.Seal(dst, a.nonceBuf, src, ad)
}
func (a *updatableAEAD) SetLargestAcked(pn protocol.PacketNumber) error {
if a.firstSentWithCurrentKey != protocol.InvalidPacketNumber &&
pn >= a.firstSentWithCurrentKey && a.numRcvdWithCurrentKey == 0 {
return &qerr.TransportError{
ErrorCode: qerr.KeyUpdateError,
ErrorMessage: fmt.Sprintf("received ACK for key phase %d, but peer didn't update keys", a.keyPhase),
}
}
a.largestAcked = pn
return nil
}
func (a *updatableAEAD) SetHandshakeConfirmed() {
a.handshakeConfirmed = true
}
func (a *updatableAEAD) updateAllowed() bool {
if !a.handshakeConfirmed {
return false
}
// the first key update is allowed as soon as the handshake is confirmed
return a.keyPhase == 0 ||
// subsequent key updates as soon as a packet sent with that key phase has been acknowledged
(a.firstSentWithCurrentKey != protocol.InvalidPacketNumber &&
a.largestAcked != protocol.InvalidPacketNumber &&
a.largestAcked >= a.firstSentWithCurrentKey)
}
func (a *updatableAEAD) shouldInitiateKeyUpdate() bool {
if !a.updateAllowed() {
return false
}
if a.numRcvdWithCurrentKey >= a.keyUpdateInterval {
a.logger.Debugf("Received %d packets with current key phase. Initiating key update to the next key phase: %d", a.numRcvdWithCurrentKey, a.keyPhase+1)
return true
}
if a.numSentWithCurrentKey >= a.keyUpdateInterval {
a.logger.Debugf("Sent %d packets with current key phase. Initiating key update to the next key phase: %d", a.numSentWithCurrentKey, a.keyPhase+1)
return true
}
return false
}
func (a *updatableAEAD) KeyPhase() protocol.KeyPhaseBit {
if a.shouldInitiateKeyUpdate() {
a.rollKeys()
a.logger.Debugf("Initiating key update to key phase %d", a.keyPhase)
if a.tracer != nil {
a.tracer.UpdatedKey(a.keyPhase, false)
}
}
return a.keyPhase.Bit()
}
func (a *updatableAEAD) Overhead() int {
return a.aeadOverhead
}
func (a *updatableAEAD) EncryptHeader(sample []byte, firstByte *byte, hdrBytes []byte) {
a.headerEncrypter.EncryptHeader(sample, firstByte, hdrBytes)
}
func (a *updatableAEAD) DecryptHeader(sample []byte, firstByte *byte, hdrBytes []byte) {
a.headerDecrypter.DecryptHeader(sample, firstByte, hdrBytes)
}
func (a *updatableAEAD) FirstPacketNumber() protocol.PacketNumber {
return a.firstPacketNumber
}

View File

@@ -0,0 +1,33 @@
package logutils
import (
"github.com/lucas-clemente/quic-go/internal/protocol"
"github.com/lucas-clemente/quic-go/internal/wire"
"github.com/lucas-clemente/quic-go/logging"
)
// ConvertFrame converts a wire.Frame into a logging.Frame.
// This makes it possible for external packages to access the frames.
// Furthermore, it removes the data slices from CRYPTO and STREAM frames.
func ConvertFrame(frame wire.Frame) logging.Frame {
switch f := frame.(type) {
case *wire.CryptoFrame:
return &logging.CryptoFrame{
Offset: f.Offset,
Length: protocol.ByteCount(len(f.Data)),
}
case *wire.StreamFrame:
return &logging.StreamFrame{
StreamID: f.StreamID,
Offset: f.Offset,
Length: f.DataLen(),
Fin: f.Fin,
}
case *wire.DatagramFrame:
return &logging.DatagramFrame{
Length: logging.ByteCount(len(f.Data)),
}
default:
return logging.Frame(frame)
}
}

View File

@@ -0,0 +1,69 @@
package protocol
import (
"bytes"
"crypto/rand"
"fmt"
"io"
)
// A ConnectionID in QUIC
type ConnectionID []byte
const maxConnectionIDLen = 20
// GenerateConnectionID generates a connection ID using cryptographic random
func GenerateConnectionID(len int) (ConnectionID, error) {
b := make([]byte, len)
if _, err := rand.Read(b); err != nil {
return nil, err
}
return ConnectionID(b), nil
}
// GenerateConnectionIDForInitial generates a connection ID for the Initial packet.
// It uses a length randomly chosen between 8 and 20 bytes.
func GenerateConnectionIDForInitial() (ConnectionID, error) {
r := make([]byte, 1)
if _, err := rand.Read(r); err != nil {
return nil, err
}
len := MinConnectionIDLenInitial + int(r[0])%(maxConnectionIDLen-MinConnectionIDLenInitial+1)
return GenerateConnectionID(len)
}
// ReadConnectionID reads a connection ID of length len from the given io.Reader.
// It returns io.EOF if there are not enough bytes to read.
func ReadConnectionID(r io.Reader, len int) (ConnectionID, error) {
if len == 0 {
return nil, nil
}
c := make(ConnectionID, len)
_, err := io.ReadFull(r, c)
if err == io.ErrUnexpectedEOF {
return nil, io.EOF
}
return c, err
}
// Equal says if two connection IDs are equal
func (c ConnectionID) Equal(other ConnectionID) bool {
return bytes.Equal(c, other)
}
// Len returns the length of the connection ID in bytes
func (c ConnectionID) Len() int {
return len(c)
}
// Bytes returns the byte representation
func (c ConnectionID) Bytes() []byte {
return []byte(c)
}
func (c ConnectionID) String() string {
if c.Len() == 0 {
return "(empty)"
}
return fmt.Sprintf("%x", c.Bytes())
}

View File

@@ -0,0 +1,30 @@
package protocol
// EncryptionLevel is the encryption level
// Default value is Unencrypted
type EncryptionLevel uint8
const (
// EncryptionInitial is the Initial encryption level
EncryptionInitial EncryptionLevel = 1 + iota
// EncryptionHandshake is the Handshake encryption level
EncryptionHandshake
// Encryption0RTT is the 0-RTT encryption level
Encryption0RTT
// Encryption1RTT is the 1-RTT encryption level
Encryption1RTT
)
func (e EncryptionLevel) String() string {
switch e {
case EncryptionInitial:
return "Initial"
case EncryptionHandshake:
return "Handshake"
case Encryption0RTT:
return "0-RTT"
case Encryption1RTT:
return "1-RTT"
}
return "unknown"
}

View File

@@ -0,0 +1,36 @@
package protocol
// KeyPhase is the key phase
type KeyPhase uint64
// Bit determines the key phase bit
func (p KeyPhase) Bit() KeyPhaseBit {
if p%2 == 0 {
return KeyPhaseZero
}
return KeyPhaseOne
}
// KeyPhaseBit is the key phase bit
type KeyPhaseBit uint8
const (
// KeyPhaseUndefined is an undefined key phase
KeyPhaseUndefined KeyPhaseBit = iota
// KeyPhaseZero is key phase 0
KeyPhaseZero
// KeyPhaseOne is key phase 1
KeyPhaseOne
)
func (p KeyPhaseBit) String() string {
//nolint:exhaustive
switch p {
case KeyPhaseZero:
return "0"
case KeyPhaseOne:
return "1"
default:
return "undefined"
}
}

View File

@@ -0,0 +1,79 @@
package protocol
// A PacketNumber in QUIC
type PacketNumber int64
// InvalidPacketNumber is a packet number that is never sent.
// In QUIC, 0 is a valid packet number.
const InvalidPacketNumber PacketNumber = -1
// PacketNumberLen is the length of the packet number in bytes
type PacketNumberLen uint8
const (
// PacketNumberLen1 is a packet number length of 1 byte
PacketNumberLen1 PacketNumberLen = 1
// PacketNumberLen2 is a packet number length of 2 bytes
PacketNumberLen2 PacketNumberLen = 2
// PacketNumberLen3 is a packet number length of 3 bytes
PacketNumberLen3 PacketNumberLen = 3
// PacketNumberLen4 is a packet number length of 4 bytes
PacketNumberLen4 PacketNumberLen = 4
)
// DecodePacketNumber calculates the packet number based on the received packet number, its length and the last seen packet number
func DecodePacketNumber(
packetNumberLength PacketNumberLen,
lastPacketNumber PacketNumber,
wirePacketNumber PacketNumber,
) PacketNumber {
var epochDelta PacketNumber
switch packetNumberLength {
case PacketNumberLen1:
epochDelta = PacketNumber(1) << 8
case PacketNumberLen2:
epochDelta = PacketNumber(1) << 16
case PacketNumberLen3:
epochDelta = PacketNumber(1) << 24
case PacketNumberLen4:
epochDelta = PacketNumber(1) << 32
}
epoch := lastPacketNumber & ^(epochDelta - 1)
var prevEpochBegin PacketNumber
if epoch > epochDelta {
prevEpochBegin = epoch - epochDelta
}
nextEpochBegin := epoch + epochDelta
return closestTo(
lastPacketNumber+1,
epoch+wirePacketNumber,
closestTo(lastPacketNumber+1, prevEpochBegin+wirePacketNumber, nextEpochBegin+wirePacketNumber),
)
}
func closestTo(target, a, b PacketNumber) PacketNumber {
if delta(target, a) < delta(target, b) {
return a
}
return b
}
func delta(a, b PacketNumber) PacketNumber {
if a < b {
return b - a
}
return a - b
}
// GetPacketNumberLengthForHeader gets the length of the packet number for the public header
// it never chooses a PacketNumberLen of 1 byte, since this is too short under certain circumstances
func GetPacketNumberLengthForHeader(packetNumber, leastUnacked PacketNumber) PacketNumberLen {
diff := uint64(packetNumber - leastUnacked)
if diff < (1 << (16 - 1)) {
return PacketNumberLen2
}
if diff < (1 << (24 - 1)) {
return PacketNumberLen3
}
return PacketNumberLen4
}

View File

@@ -0,0 +1,195 @@
package protocol
import "time"
// DesiredReceiveBufferSize is the kernel UDP receive buffer size that we'd like to use.
const DesiredReceiveBufferSize = (1 << 20) * 2 // 2 MB
// InitialPacketSizeIPv4 is the maximum packet size that we use for sending IPv4 packets.
const InitialPacketSizeIPv4 = 1252
// InitialPacketSizeIPv6 is the maximum packet size that we use for sending IPv6 packets.
const InitialPacketSizeIPv6 = 1232
// MaxCongestionWindowPackets is the maximum congestion window in packet.
const MaxCongestionWindowPackets = 10000
// MaxUndecryptablePackets limits the number of undecryptable packets that are queued in the session.
const MaxUndecryptablePackets = 32
// ConnectionFlowControlMultiplier determines how much larger the connection flow control windows needs to be relative to any stream's flow control window
// This is the value that Chromium is using
const ConnectionFlowControlMultiplier = 1.5
// DefaultInitialMaxStreamData is the default initial stream-level flow control window for receiving data
const DefaultInitialMaxStreamData = (1 << 10) * 512 // 512 kb
// DefaultInitialMaxData is the connection-level flow control window for receiving data
const DefaultInitialMaxData = ConnectionFlowControlMultiplier * DefaultInitialMaxStreamData
// DefaultMaxReceiveStreamFlowControlWindow is the default maximum stream-level flow control window for receiving data
const DefaultMaxReceiveStreamFlowControlWindow = 6 * (1 << 20) // 6 MB
// DefaultMaxReceiveConnectionFlowControlWindow is the default connection-level flow control window for receiving data
const DefaultMaxReceiveConnectionFlowControlWindow = 15 * (1 << 20) // 15 MB
// WindowUpdateThreshold is the fraction of the receive window that has to be consumed before an higher offset is advertised to the client
const WindowUpdateThreshold = 0.25
// DefaultMaxIncomingStreams is the maximum number of streams that a peer may open
const DefaultMaxIncomingStreams = 100
// DefaultMaxIncomingUniStreams is the maximum number of unidirectional streams that a peer may open
const DefaultMaxIncomingUniStreams = 100
// MaxServerUnprocessedPackets is the max number of packets stored in the server that are not yet processed.
const MaxServerUnprocessedPackets = 1024
// MaxSessionUnprocessedPackets is the max number of packets stored in each session that are not yet processed.
const MaxSessionUnprocessedPackets = 256
// SkipPacketInitialPeriod is the initial period length used for packet number skipping to prevent an Optimistic ACK attack.
// Every time a packet number is skipped, the period is doubled, up to SkipPacketMaxPeriod.
const SkipPacketInitialPeriod PacketNumber = 256
// SkipPacketMaxPeriod is the maximum period length used for packet number skipping.
const SkipPacketMaxPeriod PacketNumber = 128 * 1024
// MaxAcceptQueueSize is the maximum number of sessions that the server queues for accepting.
// If the queue is full, new connection attempts will be rejected.
const MaxAcceptQueueSize = 32
// TokenValidity is the duration that a (non-retry) token is considered valid
const TokenValidity = 24 * time.Hour
// RetryTokenValidity is the duration that a retry token is considered valid
const RetryTokenValidity = 10 * time.Second
// MaxOutstandingSentPackets is maximum number of packets saved for retransmission.
// When reached, it imposes a soft limit on sending new packets:
// Sending ACKs and retransmission is still allowed, but now new regular packets can be sent.
const MaxOutstandingSentPackets = 2 * MaxCongestionWindowPackets
// MaxTrackedSentPackets is maximum number of sent packets saved for retransmission.
// When reached, no more packets will be sent.
// This value *must* be larger than MaxOutstandingSentPackets.
const MaxTrackedSentPackets = MaxOutstandingSentPackets * 5 / 4
// MaxNonAckElicitingAcks is the maximum number of packets containing an ACK,
// but no ack-eliciting frames, that we send in a row
const MaxNonAckElicitingAcks = 19
// MaxStreamFrameSorterGaps is the maximum number of gaps between received StreamFrames
// prevents DoS attacks against the streamFrameSorter
const MaxStreamFrameSorterGaps = 1000
// MinStreamFrameBufferSize is the minimum data length of a received STREAM frame
// that we use the buffer for. This protects against a DoS where an attacker would send us
// very small STREAM frames to consume a lot of memory.
const MinStreamFrameBufferSize = 128
// MinCoalescedPacketSize is the minimum size of a coalesced packet that we pack.
// If a packet has less than this number of bytes, we won't coalesce any more packets onto it.
const MinCoalescedPacketSize = 128
// MaxCryptoStreamOffset is the maximum offset allowed on any of the crypto streams.
// This limits the size of the ClientHello and Certificates that can be received.
const MaxCryptoStreamOffset = 16 * (1 << 10)
// MinRemoteIdleTimeout is the minimum value that we accept for the remote idle timeout
const MinRemoteIdleTimeout = 5 * time.Second
// DefaultIdleTimeout is the default idle timeout
const DefaultIdleTimeout = 30 * time.Second
// DefaultHandshakeIdleTimeout is the default idle timeout used before handshake completion.
const DefaultHandshakeIdleTimeout = 5 * time.Second
// DefaultHandshakeTimeout is the default timeout for a connection until the crypto handshake succeeds.
const DefaultHandshakeTimeout = 10 * time.Second
// MaxKeepAliveInterval is the maximum time until we send a packet to keep a connection alive.
// It should be shorter than the time that NATs clear their mapping.
const MaxKeepAliveInterval = 20 * time.Second
// RetiredConnectionIDDeleteTimeout is the time we keep closed sessions around in order to retransmit the CONNECTION_CLOSE.
// after this time all information about the old connection will be deleted
const RetiredConnectionIDDeleteTimeout = 5 * time.Second
// MinStreamFrameSize is the minimum size that has to be left in a packet, so that we add another STREAM frame.
// This avoids splitting up STREAM frames into small pieces, which has 2 advantages:
// 1. it reduces the framing overhead
// 2. it reduces the head-of-line blocking, when a packet is lost
const MinStreamFrameSize ByteCount = 128
// MaxPostHandshakeCryptoFrameSize is the maximum size of CRYPTO frames
// we send after the handshake completes.
const MaxPostHandshakeCryptoFrameSize = 1000
// MaxAckFrameSize is the maximum size for an ACK frame that we write
// Due to the varint encoding, ACK frames can grow (almost) indefinitely large.
// The MaxAckFrameSize should be large enough to encode many ACK range,
// but must ensure that a maximum size ACK frame fits into one packet.
const MaxAckFrameSize ByteCount = 1000
// MaxDatagramFrameSize is the maximum size of a DATAGRAM frame as defined in
// https://datatracker.ietf.org/doc/draft-pauly-quic-datagram/.
// The size is chosen such that a DATAGRAM frame fits into a QUIC packet.
const MaxDatagramFrameSize ByteCount = 1220
// DatagramRcvQueueLen is the length of the receive queue for DATAGRAM frames.
// See https://datatracker.ietf.org/doc/draft-pauly-quic-datagram/.
const DatagramRcvQueueLen = 128
// MaxNumAckRanges is the maximum number of ACK ranges that we send in an ACK frame.
// It also serves as a limit for the packet history.
// If at any point we keep track of more ranges, old ranges are discarded.
const MaxNumAckRanges = 32
// MinPacingDelay is the minimum duration that is used for packet pacing
// If the packet packing frequency is higher, multiple packets might be sent at once.
// Example: For a packet pacing delay of 200μs, we would send 5 packets at once, wait for 1ms, and so forth.
const MinPacingDelay = time.Millisecond
// DefaultConnectionIDLength is the connection ID length that is used for multiplexed connections
// if no other value is configured.
const DefaultConnectionIDLength = 4
// MaxActiveConnectionIDs is the number of connection IDs that we're storing.
const MaxActiveConnectionIDs = 4
// MaxIssuedConnectionIDs is the maximum number of connection IDs that we're issuing at the same time.
const MaxIssuedConnectionIDs = 6
// PacketsPerConnectionID is the number of packets we send using one connection ID.
// If the peer provices us with enough new connection IDs, we switch to a new connection ID.
const PacketsPerConnectionID = 10000
// AckDelayExponent is the ack delay exponent used when sending ACKs.
const AckDelayExponent = 3
// Estimated timer granularity.
// The loss detection timer will not be set to a value smaller than granularity.
const TimerGranularity = time.Millisecond
// MaxAckDelay is the maximum time by which we delay sending ACKs.
const MaxAckDelay = 25 * time.Millisecond
// MaxAckDelayInclGranularity is the max_ack_delay including the timer granularity.
// This is the value that should be advertised to the peer.
const MaxAckDelayInclGranularity = MaxAckDelay + TimerGranularity
// KeyUpdateInterval is the maximum number of packets we send or receive before initiating a key update.
const KeyUpdateInterval = 100 * 1000
// Max0RTTQueueingDuration is the maximum time that we store 0-RTT packets in order to wait for the corresponding Initial to be received.
const Max0RTTQueueingDuration = 100 * time.Millisecond
// Max0RTTQueues is the maximum number of connections that we buffer 0-RTT packets for.
const Max0RTTQueues = 32
// Max0RTTQueueLen is the maximum number of 0-RTT packets that we buffer for each connection.
// When a new session is created, all buffered packets are passed to the session immediately.
// To avoid blocking, this value has to be smaller than MaxSessionUnprocessedPackets.
// To avoid packets being dropped as undecryptable by the session, this value has to be smaller than MaxUndecryptablePackets.
const Max0RTTQueueLen = 31

View File

@@ -0,0 +1,26 @@
package protocol
// Perspective determines if we're acting as a server or a client
type Perspective int
// the perspectives
const (
PerspectiveServer Perspective = 1
PerspectiveClient Perspective = 2
)
// Opposite returns the perspective of the peer
func (p Perspective) Opposite() Perspective {
return 3 - p
}
func (p Perspective) String() string {
switch p {
case PerspectiveServer:
return "Server"
case PerspectiveClient:
return "Client"
default:
return "invalid perspective"
}
}

View File

@@ -0,0 +1,97 @@
package protocol
import (
"fmt"
"time"
)
// The PacketType is the Long Header Type
type PacketType uint8
const (
// PacketTypeInitial is the packet type of an Initial packet
PacketTypeInitial PacketType = 1 + iota
// PacketTypeRetry is the packet type of a Retry packet
PacketTypeRetry
// PacketTypeHandshake is the packet type of a Handshake packet
PacketTypeHandshake
// PacketType0RTT is the packet type of a 0-RTT packet
PacketType0RTT
)
func (t PacketType) String() string {
switch t {
case PacketTypeInitial:
return "Initial"
case PacketTypeRetry:
return "Retry"
case PacketTypeHandshake:
return "Handshake"
case PacketType0RTT:
return "0-RTT Protected"
default:
return fmt.Sprintf("unknown packet type: %d", t)
}
}
type ECN uint8
const (
ECNNon ECN = iota // 00
ECT1 // 01
ECT0 // 10
ECNCE // 11
)
// A ByteCount in QUIC
type ByteCount int64
// MaxByteCount is the maximum value of a ByteCount
const MaxByteCount = ByteCount(1<<62 - 1)
// InvalidByteCount is an invalid byte count
const InvalidByteCount ByteCount = -1
// A StatelessResetToken is a stateless reset token.
type StatelessResetToken [16]byte
// MaxPacketBufferSize maximum packet size of any QUIC packet, based on
// ethernet's max size, minus the IP and UDP headers. IPv6 has a 40 byte header,
// UDP adds an additional 8 bytes. This is a total overhead of 48 bytes.
// Ethernet's max packet size is 1500 bytes, 1500 - 48 = 1452.
const MaxPacketBufferSize ByteCount = 1452
// MinInitialPacketSize is the minimum size an Initial packet is required to have.
const MinInitialPacketSize = 1200
// MinUnknownVersionPacketSize is the minimum size a packet with an unknown version
// needs to have in order to trigger a Version Negotiation packet.
const MinUnknownVersionPacketSize = MinInitialPacketSize
// MinStatelessResetSize is the minimum size of a stateless reset packet that we send
const MinStatelessResetSize = 1 /* first byte */ + 20 /* max. conn ID length */ + 4 /* max. packet number length */ + 1 /* min. payload length */ + 16 /* token */
// MinConnectionIDLenInitial is the minimum length of the destination connection ID on an Initial packet.
const MinConnectionIDLenInitial = 8
// DefaultAckDelayExponent is the default ack delay exponent
const DefaultAckDelayExponent = 3
// MaxAckDelayExponent is the maximum ack delay exponent
const MaxAckDelayExponent = 20
// DefaultMaxAckDelay is the default max_ack_delay
const DefaultMaxAckDelay = 25 * time.Millisecond
// MaxMaxAckDelay is the maximum max_ack_delay
const MaxMaxAckDelay = (1<<14 - 1) * time.Millisecond
// MaxConnIDLen is the maximum length of the connection ID
const MaxConnIDLen = 20
// InvalidPacketLimitAES is the maximum number of packets that we can fail to decrypt when using
// AEAD_AES_128_GCM or AEAD_AES_265_GCM.
const InvalidPacketLimitAES = 1 << 52
// InvalidPacketLimitChaCha is the maximum number of packets that we can fail to decrypt when using AEAD_CHACHA20_POLY1305.
const InvalidPacketLimitChaCha = 1 << 36

View File

@@ -0,0 +1,76 @@
package protocol
// StreamType encodes if this is a unidirectional or bidirectional stream
type StreamType uint8
const (
// StreamTypeUni is a unidirectional stream
StreamTypeUni StreamType = iota
// StreamTypeBidi is a bidirectional stream
StreamTypeBidi
)
// InvalidPacketNumber is a stream ID that is invalid.
// The first valid stream ID in QUIC is 0.
const InvalidStreamID StreamID = -1
// StreamNum is the stream number
type StreamNum int64
const (
// InvalidStreamNum is an invalid stream number.
InvalidStreamNum = -1
// MaxStreamCount is the maximum stream count value that can be sent in MAX_STREAMS frames
// and as the stream count in the transport parameters
MaxStreamCount StreamNum = 1 << 60
)
// StreamID calculates the stream ID.
func (s StreamNum) StreamID(stype StreamType, pers Perspective) StreamID {
if s == 0 {
return InvalidStreamID
}
var first StreamID
switch stype {
case StreamTypeBidi:
switch pers {
case PerspectiveClient:
first = 0
case PerspectiveServer:
first = 1
}
case StreamTypeUni:
switch pers {
case PerspectiveClient:
first = 2
case PerspectiveServer:
first = 3
}
}
return first + 4*StreamID(s-1)
}
// A StreamID in QUIC
type StreamID int64
// InitiatedBy says if the stream was initiated by the client or by the server
func (s StreamID) InitiatedBy() Perspective {
if s%2 == 0 {
return PerspectiveClient
}
return PerspectiveServer
}
// Type says if this is a unidirectional or bidirectional stream
func (s StreamID) Type() StreamType {
if s%4 >= 2 {
return StreamTypeUni
}
return StreamTypeBidi
}
// StreamNum returns how many streams in total are below this
// Example: for stream 9 it returns 3 (i.e. streams 1, 5 and 9)
func (s StreamID) StreamNum() StreamNum {
return StreamNum(s/4) + 1
}

View File

@@ -0,0 +1,134 @@
package protocol
import (
"crypto/rand"
"encoding/binary"
"fmt"
"math"
)
// VersionNumber is a version number as int
type VersionNumber uint32
// gQUIC version range as defined in the wiki: https://github.com/quicwg/base-drafts/wiki/QUIC-Versions
const (
gquicVersion0 = 0x51303030
maxGquicVersion = 0x51303439
)
// The version numbers, making grepping easier
const (
VersionTLS VersionNumber = 0x1
VersionWhatever VersionNumber = math.MaxUint32 - 1 // for when the version doesn't matter
VersionUnknown VersionNumber = math.MaxUint32
VersionDraft29 VersionNumber = 0xff00001d
VersionDraft32 VersionNumber = 0xff000020
VersionDraft34 VersionNumber = 0xff000022
Version1 VersionNumber = 0x1
)
// SupportedVersions lists the versions that the server supports
// must be in sorted descending order
var SupportedVersions = []VersionNumber{Version1, VersionDraft34, VersionDraft32, VersionDraft29}
// IsValidVersion says if the version is known to quic-go
func IsValidVersion(v VersionNumber) bool {
return v == VersionTLS || IsSupportedVersion(SupportedVersions, v)
}
func (vn VersionNumber) String() string {
// For releases, VersionTLS will be set to a draft version.
// A switch statement can't contain duplicate cases.
if vn == VersionTLS && VersionTLS != VersionDraft29 && VersionTLS != VersionDraft32 && VersionTLS != Version1 {
return "TLS dev version (WIP)"
}
//nolint:exhaustive
switch vn {
case VersionWhatever:
return "whatever"
case VersionUnknown:
return "unknown"
case VersionDraft29:
return "draft-29"
case VersionDraft32:
return "draft-32"
case VersionDraft34:
return "draft-34"
case Version1:
return "v1"
default:
if vn.isGQUIC() {
return fmt.Sprintf("gQUIC %d", vn.toGQUICVersion())
}
return fmt.Sprintf("%#x", uint32(vn))
}
}
func (vn VersionNumber) isGQUIC() bool {
return vn > gquicVersion0 && vn <= maxGquicVersion
}
func (vn VersionNumber) toGQUICVersion() int {
return int(10*(vn-gquicVersion0)/0x100) + int(vn%0x10)
}
// UseRetireBugBackwardsCompatibilityMode says if it is necessary to use the backwards compatilibity mode.
// This is only the case if it 1. is enabled and 2. draft-29 is used.
func UseRetireBugBackwardsCompatibilityMode(enabled bool, v VersionNumber) bool {
return enabled && v == VersionDraft29
}
// IsSupportedVersion returns true if the server supports this version
func IsSupportedVersion(supported []VersionNumber, v VersionNumber) bool {
for _, t := range supported {
if t == v {
return true
}
}
return false
}
// ChooseSupportedVersion finds the best version in the overlap of ours and theirs
// ours is a slice of versions that we support, sorted by our preference (descending)
// theirs is a slice of versions offered by the peer. The order does not matter.
// The bool returned indicates if a matching version was found.
func ChooseSupportedVersion(ours, theirs []VersionNumber) (VersionNumber, bool) {
for _, ourVer := range ours {
for _, theirVer := range theirs {
if ourVer == theirVer {
return ourVer, true
}
}
}
return 0, false
}
// generateReservedVersion generates a reserved version number (v & 0x0f0f0f0f == 0x0a0a0a0a)
func generateReservedVersion() VersionNumber {
b := make([]byte, 4)
_, _ = rand.Read(b) // ignore the error here. Failure to read random data doesn't break anything
return VersionNumber((binary.BigEndian.Uint32(b) | 0x0a0a0a0a) & 0xfafafafa)
}
// GetGreasedVersions adds one reserved version number to a slice of version numbers, at a random position
func GetGreasedVersions(supported []VersionNumber) []VersionNumber {
b := make([]byte, 1)
_, _ = rand.Read(b) // ignore the error here. Failure to read random data doesn't break anything
randPos := int(b[0]) % (len(supported) + 1)
greased := make([]VersionNumber, len(supported)+1)
copy(greased, supported[:randPos])
greased[randPos] = generateReservedVersion()
copy(greased[randPos+1:], supported[randPos:])
return greased
}
// StripGreasedVersions strips all greased versions from a slice of versions
func StripGreasedVersions(versions []VersionNumber) []VersionNumber {
realVersions := make([]VersionNumber, 0, len(versions))
for _, v := range versions {
if v&0x0f0f0f0f != 0x0a0a0a0a {
realVersions = append(realVersions, v)
}
}
return realVersions
}

View File

@@ -0,0 +1,88 @@
package qerr
import (
"fmt"
"github.com/lucas-clemente/quic-go/internal/qtls"
)
// TransportErrorCode is a QUIC transport error.
type TransportErrorCode uint64
// The error codes defined by QUIC
const (
NoError TransportErrorCode = 0x0
InternalError TransportErrorCode = 0x1
ConnectionRefused TransportErrorCode = 0x2
FlowControlError TransportErrorCode = 0x3
StreamLimitError TransportErrorCode = 0x4
StreamStateError TransportErrorCode = 0x5
FinalSizeError TransportErrorCode = 0x6
FrameEncodingError TransportErrorCode = 0x7
TransportParameterError TransportErrorCode = 0x8
ConnectionIDLimitError TransportErrorCode = 0x9
ProtocolViolation TransportErrorCode = 0xa
InvalidToken TransportErrorCode = 0xb
ApplicationErrorErrorCode TransportErrorCode = 0xc
CryptoBufferExceeded TransportErrorCode = 0xd
KeyUpdateError TransportErrorCode = 0xe
AEADLimitReached TransportErrorCode = 0xf
NoViablePathError TransportErrorCode = 0x10
)
func (e TransportErrorCode) IsCryptoError() bool {
return e >= 0x100 && e < 0x200
}
// Message is a description of the error.
// It only returns a non-empty string for crypto errors.
func (e TransportErrorCode) Message() string {
if !e.IsCryptoError() {
return ""
}
return qtls.Alert(e - 0x100).Error()
}
func (e TransportErrorCode) String() string {
switch e {
case NoError:
return "NO_ERROR"
case InternalError:
return "INTERNAL_ERROR"
case ConnectionRefused:
return "CONNECTION_REFUSED"
case FlowControlError:
return "FLOW_CONTROL_ERROR"
case StreamLimitError:
return "STREAM_LIMIT_ERROR"
case StreamStateError:
return "STREAM_STATE_ERROR"
case FinalSizeError:
return "FINAL_SIZE_ERROR"
case FrameEncodingError:
return "FRAME_ENCODING_ERROR"
case TransportParameterError:
return "TRANSPORT_PARAMETER_ERROR"
case ConnectionIDLimitError:
return "CONNECTION_ID_LIMIT_ERROR"
case ProtocolViolation:
return "PROTOCOL_VIOLATION"
case InvalidToken:
return "INVALID_TOKEN"
case ApplicationErrorErrorCode:
return "APPLICATION_ERROR"
case CryptoBufferExceeded:
return "CRYPTO_BUFFER_EXCEEDED"
case KeyUpdateError:
return "KEY_UPDATE_ERROR"
case AEADLimitReached:
return "AEAD_LIMIT_REACHED"
case NoViablePathError:
return "NO_VIABLE_PATH"
default:
if e.IsCryptoError() {
return fmt.Sprintf("CRYPTO_ERROR (%#x)", uint16(e))
}
return fmt.Sprintf("unknown error code: %#x", uint16(e))
}
}

View File

@@ -0,0 +1,106 @@
package qerr
import (
"fmt"
"net"
"github.com/lucas-clemente/quic-go/internal/protocol"
)
var (
ErrHandshakeTimeout = &HandshakeTimeoutError{}
ErrIdleTimeout = &IdleTimeoutError{}
)
type TransportError struct {
Remote bool
FrameType uint64
ErrorCode TransportErrorCode
ErrorMessage string
}
var _ error = &TransportError{}
// NewCryptoError create a new TransportError instance for a crypto error
func NewCryptoError(tlsAlert uint8, errorMessage string) *TransportError {
return &TransportError{
ErrorCode: 0x100 + TransportErrorCode(tlsAlert),
ErrorMessage: errorMessage,
}
}
func (e *TransportError) Error() string {
str := e.ErrorCode.String()
if e.FrameType != 0 {
str += fmt.Sprintf(" (frame type: %#x)", e.FrameType)
}
msg := e.ErrorMessage
if len(msg) == 0 {
msg = e.ErrorCode.Message()
}
if len(msg) == 0 {
return str
}
return str + ": " + msg
}
// An ApplicationErrorCode is an application-defined error code.
type ApplicationErrorCode uint64
// A StreamErrorCode is an error code used to cancel streams.
type StreamErrorCode uint64
type ApplicationError struct {
Remote bool
ErrorCode ApplicationErrorCode
ErrorMessage string
}
var _ error = &ApplicationError{}
func (e *ApplicationError) Error() string {
if len(e.ErrorMessage) == 0 {
return fmt.Sprintf("Application error %#x", e.ErrorCode)
}
return fmt.Sprintf("Application error %#x: %s", e.ErrorCode, e.ErrorMessage)
}
type IdleTimeoutError struct{}
var _ error = &IdleTimeoutError{}
func (e *IdleTimeoutError) Timeout() bool { return true }
func (e *IdleTimeoutError) Temporary() bool { return false }
func (e *IdleTimeoutError) Error() string { return "timeout: no recent network activity" }
type HandshakeTimeoutError struct{}
var _ error = &HandshakeTimeoutError{}
func (e *HandshakeTimeoutError) Timeout() bool { return true }
func (e *HandshakeTimeoutError) Temporary() bool { return false }
func (e *HandshakeTimeoutError) Error() string { return "timeout: handshake did not complete in time" }
// A VersionNegotiationError occurs when the client and the server can't agree on a QUIC version.
type VersionNegotiationError struct {
Ours []protocol.VersionNumber
Theirs []protocol.VersionNumber
}
func (e *VersionNegotiationError) Error() string {
return fmt.Sprintf("no compatible QUIC version found (we support %s, server offered %s)", e.Ours, e.Theirs)
}
// A StatelessResetError occurs when we receive a stateless reset.
type StatelessResetError struct {
Token protocol.StatelessResetToken
}
var _ net.Error = &StatelessResetError{}
func (e *StatelessResetError) Error() string {
return fmt.Sprintf("received a stateless reset with token %x", e.Token)
}
func (e *StatelessResetError) Timeout() bool { return false }
func (e *StatelessResetError) Temporary() bool { return true }

View File

@@ -0,0 +1,55 @@
// +build go1.16
package qerr
import (
"net"
)
func (e *TransportError) Is(target error) bool {
_, ok := target.(*TransportError)
if ok {
return true
}
return target == net.ErrClosed
}
func (e *ApplicationError) Is(target error) bool {
_, ok := target.(*ApplicationError)
if ok {
return true
}
return target == net.ErrClosed
}
func (e *IdleTimeoutError) Is(target error) bool {
_, ok := target.(*IdleTimeoutError)
if ok {
return true
}
return target == net.ErrClosed
}
func (e *HandshakeTimeoutError) Is(target error) bool {
_, ok := target.(*HandshakeTimeoutError)
if ok {
return true
}
return target == net.ErrClosed
}
func (e *VersionNegotiationError) Is(target error) bool {
_, ok := target.(*VersionNegotiationError)
if ok {
return true
}
return target == net.ErrClosed
}
func (e *StatelessResetError) Is(target error) bool {
_, ok := target.(*StatelessResetError)
if ok {
return true
}
return target == net.ErrClosed
}

View File

@@ -0,0 +1,33 @@
// +build !go1.16
package qerr
func (e *TransportError) Is(target error) bool {
_, ok := target.(*TransportError)
return ok
}
func (e *ApplicationError) Is(target error) bool {
_, ok := target.(*ApplicationError)
return ok
}
func (e *IdleTimeoutError) Is(target error) bool {
_, ok := target.(*IdleTimeoutError)
return ok
}
func (e *HandshakeTimeoutError) Is(target error) bool {
_, ok := target.(*HandshakeTimeoutError)
return ok
}
func (e *VersionNegotiationError) Is(target error) bool {
_, ok := target.(*VersionNegotiationError)
return ok
}
func (e *StatelessResetError) Is(target error) bool {
_, ok := target.(*StatelessResetError)
return ok
}

View File

@@ -0,0 +1,100 @@
// +build go1.15
// +build !go1.16
package qtls
import (
"crypto"
"crypto/cipher"
"crypto/tls"
"net"
"unsafe"
"github.com/marten-seemann/qtls-go1-15"
)
type (
// Alert is a TLS alert
Alert = qtls.Alert
// A Certificate is qtls.Certificate.
Certificate = qtls.Certificate
// CertificateRequestInfo contains inforamtion about a certificate request.
CertificateRequestInfo = qtls.CertificateRequestInfo
// A CipherSuiteTLS13 is a cipher suite for TLS 1.3
CipherSuiteTLS13 = qtls.CipherSuiteTLS13
// ClientHelloInfo contains information about a ClientHello.
ClientHelloInfo = qtls.ClientHelloInfo
// ClientSessionCache is a cache used for session resumption.
ClientSessionCache = qtls.ClientSessionCache
// ClientSessionState is a state needed for session resumption.
ClientSessionState = qtls.ClientSessionState
// A Config is a qtls.Config.
Config = qtls.Config
// A Conn is a qtls.Conn.
Conn = qtls.Conn
// ConnectionState contains information about the state of the connection.
ConnectionState = qtls.ConnectionStateWith0RTT
// EncryptionLevel is the encryption level of a message.
EncryptionLevel = qtls.EncryptionLevel
// Extension is a TLS extension
Extension = qtls.Extension
// ExtraConfig is the qtls.ExtraConfig
ExtraConfig = qtls.ExtraConfig
// RecordLayer is a qtls RecordLayer.
RecordLayer = qtls.RecordLayer
)
const (
// EncryptionHandshake is the Handshake encryption level
EncryptionHandshake = qtls.EncryptionHandshake
// Encryption0RTT is the 0-RTT encryption level
Encryption0RTT = qtls.Encryption0RTT
// EncryptionApplication is the application data encryption level
EncryptionApplication = qtls.EncryptionApplication
)
// AEADAESGCMTLS13 creates a new AES-GCM AEAD for TLS 1.3
func AEADAESGCMTLS13(key, fixedNonce []byte) cipher.AEAD {
return qtls.AEADAESGCMTLS13(key, fixedNonce)
}
// Client returns a new TLS client side connection.
func Client(conn net.Conn, config *Config, extraConfig *ExtraConfig) *Conn {
return qtls.Client(conn, config, extraConfig)
}
// Server returns a new TLS server side connection.
func Server(conn net.Conn, config *Config, extraConfig *ExtraConfig) *Conn {
return qtls.Server(conn, config, extraConfig)
}
func GetConnectionState(conn *Conn) ConnectionState {
return conn.ConnectionStateWith0RTT()
}
// ToTLSConnectionState extracts the tls.ConnectionState
func ToTLSConnectionState(cs ConnectionState) tls.ConnectionState {
return cs.ConnectionState
}
type cipherSuiteTLS13 struct {
ID uint16
KeyLen int
AEAD func(key, fixedNonce []byte) cipher.AEAD
Hash crypto.Hash
}
//go:linkname cipherSuiteTLS13ByID github.com/marten-seemann/qtls-go1-15.cipherSuiteTLS13ByID
func cipherSuiteTLS13ByID(id uint16) *cipherSuiteTLS13
// CipherSuiteTLS13ByID gets a TLS 1.3 cipher suite.
func CipherSuiteTLS13ByID(id uint16) *CipherSuiteTLS13 {
val := cipherSuiteTLS13ByID(id)
cs := (*cipherSuiteTLS13)(unsafe.Pointer(val))
return &qtls.CipherSuiteTLS13{
ID: cs.ID,
KeyLen: cs.KeyLen,
AEAD: cs.AEAD,
Hash: cs.Hash,
}
}

View File

@@ -0,0 +1,100 @@
// +build go1.16
// +build !go1.17
package qtls
import (
"crypto"
"crypto/cipher"
"crypto/tls"
"net"
"unsafe"
"github.com/marten-seemann/qtls-go1-16"
)
type (
// Alert is a TLS alert
Alert = qtls.Alert
// A Certificate is qtls.Certificate.
Certificate = qtls.Certificate
// CertificateRequestInfo contains inforamtion about a certificate request.
CertificateRequestInfo = qtls.CertificateRequestInfo
// A CipherSuiteTLS13 is a cipher suite for TLS 1.3
CipherSuiteTLS13 = qtls.CipherSuiteTLS13
// ClientHelloInfo contains information about a ClientHello.
ClientHelloInfo = qtls.ClientHelloInfo
// ClientSessionCache is a cache used for session resumption.
ClientSessionCache = qtls.ClientSessionCache
// ClientSessionState is a state needed for session resumption.
ClientSessionState = qtls.ClientSessionState
// A Config is a qtls.Config.
Config = qtls.Config
// A Conn is a qtls.Conn.
Conn = qtls.Conn
// ConnectionState contains information about the state of the connection.
ConnectionState = qtls.ConnectionStateWith0RTT
// EncryptionLevel is the encryption level of a message.
EncryptionLevel = qtls.EncryptionLevel
// Extension is a TLS extension
Extension = qtls.Extension
// ExtraConfig is the qtls.ExtraConfig
ExtraConfig = qtls.ExtraConfig
// RecordLayer is a qtls RecordLayer.
RecordLayer = qtls.RecordLayer
)
const (
// EncryptionHandshake is the Handshake encryption level
EncryptionHandshake = qtls.EncryptionHandshake
// Encryption0RTT is the 0-RTT encryption level
Encryption0RTT = qtls.Encryption0RTT
// EncryptionApplication is the application data encryption level
EncryptionApplication = qtls.EncryptionApplication
)
// AEADAESGCMTLS13 creates a new AES-GCM AEAD for TLS 1.3
func AEADAESGCMTLS13(key, fixedNonce []byte) cipher.AEAD {
return qtls.AEADAESGCMTLS13(key, fixedNonce)
}
// Client returns a new TLS client side connection.
func Client(conn net.Conn, config *Config, extraConfig *ExtraConfig) *Conn {
return qtls.Client(conn, config, extraConfig)
}
// Server returns a new TLS server side connection.
func Server(conn net.Conn, config *Config, extraConfig *ExtraConfig) *Conn {
return qtls.Server(conn, config, extraConfig)
}
func GetConnectionState(conn *Conn) ConnectionState {
return conn.ConnectionStateWith0RTT()
}
// ToTLSConnectionState extracts the tls.ConnectionState
func ToTLSConnectionState(cs ConnectionState) tls.ConnectionState {
return cs.ConnectionState
}
type cipherSuiteTLS13 struct {
ID uint16
KeyLen int
AEAD func(key, fixedNonce []byte) cipher.AEAD
Hash crypto.Hash
}
//go:linkname cipherSuiteTLS13ByID github.com/marten-seemann/qtls-go1-16.cipherSuiteTLS13ByID
func cipherSuiteTLS13ByID(id uint16) *cipherSuiteTLS13
// CipherSuiteTLS13ByID gets a TLS 1.3 cipher suite.
func CipherSuiteTLS13ByID(id uint16) *CipherSuiteTLS13 {
val := cipherSuiteTLS13ByID(id)
cs := (*cipherSuiteTLS13)(unsafe.Pointer(val))
return &qtls.CipherSuiteTLS13{
ID: cs.ID,
KeyLen: cs.KeyLen,
AEAD: cs.AEAD,
Hash: cs.Hash,
}
}

View File

@@ -0,0 +1,99 @@
// +build go1.17
package qtls
import (
"crypto"
"crypto/cipher"
"crypto/tls"
"net"
"unsafe"
"github.com/marten-seemann/qtls-go1-17"
)
type (
// Alert is a TLS alert
Alert = qtls.Alert
// A Certificate is qtls.Certificate.
Certificate = qtls.Certificate
// CertificateRequestInfo contains inforamtion about a certificate request.
CertificateRequestInfo = qtls.CertificateRequestInfo
// A CipherSuiteTLS13 is a cipher suite for TLS 1.3
CipherSuiteTLS13 = qtls.CipherSuiteTLS13
// ClientHelloInfo contains information about a ClientHello.
ClientHelloInfo = qtls.ClientHelloInfo
// ClientSessionCache is a cache used for session resumption.
ClientSessionCache = qtls.ClientSessionCache
// ClientSessionState is a state needed for session resumption.
ClientSessionState = qtls.ClientSessionState
// A Config is a qtls.Config.
Config = qtls.Config
// A Conn is a qtls.Conn.
Conn = qtls.Conn
// ConnectionState contains information about the state of the connection.
ConnectionState = qtls.ConnectionStateWith0RTT
// EncryptionLevel is the encryption level of a message.
EncryptionLevel = qtls.EncryptionLevel
// Extension is a TLS extension
Extension = qtls.Extension
// ExtraConfig is the qtls.ExtraConfig
ExtraConfig = qtls.ExtraConfig
// RecordLayer is a qtls RecordLayer.
RecordLayer = qtls.RecordLayer
)
const (
// EncryptionHandshake is the Handshake encryption level
EncryptionHandshake = qtls.EncryptionHandshake
// Encryption0RTT is the 0-RTT encryption level
Encryption0RTT = qtls.Encryption0RTT
// EncryptionApplication is the application data encryption level
EncryptionApplication = qtls.EncryptionApplication
)
// AEADAESGCMTLS13 creates a new AES-GCM AEAD for TLS 1.3
func AEADAESGCMTLS13(key, fixedNonce []byte) cipher.AEAD {
return qtls.AEADAESGCMTLS13(key, fixedNonce)
}
// Client returns a new TLS client side connection.
func Client(conn net.Conn, config *Config, extraConfig *ExtraConfig) *Conn {
return qtls.Client(conn, config, extraConfig)
}
// Server returns a new TLS server side connection.
func Server(conn net.Conn, config *Config, extraConfig *ExtraConfig) *Conn {
return qtls.Server(conn, config, extraConfig)
}
func GetConnectionState(conn *Conn) ConnectionState {
return conn.ConnectionStateWith0RTT()
}
// ToTLSConnectionState extracts the tls.ConnectionState
func ToTLSConnectionState(cs ConnectionState) tls.ConnectionState {
return cs.ConnectionState
}
type cipherSuiteTLS13 struct {
ID uint16
KeyLen int
AEAD func(key, fixedNonce []byte) cipher.AEAD
Hash crypto.Hash
}
//go:linkname cipherSuiteTLS13ByID github.com/marten-seemann/qtls-go1-17.cipherSuiteTLS13ByID
func cipherSuiteTLS13ByID(id uint16) *cipherSuiteTLS13
// CipherSuiteTLS13ByID gets a TLS 1.3 cipher suite.
func CipherSuiteTLS13ByID(id uint16) *CipherSuiteTLS13 {
val := cipherSuiteTLS13ByID(id)
cs := (*cipherSuiteTLS13)(unsafe.Pointer(val))
return &qtls.CipherSuiteTLS13{
ID: cs.ID,
KeyLen: cs.KeyLen,
AEAD: cs.AEAD,
Hash: cs.Hash,
}
}

View File

@@ -0,0 +1,5 @@
// +build go1.18
package qtls
var _ int = "quic-go doesn't build on Go 1.18 yet."

View File

@@ -0,0 +1,22 @@
package utils
import "sync/atomic"
// An AtomicBool is an atomic bool
type AtomicBool struct {
v int32
}
// Set sets the value
func (a *AtomicBool) Set(value bool) {
var n int32
if value {
n = 1
}
atomic.StoreInt32(&a.v, n)
}
// Get gets the value
func (a *AtomicBool) Get() bool {
return atomic.LoadInt32(&a.v) != 0
}

View File

@@ -0,0 +1,26 @@
package utils
import (
"bufio"
"io"
)
type bufferedWriteCloser struct {
*bufio.Writer
io.Closer
}
// NewBufferedWriteCloser creates an io.WriteCloser from a bufio.Writer and an io.Closer
func NewBufferedWriteCloser(writer *bufio.Writer, closer io.Closer) io.WriteCloser {
return &bufferedWriteCloser{
Writer: writer,
Closer: closer,
}
}
func (h bufferedWriteCloser) Close() error {
if err := h.Writer.Flush(); err != nil {
return err
}
return h.Closer.Close()
}

View File

@@ -0,0 +1,217 @@
// This file was automatically generated by genny.
// Any changes will be lost if this file is regenerated.
// see https://github.com/cheekybits/genny
package utils
// Linked list implementation from the Go standard library.
// ByteIntervalElement is an element of a linked list.
type ByteIntervalElement struct {
// Next and previous pointers in the doubly-linked list of elements.
// To simplify the implementation, internally a list l is implemented
// as a ring, such that &l.root is both the next element of the last
// list element (l.Back()) and the previous element of the first list
// element (l.Front()).
next, prev *ByteIntervalElement
// The list to which this element belongs.
list *ByteIntervalList
// The value stored with this element.
Value ByteInterval
}
// Next returns the next list element or nil.
func (e *ByteIntervalElement) Next() *ByteIntervalElement {
if p := e.next; e.list != nil && p != &e.list.root {
return p
}
return nil
}
// Prev returns the previous list element or nil.
func (e *ByteIntervalElement) Prev() *ByteIntervalElement {
if p := e.prev; e.list != nil && p != &e.list.root {
return p
}
return nil
}
// ByteIntervalList is a linked list of ByteIntervals.
type ByteIntervalList struct {
root ByteIntervalElement // sentinel list element, only &root, root.prev, and root.next are used
len int // current list length excluding (this) sentinel element
}
// Init initializes or clears list l.
func (l *ByteIntervalList) Init() *ByteIntervalList {
l.root.next = &l.root
l.root.prev = &l.root
l.len = 0
return l
}
// NewByteIntervalList returns an initialized list.
func NewByteIntervalList() *ByteIntervalList { return new(ByteIntervalList).Init() }
// Len returns the number of elements of list l.
// The complexity is O(1).
func (l *ByteIntervalList) Len() int { return l.len }
// Front returns the first element of list l or nil if the list is empty.
func (l *ByteIntervalList) Front() *ByteIntervalElement {
if l.len == 0 {
return nil
}
return l.root.next
}
// Back returns the last element of list l or nil if the list is empty.
func (l *ByteIntervalList) Back() *ByteIntervalElement {
if l.len == 0 {
return nil
}
return l.root.prev
}
// lazyInit lazily initializes a zero List value.
func (l *ByteIntervalList) lazyInit() {
if l.root.next == nil {
l.Init()
}
}
// insert inserts e after at, increments l.len, and returns e.
func (l *ByteIntervalList) insert(e, at *ByteIntervalElement) *ByteIntervalElement {
n := at.next
at.next = e
e.prev = at
e.next = n
n.prev = e
e.list = l
l.len++
return e
}
// insertValue is a convenience wrapper for insert(&Element{Value: v}, at).
func (l *ByteIntervalList) insertValue(v ByteInterval, at *ByteIntervalElement) *ByteIntervalElement {
return l.insert(&ByteIntervalElement{Value: v}, at)
}
// remove removes e from its list, decrements l.len, and returns e.
func (l *ByteIntervalList) remove(e *ByteIntervalElement) *ByteIntervalElement {
e.prev.next = e.next
e.next.prev = e.prev
e.next = nil // avoid memory leaks
e.prev = nil // avoid memory leaks
e.list = nil
l.len--
return e
}
// Remove removes e from l if e is an element of list l.
// It returns the element value e.Value.
// The element must not be nil.
func (l *ByteIntervalList) Remove(e *ByteIntervalElement) ByteInterval {
if e.list == l {
// if e.list == l, l must have been initialized when e was inserted
// in l or l == nil (e is a zero Element) and l.remove will crash
l.remove(e)
}
return e.Value
}
// PushFront inserts a new element e with value v at the front of list l and returns e.
func (l *ByteIntervalList) PushFront(v ByteInterval) *ByteIntervalElement {
l.lazyInit()
return l.insertValue(v, &l.root)
}
// PushBack inserts a new element e with value v at the back of list l and returns e.
func (l *ByteIntervalList) PushBack(v ByteInterval) *ByteIntervalElement {
l.lazyInit()
return l.insertValue(v, l.root.prev)
}
// InsertBefore inserts a new element e with value v immediately before mark and returns e.
// If mark is not an element of l, the list is not modified.
// The mark must not be nil.
func (l *ByteIntervalList) InsertBefore(v ByteInterval, mark *ByteIntervalElement) *ByteIntervalElement {
if mark.list != l {
return nil
}
// see comment in List.Remove about initialization of l
return l.insertValue(v, mark.prev)
}
// InsertAfter inserts a new element e with value v immediately after mark and returns e.
// If mark is not an element of l, the list is not modified.
// The mark must not be nil.
func (l *ByteIntervalList) InsertAfter(v ByteInterval, mark *ByteIntervalElement) *ByteIntervalElement {
if mark.list != l {
return nil
}
// see comment in List.Remove about initialization of l
return l.insertValue(v, mark)
}
// MoveToFront moves element e to the front of list l.
// If e is not an element of l, the list is not modified.
// The element must not be nil.
func (l *ByteIntervalList) MoveToFront(e *ByteIntervalElement) {
if e.list != l || l.root.next == e {
return
}
// see comment in List.Remove about initialization of l
l.insert(l.remove(e), &l.root)
}
// MoveToBack moves element e to the back of list l.
// If e is not an element of l, the list is not modified.
// The element must not be nil.
func (l *ByteIntervalList) MoveToBack(e *ByteIntervalElement) {
if e.list != l || l.root.prev == e {
return
}
// see comment in List.Remove about initialization of l
l.insert(l.remove(e), l.root.prev)
}
// MoveBefore moves element e to its new position before mark.
// If e or mark is not an element of l, or e == mark, the list is not modified.
// The element and mark must not be nil.
func (l *ByteIntervalList) MoveBefore(e, mark *ByteIntervalElement) {
if e.list != l || e == mark || mark.list != l {
return
}
l.insert(l.remove(e), mark.prev)
}
// MoveAfter moves element e to its new position after mark.
// If e or mark is not an element of l, or e == mark, the list is not modified.
// The element and mark must not be nil.
func (l *ByteIntervalList) MoveAfter(e, mark *ByteIntervalElement) {
if e.list != l || e == mark || mark.list != l {
return
}
l.insert(l.remove(e), mark)
}
// PushBackList inserts a copy of an other list at the back of list l.
// The lists l and other may be the same. They must not be nil.
func (l *ByteIntervalList) PushBackList(other *ByteIntervalList) {
l.lazyInit()
for i, e := other.Len(), other.Front(); i > 0; i, e = i-1, e.Next() {
l.insertValue(e.Value, l.root.prev)
}
}
// PushFrontList inserts a copy of an other list at the front of list l.
// The lists l and other may be the same. They must not be nil.
func (l *ByteIntervalList) PushFrontList(other *ByteIntervalList) {
l.lazyInit()
for i, e := other.Len(), other.Back(); i > 0; i, e = i-1, e.Prev() {
l.insertValue(e.Value, &l.root)
}
}

View File

@@ -0,0 +1,17 @@
package utils
import (
"bytes"
"io"
)
// A ByteOrder specifies how to convert byte sequences into 16-, 32-, or 64-bit unsigned integers.
type ByteOrder interface {
ReadUint32(io.ByteReader) (uint32, error)
ReadUint24(io.ByteReader) (uint32, error)
ReadUint16(io.ByteReader) (uint16, error)
WriteUint32(*bytes.Buffer, uint32)
WriteUint24(*bytes.Buffer, uint32)
WriteUint16(*bytes.Buffer, uint16)
}

View File

@@ -0,0 +1,89 @@
package utils
import (
"bytes"
"io"
)
// BigEndian is the big-endian implementation of ByteOrder.
var BigEndian ByteOrder = bigEndian{}
type bigEndian struct{}
var _ ByteOrder = &bigEndian{}
// ReadUintN reads N bytes
func (bigEndian) ReadUintN(b io.ByteReader, length uint8) (uint64, error) {
var res uint64
for i := uint8(0); i < length; i++ {
bt, err := b.ReadByte()
if err != nil {
return 0, err
}
res ^= uint64(bt) << ((length - 1 - i) * 8)
}
return res, nil
}
// ReadUint32 reads a uint32
func (bigEndian) ReadUint32(b io.ByteReader) (uint32, error) {
var b1, b2, b3, b4 uint8
var err error
if b4, err = b.ReadByte(); err != nil {
return 0, err
}
if b3, err = b.ReadByte(); err != nil {
return 0, err
}
if b2, err = b.ReadByte(); err != nil {
return 0, err
}
if b1, err = b.ReadByte(); err != nil {
return 0, err
}
return uint32(b1) + uint32(b2)<<8 + uint32(b3)<<16 + uint32(b4)<<24, nil
}
// ReadUint24 reads a uint24
func (bigEndian) ReadUint24(b io.ByteReader) (uint32, error) {
var b1, b2, b3 uint8
var err error
if b3, err = b.ReadByte(); err != nil {
return 0, err
}
if b2, err = b.ReadByte(); err != nil {
return 0, err
}
if b1, err = b.ReadByte(); err != nil {
return 0, err
}
return uint32(b1) + uint32(b2)<<8 + uint32(b3)<<16, nil
}
// ReadUint16 reads a uint16
func (bigEndian) ReadUint16(b io.ByteReader) (uint16, error) {
var b1, b2 uint8
var err error
if b2, err = b.ReadByte(); err != nil {
return 0, err
}
if b1, err = b.ReadByte(); err != nil {
return 0, err
}
return uint16(b1) + uint16(b2)<<8, nil
}
// WriteUint32 writes a uint32
func (bigEndian) WriteUint32(b *bytes.Buffer, i uint32) {
b.Write([]byte{uint8(i >> 24), uint8(i >> 16), uint8(i >> 8), uint8(i)})
}
// WriteUint24 writes a uint24
func (bigEndian) WriteUint24(b *bytes.Buffer, i uint32) {
b.Write([]byte{uint8(i >> 16), uint8(i >> 8), uint8(i)})
}
// WriteUint16 writes a uint16
func (bigEndian) WriteUint16(b *bytes.Buffer, i uint16) {
b.Write([]byte{uint8(i >> 8), uint8(i)})
}

View File

@@ -0,0 +1,5 @@
package utils
//go:generate genny -pkg utils -in linkedlist/linkedlist.go -out byteinterval_linkedlist.go gen Item=ByteInterval
//go:generate genny -pkg utils -in linkedlist/linkedlist.go -out packetinterval_linkedlist.go gen Item=PacketInterval
//go:generate genny -pkg utils -in linkedlist/linkedlist.go -out newconnectionid_linkedlist.go gen Item=NewConnectionID

View File

@@ -0,0 +1,10 @@
package utils
import "net"
func IsIPv4(ip net.IP) bool {
// If ip is not an IPv4 address, To4 returns nil.
// Note that there might be some corner cases, where this is not correct.
// See https://stackoverflow.com/questions/22751035/golang-distinguish-ipv4-ipv6.
return ip.To4() != nil
}

View File

@@ -0,0 +1,131 @@
package utils
import (
"fmt"
"log"
"os"
"strings"
"time"
)
// LogLevel of quic-go
type LogLevel uint8
const (
// LogLevelNothing disables
LogLevelNothing LogLevel = iota
// LogLevelError enables err logs
LogLevelError
// LogLevelInfo enables info logs (e.g. packets)
LogLevelInfo
// LogLevelDebug enables debug logs (e.g. packet contents)
LogLevelDebug
)
const logEnv = "QUIC_GO_LOG_LEVEL"
// A Logger logs.
type Logger interface {
SetLogLevel(LogLevel)
SetLogTimeFormat(format string)
WithPrefix(prefix string) Logger
Debug() bool
Errorf(format string, args ...interface{})
Infof(format string, args ...interface{})
Debugf(format string, args ...interface{})
}
// DefaultLogger is used by quic-go for logging.
var DefaultLogger Logger
type defaultLogger struct {
prefix string
logLevel LogLevel
timeFormat string
}
var _ Logger = &defaultLogger{}
// SetLogLevel sets the log level
func (l *defaultLogger) SetLogLevel(level LogLevel) {
l.logLevel = level
}
// SetLogTimeFormat sets the format of the timestamp
// an empty string disables the logging of timestamps
func (l *defaultLogger) SetLogTimeFormat(format string) {
log.SetFlags(0) // disable timestamp logging done by the log package
l.timeFormat = format
}
// Debugf logs something
func (l *defaultLogger) Debugf(format string, args ...interface{}) {
if l.logLevel == LogLevelDebug {
l.logMessage(format, args...)
}
}
// Infof logs something
func (l *defaultLogger) Infof(format string, args ...interface{}) {
if l.logLevel >= LogLevelInfo {
l.logMessage(format, args...)
}
}
// Errorf logs something
func (l *defaultLogger) Errorf(format string, args ...interface{}) {
if l.logLevel >= LogLevelError {
l.logMessage(format, args...)
}
}
func (l *defaultLogger) logMessage(format string, args ...interface{}) {
var pre string
if len(l.timeFormat) > 0 {
pre = time.Now().Format(l.timeFormat) + " "
}
if len(l.prefix) > 0 {
pre += l.prefix + " "
}
log.Printf(pre+format, args...)
}
func (l *defaultLogger) WithPrefix(prefix string) Logger {
if len(l.prefix) > 0 {
prefix = l.prefix + " " + prefix
}
return &defaultLogger{
logLevel: l.logLevel,
timeFormat: l.timeFormat,
prefix: prefix,
}
}
// Debug returns true if the log level is LogLevelDebug
func (l *defaultLogger) Debug() bool {
return l.logLevel == LogLevelDebug
}
func init() {
DefaultLogger = &defaultLogger{}
DefaultLogger.SetLogLevel(readLoggingEnv())
}
func readLoggingEnv() LogLevel {
switch strings.ToLower(os.Getenv(logEnv)) {
case "":
return LogLevelNothing
case "debug":
return LogLevelDebug
case "info":
return LogLevelInfo
case "error":
return LogLevelError
default:
fmt.Fprintln(os.Stderr, "invalid quic-go log level, see https://github.com/lucas-clemente/quic-go/wiki/Logging")
return LogLevelNothing
}
}

View File

@@ -0,0 +1,170 @@
package utils
import (
"math"
"time"
"github.com/lucas-clemente/quic-go/internal/protocol"
)
// InfDuration is a duration of infinite length
const InfDuration = time.Duration(math.MaxInt64)
// Max returns the maximum of two Ints
func Max(a, b int) int {
if a < b {
return b
}
return a
}
// MaxUint32 returns the maximum of two uint32
func MaxUint32(a, b uint32) uint32 {
if a < b {
return b
}
return a
}
// MaxUint64 returns the maximum of two uint64
func MaxUint64(a, b uint64) uint64 {
if a < b {
return b
}
return a
}
// MinUint64 returns the maximum of two uint64
func MinUint64(a, b uint64) uint64 {
if a < b {
return a
}
return b
}
// Min returns the minimum of two Ints
func Min(a, b int) int {
if a < b {
return a
}
return b
}
// MinUint32 returns the maximum of two uint32
func MinUint32(a, b uint32) uint32 {
if a < b {
return a
}
return b
}
// MinInt64 returns the minimum of two int64
func MinInt64(a, b int64) int64 {
if a < b {
return a
}
return b
}
// MaxInt64 returns the minimum of two int64
func MaxInt64(a, b int64) int64 {
if a > b {
return a
}
return b
}
// MinByteCount returns the minimum of two ByteCounts
func MinByteCount(a, b protocol.ByteCount) protocol.ByteCount {
if a < b {
return a
}
return b
}
// MaxByteCount returns the maximum of two ByteCounts
func MaxByteCount(a, b protocol.ByteCount) protocol.ByteCount {
if a < b {
return b
}
return a
}
// MaxDuration returns the max duration
func MaxDuration(a, b time.Duration) time.Duration {
if a > b {
return a
}
return b
}
// MinDuration returns the minimum duration
func MinDuration(a, b time.Duration) time.Duration {
if a > b {
return b
}
return a
}
// MinNonZeroDuration return the minimum duration that's not zero.
func MinNonZeroDuration(a, b time.Duration) time.Duration {
if a == 0 {
return b
}
if b == 0 {
return a
}
return MinDuration(a, b)
}
// AbsDuration returns the absolute value of a time duration
func AbsDuration(d time.Duration) time.Duration {
if d >= 0 {
return d
}
return -d
}
// MinTime returns the earlier time
func MinTime(a, b time.Time) time.Time {
if a.After(b) {
return b
}
return a
}
// MinNonZeroTime returns the earlist time that is not time.Time{}
// If both a and b are time.Time{}, it returns time.Time{}
func MinNonZeroTime(a, b time.Time) time.Time {
if a.IsZero() {
return b
}
if b.IsZero() {
return a
}
return MinTime(a, b)
}
// MaxTime returns the later time
func MaxTime(a, b time.Time) time.Time {
if a.After(b) {
return a
}
return b
}
// MaxPacketNumber returns the max packet number
func MaxPacketNumber(a, b protocol.PacketNumber) protocol.PacketNumber {
if a > b {
return a
}
return b
}
// MinPacketNumber returns the min packet number
func MinPacketNumber(a, b protocol.PacketNumber) protocol.PacketNumber {
if a < b {
return a
}
return b
}

View File

@@ -0,0 +1,12 @@
package utils
import (
"github.com/lucas-clemente/quic-go/internal/protocol"
)
// NewConnectionID is a new connection ID
type NewConnectionID struct {
SequenceNumber uint64
ConnectionID protocol.ConnectionID
StatelessResetToken protocol.StatelessResetToken
}

View File

@@ -0,0 +1,217 @@
// This file was automatically generated by genny.
// Any changes will be lost if this file is regenerated.
// see https://github.com/cheekybits/genny
package utils
// Linked list implementation from the Go standard library.
// NewConnectionIDElement is an element of a linked list.
type NewConnectionIDElement struct {
// Next and previous pointers in the doubly-linked list of elements.
// To simplify the implementation, internally a list l is implemented
// as a ring, such that &l.root is both the next element of the last
// list element (l.Back()) and the previous element of the first list
// element (l.Front()).
next, prev *NewConnectionIDElement
// The list to which this element belongs.
list *NewConnectionIDList
// The value stored with this element.
Value NewConnectionID
}
// Next returns the next list element or nil.
func (e *NewConnectionIDElement) Next() *NewConnectionIDElement {
if p := e.next; e.list != nil && p != &e.list.root {
return p
}
return nil
}
// Prev returns the previous list element or nil.
func (e *NewConnectionIDElement) Prev() *NewConnectionIDElement {
if p := e.prev; e.list != nil && p != &e.list.root {
return p
}
return nil
}
// NewConnectionIDList is a linked list of NewConnectionIDs.
type NewConnectionIDList struct {
root NewConnectionIDElement // sentinel list element, only &root, root.prev, and root.next are used
len int // current list length excluding (this) sentinel element
}
// Init initializes or clears list l.
func (l *NewConnectionIDList) Init() *NewConnectionIDList {
l.root.next = &l.root
l.root.prev = &l.root
l.len = 0
return l
}
// NewNewConnectionIDList returns an initialized list.
func NewNewConnectionIDList() *NewConnectionIDList { return new(NewConnectionIDList).Init() }
// Len returns the number of elements of list l.
// The complexity is O(1).
func (l *NewConnectionIDList) Len() int { return l.len }
// Front returns the first element of list l or nil if the list is empty.
func (l *NewConnectionIDList) Front() *NewConnectionIDElement {
if l.len == 0 {
return nil
}
return l.root.next
}
// Back returns the last element of list l or nil if the list is empty.
func (l *NewConnectionIDList) Back() *NewConnectionIDElement {
if l.len == 0 {
return nil
}
return l.root.prev
}
// lazyInit lazily initializes a zero List value.
func (l *NewConnectionIDList) lazyInit() {
if l.root.next == nil {
l.Init()
}
}
// insert inserts e after at, increments l.len, and returns e.
func (l *NewConnectionIDList) insert(e, at *NewConnectionIDElement) *NewConnectionIDElement {
n := at.next
at.next = e
e.prev = at
e.next = n
n.prev = e
e.list = l
l.len++
return e
}
// insertValue is a convenience wrapper for insert(&Element{Value: v}, at).
func (l *NewConnectionIDList) insertValue(v NewConnectionID, at *NewConnectionIDElement) *NewConnectionIDElement {
return l.insert(&NewConnectionIDElement{Value: v}, at)
}
// remove removes e from its list, decrements l.len, and returns e.
func (l *NewConnectionIDList) remove(e *NewConnectionIDElement) *NewConnectionIDElement {
e.prev.next = e.next
e.next.prev = e.prev
e.next = nil // avoid memory leaks
e.prev = nil // avoid memory leaks
e.list = nil
l.len--
return e
}
// Remove removes e from l if e is an element of list l.
// It returns the element value e.Value.
// The element must not be nil.
func (l *NewConnectionIDList) Remove(e *NewConnectionIDElement) NewConnectionID {
if e.list == l {
// if e.list == l, l must have been initialized when e was inserted
// in l or l == nil (e is a zero Element) and l.remove will crash
l.remove(e)
}
return e.Value
}
// PushFront inserts a new element e with value v at the front of list l and returns e.
func (l *NewConnectionIDList) PushFront(v NewConnectionID) *NewConnectionIDElement {
l.lazyInit()
return l.insertValue(v, &l.root)
}
// PushBack inserts a new element e with value v at the back of list l and returns e.
func (l *NewConnectionIDList) PushBack(v NewConnectionID) *NewConnectionIDElement {
l.lazyInit()
return l.insertValue(v, l.root.prev)
}
// InsertBefore inserts a new element e with value v immediately before mark and returns e.
// If mark is not an element of l, the list is not modified.
// The mark must not be nil.
func (l *NewConnectionIDList) InsertBefore(v NewConnectionID, mark *NewConnectionIDElement) *NewConnectionIDElement {
if mark.list != l {
return nil
}
// see comment in List.Remove about initialization of l
return l.insertValue(v, mark.prev)
}
// InsertAfter inserts a new element e with value v immediately after mark and returns e.
// If mark is not an element of l, the list is not modified.
// The mark must not be nil.
func (l *NewConnectionIDList) InsertAfter(v NewConnectionID, mark *NewConnectionIDElement) *NewConnectionIDElement {
if mark.list != l {
return nil
}
// see comment in List.Remove about initialization of l
return l.insertValue(v, mark)
}
// MoveToFront moves element e to the front of list l.
// If e is not an element of l, the list is not modified.
// The element must not be nil.
func (l *NewConnectionIDList) MoveToFront(e *NewConnectionIDElement) {
if e.list != l || l.root.next == e {
return
}
// see comment in List.Remove about initialization of l
l.insert(l.remove(e), &l.root)
}
// MoveToBack moves element e to the back of list l.
// If e is not an element of l, the list is not modified.
// The element must not be nil.
func (l *NewConnectionIDList) MoveToBack(e *NewConnectionIDElement) {
if e.list != l || l.root.prev == e {
return
}
// see comment in List.Remove about initialization of l
l.insert(l.remove(e), l.root.prev)
}
// MoveBefore moves element e to its new position before mark.
// If e or mark is not an element of l, or e == mark, the list is not modified.
// The element and mark must not be nil.
func (l *NewConnectionIDList) MoveBefore(e, mark *NewConnectionIDElement) {
if e.list != l || e == mark || mark.list != l {
return
}
l.insert(l.remove(e), mark.prev)
}
// MoveAfter moves element e to its new position after mark.
// If e or mark is not an element of l, or e == mark, the list is not modified.
// The element and mark must not be nil.
func (l *NewConnectionIDList) MoveAfter(e, mark *NewConnectionIDElement) {
if e.list != l || e == mark || mark.list != l {
return
}
l.insert(l.remove(e), mark)
}
// PushBackList inserts a copy of an other list at the back of list l.
// The lists l and other may be the same. They must not be nil.
func (l *NewConnectionIDList) PushBackList(other *NewConnectionIDList) {
l.lazyInit()
for i, e := other.Len(), other.Front(); i > 0; i, e = i-1, e.Next() {
l.insertValue(e.Value, l.root.prev)
}
}
// PushFrontList inserts a copy of an other list at the front of list l.
// The lists l and other may be the same. They must not be nil.
func (l *NewConnectionIDList) PushFrontList(other *NewConnectionIDList) {
l.lazyInit()
for i, e := other.Len(), other.Back(); i > 0; i, e = i-1, e.Prev() {
l.insertValue(e.Value, &l.root)
}
}

View File

@@ -0,0 +1,9 @@
package utils
import "github.com/lucas-clemente/quic-go/internal/protocol"
// PacketInterval is an interval from one PacketNumber to the other
type PacketInterval struct {
Start protocol.PacketNumber
End protocol.PacketNumber
}

View File

@@ -0,0 +1,217 @@
// This file was automatically generated by genny.
// Any changes will be lost if this file is regenerated.
// see https://github.com/cheekybits/genny
package utils
// Linked list implementation from the Go standard library.
// PacketIntervalElement is an element of a linked list.
type PacketIntervalElement struct {
// Next and previous pointers in the doubly-linked list of elements.
// To simplify the implementation, internally a list l is implemented
// as a ring, such that &l.root is both the next element of the last
// list element (l.Back()) and the previous element of the first list
// element (l.Front()).
next, prev *PacketIntervalElement
// The list to which this element belongs.
list *PacketIntervalList
// The value stored with this element.
Value PacketInterval
}
// Next returns the next list element or nil.
func (e *PacketIntervalElement) Next() *PacketIntervalElement {
if p := e.next; e.list != nil && p != &e.list.root {
return p
}
return nil
}
// Prev returns the previous list element or nil.
func (e *PacketIntervalElement) Prev() *PacketIntervalElement {
if p := e.prev; e.list != nil && p != &e.list.root {
return p
}
return nil
}
// PacketIntervalList is a linked list of PacketIntervals.
type PacketIntervalList struct {
root PacketIntervalElement // sentinel list element, only &root, root.prev, and root.next are used
len int // current list length excluding (this) sentinel element
}
// Init initializes or clears list l.
func (l *PacketIntervalList) Init() *PacketIntervalList {
l.root.next = &l.root
l.root.prev = &l.root
l.len = 0
return l
}
// NewPacketIntervalList returns an initialized list.
func NewPacketIntervalList() *PacketIntervalList { return new(PacketIntervalList).Init() }
// Len returns the number of elements of list l.
// The complexity is O(1).
func (l *PacketIntervalList) Len() int { return l.len }
// Front returns the first element of list l or nil if the list is empty.
func (l *PacketIntervalList) Front() *PacketIntervalElement {
if l.len == 0 {
return nil
}
return l.root.next
}
// Back returns the last element of list l or nil if the list is empty.
func (l *PacketIntervalList) Back() *PacketIntervalElement {
if l.len == 0 {
return nil
}
return l.root.prev
}
// lazyInit lazily initializes a zero List value.
func (l *PacketIntervalList) lazyInit() {
if l.root.next == nil {
l.Init()
}
}
// insert inserts e after at, increments l.len, and returns e.
func (l *PacketIntervalList) insert(e, at *PacketIntervalElement) *PacketIntervalElement {
n := at.next
at.next = e
e.prev = at
e.next = n
n.prev = e
e.list = l
l.len++
return e
}
// insertValue is a convenience wrapper for insert(&Element{Value: v}, at).
func (l *PacketIntervalList) insertValue(v PacketInterval, at *PacketIntervalElement) *PacketIntervalElement {
return l.insert(&PacketIntervalElement{Value: v}, at)
}
// remove removes e from its list, decrements l.len, and returns e.
func (l *PacketIntervalList) remove(e *PacketIntervalElement) *PacketIntervalElement {
e.prev.next = e.next
e.next.prev = e.prev
e.next = nil // avoid memory leaks
e.prev = nil // avoid memory leaks
e.list = nil
l.len--
return e
}
// Remove removes e from l if e is an element of list l.
// It returns the element value e.Value.
// The element must not be nil.
func (l *PacketIntervalList) Remove(e *PacketIntervalElement) PacketInterval {
if e.list == l {
// if e.list == l, l must have been initialized when e was inserted
// in l or l == nil (e is a zero Element) and l.remove will crash
l.remove(e)
}
return e.Value
}
// PushFront inserts a new element e with value v at the front of list l and returns e.
func (l *PacketIntervalList) PushFront(v PacketInterval) *PacketIntervalElement {
l.lazyInit()
return l.insertValue(v, &l.root)
}
// PushBack inserts a new element e with value v at the back of list l and returns e.
func (l *PacketIntervalList) PushBack(v PacketInterval) *PacketIntervalElement {
l.lazyInit()
return l.insertValue(v, l.root.prev)
}
// InsertBefore inserts a new element e with value v immediately before mark and returns e.
// If mark is not an element of l, the list is not modified.
// The mark must not be nil.
func (l *PacketIntervalList) InsertBefore(v PacketInterval, mark *PacketIntervalElement) *PacketIntervalElement {
if mark.list != l {
return nil
}
// see comment in List.Remove about initialization of l
return l.insertValue(v, mark.prev)
}
// InsertAfter inserts a new element e with value v immediately after mark and returns e.
// If mark is not an element of l, the list is not modified.
// The mark must not be nil.
func (l *PacketIntervalList) InsertAfter(v PacketInterval, mark *PacketIntervalElement) *PacketIntervalElement {
if mark.list != l {
return nil
}
// see comment in List.Remove about initialization of l
return l.insertValue(v, mark)
}
// MoveToFront moves element e to the front of list l.
// If e is not an element of l, the list is not modified.
// The element must not be nil.
func (l *PacketIntervalList) MoveToFront(e *PacketIntervalElement) {
if e.list != l || l.root.next == e {
return
}
// see comment in List.Remove about initialization of l
l.insert(l.remove(e), &l.root)
}
// MoveToBack moves element e to the back of list l.
// If e is not an element of l, the list is not modified.
// The element must not be nil.
func (l *PacketIntervalList) MoveToBack(e *PacketIntervalElement) {
if e.list != l || l.root.prev == e {
return
}
// see comment in List.Remove about initialization of l
l.insert(l.remove(e), l.root.prev)
}
// MoveBefore moves element e to its new position before mark.
// If e or mark is not an element of l, or e == mark, the list is not modified.
// The element and mark must not be nil.
func (l *PacketIntervalList) MoveBefore(e, mark *PacketIntervalElement) {
if e.list != l || e == mark || mark.list != l {
return
}
l.insert(l.remove(e), mark.prev)
}
// MoveAfter moves element e to its new position after mark.
// If e or mark is not an element of l, or e == mark, the list is not modified.
// The element and mark must not be nil.
func (l *PacketIntervalList) MoveAfter(e, mark *PacketIntervalElement) {
if e.list != l || e == mark || mark.list != l {
return
}
l.insert(l.remove(e), mark)
}
// PushBackList inserts a copy of an other list at the back of list l.
// The lists l and other may be the same. They must not be nil.
func (l *PacketIntervalList) PushBackList(other *PacketIntervalList) {
l.lazyInit()
for i, e := other.Len(), other.Front(); i > 0; i, e = i-1, e.Next() {
l.insertValue(e.Value, l.root.prev)
}
}
// PushFrontList inserts a copy of an other list at the front of list l.
// The lists l and other may be the same. They must not be nil.
func (l *PacketIntervalList) PushFrontList(other *PacketIntervalList) {
l.lazyInit()
for i, e := other.Len(), other.Back(); i > 0; i, e = i-1, e.Prev() {
l.insertValue(e.Value, &l.root)
}
}

View File

@@ -0,0 +1,29 @@
package utils
import (
"crypto/rand"
"encoding/binary"
)
// Rand is a wrapper around crypto/rand that adds some convenience functions known from math/rand.
type Rand struct {
buf [4]byte
}
func (r *Rand) Int31() int32 {
rand.Read(r.buf[:])
return int32(binary.BigEndian.Uint32(r.buf[:]) & ^uint32(1<<31))
}
// copied from the standard library math/rand implementation of Int63n
func (r *Rand) Int31n(n int32) int32 {
if n&(n-1) == 0 { // n is power of two, can mask
return r.Int31() & (n - 1)
}
max := int32((1 << 31) - 1 - (1<<31)%uint32(n))
v := r.Int31()
for v > max {
v = r.Int31()
}
return v % n
}

View File

@@ -0,0 +1,127 @@
package utils
import (
"time"
"github.com/lucas-clemente/quic-go/internal/protocol"
)
const (
rttAlpha = 0.125
oneMinusAlpha = 1 - rttAlpha
rttBeta = 0.25
oneMinusBeta = 1 - rttBeta
// The default RTT used before an RTT sample is taken.
defaultInitialRTT = 100 * time.Millisecond
)
// RTTStats provides round-trip statistics
type RTTStats struct {
hasMeasurement bool
minRTT time.Duration
latestRTT time.Duration
smoothedRTT time.Duration
meanDeviation time.Duration
maxAckDelay time.Duration
}
// NewRTTStats makes a properly initialized RTTStats object
func NewRTTStats() *RTTStats {
return &RTTStats{}
}
// MinRTT Returns the minRTT for the entire connection.
// May return Zero if no valid updates have occurred.
func (r *RTTStats) MinRTT() time.Duration { return r.minRTT }
// LatestRTT returns the most recent rtt measurement.
// May return Zero if no valid updates have occurred.
func (r *RTTStats) LatestRTT() time.Duration { return r.latestRTT }
// SmoothedRTT returns the smoothed RTT for the connection.
// May return Zero if no valid updates have occurred.
func (r *RTTStats) SmoothedRTT() time.Duration { return r.smoothedRTT }
// MeanDeviation gets the mean deviation
func (r *RTTStats) MeanDeviation() time.Duration { return r.meanDeviation }
// MaxAckDelay gets the max_ack_delay advertised by the peer
func (r *RTTStats) MaxAckDelay() time.Duration { return r.maxAckDelay }
// PTO gets the probe timeout duration.
func (r *RTTStats) PTO(includeMaxAckDelay bool) time.Duration {
if r.SmoothedRTT() == 0 {
return 2 * defaultInitialRTT
}
pto := r.SmoothedRTT() + MaxDuration(4*r.MeanDeviation(), protocol.TimerGranularity)
if includeMaxAckDelay {
pto += r.MaxAckDelay()
}
return pto
}
// UpdateRTT updates the RTT based on a new sample.
func (r *RTTStats) UpdateRTT(sendDelta, ackDelay time.Duration, now time.Time) {
if sendDelta == InfDuration || sendDelta <= 0 {
return
}
// Update r.minRTT first. r.minRTT does not use an rttSample corrected for
// ackDelay but the raw observed sendDelta, since poor clock granularity at
// the client may cause a high ackDelay to result in underestimation of the
// r.minRTT.
if r.minRTT == 0 || r.minRTT > sendDelta {
r.minRTT = sendDelta
}
// Correct for ackDelay if information received from the peer results in a
// an RTT sample at least as large as minRTT. Otherwise, only use the
// sendDelta.
sample := sendDelta
if sample-r.minRTT >= ackDelay {
sample -= ackDelay
}
r.latestRTT = sample
// First time call.
if !r.hasMeasurement {
r.hasMeasurement = true
r.smoothedRTT = sample
r.meanDeviation = sample / 2
} else {
r.meanDeviation = time.Duration(oneMinusBeta*float32(r.meanDeviation/time.Microsecond)+rttBeta*float32(AbsDuration(r.smoothedRTT-sample)/time.Microsecond)) * time.Microsecond
r.smoothedRTT = time.Duration((float32(r.smoothedRTT/time.Microsecond)*oneMinusAlpha)+(float32(sample/time.Microsecond)*rttAlpha)) * time.Microsecond
}
}
// SetMaxAckDelay sets the max_ack_delay
func (r *RTTStats) SetMaxAckDelay(mad time.Duration) {
r.maxAckDelay = mad
}
// SetInitialRTT sets the initial RTT.
// It is used during the 0-RTT handshake when restoring the RTT stats from the session state.
func (r *RTTStats) SetInitialRTT(t time.Duration) {
if r.hasMeasurement {
panic("initial RTT set after first measurement")
}
r.smoothedRTT = t
r.latestRTT = t
}
// OnConnectionMigration is called when connection migrates and rtt measurement needs to be reset.
func (r *RTTStats) OnConnectionMigration() {
r.latestRTT = 0
r.minRTT = 0
r.smoothedRTT = 0
r.meanDeviation = 0
}
// ExpireSmoothedMetrics causes the smoothed_rtt to be increased to the latest_rtt if the latest_rtt
// is larger. The mean deviation is increased to the most recent deviation if
// it's larger.
func (r *RTTStats) ExpireSmoothedMetrics() {
r.meanDeviation = MaxDuration(r.meanDeviation, AbsDuration(r.smoothedRTT-r.latestRTT))
r.smoothedRTT = MaxDuration(r.smoothedRTT, r.latestRTT)
}

View File

@@ -0,0 +1,9 @@
package utils
import "github.com/lucas-clemente/quic-go/internal/protocol"
// ByteInterval is an interval from one ByteCount to the other
type ByteInterval struct {
Start protocol.ByteCount
End protocol.ByteCount
}

View File

@@ -0,0 +1,53 @@
package utils
import (
"math"
"time"
)
// A Timer wrapper that behaves correctly when resetting
type Timer struct {
t *time.Timer
read bool
deadline time.Time
}
// NewTimer creates a new timer that is not set
func NewTimer() *Timer {
return &Timer{t: time.NewTimer(time.Duration(math.MaxInt64))}
}
// Chan returns the channel of the wrapped timer
func (t *Timer) Chan() <-chan time.Time {
return t.t.C
}
// Reset the timer, no matter whether the value was read or not
func (t *Timer) Reset(deadline time.Time) {
if deadline.Equal(t.deadline) && !t.read {
// No need to reset the timer
return
}
// We need to drain the timer if the value from its channel was not read yet.
// See https://groups.google.com/forum/#!topic/golang-dev/c9UUfASVPoU
if !t.t.Stop() && !t.read {
<-t.t.C
}
if !deadline.IsZero() {
t.t.Reset(time.Until(deadline))
}
t.read = false
t.deadline = deadline
}
// SetRead should be called after the value from the chan was read
func (t *Timer) SetRead() {
t.read = true
}
// Stop stops the timer
func (t *Timer) Stop() {
t.t.Stop()
}

View File

@@ -0,0 +1,251 @@
package wire
import (
"bytes"
"errors"
"sort"
"time"
"github.com/lucas-clemente/quic-go/internal/protocol"
"github.com/lucas-clemente/quic-go/internal/utils"
"github.com/lucas-clemente/quic-go/quicvarint"
)
var errInvalidAckRanges = errors.New("AckFrame: ACK frame contains invalid ACK ranges")
// An AckFrame is an ACK frame
type AckFrame struct {
AckRanges []AckRange // has to be ordered. The highest ACK range goes first, the lowest ACK range goes last
DelayTime time.Duration
ECT0, ECT1, ECNCE uint64
}
// parseAckFrame reads an ACK frame
func parseAckFrame(r *bytes.Reader, ackDelayExponent uint8, _ protocol.VersionNumber) (*AckFrame, error) {
typeByte, err := r.ReadByte()
if err != nil {
return nil, err
}
ecn := typeByte&0x1 > 0
frame := &AckFrame{}
la, err := quicvarint.Read(r)
if err != nil {
return nil, err
}
largestAcked := protocol.PacketNumber(la)
delay, err := quicvarint.Read(r)
if err != nil {
return nil, err
}
delayTime := time.Duration(delay*1<<ackDelayExponent) * time.Microsecond
if delayTime < 0 {
// If the delay time overflows, set it to the maximum encodable value.
delayTime = utils.InfDuration
}
frame.DelayTime = delayTime
numBlocks, err := quicvarint.Read(r)
if err != nil {
return nil, err
}
// read the first ACK range
ab, err := quicvarint.Read(r)
if err != nil {
return nil, err
}
ackBlock := protocol.PacketNumber(ab)
if ackBlock > largestAcked {
return nil, errors.New("invalid first ACK range")
}
smallest := largestAcked - ackBlock
// read all the other ACK ranges
frame.AckRanges = append(frame.AckRanges, AckRange{Smallest: smallest, Largest: largestAcked})
for i := uint64(0); i < numBlocks; i++ {
g, err := quicvarint.Read(r)
if err != nil {
return nil, err
}
gap := protocol.PacketNumber(g)
if smallest < gap+2 {
return nil, errInvalidAckRanges
}
largest := smallest - gap - 2
ab, err := quicvarint.Read(r)
if err != nil {
return nil, err
}
ackBlock := protocol.PacketNumber(ab)
if ackBlock > largest {
return nil, errInvalidAckRanges
}
smallest = largest - ackBlock
frame.AckRanges = append(frame.AckRanges, AckRange{Smallest: smallest, Largest: largest})
}
if !frame.validateAckRanges() {
return nil, errInvalidAckRanges
}
// parse (and skip) the ECN section
if ecn {
for i := 0; i < 3; i++ {
if _, err := quicvarint.Read(r); err != nil {
return nil, err
}
}
}
return frame, nil
}
// Write writes an ACK frame.
func (f *AckFrame) Write(b *bytes.Buffer, _ protocol.VersionNumber) error {
hasECN := f.ECT0 > 0 || f.ECT1 > 0 || f.ECNCE > 0
if hasECN {
b.WriteByte(0x3)
} else {
b.WriteByte(0x2)
}
quicvarint.Write(b, uint64(f.LargestAcked()))
quicvarint.Write(b, encodeAckDelay(f.DelayTime))
numRanges := f.numEncodableAckRanges()
quicvarint.Write(b, uint64(numRanges-1))
// write the first range
_, firstRange := f.encodeAckRange(0)
quicvarint.Write(b, firstRange)
// write all the other range
for i := 1; i < numRanges; i++ {
gap, len := f.encodeAckRange(i)
quicvarint.Write(b, gap)
quicvarint.Write(b, len)
}
if hasECN {
quicvarint.Write(b, f.ECT0)
quicvarint.Write(b, f.ECT1)
quicvarint.Write(b, f.ECNCE)
}
return nil
}
// Length of a written frame
func (f *AckFrame) Length(version protocol.VersionNumber) protocol.ByteCount {
largestAcked := f.AckRanges[0].Largest
numRanges := f.numEncodableAckRanges()
length := 1 + quicvarint.Len(uint64(largestAcked)) + quicvarint.Len(encodeAckDelay(f.DelayTime))
length += quicvarint.Len(uint64(numRanges - 1))
lowestInFirstRange := f.AckRanges[0].Smallest
length += quicvarint.Len(uint64(largestAcked - lowestInFirstRange))
for i := 1; i < numRanges; i++ {
gap, len := f.encodeAckRange(i)
length += quicvarint.Len(gap)
length += quicvarint.Len(len)
}
if f.ECT0 > 0 || f.ECT1 > 0 || f.ECNCE > 0 {
length += quicvarint.Len(f.ECT0)
length += quicvarint.Len(f.ECT1)
length += quicvarint.Len(f.ECNCE)
}
return length
}
// gets the number of ACK ranges that can be encoded
// such that the resulting frame is smaller than the maximum ACK frame size
func (f *AckFrame) numEncodableAckRanges() int {
length := 1 + quicvarint.Len(uint64(f.LargestAcked())) + quicvarint.Len(encodeAckDelay(f.DelayTime))
length += 2 // assume that the number of ranges will consume 2 bytes
for i := 1; i < len(f.AckRanges); i++ {
gap, len := f.encodeAckRange(i)
rangeLen := quicvarint.Len(gap) + quicvarint.Len(len)
if length+rangeLen > protocol.MaxAckFrameSize {
// Writing range i would exceed the MaxAckFrameSize.
// So encode one range less than that.
return i - 1
}
length += rangeLen
}
return len(f.AckRanges)
}
func (f *AckFrame) encodeAckRange(i int) (uint64 /* gap */, uint64 /* length */) {
if i == 0 {
return 0, uint64(f.AckRanges[0].Largest - f.AckRanges[0].Smallest)
}
return uint64(f.AckRanges[i-1].Smallest - f.AckRanges[i].Largest - 2),
uint64(f.AckRanges[i].Largest - f.AckRanges[i].Smallest)
}
// HasMissingRanges returns if this frame reports any missing packets
func (f *AckFrame) HasMissingRanges() bool {
return len(f.AckRanges) > 1
}
func (f *AckFrame) validateAckRanges() bool {
if len(f.AckRanges) == 0 {
return false
}
// check the validity of every single ACK range
for _, ackRange := range f.AckRanges {
if ackRange.Smallest > ackRange.Largest {
return false
}
}
// check the consistency for ACK with multiple NACK ranges
for i, ackRange := range f.AckRanges {
if i == 0 {
continue
}
lastAckRange := f.AckRanges[i-1]
if lastAckRange.Smallest <= ackRange.Smallest {
return false
}
if lastAckRange.Smallest <= ackRange.Largest+1 {
return false
}
}
return true
}
// LargestAcked is the largest acked packet number
func (f *AckFrame) LargestAcked() protocol.PacketNumber {
return f.AckRanges[0].Largest
}
// LowestAcked is the lowest acked packet number
func (f *AckFrame) LowestAcked() protocol.PacketNumber {
return f.AckRanges[len(f.AckRanges)-1].Smallest
}
// AcksPacket determines if this ACK frame acks a certain packet number
func (f *AckFrame) AcksPacket(p protocol.PacketNumber) bool {
if p < f.LowestAcked() || p > f.LargestAcked() {
return false
}
i := sort.Search(len(f.AckRanges), func(i int) bool {
return p >= f.AckRanges[i].Smallest
})
// i will always be < len(f.AckRanges), since we checked above that p is not bigger than the largest acked
return p <= f.AckRanges[i].Largest
}
func encodeAckDelay(delay time.Duration) uint64 {
return uint64(delay.Nanoseconds() / (1000 * (1 << protocol.AckDelayExponent)))
}

View File

@@ -0,0 +1,14 @@
package wire
import "github.com/lucas-clemente/quic-go/internal/protocol"
// AckRange is an ACK range
type AckRange struct {
Smallest protocol.PacketNumber
Largest protocol.PacketNumber
}
// Len returns the number of packets contained in this ACK range
func (r AckRange) Len() protocol.PacketNumber {
return r.Largest - r.Smallest + 1
}

View File

@@ -0,0 +1,83 @@
package wire
import (
"bytes"
"io"
"github.com/lucas-clemente/quic-go/internal/protocol"
"github.com/lucas-clemente/quic-go/quicvarint"
)
// A ConnectionCloseFrame is a CONNECTION_CLOSE frame
type ConnectionCloseFrame struct {
IsApplicationError bool
ErrorCode uint64
FrameType uint64
ReasonPhrase string
}
func parseConnectionCloseFrame(r *bytes.Reader, _ protocol.VersionNumber) (*ConnectionCloseFrame, error) {
typeByte, err := r.ReadByte()
if err != nil {
return nil, err
}
f := &ConnectionCloseFrame{IsApplicationError: typeByte == 0x1d}
ec, err := quicvarint.Read(r)
if err != nil {
return nil, err
}
f.ErrorCode = ec
// read the Frame Type, if this is not an application error
if !f.IsApplicationError {
ft, err := quicvarint.Read(r)
if err != nil {
return nil, err
}
f.FrameType = ft
}
var reasonPhraseLen uint64
reasonPhraseLen, err = quicvarint.Read(r)
if err != nil {
return nil, err
}
// shortcut to prevent the unnecessary allocation of dataLen bytes
// if the dataLen is larger than the remaining length of the packet
// reading the whole reason phrase would result in EOF when attempting to READ
if int(reasonPhraseLen) > r.Len() {
return nil, io.EOF
}
reasonPhrase := make([]byte, reasonPhraseLen)
if _, err := io.ReadFull(r, reasonPhrase); err != nil {
// this should never happen, since we already checked the reasonPhraseLen earlier
return nil, err
}
f.ReasonPhrase = string(reasonPhrase)
return f, nil
}
// Length of a written frame
func (f *ConnectionCloseFrame) Length(protocol.VersionNumber) protocol.ByteCount {
length := 1 + quicvarint.Len(f.ErrorCode) + quicvarint.Len(uint64(len(f.ReasonPhrase))) + protocol.ByteCount(len(f.ReasonPhrase))
if !f.IsApplicationError {
length += quicvarint.Len(f.FrameType) // for the frame type
}
return length
}
func (f *ConnectionCloseFrame) Write(b *bytes.Buffer, version protocol.VersionNumber) error {
if f.IsApplicationError {
b.WriteByte(0x1d)
} else {
b.WriteByte(0x1c)
}
quicvarint.Write(b, f.ErrorCode)
if !f.IsApplicationError {
quicvarint.Write(b, f.FrameType)
}
quicvarint.Write(b, uint64(len(f.ReasonPhrase)))
b.WriteString(f.ReasonPhrase)
return nil
}

View File

@@ -0,0 +1,102 @@
package wire
import (
"bytes"
"io"
"github.com/lucas-clemente/quic-go/internal/protocol"
"github.com/lucas-clemente/quic-go/quicvarint"
)
// A CryptoFrame is a CRYPTO frame
type CryptoFrame struct {
Offset protocol.ByteCount
Data []byte
}
func parseCryptoFrame(r *bytes.Reader, _ protocol.VersionNumber) (*CryptoFrame, error) {
if _, err := r.ReadByte(); err != nil {
return nil, err
}
frame := &CryptoFrame{}
offset, err := quicvarint.Read(r)
if err != nil {
return nil, err
}
frame.Offset = protocol.ByteCount(offset)
dataLen, err := quicvarint.Read(r)
if err != nil {
return nil, err
}
if dataLen > uint64(r.Len()) {
return nil, io.EOF
}
if dataLen != 0 {
frame.Data = make([]byte, dataLen)
if _, err := io.ReadFull(r, frame.Data); err != nil {
// this should never happen, since we already checked the dataLen earlier
return nil, err
}
}
return frame, nil
}
func (f *CryptoFrame) Write(b *bytes.Buffer, _ protocol.VersionNumber) error {
b.WriteByte(0x6)
quicvarint.Write(b, uint64(f.Offset))
quicvarint.Write(b, uint64(len(f.Data)))
b.Write(f.Data)
return nil
}
// Length of a written frame
func (f *CryptoFrame) Length(_ protocol.VersionNumber) protocol.ByteCount {
return 1 + quicvarint.Len(uint64(f.Offset)) + quicvarint.Len(uint64(len(f.Data))) + protocol.ByteCount(len(f.Data))
}
// MaxDataLen returns the maximum data length
func (f *CryptoFrame) MaxDataLen(maxSize protocol.ByteCount) protocol.ByteCount {
// pretend that the data size will be 1 bytes
// if it turns out that varint encoding the length will consume 2 bytes, we need to adjust the data length afterwards
headerLen := 1 + quicvarint.Len(uint64(f.Offset)) + 1
if headerLen > maxSize {
return 0
}
maxDataLen := maxSize - headerLen
if quicvarint.Len(uint64(maxDataLen)) != 1 {
maxDataLen--
}
return maxDataLen
}
// MaybeSplitOffFrame splits a frame such that it is not bigger than n bytes.
// It returns if the frame was actually split.
// The frame might not be split if:
// * the size is large enough to fit the whole frame
// * the size is too small to fit even a 1-byte frame. In that case, the frame returned is nil.
func (f *CryptoFrame) MaybeSplitOffFrame(maxSize protocol.ByteCount, version protocol.VersionNumber) (*CryptoFrame, bool /* was splitting required */) {
if f.Length(version) <= maxSize {
return nil, false
}
n := f.MaxDataLen(maxSize)
if n == 0 {
return nil, true
}
newLen := protocol.ByteCount(len(f.Data)) - n
new := &CryptoFrame{}
new.Offset = f.Offset
new.Data = make([]byte, newLen)
// swap the data slices
new.Data, f.Data = f.Data, new.Data
copy(f.Data, new.Data[n:])
new.Data = new.Data[:n]
f.Offset += n
return new, true
}

View File

@@ -0,0 +1,38 @@
package wire
import (
"bytes"
"github.com/lucas-clemente/quic-go/internal/protocol"
"github.com/lucas-clemente/quic-go/quicvarint"
)
// A DataBlockedFrame is a DATA_BLOCKED frame
type DataBlockedFrame struct {
MaximumData protocol.ByteCount
}
func parseDataBlockedFrame(r *bytes.Reader, _ protocol.VersionNumber) (*DataBlockedFrame, error) {
if _, err := r.ReadByte(); err != nil {
return nil, err
}
offset, err := quicvarint.Read(r)
if err != nil {
return nil, err
}
return &DataBlockedFrame{
MaximumData: protocol.ByteCount(offset),
}, nil
}
func (f *DataBlockedFrame) Write(b *bytes.Buffer, version protocol.VersionNumber) error {
typeByte := uint8(0x14)
b.WriteByte(typeByte)
quicvarint.Write(b, uint64(f.MaximumData))
return nil
}
// Length of a written frame
func (f *DataBlockedFrame) Length(version protocol.VersionNumber) protocol.ByteCount {
return 1 + quicvarint.Len(uint64(f.MaximumData))
}

View File

@@ -0,0 +1,85 @@
package wire
import (
"bytes"
"io"
"github.com/lucas-clemente/quic-go/internal/protocol"
"github.com/lucas-clemente/quic-go/quicvarint"
)
// A DatagramFrame is a DATAGRAM frame
type DatagramFrame struct {
DataLenPresent bool
Data []byte
}
func parseDatagramFrame(r *bytes.Reader, _ protocol.VersionNumber) (*DatagramFrame, error) {
typeByte, err := r.ReadByte()
if err != nil {
return nil, err
}
f := &DatagramFrame{}
f.DataLenPresent = typeByte&0x1 > 0
var length uint64
if f.DataLenPresent {
var err error
len, err := quicvarint.Read(r)
if err != nil {
return nil, err
}
if len > uint64(r.Len()) {
return nil, io.EOF
}
length = len
} else {
length = uint64(r.Len())
}
f.Data = make([]byte, length)
if _, err := io.ReadFull(r, f.Data); err != nil {
return nil, err
}
return f, nil
}
func (f *DatagramFrame) Write(b *bytes.Buffer, _ protocol.VersionNumber) error {
typeByte := uint8(0x30)
if f.DataLenPresent {
typeByte ^= 0x1
}
b.WriteByte(typeByte)
if f.DataLenPresent {
quicvarint.Write(b, uint64(len(f.Data)))
}
b.Write(f.Data)
return nil
}
// MaxDataLen returns the maximum data length
func (f *DatagramFrame) MaxDataLen(maxSize protocol.ByteCount, version protocol.VersionNumber) protocol.ByteCount {
headerLen := protocol.ByteCount(1)
if f.DataLenPresent {
// pretend that the data size will be 1 bytes
// if it turns out that varint encoding the length will consume 2 bytes, we need to adjust the data length afterwards
headerLen++
}
if headerLen > maxSize {
return 0
}
maxDataLen := maxSize - headerLen
if f.DataLenPresent && quicvarint.Len(uint64(maxDataLen)) != 1 {
maxDataLen--
}
return maxDataLen
}
// Length of a written frame
func (f *DatagramFrame) Length(_ protocol.VersionNumber) protocol.ByteCount {
length := 1 + protocol.ByteCount(len(f.Data))
if f.DataLenPresent {
length += quicvarint.Len(uint64(len(f.Data)))
}
return length
}

View File

@@ -0,0 +1,235 @@
package wire
import (
"bytes"
"errors"
"fmt"
"io"
"github.com/lucas-clemente/quic-go/internal/protocol"
"github.com/lucas-clemente/quic-go/internal/utils"
"github.com/lucas-clemente/quic-go/quicvarint"
)
// ErrInvalidReservedBits is returned when the reserved bits are incorrect.
// When this error is returned, parsing continues, and an ExtendedHeader is returned.
// This is necessary because we need to decrypt the packet in that case,
// in order to avoid a timing side-channel.
var ErrInvalidReservedBits = errors.New("invalid reserved bits")
// ExtendedHeader is the header of a QUIC packet.
type ExtendedHeader struct {
Header
typeByte byte
KeyPhase protocol.KeyPhaseBit
PacketNumberLen protocol.PacketNumberLen
PacketNumber protocol.PacketNumber
parsedLen protocol.ByteCount
}
func (h *ExtendedHeader) parse(b *bytes.Reader, v protocol.VersionNumber) (bool /* reserved bits valid */, error) {
startLen := b.Len()
// read the (now unencrypted) first byte
var err error
h.typeByte, err = b.ReadByte()
if err != nil {
return false, err
}
if _, err := b.Seek(int64(h.Header.ParsedLen())-1, io.SeekCurrent); err != nil {
return false, err
}
var reservedBitsValid bool
if h.IsLongHeader {
reservedBitsValid, err = h.parseLongHeader(b, v)
} else {
reservedBitsValid, err = h.parseShortHeader(b, v)
}
if err != nil {
return false, err
}
h.parsedLen = protocol.ByteCount(startLen - b.Len())
return reservedBitsValid, err
}
func (h *ExtendedHeader) parseLongHeader(b *bytes.Reader, _ protocol.VersionNumber) (bool /* reserved bits valid */, error) {
if err := h.readPacketNumber(b); err != nil {
return false, err
}
if h.typeByte&0xc != 0 {
return false, nil
}
return true, nil
}
func (h *ExtendedHeader) parseShortHeader(b *bytes.Reader, _ protocol.VersionNumber) (bool /* reserved bits valid */, error) {
h.KeyPhase = protocol.KeyPhaseZero
if h.typeByte&0x4 > 0 {
h.KeyPhase = protocol.KeyPhaseOne
}
if err := h.readPacketNumber(b); err != nil {
return false, err
}
if h.typeByte&0x18 != 0 {
return false, nil
}
return true, nil
}
func (h *ExtendedHeader) readPacketNumber(b *bytes.Reader) error {
h.PacketNumberLen = protocol.PacketNumberLen(h.typeByte&0x3) + 1
switch h.PacketNumberLen {
case protocol.PacketNumberLen1:
n, err := b.ReadByte()
if err != nil {
return err
}
h.PacketNumber = protocol.PacketNumber(n)
case protocol.PacketNumberLen2:
n, err := utils.BigEndian.ReadUint16(b)
if err != nil {
return err
}
h.PacketNumber = protocol.PacketNumber(n)
case protocol.PacketNumberLen3:
n, err := utils.BigEndian.ReadUint24(b)
if err != nil {
return err
}
h.PacketNumber = protocol.PacketNumber(n)
case protocol.PacketNumberLen4:
n, err := utils.BigEndian.ReadUint32(b)
if err != nil {
return err
}
h.PacketNumber = protocol.PacketNumber(n)
default:
return fmt.Errorf("invalid packet number length: %d", h.PacketNumberLen)
}
return nil
}
// Write writes the Header.
func (h *ExtendedHeader) Write(b *bytes.Buffer, ver protocol.VersionNumber) error {
if h.DestConnectionID.Len() > protocol.MaxConnIDLen {
return fmt.Errorf("invalid connection ID length: %d bytes", h.DestConnectionID.Len())
}
if h.SrcConnectionID.Len() > protocol.MaxConnIDLen {
return fmt.Errorf("invalid connection ID length: %d bytes", h.SrcConnectionID.Len())
}
if h.IsLongHeader {
return h.writeLongHeader(b, ver)
}
return h.writeShortHeader(b, ver)
}
func (h *ExtendedHeader) writeLongHeader(b *bytes.Buffer, _ protocol.VersionNumber) error {
var packetType uint8
//nolint:exhaustive
switch h.Type {
case protocol.PacketTypeInitial:
packetType = 0x0
case protocol.PacketType0RTT:
packetType = 0x1
case protocol.PacketTypeHandshake:
packetType = 0x2
case protocol.PacketTypeRetry:
packetType = 0x3
}
firstByte := 0xc0 | packetType<<4
if h.Type != protocol.PacketTypeRetry {
// Retry packets don't have a packet number
firstByte |= uint8(h.PacketNumberLen - 1)
}
b.WriteByte(firstByte)
utils.BigEndian.WriteUint32(b, uint32(h.Version))
b.WriteByte(uint8(h.DestConnectionID.Len()))
b.Write(h.DestConnectionID.Bytes())
b.WriteByte(uint8(h.SrcConnectionID.Len()))
b.Write(h.SrcConnectionID.Bytes())
//nolint:exhaustive
switch h.Type {
case protocol.PacketTypeRetry:
b.Write(h.Token)
return nil
case protocol.PacketTypeInitial:
quicvarint.Write(b, uint64(len(h.Token)))
b.Write(h.Token)
}
quicvarint.WriteWithLen(b, uint64(h.Length), 2)
return h.writePacketNumber(b)
}
func (h *ExtendedHeader) writeShortHeader(b *bytes.Buffer, _ protocol.VersionNumber) error {
typeByte := 0x40 | uint8(h.PacketNumberLen-1)
if h.KeyPhase == protocol.KeyPhaseOne {
typeByte |= byte(1 << 2)
}
b.WriteByte(typeByte)
b.Write(h.DestConnectionID.Bytes())
return h.writePacketNumber(b)
}
func (h *ExtendedHeader) writePacketNumber(b *bytes.Buffer) error {
switch h.PacketNumberLen {
case protocol.PacketNumberLen1:
b.WriteByte(uint8(h.PacketNumber))
case protocol.PacketNumberLen2:
utils.BigEndian.WriteUint16(b, uint16(h.PacketNumber))
case protocol.PacketNumberLen3:
utils.BigEndian.WriteUint24(b, uint32(h.PacketNumber))
case protocol.PacketNumberLen4:
utils.BigEndian.WriteUint32(b, uint32(h.PacketNumber))
default:
return fmt.Errorf("invalid packet number length: %d", h.PacketNumberLen)
}
return nil
}
// ParsedLen returns the number of bytes that were consumed when parsing the header
func (h *ExtendedHeader) ParsedLen() protocol.ByteCount {
return h.parsedLen
}
// GetLength determines the length of the Header.
func (h *ExtendedHeader) GetLength(v protocol.VersionNumber) protocol.ByteCount {
if h.IsLongHeader {
length := 1 /* type byte */ + 4 /* version */ + 1 /* dest conn ID len */ + protocol.ByteCount(h.DestConnectionID.Len()) + 1 /* src conn ID len */ + protocol.ByteCount(h.SrcConnectionID.Len()) + protocol.ByteCount(h.PacketNumberLen) + 2 /* length */
if h.Type == protocol.PacketTypeInitial {
length += quicvarint.Len(uint64(len(h.Token))) + protocol.ByteCount(len(h.Token))
}
return length
}
length := protocol.ByteCount(1 /* type byte */ + h.DestConnectionID.Len())
length += protocol.ByteCount(h.PacketNumberLen)
return length
}
// Log logs the Header
func (h *ExtendedHeader) Log(logger utils.Logger) {
if h.IsLongHeader {
var token string
if h.Type == protocol.PacketTypeInitial || h.Type == protocol.PacketTypeRetry {
if len(h.Token) == 0 {
token = "Token: (empty), "
} else {
token = fmt.Sprintf("Token: %#x, ", h.Token)
}
if h.Type == protocol.PacketTypeRetry {
logger.Debugf("\tLong Header{Type: %s, DestConnectionID: %s, SrcConnectionID: %s, %sVersion: %s}", h.Type, h.DestConnectionID, h.SrcConnectionID, token, h.Version)
return
}
}
logger.Debugf("\tLong Header{Type: %s, DestConnectionID: %s, SrcConnectionID: %s, %sPacketNumber: %d, PacketNumberLen: %d, Length: %d, Version: %s}", h.Type, h.DestConnectionID, h.SrcConnectionID, token, h.PacketNumber, h.PacketNumberLen, h.Length, h.Version)
} else {
logger.Debugf("\tShort Header{DestConnectionID: %s, PacketNumber: %d, PacketNumberLen: %d, KeyPhase: %s}", h.DestConnectionID, h.PacketNumber, h.PacketNumberLen, h.KeyPhase)
}
}

View File

@@ -0,0 +1,143 @@
package wire
import (
"bytes"
"errors"
"fmt"
"reflect"
"github.com/lucas-clemente/quic-go/internal/protocol"
"github.com/lucas-clemente/quic-go/internal/qerr"
)
type frameParser struct {
ackDelayExponent uint8
supportsDatagrams bool
version protocol.VersionNumber
}
// NewFrameParser creates a new frame parser.
func NewFrameParser(supportsDatagrams bool, v protocol.VersionNumber) FrameParser {
return &frameParser{
supportsDatagrams: supportsDatagrams,
version: v,
}
}
// ParseNext parses the next frame.
// It skips PADDING frames.
func (p *frameParser) ParseNext(r *bytes.Reader, encLevel protocol.EncryptionLevel) (Frame, error) {
for r.Len() != 0 {
typeByte, _ := r.ReadByte()
if typeByte == 0x0 { // PADDING frame
continue
}
r.UnreadByte()
f, err := p.parseFrame(r, typeByte, encLevel)
if err != nil {
return nil, &qerr.TransportError{
FrameType: uint64(typeByte),
ErrorCode: qerr.FrameEncodingError,
ErrorMessage: err.Error(),
}
}
return f, nil
}
return nil, nil
}
func (p *frameParser) parseFrame(r *bytes.Reader, typeByte byte, encLevel protocol.EncryptionLevel) (Frame, error) {
var frame Frame
var err error
if typeByte&0xf8 == 0x8 {
frame, err = parseStreamFrame(r, p.version)
} else {
switch typeByte {
case 0x1:
frame, err = parsePingFrame(r, p.version)
case 0x2, 0x3:
ackDelayExponent := p.ackDelayExponent
if encLevel != protocol.Encryption1RTT {
ackDelayExponent = protocol.DefaultAckDelayExponent
}
frame, err = parseAckFrame(r, ackDelayExponent, p.version)
case 0x4:
frame, err = parseResetStreamFrame(r, p.version)
case 0x5:
frame, err = parseStopSendingFrame(r, p.version)
case 0x6:
frame, err = parseCryptoFrame(r, p.version)
case 0x7:
frame, err = parseNewTokenFrame(r, p.version)
case 0x10:
frame, err = parseMaxDataFrame(r, p.version)
case 0x11:
frame, err = parseMaxStreamDataFrame(r, p.version)
case 0x12, 0x13:
frame, err = parseMaxStreamsFrame(r, p.version)
case 0x14:
frame, err = parseDataBlockedFrame(r, p.version)
case 0x15:
frame, err = parseStreamDataBlockedFrame(r, p.version)
case 0x16, 0x17:
frame, err = parseStreamsBlockedFrame(r, p.version)
case 0x18:
frame, err = parseNewConnectionIDFrame(r, p.version)
case 0x19:
frame, err = parseRetireConnectionIDFrame(r, p.version)
case 0x1a:
frame, err = parsePathChallengeFrame(r, p.version)
case 0x1b:
frame, err = parsePathResponseFrame(r, p.version)
case 0x1c, 0x1d:
frame, err = parseConnectionCloseFrame(r, p.version)
case 0x1e:
frame, err = parseHandshakeDoneFrame(r, p.version)
case 0x30, 0x31:
if p.supportsDatagrams {
frame, err = parseDatagramFrame(r, p.version)
break
}
fallthrough
default:
err = errors.New("unknown frame type")
}
}
if err != nil {
return nil, err
}
if !p.isAllowedAtEncLevel(frame, encLevel) {
return nil, fmt.Errorf("%s not allowed at encryption level %s", reflect.TypeOf(frame).Elem().Name(), encLevel)
}
return frame, nil
}
func (p *frameParser) isAllowedAtEncLevel(f Frame, encLevel protocol.EncryptionLevel) bool {
switch encLevel {
case protocol.EncryptionInitial, protocol.EncryptionHandshake:
switch f.(type) {
case *CryptoFrame, *AckFrame, *ConnectionCloseFrame, *PingFrame:
return true
default:
return false
}
case protocol.Encryption0RTT:
switch f.(type) {
case *CryptoFrame, *AckFrame, *ConnectionCloseFrame, *NewTokenFrame, *PathResponseFrame, *RetireConnectionIDFrame:
return false
default:
return true
}
case protocol.Encryption1RTT:
return true
default:
panic("unknown encryption level")
}
}
func (p *frameParser) SetAckDelayExponent(exp uint8) {
p.ackDelayExponent = exp
}

View File

@@ -0,0 +1,28 @@
package wire
import (
"bytes"
"github.com/lucas-clemente/quic-go/internal/protocol"
)
// A HandshakeDoneFrame is a HANDSHAKE_DONE frame
type HandshakeDoneFrame struct{}
// ParseHandshakeDoneFrame parses a HandshakeDone frame
func parseHandshakeDoneFrame(r *bytes.Reader, _ protocol.VersionNumber) (*HandshakeDoneFrame, error) {
if _, err := r.ReadByte(); err != nil {
return nil, err
}
return &HandshakeDoneFrame{}, nil
}
func (f *HandshakeDoneFrame) Write(b *bytes.Buffer, _ protocol.VersionNumber) error {
b.WriteByte(0x1e)
return nil
}
// Length of a written frame
func (f *HandshakeDoneFrame) Length(_ protocol.VersionNumber) protocol.ByteCount {
return 1
}

View File

@@ -0,0 +1,257 @@
package wire
import (
"bytes"
"encoding/binary"
"errors"
"fmt"
"io"
"github.com/lucas-clemente/quic-go/internal/protocol"
"github.com/lucas-clemente/quic-go/internal/utils"
"github.com/lucas-clemente/quic-go/quicvarint"
)
// ParseConnectionID parses the destination connection ID of a packet.
// It uses the data slice for the connection ID.
// That means that the connection ID must not be used after the packet buffer is released.
func ParseConnectionID(data []byte, shortHeaderConnIDLen int) (protocol.ConnectionID, error) {
if len(data) == 0 {
return nil, io.EOF
}
isLongHeader := data[0]&0x80 > 0
if !isLongHeader {
if len(data) < shortHeaderConnIDLen+1 {
return nil, io.EOF
}
return protocol.ConnectionID(data[1 : 1+shortHeaderConnIDLen]), nil
}
if len(data) < 6 {
return nil, io.EOF
}
destConnIDLen := int(data[5])
if len(data) < 6+destConnIDLen {
return nil, io.EOF
}
return protocol.ConnectionID(data[6 : 6+destConnIDLen]), nil
}
// IsVersionNegotiationPacket says if this is a version negotiation packet
func IsVersionNegotiationPacket(b []byte) bool {
if len(b) < 5 {
return false
}
return b[0]&0x80 > 0 && b[1] == 0 && b[2] == 0 && b[3] == 0 && b[4] == 0
}
// Is0RTTPacket says if this is a 0-RTT packet.
// A packet sent with a version we don't understand can never be a 0-RTT packet.
func Is0RTTPacket(b []byte) bool {
if len(b) < 5 {
return false
}
if b[0]&0x80 == 0 {
return false
}
if !protocol.IsSupportedVersion(protocol.SupportedVersions, protocol.VersionNumber(binary.BigEndian.Uint32(b[1:5]))) {
return false
}
return b[0]&0x30>>4 == 0x1
}
var ErrUnsupportedVersion = errors.New("unsupported version")
// The Header is the version independent part of the header
type Header struct {
IsLongHeader bool
typeByte byte
Type protocol.PacketType
Version protocol.VersionNumber
SrcConnectionID protocol.ConnectionID
DestConnectionID protocol.ConnectionID
Length protocol.ByteCount
Token []byte
parsedLen protocol.ByteCount // how many bytes were read while parsing this header
}
// ParsePacket parses a packet.
// If the packet has a long header, the packet is cut according to the length field.
// If we understand the version, the packet is header up unto the packet number.
// Otherwise, only the invariant part of the header is parsed.
func ParsePacket(data []byte, shortHeaderConnIDLen int) (*Header, []byte /* packet data */, []byte /* rest */, error) {
hdr, err := parseHeader(bytes.NewReader(data), shortHeaderConnIDLen)
if err != nil {
if err == ErrUnsupportedVersion {
return hdr, nil, nil, ErrUnsupportedVersion
}
return nil, nil, nil, err
}
var rest []byte
if hdr.IsLongHeader {
if protocol.ByteCount(len(data)) < hdr.ParsedLen()+hdr.Length {
return nil, nil, nil, fmt.Errorf("packet length (%d bytes) is smaller than the expected length (%d bytes)", len(data)-int(hdr.ParsedLen()), hdr.Length)
}
packetLen := int(hdr.ParsedLen() + hdr.Length)
rest = data[packetLen:]
data = data[:packetLen]
}
return hdr, data, rest, nil
}
// ParseHeader parses the header.
// For short header packets: up to the packet number.
// For long header packets:
// * if we understand the version: up to the packet number
// * if not, only the invariant part of the header
func parseHeader(b *bytes.Reader, shortHeaderConnIDLen int) (*Header, error) {
startLen := b.Len()
h, err := parseHeaderImpl(b, shortHeaderConnIDLen)
if err != nil {
return h, err
}
h.parsedLen = protocol.ByteCount(startLen - b.Len())
return h, err
}
func parseHeaderImpl(b *bytes.Reader, shortHeaderConnIDLen int) (*Header, error) {
typeByte, err := b.ReadByte()
if err != nil {
return nil, err
}
h := &Header{
typeByte: typeByte,
IsLongHeader: typeByte&0x80 > 0,
}
if !h.IsLongHeader {
if h.typeByte&0x40 == 0 {
return nil, errors.New("not a QUIC packet")
}
if err := h.parseShortHeader(b, shortHeaderConnIDLen); err != nil {
return nil, err
}
return h, nil
}
return h, h.parseLongHeader(b)
}
func (h *Header) parseShortHeader(b *bytes.Reader, shortHeaderConnIDLen int) error {
var err error
h.DestConnectionID, err = protocol.ReadConnectionID(b, shortHeaderConnIDLen)
return err
}
func (h *Header) parseLongHeader(b *bytes.Reader) error {
v, err := utils.BigEndian.ReadUint32(b)
if err != nil {
return err
}
h.Version = protocol.VersionNumber(v)
if h.Version != 0 && h.typeByte&0x40 == 0 {
return errors.New("not a QUIC packet")
}
destConnIDLen, err := b.ReadByte()
if err != nil {
return err
}
h.DestConnectionID, err = protocol.ReadConnectionID(b, int(destConnIDLen))
if err != nil {
return err
}
srcConnIDLen, err := b.ReadByte()
if err != nil {
return err
}
h.SrcConnectionID, err = protocol.ReadConnectionID(b, int(srcConnIDLen))
if err != nil {
return err
}
if h.Version == 0 { // version negotiation packet
return nil
}
// If we don't understand the version, we have no idea how to interpret the rest of the bytes
if !protocol.IsSupportedVersion(protocol.SupportedVersions, h.Version) {
return ErrUnsupportedVersion
}
switch (h.typeByte & 0x30) >> 4 {
case 0x0:
h.Type = protocol.PacketTypeInitial
case 0x1:
h.Type = protocol.PacketType0RTT
case 0x2:
h.Type = protocol.PacketTypeHandshake
case 0x3:
h.Type = protocol.PacketTypeRetry
}
if h.Type == protocol.PacketTypeRetry {
tokenLen := b.Len() - 16
if tokenLen <= 0 {
return io.EOF
}
h.Token = make([]byte, tokenLen)
if _, err := io.ReadFull(b, h.Token); err != nil {
return err
}
_, err := b.Seek(16, io.SeekCurrent)
return err
}
if h.Type == protocol.PacketTypeInitial {
tokenLen, err := quicvarint.Read(b)
if err != nil {
return err
}
if tokenLen > uint64(b.Len()) {
return io.EOF
}
h.Token = make([]byte, tokenLen)
if _, err := io.ReadFull(b, h.Token); err != nil {
return err
}
}
pl, err := quicvarint.Read(b)
if err != nil {
return err
}
h.Length = protocol.ByteCount(pl)
return nil
}
// ParsedLen returns the number of bytes that were consumed when parsing the header
func (h *Header) ParsedLen() protocol.ByteCount {
return h.parsedLen
}
// ParseExtended parses the version dependent part of the header.
// The Reader has to be set such that it points to the first byte of the header.
func (h *Header) ParseExtended(b *bytes.Reader, ver protocol.VersionNumber) (*ExtendedHeader, error) {
extHdr := h.toExtendedHeader()
reservedBitsValid, err := extHdr.parse(b, ver)
if err != nil {
return nil, err
}
if !reservedBitsValid {
return extHdr, ErrInvalidReservedBits
}
return extHdr, nil
}
func (h *Header) toExtendedHeader() *ExtendedHeader {
return &ExtendedHeader{Header: *h}
}
// PacketType is the type of the packet, for logging purposes
func (h *Header) PacketType() string {
if h.IsLongHeader {
return h.Type.String()
}
return "1-RTT"
}

View File

@@ -0,0 +1,19 @@
package wire
import (
"bytes"
"github.com/lucas-clemente/quic-go/internal/protocol"
)
// A Frame in QUIC
type Frame interface {
Write(b *bytes.Buffer, version protocol.VersionNumber) error
Length(version protocol.VersionNumber) protocol.ByteCount
}
// A FrameParser parses QUIC frames, one by one.
type FrameParser interface {
ParseNext(*bytes.Reader, protocol.EncryptionLevel) (Frame, error)
SetAckDelayExponent(uint8)
}

View File

@@ -0,0 +1,72 @@
package wire
import (
"fmt"
"strings"
"github.com/lucas-clemente/quic-go/internal/protocol"
"github.com/lucas-clemente/quic-go/internal/utils"
)
// LogFrame logs a frame, either sent or received
func LogFrame(logger utils.Logger, frame Frame, sent bool) {
if !logger.Debug() {
return
}
dir := "<-"
if sent {
dir = "->"
}
switch f := frame.(type) {
case *CryptoFrame:
dataLen := protocol.ByteCount(len(f.Data))
logger.Debugf("\t%s &wire.CryptoFrame{Offset: %d, Data length: %d, Offset + Data length: %d}", dir, f.Offset, dataLen, f.Offset+dataLen)
case *StreamFrame:
logger.Debugf("\t%s &wire.StreamFrame{StreamID: %d, Fin: %t, Offset: %d, Data length: %d, Offset + Data length: %d}", dir, f.StreamID, f.Fin, f.Offset, f.DataLen(), f.Offset+f.DataLen())
case *ResetStreamFrame:
logger.Debugf("\t%s &wire.ResetStreamFrame{StreamID: %d, ErrorCode: %#x, FinalSize: %d}", dir, f.StreamID, f.ErrorCode, f.FinalSize)
case *AckFrame:
hasECN := f.ECT0 > 0 || f.ECT1 > 0 || f.ECNCE > 0
var ecn string
if hasECN {
ecn = fmt.Sprintf(", ECT0: %d, ECT1: %d, CE: %d", f.ECT0, f.ECT1, f.ECNCE)
}
if len(f.AckRanges) > 1 {
ackRanges := make([]string, len(f.AckRanges))
for i, r := range f.AckRanges {
ackRanges[i] = fmt.Sprintf("{Largest: %d, Smallest: %d}", r.Largest, r.Smallest)
}
logger.Debugf("\t%s &wire.AckFrame{LargestAcked: %d, LowestAcked: %d, AckRanges: {%s}, DelayTime: %s%s}", dir, f.LargestAcked(), f.LowestAcked(), strings.Join(ackRanges, ", "), f.DelayTime.String(), ecn)
} else {
logger.Debugf("\t%s &wire.AckFrame{LargestAcked: %d, LowestAcked: %d, DelayTime: %s%s}", dir, f.LargestAcked(), f.LowestAcked(), f.DelayTime.String(), ecn)
}
case *MaxDataFrame:
logger.Debugf("\t%s &wire.MaxDataFrame{MaximumData: %d}", dir, f.MaximumData)
case *MaxStreamDataFrame:
logger.Debugf("\t%s &wire.MaxStreamDataFrame{StreamID: %d, MaximumStreamData: %d}", dir, f.StreamID, f.MaximumStreamData)
case *DataBlockedFrame:
logger.Debugf("\t%s &wire.DataBlockedFrame{MaximumData: %d}", dir, f.MaximumData)
case *StreamDataBlockedFrame:
logger.Debugf("\t%s &wire.StreamDataBlockedFrame{StreamID: %d, MaximumStreamData: %d}", dir, f.StreamID, f.MaximumStreamData)
case *MaxStreamsFrame:
switch f.Type {
case protocol.StreamTypeUni:
logger.Debugf("\t%s &wire.MaxStreamsFrame{Type: uni, MaxStreamNum: %d}", dir, f.MaxStreamNum)
case protocol.StreamTypeBidi:
logger.Debugf("\t%s &wire.MaxStreamsFrame{Type: bidi, MaxStreamNum: %d}", dir, f.MaxStreamNum)
}
case *StreamsBlockedFrame:
switch f.Type {
case protocol.StreamTypeUni:
logger.Debugf("\t%s &wire.StreamsBlockedFrame{Type: uni, MaxStreams: %d}", dir, f.StreamLimit)
case protocol.StreamTypeBidi:
logger.Debugf("\t%s &wire.StreamsBlockedFrame{Type: bidi, MaxStreams: %d}", dir, f.StreamLimit)
}
case *NewConnectionIDFrame:
logger.Debugf("\t%s &wire.NewConnectionIDFrame{SequenceNumber: %d, ConnectionID: %s, StatelessResetToken: %#x}", dir, f.SequenceNumber, f.ConnectionID, f.StatelessResetToken)
case *NewTokenFrame:
logger.Debugf("\t%s &wire.NewTokenFrame{Token: %#x}", dir, f.Token)
default:
logger.Debugf("\t%s %#v", dir, frame)
}
}

View File

@@ -0,0 +1,40 @@
package wire
import (
"bytes"
"github.com/lucas-clemente/quic-go/internal/protocol"
"github.com/lucas-clemente/quic-go/quicvarint"
)
// A MaxDataFrame carries flow control information for the connection
type MaxDataFrame struct {
MaximumData protocol.ByteCount
}
// parseMaxDataFrame parses a MAX_DATA frame
func parseMaxDataFrame(r *bytes.Reader, _ protocol.VersionNumber) (*MaxDataFrame, error) {
if _, err := r.ReadByte(); err != nil {
return nil, err
}
frame := &MaxDataFrame{}
byteOffset, err := quicvarint.Read(r)
if err != nil {
return nil, err
}
frame.MaximumData = protocol.ByteCount(byteOffset)
return frame, nil
}
// Write writes a MAX_STREAM_DATA frame
func (f *MaxDataFrame) Write(b *bytes.Buffer, version protocol.VersionNumber) error {
b.WriteByte(0x10)
quicvarint.Write(b, uint64(f.MaximumData))
return nil
}
// Length of a written frame
func (f *MaxDataFrame) Length(version protocol.VersionNumber) protocol.ByteCount {
return 1 + quicvarint.Len(uint64(f.MaximumData))
}

View File

@@ -0,0 +1,46 @@
package wire
import (
"bytes"
"github.com/lucas-clemente/quic-go/internal/protocol"
"github.com/lucas-clemente/quic-go/quicvarint"
)
// A MaxStreamDataFrame is a MAX_STREAM_DATA frame
type MaxStreamDataFrame struct {
StreamID protocol.StreamID
MaximumStreamData protocol.ByteCount
}
func parseMaxStreamDataFrame(r *bytes.Reader, _ protocol.VersionNumber) (*MaxStreamDataFrame, error) {
if _, err := r.ReadByte(); err != nil {
return nil, err
}
sid, err := quicvarint.Read(r)
if err != nil {
return nil, err
}
offset, err := quicvarint.Read(r)
if err != nil {
return nil, err
}
return &MaxStreamDataFrame{
StreamID: protocol.StreamID(sid),
MaximumStreamData: protocol.ByteCount(offset),
}, nil
}
func (f *MaxStreamDataFrame) Write(b *bytes.Buffer, version protocol.VersionNumber) error {
b.WriteByte(0x11)
quicvarint.Write(b, uint64(f.StreamID))
quicvarint.Write(b, uint64(f.MaximumStreamData))
return nil
}
// Length of a written frame
func (f *MaxStreamDataFrame) Length(version protocol.VersionNumber) protocol.ByteCount {
return 1 + quicvarint.Len(uint64(f.StreamID)) + quicvarint.Len(uint64(f.MaximumStreamData))
}

View File

@@ -0,0 +1,55 @@
package wire
import (
"bytes"
"fmt"
"github.com/lucas-clemente/quic-go/internal/protocol"
"github.com/lucas-clemente/quic-go/quicvarint"
)
// A MaxStreamsFrame is a MAX_STREAMS frame
type MaxStreamsFrame struct {
Type protocol.StreamType
MaxStreamNum protocol.StreamNum
}
func parseMaxStreamsFrame(r *bytes.Reader, _ protocol.VersionNumber) (*MaxStreamsFrame, error) {
typeByte, err := r.ReadByte()
if err != nil {
return nil, err
}
f := &MaxStreamsFrame{}
switch typeByte {
case 0x12:
f.Type = protocol.StreamTypeBidi
case 0x13:
f.Type = protocol.StreamTypeUni
}
streamID, err := quicvarint.Read(r)
if err != nil {
return nil, err
}
f.MaxStreamNum = protocol.StreamNum(streamID)
if f.MaxStreamNum > protocol.MaxStreamCount {
return nil, fmt.Errorf("%d exceeds the maximum stream count", f.MaxStreamNum)
}
return f, nil
}
func (f *MaxStreamsFrame) Write(b *bytes.Buffer, _ protocol.VersionNumber) error {
switch f.Type {
case protocol.StreamTypeBidi:
b.WriteByte(0x12)
case protocol.StreamTypeUni:
b.WriteByte(0x13)
}
quicvarint.Write(b, uint64(f.MaxStreamNum))
return nil
}
// Length of a written frame
func (f *MaxStreamsFrame) Length(protocol.VersionNumber) protocol.ByteCount {
return 1 + quicvarint.Len(uint64(f.MaxStreamNum))
}

View File

@@ -0,0 +1,80 @@
package wire
import (
"bytes"
"fmt"
"io"
"github.com/lucas-clemente/quic-go/internal/protocol"
"github.com/lucas-clemente/quic-go/quicvarint"
)
// A NewConnectionIDFrame is a NEW_CONNECTION_ID frame
type NewConnectionIDFrame struct {
SequenceNumber uint64
RetirePriorTo uint64
ConnectionID protocol.ConnectionID
StatelessResetToken protocol.StatelessResetToken
}
func parseNewConnectionIDFrame(r *bytes.Reader, _ protocol.VersionNumber) (*NewConnectionIDFrame, error) {
if _, err := r.ReadByte(); err != nil {
return nil, err
}
seq, err := quicvarint.Read(r)
if err != nil {
return nil, err
}
ret, err := quicvarint.Read(r)
if err != nil {
return nil, err
}
if ret > seq {
//nolint:stylecheck
return nil, fmt.Errorf("Retire Prior To value (%d) larger than Sequence Number (%d)", ret, seq)
}
connIDLen, err := r.ReadByte()
if err != nil {
return nil, err
}
if connIDLen > protocol.MaxConnIDLen {
return nil, fmt.Errorf("invalid connection ID length: %d", connIDLen)
}
connID, err := protocol.ReadConnectionID(r, int(connIDLen))
if err != nil {
return nil, err
}
frame := &NewConnectionIDFrame{
SequenceNumber: seq,
RetirePriorTo: ret,
ConnectionID: connID,
}
if _, err := io.ReadFull(r, frame.StatelessResetToken[:]); err != nil {
if err == io.ErrUnexpectedEOF {
return nil, io.EOF
}
return nil, err
}
return frame, nil
}
func (f *NewConnectionIDFrame) Write(b *bytes.Buffer, _ protocol.VersionNumber) error {
b.WriteByte(0x18)
quicvarint.Write(b, f.SequenceNumber)
quicvarint.Write(b, f.RetirePriorTo)
connIDLen := f.ConnectionID.Len()
if connIDLen > protocol.MaxConnIDLen {
return fmt.Errorf("invalid connection ID length: %d", connIDLen)
}
b.WriteByte(uint8(connIDLen))
b.Write(f.ConnectionID.Bytes())
b.Write(f.StatelessResetToken[:])
return nil
}
// Length of a written frame
func (f *NewConnectionIDFrame) Length(protocol.VersionNumber) protocol.ByteCount {
return 1 + quicvarint.Len(f.SequenceNumber) + quicvarint.Len(f.RetirePriorTo) + 1 /* connection ID length */ + protocol.ByteCount(f.ConnectionID.Len()) + 16
}

View File

@@ -0,0 +1,48 @@
package wire
import (
"bytes"
"errors"
"io"
"github.com/lucas-clemente/quic-go/internal/protocol"
"github.com/lucas-clemente/quic-go/quicvarint"
)
// A NewTokenFrame is a NEW_TOKEN frame
type NewTokenFrame struct {
Token []byte
}
func parseNewTokenFrame(r *bytes.Reader, _ protocol.VersionNumber) (*NewTokenFrame, error) {
if _, err := r.ReadByte(); err != nil {
return nil, err
}
tokenLen, err := quicvarint.Read(r)
if err != nil {
return nil, err
}
if uint64(r.Len()) < tokenLen {
return nil, io.EOF
}
if tokenLen == 0 {
return nil, errors.New("token must not be empty")
}
token := make([]byte, int(tokenLen))
if _, err := io.ReadFull(r, token); err != nil {
return nil, err
}
return &NewTokenFrame{Token: token}, nil
}
func (f *NewTokenFrame) Write(b *bytes.Buffer, _ protocol.VersionNumber) error {
b.WriteByte(0x7)
quicvarint.Write(b, uint64(len(f.Token)))
b.Write(f.Token)
return nil
}
// Length of a written frame
func (f *NewTokenFrame) Length(protocol.VersionNumber) protocol.ByteCount {
return 1 + quicvarint.Len(uint64(len(f.Token))) + protocol.ByteCount(len(f.Token))
}

View File

@@ -0,0 +1,38 @@
package wire
import (
"bytes"
"io"
"github.com/lucas-clemente/quic-go/internal/protocol"
)
// A PathChallengeFrame is a PATH_CHALLENGE frame
type PathChallengeFrame struct {
Data [8]byte
}
func parsePathChallengeFrame(r *bytes.Reader, _ protocol.VersionNumber) (*PathChallengeFrame, error) {
if _, err := r.ReadByte(); err != nil {
return nil, err
}
frame := &PathChallengeFrame{}
if _, err := io.ReadFull(r, frame.Data[:]); err != nil {
if err == io.ErrUnexpectedEOF {
return nil, io.EOF
}
return nil, err
}
return frame, nil
}
func (f *PathChallengeFrame) Write(b *bytes.Buffer, _ protocol.VersionNumber) error {
b.WriteByte(0x1a)
b.Write(f.Data[:])
return nil
}
// Length of a written frame
func (f *PathChallengeFrame) Length(_ protocol.VersionNumber) protocol.ByteCount {
return 1 + 8
}

View File

@@ -0,0 +1,38 @@
package wire
import (
"bytes"
"io"
"github.com/lucas-clemente/quic-go/internal/protocol"
)
// A PathResponseFrame is a PATH_RESPONSE frame
type PathResponseFrame struct {
Data [8]byte
}
func parsePathResponseFrame(r *bytes.Reader, _ protocol.VersionNumber) (*PathResponseFrame, error) {
if _, err := r.ReadByte(); err != nil {
return nil, err
}
frame := &PathResponseFrame{}
if _, err := io.ReadFull(r, frame.Data[:]); err != nil {
if err == io.ErrUnexpectedEOF {
return nil, io.EOF
}
return nil, err
}
return frame, nil
}
func (f *PathResponseFrame) Write(b *bytes.Buffer, _ protocol.VersionNumber) error {
b.WriteByte(0x1b)
b.Write(f.Data[:])
return nil
}
// Length of a written frame
func (f *PathResponseFrame) Length(_ protocol.VersionNumber) protocol.ByteCount {
return 1 + 8
}

View File

@@ -0,0 +1,27 @@
package wire
import (
"bytes"
"github.com/lucas-clemente/quic-go/internal/protocol"
)
// A PingFrame is a PING frame
type PingFrame struct{}
func parsePingFrame(r *bytes.Reader, _ protocol.VersionNumber) (*PingFrame, error) {
if _, err := r.ReadByte(); err != nil {
return nil, err
}
return &PingFrame{}, nil
}
func (f *PingFrame) Write(b *bytes.Buffer, version protocol.VersionNumber) error {
b.WriteByte(0x1)
return nil
}
// Length of a written frame
func (f *PingFrame) Length(version protocol.VersionNumber) protocol.ByteCount {
return 1
}

View File

@@ -0,0 +1,33 @@
package wire
import (
"sync"
"github.com/lucas-clemente/quic-go/internal/protocol"
)
var pool sync.Pool
func init() {
pool.New = func() interface{} {
return &StreamFrame{
Data: make([]byte, 0, protocol.MaxPacketBufferSize),
fromPool: true,
}
}
}
func GetStreamFrame() *StreamFrame {
f := pool.Get().(*StreamFrame)
return f
}
func putStreamFrame(f *StreamFrame) {
if !f.fromPool {
return
}
if protocol.ByteCount(cap(f.Data)) != protocol.MaxPacketBufferSize {
panic("wire.PutStreamFrame called with packet of wrong size!")
}
pool.Put(f)
}

View File

@@ -0,0 +1,58 @@
package wire
import (
"bytes"
"github.com/lucas-clemente/quic-go/internal/protocol"
"github.com/lucas-clemente/quic-go/internal/qerr"
"github.com/lucas-clemente/quic-go/quicvarint"
)
// A ResetStreamFrame is a RESET_STREAM frame in QUIC
type ResetStreamFrame struct {
StreamID protocol.StreamID
ErrorCode qerr.StreamErrorCode
FinalSize protocol.ByteCount
}
func parseResetStreamFrame(r *bytes.Reader, _ protocol.VersionNumber) (*ResetStreamFrame, error) {
if _, err := r.ReadByte(); err != nil { // read the TypeByte
return nil, err
}
var streamID protocol.StreamID
var byteOffset protocol.ByteCount
sid, err := quicvarint.Read(r)
if err != nil {
return nil, err
}
streamID = protocol.StreamID(sid)
errorCode, err := quicvarint.Read(r)
if err != nil {
return nil, err
}
bo, err := quicvarint.Read(r)
if err != nil {
return nil, err
}
byteOffset = protocol.ByteCount(bo)
return &ResetStreamFrame{
StreamID: streamID,
ErrorCode: qerr.StreamErrorCode(errorCode),
FinalSize: byteOffset,
}, nil
}
func (f *ResetStreamFrame) Write(b *bytes.Buffer, _ protocol.VersionNumber) error {
b.WriteByte(0x4)
quicvarint.Write(b, uint64(f.StreamID))
quicvarint.Write(b, uint64(f.ErrorCode))
quicvarint.Write(b, uint64(f.FinalSize))
return nil
}
// Length of a written frame
func (f *ResetStreamFrame) Length(version protocol.VersionNumber) protocol.ByteCount {
return 1 + quicvarint.Len(uint64(f.StreamID)) + quicvarint.Len(uint64(f.ErrorCode)) + quicvarint.Len(uint64(f.FinalSize))
}

View File

@@ -0,0 +1,36 @@
package wire
import (
"bytes"
"github.com/lucas-clemente/quic-go/internal/protocol"
"github.com/lucas-clemente/quic-go/quicvarint"
)
// A RetireConnectionIDFrame is a RETIRE_CONNECTION_ID frame
type RetireConnectionIDFrame struct {
SequenceNumber uint64
}
func parseRetireConnectionIDFrame(r *bytes.Reader, _ protocol.VersionNumber) (*RetireConnectionIDFrame, error) {
if _, err := r.ReadByte(); err != nil {
return nil, err
}
seq, err := quicvarint.Read(r)
if err != nil {
return nil, err
}
return &RetireConnectionIDFrame{SequenceNumber: seq}, nil
}
func (f *RetireConnectionIDFrame) Write(b *bytes.Buffer, _ protocol.VersionNumber) error {
b.WriteByte(0x19)
quicvarint.Write(b, f.SequenceNumber)
return nil
}
// Length of a written frame
func (f *RetireConnectionIDFrame) Length(protocol.VersionNumber) protocol.ByteCount {
return 1 + quicvarint.Len(f.SequenceNumber)
}

View File

@@ -0,0 +1,48 @@
package wire
import (
"bytes"
"github.com/lucas-clemente/quic-go/internal/protocol"
"github.com/lucas-clemente/quic-go/internal/qerr"
"github.com/lucas-clemente/quic-go/quicvarint"
)
// A StopSendingFrame is a STOP_SENDING frame
type StopSendingFrame struct {
StreamID protocol.StreamID
ErrorCode qerr.StreamErrorCode
}
// parseStopSendingFrame parses a STOP_SENDING frame
func parseStopSendingFrame(r *bytes.Reader, _ protocol.VersionNumber) (*StopSendingFrame, error) {
if _, err := r.ReadByte(); err != nil {
return nil, err
}
streamID, err := quicvarint.Read(r)
if err != nil {
return nil, err
}
errorCode, err := quicvarint.Read(r)
if err != nil {
return nil, err
}
return &StopSendingFrame{
StreamID: protocol.StreamID(streamID),
ErrorCode: qerr.StreamErrorCode(errorCode),
}, nil
}
// Length of a written frame
func (f *StopSendingFrame) Length(_ protocol.VersionNumber) protocol.ByteCount {
return 1 + quicvarint.Len(uint64(f.StreamID)) + quicvarint.Len(uint64(f.ErrorCode))
}
func (f *StopSendingFrame) Write(b *bytes.Buffer, _ protocol.VersionNumber) error {
b.WriteByte(0x5)
quicvarint.Write(b, uint64(f.StreamID))
quicvarint.Write(b, uint64(f.ErrorCode))
return nil
}

View File

@@ -0,0 +1,46 @@
package wire
import (
"bytes"
"github.com/lucas-clemente/quic-go/internal/protocol"
"github.com/lucas-clemente/quic-go/quicvarint"
)
// A StreamDataBlockedFrame is a STREAM_DATA_BLOCKED frame
type StreamDataBlockedFrame struct {
StreamID protocol.StreamID
MaximumStreamData protocol.ByteCount
}
func parseStreamDataBlockedFrame(r *bytes.Reader, _ protocol.VersionNumber) (*StreamDataBlockedFrame, error) {
if _, err := r.ReadByte(); err != nil {
return nil, err
}
sid, err := quicvarint.Read(r)
if err != nil {
return nil, err
}
offset, err := quicvarint.Read(r)
if err != nil {
return nil, err
}
return &StreamDataBlockedFrame{
StreamID: protocol.StreamID(sid),
MaximumStreamData: protocol.ByteCount(offset),
}, nil
}
func (f *StreamDataBlockedFrame) Write(b *bytes.Buffer, version protocol.VersionNumber) error {
b.WriteByte(0x15)
quicvarint.Write(b, uint64(f.StreamID))
quicvarint.Write(b, uint64(f.MaximumStreamData))
return nil
}
// Length of a written frame
func (f *StreamDataBlockedFrame) Length(version protocol.VersionNumber) protocol.ByteCount {
return 1 + quicvarint.Len(uint64(f.StreamID)) + quicvarint.Len(uint64(f.MaximumStreamData))
}

View File

@@ -0,0 +1,189 @@
package wire
import (
"bytes"
"errors"
"io"
"github.com/lucas-clemente/quic-go/internal/protocol"
"github.com/lucas-clemente/quic-go/quicvarint"
)
// A StreamFrame of QUIC
type StreamFrame struct {
StreamID protocol.StreamID
Offset protocol.ByteCount
Data []byte
Fin bool
DataLenPresent bool
fromPool bool
}
func parseStreamFrame(r *bytes.Reader, _ protocol.VersionNumber) (*StreamFrame, error) {
typeByte, err := r.ReadByte()
if err != nil {
return nil, err
}
hasOffset := typeByte&0x4 > 0
fin := typeByte&0x1 > 0
hasDataLen := typeByte&0x2 > 0
streamID, err := quicvarint.Read(r)
if err != nil {
return nil, err
}
var offset uint64
if hasOffset {
offset, err = quicvarint.Read(r)
if err != nil {
return nil, err
}
}
var dataLen uint64
if hasDataLen {
var err error
dataLen, err = quicvarint.Read(r)
if err != nil {
return nil, err
}
} else {
// The rest of the packet is data
dataLen = uint64(r.Len())
}
var frame *StreamFrame
if dataLen < protocol.MinStreamFrameBufferSize {
frame = &StreamFrame{Data: make([]byte, dataLen)}
} else {
frame = GetStreamFrame()
// The STREAM frame can't be larger than the StreamFrame we obtained from the buffer,
// since those StreamFrames have a buffer length of the maximum packet size.
if dataLen > uint64(cap(frame.Data)) {
return nil, io.EOF
}
frame.Data = frame.Data[:dataLen]
}
frame.StreamID = protocol.StreamID(streamID)
frame.Offset = protocol.ByteCount(offset)
frame.Fin = fin
frame.DataLenPresent = hasDataLen
if dataLen != 0 {
if _, err := io.ReadFull(r, frame.Data); err != nil {
return nil, err
}
}
if frame.Offset+frame.DataLen() > protocol.MaxByteCount {
return nil, errors.New("stream data overflows maximum offset")
}
return frame, nil
}
// Write writes a STREAM frame
func (f *StreamFrame) Write(b *bytes.Buffer, version protocol.VersionNumber) error {
if len(f.Data) == 0 && !f.Fin {
return errors.New("StreamFrame: attempting to write empty frame without FIN")
}
typeByte := byte(0x8)
if f.Fin {
typeByte ^= 0x1
}
hasOffset := f.Offset != 0
if f.DataLenPresent {
typeByte ^= 0x2
}
if hasOffset {
typeByte ^= 0x4
}
b.WriteByte(typeByte)
quicvarint.Write(b, uint64(f.StreamID))
if hasOffset {
quicvarint.Write(b, uint64(f.Offset))
}
if f.DataLenPresent {
quicvarint.Write(b, uint64(f.DataLen()))
}
b.Write(f.Data)
return nil
}
// Length returns the total length of the STREAM frame
func (f *StreamFrame) Length(version protocol.VersionNumber) protocol.ByteCount {
length := 1 + quicvarint.Len(uint64(f.StreamID))
if f.Offset != 0 {
length += quicvarint.Len(uint64(f.Offset))
}
if f.DataLenPresent {
length += quicvarint.Len(uint64(f.DataLen()))
}
return length + f.DataLen()
}
// DataLen gives the length of data in bytes
func (f *StreamFrame) DataLen() protocol.ByteCount {
return protocol.ByteCount(len(f.Data))
}
// MaxDataLen returns the maximum data length
// If 0 is returned, writing will fail (a STREAM frame must contain at least 1 byte of data).
func (f *StreamFrame) MaxDataLen(maxSize protocol.ByteCount, version protocol.VersionNumber) protocol.ByteCount {
headerLen := 1 + quicvarint.Len(uint64(f.StreamID))
if f.Offset != 0 {
headerLen += quicvarint.Len(uint64(f.Offset))
}
if f.DataLenPresent {
// pretend that the data size will be 1 bytes
// if it turns out that varint encoding the length will consume 2 bytes, we need to adjust the data length afterwards
headerLen++
}
if headerLen > maxSize {
return 0
}
maxDataLen := maxSize - headerLen
if f.DataLenPresent && quicvarint.Len(uint64(maxDataLen)) != 1 {
maxDataLen--
}
return maxDataLen
}
// MaybeSplitOffFrame splits a frame such that it is not bigger than n bytes.
// It returns if the frame was actually split.
// The frame might not be split if:
// * the size is large enough to fit the whole frame
// * the size is too small to fit even a 1-byte frame. In that case, the frame returned is nil.
func (f *StreamFrame) MaybeSplitOffFrame(maxSize protocol.ByteCount, version protocol.VersionNumber) (*StreamFrame, bool /* was splitting required */) {
if maxSize >= f.Length(version) {
return nil, false
}
n := f.MaxDataLen(maxSize, version)
if n == 0 {
return nil, true
}
new := GetStreamFrame()
new.StreamID = f.StreamID
new.Offset = f.Offset
new.Fin = false
new.DataLenPresent = f.DataLenPresent
// swap the data slices
new.Data, f.Data = f.Data, new.Data
new.fromPool, f.fromPool = f.fromPool, new.fromPool
f.Data = f.Data[:protocol.ByteCount(len(new.Data))-n]
copy(f.Data, new.Data[n:])
new.Data = new.Data[:n]
f.Offset += n
return new, true
}
func (f *StreamFrame) PutBack() {
putStreamFrame(f)
}

View File

@@ -0,0 +1,55 @@
package wire
import (
"bytes"
"fmt"
"github.com/lucas-clemente/quic-go/internal/protocol"
"github.com/lucas-clemente/quic-go/quicvarint"
)
// A StreamsBlockedFrame is a STREAMS_BLOCKED frame
type StreamsBlockedFrame struct {
Type protocol.StreamType
StreamLimit protocol.StreamNum
}
func parseStreamsBlockedFrame(r *bytes.Reader, _ protocol.VersionNumber) (*StreamsBlockedFrame, error) {
typeByte, err := r.ReadByte()
if err != nil {
return nil, err
}
f := &StreamsBlockedFrame{}
switch typeByte {
case 0x16:
f.Type = protocol.StreamTypeBidi
case 0x17:
f.Type = protocol.StreamTypeUni
}
streamLimit, err := quicvarint.Read(r)
if err != nil {
return nil, err
}
f.StreamLimit = protocol.StreamNum(streamLimit)
if f.StreamLimit > protocol.MaxStreamCount {
return nil, fmt.Errorf("%d exceeds the maximum stream count", f.StreamLimit)
}
return f, nil
}
func (f *StreamsBlockedFrame) Write(b *bytes.Buffer, _ protocol.VersionNumber) error {
switch f.Type {
case protocol.StreamTypeBidi:
b.WriteByte(0x16)
case protocol.StreamTypeUni:
b.WriteByte(0x17)
}
quicvarint.Write(b, uint64(f.StreamLimit))
return nil
}
// Length of a written frame
func (f *StreamsBlockedFrame) Length(_ protocol.VersionNumber) protocol.ByteCount {
return 1 + quicvarint.Len(uint64(f.StreamLimit))
}

Some files were not shown because too many files have changed in this diff Show More