mirror of
https://github.com/cloudflare/cloudflared.git
synced 2025-07-27 08:09:58 +00:00
TUN-3085: Pass connection authentication information using TunnelAuth struct
This commit is contained in:

committed by
Adam Chalmers

parent
448a7798f7
commit
8f75feac94
@@ -14,7 +14,7 @@ import (
|
||||
)
|
||||
|
||||
type RegistrationServer interface {
|
||||
RegisterConnection(ctx context.Context, auth []byte, tunnelID uuid.UUID, connIndex byte, options *ConnectionOptions) (*ConnectionDetails, error)
|
||||
RegisterConnection(ctx context.Context, auth TunnelAuth, tunnelID uuid.UUID, connIndex byte, options *ConnectionOptions) (*ConnectionDetails, error)
|
||||
UnregisterConnection(ctx context.Context)
|
||||
}
|
||||
|
||||
@@ -32,6 +32,11 @@ type ConnectionOptions struct {
|
||||
CompressionQuality uint8
|
||||
}
|
||||
|
||||
type TunnelAuth struct {
|
||||
AccountTag string
|
||||
TunnelSecret []byte
|
||||
}
|
||||
|
||||
func (p *ConnectionOptions) MarshalCapnproto(s tunnelrpc.ConnectionOptions) error {
|
||||
return pogs.Insert(tunnelrpc.ConnectionOptions_TypeID, s.Struct, p)
|
||||
}
|
||||
@@ -40,6 +45,14 @@ func (p *ConnectionOptions) UnmarshalCapnproto(s tunnelrpc.ConnectionOptions) er
|
||||
return pogs.Extract(p, tunnelrpc.ConnectionOptions_TypeID, s.Struct)
|
||||
}
|
||||
|
||||
func (a *TunnelAuth) MarshalCapnproto(s tunnelrpc.TunnelAuth) error {
|
||||
return pogs.Insert(tunnelrpc.TunnelAuth_TypeID, s.Struct, a)
|
||||
}
|
||||
|
||||
func (a *TunnelAuth) UnmarshalCapnproto(s tunnelrpc.TunnelAuth) error {
|
||||
return pogs.Extract(a, tunnelrpc.TunnelAuth_TypeID, s.Struct)
|
||||
}
|
||||
|
||||
type ConnectionDetails struct {
|
||||
UUID uuid.UUID
|
||||
Location string
|
||||
@@ -92,6 +105,11 @@ func (i TunnelServer_PogsImpl) RegisterConnection(p tunnelrpc.RegistrationServer
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
var pogsAuth TunnelAuth
|
||||
err = pogsAuth.UnmarshalCapnproto(auth)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
uuidBytes, err := p.Params.TunnelId()
|
||||
if err != nil {
|
||||
return err
|
||||
@@ -111,7 +129,7 @@ func (i TunnelServer_PogsImpl) RegisterConnection(p tunnelrpc.RegistrationServer
|
||||
return err
|
||||
}
|
||||
|
||||
connDetails, callError := i.impl.RegisterConnection(p.Ctx, auth, tunnelID, connIndex, &pogsOptions)
|
||||
connDetails, callError := i.impl.RegisterConnection(p.Ctx, pogsAuth, tunnelID, connIndex, &pogsOptions)
|
||||
|
||||
resp, err := p.Results.NewResult()
|
||||
if err != nil {
|
||||
@@ -140,10 +158,17 @@ func (i TunnelServer_PogsImpl) UnregisterConnection(p tunnelrpc.RegistrationServ
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c TunnelServer_PogsClient) RegisterConnection(ctx context.Context, auth []byte, tunnelID uuid.UUID, connIndex byte, options *ConnectionOptions) (*ConnectionDetails, error) {
|
||||
func (c TunnelServer_PogsClient) RegisterConnection(ctx context.Context, auth TunnelAuth, tunnelID uuid.UUID, connIndex byte, options *ConnectionOptions) (*ConnectionDetails, error) {
|
||||
client := tunnelrpc.TunnelServer{Client: c.Client}
|
||||
promise := client.RegisterConnection(ctx, func(p tunnelrpc.RegistrationServer_registerConnection_Params) error {
|
||||
err := p.SetAuth(auth)
|
||||
tunnelAuth, err := p.NewAuth()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if err = auth.MarshalCapnproto(tunnelAuth); err != nil {
|
||||
return err
|
||||
}
|
||||
err = p.SetAuth(tunnelAuth)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
@@ -16,6 +16,8 @@ import (
|
||||
"github.com/cloudflare/cloudflared/tunnelrpc"
|
||||
)
|
||||
|
||||
const testAccountTag = "abc123"
|
||||
|
||||
func TestMarshalConnectionOptions(t *testing.T) {
|
||||
clientID := uuid.New()
|
||||
orig := ConnectionOptions{
|
||||
@@ -47,6 +49,7 @@ func TestMarshalConnectionOptions(t *testing.T) {
|
||||
|
||||
func TestConnectionRegistrationRPC(t *testing.T) {
|
||||
p1, p2 := net.Pipe()
|
||||
|
||||
t1, t2 := rpc.StreamTransport(p1), rpc.StreamTransport(p2)
|
||||
|
||||
// Server-side
|
||||
@@ -84,9 +87,14 @@ func TestConnectionRegistrationRPC(t *testing.T) {
|
||||
testImpl.details = &expectedDetails
|
||||
testImpl.err = nil
|
||||
|
||||
auth := TunnelAuth{
|
||||
AccountTag: testAccountTag,
|
||||
TunnelSecret: []byte{1, 2, 3, 4},
|
||||
}
|
||||
|
||||
// success
|
||||
tunnelID := uuid.New()
|
||||
details, err := client.Register(ctx, []byte{1, 2, 3}, tunnelID, 2, options)
|
||||
details, err := client.RegisterConnection(ctx, auth, tunnelID, 2, options)
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, expectedDetails, *details)
|
||||
|
||||
@@ -94,15 +102,15 @@ func TestConnectionRegistrationRPC(t *testing.T) {
|
||||
testImpl.details = nil
|
||||
testImpl.err = errors.New("internal")
|
||||
|
||||
_, err = client.Register(ctx, []byte{1, 2, 3}, tunnelID, 2, options)
|
||||
_, err = client.RegisterConnection(ctx, auth, tunnelID, 2, options)
|
||||
assert.EqualError(t, err, "internal")
|
||||
|
||||
// retriable error
|
||||
testImpl.details = nil
|
||||
const delay = 27*time.Second
|
||||
const delay = 27 * time.Second
|
||||
testImpl.err = RetryErrorAfter(errors.New("retryable"), delay)
|
||||
|
||||
_, err = client.Register(ctx, []byte{1, 2, 3}, tunnelID, 2, options)
|
||||
_, err = client.RegisterConnection(ctx, auth, tunnelID, 2, options)
|
||||
assert.EqualError(t, err, "retryable")
|
||||
|
||||
re, ok := err.(*RetryableError)
|
||||
@@ -117,7 +125,10 @@ type testConnectionRegistrationServer struct {
|
||||
err error
|
||||
}
|
||||
|
||||
func (t testConnectionRegistrationServer) Register(ctx context.Context, auth []byte, tunnelUUID uuid.UUID, connIndex byte, options *ConnectionOptions) (*ConnectionDetails, error) {
|
||||
func (t *testConnectionRegistrationServer) RegisterConnection(ctx context.Context, auth TunnelAuth, tunnelID uuid.UUID, connIndex byte, options *ConnectionOptions) (*ConnectionDetails, error) {
|
||||
if auth.AccountTag != testAccountTag {
|
||||
panic("bad account tag: " + auth.AccountTag)
|
||||
}
|
||||
if t.err != nil {
|
||||
return nil, t.err
|
||||
}
|
||||
|
@@ -11,12 +11,12 @@ import (
|
||||
// mocks for specific unit tests without having to implement every method
|
||||
type mockTunnelServerBase struct{}
|
||||
|
||||
func (mockTunnelServerBase) Register(ctx context.Context, auth []byte, tunnelUUID uuid.UUID, connIndex byte, options *ConnectionOptions) (*ConnectionDetails, error) {
|
||||
panic("unexpected call to Register")
|
||||
func (mockTunnelServerBase) RegisterConnection(ctx context.Context, auth TunnelAuth, tunnelID uuid.UUID, connIndex byte, options *ConnectionOptions) (*ConnectionDetails, error) {
|
||||
panic("unexpected call to RegisterConnection")
|
||||
}
|
||||
|
||||
func (mockTunnelServerBase) Unregister(ctx context.Context) {
|
||||
panic("unexpected call to Unregister")
|
||||
func (mockTunnelServerBase) UnregisterConnection(ctx context.Context) {
|
||||
panic("unexpected call to UnregisterConnection")
|
||||
}
|
||||
|
||||
func (mockTunnelServerBase) RegisterTunnel(ctx context.Context, originCert []byte, hostname string, options *RegistrationOptions) *TunnelRegistration {
|
||||
|
Reference in New Issue
Block a user