TUN-2117: read group/system-name from CLI, send it to edge

This commit is contained in:
Nick Vollmar
2019-07-30 13:55:34 -05:00
parent 3c93d9b300
commit 74f3a55c57
9 changed files with 714 additions and 273 deletions

View File

@@ -12,6 +12,15 @@ import (
capnp "zombiezen.com/go/capnproto2"
)
// Assert *HTTPOriginConfig implements OriginConfig
var _ OriginConfig = (*HTTPOriginConfig)(nil)
// Assert *WebSocketOriginConfig implements OriginConfig
var _ OriginConfig = (*WebSocketOriginConfig)(nil)
// Assert *HelloWorldOriginConfig implements OriginConfig
var _ OriginConfig = (*HelloWorldOriginConfig)(nil)
func TestVersion(t *testing.T) {
firstVersion := InitVersion()
secondVersion := Version(1)
@@ -35,7 +44,7 @@ func TestClientConfig(t *testing.T) {
c.Origin = sampleHTTPOriginConfig()
}),
sampleReverseProxyConfig(func(c *ReverseProxyConfig) {
c.Origin = sampleHTTPOriginUnixPathConfig()
c.Origin = sampleHTTPOriginConfig(unixPathOverride)
}),
sampleReverseProxyConfig(func(c *ReverseProxyConfig) {
c.Origin = sampleWebSocketOriginConfig()
@@ -136,7 +145,7 @@ func TestReverseProxyConfig(t *testing.T) {
c.Origin = sampleHTTPOriginConfig()
}),
sampleReverseProxyConfig(func(c *ReverseProxyConfig) {
c.Origin = sampleHTTPOriginUnixPathConfig()
c.Origin = sampleHTTPOriginConfig(unixPathOverride)
}),
sampleReverseProxyConfig(func(c *ReverseProxyConfig) {
c.Origin = sampleWebSocketOriginConfig()
@@ -224,7 +233,17 @@ func TestOriginConfigInvalidURL(t *testing.T) {
//////////////////////////////////////////////////////////////////////////////
// Functions to generate sample data for ease of testing
//
// There's one "sample" function per struct type. Each goes like this:
// 1. Initialize an instance of the relevant struct.
// 2. Ensure the instance has no zero-valued fields. (This catches the
// error-case where a field was added, but we forgot to add code to
// marshal/unmarshal this field in CapnProto.)
// 3. Apply one or more "override" functions (which accept a
// pointer-to-struct, so they can mutate the instance).
// sampleClientConfig initializes a new ClientConfig literal,
// applies any number of overrides to it, and returns it.
func sampleClientConfig(overrides ...func(*ClientConfig)) *ClientConfig {
sample := &ClientConfig{
Version: Version(1337),
@@ -247,6 +266,8 @@ func sampleClientConfig(overrides ...func(*ClientConfig)) *ClientConfig {
return sample
}
// sampleDoHProxyConfig initializes a new DoHProxyConfig struct,
// applies any number of overrides to it, and returns it.
func sampleDoHProxyConfig(overrides ...func(*DoHProxyConfig)) *DoHProxyConfig {
sample := &DoHProxyConfig{
ListenHost: "127.0.0.1",
@@ -260,6 +281,8 @@ func sampleDoHProxyConfig(overrides ...func(*DoHProxyConfig)) *DoHProxyConfig {
return sample
}
// sampleReverseProxyConfig initializes a new ReverseProxyConfig struct,
// applies any number of overrides to it, and returns it.
func sampleReverseProxyConfig(overrides ...func(*ReverseProxyConfig)) *ReverseProxyConfig {
sample := &ReverseProxyConfig{
TunnelHostname: "hijk.example.com",
@@ -275,9 +298,11 @@ func sampleReverseProxyConfig(overrides ...func(*ReverseProxyConfig)) *ReversePr
return sample
}
// sampleHTTPOriginConfig initializes a new HTTPOriginConfig literal,
// applies any number of overrides to it, and returns it.
func sampleHTTPOriginConfig(overrides ...func(*HTTPOriginConfig)) *HTTPOriginConfig {
sample := &HTTPOriginConfig{
URLString: "https.example.com",
URLString: "https://example.com",
TCPKeepAlive: 7 * time.Second,
DialDualStack: true,
TLSHandshakeTimeout: 11 * time.Second,
@@ -297,28 +322,14 @@ func sampleHTTPOriginConfig(overrides ...func(*HTTPOriginConfig)) *HTTPOriginCon
return sample
}
func sampleHTTPOriginUnixPathConfig(overrides ...func(*HTTPOriginConfig)) *HTTPOriginConfig {
sample := &HTTPOriginConfig{
URLString: "unix:/var/lib/file.sock",
TCPKeepAlive: 7 * time.Second,
DialDualStack: true,
TLSHandshakeTimeout: 11 * time.Second,
TLSVerify: true,
OriginCAPool: "/etc/cert.pem",
OriginServerName: "secure.example.com",
MaxIdleConnections: 19,
IdleConnectionTimeout: 17 * time.Second,
ProxyConnectionTimeout: 15 * time.Second,
ExpectContinueTimeout: 21 * time.Second,
ChunkedEncoding: true,
}
sample.ensureNoZeroFields()
for _, f := range overrides {
f(sample)
}
return sample
// unixPathOverride sets the URLString of the given HTTPOriginConfig to be a
// Unix socket (i.e. `unix:` scheme plus a file path)
func unixPathOverride(sample *HTTPOriginConfig) {
sample.URLString = "unix:/var/lib/file.sock"
}
// sampleWebSocketOriginConfig initializes a new WebSocketOriginConfig
// struct, applies any number of overrides to it, and returns it.
func sampleWebSocketOriginConfig(overrides ...func(*WebSocketOriginConfig)) *WebSocketOriginConfig {
sample := &WebSocketOriginConfig{
URLString: "ssh://example.com",
@@ -366,7 +377,6 @@ func (c *WebSocketOriginConfig) ensureNoZeroFields() {
// include a field in the sample value, it won't be covered under tests.
// This check reduces that risk by requiring fields to be either initialized
// or explicitly uninitialized.
// https://bitbucket.cfdata.org/projects/TUN/repos/cloudflared/pull-requests/151/overview?commentId=348012
func ensureNoZeroFieldsInSample(ptrToSampleValue reflect.Value, allowedZeroFieldNames []string) {
sampleValue := ptrToSampleValue.Elem()
structType := ptrToSampleValue.Type().Elem()

View File

@@ -2,6 +2,7 @@ package pogs
import (
"context"
"fmt"
"time"
"github.com/cloudflare/cloudflared/tunnelrpc"
@@ -127,52 +128,169 @@ func UnmarshalServerInfo(s tunnelrpc.ServerInfo) (*ServerInfo, error) {
return p, err
}
//go-sumtype:decl Scope
type Scope interface {
Value() string
isScope()
}
type SystemName struct {
systemName string
}
func NewSystemName(systemName string) *SystemName {
return &SystemName{systemName: systemName}
}
func (s *SystemName) Value() string { return s.systemName }
func (_ *SystemName) isScope() {}
type Group struct {
group string
}
func NewGroup(group string) *Group {
return &Group{group: group}
}
func (g *Group) Value() string { return g.group }
func (_ *Group) isScope() {}
func MarshalScope(s tunnelrpc.Scope, p Scope) error {
ss := s.Value()
switch scope := p.(type) {
case *SystemName:
ss.SetSystemName(scope.systemName)
case *Group:
ss.SetGroup(scope.group)
default:
return fmt.Errorf("unexpected Scope value: %v", p)
}
return nil
}
func UnmarshalScope(s tunnelrpc.Scope) (Scope, error) {
ss := s.Value()
switch ss.Which() {
case tunnelrpc.Scope_value_Which_systemName:
systemName, err := ss.SystemName()
if err != nil {
return nil, err
}
return NewSystemName(systemName), nil
case tunnelrpc.Scope_value_Which_group:
group, err := ss.Group()
if err != nil {
return nil, err
}
return NewGroup(group), nil
default:
return nil, fmt.Errorf("unexpected Scope tag: %v", ss.Which())
}
}
type ConnectParameters struct {
OriginCert []byte
CloudflaredID uuid.UUID
NumPreviousAttempts uint8
Tags []Tag
CloudflaredVersion string
}
// CapnpConnectParameters is ConnectParameters represented in Cap'n Proto build-in types
type CapnpConnectParameters struct {
OriginCert []byte
CloudflaredID []byte
NumPreviousAttempts uint8
Tags []Tag
CloudflaredVersion string
Scope Scope
}
func MarshalConnectParameters(s tunnelrpc.CapnpConnectParameters, p *ConnectParameters) error {
if err := s.SetOriginCert(p.OriginCert); err != nil {
return err
}
cloudflaredIDBytes, err := p.CloudflaredID.MarshalBinary()
if err != nil {
return err
}
capnpConnectParameters := &CapnpConnectParameters{
OriginCert: p.OriginCert,
CloudflaredID: cloudflaredIDBytes,
NumPreviousAttempts: p.NumPreviousAttempts,
CloudflaredVersion: p.CloudflaredVersion,
if err := s.SetCloudflaredID(cloudflaredIDBytes); err != nil {
return err
}
return pogs.Insert(tunnelrpc.CapnpConnectParameters_TypeID, s.Struct, capnpConnectParameters)
s.SetNumPreviousAttempts(p.NumPreviousAttempts)
if len(p.Tags) > 0 {
tagsCapnpList, err := s.NewTags(int32(len(p.Tags)))
if err != nil {
return err
}
for i, tag := range p.Tags {
tagCapnp := tagsCapnpList.At(i)
if err := tagCapnp.SetName(tag.Name); err != nil {
return err
}
if err := tagCapnp.SetValue(tag.Value); err != nil {
return err
}
}
}
if err := s.SetCloudflaredVersion(p.CloudflaredVersion); err != nil {
return err
}
scope, err := s.NewScope()
if err != nil {
return err
}
return MarshalScope(scope, p.Scope)
}
func UnmarshalConnectParameters(s tunnelrpc.CapnpConnectParameters) (*ConnectParameters, error) {
p := new(CapnpConnectParameters)
err := pogs.Extract(p, tunnelrpc.CapnpConnectParameters_TypeID, s.Struct)
originCert, err := s.OriginCert()
if err != nil {
return nil, err
}
cloudflaredID, err := uuid.FromBytes(p.CloudflaredID)
cloudflaredIDBytes, err := s.CloudflaredID()
if err != nil {
return nil, err
}
cloudflaredID, err := uuid.FromBytes(cloudflaredIDBytes)
if err != nil {
return nil, err
}
tagsCapnpList, err := s.Tags()
if err != nil {
return nil, err
}
var tags []Tag
for i := 0; i < tagsCapnpList.Len(); i++ {
tagCapnp := tagsCapnpList.At(i)
name, err := tagCapnp.Name()
if err != nil {
return nil, err
}
value, err := tagCapnp.Value()
if err != nil {
return nil, err
}
tags = append(tags, Tag{Name: name, Value: value})
}
cloudflaredVersion, err := s.CloudflaredVersion()
if err != nil {
return nil, err
}
scopeCapnp, err := s.Scope()
if err != nil {
return nil, err
}
scope, err := UnmarshalScope(scopeCapnp)
if err != nil {
return nil, err
}
return &ConnectParameters{
OriginCert: p.OriginCert,
OriginCert: originCert,
CloudflaredID: cloudflaredID,
NumPreviousAttempts: p.NumPreviousAttempts,
CloudflaredVersion: p.CloudflaredVersion,
NumPreviousAttempts: s.NumPreviousAttempts(),
Tags: tags,
CloudflaredVersion: cloudflaredVersion,
Scope: scope,
}, nil
}

View File

@@ -0,0 +1,97 @@
package pogs
import (
"reflect"
"testing"
"github.com/cloudflare/cloudflared/tunnelrpc"
"github.com/google/uuid"
"github.com/stretchr/testify/assert"
capnp "zombiezen.com/go/capnproto2"
)
// Assert *SystemName implements Scope
var _ Scope = (*SystemName)(nil)
// Assert *Group implements Scope
var _ Scope = (*Group)(nil)
func TestScope(t *testing.T) {
testCases := []Scope{
&SystemName{systemName: "my_system"},
&Group{group: "my_group"},
}
for i, testCase := range testCases {
_, seg, err := capnp.NewMessage(capnp.SingleSegment(nil))
capnpEntity, err := tunnelrpc.NewScope(seg)
if !assert.NoError(t, err) {
t.Fatal("Couldn't initialize a new message")
}
err = MarshalScope(capnpEntity, testCase)
if !assert.NoError(t, err, "testCase index %v failed to marshal", i) {
continue
}
result, err := UnmarshalScope(capnpEntity)
if !assert.NoError(t, err, "testCase index %v failed to unmarshal", i) {
continue
}
assert.Equal(t, testCase, result, "testCase index %v didn't preserve struct through marshalling and unmarshalling", i)
}
}
func TestConnectParameters(t *testing.T) {
testCases := []*ConnectParameters{
sampleConnectParameters(),
sampleConnectParameters(func(c *ConnectParameters) {
c.Scope = &SystemName{systemName: "my_system"}
}),
sampleConnectParameters(func(c *ConnectParameters) {
c.Tags = nil
}),
}
for i, testCase := range testCases {
_, seg, err := capnp.NewMessage(capnp.SingleSegment(nil))
capnpEntity, err := tunnelrpc.NewCapnpConnectParameters(seg)
if !assert.NoError(t, err) {
t.Fatal("Couldn't initialize a new message")
}
err = MarshalConnectParameters(capnpEntity, testCase)
if !assert.NoError(t, err, "testCase index %v failed to marshal", i) {
continue
}
result, err := UnmarshalConnectParameters(capnpEntity)
if !assert.NoError(t, err, "testCase index %v failed to unmarshal", i) {
continue
}
assert.Equal(t, testCase, result, "testCase index %v didn't preserve struct through marshalling and unmarshalling", i)
}
}
func sampleConnectParameters(overrides ...func(*ConnectParameters)) *ConnectParameters {
cloudflaredID, err := uuid.Parse("ED7BA470-8E54-465E-825C-99712043E01C")
if err != nil {
panic(err)
}
sample := &ConnectParameters{
OriginCert: []byte("my-origin-cert"),
CloudflaredID: cloudflaredID,
NumPreviousAttempts: 19,
Tags: []Tag{
Tag{
Name: "provision-method",
Value: "new",
},
},
CloudflaredVersion: "7.0",
Scope: &Group{group: "my_group"},
}
sample.ensureNoZeroFields()
for _, f := range overrides {
f(sample)
}
return sample
}
func (c *ConnectParameters) ensureNoZeroFields() {
ensureNoZeroFieldsInSample(reflect.ValueOf(c), []string{})
}