mirror of
https://github.com/cloudflare/cloudflared.git
synced 2025-05-22 19:26:35 +00:00

Combines the tunnelrpc and quic/schema capnp files into the same module. To help reduce future issues with capnp id generation, capnpids are provided in the capnp files from the existing capnp struct ids generated in the go files. Reduces the overall interface of the Capnp methods to the rest of the code by providing an interface that will handle the quic protocol selection. Introduces a new `rpc-timeout` config that will allow all of the SessionManager and ConfigurationManager RPC requests to have a timeout. The timeout for these values is set to 5 seconds as non of these operations for the managers should take a long time to complete. Removed the RPC-specific logger as it never provided good debugging value as the RPC method names were not visible in the logs.
299 lines
8.3 KiB
Go
299 lines
8.3 KiB
Go
package quic
|
|
|
|
import (
|
|
"bytes"
|
|
"context"
|
|
"errors"
|
|
"fmt"
|
|
"io"
|
|
"net"
|
|
"testing"
|
|
"time"
|
|
|
|
"github.com/google/uuid"
|
|
"github.com/stretchr/testify/assert"
|
|
"github.com/stretchr/testify/require"
|
|
|
|
"github.com/cloudflare/cloudflared/tunnelrpc/pogs"
|
|
)
|
|
|
|
const (
|
|
testCloseIdleAfterHint = time.Minute * 2
|
|
)
|
|
|
|
func TestConnectRequestData(t *testing.T) {
|
|
var tests = []struct {
|
|
name string
|
|
hostname string
|
|
connectionType pogs.ConnectionType
|
|
metadata []pogs.Metadata
|
|
}{
|
|
{
|
|
name: "Signature verified and request metadata is unmarshaled and read correctly",
|
|
hostname: "tunnel.com",
|
|
connectionType: pogs.ConnectionTypeHTTP,
|
|
metadata: []pogs.Metadata{
|
|
{
|
|
Key: "key",
|
|
Val: "1234",
|
|
},
|
|
},
|
|
},
|
|
}
|
|
for _, test := range tests {
|
|
t.Run(test.name, func(t *testing.T) {
|
|
b := &bytes.Buffer{}
|
|
reqClientStream := RequestClientStream{noopCloser{b}}
|
|
err := reqClientStream.WriteConnectRequestData(test.hostname, test.connectionType, test.metadata...)
|
|
require.NoError(t, err)
|
|
protocol, err := determineProtocol(b)
|
|
require.NoError(t, err)
|
|
require.Equal(t, dataStreamProtocolSignature, protocol)
|
|
reqServerStream := RequestServerStream{&noopCloser{b}}
|
|
|
|
reqMeta, err := reqServerStream.ReadConnectRequestData()
|
|
require.NoError(t, err)
|
|
|
|
assert.Equal(t, test.metadata, reqMeta.Metadata)
|
|
assert.Equal(t, test.hostname, reqMeta.Dest)
|
|
assert.Equal(t, test.connectionType, reqMeta.Type)
|
|
})
|
|
}
|
|
}
|
|
|
|
func TestConnectResponseMeta(t *testing.T) {
|
|
var tests = []struct {
|
|
name string
|
|
err error
|
|
metadata []pogs.Metadata
|
|
}{
|
|
{
|
|
name: "Signature verified and response metadata is unmarshaled and read correctly",
|
|
metadata: []pogs.Metadata{
|
|
{
|
|
Key: "key",
|
|
Val: "1234",
|
|
},
|
|
},
|
|
},
|
|
{
|
|
name: "If error is not empty, other fields should be blank",
|
|
err: errors.New("something happened"),
|
|
metadata: []pogs.Metadata{
|
|
{
|
|
Key: "key",
|
|
Val: "1234",
|
|
},
|
|
},
|
|
},
|
|
}
|
|
|
|
for _, test := range tests {
|
|
t.Run(test.name, func(t *testing.T) {
|
|
b := &bytes.Buffer{}
|
|
reqServerStream := RequestServerStream{noopCloser{b}}
|
|
err := reqServerStream.WriteConnectResponseData(test.err, test.metadata...)
|
|
require.NoError(t, err)
|
|
|
|
reqClientStream := RequestClientStream{noopCloser{b}}
|
|
respMeta, err := reqClientStream.ReadConnectResponseData()
|
|
require.NoError(t, err)
|
|
|
|
if respMeta.Error == "" {
|
|
assert.Equal(t, test.metadata, respMeta.Metadata)
|
|
} else {
|
|
assert.Equal(t, 0, len(respMeta.Metadata))
|
|
}
|
|
})
|
|
}
|
|
}
|
|
|
|
func TestRegisterUdpSession(t *testing.T) {
|
|
unregisterMessage := "closed by eyeball"
|
|
|
|
var tests = []struct {
|
|
name string
|
|
sessionRPCServer mockSessionRPCServer
|
|
}{
|
|
{
|
|
name: "RegisterUdpSession (no trace context)",
|
|
sessionRPCServer: mockSessionRPCServer{
|
|
sessionID: uuid.New(),
|
|
dstIP: net.IP{172, 16, 0, 1},
|
|
dstPort: 8000,
|
|
closeIdleAfter: testCloseIdleAfterHint,
|
|
unregisterMessage: unregisterMessage,
|
|
traceContext: "",
|
|
},
|
|
},
|
|
{
|
|
name: "RegisterUdpSession (with trace context)",
|
|
sessionRPCServer: mockSessionRPCServer{
|
|
sessionID: uuid.New(),
|
|
dstIP: net.IP{172, 16, 0, 1},
|
|
dstPort: 8000,
|
|
closeIdleAfter: testCloseIdleAfterHint,
|
|
unregisterMessage: unregisterMessage,
|
|
traceContext: "1241ce3ecdefc68854e8514e69ba42ca:b38f1bf5eae406f3:0:1",
|
|
},
|
|
},
|
|
}
|
|
|
|
for _, test := range tests {
|
|
t.Run(test.name, func(t *testing.T) {
|
|
clientStream, serverStream := newMockRPCStreams()
|
|
sessionRegisteredChan := make(chan struct{})
|
|
go func() {
|
|
ss := NewCloudflaredServer(nil, test.sessionRPCServer, nil, 10*time.Second)
|
|
err := ss.Serve(context.Background(), serverStream)
|
|
assert.NoError(t, err)
|
|
|
|
serverStream.Close()
|
|
close(sessionRegisteredChan)
|
|
}()
|
|
|
|
rpcClientStream, err := NewCloudflaredClient(context.Background(), clientStream, 5*time.Second)
|
|
assert.NoError(t, err)
|
|
|
|
reg, err := rpcClientStream.RegisterUdpSession(context.Background(), test.sessionRPCServer.sessionID, test.sessionRPCServer.dstIP, test.sessionRPCServer.dstPort, testCloseIdleAfterHint, test.sessionRPCServer.traceContext)
|
|
assert.NoError(t, err)
|
|
assert.NoError(t, reg.Err)
|
|
|
|
// Different sessionID, the RPC server should reject the registraion
|
|
reg, err = rpcClientStream.RegisterUdpSession(context.Background(), uuid.New(), test.sessionRPCServer.dstIP, test.sessionRPCServer.dstPort, testCloseIdleAfterHint, test.sessionRPCServer.traceContext)
|
|
assert.NoError(t, err)
|
|
assert.Error(t, reg.Err)
|
|
|
|
assert.NoError(t, rpcClientStream.UnregisterUdpSession(context.Background(), test.sessionRPCServer.sessionID, unregisterMessage))
|
|
|
|
// Different sessionID, the RPC server should reject the unregistraion
|
|
assert.Error(t, rpcClientStream.UnregisterUdpSession(context.Background(), uuid.New(), unregisterMessage))
|
|
|
|
rpcClientStream.Close()
|
|
<-sessionRegisteredChan
|
|
})
|
|
}
|
|
}
|
|
|
|
func TestManageConfiguration(t *testing.T) {
|
|
var (
|
|
version int32 = 168
|
|
config = []byte(t.Name())
|
|
)
|
|
clientStream, serverStream := newMockRPCStreams()
|
|
|
|
configRPCServer := mockConfigRPCServer{
|
|
version: version,
|
|
config: config,
|
|
}
|
|
|
|
updatedChan := make(chan struct{})
|
|
go func() {
|
|
server := NewCloudflaredServer(nil, nil, configRPCServer, 10*time.Second)
|
|
err := server.Serve(context.Background(), serverStream)
|
|
assert.NoError(t, err)
|
|
|
|
serverStream.Close()
|
|
close(updatedChan)
|
|
}()
|
|
|
|
ctx, cancel := context.WithTimeout(context.Background(), time.Second)
|
|
defer cancel()
|
|
rpcClientStream, err := NewCloudflaredClient(ctx, clientStream, 5*time.Second)
|
|
assert.NoError(t, err)
|
|
|
|
result, err := rpcClientStream.UpdateConfiguration(ctx, version, config)
|
|
assert.NoError(t, err)
|
|
|
|
require.Equal(t, version, result.LastAppliedVersion)
|
|
require.NoError(t, result.Err)
|
|
|
|
rpcClientStream.Close()
|
|
<-updatedChan
|
|
}
|
|
|
|
type mockSessionRPCServer struct {
|
|
sessionID uuid.UUID
|
|
dstIP net.IP
|
|
dstPort uint16
|
|
closeIdleAfter time.Duration
|
|
unregisterMessage string
|
|
traceContext string
|
|
}
|
|
|
|
func (s mockSessionRPCServer) RegisterUdpSession(_ context.Context, sessionID uuid.UUID, dstIP net.IP, dstPort uint16, closeIdleAfter time.Duration, traceContext string) (*pogs.RegisterUdpSessionResponse, error) {
|
|
if s.sessionID != sessionID {
|
|
return nil, fmt.Errorf("expect session ID %s, got %s", s.sessionID, sessionID)
|
|
}
|
|
if !s.dstIP.Equal(dstIP) {
|
|
return nil, fmt.Errorf("expect destination IP %s, got %s", s.dstIP, dstIP)
|
|
}
|
|
if s.dstPort != dstPort {
|
|
return nil, fmt.Errorf("expect destination port %d, got %d", s.dstPort, dstPort)
|
|
}
|
|
if s.closeIdleAfter != closeIdleAfter {
|
|
return nil, fmt.Errorf("expect closeIdleAfter %d, got %d", s.closeIdleAfter, closeIdleAfter)
|
|
}
|
|
if s.traceContext != traceContext {
|
|
return nil, fmt.Errorf("expect traceContext %s, got %s", s.traceContext, traceContext)
|
|
}
|
|
return &pogs.RegisterUdpSessionResponse{}, nil
|
|
}
|
|
|
|
func (s mockSessionRPCServer) UnregisterUdpSession(_ context.Context, sessionID uuid.UUID, message string) error {
|
|
if s.sessionID != sessionID {
|
|
return fmt.Errorf("expect session ID %s, got %s", s.sessionID, sessionID)
|
|
}
|
|
if s.unregisterMessage != message {
|
|
return fmt.Errorf("expect unregister message %s, got %s", s.unregisterMessage, message)
|
|
}
|
|
return nil
|
|
}
|
|
|
|
type mockConfigRPCServer struct {
|
|
version int32
|
|
config []byte
|
|
}
|
|
|
|
func (s mockConfigRPCServer) UpdateConfiguration(_ context.Context, version int32, config []byte) *pogs.UpdateConfigurationResponse {
|
|
if s.version != version {
|
|
return &pogs.UpdateConfigurationResponse{
|
|
Err: fmt.Errorf("expect version %d, got %d", s.version, version),
|
|
}
|
|
}
|
|
if !bytes.Equal(s.config, config) {
|
|
return &pogs.UpdateConfigurationResponse{
|
|
Err: fmt.Errorf("expect config %v, got %v", s.config, config),
|
|
}
|
|
}
|
|
return &pogs.UpdateConfigurationResponse{LastAppliedVersion: version}
|
|
}
|
|
|
|
type mockRPCStream struct {
|
|
io.ReadCloser
|
|
io.WriteCloser
|
|
}
|
|
|
|
func newMockRPCStreams() (client io.ReadWriteCloser, server io.ReadWriteCloser) {
|
|
clientReader, serverWriter := io.Pipe()
|
|
serverReader, clientWriter := io.Pipe()
|
|
|
|
client = mockRPCStream{clientReader, clientWriter}
|
|
server = mockRPCStream{serverReader, serverWriter}
|
|
return
|
|
}
|
|
|
|
func (s mockRPCStream) Close() error {
|
|
_ = s.ReadCloser.Close()
|
|
_ = s.WriteCloser.Close()
|
|
return nil
|
|
}
|
|
|
|
type noopCloser struct {
|
|
io.ReadWriter
|
|
}
|
|
|
|
func (noopCloser) Close() error {
|
|
return nil
|
|
}
|