mirror of
https://github.com/cloudflare/cloudflared.git
synced 2025-07-27 21:09:58 +00:00
TUN-5300: Define RPC to register UDP sessions
This commit is contained in:

committed by
Arég Harutyunyan

parent
571380b3f5
commit
fc2333c934
@@ -1,16 +1,32 @@
|
||||
package quic
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"fmt"
|
||||
"io"
|
||||
"net"
|
||||
|
||||
capnp "zombiezen.com/go/capnproto2"
|
||||
"zombiezen.com/go/capnproto2/rpc"
|
||||
|
||||
"github.com/google/uuid"
|
||||
"github.com/rs/zerolog"
|
||||
|
||||
"github.com/cloudflare/cloudflared/tunnelrpc"
|
||||
tunnelpogs "github.com/cloudflare/cloudflared/tunnelrpc/pogs"
|
||||
)
|
||||
|
||||
// protocolSignature is a custom protocol signature to ensure that whoever performs a handshake does not write data
|
||||
// before writing the metadata.
|
||||
var protocolSignature = []byte{0x0A, 0x36, 0xCD, 0x12, 0xA1, 0x3E}
|
||||
// The first 6 bytes of the stream is used to distinguish the type of stream. It ensures whoever performs a handshake does
|
||||
// not write data before writing the metadata.
|
||||
type ProtocolSignature [6]byte
|
||||
|
||||
var (
|
||||
// DataStreamProtocolSignature is a custom protocol signature for data stream
|
||||
DataStreamProtocolSignature = ProtocolSignature{0x0A, 0x36, 0xCD, 0x12, 0xA1, 0x3E}
|
||||
|
||||
// RPCStreamProtocolSignature is a custom protocol signature for RPC stream
|
||||
RPCStreamProtocolSignature = ProtocolSignature{0x52, 0xBB, 0x82, 0x5C, 0xDB, 0x65}
|
||||
)
|
||||
|
||||
const protocolVersionLength = 2
|
||||
|
||||
@@ -20,18 +36,26 @@ const (
|
||||
protocolV1 protocolVersion = "01"
|
||||
)
|
||||
|
||||
// RequestServerStream is a stream to serve requests
|
||||
type RequestServerStream struct {
|
||||
io.ReadWriteCloser
|
||||
}
|
||||
|
||||
func NewRequestServerStream(stream io.ReadWriteCloser, signature ProtocolSignature) (*RequestServerStream, error) {
|
||||
if signature != DataStreamProtocolSignature {
|
||||
return nil, fmt.Errorf("RequestClientStream can only be created from data stream")
|
||||
}
|
||||
return &RequestServerStream{stream}, nil
|
||||
}
|
||||
|
||||
// ReadConnectRequestData reads the handshake data from a QUIC stream.
|
||||
func ReadConnectRequestData(stream io.Reader) (*ConnectRequest, error) {
|
||||
if err := verifySignature(stream); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
func (rss *RequestServerStream) ReadConnectRequestData() (*ConnectRequest, error) {
|
||||
// This is a NO-OP for now. We could cause a branching if we wanted to use multiple versions.
|
||||
if _, err := readVersion(stream); err != nil {
|
||||
if _, err := readVersion(rss); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
msg, err := capnp.NewDecoder(stream).Decode()
|
||||
msg, err := capnp.NewDecoder(rss).Decode()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@@ -43,50 +67,8 @@ func ReadConnectRequestData(stream io.Reader) (*ConnectRequest, error) {
|
||||
return r, nil
|
||||
}
|
||||
|
||||
// WriteConnectRequestData writes requestMeta to a stream.
|
||||
func WriteConnectRequestData(stream io.Writer, dest string, connectionType ConnectionType, metadata ...Metadata) error {
|
||||
connectRequest := &ConnectRequest{
|
||||
Dest: dest,
|
||||
Type: connectionType,
|
||||
Metadata: metadata,
|
||||
}
|
||||
|
||||
msg, err := connectRequest.toPogs()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if err := writePreamble(stream); err != nil {
|
||||
return err
|
||||
}
|
||||
return capnp.NewEncoder(stream).Encode(msg)
|
||||
}
|
||||
|
||||
// ReadConnectResponseData reads the response to a RequestMeta in a stream.
|
||||
func ReadConnectResponseData(stream io.Reader) (*ConnectResponse, error) {
|
||||
if err := verifySignature(stream); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// This is a NO-OP for now. We could cause a branching if we wanted to use multiple versions.
|
||||
if _, err := readVersion(stream); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
msg, err := capnp.NewDecoder(stream).Decode()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
r := &ConnectResponse{}
|
||||
if err := r.fromPogs(msg); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return r, nil
|
||||
}
|
||||
|
||||
// WriteConnectResponseData writes response to a QUIC stream.
|
||||
func WriteConnectResponseData(stream io.Writer, respErr error, metadata ...Metadata) error {
|
||||
func (rss *RequestServerStream) WriteConnectResponseData(respErr error, metadata ...Metadata) error {
|
||||
var connectResponse *ConnectResponse
|
||||
if respErr != nil {
|
||||
connectResponse = &ConnectResponse{
|
||||
@@ -103,14 +85,105 @@ func WriteConnectResponseData(stream io.Writer, respErr error, metadata ...Metad
|
||||
return err
|
||||
}
|
||||
|
||||
if err := writePreamble(stream); err != nil {
|
||||
if err := writeDataStreamPreamble(rss); err != nil {
|
||||
return err
|
||||
}
|
||||
return capnp.NewEncoder(stream).Encode(msg)
|
||||
return capnp.NewEncoder(rss).Encode(msg)
|
||||
}
|
||||
|
||||
func writePreamble(stream io.Writer) error {
|
||||
if err := writeSignature(stream); err != nil {
|
||||
type RequestClientStream struct {
|
||||
io.ReadWriteCloser
|
||||
}
|
||||
|
||||
// WriteConnectRequestData writes requestMeta to a stream.
|
||||
func (rcs *RequestClientStream) WriteConnectRequestData(dest string, connectionType ConnectionType, metadata ...Metadata) error {
|
||||
connectRequest := &ConnectRequest{
|
||||
Dest: dest,
|
||||
Type: connectionType,
|
||||
Metadata: metadata,
|
||||
}
|
||||
|
||||
msg, err := connectRequest.toPogs()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if err := writeDataStreamPreamble(rcs); err != nil {
|
||||
return err
|
||||
}
|
||||
return capnp.NewEncoder(rcs).Encode(msg)
|
||||
}
|
||||
|
||||
// ReadConnectResponseData reads the response to a RequestMeta in a stream.
|
||||
func (rcs *RequestClientStream) ReadConnectResponseData() (*ConnectResponse, error) {
|
||||
signature, err := DetermineProtocol(rcs)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if signature != DataStreamProtocolSignature {
|
||||
return nil, fmt.Errorf("Wrong protocol signature %v", signature)
|
||||
}
|
||||
|
||||
// This is a NO-OP for now. We could cause a branching if we wanted to use multiple versions.
|
||||
if _, err := readVersion(rcs); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
msg, err := capnp.NewDecoder(rcs).Decode()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
r := &ConnectResponse{}
|
||||
if err := r.fromPogs(msg); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return r, nil
|
||||
}
|
||||
|
||||
// RPCServerStream is a stream to serve RPCs. It is closed when the RPC client is done
|
||||
type RPCServerStream struct {
|
||||
io.ReadWriteCloser
|
||||
}
|
||||
|
||||
func NewRPCServerStream(stream io.ReadWriteCloser, protocol ProtocolSignature) (*RPCServerStream, error) {
|
||||
if protocol != RPCStreamProtocolSignature {
|
||||
return nil, fmt.Errorf("RPCStream can only be created from rpc stream")
|
||||
}
|
||||
return &RPCServerStream{stream}, nil
|
||||
}
|
||||
|
||||
func (s *RPCServerStream) Serve(sessionManager tunnelpogs.SessionManager, logger *zerolog.Logger) error {
|
||||
rpcTransport := tunnelrpc.NewTransportLogger(logger, rpc.StreamTransport(s))
|
||||
defer rpcTransport.Close()
|
||||
|
||||
main := tunnelpogs.SessionManager_ServerToClient(sessionManager)
|
||||
rpcConn := rpc.NewConn(
|
||||
rpcTransport,
|
||||
rpc.MainInterface(main.Client),
|
||||
)
|
||||
defer rpcConn.Close()
|
||||
|
||||
return rpcConn.Wait()
|
||||
}
|
||||
|
||||
func DetermineProtocol(stream io.Reader) (ProtocolSignature, error) {
|
||||
signature, err := readSignature(stream)
|
||||
if err != nil {
|
||||
return ProtocolSignature{}, err
|
||||
}
|
||||
switch signature {
|
||||
case DataStreamProtocolSignature:
|
||||
return DataStreamProtocolSignature, nil
|
||||
case RPCStreamProtocolSignature:
|
||||
return RPCStreamProtocolSignature, nil
|
||||
default:
|
||||
return ProtocolSignature{}, fmt.Errorf("Unknown signature %v", signature)
|
||||
}
|
||||
}
|
||||
|
||||
func writeDataStreamPreamble(stream io.Writer) error {
|
||||
if err := writeSignature(stream, DataStreamProtocolSignature); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
@@ -128,20 +201,53 @@ func readVersion(stream io.Reader) (string, error) {
|
||||
return string(version), err
|
||||
}
|
||||
|
||||
func writeSignature(stream io.Writer) error {
|
||||
_, err := stream.Write(protocolSignature)
|
||||
func readSignature(stream io.Reader) (ProtocolSignature, error) {
|
||||
var signature ProtocolSignature
|
||||
if _, err := io.ReadFull(stream, signature[:]); err != nil {
|
||||
return ProtocolSignature{}, err
|
||||
}
|
||||
return signature, nil
|
||||
}
|
||||
|
||||
func writeSignature(stream io.Writer, signature ProtocolSignature) error {
|
||||
_, err := stream.Write(signature[:])
|
||||
return err
|
||||
}
|
||||
|
||||
func verifySignature(stream io.Reader) error {
|
||||
signature := make([]byte, len(protocolSignature))
|
||||
if _, err := io.ReadFull(stream, signature); err != nil {
|
||||
// RPCClientStream is a stream to call methods of SessionManager
|
||||
type RPCClientStream struct {
|
||||
client tunnelpogs.SessionManager_PogsClient
|
||||
transport rpc.Transport
|
||||
}
|
||||
|
||||
func NewRPCClientStream(ctx context.Context, stream io.ReadWriteCloser, logger *zerolog.Logger) (*RPCClientStream, error) {
|
||||
n, err := stream.Write(RPCStreamProtocolSignature[:])
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if n != len(RPCStreamProtocolSignature) {
|
||||
return nil, fmt.Errorf("expect to write %d bytes for RPC stream protocol signature, wrote %d", len(RPCStreamProtocolSignature), n)
|
||||
}
|
||||
transport := tunnelrpc.NewTransportLogger(logger, rpc.StreamTransport(stream))
|
||||
conn := rpc.NewConn(
|
||||
transport,
|
||||
tunnelrpc.ConnLog(logger),
|
||||
)
|
||||
return &RPCClientStream{
|
||||
client: tunnelpogs.SessionManager_PogsClient{Client: conn.Bootstrap(ctx), Conn: conn},
|
||||
transport: transport,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (rcs *RPCClientStream) RegisterUdpSession(ctx context.Context, sessionID uuid.UUID, dstIP net.IP, dstPort uint16) error {
|
||||
resp, err := rcs.client.RegisterUdpSession(ctx, sessionID, dstIP, dstPort)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if !bytes.Equal(signature[0:], protocolSignature) {
|
||||
return fmt.Errorf("Wrong signature: %v", signature)
|
||||
}
|
||||
|
||||
return nil
|
||||
return resp.Err
|
||||
}
|
||||
|
||||
func (rcs *RPCClientStream) Close() {
|
||||
_ = rcs.client.Close()
|
||||
_ = rcs.transport.Close()
|
||||
}
|
||||
|
@@ -2,9 +2,15 @@ package quic
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"net"
|
||||
"testing"
|
||||
|
||||
"github.com/google/uuid"
|
||||
"github.com/rs/zerolog"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
@@ -21,7 +27,7 @@ func TestConnectRequestData(t *testing.T) {
|
||||
hostname: "tunnel.com",
|
||||
connectionType: ConnectionTypeHTTP,
|
||||
metadata: []Metadata{
|
||||
Metadata{
|
||||
{
|
||||
Key: "key",
|
||||
Val: "1234",
|
||||
},
|
||||
@@ -31,9 +37,15 @@ func TestConnectRequestData(t *testing.T) {
|
||||
for _, test := range tests {
|
||||
t.Run(test.name, func(t *testing.T) {
|
||||
b := &bytes.Buffer{}
|
||||
err := WriteConnectRequestData(b, test.hostname, test.connectionType, test.metadata...)
|
||||
reqClientStream := RequestClientStream{noopCloser{b}}
|
||||
err := reqClientStream.WriteConnectRequestData(test.hostname, test.connectionType, test.metadata...)
|
||||
require.NoError(t, err)
|
||||
reqMeta, err := ReadConnectRequestData(b)
|
||||
protocol, err := DetermineProtocol(b)
|
||||
require.NoError(t, err)
|
||||
reqServerStream, err := NewRequestServerStream(noopCloser{b}, protocol)
|
||||
require.NoError(t, err)
|
||||
|
||||
reqMeta, err := reqServerStream.ReadConnectRequestData()
|
||||
require.NoError(t, err)
|
||||
|
||||
assert.Equal(t, test.metadata, reqMeta.Metadata)
|
||||
@@ -52,7 +64,7 @@ func TestConnectResponseMeta(t *testing.T) {
|
||||
{
|
||||
name: "Signature verified and response metadata is unmarshaled and read correctly",
|
||||
metadata: []Metadata{
|
||||
Metadata{
|
||||
{
|
||||
Key: "key",
|
||||
Val: "1234",
|
||||
},
|
||||
@@ -62,7 +74,7 @@ func TestConnectResponseMeta(t *testing.T) {
|
||||
name: "If error is not empty, other fields should be blank",
|
||||
err: errors.New("something happened"),
|
||||
metadata: []Metadata{
|
||||
Metadata{
|
||||
{
|
||||
Key: "key",
|
||||
Val: "1234",
|
||||
},
|
||||
@@ -73,9 +85,12 @@ func TestConnectResponseMeta(t *testing.T) {
|
||||
for _, test := range tests {
|
||||
t.Run(test.name, func(t *testing.T) {
|
||||
b := &bytes.Buffer{}
|
||||
err := WriteConnectResponseData(b, test.err, test.metadata...)
|
||||
reqServerStream := RequestServerStream{noopCloser{b}}
|
||||
err := reqServerStream.WriteConnectResponseData(test.err, test.metadata...)
|
||||
require.NoError(t, err)
|
||||
respMeta, err := ReadConnectResponseData(b)
|
||||
|
||||
reqClientStream := RequestClientStream{noopCloser{b}}
|
||||
respMeta, err := reqClientStream.ReadConnectResponseData()
|
||||
require.NoError(t, err)
|
||||
|
||||
if respMeta.Error == "" {
|
||||
@@ -86,3 +101,81 @@ func TestConnectResponseMeta(t *testing.T) {
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestRegisterUdpSession(t *testing.T) {
|
||||
clientReader, serverWriter := io.Pipe()
|
||||
serverReader, clientWriter := io.Pipe()
|
||||
|
||||
clientStream := mockRPCStream{clientReader, clientWriter}
|
||||
serverStream := mockRPCStream{serverReader, serverWriter}
|
||||
|
||||
rpcServer := mockRPCServer{
|
||||
sessionID: uuid.New(),
|
||||
dstIP: net.IP{172, 16, 0, 1},
|
||||
dstPort: 8000,
|
||||
}
|
||||
logger := zerolog.Nop()
|
||||
sessionRegisteredChan := make(chan struct{})
|
||||
go func() {
|
||||
protocol, err := DetermineProtocol(serverStream)
|
||||
assert.NoError(t, err)
|
||||
rpcServerStream, err := NewRPCServerStream(serverStream, protocol)
|
||||
assert.NoError(t, err)
|
||||
err = rpcServerStream.Serve(rpcServer, &logger)
|
||||
assert.NoError(t, err)
|
||||
|
||||
serverStream.Close()
|
||||
close(sessionRegisteredChan)
|
||||
}()
|
||||
|
||||
rpcClientStream, err := NewRPCClientStream(context.Background(), clientStream, &logger)
|
||||
assert.NoError(t, err)
|
||||
|
||||
err = rpcClientStream.RegisterUdpSession(context.Background(), rpcServer.sessionID, rpcServer.dstIP, rpcServer.dstPort)
|
||||
assert.NoError(t, err)
|
||||
|
||||
// Different sessionID, the RPC server should reject the registraion
|
||||
err = rpcClientStream.RegisterUdpSession(context.Background(), uuid.New(), rpcServer.dstIP, rpcServer.dstPort)
|
||||
assert.Error(t, err)
|
||||
|
||||
rpcClientStream.Close()
|
||||
<-sessionRegisteredChan
|
||||
}
|
||||
|
||||
type mockRPCServer struct {
|
||||
sessionID uuid.UUID
|
||||
dstIP net.IP
|
||||
dstPort uint16
|
||||
}
|
||||
|
||||
func (s mockRPCServer) RegisterUdpSession(ctx context.Context, sessionID uuid.UUID, dstIP net.IP, dstPort uint16) error {
|
||||
if s.sessionID != sessionID {
|
||||
return fmt.Errorf("expect session ID %s, got %s", s.sessionID, sessionID)
|
||||
}
|
||||
if !s.dstIP.Equal(dstIP) {
|
||||
return fmt.Errorf("expect destination IP %s, got %s", s.dstIP, dstIP)
|
||||
}
|
||||
if s.dstPort != dstPort {
|
||||
return fmt.Errorf("expect session ID %d, got %d", s.dstPort, dstPort)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
type mockRPCStream struct {
|
||||
io.ReadCloser
|
||||
io.WriteCloser
|
||||
}
|
||||
|
||||
func (s mockRPCStream) Close() error {
|
||||
_ = s.ReadCloser.Close()
|
||||
_ = s.WriteCloser.Close()
|
||||
return nil
|
||||
}
|
||||
|
||||
type noopCloser struct {
|
||||
io.ReadWriter
|
||||
}
|
||||
|
||||
func (noopCloser) Close() error {
|
||||
return nil
|
||||
}
|
||||
|
Reference in New Issue
Block a user