mirror of
https://github.com/cloudflare/cloudflared.git
synced 2025-07-28 11:59:58 +00:00
TUN-5300: Define RPC to register UDP sessions
This commit is contained in:

committed by
Arég Harutyunyan

parent
571380b3f5
commit
fc2333c934
@@ -2,9 +2,15 @@ package quic
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"net"
|
||||
"testing"
|
||||
|
||||
"github.com/google/uuid"
|
||||
"github.com/rs/zerolog"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
@@ -21,7 +27,7 @@ func TestConnectRequestData(t *testing.T) {
|
||||
hostname: "tunnel.com",
|
||||
connectionType: ConnectionTypeHTTP,
|
||||
metadata: []Metadata{
|
||||
Metadata{
|
||||
{
|
||||
Key: "key",
|
||||
Val: "1234",
|
||||
},
|
||||
@@ -31,9 +37,15 @@ func TestConnectRequestData(t *testing.T) {
|
||||
for _, test := range tests {
|
||||
t.Run(test.name, func(t *testing.T) {
|
||||
b := &bytes.Buffer{}
|
||||
err := WriteConnectRequestData(b, test.hostname, test.connectionType, test.metadata...)
|
||||
reqClientStream := RequestClientStream{noopCloser{b}}
|
||||
err := reqClientStream.WriteConnectRequestData(test.hostname, test.connectionType, test.metadata...)
|
||||
require.NoError(t, err)
|
||||
reqMeta, err := ReadConnectRequestData(b)
|
||||
protocol, err := DetermineProtocol(b)
|
||||
require.NoError(t, err)
|
||||
reqServerStream, err := NewRequestServerStream(noopCloser{b}, protocol)
|
||||
require.NoError(t, err)
|
||||
|
||||
reqMeta, err := reqServerStream.ReadConnectRequestData()
|
||||
require.NoError(t, err)
|
||||
|
||||
assert.Equal(t, test.metadata, reqMeta.Metadata)
|
||||
@@ -52,7 +64,7 @@ func TestConnectResponseMeta(t *testing.T) {
|
||||
{
|
||||
name: "Signature verified and response metadata is unmarshaled and read correctly",
|
||||
metadata: []Metadata{
|
||||
Metadata{
|
||||
{
|
||||
Key: "key",
|
||||
Val: "1234",
|
||||
},
|
||||
@@ -62,7 +74,7 @@ func TestConnectResponseMeta(t *testing.T) {
|
||||
name: "If error is not empty, other fields should be blank",
|
||||
err: errors.New("something happened"),
|
||||
metadata: []Metadata{
|
||||
Metadata{
|
||||
{
|
||||
Key: "key",
|
||||
Val: "1234",
|
||||
},
|
||||
@@ -73,9 +85,12 @@ func TestConnectResponseMeta(t *testing.T) {
|
||||
for _, test := range tests {
|
||||
t.Run(test.name, func(t *testing.T) {
|
||||
b := &bytes.Buffer{}
|
||||
err := WriteConnectResponseData(b, test.err, test.metadata...)
|
||||
reqServerStream := RequestServerStream{noopCloser{b}}
|
||||
err := reqServerStream.WriteConnectResponseData(test.err, test.metadata...)
|
||||
require.NoError(t, err)
|
||||
respMeta, err := ReadConnectResponseData(b)
|
||||
|
||||
reqClientStream := RequestClientStream{noopCloser{b}}
|
||||
respMeta, err := reqClientStream.ReadConnectResponseData()
|
||||
require.NoError(t, err)
|
||||
|
||||
if respMeta.Error == "" {
|
||||
@@ -86,3 +101,81 @@ func TestConnectResponseMeta(t *testing.T) {
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestRegisterUdpSession(t *testing.T) {
|
||||
clientReader, serverWriter := io.Pipe()
|
||||
serverReader, clientWriter := io.Pipe()
|
||||
|
||||
clientStream := mockRPCStream{clientReader, clientWriter}
|
||||
serverStream := mockRPCStream{serverReader, serverWriter}
|
||||
|
||||
rpcServer := mockRPCServer{
|
||||
sessionID: uuid.New(),
|
||||
dstIP: net.IP{172, 16, 0, 1},
|
||||
dstPort: 8000,
|
||||
}
|
||||
logger := zerolog.Nop()
|
||||
sessionRegisteredChan := make(chan struct{})
|
||||
go func() {
|
||||
protocol, err := DetermineProtocol(serverStream)
|
||||
assert.NoError(t, err)
|
||||
rpcServerStream, err := NewRPCServerStream(serverStream, protocol)
|
||||
assert.NoError(t, err)
|
||||
err = rpcServerStream.Serve(rpcServer, &logger)
|
||||
assert.NoError(t, err)
|
||||
|
||||
serverStream.Close()
|
||||
close(sessionRegisteredChan)
|
||||
}()
|
||||
|
||||
rpcClientStream, err := NewRPCClientStream(context.Background(), clientStream, &logger)
|
||||
assert.NoError(t, err)
|
||||
|
||||
err = rpcClientStream.RegisterUdpSession(context.Background(), rpcServer.sessionID, rpcServer.dstIP, rpcServer.dstPort)
|
||||
assert.NoError(t, err)
|
||||
|
||||
// Different sessionID, the RPC server should reject the registraion
|
||||
err = rpcClientStream.RegisterUdpSession(context.Background(), uuid.New(), rpcServer.dstIP, rpcServer.dstPort)
|
||||
assert.Error(t, err)
|
||||
|
||||
rpcClientStream.Close()
|
||||
<-sessionRegisteredChan
|
||||
}
|
||||
|
||||
type mockRPCServer struct {
|
||||
sessionID uuid.UUID
|
||||
dstIP net.IP
|
||||
dstPort uint16
|
||||
}
|
||||
|
||||
func (s mockRPCServer) RegisterUdpSession(ctx context.Context, sessionID uuid.UUID, dstIP net.IP, dstPort uint16) error {
|
||||
if s.sessionID != sessionID {
|
||||
return fmt.Errorf("expect session ID %s, got %s", s.sessionID, sessionID)
|
||||
}
|
||||
if !s.dstIP.Equal(dstIP) {
|
||||
return fmt.Errorf("expect destination IP %s, got %s", s.dstIP, dstIP)
|
||||
}
|
||||
if s.dstPort != dstPort {
|
||||
return fmt.Errorf("expect session ID %d, got %d", s.dstPort, dstPort)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
type mockRPCStream struct {
|
||||
io.ReadCloser
|
||||
io.WriteCloser
|
||||
}
|
||||
|
||||
func (s mockRPCStream) Close() error {
|
||||
_ = s.ReadCloser.Close()
|
||||
_ = s.WriteCloser.Close()
|
||||
return nil
|
||||
}
|
||||
|
||||
type noopCloser struct {
|
||||
io.ReadWriter
|
||||
}
|
||||
|
||||
func (noopCloser) Close() error {
|
||||
return nil
|
||||
}
|
||||
|
Reference in New Issue
Block a user