mirror of
https://github.com/cloudflare/cloudflared.git
synced 2025-07-27 19:29:57 +00:00
TUN-5488: Close session after it's idle for a period defined by registerUdpSession RPC
This commit is contained in:
@@ -127,7 +127,7 @@ func (m *manager) sendToSession(datagram *newDatagram) {
|
||||
}
|
||||
// session writes to destination over a connected UDP socket, which should not be blocking, so this call doesn't
|
||||
// need to run in another go routine
|
||||
_, err := session.writeToDst(datagram.payload)
|
||||
_, err := session.transportToDst(datagram.payload)
|
||||
if err != nil {
|
||||
m.log.Err(err).Str("sessionID", datagram.sessionID.String()).Msg("Failed to write payload to session")
|
||||
}
|
||||
|
@@ -7,6 +7,7 @@ import (
|
||||
"io"
|
||||
"net"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/google/uuid"
|
||||
"github.com/rs/zerolog"
|
||||
@@ -21,15 +22,15 @@ func TestManagerServe(t *testing.T) {
|
||||
)
|
||||
log := zerolog.Nop()
|
||||
transport := &mockQUICTransport{
|
||||
reqChan: newDatagramChannel(),
|
||||
respChan: newDatagramChannel(),
|
||||
reqChan: newDatagramChannel(1),
|
||||
respChan: newDatagramChannel(1),
|
||||
}
|
||||
mg := NewManager(transport, &log)
|
||||
|
||||
eyeballTracker := make(map[uuid.UUID]*datagramChannel)
|
||||
for i := 0; i < sessions; i++ {
|
||||
sessionID := uuid.New()
|
||||
eyeballTracker[sessionID] = newDatagramChannel()
|
||||
eyeballTracker[sessionID] = newDatagramChannel(1)
|
||||
}
|
||||
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
@@ -88,7 +89,7 @@ func TestManagerServe(t *testing.T) {
|
||||
|
||||
sessionDone := make(chan struct{})
|
||||
go func() {
|
||||
session.Serve(ctx)
|
||||
session.Serve(ctx, time.Minute*2)
|
||||
close(sessionDone)
|
||||
}()
|
||||
|
||||
@@ -179,9 +180,9 @@ type datagramChannel struct {
|
||||
closedChan chan struct{}
|
||||
}
|
||||
|
||||
func newDatagramChannel() *datagramChannel {
|
||||
func newDatagramChannel(capacity uint) *datagramChannel {
|
||||
return &datagramChannel{
|
||||
datagramChan: make(chan *newDatagram, 1),
|
||||
datagramChan: make(chan *newDatagram, capacity),
|
||||
closedChan: make(chan struct{}),
|
||||
}
|
||||
}
|
||||
|
@@ -3,10 +3,15 @@ package datagramsession
|
||||
import (
|
||||
"context"
|
||||
"io"
|
||||
"time"
|
||||
|
||||
"github.com/google/uuid"
|
||||
)
|
||||
|
||||
const (
|
||||
defaultCloseIdleAfter = time.Second * 210
|
||||
)
|
||||
|
||||
// Each Session is a bidirectional pipe of datagrams between transport and dstConn
|
||||
// Currently the only implementation of transport is quic DatagramMuxer
|
||||
// Destination can be a connection with origin or with eyeball
|
||||
@@ -22,7 +27,9 @@ type Session struct {
|
||||
id uuid.UUID
|
||||
transport transport
|
||||
dstConn io.ReadWriteCloser
|
||||
doneChan chan struct{}
|
||||
// activeAtChan is used to communicate the last read/write time
|
||||
activeAtChan chan time.Time
|
||||
doneChan chan struct{}
|
||||
}
|
||||
|
||||
func newSession(id uuid.UUID, transport transport, dstConn io.ReadWriteCloser) *Session {
|
||||
@@ -30,41 +37,81 @@ func newSession(id uuid.UUID, transport transport, dstConn io.ReadWriteCloser) *
|
||||
id: id,
|
||||
transport: transport,
|
||||
dstConn: dstConn,
|
||||
doneChan: make(chan struct{}),
|
||||
// activeAtChan has low capacity. It can be full when there are many concurrent read/write. markActive() will
|
||||
// drop instead of blocking because last active time only needs to be an approximation
|
||||
activeAtChan: make(chan time.Time, 2),
|
||||
doneChan: make(chan struct{}),
|
||||
}
|
||||
}
|
||||
|
||||
func (s *Session) Serve(ctx context.Context) error {
|
||||
func (s *Session) Serve(ctx context.Context, closeAfterIdle time.Duration) error {
|
||||
serveCtx, cancel := context.WithCancel(ctx)
|
||||
defer cancel()
|
||||
go func() {
|
||||
select {
|
||||
case <-serveCtx.Done():
|
||||
case <-s.doneChan:
|
||||
}
|
||||
s.dstConn.Close()
|
||||
}()
|
||||
go s.waitForCloseCondition(serveCtx, closeAfterIdle)
|
||||
// QUIC implementation copies data to another buffer before returning https://github.com/lucas-clemente/quic-go/blob/v0.24.0/session.go#L1967-L1975
|
||||
// This makes it safe to share readBuffer between iterations
|
||||
readBuffer := make([]byte, 1280)
|
||||
readBuffer := make([]byte, s.transport.MTU())
|
||||
for {
|
||||
// TODO: TUN-5303: origin proxy should determine the buffer size
|
||||
n, err := s.dstConn.Read(readBuffer)
|
||||
if n > 0 {
|
||||
if err := s.transport.SendTo(s.id, readBuffer[:n]); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
if err != nil {
|
||||
if err := s.dstToTransport(readBuffer); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (s *Session) writeToDst(payload []byte) (int, error) {
|
||||
func (s *Session) waitForCloseCondition(ctx context.Context, closeAfterIdle time.Duration) {
|
||||
if closeAfterIdle == 0 {
|
||||
// provide deafult is caller doesn't specify one
|
||||
closeAfterIdle = defaultCloseIdleAfter
|
||||
}
|
||||
// Closing dstConn cancels read so Serve function can return
|
||||
defer s.dstConn.Close()
|
||||
|
||||
checkIdleFreq := closeAfterIdle / 8
|
||||
checkIdleTicker := time.NewTicker(checkIdleFreq)
|
||||
defer checkIdleTicker.Stop()
|
||||
|
||||
activeAt := time.Now()
|
||||
for {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return
|
||||
case <-s.doneChan:
|
||||
return
|
||||
case <-checkIdleTicker.C:
|
||||
// The session is considered inactive if current time is after (last active time + allowed idle time)
|
||||
if time.Now().After(activeAt.Add(closeAfterIdle)) {
|
||||
return
|
||||
}
|
||||
case activeAt = <-s.activeAtChan: // Update last active time
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (s *Session) dstToTransport(buffer []byte) error {
|
||||
n, err := s.dstConn.Read(buffer)
|
||||
s.markActive()
|
||||
if n > 0 {
|
||||
if err := s.transport.SendTo(s.id, buffer[:n]); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
return err
|
||||
}
|
||||
|
||||
func (s *Session) transportToDst(payload []byte) (int, error) {
|
||||
s.markActive()
|
||||
return s.dstConn.Write(payload)
|
||||
}
|
||||
|
||||
// Sends the last active time to the idle checker loop without blocking. activeAtChan will only be full when there
|
||||
// are many concurrent read/write. It is fine to lose some precision
|
||||
func (s *Session) markActive() {
|
||||
select {
|
||||
case s.activeAtChan <- time.Now():
|
||||
default:
|
||||
}
|
||||
}
|
||||
|
||||
func (s *Session) close() {
|
||||
close(s.doneChan)
|
||||
}
|
||||
|
@@ -1,43 +1,54 @@
|
||||
package datagramsession
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"fmt"
|
||||
"io"
|
||||
"net"
|
||||
"sync"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/google/uuid"
|
||||
"github.com/stretchr/testify/require"
|
||||
"golang.org/x/sync/errgroup"
|
||||
)
|
||||
|
||||
// TestCloseSession makes sure a session will stop after context is done
|
||||
func TestSessionCtxDone(t *testing.T) {
|
||||
testSessionReturns(t, true)
|
||||
testSessionReturns(t, closeByContext, time.Minute*2)
|
||||
}
|
||||
|
||||
// TestCloseSession makes sure a session will stop after close method is called
|
||||
func TestCloseSession(t *testing.T) {
|
||||
testSessionReturns(t, false)
|
||||
testSessionReturns(t, closeByCallingClose, time.Minute*2)
|
||||
}
|
||||
|
||||
func testSessionReturns(t *testing.T, closeByContext bool) {
|
||||
// TestCloseIdle makess sure a session will stop after there is no read/write for a period defined by closeAfterIdle
|
||||
func TestCloseIdle(t *testing.T) {
|
||||
testSessionReturns(t, closeByTimeout, time.Millisecond*100)
|
||||
}
|
||||
|
||||
func testSessionReturns(t *testing.T, closeBy closeMethod, closeAfterIdle time.Duration) {
|
||||
sessionID := uuid.New()
|
||||
cfdConn, originConn := net.Pipe()
|
||||
payload := testPayload(sessionID)
|
||||
transport := &mockQUICTransport{
|
||||
reqChan: newDatagramChannel(),
|
||||
respChan: newDatagramChannel(),
|
||||
reqChan: newDatagramChannel(1),
|
||||
respChan: newDatagramChannel(1),
|
||||
}
|
||||
session := newSession(sessionID, transport, cfdConn)
|
||||
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
sessionDone := make(chan struct{})
|
||||
go func() {
|
||||
session.Serve(ctx)
|
||||
session.Serve(ctx, closeAfterIdle)
|
||||
close(sessionDone)
|
||||
}()
|
||||
|
||||
go func() {
|
||||
n, err := session.writeToDst(payload)
|
||||
n, err := session.transportToDst(payload)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, len(payload), n)
|
||||
}()
|
||||
@@ -47,13 +58,120 @@ func testSessionReturns(t *testing.T, closeByContext bool) {
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, len(payload), n)
|
||||
|
||||
if closeByContext {
|
||||
lastRead := time.Now()
|
||||
|
||||
switch closeBy {
|
||||
case closeByContext:
|
||||
cancel()
|
||||
} else {
|
||||
case closeByCallingClose:
|
||||
session.close()
|
||||
}
|
||||
|
||||
<-sessionDone
|
||||
if closeBy == closeByTimeout {
|
||||
require.True(t, time.Now().After(lastRead.Add(closeAfterIdle)))
|
||||
}
|
||||
// call cancelled again otherwise the linter will warn about possible context leak
|
||||
cancel()
|
||||
}
|
||||
|
||||
type closeMethod int
|
||||
|
||||
const (
|
||||
closeByContext closeMethod = iota
|
||||
closeByCallingClose
|
||||
closeByTimeout
|
||||
)
|
||||
|
||||
func TestWriteToDstSessionPreventClosed(t *testing.T) {
|
||||
testActiveSessionNotClosed(t, false, true)
|
||||
}
|
||||
|
||||
func TestReadFromDstSessionPreventClosed(t *testing.T) {
|
||||
testActiveSessionNotClosed(t, true, false)
|
||||
}
|
||||
|
||||
func testActiveSessionNotClosed(t *testing.T, readFromDst bool, writeToDst bool) {
|
||||
const closeAfterIdle = time.Millisecond * 100
|
||||
const activeTime = time.Millisecond * 500
|
||||
|
||||
sessionID := uuid.New()
|
||||
cfdConn, originConn := net.Pipe()
|
||||
payload := testPayload(sessionID)
|
||||
transport := &mockQUICTransport{
|
||||
reqChan: newDatagramChannel(100),
|
||||
respChan: newDatagramChannel(100),
|
||||
}
|
||||
session := newSession(sessionID, transport, cfdConn)
|
||||
|
||||
startTime := time.Now()
|
||||
activeUntil := startTime.Add(activeTime)
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
errGroup, ctx := errgroup.WithContext(ctx)
|
||||
errGroup.Go(func() error {
|
||||
session.Serve(ctx, closeAfterIdle)
|
||||
if time.Now().Before(startTime.Add(activeTime)) {
|
||||
return fmt.Errorf("session closed while it's still active")
|
||||
}
|
||||
return nil
|
||||
})
|
||||
|
||||
if readFromDst {
|
||||
errGroup.Go(func() error {
|
||||
for {
|
||||
if time.Now().After(activeUntil) {
|
||||
return nil
|
||||
}
|
||||
if _, err := originConn.Write(payload); err != nil {
|
||||
return err
|
||||
}
|
||||
time.Sleep(closeAfterIdle / 2)
|
||||
}
|
||||
})
|
||||
}
|
||||
if writeToDst {
|
||||
errGroup.Go(func() error {
|
||||
readBuffer := make([]byte, len(payload))
|
||||
for {
|
||||
n, err := originConn.Read(readBuffer)
|
||||
if err != nil {
|
||||
if err == io.EOF || err == io.ErrClosedPipe {
|
||||
return nil
|
||||
}
|
||||
return err
|
||||
}
|
||||
if !bytes.Equal(payload, readBuffer[:n]) {
|
||||
return fmt.Errorf("payload %v is not equal to %v", readBuffer[:n], payload)
|
||||
}
|
||||
}
|
||||
})
|
||||
errGroup.Go(func() error {
|
||||
for {
|
||||
if time.Now().After(activeUntil) {
|
||||
return nil
|
||||
}
|
||||
if _, err := session.transportToDst(payload); err != nil {
|
||||
return err
|
||||
}
|
||||
time.Sleep(closeAfterIdle / 2)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
require.NoError(t, errGroup.Wait())
|
||||
cancel()
|
||||
}
|
||||
|
||||
func TestMarkActiveNotBlocking(t *testing.T) {
|
||||
const concurrentCalls = 50
|
||||
session := newSession(uuid.New(), nil, nil)
|
||||
var wg sync.WaitGroup
|
||||
wg.Add(concurrentCalls)
|
||||
for i := 0; i < concurrentCalls; i++ {
|
||||
go func() {
|
||||
session.markActive()
|
||||
wg.Done()
|
||||
}()
|
||||
}
|
||||
wg.Wait()
|
||||
}
|
||||
|
@@ -8,4 +8,6 @@ type transport interface {
|
||||
SendTo(sessionID uuid.UUID, payload []byte) error
|
||||
// ReceiveFrom reads the next datagram from the transport
|
||||
ReceiveFrom() (uuid.UUID, []byte, error)
|
||||
// Max transmission unit of the transport
|
||||
MTU() uint
|
||||
}
|
||||
|
@@ -22,6 +22,10 @@ func (mt *mockQUICTransport) ReceiveFrom() (uuid.UUID, []byte, error) {
|
||||
return mt.reqChan.Receive(context.Background())
|
||||
}
|
||||
|
||||
func (mt *mockQUICTransport) MTU() uint {
|
||||
return 1220
|
||||
}
|
||||
|
||||
func (mt *mockQUICTransport) newRequest(ctx context.Context, sessionID uuid.UUID, payload []byte) error {
|
||||
return mt.reqChan.Send(ctx, sessionID, payload)
|
||||
}
|
||||
|
Reference in New Issue
Block a user