TUN-6688: Update RegisterUdpSession capnproto to include trace context

This commit is contained in:
Devin Carr
2022-09-07 15:06:06 -07:00
parent 11cbff4ff7
commit e380333520
8 changed files with 361 additions and 279 deletions

View File

@@ -247,12 +247,8 @@ func NewRPCClientStream(ctx context.Context, stream io.ReadWriteCloser, logger *
}, nil
}
func (rcs *RPCClientStream) RegisterUdpSession(ctx context.Context, sessionID uuid.UUID, dstIP net.IP, dstPort uint16, closeIdleAfterHint time.Duration) error {
resp, err := rcs.client.RegisterUdpSession(ctx, sessionID, dstIP, dstPort, closeIdleAfterHint)
if err != nil {
return err
}
return resp.Err
func (rcs *RPCClientStream) RegisterUdpSession(ctx context.Context, sessionID uuid.UUID, dstIP net.IP, dstPort uint16, closeIdleAfterHint time.Duration, traceContext string) (*tunnelpogs.RegisterUdpSessionResponse, error) {
return rcs.client.RegisterUdpSession(ctx, sessionID, dstIP, dstPort, closeIdleAfterHint, traceContext)
}
func (rcs *RPCClientStream) UnregisterUdpSession(ctx context.Context, sessionID uuid.UUID, message string) error {

View File

@@ -110,45 +110,74 @@ func TestConnectResponseMeta(t *testing.T) {
}
func TestRegisterUdpSession(t *testing.T) {
clientStream, serverStream := newMockRPCStreams()
unregisterMessage := "closed by eyeball"
sessionRPCServer := mockSessionRPCServer{
sessionID: uuid.New(),
dstIP: net.IP{172, 16, 0, 1},
dstPort: 8000,
closeIdleAfter: testCloseIdleAfterHint,
unregisterMessage: unregisterMessage,
var tests = []struct {
name string
sessionRPCServer mockSessionRPCServer
}{
{
name: "RegisterUdpSession (no trace context)",
sessionRPCServer: mockSessionRPCServer{
sessionID: uuid.New(),
dstIP: net.IP{172, 16, 0, 1},
dstPort: 8000,
closeIdleAfter: testCloseIdleAfterHint,
unregisterMessage: unregisterMessage,
traceContext: "",
},
},
{
name: "RegisterUdpSession (with trace context)",
sessionRPCServer: mockSessionRPCServer{
sessionID: uuid.New(),
dstIP: net.IP{172, 16, 0, 1},
dstPort: 8000,
closeIdleAfter: testCloseIdleAfterHint,
unregisterMessage: unregisterMessage,
traceContext: "1241ce3ecdefc68854e8514e69ba42ca:b38f1bf5eae406f3:0:1",
},
},
}
logger := zerolog.Nop()
sessionRegisteredChan := make(chan struct{})
go func() {
protocol, err := DetermineProtocol(serverStream)
assert.NoError(t, err)
rpcServerStream, err := NewRPCServerStream(serverStream, protocol)
assert.NoError(t, err)
err = rpcServerStream.Serve(sessionRPCServer, nil, &logger)
assert.NoError(t, err)
serverStream.Close()
close(sessionRegisteredChan)
}()
for _, test := range tests {
t.Run(test.name, func(t *testing.T) {
logger := zerolog.Nop()
clientStream, serverStream := newMockRPCStreams()
sessionRegisteredChan := make(chan struct{})
go func() {
protocol, err := DetermineProtocol(serverStream)
assert.NoError(t, err)
rpcServerStream, err := NewRPCServerStream(serverStream, protocol)
assert.NoError(t, err)
err = rpcServerStream.Serve(test.sessionRPCServer, nil, &logger)
assert.NoError(t, err)
rpcClientStream, err := NewRPCClientStream(context.Background(), clientStream, &logger)
assert.NoError(t, err)
serverStream.Close()
close(sessionRegisteredChan)
}()
assert.NoError(t, rpcClientStream.RegisterUdpSession(context.Background(), sessionRPCServer.sessionID, sessionRPCServer.dstIP, sessionRPCServer.dstPort, testCloseIdleAfterHint))
rpcClientStream, err := NewRPCClientStream(context.Background(), clientStream, &logger)
assert.NoError(t, err)
// Different sessionID, the RPC server should reject the registraion
assert.Error(t, rpcClientStream.RegisterUdpSession(context.Background(), uuid.New(), sessionRPCServer.dstIP, sessionRPCServer.dstPort, testCloseIdleAfterHint))
reg, err := rpcClientStream.RegisterUdpSession(context.Background(), test.sessionRPCServer.sessionID, test.sessionRPCServer.dstIP, test.sessionRPCServer.dstPort, testCloseIdleAfterHint, test.sessionRPCServer.traceContext)
assert.NoError(t, err)
assert.NoError(t, reg.Err)
assert.NoError(t, rpcClientStream.UnregisterUdpSession(context.Background(), sessionRPCServer.sessionID, unregisterMessage))
// Different sessionID, the RPC server should reject the registraion
reg, err = rpcClientStream.RegisterUdpSession(context.Background(), uuid.New(), test.sessionRPCServer.dstIP, test.sessionRPCServer.dstPort, testCloseIdleAfterHint, test.sessionRPCServer.traceContext)
assert.NoError(t, err)
assert.Error(t, reg.Err)
// Different sessionID, the RPC server should reject the unregistraion
assert.Error(t, rpcClientStream.UnregisterUdpSession(context.Background(), uuid.New(), unregisterMessage))
assert.NoError(t, rpcClientStream.UnregisterUdpSession(context.Background(), test.sessionRPCServer.sessionID, unregisterMessage))
rpcClientStream.Close()
<-sessionRegisteredChan
// Different sessionID, the RPC server should reject the unregistraion
assert.Error(t, rpcClientStream.UnregisterUdpSession(context.Background(), uuid.New(), unregisterMessage))
rpcClientStream.Close()
<-sessionRegisteredChan
})
}
}
func TestManageConfiguration(t *testing.T) {
@@ -198,9 +227,10 @@ type mockSessionRPCServer struct {
dstPort uint16
closeIdleAfter time.Duration
unregisterMessage string
traceContext string
}
func (s mockSessionRPCServer) RegisterUdpSession(_ context.Context, sessionID uuid.UUID, dstIP net.IP, dstPort uint16, closeIdleAfter time.Duration) error {
func (s mockSessionRPCServer) RegisterUdpSession(_ context.Context, sessionID uuid.UUID, dstIP net.IP, dstPort uint16, closeIdleAfter time.Duration, traceContext string) error {
if s.sessionID != sessionID {
return fmt.Errorf("expect session ID %s, got %s", s.sessionID, sessionID)
}
@@ -213,6 +243,9 @@ func (s mockSessionRPCServer) RegisterUdpSession(_ context.Context, sessionID uu
if s.closeIdleAfter != closeIdleAfter {
return fmt.Errorf("expect closeIdleAfter %d, got %d", s.closeIdleAfter, closeIdleAfter)
}
if s.traceContext != traceContext {
return fmt.Errorf("expect traceContext %s, got %s", s.traceContext, traceContext)
}
return nil
}