mirror of
https://github.com/cloudflare/cloudflared.git
synced 2025-07-27 20:39:57 +00:00
TUN-5695: Define RPC method to update configuration
This commit is contained in:
@@ -125,7 +125,7 @@ func (rcs *RequestClientStream) ReadConnectResponseData() (*ConnectResponse, err
|
||||
return nil, err
|
||||
}
|
||||
if signature != DataStreamProtocolSignature {
|
||||
return nil, fmt.Errorf("Wrong protocol signature %v", signature)
|
||||
return nil, fmt.Errorf("wrong protocol signature %v", signature)
|
||||
}
|
||||
|
||||
// This is a NO-OP for now. We could cause a branching if we wanted to use multiple versions.
|
||||
@@ -157,13 +157,13 @@ func NewRPCServerStream(stream io.ReadWriteCloser, protocol ProtocolSignature) (
|
||||
return &RPCServerStream{stream}, nil
|
||||
}
|
||||
|
||||
func (s *RPCServerStream) Serve(sessionManager tunnelpogs.SessionManager, logger *zerolog.Logger) error {
|
||||
func (s *RPCServerStream) Serve(sessionManager tunnelpogs.SessionManager, configManager tunnelpogs.ConfigurationManager, logger *zerolog.Logger) error {
|
||||
// RPC logs are very robust, create a new logger that only logs error to reduce noise
|
||||
rpcLogger := logger.Level(zerolog.ErrorLevel)
|
||||
rpcTransport := tunnelrpc.NewTransportLogger(&rpcLogger, rpc.StreamTransport(s))
|
||||
defer rpcTransport.Close()
|
||||
|
||||
main := tunnelpogs.SessionManager_ServerToClient(sessionManager)
|
||||
main := tunnelpogs.CloudflaredServer_ServerToClient(sessionManager, configManager)
|
||||
rpcConn := rpc.NewConn(
|
||||
rpcTransport,
|
||||
rpc.MainInterface(main.Client),
|
||||
@@ -223,7 +223,7 @@ func writeSignature(stream io.Writer, signature ProtocolSignature) error {
|
||||
|
||||
// RPCClientStream is a stream to call methods of SessionManager
|
||||
type RPCClientStream struct {
|
||||
client tunnelpogs.SessionManager_PogsClient
|
||||
client tunnelpogs.CloudflaredServer_PogsClient
|
||||
transport rpc.Transport
|
||||
}
|
||||
|
||||
@@ -241,7 +241,7 @@ func NewRPCClientStream(ctx context.Context, stream io.ReadWriteCloser, logger *
|
||||
tunnelrpc.ConnLog(logger),
|
||||
)
|
||||
return &RPCClientStream{
|
||||
client: tunnelpogs.SessionManager_PogsClient{Client: conn.Bootstrap(ctx), Conn: conn},
|
||||
client: tunnelpogs.NewCloudflaredServer_PogsClient(conn.Bootstrap(ctx), conn),
|
||||
transport: transport,
|
||||
}, nil
|
||||
}
|
||||
@@ -258,6 +258,10 @@ func (rcs *RPCClientStream) UnregisterUdpSession(ctx context.Context, sessionID
|
||||
return rcs.client.UnregisterUdpSession(ctx, sessionID, message)
|
||||
}
|
||||
|
||||
func (rcs *RPCClientStream) UpdateConfiguration(ctx context.Context, version int32, config []byte) (*tunnelpogs.UpdateConfigurationResponse, error) {
|
||||
return rcs.client.UpdateConfiguration(ctx, version, config)
|
||||
}
|
||||
|
||||
func (rcs *RPCClientStream) Close() {
|
||||
_ = rcs.client.Close()
|
||||
_ = rcs.transport.Close()
|
||||
|
@@ -14,6 +14,8 @@ import (
|
||||
"github.com/rs/zerolog"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
tunnelpogs "github.com/cloudflare/cloudflared/tunnelrpc/pogs"
|
||||
)
|
||||
|
||||
const (
|
||||
@@ -108,14 +110,10 @@ func TestConnectResponseMeta(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestRegisterUdpSession(t *testing.T) {
|
||||
clientReader, serverWriter := io.Pipe()
|
||||
serverReader, clientWriter := io.Pipe()
|
||||
|
||||
clientStream := mockRPCStream{clientReader, clientWriter}
|
||||
serverStream := mockRPCStream{serverReader, serverWriter}
|
||||
clientStream, serverStream := newMockRPCStreams()
|
||||
|
||||
unregisterMessage := "closed by eyeball"
|
||||
rpcServer := mockRPCServer{
|
||||
sessionRPCServer := mockSessionRPCServer{
|
||||
sessionID: uuid.New(),
|
||||
dstIP: net.IP{172, 16, 0, 1},
|
||||
dstPort: 8000,
|
||||
@@ -129,7 +127,7 @@ func TestRegisterUdpSession(t *testing.T) {
|
||||
assert.NoError(t, err)
|
||||
rpcServerStream, err := NewRPCServerStream(serverStream, protocol)
|
||||
assert.NoError(t, err)
|
||||
err = rpcServerStream.Serve(rpcServer, &logger)
|
||||
err = rpcServerStream.Serve(sessionRPCServer, nil, &logger)
|
||||
assert.NoError(t, err)
|
||||
|
||||
serverStream.Close()
|
||||
@@ -139,12 +137,12 @@ func TestRegisterUdpSession(t *testing.T) {
|
||||
rpcClientStream, err := NewRPCClientStream(context.Background(), clientStream, &logger)
|
||||
assert.NoError(t, err)
|
||||
|
||||
assert.NoError(t, rpcClientStream.RegisterUdpSession(context.Background(), rpcServer.sessionID, rpcServer.dstIP, rpcServer.dstPort, testCloseIdleAfterHint))
|
||||
assert.NoError(t, rpcClientStream.RegisterUdpSession(context.Background(), sessionRPCServer.sessionID, sessionRPCServer.dstIP, sessionRPCServer.dstPort, testCloseIdleAfterHint))
|
||||
|
||||
// Different sessionID, the RPC server should reject the registraion
|
||||
assert.Error(t, rpcClientStream.RegisterUdpSession(context.Background(), uuid.New(), rpcServer.dstIP, rpcServer.dstPort, testCloseIdleAfterHint))
|
||||
assert.Error(t, rpcClientStream.RegisterUdpSession(context.Background(), uuid.New(), sessionRPCServer.dstIP, sessionRPCServer.dstPort, testCloseIdleAfterHint))
|
||||
|
||||
assert.NoError(t, rpcClientStream.UnregisterUdpSession(context.Background(), rpcServer.sessionID, unregisterMessage))
|
||||
assert.NoError(t, rpcClientStream.UnregisterUdpSession(context.Background(), sessionRPCServer.sessionID, unregisterMessage))
|
||||
|
||||
// Different sessionID, the RPC server should reject the unregistraion
|
||||
assert.Error(t, rpcClientStream.UnregisterUdpSession(context.Background(), uuid.New(), unregisterMessage))
|
||||
@@ -153,7 +151,48 @@ func TestRegisterUdpSession(t *testing.T) {
|
||||
<-sessionRegisteredChan
|
||||
}
|
||||
|
||||
type mockRPCServer struct {
|
||||
func TestManageConfiguration(t *testing.T) {
|
||||
var (
|
||||
version int32 = 168
|
||||
config = []byte(t.Name())
|
||||
)
|
||||
clientStream, serverStream := newMockRPCStreams()
|
||||
|
||||
configRPCServer := mockConfigRPCServer{
|
||||
version: version,
|
||||
config: config,
|
||||
}
|
||||
|
||||
logger := zerolog.Nop()
|
||||
updatedChan := make(chan struct{})
|
||||
go func() {
|
||||
protocol, err := DetermineProtocol(serverStream)
|
||||
assert.NoError(t, err)
|
||||
rpcServerStream, err := NewRPCServerStream(serverStream, protocol)
|
||||
assert.NoError(t, err)
|
||||
err = rpcServerStream.Serve(nil, configRPCServer, &logger)
|
||||
assert.NoError(t, err)
|
||||
|
||||
serverStream.Close()
|
||||
close(updatedChan)
|
||||
}()
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), time.Second)
|
||||
defer cancel()
|
||||
rpcClientStream, err := NewRPCClientStream(ctx, clientStream, &logger)
|
||||
assert.NoError(t, err)
|
||||
|
||||
result, err := rpcClientStream.UpdateConfiguration(ctx, version, config)
|
||||
assert.NoError(t, err)
|
||||
|
||||
require.Equal(t, version, result.LastAppliedVersion)
|
||||
require.NoError(t, result.Err)
|
||||
|
||||
rpcClientStream.Close()
|
||||
<-updatedChan
|
||||
}
|
||||
|
||||
type mockSessionRPCServer struct {
|
||||
sessionID uuid.UUID
|
||||
dstIP net.IP
|
||||
dstPort uint16
|
||||
@@ -161,7 +200,7 @@ type mockRPCServer struct {
|
||||
unregisterMessage string
|
||||
}
|
||||
|
||||
func (s mockRPCServer) RegisterUdpSession(ctx context.Context, sessionID uuid.UUID, dstIP net.IP, dstPort uint16, closeIdleAfter time.Duration) error {
|
||||
func (s mockSessionRPCServer) RegisterUdpSession(_ context.Context, sessionID uuid.UUID, dstIP net.IP, dstPort uint16, closeIdleAfter time.Duration) error {
|
||||
if s.sessionID != sessionID {
|
||||
return fmt.Errorf("expect session ID %s, got %s", s.sessionID, sessionID)
|
||||
}
|
||||
@@ -177,7 +216,7 @@ func (s mockRPCServer) RegisterUdpSession(ctx context.Context, sessionID uuid.UU
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s mockRPCServer) UnregisterUdpSession(ctx context.Context, sessionID uuid.UUID, message string) error {
|
||||
func (s mockSessionRPCServer) UnregisterUdpSession(_ context.Context, sessionID uuid.UUID, message string) error {
|
||||
if s.sessionID != sessionID {
|
||||
return fmt.Errorf("expect session ID %s, got %s", s.sessionID, sessionID)
|
||||
}
|
||||
@@ -187,11 +226,35 @@ func (s mockRPCServer) UnregisterUdpSession(ctx context.Context, sessionID uuid.
|
||||
return nil
|
||||
}
|
||||
|
||||
type mockConfigRPCServer struct {
|
||||
version int32
|
||||
config []byte
|
||||
}
|
||||
|
||||
func (s mockConfigRPCServer) UpdateConfiguration(_ context.Context, version int32, config []byte) (*tunnelpogs.UpdateConfigurationResponse, error) {
|
||||
if s.version != version {
|
||||
return nil, fmt.Errorf("expect version %d, got %d", s.version, version)
|
||||
}
|
||||
if !bytes.Equal(s.config, config) {
|
||||
return nil, fmt.Errorf("expect config %v, got %v", s.config, config)
|
||||
}
|
||||
return &tunnelpogs.UpdateConfigurationResponse{LastAppliedVersion: version}, nil
|
||||
}
|
||||
|
||||
type mockRPCStream struct {
|
||||
io.ReadCloser
|
||||
io.WriteCloser
|
||||
}
|
||||
|
||||
func newMockRPCStreams() (client io.ReadWriteCloser, server io.ReadWriteCloser) {
|
||||
clientReader, serverWriter := io.Pipe()
|
||||
serverReader, clientWriter := io.Pipe()
|
||||
|
||||
client = mockRPCStream{clientReader, clientWriter}
|
||||
server = mockRPCStream{serverReader, serverWriter}
|
||||
return
|
||||
}
|
||||
|
||||
func (s mockRPCStream) Close() error {
|
||||
_ = s.ReadCloser.Close()
|
||||
_ = s.WriteCloser.Close()
|
||||
|
Reference in New Issue
Block a user