mirror of
https://github.com/cloudflare/cloudflared.git
synced 2025-07-28 12:09:57 +00:00
TUN-5300: Define RPC to register UDP sessions
This commit is contained in:

committed by
Arég Harutyunyan

parent
571380b3f5
commit
fc2333c934
@@ -1,16 +1,32 @@
|
||||
package quic
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"fmt"
|
||||
"io"
|
||||
"net"
|
||||
|
||||
capnp "zombiezen.com/go/capnproto2"
|
||||
"zombiezen.com/go/capnproto2/rpc"
|
||||
|
||||
"github.com/google/uuid"
|
||||
"github.com/rs/zerolog"
|
||||
|
||||
"github.com/cloudflare/cloudflared/tunnelrpc"
|
||||
tunnelpogs "github.com/cloudflare/cloudflared/tunnelrpc/pogs"
|
||||
)
|
||||
|
||||
// protocolSignature is a custom protocol signature to ensure that whoever performs a handshake does not write data
|
||||
// before writing the metadata.
|
||||
var protocolSignature = []byte{0x0A, 0x36, 0xCD, 0x12, 0xA1, 0x3E}
|
||||
// The first 6 bytes of the stream is used to distinguish the type of stream. It ensures whoever performs a handshake does
|
||||
// not write data before writing the metadata.
|
||||
type ProtocolSignature [6]byte
|
||||
|
||||
var (
|
||||
// DataStreamProtocolSignature is a custom protocol signature for data stream
|
||||
DataStreamProtocolSignature = ProtocolSignature{0x0A, 0x36, 0xCD, 0x12, 0xA1, 0x3E}
|
||||
|
||||
// RPCStreamProtocolSignature is a custom protocol signature for RPC stream
|
||||
RPCStreamProtocolSignature = ProtocolSignature{0x52, 0xBB, 0x82, 0x5C, 0xDB, 0x65}
|
||||
)
|
||||
|
||||
const protocolVersionLength = 2
|
||||
|
||||
@@ -20,18 +36,26 @@ const (
|
||||
protocolV1 protocolVersion = "01"
|
||||
)
|
||||
|
||||
// RequestServerStream is a stream to serve requests
|
||||
type RequestServerStream struct {
|
||||
io.ReadWriteCloser
|
||||
}
|
||||
|
||||
func NewRequestServerStream(stream io.ReadWriteCloser, signature ProtocolSignature) (*RequestServerStream, error) {
|
||||
if signature != DataStreamProtocolSignature {
|
||||
return nil, fmt.Errorf("RequestClientStream can only be created from data stream")
|
||||
}
|
||||
return &RequestServerStream{stream}, nil
|
||||
}
|
||||
|
||||
// ReadConnectRequestData reads the handshake data from a QUIC stream.
|
||||
func ReadConnectRequestData(stream io.Reader) (*ConnectRequest, error) {
|
||||
if err := verifySignature(stream); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
func (rss *RequestServerStream) ReadConnectRequestData() (*ConnectRequest, error) {
|
||||
// This is a NO-OP for now. We could cause a branching if we wanted to use multiple versions.
|
||||
if _, err := readVersion(stream); err != nil {
|
||||
if _, err := readVersion(rss); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
msg, err := capnp.NewDecoder(stream).Decode()
|
||||
msg, err := capnp.NewDecoder(rss).Decode()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@@ -43,50 +67,8 @@ func ReadConnectRequestData(stream io.Reader) (*ConnectRequest, error) {
|
||||
return r, nil
|
||||
}
|
||||
|
||||
// WriteConnectRequestData writes requestMeta to a stream.
|
||||
func WriteConnectRequestData(stream io.Writer, dest string, connectionType ConnectionType, metadata ...Metadata) error {
|
||||
connectRequest := &ConnectRequest{
|
||||
Dest: dest,
|
||||
Type: connectionType,
|
||||
Metadata: metadata,
|
||||
}
|
||||
|
||||
msg, err := connectRequest.toPogs()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if err := writePreamble(stream); err != nil {
|
||||
return err
|
||||
}
|
||||
return capnp.NewEncoder(stream).Encode(msg)
|
||||
}
|
||||
|
||||
// ReadConnectResponseData reads the response to a RequestMeta in a stream.
|
||||
func ReadConnectResponseData(stream io.Reader) (*ConnectResponse, error) {
|
||||
if err := verifySignature(stream); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// This is a NO-OP for now. We could cause a branching if we wanted to use multiple versions.
|
||||
if _, err := readVersion(stream); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
msg, err := capnp.NewDecoder(stream).Decode()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
r := &ConnectResponse{}
|
||||
if err := r.fromPogs(msg); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return r, nil
|
||||
}
|
||||
|
||||
// WriteConnectResponseData writes response to a QUIC stream.
|
||||
func WriteConnectResponseData(stream io.Writer, respErr error, metadata ...Metadata) error {
|
||||
func (rss *RequestServerStream) WriteConnectResponseData(respErr error, metadata ...Metadata) error {
|
||||
var connectResponse *ConnectResponse
|
||||
if respErr != nil {
|
||||
connectResponse = &ConnectResponse{
|
||||
@@ -103,14 +85,105 @@ func WriteConnectResponseData(stream io.Writer, respErr error, metadata ...Metad
|
||||
return err
|
||||
}
|
||||
|
||||
if err := writePreamble(stream); err != nil {
|
||||
if err := writeDataStreamPreamble(rss); err != nil {
|
||||
return err
|
||||
}
|
||||
return capnp.NewEncoder(stream).Encode(msg)
|
||||
return capnp.NewEncoder(rss).Encode(msg)
|
||||
}
|
||||
|
||||
func writePreamble(stream io.Writer) error {
|
||||
if err := writeSignature(stream); err != nil {
|
||||
type RequestClientStream struct {
|
||||
io.ReadWriteCloser
|
||||
}
|
||||
|
||||
// WriteConnectRequestData writes requestMeta to a stream.
|
||||
func (rcs *RequestClientStream) WriteConnectRequestData(dest string, connectionType ConnectionType, metadata ...Metadata) error {
|
||||
connectRequest := &ConnectRequest{
|
||||
Dest: dest,
|
||||
Type: connectionType,
|
||||
Metadata: metadata,
|
||||
}
|
||||
|
||||
msg, err := connectRequest.toPogs()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if err := writeDataStreamPreamble(rcs); err != nil {
|
||||
return err
|
||||
}
|
||||
return capnp.NewEncoder(rcs).Encode(msg)
|
||||
}
|
||||
|
||||
// ReadConnectResponseData reads the response to a RequestMeta in a stream.
|
||||
func (rcs *RequestClientStream) ReadConnectResponseData() (*ConnectResponse, error) {
|
||||
signature, err := DetermineProtocol(rcs)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if signature != DataStreamProtocolSignature {
|
||||
return nil, fmt.Errorf("Wrong protocol signature %v", signature)
|
||||
}
|
||||
|
||||
// This is a NO-OP for now. We could cause a branching if we wanted to use multiple versions.
|
||||
if _, err := readVersion(rcs); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
msg, err := capnp.NewDecoder(rcs).Decode()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
r := &ConnectResponse{}
|
||||
if err := r.fromPogs(msg); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return r, nil
|
||||
}
|
||||
|
||||
// RPCServerStream is a stream to serve RPCs. It is closed when the RPC client is done
|
||||
type RPCServerStream struct {
|
||||
io.ReadWriteCloser
|
||||
}
|
||||
|
||||
func NewRPCServerStream(stream io.ReadWriteCloser, protocol ProtocolSignature) (*RPCServerStream, error) {
|
||||
if protocol != RPCStreamProtocolSignature {
|
||||
return nil, fmt.Errorf("RPCStream can only be created from rpc stream")
|
||||
}
|
||||
return &RPCServerStream{stream}, nil
|
||||
}
|
||||
|
||||
func (s *RPCServerStream) Serve(sessionManager tunnelpogs.SessionManager, logger *zerolog.Logger) error {
|
||||
rpcTransport := tunnelrpc.NewTransportLogger(logger, rpc.StreamTransport(s))
|
||||
defer rpcTransport.Close()
|
||||
|
||||
main := tunnelpogs.SessionManager_ServerToClient(sessionManager)
|
||||
rpcConn := rpc.NewConn(
|
||||
rpcTransport,
|
||||
rpc.MainInterface(main.Client),
|
||||
)
|
||||
defer rpcConn.Close()
|
||||
|
||||
return rpcConn.Wait()
|
||||
}
|
||||
|
||||
func DetermineProtocol(stream io.Reader) (ProtocolSignature, error) {
|
||||
signature, err := readSignature(stream)
|
||||
if err != nil {
|
||||
return ProtocolSignature{}, err
|
||||
}
|
||||
switch signature {
|
||||
case DataStreamProtocolSignature:
|
||||
return DataStreamProtocolSignature, nil
|
||||
case RPCStreamProtocolSignature:
|
||||
return RPCStreamProtocolSignature, nil
|
||||
default:
|
||||
return ProtocolSignature{}, fmt.Errorf("Unknown signature %v", signature)
|
||||
}
|
||||
}
|
||||
|
||||
func writeDataStreamPreamble(stream io.Writer) error {
|
||||
if err := writeSignature(stream, DataStreamProtocolSignature); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
@@ -128,20 +201,53 @@ func readVersion(stream io.Reader) (string, error) {
|
||||
return string(version), err
|
||||
}
|
||||
|
||||
func writeSignature(stream io.Writer) error {
|
||||
_, err := stream.Write(protocolSignature)
|
||||
func readSignature(stream io.Reader) (ProtocolSignature, error) {
|
||||
var signature ProtocolSignature
|
||||
if _, err := io.ReadFull(stream, signature[:]); err != nil {
|
||||
return ProtocolSignature{}, err
|
||||
}
|
||||
return signature, nil
|
||||
}
|
||||
|
||||
func writeSignature(stream io.Writer, signature ProtocolSignature) error {
|
||||
_, err := stream.Write(signature[:])
|
||||
return err
|
||||
}
|
||||
|
||||
func verifySignature(stream io.Reader) error {
|
||||
signature := make([]byte, len(protocolSignature))
|
||||
if _, err := io.ReadFull(stream, signature); err != nil {
|
||||
// RPCClientStream is a stream to call methods of SessionManager
|
||||
type RPCClientStream struct {
|
||||
client tunnelpogs.SessionManager_PogsClient
|
||||
transport rpc.Transport
|
||||
}
|
||||
|
||||
func NewRPCClientStream(ctx context.Context, stream io.ReadWriteCloser, logger *zerolog.Logger) (*RPCClientStream, error) {
|
||||
n, err := stream.Write(RPCStreamProtocolSignature[:])
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if n != len(RPCStreamProtocolSignature) {
|
||||
return nil, fmt.Errorf("expect to write %d bytes for RPC stream protocol signature, wrote %d", len(RPCStreamProtocolSignature), n)
|
||||
}
|
||||
transport := tunnelrpc.NewTransportLogger(logger, rpc.StreamTransport(stream))
|
||||
conn := rpc.NewConn(
|
||||
transport,
|
||||
tunnelrpc.ConnLog(logger),
|
||||
)
|
||||
return &RPCClientStream{
|
||||
client: tunnelpogs.SessionManager_PogsClient{Client: conn.Bootstrap(ctx), Conn: conn},
|
||||
transport: transport,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (rcs *RPCClientStream) RegisterUdpSession(ctx context.Context, sessionID uuid.UUID, dstIP net.IP, dstPort uint16) error {
|
||||
resp, err := rcs.client.RegisterUdpSession(ctx, sessionID, dstIP, dstPort)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if !bytes.Equal(signature[0:], protocolSignature) {
|
||||
return fmt.Errorf("Wrong signature: %v", signature)
|
||||
}
|
||||
|
||||
return nil
|
||||
return resp.Err
|
||||
}
|
||||
|
||||
func (rcs *RPCClientStream) Close() {
|
||||
_ = rcs.client.Close()
|
||||
_ = rcs.transport.Close()
|
||||
}
|
||||
|
Reference in New Issue
Block a user