mirror of
https://github.com/cloudflare/cloudflared.git
synced 2025-07-27 08:09:58 +00:00
TUN-2260: add name/group to CapnpConnectParameters, remove Scope
This commit is contained in:
@@ -8,39 +8,6 @@ import (
|
||||
"github.com/pkg/errors"
|
||||
)
|
||||
|
||||
// ScopeUnmarshaler can marshal a Scope pog from JSON.
|
||||
type ScopeUnmarshaler struct {
|
||||
Scope Scope
|
||||
}
|
||||
|
||||
// UnmarshalJSON takes in a JSON string, and attempts to marshal it into a Scope.
|
||||
// If successful, the Scope member of this ScopeUnmarshaler is set and nil is returned.
|
||||
// If unsuccessful, returns an error.
|
||||
func (su *ScopeUnmarshaler) UnmarshalJSON(b []byte) error {
|
||||
var scopeJSON map[string]interface{}
|
||||
if err := json.Unmarshal(b, &scopeJSON); err != nil {
|
||||
return errors.Wrapf(err, "cannot unmarshal %s into scopeJSON", string(b))
|
||||
}
|
||||
|
||||
if group, ok := scopeJSON["group"]; ok {
|
||||
if val, ok := group.(string); ok {
|
||||
su.Scope = NewGroup(val)
|
||||
return nil
|
||||
}
|
||||
return fmt.Errorf("JSON should have been a Scope, but the 'group' key contained %v", group)
|
||||
}
|
||||
|
||||
if systemName, ok := scopeJSON["system_name"]; ok {
|
||||
if val, ok := systemName.(string); ok {
|
||||
su.Scope = NewSystemName(val)
|
||||
return nil
|
||||
}
|
||||
return fmt.Errorf("JSON should have been a Scope, but the 'system_name' key contained %v", systemName)
|
||||
}
|
||||
|
||||
return fmt.Errorf("JSON should have been an object with one root key, either 'system_name' or 'group'")
|
||||
}
|
||||
|
||||
// OriginConfigJSONHandler is a wrapper to serialize OriginConfig with type information, and deserialize JSON
|
||||
// into an OriginConfig.
|
||||
type OriginConfigJSONHandler struct {
|
||||
|
@@ -9,59 +9,6 @@ import (
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
func TestScopeUnmarshaler_UnmarshalJSON(t *testing.T) {
|
||||
type fields struct {
|
||||
Scope Scope
|
||||
}
|
||||
type args struct {
|
||||
b []byte
|
||||
}
|
||||
tests := []struct {
|
||||
name string
|
||||
fields fields
|
||||
args args
|
||||
wantErr bool
|
||||
wantScope Scope
|
||||
}{
|
||||
{
|
||||
name: "group_successful",
|
||||
args: args{b: []byte(`{"group": "my-group"}`)},
|
||||
wantScope: NewGroup("my-group"),
|
||||
},
|
||||
{
|
||||
name: "system_name_successful",
|
||||
args: args{b: []byte(`{"system_name": "my-computer"}`)},
|
||||
wantScope: NewSystemName("my-computer"),
|
||||
},
|
||||
{
|
||||
name: "not_a_scope",
|
||||
args: args{b: []byte(`{"x": "y"}`)},
|
||||
wantErr: true,
|
||||
},
|
||||
{
|
||||
name: "malformed_group",
|
||||
args: args{b: []byte(`{"group": ["a", "b"]}`)},
|
||||
wantErr: true,
|
||||
},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
su := &ScopeUnmarshaler{
|
||||
Scope: tt.fields.Scope,
|
||||
}
|
||||
err := su.UnmarshalJSON(tt.args.b)
|
||||
if !tt.wantErr {
|
||||
if err != nil {
|
||||
t.Errorf("ScopeUnmarshaler.UnmarshalJSON() error = %v, wantErr %v", err, tt.wantErr)
|
||||
}
|
||||
if !eqScope(tt.wantScope, su.Scope) {
|
||||
t.Errorf("Wanted scope %v but got scope %v", tt.wantScope, su.Scope)
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestUnmarshalOrigin(t *testing.T) {
|
||||
tests := []struct {
|
||||
jsonLiteral string
|
||||
@@ -336,7 +283,3 @@ type prettyJSON string
|
||||
func prettyToValidJSON(prettyJSON string) string {
|
||||
return strings.ReplaceAll(strings.ReplaceAll(prettyJSON, "\n", ""), "\t", "")
|
||||
}
|
||||
|
||||
func eqScope(s1, s2 Scope) bool {
|
||||
return s1.Value() == s2.Value() && s1.PostgresType() == s2.PostgresType()
|
||||
}
|
||||
|
@@ -2,7 +2,6 @@ package pogs
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"time"
|
||||
|
||||
"github.com/cloudflare/cloudflared/tunnelrpc"
|
||||
@@ -129,82 +128,14 @@ func UnmarshalServerInfo(s tunnelrpc.ServerInfo) (*ServerInfo, error) {
|
||||
return p, err
|
||||
}
|
||||
|
||||
//go-sumtype:decl Scope
|
||||
type Scope interface {
|
||||
Value() string
|
||||
PostgresType() string
|
||||
GraphQLType() string
|
||||
isScope()
|
||||
}
|
||||
|
||||
type SystemName struct {
|
||||
systemName string
|
||||
}
|
||||
|
||||
func NewSystemName(systemName string) *SystemName {
|
||||
return &SystemName{systemName: systemName}
|
||||
}
|
||||
|
||||
func (s *SystemName) Value() string { return s.systemName }
|
||||
func (_ *SystemName) PostgresType() string { return "system_name" }
|
||||
func (_ *SystemName) GraphQLType() string { return "SYSTEM_NAME" }
|
||||
|
||||
func (_ *SystemName) isScope() {}
|
||||
|
||||
type Group struct {
|
||||
group string
|
||||
}
|
||||
|
||||
func NewGroup(group string) *Group {
|
||||
return &Group{group: group}
|
||||
}
|
||||
|
||||
func (g *Group) Value() string { return g.group }
|
||||
func (_ *Group) PostgresType() string { return "group" }
|
||||
func (_ *Group) GraphQLType() string { return "GROUP" }
|
||||
|
||||
func (_ *Group) isScope() {}
|
||||
|
||||
func MarshalScope(s tunnelrpc.Scope, p Scope) error {
|
||||
ss := s.Value()
|
||||
switch scope := p.(type) {
|
||||
case *SystemName:
|
||||
ss.SetSystemName(scope.systemName)
|
||||
case *Group:
|
||||
ss.SetGroup(scope.group)
|
||||
default:
|
||||
return fmt.Errorf("unexpected Scope value: %v", p)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func UnmarshalScope(s tunnelrpc.Scope) (Scope, error) {
|
||||
ss := s.Value()
|
||||
switch ss.Which() {
|
||||
case tunnelrpc.Scope_value_Which_systemName:
|
||||
systemName, err := ss.SystemName()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return NewSystemName(systemName), nil
|
||||
case tunnelrpc.Scope_value_Which_group:
|
||||
group, err := ss.Group()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return NewGroup(group), nil
|
||||
default:
|
||||
return nil, fmt.Errorf("unexpected Scope tag: %v", ss.Which())
|
||||
}
|
||||
}
|
||||
|
||||
type ConnectParameters struct {
|
||||
OriginCert []byte
|
||||
CloudflaredID uuid.UUID
|
||||
NumPreviousAttempts uint8
|
||||
Tags []Tag
|
||||
CloudflaredVersion string
|
||||
Scope Scope
|
||||
Name string
|
||||
Group string
|
||||
}
|
||||
|
||||
func MarshalConnectParameters(s tunnelrpc.CapnpConnectParameters, p *ConnectParameters) error {
|
||||
@@ -237,11 +168,13 @@ func MarshalConnectParameters(s tunnelrpc.CapnpConnectParameters, p *ConnectPara
|
||||
if err := s.SetCloudflaredVersion(p.CloudflaredVersion); err != nil {
|
||||
return err
|
||||
}
|
||||
scope, err := s.NewScope()
|
||||
if err != nil {
|
||||
if err := s.SetName(p.Name); err != nil {
|
||||
return err
|
||||
}
|
||||
return MarshalScope(scope, p.Scope)
|
||||
if err := s.SetGroup(p.Group); err != nil {
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func UnmarshalConnectParameters(s tunnelrpc.CapnpConnectParameters) (*ConnectParameters, error) {
|
||||
@@ -282,11 +215,12 @@ func UnmarshalConnectParameters(s tunnelrpc.CapnpConnectParameters) (*ConnectPar
|
||||
return nil, err
|
||||
}
|
||||
|
||||
scopeCapnp, err := s.Scope()
|
||||
name, err := s.Name()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
scope, err := UnmarshalScope(scopeCapnp)
|
||||
|
||||
group, err := s.Group()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@@ -297,7 +231,8 @@ func UnmarshalConnectParameters(s tunnelrpc.CapnpConnectParameters) (*ConnectPar
|
||||
NumPreviousAttempts: s.NumPreviousAttempts(),
|
||||
Tags: tags,
|
||||
CloudflaredVersion: cloudflaredVersion,
|
||||
Scope: scope,
|
||||
Name: name,
|
||||
Group: group,
|
||||
}, nil
|
||||
}
|
||||
|
||||
|
@@ -11,35 +11,6 @@ import (
|
||||
capnp "zombiezen.com/go/capnproto2"
|
||||
)
|
||||
|
||||
// Assert *SystemName implements Scope
|
||||
var _ Scope = (*SystemName)(nil)
|
||||
|
||||
// Assert *Group implements Scope
|
||||
var _ Scope = (*Group)(nil)
|
||||
|
||||
func TestScope(t *testing.T) {
|
||||
testCases := []Scope{
|
||||
&SystemName{systemName: "my_system"},
|
||||
&Group{group: "my_group"},
|
||||
}
|
||||
for i, testCase := range testCases {
|
||||
_, seg, err := capnp.NewMessage(capnp.SingleSegment(nil))
|
||||
capnpEntity, err := tunnelrpc.NewScope(seg)
|
||||
if !assert.NoError(t, err) {
|
||||
t.Fatal("Couldn't initialize a new message")
|
||||
}
|
||||
err = MarshalScope(capnpEntity, testCase)
|
||||
if !assert.NoError(t, err, "testCase index %v failed to marshal", i) {
|
||||
continue
|
||||
}
|
||||
result, err := UnmarshalScope(capnpEntity)
|
||||
if !assert.NoError(t, err, "testCase index %v failed to unmarshal", i) {
|
||||
continue
|
||||
}
|
||||
assert.Equal(t, testCase, result, "testCase index %v didn't preserve struct through marshalling and unmarshalling", i)
|
||||
}
|
||||
}
|
||||
|
||||
func sampleTestConnectResult() *ConnectResult {
|
||||
return &ConnectResult{
|
||||
Err: &ConnectError{
|
||||
@@ -78,7 +49,7 @@ func TestConnectParameters(t *testing.T) {
|
||||
testCases := []*ConnectParameters{
|
||||
sampleConnectParameters(),
|
||||
sampleConnectParameters(func(c *ConnectParameters) {
|
||||
c.Scope = &SystemName{systemName: "my_system"}
|
||||
c.Name = ""
|
||||
}),
|
||||
sampleConnectParameters(func(c *ConnectParameters) {
|
||||
c.Tags = nil
|
||||
@@ -118,7 +89,8 @@ func sampleConnectParameters(overrides ...func(*ConnectParameters)) *ConnectPara
|
||||
},
|
||||
},
|
||||
CloudflaredVersion: "7.0",
|
||||
Scope: &Group{group: "my_group"},
|
||||
Name: "My Computer",
|
||||
Group: "www",
|
||||
}
|
||||
sample.ensureNoZeroFields()
|
||||
for _, f := range overrides {
|
||||
|
Reference in New Issue
Block a user