TUN-7480: Added a timeout for unregisterUDP.

I deliberately kept this as an unregistertimeout because that was the
intent. In the future we could change this to a UDPConnConfig if we want
to pass multiple values here.

The idea of this PR is simply to add a configurable unregister UDP
timeout.
This commit is contained in:
Sudarsan Reddy
2023-06-19 17:03:11 +01:00
parent a3bcf25fae
commit 1abd22ef0a
7 changed files with 96 additions and 15 deletions

View File

@@ -229,9 +229,12 @@ func writeSignature(stream io.Writer, signature ProtocolSignature) error {
type RPCClientStream struct {
client tunnelpogs.CloudflaredServer_PogsClient
transport rpc.Transport
// Time we wait for the server to respond to a request before we close the connection.
rpcUnregisterUDPSessionDeadline time.Duration
}
func NewRPCClientStream(ctx context.Context, stream io.ReadWriteCloser, logger *zerolog.Logger) (*RPCClientStream, error) {
func NewRPCClientStream(ctx context.Context, stream io.ReadWriteCloser, rpcUnregisterUDPSessionDeadline time.Duration, logger *zerolog.Logger) (*RPCClientStream, error) {
n, err := stream.Write(RPCStreamProtocolSignature[:])
if err != nil {
return nil, err
@@ -245,8 +248,9 @@ func NewRPCClientStream(ctx context.Context, stream io.ReadWriteCloser, logger *
tunnelrpc.ConnLog(logger),
)
return &RPCClientStream{
client: tunnelpogs.NewCloudflaredServer_PogsClient(conn.Bootstrap(ctx), conn),
transport: transport,
client: tunnelpogs.NewCloudflaredServer_PogsClient(conn.Bootstrap(ctx), conn),
transport: transport,
rpcUnregisterUDPSessionDeadline: rpcUnregisterUDPSessionDeadline,
}, nil
}
@@ -255,6 +259,8 @@ func (rcs *RPCClientStream) RegisterUdpSession(ctx context.Context, sessionID uu
}
func (rcs *RPCClientStream) UnregisterUdpSession(ctx context.Context, sessionID uuid.UUID, message string) error {
ctx, cancel := context.WithTimeout(ctx, rcs.rpcUnregisterUDPSessionDeadline)
defer cancel()
return rcs.client.UnregisterUdpSession(ctx, sessionID, message)
}

View File

@@ -109,6 +109,63 @@ func TestConnectResponseMeta(t *testing.T) {
}
}
func TestUnregisterUdpSession(t *testing.T) {
unregisterMessage := "closed by eyeball"
var tests = []struct {
name string
sessionRPCServer mockSessionRPCServer
timeout time.Duration
}{
{
name: "UnregisterUdpSessionTimesout if the RPC server does not respond",
sessionRPCServer: mockSessionRPCServer{
sessionID: uuid.New(),
dstIP: net.IP{172, 16, 0, 1},
dstPort: 8000,
closeIdleAfter: testCloseIdleAfterHint,
unregisterMessage: unregisterMessage,
traceContext: "1241ce3ecdefc68854e8514e69ba42ca:b38f1bf5eae406f3:0:1",
},
// very very low value so we trigger the timeout every time.
timeout: time.Nanosecond * 1,
},
}
for _, test := range tests {
t.Run(test.name, func(t *testing.T) {
logger := zerolog.Nop()
clientStream, serverStream := newMockRPCStreams()
sessionRegisteredChan := make(chan struct{})
go func() {
protocol, err := DetermineProtocol(serverStream)
assert.NoError(t, err)
rpcServerStream, err := NewRPCServerStream(serverStream, protocol)
assert.NoError(t, err)
err = rpcServerStream.Serve(test.sessionRPCServer, nil, &logger)
assert.NoError(t, err)
serverStream.Close()
close(sessionRegisteredChan)
}()
rpcClientStream, err := NewRPCClientStream(context.Background(), clientStream, test.timeout, &logger)
assert.NoError(t, err)
reg, err := rpcClientStream.RegisterUdpSession(context.Background(), test.sessionRPCServer.sessionID, test.sessionRPCServer.dstIP, test.sessionRPCServer.dstPort, testCloseIdleAfterHint, test.sessionRPCServer.traceContext)
assert.NoError(t, err)
assert.NoError(t, reg.Err)
assert.Error(t, rpcClientStream.UnregisterUdpSession(context.Background(), test.sessionRPCServer.sessionID, unregisterMessage))
rpcClientStream.Close()
<-sessionRegisteredChan
})
}
}
func TestRegisterUdpSession(t *testing.T) {
unregisterMessage := "closed by eyeball"
@@ -157,7 +214,7 @@ func TestRegisterUdpSession(t *testing.T) {
close(sessionRegisteredChan)
}()
rpcClientStream, err := NewRPCClientStream(context.Background(), clientStream, &logger)
rpcClientStream, err := NewRPCClientStream(context.Background(), clientStream, 5*time.Second, &logger)
assert.NoError(t, err)
reg, err := rpcClientStream.RegisterUdpSession(context.Background(), test.sessionRPCServer.sessionID, test.sessionRPCServer.dstIP, test.sessionRPCServer.dstPort, testCloseIdleAfterHint, test.sessionRPCServer.traceContext)
@@ -208,7 +265,7 @@ func TestManageConfiguration(t *testing.T) {
ctx, cancel := context.WithTimeout(context.Background(), time.Second)
defer cancel()
rpcClientStream, err := NewRPCClientStream(ctx, clientStream, &logger)
rpcClientStream, err := NewRPCClientStream(ctx, clientStream, 5*time.Second, &logger)
assert.NoError(t, err)
result, err := rpcClientStream.UpdateConfiguration(ctx, version, config)