mirror of
https://github.com/cloudflare/cloudflared.git
synced 2025-07-27 20:50:00 +00:00
TUN-5488: Close session after it's idle for a period defined by registerUdpSession RPC
This commit is contained in:
@@ -57,6 +57,10 @@ func (dm *DatagramMuxer) ReceiveFrom() (uuid.UUID, []byte, error) {
|
||||
return ExtractSessionID(msg)
|
||||
}
|
||||
|
||||
func (dm *DatagramMuxer) MTU() uint {
|
||||
return MaxDatagramFrameSize
|
||||
}
|
||||
|
||||
// Each QUIC datagram should be suffixed with session ID.
|
||||
// ExtractSessionID extracts the session ID and a slice with only the payload
|
||||
func ExtractSessionID(b []byte) (uuid.UUID, []byte, error) {
|
||||
|
@@ -5,6 +5,7 @@ import (
|
||||
"fmt"
|
||||
"io"
|
||||
"net"
|
||||
"time"
|
||||
|
||||
capnp "zombiezen.com/go/capnproto2"
|
||||
"zombiezen.com/go/capnproto2/rpc"
|
||||
@@ -239,8 +240,8 @@ func NewRPCClientStream(ctx context.Context, stream io.ReadWriteCloser, logger *
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (rcs *RPCClientStream) RegisterUdpSession(ctx context.Context, sessionID uuid.UUID, dstIP net.IP, dstPort uint16) error {
|
||||
resp, err := rcs.client.RegisterUdpSession(ctx, sessionID, dstIP, dstPort)
|
||||
func (rcs *RPCClientStream) RegisterUdpSession(ctx context.Context, sessionID uuid.UUID, dstIP net.IP, dstPort uint16, closeIdleAfterHint time.Duration) error {
|
||||
resp, err := rcs.client.RegisterUdpSession(ctx, sessionID, dstIP, dstPort, closeIdleAfterHint)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
@@ -8,6 +8,7 @@ import (
|
||||
"io"
|
||||
"net"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/google/uuid"
|
||||
"github.com/rs/zerolog"
|
||||
@@ -15,6 +16,10 @@ import (
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
const (
|
||||
testCloseIdleAfterHint = time.Minute * 2
|
||||
)
|
||||
|
||||
func TestConnectRequestData(t *testing.T) {
|
||||
var tests = []struct {
|
||||
name string
|
||||
@@ -110,9 +115,10 @@ func TestRegisterUdpSession(t *testing.T) {
|
||||
serverStream := mockRPCStream{serverReader, serverWriter}
|
||||
|
||||
rpcServer := mockRPCServer{
|
||||
sessionID: uuid.New(),
|
||||
dstIP: net.IP{172, 16, 0, 1},
|
||||
dstPort: 8000,
|
||||
sessionID: uuid.New(),
|
||||
dstIP: net.IP{172, 16, 0, 1},
|
||||
dstPort: 8000,
|
||||
closeIdleAfter: testCloseIdleAfterHint,
|
||||
}
|
||||
logger := zerolog.Nop()
|
||||
sessionRegisteredChan := make(chan struct{})
|
||||
@@ -131,10 +137,10 @@ func TestRegisterUdpSession(t *testing.T) {
|
||||
rpcClientStream, err := NewRPCClientStream(context.Background(), clientStream, &logger)
|
||||
assert.NoError(t, err)
|
||||
|
||||
assert.NoError(t, rpcClientStream.RegisterUdpSession(context.Background(), rpcServer.sessionID, rpcServer.dstIP, rpcServer.dstPort))
|
||||
assert.NoError(t, rpcClientStream.RegisterUdpSession(context.Background(), rpcServer.sessionID, rpcServer.dstIP, rpcServer.dstPort, testCloseIdleAfterHint))
|
||||
|
||||
// Different sessionID, the RPC server should reject the registraion
|
||||
assert.Error(t, rpcClientStream.RegisterUdpSession(context.Background(), uuid.New(), rpcServer.dstIP, rpcServer.dstPort))
|
||||
assert.Error(t, rpcClientStream.RegisterUdpSession(context.Background(), uuid.New(), rpcServer.dstIP, rpcServer.dstPort, testCloseIdleAfterHint))
|
||||
|
||||
assert.NoError(t, rpcClientStream.UnregisterUdpSession(context.Background(), rpcServer.sessionID))
|
||||
|
||||
@@ -146,12 +152,13 @@ func TestRegisterUdpSession(t *testing.T) {
|
||||
}
|
||||
|
||||
type mockRPCServer struct {
|
||||
sessionID uuid.UUID
|
||||
dstIP net.IP
|
||||
dstPort uint16
|
||||
sessionID uuid.UUID
|
||||
dstIP net.IP
|
||||
dstPort uint16
|
||||
closeIdleAfter time.Duration
|
||||
}
|
||||
|
||||
func (s mockRPCServer) RegisterUdpSession(ctx context.Context, sessionID uuid.UUID, dstIP net.IP, dstPort uint16) error {
|
||||
func (s mockRPCServer) RegisterUdpSession(ctx context.Context, sessionID uuid.UUID, dstIP net.IP, dstPort uint16, closeIdleAfter time.Duration) error {
|
||||
if s.sessionID != sessionID {
|
||||
return fmt.Errorf("expect session ID %s, got %s", s.sessionID, sessionID)
|
||||
}
|
||||
@@ -159,7 +166,10 @@ func (s mockRPCServer) RegisterUdpSession(ctx context.Context, sessionID uuid.UU
|
||||
return fmt.Errorf("expect destination IP %s, got %s", s.dstIP, dstIP)
|
||||
}
|
||||
if s.dstPort != dstPort {
|
||||
return fmt.Errorf("expect session ID %d, got %d", s.dstPort, dstPort)
|
||||
return fmt.Errorf("expect destination port %d, got %d", s.dstPort, dstPort)
|
||||
}
|
||||
if s.closeIdleAfter != closeIdleAfter {
|
||||
return fmt.Errorf("expect closeIdleAfter %d, got %d", s.closeIdleAfter, closeIdleAfter)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
Reference in New Issue
Block a user