cloudflared/quic/v3/session_test.go
Devin Carr 6a6c890700 TUN-8667: Add datagram v3 session manager
New session manager leverages similar functionality that was previously
provided with datagram v2, with the distinct difference that the sessions
are registered via QUIC Datagrams and unregistered via timeouts only; the
sessions will no longer attempt to unregister sessions remotely with the
edge service.

The Session Manager is shared across all QUIC connections that cloudflared
uses to connect to the edge (typically 4). This will help cloudflared be
able to monitor all sessions across the connections and help correlate
in the future if sessions migrate across connections.

The UDP payload size is still limited to 1280 bytes across all OS's. Any
UDP packet that provides a payload size of greater than 1280 will cause
cloudflared to report (as it currently does) a log error and drop the packet.

Closes TUN-8667
2024-10-31 14:05:15 -07:00

284 lines
6.8 KiB
Go

package v3_test
import (
"context"
"errors"
"net"
"slices"
"sync/atomic"
"testing"
"time"
"github.com/rs/zerolog"
v3 "github.com/cloudflare/cloudflared/quic/v3"
)
var expectedContextCanceled = errors.New("expected context canceled")
func TestSessionNew(t *testing.T) {
log := zerolog.Nop()
session := v3.NewSession(testRequestID, 5*time.Second, nil, &noopEyeball{}, &log)
if testRequestID != session.ID() {
t.Fatalf("session id doesn't match: %s != %s", testRequestID, session.ID())
}
}
func testSessionWrite(t *testing.T, payload []byte) {
log := zerolog.Nop()
origin := newTestOrigin(makePayload(1280))
session := v3.NewSession(testRequestID, 5*time.Second, &origin, &noopEyeball{}, &log)
n, err := session.Write(payload)
if err != nil {
t.Fatal(err)
}
if n != len(payload) {
t.Fatal("unable to write the whole payload")
}
if !slices.Equal(payload, origin.write[:len(payload)]) {
t.Fatal("payload provided from origin and read value are not the same")
}
}
func TestSessionWrite_Max(t *testing.T) {
payload := makePayload(1280)
testSessionWrite(t, payload)
}
func TestSessionWrite_Min(t *testing.T) {
payload := makePayload(0)
testSessionWrite(t, payload)
}
func TestSessionServe_OriginMax(t *testing.T) {
payload := makePayload(1280)
testSessionServe_Origin(t, payload)
}
func TestSessionServe_OriginMin(t *testing.T) {
payload := makePayload(0)
testSessionServe_Origin(t, payload)
}
func testSessionServe_Origin(t *testing.T, payload []byte) {
log := zerolog.Nop()
eyeball := newMockEyeball()
origin := newTestOrigin(payload)
session := v3.NewSession(testRequestID, 3*time.Second, &origin, &eyeball, &log)
defer session.Close()
ctx, cancel := context.WithCancelCause(context.Background())
defer cancel(context.Canceled)
done := make(chan error)
go func() {
done <- session.Serve(ctx)
}()
select {
case data := <-eyeball.recvData:
// check received data matches provided from origin
expectedData := makePayload(1500)
v3.MarshalPayloadHeaderTo(testRequestID, expectedData[:])
copy(expectedData[17:], payload)
if !slices.Equal(expectedData[:17+len(payload)], data) {
t.Fatal("expected datagram did not equal expected")
}
cancel(expectedContextCanceled)
case err := <-ctx.Done():
// we expect the payload to return before the context to cancel on the session
t.Fatal(err)
}
err := <-done
if !errors.Is(err, context.Canceled) {
t.Fatal(err)
}
if !errors.Is(context.Cause(ctx), expectedContextCanceled) {
t.Fatal(err)
}
}
func TestSessionServe_OriginTooLarge(t *testing.T) {
log := zerolog.Nop()
eyeball := newMockEyeball()
payload := makePayload(1281)
origin := newTestOrigin(payload)
session := v3.NewSession(testRequestID, 2*time.Second, &origin, &eyeball, &log)
defer session.Close()
done := make(chan error)
go func() {
done <- session.Serve(context.Background())
}()
select {
case data := <-eyeball.recvData:
// we never expect a read to make it here because the origin provided a payload that is too large
// for cloudflared to proxy and it will drop it.
t.Fatalf("we should never proxy a payload of this size: %d", len(data))
case err := <-done:
if !errors.Is(err, v3.SessionIdleErr{}) {
t.Error(err)
}
}
}
func TestSessionClose_Multiple(t *testing.T) {
log := zerolog.Nop()
origin := newTestOrigin(makePayload(128))
session := v3.NewSession(testRequestID, 5*time.Second, &origin, &noopEyeball{}, &log)
err := session.Close()
if err != nil {
t.Fatal(err)
}
if !origin.closed.Load() {
t.Fatal("origin wasn't closed")
}
// subsequent closes shouldn't call close again or cause any errors
err = session.Close()
if err != nil {
t.Fatal(err)
}
}
func TestSessionServe_IdleTimeout(t *testing.T) {
log := zerolog.Nop()
origin := newTestIdleOrigin(10 * time.Second) // Make idle time longer than closeAfterIdle
closeAfterIdle := 2 * time.Second
session := v3.NewSession(testRequestID, closeAfterIdle, &origin, &noopEyeball{}, &log)
err := session.Serve(context.Background())
if !errors.Is(err, v3.SessionIdleErr{}) {
t.Fatal(err)
}
// session should be closed
if !origin.closed {
t.Fatalf("session should be closed after Serve returns")
}
// closing a session again should not return an error
err = session.Close()
if err != nil {
t.Fatal(err)
}
}
func TestSessionServe_ParentContextCanceled(t *testing.T) {
log := zerolog.Nop()
// Make idle time and idle timeout longer than closeAfterIdle
origin := newTestIdleOrigin(10 * time.Second)
closeAfterIdle := 10 * time.Second
session := v3.NewSession(testRequestID, closeAfterIdle, &origin, &noopEyeball{}, &log)
ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second)
defer cancel()
err := session.Serve(ctx)
if !errors.Is(err, context.DeadlineExceeded) {
t.Fatal(err)
}
// session should be closed
if !origin.closed {
t.Fatalf("session should be closed after Serve returns")
}
// closing a session again should not return an error
err = session.Close()
if err != nil {
t.Fatal(err)
}
}
func TestSessionServe_ReadErrors(t *testing.T) {
log := zerolog.Nop()
origin := newTestErrOrigin(net.ErrClosed, nil)
session := v3.NewSession(testRequestID, 30*time.Second, &origin, &noopEyeball{}, &log)
err := session.Serve(context.Background())
if !errors.Is(err, net.ErrClosed) {
t.Fatal(err)
}
}
type testOrigin struct {
// bytes from Write
write []byte
// bytes provided to Read
read []byte
readOnce atomic.Bool
closed atomic.Bool
}
func newTestOrigin(payload []byte) testOrigin {
return testOrigin{
read: payload,
}
}
func (o *testOrigin) Read(p []byte) (n int, err error) {
if o.closed.Load() {
return -1, net.ErrClosed
}
if o.readOnce.Load() {
// We only want to provide one read so all other reads will be blocked
time.Sleep(10 * time.Second)
}
o.readOnce.Store(true)
return copy(p, o.read), nil
}
func (o *testOrigin) Write(p []byte) (n int, err error) {
if o.closed.Load() {
return -1, net.ErrClosed
}
o.write = make([]byte, len(p))
copy(o.write, p)
return len(p), nil
}
func (o *testOrigin) Close() error {
o.closed.Store(true)
return nil
}
type testIdleOrigin struct {
duration time.Duration
closed bool
}
func newTestIdleOrigin(d time.Duration) testIdleOrigin {
return testIdleOrigin{
duration: d,
}
}
func (o *testIdleOrigin) Read(p []byte) (n int, err error) {
time.Sleep(o.duration)
return 0, nil
}
func (o *testIdleOrigin) Write(p []byte) (n int, err error) {
return 0, nil
}
func (o *testIdleOrigin) Close() error {
o.closed = true
return nil
}
type testErrOrigin struct {
readErr error
writeErr error
}
func newTestErrOrigin(readErr error, writeErr error) testErrOrigin {
return testErrOrigin{readErr, writeErr}
}
func (o *testErrOrigin) Read(p []byte) (n int, err error) {
return 0, o.readErr
}
func (o *testErrOrigin) Write(p []byte) (n int, err error) {
return len(p), o.writeErr
}
func (o *testErrOrigin) Close() error {
return nil
}