mirror of
https://github.com/cloudflare/cloudflared.git
synced 2025-07-27 17:29:58 +00:00
TUN-6688: Update RegisterUdpSession capnproto to include trace context
This commit is contained in:
@@ -247,12 +247,8 @@ func NewRPCClientStream(ctx context.Context, stream io.ReadWriteCloser, logger *
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (rcs *RPCClientStream) RegisterUdpSession(ctx context.Context, sessionID uuid.UUID, dstIP net.IP, dstPort uint16, closeIdleAfterHint time.Duration) error {
|
||||
resp, err := rcs.client.RegisterUdpSession(ctx, sessionID, dstIP, dstPort, closeIdleAfterHint)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
return resp.Err
|
||||
func (rcs *RPCClientStream) RegisterUdpSession(ctx context.Context, sessionID uuid.UUID, dstIP net.IP, dstPort uint16, closeIdleAfterHint time.Duration, traceContext string) (*tunnelpogs.RegisterUdpSessionResponse, error) {
|
||||
return rcs.client.RegisterUdpSession(ctx, sessionID, dstIP, dstPort, closeIdleAfterHint, traceContext)
|
||||
}
|
||||
|
||||
func (rcs *RPCClientStream) UnregisterUdpSession(ctx context.Context, sessionID uuid.UUID, message string) error {
|
||||
|
@@ -110,45 +110,74 @@ func TestConnectResponseMeta(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestRegisterUdpSession(t *testing.T) {
|
||||
clientStream, serverStream := newMockRPCStreams()
|
||||
|
||||
unregisterMessage := "closed by eyeball"
|
||||
sessionRPCServer := mockSessionRPCServer{
|
||||
sessionID: uuid.New(),
|
||||
dstIP: net.IP{172, 16, 0, 1},
|
||||
dstPort: 8000,
|
||||
closeIdleAfter: testCloseIdleAfterHint,
|
||||
unregisterMessage: unregisterMessage,
|
||||
|
||||
var tests = []struct {
|
||||
name string
|
||||
sessionRPCServer mockSessionRPCServer
|
||||
}{
|
||||
{
|
||||
name: "RegisterUdpSession (no trace context)",
|
||||
sessionRPCServer: mockSessionRPCServer{
|
||||
sessionID: uuid.New(),
|
||||
dstIP: net.IP{172, 16, 0, 1},
|
||||
dstPort: 8000,
|
||||
closeIdleAfter: testCloseIdleAfterHint,
|
||||
unregisterMessage: unregisterMessage,
|
||||
traceContext: "",
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "RegisterUdpSession (with trace context)",
|
||||
sessionRPCServer: mockSessionRPCServer{
|
||||
sessionID: uuid.New(),
|
||||
dstIP: net.IP{172, 16, 0, 1},
|
||||
dstPort: 8000,
|
||||
closeIdleAfter: testCloseIdleAfterHint,
|
||||
unregisterMessage: unregisterMessage,
|
||||
traceContext: "1241ce3ecdefc68854e8514e69ba42ca:b38f1bf5eae406f3:0:1",
|
||||
},
|
||||
},
|
||||
}
|
||||
logger := zerolog.Nop()
|
||||
sessionRegisteredChan := make(chan struct{})
|
||||
go func() {
|
||||
protocol, err := DetermineProtocol(serverStream)
|
||||
assert.NoError(t, err)
|
||||
rpcServerStream, err := NewRPCServerStream(serverStream, protocol)
|
||||
assert.NoError(t, err)
|
||||
err = rpcServerStream.Serve(sessionRPCServer, nil, &logger)
|
||||
assert.NoError(t, err)
|
||||
|
||||
serverStream.Close()
|
||||
close(sessionRegisteredChan)
|
||||
}()
|
||||
for _, test := range tests {
|
||||
t.Run(test.name, func(t *testing.T) {
|
||||
logger := zerolog.Nop()
|
||||
clientStream, serverStream := newMockRPCStreams()
|
||||
sessionRegisteredChan := make(chan struct{})
|
||||
go func() {
|
||||
protocol, err := DetermineProtocol(serverStream)
|
||||
assert.NoError(t, err)
|
||||
rpcServerStream, err := NewRPCServerStream(serverStream, protocol)
|
||||
assert.NoError(t, err)
|
||||
err = rpcServerStream.Serve(test.sessionRPCServer, nil, &logger)
|
||||
assert.NoError(t, err)
|
||||
|
||||
rpcClientStream, err := NewRPCClientStream(context.Background(), clientStream, &logger)
|
||||
assert.NoError(t, err)
|
||||
serverStream.Close()
|
||||
close(sessionRegisteredChan)
|
||||
}()
|
||||
|
||||
assert.NoError(t, rpcClientStream.RegisterUdpSession(context.Background(), sessionRPCServer.sessionID, sessionRPCServer.dstIP, sessionRPCServer.dstPort, testCloseIdleAfterHint))
|
||||
rpcClientStream, err := NewRPCClientStream(context.Background(), clientStream, &logger)
|
||||
assert.NoError(t, err)
|
||||
|
||||
// Different sessionID, the RPC server should reject the registraion
|
||||
assert.Error(t, rpcClientStream.RegisterUdpSession(context.Background(), uuid.New(), sessionRPCServer.dstIP, sessionRPCServer.dstPort, testCloseIdleAfterHint))
|
||||
reg, err := rpcClientStream.RegisterUdpSession(context.Background(), test.sessionRPCServer.sessionID, test.sessionRPCServer.dstIP, test.sessionRPCServer.dstPort, testCloseIdleAfterHint, test.sessionRPCServer.traceContext)
|
||||
assert.NoError(t, err)
|
||||
assert.NoError(t, reg.Err)
|
||||
|
||||
assert.NoError(t, rpcClientStream.UnregisterUdpSession(context.Background(), sessionRPCServer.sessionID, unregisterMessage))
|
||||
// Different sessionID, the RPC server should reject the registraion
|
||||
reg, err = rpcClientStream.RegisterUdpSession(context.Background(), uuid.New(), test.sessionRPCServer.dstIP, test.sessionRPCServer.dstPort, testCloseIdleAfterHint, test.sessionRPCServer.traceContext)
|
||||
assert.NoError(t, err)
|
||||
assert.Error(t, reg.Err)
|
||||
|
||||
// Different sessionID, the RPC server should reject the unregistraion
|
||||
assert.Error(t, rpcClientStream.UnregisterUdpSession(context.Background(), uuid.New(), unregisterMessage))
|
||||
assert.NoError(t, rpcClientStream.UnregisterUdpSession(context.Background(), test.sessionRPCServer.sessionID, unregisterMessage))
|
||||
|
||||
rpcClientStream.Close()
|
||||
<-sessionRegisteredChan
|
||||
// Different sessionID, the RPC server should reject the unregistraion
|
||||
assert.Error(t, rpcClientStream.UnregisterUdpSession(context.Background(), uuid.New(), unregisterMessage))
|
||||
|
||||
rpcClientStream.Close()
|
||||
<-sessionRegisteredChan
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestManageConfiguration(t *testing.T) {
|
||||
@@ -198,9 +227,10 @@ type mockSessionRPCServer struct {
|
||||
dstPort uint16
|
||||
closeIdleAfter time.Duration
|
||||
unregisterMessage string
|
||||
traceContext string
|
||||
}
|
||||
|
||||
func (s mockSessionRPCServer) RegisterUdpSession(_ context.Context, sessionID uuid.UUID, dstIP net.IP, dstPort uint16, closeIdleAfter time.Duration) error {
|
||||
func (s mockSessionRPCServer) RegisterUdpSession(_ context.Context, sessionID uuid.UUID, dstIP net.IP, dstPort uint16, closeIdleAfter time.Duration, traceContext string) error {
|
||||
if s.sessionID != sessionID {
|
||||
return fmt.Errorf("expect session ID %s, got %s", s.sessionID, sessionID)
|
||||
}
|
||||
@@ -213,6 +243,9 @@ func (s mockSessionRPCServer) RegisterUdpSession(_ context.Context, sessionID uu
|
||||
if s.closeIdleAfter != closeIdleAfter {
|
||||
return fmt.Errorf("expect closeIdleAfter %d, got %d", s.closeIdleAfter, closeIdleAfter)
|
||||
}
|
||||
if s.traceContext != traceContext {
|
||||
return fmt.Errorf("expect traceContext %s, got %s", s.traceContext, traceContext)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
|
Reference in New Issue
Block a user