mirror of
https://github.com/cloudflare/cloudflared.git
synced 2025-07-27 21:09:58 +00:00
TUN-8415: Refactor capnp rpc into a single module
Combines the tunnelrpc and quic/schema capnp files into the same module. To help reduce future issues with capnp id generation, capnpids are provided in the capnp files from the existing capnp struct ids generated in the go files. Reduces the overall interface of the Capnp methods to the rest of the code by providing an interface that will handle the quic protocol selection. Introduces a new `rpc-timeout` config that will allow all of the SessionManager and ConfigurationManager RPC requests to have a timeout. The timeout for these values is set to 5 seconds as non of these operations for the managers should take a long time to complete. Removed the RPC-specific logger as it never provided good debugging value as the RPC method names were not visible in the logs.
This commit is contained in:
63
tunnelrpc/quic/cloudflared_client.go
Normal file
63
tunnelrpc/quic/cloudflared_client.go
Normal file
@@ -0,0 +1,63 @@
|
||||
package quic
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"io"
|
||||
"net"
|
||||
"time"
|
||||
|
||||
"zombiezen.com/go/capnproto2/rpc"
|
||||
|
||||
"github.com/google/uuid"
|
||||
|
||||
"github.com/cloudflare/cloudflared/tunnelrpc/pogs"
|
||||
)
|
||||
|
||||
// CloudflaredClient calls capnp rpc methods of SessionManager and ConfigurationManager.
|
||||
type CloudflaredClient struct {
|
||||
client pogs.CloudflaredServer_PogsClient
|
||||
transport rpc.Transport
|
||||
requestTimeout time.Duration
|
||||
}
|
||||
|
||||
func NewCloudflaredClient(ctx context.Context, stream io.ReadWriteCloser, requestTimeout time.Duration) (*CloudflaredClient, error) {
|
||||
n, err := stream.Write(rpcStreamProtocolSignature[:])
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if n != len(rpcStreamProtocolSignature) {
|
||||
return nil, fmt.Errorf("expect to write %d bytes for RPC stream protocol signature, wrote %d", len(rpcStreamProtocolSignature), n)
|
||||
}
|
||||
transport := rpc.StreamTransport(stream)
|
||||
conn := rpc.NewConn(transport)
|
||||
client := pogs.NewCloudflaredServer_PogsClient(conn.Bootstrap(ctx), conn)
|
||||
return &CloudflaredClient{
|
||||
client: client,
|
||||
transport: transport,
|
||||
requestTimeout: requestTimeout,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (c *CloudflaredClient) RegisterUdpSession(ctx context.Context, sessionID uuid.UUID, dstIP net.IP, dstPort uint16, closeIdleAfterHint time.Duration, traceContext string) (*pogs.RegisterUdpSessionResponse, error) {
|
||||
ctx, cancel := context.WithTimeout(ctx, c.requestTimeout)
|
||||
defer cancel()
|
||||
return c.client.RegisterUdpSession(ctx, sessionID, dstIP, dstPort, closeIdleAfterHint, traceContext)
|
||||
}
|
||||
|
||||
func (c *CloudflaredClient) UnregisterUdpSession(ctx context.Context, sessionID uuid.UUID, message string) error {
|
||||
ctx, cancel := context.WithTimeout(ctx, c.requestTimeout)
|
||||
defer cancel()
|
||||
return c.client.UnregisterUdpSession(ctx, sessionID, message)
|
||||
}
|
||||
|
||||
func (c *CloudflaredClient) UpdateConfiguration(ctx context.Context, version int32, config []byte) (*pogs.UpdateConfigurationResponse, error) {
|
||||
ctx, cancel := context.WithTimeout(ctx, c.requestTimeout)
|
||||
defer cancel()
|
||||
return c.client.UpdateConfiguration(ctx, version, config)
|
||||
}
|
||||
|
||||
func (c *CloudflaredClient) Close() {
|
||||
_ = c.client.Close()
|
||||
_ = c.transport.Close()
|
||||
}
|
69
tunnelrpc/quic/cloudflared_server.go
Normal file
69
tunnelrpc/quic/cloudflared_server.go
Normal file
@@ -0,0 +1,69 @@
|
||||
package quic
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"io"
|
||||
"time"
|
||||
|
||||
"zombiezen.com/go/capnproto2/rpc"
|
||||
|
||||
"github.com/cloudflare/cloudflared/tunnelrpc/pogs"
|
||||
)
|
||||
|
||||
// HandleRequestFunc wraps the proxied request from the upstream and also provides methods on the stream to
|
||||
// handle the response back.
|
||||
type HandleRequestFunc = func(ctx context.Context, stream *RequestServerStream) error
|
||||
|
||||
// CloudflaredServer provides a handler interface for a client to provide methods to handle the different types of
|
||||
// requests that can be communicated by the stream.
|
||||
type CloudflaredServer struct {
|
||||
handleRequest HandleRequestFunc
|
||||
sessionManager pogs.SessionManager
|
||||
configManager pogs.ConfigurationManager
|
||||
responseTimeout time.Duration
|
||||
}
|
||||
|
||||
func NewCloudflaredServer(handleRequest HandleRequestFunc, sessionManager pogs.SessionManager, configManager pogs.ConfigurationManager, responseTimeout time.Duration) *CloudflaredServer {
|
||||
return &CloudflaredServer{
|
||||
handleRequest: handleRequest,
|
||||
sessionManager: sessionManager,
|
||||
configManager: configManager,
|
||||
responseTimeout: responseTimeout,
|
||||
}
|
||||
}
|
||||
|
||||
// Serve executes the defined handlers in ServerStream on the provided stream if it is a proper RPC stream with the
|
||||
// correct preamble protocol signature.
|
||||
func (s *CloudflaredServer) Serve(ctx context.Context, stream io.ReadWriteCloser) error {
|
||||
signature, err := determineProtocol(stream)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
switch signature {
|
||||
case dataStreamProtocolSignature:
|
||||
return s.handleRequest(ctx, &RequestServerStream{stream})
|
||||
case rpcStreamProtocolSignature:
|
||||
return s.handleRPC(ctx, stream)
|
||||
default:
|
||||
return fmt.Errorf("unknown protocol %v", signature)
|
||||
}
|
||||
}
|
||||
|
||||
func (s *CloudflaredServer) handleRPC(ctx context.Context, stream io.ReadWriteCloser) error {
|
||||
ctx, cancel := context.WithTimeout(ctx, s.responseTimeout)
|
||||
defer cancel()
|
||||
transport := rpc.StreamTransport(stream)
|
||||
defer transport.Close()
|
||||
|
||||
main := pogs.CloudflaredServer_ServerToClient(s.sessionManager, s.configManager)
|
||||
rpcConn := rpc.NewConn(transport, rpc.MainInterface(main.Client))
|
||||
defer rpcConn.Close()
|
||||
|
||||
// We ignore the errors here because if cloudflared fails to handle a request, we will just move on.
|
||||
select {
|
||||
case <-rpcConn.Done():
|
||||
case <-ctx.Done():
|
||||
}
|
||||
return nil
|
||||
}
|
78
tunnelrpc/quic/protocol.go
Normal file
78
tunnelrpc/quic/protocol.go
Normal file
@@ -0,0 +1,78 @@
|
||||
package quic
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"io"
|
||||
)
|
||||
|
||||
// protocolSignature defines the first 6 bytes of the stream, which is used to distinguish the type of stream. It
|
||||
// ensures whoever performs a handshake does not write data before writing the metadata.
|
||||
type protocolSignature [6]byte
|
||||
|
||||
var (
|
||||
// dataStreamProtocolSignature is a custom protocol signature for data stream
|
||||
dataStreamProtocolSignature = protocolSignature{0x0A, 0x36, 0xCD, 0x12, 0xA1, 0x3E}
|
||||
|
||||
// rpcStreamProtocolSignature is a custom protocol signature for RPC stream
|
||||
rpcStreamProtocolSignature = protocolSignature{0x52, 0xBB, 0x82, 0x5C, 0xDB, 0x65}
|
||||
|
||||
errDataStreamNotSupported = fmt.Errorf("data protocol not supported")
|
||||
errRPCStreamNotSupported = fmt.Errorf("rpc protocol not supported")
|
||||
)
|
||||
|
||||
type protocolVersion string
|
||||
|
||||
const (
|
||||
protocolV1 protocolVersion = "01"
|
||||
|
||||
protocolVersionLength = 2
|
||||
)
|
||||
|
||||
// determineProtocol reads the first 6 bytes from the stream to determine which protocol is spoken by the client.
|
||||
// The protocols are magic byte arrays understood by both sides of the stream.
|
||||
func determineProtocol(stream io.Reader) (protocolSignature, error) {
|
||||
signature, err := readSignature(stream)
|
||||
if err != nil {
|
||||
return protocolSignature{}, err
|
||||
}
|
||||
switch signature {
|
||||
case dataStreamProtocolSignature:
|
||||
return dataStreamProtocolSignature, nil
|
||||
case rpcStreamProtocolSignature:
|
||||
return rpcStreamProtocolSignature, nil
|
||||
default:
|
||||
return protocolSignature{}, fmt.Errorf("unknown signature %v", signature)
|
||||
}
|
||||
}
|
||||
|
||||
func writeDataStreamPreamble(stream io.Writer) error {
|
||||
if err := writeSignature(stream, dataStreamProtocolSignature); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return writeVersion(stream)
|
||||
}
|
||||
|
||||
func writeVersion(stream io.Writer) error {
|
||||
_, err := stream.Write([]byte(protocolV1)[:protocolVersionLength])
|
||||
return err
|
||||
}
|
||||
|
||||
func readVersion(stream io.Reader) (string, error) {
|
||||
version := make([]byte, protocolVersionLength)
|
||||
_, err := stream.Read(version)
|
||||
return string(version), err
|
||||
}
|
||||
|
||||
func readSignature(stream io.Reader) (protocolSignature, error) {
|
||||
var signature protocolSignature
|
||||
if _, err := io.ReadFull(stream, signature[:]); err != nil {
|
||||
return protocolSignature{}, err
|
||||
}
|
||||
return signature, nil
|
||||
}
|
||||
|
||||
func writeSignature(stream io.Writer, signature protocolSignature) error {
|
||||
_, err := stream.Write(signature[:])
|
||||
return err
|
||||
}
|
61
tunnelrpc/quic/request_client_stream.go
Normal file
61
tunnelrpc/quic/request_client_stream.go
Normal file
@@ -0,0 +1,61 @@
|
||||
package quic
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"io"
|
||||
|
||||
capnp "zombiezen.com/go/capnproto2"
|
||||
|
||||
"github.com/cloudflare/cloudflared/tunnelrpc/pogs"
|
||||
)
|
||||
|
||||
// RequestClientStream is a stream to provide requests to the server. This operation is typically driven by the edge service.
|
||||
type RequestClientStream struct {
|
||||
io.ReadWriteCloser
|
||||
}
|
||||
|
||||
// WriteConnectRequestData writes requestMeta to a stream.
|
||||
func (rcs *RequestClientStream) WriteConnectRequestData(dest string, connectionType pogs.ConnectionType, metadata ...pogs.Metadata) error {
|
||||
connectRequest := &pogs.ConnectRequest{
|
||||
Dest: dest,
|
||||
Type: connectionType,
|
||||
Metadata: metadata,
|
||||
}
|
||||
|
||||
msg, err := connectRequest.ToPogs()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if err := writeDataStreamPreamble(rcs); err != nil {
|
||||
return err
|
||||
}
|
||||
return capnp.NewEncoder(rcs).Encode(msg)
|
||||
}
|
||||
|
||||
// ReadConnectResponseData reads the response from the rpc stream to a ConnectResponse.
|
||||
func (rcs *RequestClientStream) ReadConnectResponseData() (*pogs.ConnectResponse, error) {
|
||||
signature, err := determineProtocol(rcs)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if signature != dataStreamProtocolSignature {
|
||||
return nil, fmt.Errorf("wrong protocol signature %v", signature)
|
||||
}
|
||||
|
||||
// This is a NO-OP for now. We could cause a branching if we wanted to use multiple versions.
|
||||
if _, err := readVersion(rcs); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
msg, err := capnp.NewDecoder(rcs).Decode()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
r := &pogs.ConnectResponse{}
|
||||
if err := r.FromPogs(msg); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return r, nil
|
||||
}
|
57
tunnelrpc/quic/request_server_stream.go
Normal file
57
tunnelrpc/quic/request_server_stream.go
Normal file
@@ -0,0 +1,57 @@
|
||||
package quic
|
||||
|
||||
import (
|
||||
"io"
|
||||
|
||||
capnp "zombiezen.com/go/capnproto2"
|
||||
|
||||
"github.com/cloudflare/cloudflared/tunnelrpc/pogs"
|
||||
)
|
||||
|
||||
// RequestServerStream is a stream to serve requests
|
||||
type RequestServerStream struct {
|
||||
io.ReadWriteCloser
|
||||
}
|
||||
|
||||
// ReadConnectRequestData reads the handshake data from a QUIC stream.
|
||||
func (rss *RequestServerStream) ReadConnectRequestData() (*pogs.ConnectRequest, error) {
|
||||
// This is a NO-OP for now. We could cause a branching if we wanted to use multiple versions.
|
||||
if _, err := readVersion(rss); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
msg, err := capnp.NewDecoder(rss).Decode()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
r := &pogs.ConnectRequest{}
|
||||
if err := r.FromPogs(msg); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return r, nil
|
||||
}
|
||||
|
||||
// WriteConnectResponseData writes response to a QUIC stream.
|
||||
func (rss *RequestServerStream) WriteConnectResponseData(respErr error, metadata ...pogs.Metadata) error {
|
||||
var connectResponse *pogs.ConnectResponse
|
||||
if respErr != nil {
|
||||
connectResponse = &pogs.ConnectResponse{
|
||||
Error: respErr.Error(),
|
||||
}
|
||||
} else {
|
||||
connectResponse = &pogs.ConnectResponse{
|
||||
Metadata: metadata,
|
||||
}
|
||||
}
|
||||
|
||||
msg, err := connectResponse.ToPogs()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if err := writeDataStreamPreamble(rss); err != nil {
|
||||
return err
|
||||
}
|
||||
return capnp.NewEncoder(rss).Encode(msg)
|
||||
}
|
298
tunnelrpc/quic/request_server_stream_test.go
Normal file
298
tunnelrpc/quic/request_server_stream_test.go
Normal file
@@ -0,0 +1,298 @@
|
||||
package quic
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"net"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/google/uuid"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
"github.com/cloudflare/cloudflared/tunnelrpc/pogs"
|
||||
)
|
||||
|
||||
const (
|
||||
testCloseIdleAfterHint = time.Minute * 2
|
||||
)
|
||||
|
||||
func TestConnectRequestData(t *testing.T) {
|
||||
var tests = []struct {
|
||||
name string
|
||||
hostname string
|
||||
connectionType pogs.ConnectionType
|
||||
metadata []pogs.Metadata
|
||||
}{
|
||||
{
|
||||
name: "Signature verified and request metadata is unmarshaled and read correctly",
|
||||
hostname: "tunnel.com",
|
||||
connectionType: pogs.ConnectionTypeHTTP,
|
||||
metadata: []pogs.Metadata{
|
||||
{
|
||||
Key: "key",
|
||||
Val: "1234",
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
for _, test := range tests {
|
||||
t.Run(test.name, func(t *testing.T) {
|
||||
b := &bytes.Buffer{}
|
||||
reqClientStream := RequestClientStream{noopCloser{b}}
|
||||
err := reqClientStream.WriteConnectRequestData(test.hostname, test.connectionType, test.metadata...)
|
||||
require.NoError(t, err)
|
||||
protocol, err := determineProtocol(b)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, dataStreamProtocolSignature, protocol)
|
||||
reqServerStream := RequestServerStream{&noopCloser{b}}
|
||||
|
||||
reqMeta, err := reqServerStream.ReadConnectRequestData()
|
||||
require.NoError(t, err)
|
||||
|
||||
assert.Equal(t, test.metadata, reqMeta.Metadata)
|
||||
assert.Equal(t, test.hostname, reqMeta.Dest)
|
||||
assert.Equal(t, test.connectionType, reqMeta.Type)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestConnectResponseMeta(t *testing.T) {
|
||||
var tests = []struct {
|
||||
name string
|
||||
err error
|
||||
metadata []pogs.Metadata
|
||||
}{
|
||||
{
|
||||
name: "Signature verified and response metadata is unmarshaled and read correctly",
|
||||
metadata: []pogs.Metadata{
|
||||
{
|
||||
Key: "key",
|
||||
Val: "1234",
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "If error is not empty, other fields should be blank",
|
||||
err: errors.New("something happened"),
|
||||
metadata: []pogs.Metadata{
|
||||
{
|
||||
Key: "key",
|
||||
Val: "1234",
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
for _, test := range tests {
|
||||
t.Run(test.name, func(t *testing.T) {
|
||||
b := &bytes.Buffer{}
|
||||
reqServerStream := RequestServerStream{noopCloser{b}}
|
||||
err := reqServerStream.WriteConnectResponseData(test.err, test.metadata...)
|
||||
require.NoError(t, err)
|
||||
|
||||
reqClientStream := RequestClientStream{noopCloser{b}}
|
||||
respMeta, err := reqClientStream.ReadConnectResponseData()
|
||||
require.NoError(t, err)
|
||||
|
||||
if respMeta.Error == "" {
|
||||
assert.Equal(t, test.metadata, respMeta.Metadata)
|
||||
} else {
|
||||
assert.Equal(t, 0, len(respMeta.Metadata))
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestRegisterUdpSession(t *testing.T) {
|
||||
unregisterMessage := "closed by eyeball"
|
||||
|
||||
var tests = []struct {
|
||||
name string
|
||||
sessionRPCServer mockSessionRPCServer
|
||||
}{
|
||||
{
|
||||
name: "RegisterUdpSession (no trace context)",
|
||||
sessionRPCServer: mockSessionRPCServer{
|
||||
sessionID: uuid.New(),
|
||||
dstIP: net.IP{172, 16, 0, 1},
|
||||
dstPort: 8000,
|
||||
closeIdleAfter: testCloseIdleAfterHint,
|
||||
unregisterMessage: unregisterMessage,
|
||||
traceContext: "",
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "RegisterUdpSession (with trace context)",
|
||||
sessionRPCServer: mockSessionRPCServer{
|
||||
sessionID: uuid.New(),
|
||||
dstIP: net.IP{172, 16, 0, 1},
|
||||
dstPort: 8000,
|
||||
closeIdleAfter: testCloseIdleAfterHint,
|
||||
unregisterMessage: unregisterMessage,
|
||||
traceContext: "1241ce3ecdefc68854e8514e69ba42ca:b38f1bf5eae406f3:0:1",
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
for _, test := range tests {
|
||||
t.Run(test.name, func(t *testing.T) {
|
||||
clientStream, serverStream := newMockRPCStreams()
|
||||
sessionRegisteredChan := make(chan struct{})
|
||||
go func() {
|
||||
ss := NewCloudflaredServer(nil, test.sessionRPCServer, nil, 10*time.Second)
|
||||
err := ss.Serve(context.Background(), serverStream)
|
||||
assert.NoError(t, err)
|
||||
|
||||
serverStream.Close()
|
||||
close(sessionRegisteredChan)
|
||||
}()
|
||||
|
||||
rpcClientStream, err := NewCloudflaredClient(context.Background(), clientStream, 5*time.Second)
|
||||
assert.NoError(t, err)
|
||||
|
||||
reg, err := rpcClientStream.RegisterUdpSession(context.Background(), test.sessionRPCServer.sessionID, test.sessionRPCServer.dstIP, test.sessionRPCServer.dstPort, testCloseIdleAfterHint, test.sessionRPCServer.traceContext)
|
||||
assert.NoError(t, err)
|
||||
assert.NoError(t, reg.Err)
|
||||
|
||||
// Different sessionID, the RPC server should reject the registraion
|
||||
reg, err = rpcClientStream.RegisterUdpSession(context.Background(), uuid.New(), test.sessionRPCServer.dstIP, test.sessionRPCServer.dstPort, testCloseIdleAfterHint, test.sessionRPCServer.traceContext)
|
||||
assert.NoError(t, err)
|
||||
assert.Error(t, reg.Err)
|
||||
|
||||
assert.NoError(t, rpcClientStream.UnregisterUdpSession(context.Background(), test.sessionRPCServer.sessionID, unregisterMessage))
|
||||
|
||||
// Different sessionID, the RPC server should reject the unregistraion
|
||||
assert.Error(t, rpcClientStream.UnregisterUdpSession(context.Background(), uuid.New(), unregisterMessage))
|
||||
|
||||
rpcClientStream.Close()
|
||||
<-sessionRegisteredChan
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestManageConfiguration(t *testing.T) {
|
||||
var (
|
||||
version int32 = 168
|
||||
config = []byte(t.Name())
|
||||
)
|
||||
clientStream, serverStream := newMockRPCStreams()
|
||||
|
||||
configRPCServer := mockConfigRPCServer{
|
||||
version: version,
|
||||
config: config,
|
||||
}
|
||||
|
||||
updatedChan := make(chan struct{})
|
||||
go func() {
|
||||
server := NewCloudflaredServer(nil, nil, configRPCServer, 10*time.Second)
|
||||
err := server.Serve(context.Background(), serverStream)
|
||||
assert.NoError(t, err)
|
||||
|
||||
serverStream.Close()
|
||||
close(updatedChan)
|
||||
}()
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), time.Second)
|
||||
defer cancel()
|
||||
rpcClientStream, err := NewCloudflaredClient(ctx, clientStream, 5*time.Second)
|
||||
assert.NoError(t, err)
|
||||
|
||||
result, err := rpcClientStream.UpdateConfiguration(ctx, version, config)
|
||||
assert.NoError(t, err)
|
||||
|
||||
require.Equal(t, version, result.LastAppliedVersion)
|
||||
require.NoError(t, result.Err)
|
||||
|
||||
rpcClientStream.Close()
|
||||
<-updatedChan
|
||||
}
|
||||
|
||||
type mockSessionRPCServer struct {
|
||||
sessionID uuid.UUID
|
||||
dstIP net.IP
|
||||
dstPort uint16
|
||||
closeIdleAfter time.Duration
|
||||
unregisterMessage string
|
||||
traceContext string
|
||||
}
|
||||
|
||||
func (s mockSessionRPCServer) RegisterUdpSession(_ context.Context, sessionID uuid.UUID, dstIP net.IP, dstPort uint16, closeIdleAfter time.Duration, traceContext string) (*pogs.RegisterUdpSessionResponse, error) {
|
||||
if s.sessionID != sessionID {
|
||||
return nil, fmt.Errorf("expect session ID %s, got %s", s.sessionID, sessionID)
|
||||
}
|
||||
if !s.dstIP.Equal(dstIP) {
|
||||
return nil, fmt.Errorf("expect destination IP %s, got %s", s.dstIP, dstIP)
|
||||
}
|
||||
if s.dstPort != dstPort {
|
||||
return nil, fmt.Errorf("expect destination port %d, got %d", s.dstPort, dstPort)
|
||||
}
|
||||
if s.closeIdleAfter != closeIdleAfter {
|
||||
return nil, fmt.Errorf("expect closeIdleAfter %d, got %d", s.closeIdleAfter, closeIdleAfter)
|
||||
}
|
||||
if s.traceContext != traceContext {
|
||||
return nil, fmt.Errorf("expect traceContext %s, got %s", s.traceContext, traceContext)
|
||||
}
|
||||
return &pogs.RegisterUdpSessionResponse{}, nil
|
||||
}
|
||||
|
||||
func (s mockSessionRPCServer) UnregisterUdpSession(_ context.Context, sessionID uuid.UUID, message string) error {
|
||||
if s.sessionID != sessionID {
|
||||
return fmt.Errorf("expect session ID %s, got %s", s.sessionID, sessionID)
|
||||
}
|
||||
if s.unregisterMessage != message {
|
||||
return fmt.Errorf("expect unregister message %s, got %s", s.unregisterMessage, message)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
type mockConfigRPCServer struct {
|
||||
version int32
|
||||
config []byte
|
||||
}
|
||||
|
||||
func (s mockConfigRPCServer) UpdateConfiguration(_ context.Context, version int32, config []byte) *pogs.UpdateConfigurationResponse {
|
||||
if s.version != version {
|
||||
return &pogs.UpdateConfigurationResponse{
|
||||
Err: fmt.Errorf("expect version %d, got %d", s.version, version),
|
||||
}
|
||||
}
|
||||
if !bytes.Equal(s.config, config) {
|
||||
return &pogs.UpdateConfigurationResponse{
|
||||
Err: fmt.Errorf("expect config %v, got %v", s.config, config),
|
||||
}
|
||||
}
|
||||
return &pogs.UpdateConfigurationResponse{LastAppliedVersion: version}
|
||||
}
|
||||
|
||||
type mockRPCStream struct {
|
||||
io.ReadCloser
|
||||
io.WriteCloser
|
||||
}
|
||||
|
||||
func newMockRPCStreams() (client io.ReadWriteCloser, server io.ReadWriteCloser) {
|
||||
clientReader, serverWriter := io.Pipe()
|
||||
serverReader, clientWriter := io.Pipe()
|
||||
|
||||
client = mockRPCStream{clientReader, clientWriter}
|
||||
server = mockRPCStream{serverReader, serverWriter}
|
||||
return
|
||||
}
|
||||
|
||||
func (s mockRPCStream) Close() error {
|
||||
_ = s.ReadCloser.Close()
|
||||
_ = s.WriteCloser.Close()
|
||||
return nil
|
||||
}
|
||||
|
||||
type noopCloser struct {
|
||||
io.ReadWriter
|
||||
}
|
||||
|
||||
func (noopCloser) Close() error {
|
||||
return nil
|
||||
}
|
55
tunnelrpc/quic/session_client.go
Normal file
55
tunnelrpc/quic/session_client.go
Normal file
@@ -0,0 +1,55 @@
|
||||
package quic
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"io"
|
||||
"net"
|
||||
"time"
|
||||
|
||||
"github.com/google/uuid"
|
||||
"zombiezen.com/go/capnproto2/rpc"
|
||||
|
||||
"github.com/cloudflare/cloudflared/tunnelrpc/pogs"
|
||||
)
|
||||
|
||||
// SessionClient calls capnp rpc methods of SessionManager.
|
||||
type SessionClient struct {
|
||||
client pogs.SessionManager_PogsClient
|
||||
transport rpc.Transport
|
||||
requestTimeout time.Duration
|
||||
}
|
||||
|
||||
func NewSessionClient(ctx context.Context, stream io.ReadWriteCloser, requestTimeout time.Duration) (*SessionClient, error) {
|
||||
n, err := stream.Write(rpcStreamProtocolSignature[:])
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if n != len(rpcStreamProtocolSignature) {
|
||||
return nil, fmt.Errorf("expect to write %d bytes for RPC stream protocol signature, wrote %d", len(rpcStreamProtocolSignature), n)
|
||||
}
|
||||
transport := rpc.StreamTransport(stream)
|
||||
conn := rpc.NewConn(transport)
|
||||
return &SessionClient{
|
||||
client: pogs.NewSessionManager_PogsClient(conn.Bootstrap(ctx), conn),
|
||||
transport: transport,
|
||||
requestTimeout: requestTimeout,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (c *SessionClient) RegisterUdpSession(ctx context.Context, sessionID uuid.UUID, dstIP net.IP, dstPort uint16, closeIdleAfterHint time.Duration, traceContext string) (*pogs.RegisterUdpSessionResponse, error) {
|
||||
ctx, cancel := context.WithTimeout(ctx, c.requestTimeout)
|
||||
defer cancel()
|
||||
return c.client.RegisterUdpSession(ctx, sessionID, dstIP, dstPort, closeIdleAfterHint, traceContext)
|
||||
}
|
||||
|
||||
func (c *SessionClient) UnregisterUdpSession(ctx context.Context, sessionID uuid.UUID, message string) error {
|
||||
ctx, cancel := context.WithTimeout(ctx, c.requestTimeout)
|
||||
defer cancel()
|
||||
return c.client.UnregisterUdpSession(ctx, sessionID, message)
|
||||
}
|
||||
|
||||
func (c *SessionClient) Close() {
|
||||
_ = c.client.Close()
|
||||
_ = c.transport.Close()
|
||||
}
|
59
tunnelrpc/quic/session_server.go
Normal file
59
tunnelrpc/quic/session_server.go
Normal file
@@ -0,0 +1,59 @@
|
||||
package quic
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"io"
|
||||
"time"
|
||||
|
||||
"zombiezen.com/go/capnproto2/rpc"
|
||||
|
||||
"github.com/cloudflare/cloudflared/tunnelrpc/pogs"
|
||||
)
|
||||
|
||||
// SessionManagerServer handles streams with the SessionManager RPCs.
|
||||
type SessionManagerServer struct {
|
||||
sessionManager pogs.SessionManager
|
||||
responseTimeout time.Duration
|
||||
}
|
||||
|
||||
func NewSessionManagerServer(sessionManager pogs.SessionManager, responseTimeout time.Duration) *SessionManagerServer {
|
||||
return &SessionManagerServer{
|
||||
sessionManager: sessionManager,
|
||||
responseTimeout: responseTimeout,
|
||||
}
|
||||
}
|
||||
|
||||
func (s *SessionManagerServer) Serve(ctx context.Context, stream io.ReadWriteCloser) error {
|
||||
signature, err := determineProtocol(stream)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
switch signature {
|
||||
case rpcStreamProtocolSignature:
|
||||
break
|
||||
case dataStreamProtocolSignature:
|
||||
return errDataStreamNotSupported
|
||||
default:
|
||||
return fmt.Errorf("unknown protocol %v", signature)
|
||||
}
|
||||
|
||||
// Every new quic.Stream request aligns to a new RPC request, this is why there is a timeout for the server-side
|
||||
// of the RPC request.
|
||||
ctx, cancel := context.WithTimeout(ctx, s.responseTimeout)
|
||||
defer cancel()
|
||||
|
||||
transport := rpc.StreamTransport(stream)
|
||||
defer transport.Close()
|
||||
|
||||
main := pogs.SessionManager_ServerToClient(s.sessionManager)
|
||||
rpcConn := rpc.NewConn(transport, rpc.MainInterface(main.Client))
|
||||
defer rpcConn.Close()
|
||||
|
||||
select {
|
||||
case <-rpcConn.Done():
|
||||
return rpcConn.Err()
|
||||
case <-ctx.Done():
|
||||
return ctx.Err()
|
||||
}
|
||||
}
|
Reference in New Issue
Block a user