TUN-5695: Define RPC method to update configuration

This commit is contained in:
cthuang
2022-02-02 12:27:49 +00:00
parent 0ab6867ae5
commit d07d24e5a2
8 changed files with 923 additions and 223 deletions

View File

@@ -125,7 +125,7 @@ func (rcs *RequestClientStream) ReadConnectResponseData() (*ConnectResponse, err
return nil, err
}
if signature != DataStreamProtocolSignature {
return nil, fmt.Errorf("Wrong protocol signature %v", signature)
return nil, fmt.Errorf("wrong protocol signature %v", signature)
}
// This is a NO-OP for now. We could cause a branching if we wanted to use multiple versions.
@@ -157,13 +157,13 @@ func NewRPCServerStream(stream io.ReadWriteCloser, protocol ProtocolSignature) (
return &RPCServerStream{stream}, nil
}
func (s *RPCServerStream) Serve(sessionManager tunnelpogs.SessionManager, logger *zerolog.Logger) error {
func (s *RPCServerStream) Serve(sessionManager tunnelpogs.SessionManager, configManager tunnelpogs.ConfigurationManager, logger *zerolog.Logger) error {
// RPC logs are very robust, create a new logger that only logs error to reduce noise
rpcLogger := logger.Level(zerolog.ErrorLevel)
rpcTransport := tunnelrpc.NewTransportLogger(&rpcLogger, rpc.StreamTransport(s))
defer rpcTransport.Close()
main := tunnelpogs.SessionManager_ServerToClient(sessionManager)
main := tunnelpogs.CloudflaredServer_ServerToClient(sessionManager, configManager)
rpcConn := rpc.NewConn(
rpcTransport,
rpc.MainInterface(main.Client),
@@ -223,7 +223,7 @@ func writeSignature(stream io.Writer, signature ProtocolSignature) error {
// RPCClientStream is a stream to call methods of SessionManager
type RPCClientStream struct {
client tunnelpogs.SessionManager_PogsClient
client tunnelpogs.CloudflaredServer_PogsClient
transport rpc.Transport
}
@@ -241,7 +241,7 @@ func NewRPCClientStream(ctx context.Context, stream io.ReadWriteCloser, logger *
tunnelrpc.ConnLog(logger),
)
return &RPCClientStream{
client: tunnelpogs.SessionManager_PogsClient{Client: conn.Bootstrap(ctx), Conn: conn},
client: tunnelpogs.NewCloudflaredServer_PogsClient(conn.Bootstrap(ctx), conn),
transport: transport,
}, nil
}
@@ -258,6 +258,10 @@ func (rcs *RPCClientStream) UnregisterUdpSession(ctx context.Context, sessionID
return rcs.client.UnregisterUdpSession(ctx, sessionID, message)
}
func (rcs *RPCClientStream) UpdateConfiguration(ctx context.Context, version int32, config []byte) (*tunnelpogs.UpdateConfigurationResponse, error) {
return rcs.client.UpdateConfiguration(ctx, version, config)
}
func (rcs *RPCClientStream) Close() {
_ = rcs.client.Close()
_ = rcs.transport.Close()

View File

@@ -14,6 +14,8 @@ import (
"github.com/rs/zerolog"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
tunnelpogs "github.com/cloudflare/cloudflared/tunnelrpc/pogs"
)
const (
@@ -108,14 +110,10 @@ func TestConnectResponseMeta(t *testing.T) {
}
func TestRegisterUdpSession(t *testing.T) {
clientReader, serverWriter := io.Pipe()
serverReader, clientWriter := io.Pipe()
clientStream := mockRPCStream{clientReader, clientWriter}
serverStream := mockRPCStream{serverReader, serverWriter}
clientStream, serverStream := newMockRPCStreams()
unregisterMessage := "closed by eyeball"
rpcServer := mockRPCServer{
sessionRPCServer := mockSessionRPCServer{
sessionID: uuid.New(),
dstIP: net.IP{172, 16, 0, 1},
dstPort: 8000,
@@ -129,7 +127,7 @@ func TestRegisterUdpSession(t *testing.T) {
assert.NoError(t, err)
rpcServerStream, err := NewRPCServerStream(serverStream, protocol)
assert.NoError(t, err)
err = rpcServerStream.Serve(rpcServer, &logger)
err = rpcServerStream.Serve(sessionRPCServer, nil, &logger)
assert.NoError(t, err)
serverStream.Close()
@@ -139,12 +137,12 @@ func TestRegisterUdpSession(t *testing.T) {
rpcClientStream, err := NewRPCClientStream(context.Background(), clientStream, &logger)
assert.NoError(t, err)
assert.NoError(t, rpcClientStream.RegisterUdpSession(context.Background(), rpcServer.sessionID, rpcServer.dstIP, rpcServer.dstPort, testCloseIdleAfterHint))
assert.NoError(t, rpcClientStream.RegisterUdpSession(context.Background(), sessionRPCServer.sessionID, sessionRPCServer.dstIP, sessionRPCServer.dstPort, testCloseIdleAfterHint))
// Different sessionID, the RPC server should reject the registraion
assert.Error(t, rpcClientStream.RegisterUdpSession(context.Background(), uuid.New(), rpcServer.dstIP, rpcServer.dstPort, testCloseIdleAfterHint))
assert.Error(t, rpcClientStream.RegisterUdpSession(context.Background(), uuid.New(), sessionRPCServer.dstIP, sessionRPCServer.dstPort, testCloseIdleAfterHint))
assert.NoError(t, rpcClientStream.UnregisterUdpSession(context.Background(), rpcServer.sessionID, unregisterMessage))
assert.NoError(t, rpcClientStream.UnregisterUdpSession(context.Background(), sessionRPCServer.sessionID, unregisterMessage))
// Different sessionID, the RPC server should reject the unregistraion
assert.Error(t, rpcClientStream.UnregisterUdpSession(context.Background(), uuid.New(), unregisterMessage))
@@ -153,7 +151,48 @@ func TestRegisterUdpSession(t *testing.T) {
<-sessionRegisteredChan
}
type mockRPCServer struct {
func TestManageConfiguration(t *testing.T) {
var (
version int32 = 168
config = []byte(t.Name())
)
clientStream, serverStream := newMockRPCStreams()
configRPCServer := mockConfigRPCServer{
version: version,
config: config,
}
logger := zerolog.Nop()
updatedChan := make(chan struct{})
go func() {
protocol, err := DetermineProtocol(serverStream)
assert.NoError(t, err)
rpcServerStream, err := NewRPCServerStream(serverStream, protocol)
assert.NoError(t, err)
err = rpcServerStream.Serve(nil, configRPCServer, &logger)
assert.NoError(t, err)
serverStream.Close()
close(updatedChan)
}()
ctx, cancel := context.WithTimeout(context.Background(), time.Second)
defer cancel()
rpcClientStream, err := NewRPCClientStream(ctx, clientStream, &logger)
assert.NoError(t, err)
result, err := rpcClientStream.UpdateConfiguration(ctx, version, config)
assert.NoError(t, err)
require.Equal(t, version, result.LastAppliedVersion)
require.NoError(t, result.Err)
rpcClientStream.Close()
<-updatedChan
}
type mockSessionRPCServer struct {
sessionID uuid.UUID
dstIP net.IP
dstPort uint16
@@ -161,7 +200,7 @@ type mockRPCServer struct {
unregisterMessage string
}
func (s mockRPCServer) RegisterUdpSession(ctx context.Context, sessionID uuid.UUID, dstIP net.IP, dstPort uint16, closeIdleAfter time.Duration) error {
func (s mockSessionRPCServer) RegisterUdpSession(_ context.Context, sessionID uuid.UUID, dstIP net.IP, dstPort uint16, closeIdleAfter time.Duration) error {
if s.sessionID != sessionID {
return fmt.Errorf("expect session ID %s, got %s", s.sessionID, sessionID)
}
@@ -177,7 +216,7 @@ func (s mockRPCServer) RegisterUdpSession(ctx context.Context, sessionID uuid.UU
return nil
}
func (s mockRPCServer) UnregisterUdpSession(ctx context.Context, sessionID uuid.UUID, message string) error {
func (s mockSessionRPCServer) UnregisterUdpSession(_ context.Context, sessionID uuid.UUID, message string) error {
if s.sessionID != sessionID {
return fmt.Errorf("expect session ID %s, got %s", s.sessionID, sessionID)
}
@@ -187,11 +226,35 @@ func (s mockRPCServer) UnregisterUdpSession(ctx context.Context, sessionID uuid.
return nil
}
type mockConfigRPCServer struct {
version int32
config []byte
}
func (s mockConfigRPCServer) UpdateConfiguration(_ context.Context, version int32, config []byte) (*tunnelpogs.UpdateConfigurationResponse, error) {
if s.version != version {
return nil, fmt.Errorf("expect version %d, got %d", s.version, version)
}
if !bytes.Equal(s.config, config) {
return nil, fmt.Errorf("expect config %v, got %v", s.config, config)
}
return &tunnelpogs.UpdateConfigurationResponse{LastAppliedVersion: version}, nil
}
type mockRPCStream struct {
io.ReadCloser
io.WriteCloser
}
func newMockRPCStreams() (client io.ReadWriteCloser, server io.ReadWriteCloser) {
clientReader, serverWriter := io.Pipe()
serverReader, clientWriter := io.Pipe()
client = mockRPCStream{clientReader, clientWriter}
server = mockRPCStream{serverReader, serverWriter}
return
}
func (s mockRPCStream) Close() error {
_ = s.ReadCloser.Close()
_ = s.WriteCloser.Close()