mirror of
https://github.com/cloudflare/cloudflared.git
synced 2025-07-27 18:39:58 +00:00
TUN-5488: Close session after it's idle for a period defined by registerUdpSession RPC
This commit is contained in:
@@ -1,43 +1,54 @@
|
||||
package datagramsession
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"fmt"
|
||||
"io"
|
||||
"net"
|
||||
"sync"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/google/uuid"
|
||||
"github.com/stretchr/testify/require"
|
||||
"golang.org/x/sync/errgroup"
|
||||
)
|
||||
|
||||
// TestCloseSession makes sure a session will stop after context is done
|
||||
func TestSessionCtxDone(t *testing.T) {
|
||||
testSessionReturns(t, true)
|
||||
testSessionReturns(t, closeByContext, time.Minute*2)
|
||||
}
|
||||
|
||||
// TestCloseSession makes sure a session will stop after close method is called
|
||||
func TestCloseSession(t *testing.T) {
|
||||
testSessionReturns(t, false)
|
||||
testSessionReturns(t, closeByCallingClose, time.Minute*2)
|
||||
}
|
||||
|
||||
func testSessionReturns(t *testing.T, closeByContext bool) {
|
||||
// TestCloseIdle makess sure a session will stop after there is no read/write for a period defined by closeAfterIdle
|
||||
func TestCloseIdle(t *testing.T) {
|
||||
testSessionReturns(t, closeByTimeout, time.Millisecond*100)
|
||||
}
|
||||
|
||||
func testSessionReturns(t *testing.T, closeBy closeMethod, closeAfterIdle time.Duration) {
|
||||
sessionID := uuid.New()
|
||||
cfdConn, originConn := net.Pipe()
|
||||
payload := testPayload(sessionID)
|
||||
transport := &mockQUICTransport{
|
||||
reqChan: newDatagramChannel(),
|
||||
respChan: newDatagramChannel(),
|
||||
reqChan: newDatagramChannel(1),
|
||||
respChan: newDatagramChannel(1),
|
||||
}
|
||||
session := newSession(sessionID, transport, cfdConn)
|
||||
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
sessionDone := make(chan struct{})
|
||||
go func() {
|
||||
session.Serve(ctx)
|
||||
session.Serve(ctx, closeAfterIdle)
|
||||
close(sessionDone)
|
||||
}()
|
||||
|
||||
go func() {
|
||||
n, err := session.writeToDst(payload)
|
||||
n, err := session.transportToDst(payload)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, len(payload), n)
|
||||
}()
|
||||
@@ -47,13 +58,120 @@ func testSessionReturns(t *testing.T, closeByContext bool) {
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, len(payload), n)
|
||||
|
||||
if closeByContext {
|
||||
lastRead := time.Now()
|
||||
|
||||
switch closeBy {
|
||||
case closeByContext:
|
||||
cancel()
|
||||
} else {
|
||||
case closeByCallingClose:
|
||||
session.close()
|
||||
}
|
||||
|
||||
<-sessionDone
|
||||
if closeBy == closeByTimeout {
|
||||
require.True(t, time.Now().After(lastRead.Add(closeAfterIdle)))
|
||||
}
|
||||
// call cancelled again otherwise the linter will warn about possible context leak
|
||||
cancel()
|
||||
}
|
||||
|
||||
type closeMethod int
|
||||
|
||||
const (
|
||||
closeByContext closeMethod = iota
|
||||
closeByCallingClose
|
||||
closeByTimeout
|
||||
)
|
||||
|
||||
func TestWriteToDstSessionPreventClosed(t *testing.T) {
|
||||
testActiveSessionNotClosed(t, false, true)
|
||||
}
|
||||
|
||||
func TestReadFromDstSessionPreventClosed(t *testing.T) {
|
||||
testActiveSessionNotClosed(t, true, false)
|
||||
}
|
||||
|
||||
func testActiveSessionNotClosed(t *testing.T, readFromDst bool, writeToDst bool) {
|
||||
const closeAfterIdle = time.Millisecond * 100
|
||||
const activeTime = time.Millisecond * 500
|
||||
|
||||
sessionID := uuid.New()
|
||||
cfdConn, originConn := net.Pipe()
|
||||
payload := testPayload(sessionID)
|
||||
transport := &mockQUICTransport{
|
||||
reqChan: newDatagramChannel(100),
|
||||
respChan: newDatagramChannel(100),
|
||||
}
|
||||
session := newSession(sessionID, transport, cfdConn)
|
||||
|
||||
startTime := time.Now()
|
||||
activeUntil := startTime.Add(activeTime)
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
errGroup, ctx := errgroup.WithContext(ctx)
|
||||
errGroup.Go(func() error {
|
||||
session.Serve(ctx, closeAfterIdle)
|
||||
if time.Now().Before(startTime.Add(activeTime)) {
|
||||
return fmt.Errorf("session closed while it's still active")
|
||||
}
|
||||
return nil
|
||||
})
|
||||
|
||||
if readFromDst {
|
||||
errGroup.Go(func() error {
|
||||
for {
|
||||
if time.Now().After(activeUntil) {
|
||||
return nil
|
||||
}
|
||||
if _, err := originConn.Write(payload); err != nil {
|
||||
return err
|
||||
}
|
||||
time.Sleep(closeAfterIdle / 2)
|
||||
}
|
||||
})
|
||||
}
|
||||
if writeToDst {
|
||||
errGroup.Go(func() error {
|
||||
readBuffer := make([]byte, len(payload))
|
||||
for {
|
||||
n, err := originConn.Read(readBuffer)
|
||||
if err != nil {
|
||||
if err == io.EOF || err == io.ErrClosedPipe {
|
||||
return nil
|
||||
}
|
||||
return err
|
||||
}
|
||||
if !bytes.Equal(payload, readBuffer[:n]) {
|
||||
return fmt.Errorf("payload %v is not equal to %v", readBuffer[:n], payload)
|
||||
}
|
||||
}
|
||||
})
|
||||
errGroup.Go(func() error {
|
||||
for {
|
||||
if time.Now().After(activeUntil) {
|
||||
return nil
|
||||
}
|
||||
if _, err := session.transportToDst(payload); err != nil {
|
||||
return err
|
||||
}
|
||||
time.Sleep(closeAfterIdle / 2)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
require.NoError(t, errGroup.Wait())
|
||||
cancel()
|
||||
}
|
||||
|
||||
func TestMarkActiveNotBlocking(t *testing.T) {
|
||||
const concurrentCalls = 50
|
||||
session := newSession(uuid.New(), nil, nil)
|
||||
var wg sync.WaitGroup
|
||||
wg.Add(concurrentCalls)
|
||||
for i := 0; i < concurrentCalls; i++ {
|
||||
go func() {
|
||||
session.markActive()
|
||||
wg.Done()
|
||||
}()
|
||||
}
|
||||
wg.Wait()
|
||||
}
|
||||
|
Reference in New Issue
Block a user