mirror of
https://github.com/cloudflare/cloudflared.git
synced 2025-07-27 08:09:58 +00:00
TUN-2117: read group/system-name from CLI, send it to edge
This commit is contained in:
@@ -12,6 +12,15 @@ import (
|
||||
capnp "zombiezen.com/go/capnproto2"
|
||||
)
|
||||
|
||||
// Assert *HTTPOriginConfig implements OriginConfig
|
||||
var _ OriginConfig = (*HTTPOriginConfig)(nil)
|
||||
|
||||
// Assert *WebSocketOriginConfig implements OriginConfig
|
||||
var _ OriginConfig = (*WebSocketOriginConfig)(nil)
|
||||
|
||||
// Assert *HelloWorldOriginConfig implements OriginConfig
|
||||
var _ OriginConfig = (*HelloWorldOriginConfig)(nil)
|
||||
|
||||
func TestVersion(t *testing.T) {
|
||||
firstVersion := InitVersion()
|
||||
secondVersion := Version(1)
|
||||
@@ -35,7 +44,7 @@ func TestClientConfig(t *testing.T) {
|
||||
c.Origin = sampleHTTPOriginConfig()
|
||||
}),
|
||||
sampleReverseProxyConfig(func(c *ReverseProxyConfig) {
|
||||
c.Origin = sampleHTTPOriginUnixPathConfig()
|
||||
c.Origin = sampleHTTPOriginConfig(unixPathOverride)
|
||||
}),
|
||||
sampleReverseProxyConfig(func(c *ReverseProxyConfig) {
|
||||
c.Origin = sampleWebSocketOriginConfig()
|
||||
@@ -136,7 +145,7 @@ func TestReverseProxyConfig(t *testing.T) {
|
||||
c.Origin = sampleHTTPOriginConfig()
|
||||
}),
|
||||
sampleReverseProxyConfig(func(c *ReverseProxyConfig) {
|
||||
c.Origin = sampleHTTPOriginUnixPathConfig()
|
||||
c.Origin = sampleHTTPOriginConfig(unixPathOverride)
|
||||
}),
|
||||
sampleReverseProxyConfig(func(c *ReverseProxyConfig) {
|
||||
c.Origin = sampleWebSocketOriginConfig()
|
||||
@@ -224,7 +233,17 @@ func TestOriginConfigInvalidURL(t *testing.T) {
|
||||
|
||||
//////////////////////////////////////////////////////////////////////////////
|
||||
// Functions to generate sample data for ease of testing
|
||||
//
|
||||
// There's one "sample" function per struct type. Each goes like this:
|
||||
// 1. Initialize an instance of the relevant struct.
|
||||
// 2. Ensure the instance has no zero-valued fields. (This catches the
|
||||
// error-case where a field was added, but we forgot to add code to
|
||||
// marshal/unmarshal this field in CapnProto.)
|
||||
// 3. Apply one or more "override" functions (which accept a
|
||||
// pointer-to-struct, so they can mutate the instance).
|
||||
|
||||
// sampleClientConfig initializes a new ClientConfig literal,
|
||||
// applies any number of overrides to it, and returns it.
|
||||
func sampleClientConfig(overrides ...func(*ClientConfig)) *ClientConfig {
|
||||
sample := &ClientConfig{
|
||||
Version: Version(1337),
|
||||
@@ -247,6 +266,8 @@ func sampleClientConfig(overrides ...func(*ClientConfig)) *ClientConfig {
|
||||
return sample
|
||||
}
|
||||
|
||||
// sampleDoHProxyConfig initializes a new DoHProxyConfig struct,
|
||||
// applies any number of overrides to it, and returns it.
|
||||
func sampleDoHProxyConfig(overrides ...func(*DoHProxyConfig)) *DoHProxyConfig {
|
||||
sample := &DoHProxyConfig{
|
||||
ListenHost: "127.0.0.1",
|
||||
@@ -260,6 +281,8 @@ func sampleDoHProxyConfig(overrides ...func(*DoHProxyConfig)) *DoHProxyConfig {
|
||||
return sample
|
||||
}
|
||||
|
||||
// sampleReverseProxyConfig initializes a new ReverseProxyConfig struct,
|
||||
// applies any number of overrides to it, and returns it.
|
||||
func sampleReverseProxyConfig(overrides ...func(*ReverseProxyConfig)) *ReverseProxyConfig {
|
||||
sample := &ReverseProxyConfig{
|
||||
TunnelHostname: "hijk.example.com",
|
||||
@@ -275,9 +298,11 @@ func sampleReverseProxyConfig(overrides ...func(*ReverseProxyConfig)) *ReversePr
|
||||
return sample
|
||||
}
|
||||
|
||||
// sampleHTTPOriginConfig initializes a new HTTPOriginConfig literal,
|
||||
// applies any number of overrides to it, and returns it.
|
||||
func sampleHTTPOriginConfig(overrides ...func(*HTTPOriginConfig)) *HTTPOriginConfig {
|
||||
sample := &HTTPOriginConfig{
|
||||
URLString: "https.example.com",
|
||||
URLString: "https://example.com",
|
||||
TCPKeepAlive: 7 * time.Second,
|
||||
DialDualStack: true,
|
||||
TLSHandshakeTimeout: 11 * time.Second,
|
||||
@@ -297,28 +322,14 @@ func sampleHTTPOriginConfig(overrides ...func(*HTTPOriginConfig)) *HTTPOriginCon
|
||||
return sample
|
||||
}
|
||||
|
||||
func sampleHTTPOriginUnixPathConfig(overrides ...func(*HTTPOriginConfig)) *HTTPOriginConfig {
|
||||
sample := &HTTPOriginConfig{
|
||||
URLString: "unix:/var/lib/file.sock",
|
||||
TCPKeepAlive: 7 * time.Second,
|
||||
DialDualStack: true,
|
||||
TLSHandshakeTimeout: 11 * time.Second,
|
||||
TLSVerify: true,
|
||||
OriginCAPool: "/etc/cert.pem",
|
||||
OriginServerName: "secure.example.com",
|
||||
MaxIdleConnections: 19,
|
||||
IdleConnectionTimeout: 17 * time.Second,
|
||||
ProxyConnectionTimeout: 15 * time.Second,
|
||||
ExpectContinueTimeout: 21 * time.Second,
|
||||
ChunkedEncoding: true,
|
||||
}
|
||||
sample.ensureNoZeroFields()
|
||||
for _, f := range overrides {
|
||||
f(sample)
|
||||
}
|
||||
return sample
|
||||
// unixPathOverride sets the URLString of the given HTTPOriginConfig to be a
|
||||
// Unix socket (i.e. `unix:` scheme plus a file path)
|
||||
func unixPathOverride(sample *HTTPOriginConfig) {
|
||||
sample.URLString = "unix:/var/lib/file.sock"
|
||||
}
|
||||
|
||||
// sampleWebSocketOriginConfig initializes a new WebSocketOriginConfig
|
||||
// struct, applies any number of overrides to it, and returns it.
|
||||
func sampleWebSocketOriginConfig(overrides ...func(*WebSocketOriginConfig)) *WebSocketOriginConfig {
|
||||
sample := &WebSocketOriginConfig{
|
||||
URLString: "ssh://example.com",
|
||||
@@ -366,7 +377,6 @@ func (c *WebSocketOriginConfig) ensureNoZeroFields() {
|
||||
// include a field in the sample value, it won't be covered under tests.
|
||||
// This check reduces that risk by requiring fields to be either initialized
|
||||
// or explicitly uninitialized.
|
||||
// https://bitbucket.cfdata.org/projects/TUN/repos/cloudflared/pull-requests/151/overview?commentId=348012
|
||||
func ensureNoZeroFieldsInSample(ptrToSampleValue reflect.Value, allowedZeroFieldNames []string) {
|
||||
sampleValue := ptrToSampleValue.Elem()
|
||||
structType := ptrToSampleValue.Type().Elem()
|
||||
|
@@ -2,6 +2,7 @@ package pogs
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"time"
|
||||
|
||||
"github.com/cloudflare/cloudflared/tunnelrpc"
|
||||
@@ -127,52 +128,169 @@ func UnmarshalServerInfo(s tunnelrpc.ServerInfo) (*ServerInfo, error) {
|
||||
return p, err
|
||||
}
|
||||
|
||||
//go-sumtype:decl Scope
|
||||
type Scope interface {
|
||||
Value() string
|
||||
isScope()
|
||||
}
|
||||
|
||||
type SystemName struct {
|
||||
systemName string
|
||||
}
|
||||
|
||||
func NewSystemName(systemName string) *SystemName {
|
||||
return &SystemName{systemName: systemName}
|
||||
}
|
||||
|
||||
func (s *SystemName) Value() string { return s.systemName }
|
||||
|
||||
func (_ *SystemName) isScope() {}
|
||||
|
||||
type Group struct {
|
||||
group string
|
||||
}
|
||||
|
||||
func NewGroup(group string) *Group {
|
||||
return &Group{group: group}
|
||||
}
|
||||
|
||||
func (g *Group) Value() string { return g.group }
|
||||
|
||||
func (_ *Group) isScope() {}
|
||||
|
||||
func MarshalScope(s tunnelrpc.Scope, p Scope) error {
|
||||
ss := s.Value()
|
||||
switch scope := p.(type) {
|
||||
case *SystemName:
|
||||
ss.SetSystemName(scope.systemName)
|
||||
case *Group:
|
||||
ss.SetGroup(scope.group)
|
||||
default:
|
||||
return fmt.Errorf("unexpected Scope value: %v", p)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func UnmarshalScope(s tunnelrpc.Scope) (Scope, error) {
|
||||
ss := s.Value()
|
||||
switch ss.Which() {
|
||||
case tunnelrpc.Scope_value_Which_systemName:
|
||||
systemName, err := ss.SystemName()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return NewSystemName(systemName), nil
|
||||
case tunnelrpc.Scope_value_Which_group:
|
||||
group, err := ss.Group()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return NewGroup(group), nil
|
||||
default:
|
||||
return nil, fmt.Errorf("unexpected Scope tag: %v", ss.Which())
|
||||
}
|
||||
}
|
||||
|
||||
type ConnectParameters struct {
|
||||
OriginCert []byte
|
||||
CloudflaredID uuid.UUID
|
||||
NumPreviousAttempts uint8
|
||||
Tags []Tag
|
||||
CloudflaredVersion string
|
||||
}
|
||||
|
||||
// CapnpConnectParameters is ConnectParameters represented in Cap'n Proto build-in types
|
||||
type CapnpConnectParameters struct {
|
||||
OriginCert []byte
|
||||
CloudflaredID []byte
|
||||
NumPreviousAttempts uint8
|
||||
Tags []Tag
|
||||
CloudflaredVersion string
|
||||
Scope Scope
|
||||
}
|
||||
|
||||
func MarshalConnectParameters(s tunnelrpc.CapnpConnectParameters, p *ConnectParameters) error {
|
||||
if err := s.SetOriginCert(p.OriginCert); err != nil {
|
||||
return err
|
||||
}
|
||||
cloudflaredIDBytes, err := p.CloudflaredID.MarshalBinary()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
capnpConnectParameters := &CapnpConnectParameters{
|
||||
OriginCert: p.OriginCert,
|
||||
CloudflaredID: cloudflaredIDBytes,
|
||||
NumPreviousAttempts: p.NumPreviousAttempts,
|
||||
CloudflaredVersion: p.CloudflaredVersion,
|
||||
if err := s.SetCloudflaredID(cloudflaredIDBytes); err != nil {
|
||||
return err
|
||||
}
|
||||
return pogs.Insert(tunnelrpc.CapnpConnectParameters_TypeID, s.Struct, capnpConnectParameters)
|
||||
s.SetNumPreviousAttempts(p.NumPreviousAttempts)
|
||||
if len(p.Tags) > 0 {
|
||||
tagsCapnpList, err := s.NewTags(int32(len(p.Tags)))
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
for i, tag := range p.Tags {
|
||||
tagCapnp := tagsCapnpList.At(i)
|
||||
if err := tagCapnp.SetName(tag.Name); err != nil {
|
||||
return err
|
||||
}
|
||||
if err := tagCapnp.SetValue(tag.Value); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
}
|
||||
if err := s.SetCloudflaredVersion(p.CloudflaredVersion); err != nil {
|
||||
return err
|
||||
}
|
||||
scope, err := s.NewScope()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
return MarshalScope(scope, p.Scope)
|
||||
}
|
||||
|
||||
func UnmarshalConnectParameters(s tunnelrpc.CapnpConnectParameters) (*ConnectParameters, error) {
|
||||
p := new(CapnpConnectParameters)
|
||||
err := pogs.Extract(p, tunnelrpc.CapnpConnectParameters_TypeID, s.Struct)
|
||||
originCert, err := s.OriginCert()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
cloudflaredID, err := uuid.FromBytes(p.CloudflaredID)
|
||||
|
||||
cloudflaredIDBytes, err := s.CloudflaredID()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
cloudflaredID, err := uuid.FromBytes(cloudflaredIDBytes)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
tagsCapnpList, err := s.Tags()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
var tags []Tag
|
||||
for i := 0; i < tagsCapnpList.Len(); i++ {
|
||||
tagCapnp := tagsCapnpList.At(i)
|
||||
name, err := tagCapnp.Name()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
value, err := tagCapnp.Value()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
tags = append(tags, Tag{Name: name, Value: value})
|
||||
}
|
||||
|
||||
cloudflaredVersion, err := s.CloudflaredVersion()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
scopeCapnp, err := s.Scope()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
scope, err := UnmarshalScope(scopeCapnp)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return &ConnectParameters{
|
||||
OriginCert: p.OriginCert,
|
||||
OriginCert: originCert,
|
||||
CloudflaredID: cloudflaredID,
|
||||
NumPreviousAttempts: p.NumPreviousAttempts,
|
||||
CloudflaredVersion: p.CloudflaredVersion,
|
||||
NumPreviousAttempts: s.NumPreviousAttempts(),
|
||||
Tags: tags,
|
||||
CloudflaredVersion: cloudflaredVersion,
|
||||
Scope: scope,
|
||||
}, nil
|
||||
}
|
||||
|
||||
|
97
tunnelrpc/pogs/tunnelrpc_test.go
Normal file
97
tunnelrpc/pogs/tunnelrpc_test.go
Normal file
@@ -0,0 +1,97 @@
|
||||
package pogs
|
||||
|
||||
import (
|
||||
"reflect"
|
||||
"testing"
|
||||
|
||||
"github.com/cloudflare/cloudflared/tunnelrpc"
|
||||
"github.com/google/uuid"
|
||||
"github.com/stretchr/testify/assert"
|
||||
capnp "zombiezen.com/go/capnproto2"
|
||||
)
|
||||
|
||||
// Assert *SystemName implements Scope
|
||||
var _ Scope = (*SystemName)(nil)
|
||||
|
||||
// Assert *Group implements Scope
|
||||
var _ Scope = (*Group)(nil)
|
||||
|
||||
func TestScope(t *testing.T) {
|
||||
testCases := []Scope{
|
||||
&SystemName{systemName: "my_system"},
|
||||
&Group{group: "my_group"},
|
||||
}
|
||||
for i, testCase := range testCases {
|
||||
_, seg, err := capnp.NewMessage(capnp.SingleSegment(nil))
|
||||
capnpEntity, err := tunnelrpc.NewScope(seg)
|
||||
if !assert.NoError(t, err) {
|
||||
t.Fatal("Couldn't initialize a new message")
|
||||
}
|
||||
err = MarshalScope(capnpEntity, testCase)
|
||||
if !assert.NoError(t, err, "testCase index %v failed to marshal", i) {
|
||||
continue
|
||||
}
|
||||
result, err := UnmarshalScope(capnpEntity)
|
||||
if !assert.NoError(t, err, "testCase index %v failed to unmarshal", i) {
|
||||
continue
|
||||
}
|
||||
assert.Equal(t, testCase, result, "testCase index %v didn't preserve struct through marshalling and unmarshalling", i)
|
||||
}
|
||||
}
|
||||
|
||||
func TestConnectParameters(t *testing.T) {
|
||||
testCases := []*ConnectParameters{
|
||||
sampleConnectParameters(),
|
||||
sampleConnectParameters(func(c *ConnectParameters) {
|
||||
c.Scope = &SystemName{systemName: "my_system"}
|
||||
}),
|
||||
sampleConnectParameters(func(c *ConnectParameters) {
|
||||
c.Tags = nil
|
||||
}),
|
||||
}
|
||||
for i, testCase := range testCases {
|
||||
_, seg, err := capnp.NewMessage(capnp.SingleSegment(nil))
|
||||
capnpEntity, err := tunnelrpc.NewCapnpConnectParameters(seg)
|
||||
if !assert.NoError(t, err) {
|
||||
t.Fatal("Couldn't initialize a new message")
|
||||
}
|
||||
err = MarshalConnectParameters(capnpEntity, testCase)
|
||||
if !assert.NoError(t, err, "testCase index %v failed to marshal", i) {
|
||||
continue
|
||||
}
|
||||
result, err := UnmarshalConnectParameters(capnpEntity)
|
||||
if !assert.NoError(t, err, "testCase index %v failed to unmarshal", i) {
|
||||
continue
|
||||
}
|
||||
assert.Equal(t, testCase, result, "testCase index %v didn't preserve struct through marshalling and unmarshalling", i)
|
||||
}
|
||||
}
|
||||
|
||||
func sampleConnectParameters(overrides ...func(*ConnectParameters)) *ConnectParameters {
|
||||
cloudflaredID, err := uuid.Parse("ED7BA470-8E54-465E-825C-99712043E01C")
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
sample := &ConnectParameters{
|
||||
OriginCert: []byte("my-origin-cert"),
|
||||
CloudflaredID: cloudflaredID,
|
||||
NumPreviousAttempts: 19,
|
||||
Tags: []Tag{
|
||||
Tag{
|
||||
Name: "provision-method",
|
||||
Value: "new",
|
||||
},
|
||||
},
|
||||
CloudflaredVersion: "7.0",
|
||||
Scope: &Group{group: "my_group"},
|
||||
}
|
||||
sample.ensureNoZeroFields()
|
||||
for _, f := range overrides {
|
||||
f(sample)
|
||||
}
|
||||
return sample
|
||||
}
|
||||
|
||||
func (c *ConnectParameters) ensureNoZeroFields() {
|
||||
ensureNoZeroFieldsInSample(reflect.ValueOf(c), []string{})
|
||||
}
|
Reference in New Issue
Block a user