TUN-5300: Define RPC to register UDP sessions

This commit is contained in:
cthuang
2021-11-12 09:37:28 +00:00
committed by Arég Harutyunyan
parent 571380b3f5
commit fc2333c934
9 changed files with 1035 additions and 301 deletions

View File

@@ -1,16 +1,32 @@
package quic
import (
"bytes"
"context"
"fmt"
"io"
"net"
capnp "zombiezen.com/go/capnproto2"
"zombiezen.com/go/capnproto2/rpc"
"github.com/google/uuid"
"github.com/rs/zerolog"
"github.com/cloudflare/cloudflared/tunnelrpc"
tunnelpogs "github.com/cloudflare/cloudflared/tunnelrpc/pogs"
)
// protocolSignature is a custom protocol signature to ensure that whoever performs a handshake does not write data
// before writing the metadata.
var protocolSignature = []byte{0x0A, 0x36, 0xCD, 0x12, 0xA1, 0x3E}
// The first 6 bytes of the stream is used to distinguish the type of stream. It ensures whoever performs a handshake does
// not write data before writing the metadata.
type ProtocolSignature [6]byte
var (
// DataStreamProtocolSignature is a custom protocol signature for data stream
DataStreamProtocolSignature = ProtocolSignature{0x0A, 0x36, 0xCD, 0x12, 0xA1, 0x3E}
// RPCStreamProtocolSignature is a custom protocol signature for RPC stream
RPCStreamProtocolSignature = ProtocolSignature{0x52, 0xBB, 0x82, 0x5C, 0xDB, 0x65}
)
const protocolVersionLength = 2
@@ -20,18 +36,26 @@ const (
protocolV1 protocolVersion = "01"
)
// RequestServerStream is a stream to serve requests
type RequestServerStream struct {
io.ReadWriteCloser
}
func NewRequestServerStream(stream io.ReadWriteCloser, signature ProtocolSignature) (*RequestServerStream, error) {
if signature != DataStreamProtocolSignature {
return nil, fmt.Errorf("RequestClientStream can only be created from data stream")
}
return &RequestServerStream{stream}, nil
}
// ReadConnectRequestData reads the handshake data from a QUIC stream.
func ReadConnectRequestData(stream io.Reader) (*ConnectRequest, error) {
if err := verifySignature(stream); err != nil {
return nil, err
}
func (rss *RequestServerStream) ReadConnectRequestData() (*ConnectRequest, error) {
// This is a NO-OP for now. We could cause a branching if we wanted to use multiple versions.
if _, err := readVersion(stream); err != nil {
if _, err := readVersion(rss); err != nil {
return nil, err
}
msg, err := capnp.NewDecoder(stream).Decode()
msg, err := capnp.NewDecoder(rss).Decode()
if err != nil {
return nil, err
}
@@ -43,50 +67,8 @@ func ReadConnectRequestData(stream io.Reader) (*ConnectRequest, error) {
return r, nil
}
// WriteConnectRequestData writes requestMeta to a stream.
func WriteConnectRequestData(stream io.Writer, dest string, connectionType ConnectionType, metadata ...Metadata) error {
connectRequest := &ConnectRequest{
Dest: dest,
Type: connectionType,
Metadata: metadata,
}
msg, err := connectRequest.toPogs()
if err != nil {
return err
}
if err := writePreamble(stream); err != nil {
return err
}
return capnp.NewEncoder(stream).Encode(msg)
}
// ReadConnectResponseData reads the response to a RequestMeta in a stream.
func ReadConnectResponseData(stream io.Reader) (*ConnectResponse, error) {
if err := verifySignature(stream); err != nil {
return nil, err
}
// This is a NO-OP for now. We could cause a branching if we wanted to use multiple versions.
if _, err := readVersion(stream); err != nil {
return nil, err
}
msg, err := capnp.NewDecoder(stream).Decode()
if err != nil {
return nil, err
}
r := &ConnectResponse{}
if err := r.fromPogs(msg); err != nil {
return nil, err
}
return r, nil
}
// WriteConnectResponseData writes response to a QUIC stream.
func WriteConnectResponseData(stream io.Writer, respErr error, metadata ...Metadata) error {
func (rss *RequestServerStream) WriteConnectResponseData(respErr error, metadata ...Metadata) error {
var connectResponse *ConnectResponse
if respErr != nil {
connectResponse = &ConnectResponse{
@@ -103,14 +85,105 @@ func WriteConnectResponseData(stream io.Writer, respErr error, metadata ...Metad
return err
}
if err := writePreamble(stream); err != nil {
if err := writeDataStreamPreamble(rss); err != nil {
return err
}
return capnp.NewEncoder(stream).Encode(msg)
return capnp.NewEncoder(rss).Encode(msg)
}
func writePreamble(stream io.Writer) error {
if err := writeSignature(stream); err != nil {
type RequestClientStream struct {
io.ReadWriteCloser
}
// WriteConnectRequestData writes requestMeta to a stream.
func (rcs *RequestClientStream) WriteConnectRequestData(dest string, connectionType ConnectionType, metadata ...Metadata) error {
connectRequest := &ConnectRequest{
Dest: dest,
Type: connectionType,
Metadata: metadata,
}
msg, err := connectRequest.toPogs()
if err != nil {
return err
}
if err := writeDataStreamPreamble(rcs); err != nil {
return err
}
return capnp.NewEncoder(rcs).Encode(msg)
}
// ReadConnectResponseData reads the response to a RequestMeta in a stream.
func (rcs *RequestClientStream) ReadConnectResponseData() (*ConnectResponse, error) {
signature, err := DetermineProtocol(rcs)
if err != nil {
return nil, err
}
if signature != DataStreamProtocolSignature {
return nil, fmt.Errorf("Wrong protocol signature %v", signature)
}
// This is a NO-OP for now. We could cause a branching if we wanted to use multiple versions.
if _, err := readVersion(rcs); err != nil {
return nil, err
}
msg, err := capnp.NewDecoder(rcs).Decode()
if err != nil {
return nil, err
}
r := &ConnectResponse{}
if err := r.fromPogs(msg); err != nil {
return nil, err
}
return r, nil
}
// RPCServerStream is a stream to serve RPCs. It is closed when the RPC client is done
type RPCServerStream struct {
io.ReadWriteCloser
}
func NewRPCServerStream(stream io.ReadWriteCloser, protocol ProtocolSignature) (*RPCServerStream, error) {
if protocol != RPCStreamProtocolSignature {
return nil, fmt.Errorf("RPCStream can only be created from rpc stream")
}
return &RPCServerStream{stream}, nil
}
func (s *RPCServerStream) Serve(sessionManager tunnelpogs.SessionManager, logger *zerolog.Logger) error {
rpcTransport := tunnelrpc.NewTransportLogger(logger, rpc.StreamTransport(s))
defer rpcTransport.Close()
main := tunnelpogs.SessionManager_ServerToClient(sessionManager)
rpcConn := rpc.NewConn(
rpcTransport,
rpc.MainInterface(main.Client),
)
defer rpcConn.Close()
return rpcConn.Wait()
}
func DetermineProtocol(stream io.Reader) (ProtocolSignature, error) {
signature, err := readSignature(stream)
if err != nil {
return ProtocolSignature{}, err
}
switch signature {
case DataStreamProtocolSignature:
return DataStreamProtocolSignature, nil
case RPCStreamProtocolSignature:
return RPCStreamProtocolSignature, nil
default:
return ProtocolSignature{}, fmt.Errorf("Unknown signature %v", signature)
}
}
func writeDataStreamPreamble(stream io.Writer) error {
if err := writeSignature(stream, DataStreamProtocolSignature); err != nil {
return err
}
@@ -128,20 +201,53 @@ func readVersion(stream io.Reader) (string, error) {
return string(version), err
}
func writeSignature(stream io.Writer) error {
_, err := stream.Write(protocolSignature)
func readSignature(stream io.Reader) (ProtocolSignature, error) {
var signature ProtocolSignature
if _, err := io.ReadFull(stream, signature[:]); err != nil {
return ProtocolSignature{}, err
}
return signature, nil
}
func writeSignature(stream io.Writer, signature ProtocolSignature) error {
_, err := stream.Write(signature[:])
return err
}
func verifySignature(stream io.Reader) error {
signature := make([]byte, len(protocolSignature))
if _, err := io.ReadFull(stream, signature); err != nil {
// RPCClientStream is a stream to call methods of SessionManager
type RPCClientStream struct {
client tunnelpogs.SessionManager_PogsClient
transport rpc.Transport
}
func NewRPCClientStream(ctx context.Context, stream io.ReadWriteCloser, logger *zerolog.Logger) (*RPCClientStream, error) {
n, err := stream.Write(RPCStreamProtocolSignature[:])
if err != nil {
return nil, err
}
if n != len(RPCStreamProtocolSignature) {
return nil, fmt.Errorf("expect to write %d bytes for RPC stream protocol signature, wrote %d", len(RPCStreamProtocolSignature), n)
}
transport := tunnelrpc.NewTransportLogger(logger, rpc.StreamTransport(stream))
conn := rpc.NewConn(
transport,
tunnelrpc.ConnLog(logger),
)
return &RPCClientStream{
client: tunnelpogs.SessionManager_PogsClient{Client: conn.Bootstrap(ctx), Conn: conn},
transport: transport,
}, nil
}
func (rcs *RPCClientStream) RegisterUdpSession(ctx context.Context, sessionID uuid.UUID, dstIP net.IP, dstPort uint16) error {
resp, err := rcs.client.RegisterUdpSession(ctx, sessionID, dstIP, dstPort)
if err != nil {
return err
}
if !bytes.Equal(signature[0:], protocolSignature) {
return fmt.Errorf("Wrong signature: %v", signature)
}
return nil
return resp.Err
}
func (rcs *RPCClientStream) Close() {
_ = rcs.client.Close()
_ = rcs.transport.Close()
}

View File

@@ -2,9 +2,15 @@ package quic
import (
"bytes"
"context"
"errors"
"fmt"
"io"
"net"
"testing"
"github.com/google/uuid"
"github.com/rs/zerolog"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
@@ -21,7 +27,7 @@ func TestConnectRequestData(t *testing.T) {
hostname: "tunnel.com",
connectionType: ConnectionTypeHTTP,
metadata: []Metadata{
Metadata{
{
Key: "key",
Val: "1234",
},
@@ -31,9 +37,15 @@ func TestConnectRequestData(t *testing.T) {
for _, test := range tests {
t.Run(test.name, func(t *testing.T) {
b := &bytes.Buffer{}
err := WriteConnectRequestData(b, test.hostname, test.connectionType, test.metadata...)
reqClientStream := RequestClientStream{noopCloser{b}}
err := reqClientStream.WriteConnectRequestData(test.hostname, test.connectionType, test.metadata...)
require.NoError(t, err)
reqMeta, err := ReadConnectRequestData(b)
protocol, err := DetermineProtocol(b)
require.NoError(t, err)
reqServerStream, err := NewRequestServerStream(noopCloser{b}, protocol)
require.NoError(t, err)
reqMeta, err := reqServerStream.ReadConnectRequestData()
require.NoError(t, err)
assert.Equal(t, test.metadata, reqMeta.Metadata)
@@ -52,7 +64,7 @@ func TestConnectResponseMeta(t *testing.T) {
{
name: "Signature verified and response metadata is unmarshaled and read correctly",
metadata: []Metadata{
Metadata{
{
Key: "key",
Val: "1234",
},
@@ -62,7 +74,7 @@ func TestConnectResponseMeta(t *testing.T) {
name: "If error is not empty, other fields should be blank",
err: errors.New("something happened"),
metadata: []Metadata{
Metadata{
{
Key: "key",
Val: "1234",
},
@@ -73,9 +85,12 @@ func TestConnectResponseMeta(t *testing.T) {
for _, test := range tests {
t.Run(test.name, func(t *testing.T) {
b := &bytes.Buffer{}
err := WriteConnectResponseData(b, test.err, test.metadata...)
reqServerStream := RequestServerStream{noopCloser{b}}
err := reqServerStream.WriteConnectResponseData(test.err, test.metadata...)
require.NoError(t, err)
respMeta, err := ReadConnectResponseData(b)
reqClientStream := RequestClientStream{noopCloser{b}}
respMeta, err := reqClientStream.ReadConnectResponseData()
require.NoError(t, err)
if respMeta.Error == "" {
@@ -86,3 +101,81 @@ func TestConnectResponseMeta(t *testing.T) {
})
}
}
func TestRegisterUdpSession(t *testing.T) {
clientReader, serverWriter := io.Pipe()
serverReader, clientWriter := io.Pipe()
clientStream := mockRPCStream{clientReader, clientWriter}
serverStream := mockRPCStream{serverReader, serverWriter}
rpcServer := mockRPCServer{
sessionID: uuid.New(),
dstIP: net.IP{172, 16, 0, 1},
dstPort: 8000,
}
logger := zerolog.Nop()
sessionRegisteredChan := make(chan struct{})
go func() {
protocol, err := DetermineProtocol(serverStream)
assert.NoError(t, err)
rpcServerStream, err := NewRPCServerStream(serverStream, protocol)
assert.NoError(t, err)
err = rpcServerStream.Serve(rpcServer, &logger)
assert.NoError(t, err)
serverStream.Close()
close(sessionRegisteredChan)
}()
rpcClientStream, err := NewRPCClientStream(context.Background(), clientStream, &logger)
assert.NoError(t, err)
err = rpcClientStream.RegisterUdpSession(context.Background(), rpcServer.sessionID, rpcServer.dstIP, rpcServer.dstPort)
assert.NoError(t, err)
// Different sessionID, the RPC server should reject the registraion
err = rpcClientStream.RegisterUdpSession(context.Background(), uuid.New(), rpcServer.dstIP, rpcServer.dstPort)
assert.Error(t, err)
rpcClientStream.Close()
<-sessionRegisteredChan
}
type mockRPCServer struct {
sessionID uuid.UUID
dstIP net.IP
dstPort uint16
}
func (s mockRPCServer) RegisterUdpSession(ctx context.Context, sessionID uuid.UUID, dstIP net.IP, dstPort uint16) error {
if s.sessionID != sessionID {
return fmt.Errorf("expect session ID %s, got %s", s.sessionID, sessionID)
}
if !s.dstIP.Equal(dstIP) {
return fmt.Errorf("expect destination IP %s, got %s", s.dstIP, dstIP)
}
if s.dstPort != dstPort {
return fmt.Errorf("expect session ID %d, got %d", s.dstPort, dstPort)
}
return nil
}
type mockRPCStream struct {
io.ReadCloser
io.WriteCloser
}
func (s mockRPCStream) Close() error {
_ = s.ReadCloser.Close()
_ = s.WriteCloser.Close()
return nil
}
type noopCloser struct {
io.ReadWriter
}
func (noopCloser) Close() error {
return nil
}