TUN-5488: Close session after it's idle for a period defined by registerUdpSession RPC

This commit is contained in:
cthuang
2021-12-02 11:02:27 +00:00
parent 9bc59bc78c
commit 73a265f2fc
13 changed files with 456 additions and 253 deletions

View File

@@ -127,7 +127,7 @@ func (m *manager) sendToSession(datagram *newDatagram) {
}
// session writes to destination over a connected UDP socket, which should not be blocking, so this call doesn't
// need to run in another go routine
_, err := session.writeToDst(datagram.payload)
_, err := session.transportToDst(datagram.payload)
if err != nil {
m.log.Err(err).Str("sessionID", datagram.sessionID.String()).Msg("Failed to write payload to session")
}

View File

@@ -7,6 +7,7 @@ import (
"io"
"net"
"testing"
"time"
"github.com/google/uuid"
"github.com/rs/zerolog"
@@ -21,15 +22,15 @@ func TestManagerServe(t *testing.T) {
)
log := zerolog.Nop()
transport := &mockQUICTransport{
reqChan: newDatagramChannel(),
respChan: newDatagramChannel(),
reqChan: newDatagramChannel(1),
respChan: newDatagramChannel(1),
}
mg := NewManager(transport, &log)
eyeballTracker := make(map[uuid.UUID]*datagramChannel)
for i := 0; i < sessions; i++ {
sessionID := uuid.New()
eyeballTracker[sessionID] = newDatagramChannel()
eyeballTracker[sessionID] = newDatagramChannel(1)
}
ctx, cancel := context.WithCancel(context.Background())
@@ -88,7 +89,7 @@ func TestManagerServe(t *testing.T) {
sessionDone := make(chan struct{})
go func() {
session.Serve(ctx)
session.Serve(ctx, time.Minute*2)
close(sessionDone)
}()
@@ -179,9 +180,9 @@ type datagramChannel struct {
closedChan chan struct{}
}
func newDatagramChannel() *datagramChannel {
func newDatagramChannel(capacity uint) *datagramChannel {
return &datagramChannel{
datagramChan: make(chan *newDatagram, 1),
datagramChan: make(chan *newDatagram, capacity),
closedChan: make(chan struct{}),
}
}

View File

@@ -3,10 +3,15 @@ package datagramsession
import (
"context"
"io"
"time"
"github.com/google/uuid"
)
const (
defaultCloseIdleAfter = time.Second * 210
)
// Each Session is a bidirectional pipe of datagrams between transport and dstConn
// Currently the only implementation of transport is quic DatagramMuxer
// Destination can be a connection with origin or with eyeball
@@ -22,7 +27,9 @@ type Session struct {
id uuid.UUID
transport transport
dstConn io.ReadWriteCloser
doneChan chan struct{}
// activeAtChan is used to communicate the last read/write time
activeAtChan chan time.Time
doneChan chan struct{}
}
func newSession(id uuid.UUID, transport transport, dstConn io.ReadWriteCloser) *Session {
@@ -30,41 +37,81 @@ func newSession(id uuid.UUID, transport transport, dstConn io.ReadWriteCloser) *
id: id,
transport: transport,
dstConn: dstConn,
doneChan: make(chan struct{}),
// activeAtChan has low capacity. It can be full when there are many concurrent read/write. markActive() will
// drop instead of blocking because last active time only needs to be an approximation
activeAtChan: make(chan time.Time, 2),
doneChan: make(chan struct{}),
}
}
func (s *Session) Serve(ctx context.Context) error {
func (s *Session) Serve(ctx context.Context, closeAfterIdle time.Duration) error {
serveCtx, cancel := context.WithCancel(ctx)
defer cancel()
go func() {
select {
case <-serveCtx.Done():
case <-s.doneChan:
}
s.dstConn.Close()
}()
go s.waitForCloseCondition(serveCtx, closeAfterIdle)
// QUIC implementation copies data to another buffer before returning https://github.com/lucas-clemente/quic-go/blob/v0.24.0/session.go#L1967-L1975
// This makes it safe to share readBuffer between iterations
readBuffer := make([]byte, 1280)
readBuffer := make([]byte, s.transport.MTU())
for {
// TODO: TUN-5303: origin proxy should determine the buffer size
n, err := s.dstConn.Read(readBuffer)
if n > 0 {
if err := s.transport.SendTo(s.id, readBuffer[:n]); err != nil {
return err
}
}
if err != nil {
if err := s.dstToTransport(readBuffer); err != nil {
return err
}
}
}
func (s *Session) writeToDst(payload []byte) (int, error) {
func (s *Session) waitForCloseCondition(ctx context.Context, closeAfterIdle time.Duration) {
if closeAfterIdle == 0 {
// provide deafult is caller doesn't specify one
closeAfterIdle = defaultCloseIdleAfter
}
// Closing dstConn cancels read so Serve function can return
defer s.dstConn.Close()
checkIdleFreq := closeAfterIdle / 8
checkIdleTicker := time.NewTicker(checkIdleFreq)
defer checkIdleTicker.Stop()
activeAt := time.Now()
for {
select {
case <-ctx.Done():
return
case <-s.doneChan:
return
case <-checkIdleTicker.C:
// The session is considered inactive if current time is after (last active time + allowed idle time)
if time.Now().After(activeAt.Add(closeAfterIdle)) {
return
}
case activeAt = <-s.activeAtChan: // Update last active time
}
}
}
func (s *Session) dstToTransport(buffer []byte) error {
n, err := s.dstConn.Read(buffer)
s.markActive()
if n > 0 {
if err := s.transport.SendTo(s.id, buffer[:n]); err != nil {
return err
}
}
return err
}
func (s *Session) transportToDst(payload []byte) (int, error) {
s.markActive()
return s.dstConn.Write(payload)
}
// Sends the last active time to the idle checker loop without blocking. activeAtChan will only be full when there
// are many concurrent read/write. It is fine to lose some precision
func (s *Session) markActive() {
select {
case s.activeAtChan <- time.Now():
default:
}
}
func (s *Session) close() {
close(s.doneChan)
}

View File

@@ -1,43 +1,54 @@
package datagramsession
import (
"bytes"
"context"
"fmt"
"io"
"net"
"sync"
"testing"
"time"
"github.com/google/uuid"
"github.com/stretchr/testify/require"
"golang.org/x/sync/errgroup"
)
// TestCloseSession makes sure a session will stop after context is done
func TestSessionCtxDone(t *testing.T) {
testSessionReturns(t, true)
testSessionReturns(t, closeByContext, time.Minute*2)
}
// TestCloseSession makes sure a session will stop after close method is called
func TestCloseSession(t *testing.T) {
testSessionReturns(t, false)
testSessionReturns(t, closeByCallingClose, time.Minute*2)
}
func testSessionReturns(t *testing.T, closeByContext bool) {
// TestCloseIdle makess sure a session will stop after there is no read/write for a period defined by closeAfterIdle
func TestCloseIdle(t *testing.T) {
testSessionReturns(t, closeByTimeout, time.Millisecond*100)
}
func testSessionReturns(t *testing.T, closeBy closeMethod, closeAfterIdle time.Duration) {
sessionID := uuid.New()
cfdConn, originConn := net.Pipe()
payload := testPayload(sessionID)
transport := &mockQUICTransport{
reqChan: newDatagramChannel(),
respChan: newDatagramChannel(),
reqChan: newDatagramChannel(1),
respChan: newDatagramChannel(1),
}
session := newSession(sessionID, transport, cfdConn)
ctx, cancel := context.WithCancel(context.Background())
sessionDone := make(chan struct{})
go func() {
session.Serve(ctx)
session.Serve(ctx, closeAfterIdle)
close(sessionDone)
}()
go func() {
n, err := session.writeToDst(payload)
n, err := session.transportToDst(payload)
require.NoError(t, err)
require.Equal(t, len(payload), n)
}()
@@ -47,13 +58,120 @@ func testSessionReturns(t *testing.T, closeByContext bool) {
require.NoError(t, err)
require.Equal(t, len(payload), n)
if closeByContext {
lastRead := time.Now()
switch closeBy {
case closeByContext:
cancel()
} else {
case closeByCallingClose:
session.close()
}
<-sessionDone
if closeBy == closeByTimeout {
require.True(t, time.Now().After(lastRead.Add(closeAfterIdle)))
}
// call cancelled again otherwise the linter will warn about possible context leak
cancel()
}
type closeMethod int
const (
closeByContext closeMethod = iota
closeByCallingClose
closeByTimeout
)
func TestWriteToDstSessionPreventClosed(t *testing.T) {
testActiveSessionNotClosed(t, false, true)
}
func TestReadFromDstSessionPreventClosed(t *testing.T) {
testActiveSessionNotClosed(t, true, false)
}
func testActiveSessionNotClosed(t *testing.T, readFromDst bool, writeToDst bool) {
const closeAfterIdle = time.Millisecond * 100
const activeTime = time.Millisecond * 500
sessionID := uuid.New()
cfdConn, originConn := net.Pipe()
payload := testPayload(sessionID)
transport := &mockQUICTransport{
reqChan: newDatagramChannel(100),
respChan: newDatagramChannel(100),
}
session := newSession(sessionID, transport, cfdConn)
startTime := time.Now()
activeUntil := startTime.Add(activeTime)
ctx, cancel := context.WithCancel(context.Background())
errGroup, ctx := errgroup.WithContext(ctx)
errGroup.Go(func() error {
session.Serve(ctx, closeAfterIdle)
if time.Now().Before(startTime.Add(activeTime)) {
return fmt.Errorf("session closed while it's still active")
}
return nil
})
if readFromDst {
errGroup.Go(func() error {
for {
if time.Now().After(activeUntil) {
return nil
}
if _, err := originConn.Write(payload); err != nil {
return err
}
time.Sleep(closeAfterIdle / 2)
}
})
}
if writeToDst {
errGroup.Go(func() error {
readBuffer := make([]byte, len(payload))
for {
n, err := originConn.Read(readBuffer)
if err != nil {
if err == io.EOF || err == io.ErrClosedPipe {
return nil
}
return err
}
if !bytes.Equal(payload, readBuffer[:n]) {
return fmt.Errorf("payload %v is not equal to %v", readBuffer[:n], payload)
}
}
})
errGroup.Go(func() error {
for {
if time.Now().After(activeUntil) {
return nil
}
if _, err := session.transportToDst(payload); err != nil {
return err
}
time.Sleep(closeAfterIdle / 2)
}
})
}
require.NoError(t, errGroup.Wait())
cancel()
}
func TestMarkActiveNotBlocking(t *testing.T) {
const concurrentCalls = 50
session := newSession(uuid.New(), nil, nil)
var wg sync.WaitGroup
wg.Add(concurrentCalls)
for i := 0; i < concurrentCalls; i++ {
go func() {
session.markActive()
wg.Done()
}()
}
wg.Wait()
}

View File

@@ -8,4 +8,6 @@ type transport interface {
SendTo(sessionID uuid.UUID, payload []byte) error
// ReceiveFrom reads the next datagram from the transport
ReceiveFrom() (uuid.UUID, []byte, error)
// Max transmission unit of the transport
MTU() uint
}

View File

@@ -22,6 +22,10 @@ func (mt *mockQUICTransport) ReceiveFrom() (uuid.UUID, []byte, error) {
return mt.reqChan.Receive(context.Background())
}
func (mt *mockQUICTransport) MTU() uint {
return 1220
}
func (mt *mockQUICTransport) newRequest(ctx context.Context, sessionID uuid.UUID, payload []byte) error {
return mt.reqChan.Send(ctx, sessionID, payload)
}