TUN-5488: Close session after it's idle for a period defined by registerUdpSession RPC

This commit is contained in:
cthuang
2021-12-02 11:02:27 +00:00
parent 9bc59bc78c
commit 73a265f2fc
13 changed files with 456 additions and 253 deletions

View File

@@ -57,6 +57,10 @@ func (dm *DatagramMuxer) ReceiveFrom() (uuid.UUID, []byte, error) {
return ExtractSessionID(msg)
}
func (dm *DatagramMuxer) MTU() uint {
return MaxDatagramFrameSize
}
// Each QUIC datagram should be suffixed with session ID.
// ExtractSessionID extracts the session ID and a slice with only the payload
func ExtractSessionID(b []byte) (uuid.UUID, []byte, error) {

View File

@@ -5,6 +5,7 @@ import (
"fmt"
"io"
"net"
"time"
capnp "zombiezen.com/go/capnproto2"
"zombiezen.com/go/capnproto2/rpc"
@@ -239,8 +240,8 @@ func NewRPCClientStream(ctx context.Context, stream io.ReadWriteCloser, logger *
}, nil
}
func (rcs *RPCClientStream) RegisterUdpSession(ctx context.Context, sessionID uuid.UUID, dstIP net.IP, dstPort uint16) error {
resp, err := rcs.client.RegisterUdpSession(ctx, sessionID, dstIP, dstPort)
func (rcs *RPCClientStream) RegisterUdpSession(ctx context.Context, sessionID uuid.UUID, dstIP net.IP, dstPort uint16, closeIdleAfterHint time.Duration) error {
resp, err := rcs.client.RegisterUdpSession(ctx, sessionID, dstIP, dstPort, closeIdleAfterHint)
if err != nil {
return err
}

View File

@@ -8,6 +8,7 @@ import (
"io"
"net"
"testing"
"time"
"github.com/google/uuid"
"github.com/rs/zerolog"
@@ -15,6 +16,10 @@ import (
"github.com/stretchr/testify/require"
)
const (
testCloseIdleAfterHint = time.Minute * 2
)
func TestConnectRequestData(t *testing.T) {
var tests = []struct {
name string
@@ -110,9 +115,10 @@ func TestRegisterUdpSession(t *testing.T) {
serverStream := mockRPCStream{serverReader, serverWriter}
rpcServer := mockRPCServer{
sessionID: uuid.New(),
dstIP: net.IP{172, 16, 0, 1},
dstPort: 8000,
sessionID: uuid.New(),
dstIP: net.IP{172, 16, 0, 1},
dstPort: 8000,
closeIdleAfter: testCloseIdleAfterHint,
}
logger := zerolog.Nop()
sessionRegisteredChan := make(chan struct{})
@@ -131,10 +137,10 @@ func TestRegisterUdpSession(t *testing.T) {
rpcClientStream, err := NewRPCClientStream(context.Background(), clientStream, &logger)
assert.NoError(t, err)
assert.NoError(t, rpcClientStream.RegisterUdpSession(context.Background(), rpcServer.sessionID, rpcServer.dstIP, rpcServer.dstPort))
assert.NoError(t, rpcClientStream.RegisterUdpSession(context.Background(), rpcServer.sessionID, rpcServer.dstIP, rpcServer.dstPort, testCloseIdleAfterHint))
// Different sessionID, the RPC server should reject the registraion
assert.Error(t, rpcClientStream.RegisterUdpSession(context.Background(), uuid.New(), rpcServer.dstIP, rpcServer.dstPort))
assert.Error(t, rpcClientStream.RegisterUdpSession(context.Background(), uuid.New(), rpcServer.dstIP, rpcServer.dstPort, testCloseIdleAfterHint))
assert.NoError(t, rpcClientStream.UnregisterUdpSession(context.Background(), rpcServer.sessionID))
@@ -146,12 +152,13 @@ func TestRegisterUdpSession(t *testing.T) {
}
type mockRPCServer struct {
sessionID uuid.UUID
dstIP net.IP
dstPort uint16
sessionID uuid.UUID
dstIP net.IP
dstPort uint16
closeIdleAfter time.Duration
}
func (s mockRPCServer) RegisterUdpSession(ctx context.Context, sessionID uuid.UUID, dstIP net.IP, dstPort uint16) error {
func (s mockRPCServer) RegisterUdpSession(ctx context.Context, sessionID uuid.UUID, dstIP net.IP, dstPort uint16, closeIdleAfter time.Duration) error {
if s.sessionID != sessionID {
return fmt.Errorf("expect session ID %s, got %s", s.sessionID, sessionID)
}
@@ -159,7 +166,10 @@ func (s mockRPCServer) RegisterUdpSession(ctx context.Context, sessionID uuid.UU
return fmt.Errorf("expect destination IP %s, got %s", s.dstIP, dstIP)
}
if s.dstPort != dstPort {
return fmt.Errorf("expect session ID %d, got %d", s.dstPort, dstPort)
return fmt.Errorf("expect destination port %d, got %d", s.dstPort, dstPort)
}
if s.closeIdleAfter != closeIdleAfter {
return fmt.Errorf("expect closeIdleAfter %d, got %d", s.closeIdleAfter, closeIdleAfter)
}
return nil
}