mirror of
https://github.com/cloudflare/cloudflared.git
synced 2025-07-27 08:09:58 +00:00
TUN-2307: Capnp is the only serialization format used in tunnelpogs
This commit is contained in:

committed by
Chung Ting Huang

parent
ff795a7beb
commit
fe032843f3
@@ -4,7 +4,6 @@ import (
|
||||
"context"
|
||||
"crypto/tls"
|
||||
"crypto/x509"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"net"
|
||||
"net/http"
|
||||
@@ -29,11 +28,43 @@ import (
|
||||
|
||||
// ClientConfig is a collection of FallibleConfig that determines how cloudflared should function
|
||||
type ClientConfig struct {
|
||||
Version Version `json:"version"`
|
||||
SupervisorConfig *SupervisorConfig `json:"supervisor_config"`
|
||||
EdgeConnectionConfig *EdgeConnectionConfig `json:"edge_connection_config"`
|
||||
DoHProxyConfigs []*DoHProxyConfig `json:"doh_proxy_configs" capnp:"dohProxyConfigs"`
|
||||
ReverseProxyConfigs []*ReverseProxyConfig `json:"reverse_proxy_configs"`
|
||||
Version Version
|
||||
SupervisorConfig *SupervisorConfig
|
||||
EdgeConnectionConfig *EdgeConnectionConfig
|
||||
DoHProxyConfigs []*DoHProxyConfig `capnp:"dohProxyConfigs"`
|
||||
ReverseProxyConfigs []*ReverseProxyConfig
|
||||
}
|
||||
|
||||
func (c *ClientConfig) MarshalBytes() ([]byte, error) {
|
||||
msg, firstSeg, err := capnp.NewMessage(capnp.SingleSegment(nil))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
capnpEntity, err := tunnelrpc.NewRootClientConfig(firstSeg)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
err = MarshalClientConfig(capnpEntity, c)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return msg.Marshal()
|
||||
}
|
||||
|
||||
func UnmarshalClientConfigFromBytes(clientConfigBytes []byte) (*ClientConfig, error) {
|
||||
msg, err := capnp.Unmarshal(clientConfigBytes)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
capnpClientConfig, err := tunnelrpc.ReadRootClientConfig(msg)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
pogsClientConfig, err := UnmarshalClientConfig(capnpClientConfig)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return pogsClientConfig, nil
|
||||
}
|
||||
|
||||
// Version type models the version of a ClientConfig
|
||||
@@ -52,16 +83,17 @@ func (v Version) String() string {
|
||||
}
|
||||
|
||||
// FallibleConfig is an interface implemented by configs that cloudflared might not be able to apply
|
||||
//go-sumtype:decl FallibleConfig
|
||||
type FallibleConfig interface {
|
||||
FailReason(err error) string
|
||||
jsonType() string
|
||||
isFallibleConfig()
|
||||
}
|
||||
|
||||
// SupervisorConfig specifies config of components managed by Supervisor other than ConnectionManager
|
||||
type SupervisorConfig struct {
|
||||
AutoUpdateFrequency time.Duration `json:"auto_update_frequency"`
|
||||
MetricsUpdateFrequency time.Duration `json:"metrics_update_frequency"`
|
||||
GracePeriod time.Duration `json:"grace_period"`
|
||||
AutoUpdateFrequency time.Duration
|
||||
MetricsUpdateFrequency time.Duration
|
||||
GracePeriod time.Duration
|
||||
}
|
||||
|
||||
// FailReason impelents FallibleConfig interface for SupervisorConfig
|
||||
@@ -69,23 +101,15 @@ func (sc *SupervisorConfig) FailReason(err error) string {
|
||||
return fmt.Sprintf("Cannot apply SupervisorConfig, err: %v", err)
|
||||
}
|
||||
|
||||
func (sc *SupervisorConfig) MarshalJSON() ([]byte, error) {
|
||||
marshaler := make(map[string]SupervisorConfig, 1)
|
||||
marshaler[sc.jsonType()] = *sc
|
||||
return json.Marshal(marshaler)
|
||||
}
|
||||
|
||||
func (sc *SupervisorConfig) jsonType() string {
|
||||
return "supervisor_config"
|
||||
}
|
||||
func (_ *SupervisorConfig) isFallibleConfig() {}
|
||||
|
||||
// EdgeConnectionConfig specifies what parameters and how may connections should ConnectionManager establish with edge
|
||||
type EdgeConnectionConfig struct {
|
||||
NumHAConnections uint8 `json:"num_ha_connections"`
|
||||
HeartbeatInterval time.Duration `json:"heartbeat_interval"`
|
||||
Timeout time.Duration `json:"timeout"`
|
||||
MaxFailedHeartbeats uint64 `json:"max_failed_heartbeats"`
|
||||
UserCredentialPath string `json:"user_credential_path"`
|
||||
NumHAConnections uint8
|
||||
HeartbeatInterval time.Duration
|
||||
Timeout time.Duration
|
||||
MaxFailedHeartbeats uint64
|
||||
UserCredentialPath string
|
||||
}
|
||||
|
||||
// FailReason impelents FallibleConfig interface for EdgeConnectionConfig
|
||||
@@ -93,21 +117,13 @@ func (cmc *EdgeConnectionConfig) FailReason(err error) string {
|
||||
return fmt.Sprintf("Cannot apply EdgeConnectionConfig, err: %v", err)
|
||||
}
|
||||
|
||||
func (cmc *EdgeConnectionConfig) MarshalJSON() ([]byte, error) {
|
||||
marshaler := make(map[string]EdgeConnectionConfig, 1)
|
||||
marshaler[cmc.jsonType()] = *cmc
|
||||
return json.Marshal(marshaler)
|
||||
}
|
||||
|
||||
func (cmc *EdgeConnectionConfig) jsonType() string {
|
||||
return "edge_connection_config"
|
||||
}
|
||||
func (_ *EdgeConnectionConfig) isFallibleConfig() {}
|
||||
|
||||
// DoHProxyConfig is configuration for DNS over HTTPS service
|
||||
type DoHProxyConfig struct {
|
||||
ListenHost string `json:"listen_host"`
|
||||
ListenPort uint16 `json:"listen_port"`
|
||||
Upstreams []string `json:"upstreams"`
|
||||
ListenHost string
|
||||
ListenPort uint16
|
||||
Upstreams []string
|
||||
}
|
||||
|
||||
// FailReason impelents FallibleConfig interface for DoHProxyConfig
|
||||
@@ -115,23 +131,15 @@ func (dpc *DoHProxyConfig) FailReason(err error) string {
|
||||
return fmt.Sprintf("Cannot apply DoHProxyConfig, err: %v", err)
|
||||
}
|
||||
|
||||
func (dpc *DoHProxyConfig) MarshalJSON() ([]byte, error) {
|
||||
marshaler := make(map[string]DoHProxyConfig, 1)
|
||||
marshaler[dpc.jsonType()] = *dpc
|
||||
return json.Marshal(marshaler)
|
||||
}
|
||||
|
||||
func (dpc *DoHProxyConfig) jsonType() string {
|
||||
return "doh_proxy_config"
|
||||
}
|
||||
func (_ *DoHProxyConfig) isFallibleConfig() {}
|
||||
|
||||
// ReverseProxyConfig how and for what hostnames can this cloudflared proxy
|
||||
type ReverseProxyConfig struct {
|
||||
TunnelHostname h2mux.TunnelHostname `json:"tunnel_hostname"`
|
||||
OriginConfigJSONHandler *OriginConfigJSONHandler `json:"origin_config"`
|
||||
Retries uint64 `json:"retries"`
|
||||
ConnectionTimeout time.Duration `json:"connection_timeout"`
|
||||
CompressionQuality uint64 `json:"compression_quality"`
|
||||
TunnelHostname h2mux.TunnelHostname
|
||||
OriginConfig OriginConfig
|
||||
Retries uint64
|
||||
ConnectionTimeout time.Duration
|
||||
CompressionQuality uint64
|
||||
}
|
||||
|
||||
func NewReverseProxyConfig(
|
||||
@@ -145,11 +153,11 @@ func NewReverseProxyConfig(
|
||||
return nil, fmt.Errorf("NewReverseProxyConfig: originConfigUnmarshaler was null")
|
||||
}
|
||||
return &ReverseProxyConfig{
|
||||
TunnelHostname: h2mux.TunnelHostname(tunnelHostname),
|
||||
OriginConfigJSONHandler: &OriginConfigJSONHandler{originConfig},
|
||||
Retries: retries,
|
||||
ConnectionTimeout: connectionTimeout,
|
||||
CompressionQuality: compressionQuality,
|
||||
TunnelHostname: h2mux.TunnelHostname(tunnelHostname),
|
||||
OriginConfig: originConfig,
|
||||
Retries: retries,
|
||||
ConnectionTimeout: connectionTimeout,
|
||||
CompressionQuality: compressionQuality,
|
||||
}, nil
|
||||
}
|
||||
|
||||
@@ -158,58 +166,29 @@ func (rpc *ReverseProxyConfig) FailReason(err error) string {
|
||||
return fmt.Sprintf("Cannot apply ReverseProxyConfig, err: %v", err)
|
||||
}
|
||||
|
||||
func (rpc *ReverseProxyConfig) MarshalJSON() ([]byte, error) {
|
||||
marshaler := make(map[string]ReverseProxyConfig, 1)
|
||||
marshaler[rpc.jsonType()] = *rpc
|
||||
return json.Marshal(marshaler)
|
||||
}
|
||||
|
||||
func (rpc *ReverseProxyConfig) jsonType() string {
|
||||
return "reverse_proxy_config"
|
||||
}
|
||||
func (_ *ReverseProxyConfig) isFallibleConfig() {}
|
||||
|
||||
//go-sumtype:decl OriginConfig
|
||||
type OriginConfig interface {
|
||||
// Service returns a OriginService used to proxy to the origin
|
||||
Service() (originservice.OriginService, error)
|
||||
// go-sumtype requires at least one unexported method, otherwise it will complain that interface is not sealed
|
||||
jsonType() string
|
||||
}
|
||||
|
||||
type originType int
|
||||
|
||||
const (
|
||||
httpType originType = iota
|
||||
wsType
|
||||
helloWorldType
|
||||
)
|
||||
|
||||
func (ot originType) String() string {
|
||||
switch ot {
|
||||
case httpType:
|
||||
return "Http"
|
||||
case wsType:
|
||||
return "WebSocket"
|
||||
case helloWorldType:
|
||||
return "HelloWorld"
|
||||
default:
|
||||
return "unknown"
|
||||
}
|
||||
isOriginConfig()
|
||||
}
|
||||
|
||||
type HTTPOriginConfig struct {
|
||||
URLString string `capnp:"urlString" json:"url_string" mapstructure:"url_string"`
|
||||
TCPKeepAlive time.Duration `capnp:"tcpKeepAlive" json:"tcp_keep_alive" mapstructure:"tcp_keep_alive"`
|
||||
DialDualStack bool `json:"dial_dual_stack" mapstructure:"dial_dual_stack"`
|
||||
TLSHandshakeTimeout time.Duration `capnp:"tlsHandshakeTimeout" json:"tls_handshake_timeout" mapstructure:"tls_handshake_timeout"`
|
||||
TLSVerify bool `capnp:"tlsVerify" json:"tls_verify" mapstructure:"tls_verify"`
|
||||
OriginCAPool string `json:"origin_ca_pool" mapstructure:"origin_ca_pool"`
|
||||
OriginServerName string `json:"origin_server_name" mapstructure:"origin_server_name"`
|
||||
MaxIdleConnections uint64 `json:"max_idle_connections" mapstructure:"max_idle_connections"`
|
||||
IdleConnectionTimeout time.Duration `json:"idle_connection_timeout" mapstructure:"idle_connection_timeout"`
|
||||
ProxyConnectionTimeout time.Duration `json:"proxy_connection_timeout" mapstructure:"proxy_connection_timeout"`
|
||||
ExpectContinueTimeout time.Duration `json:"expect_continue_timeout" mapstructure:"expect_continue_timeout"`
|
||||
ChunkedEncoding bool `json:"chunked_encoding" mapstructure:"chunked_encoding"`
|
||||
URLString string `capnp:"urlString"`
|
||||
TCPKeepAlive time.Duration `capnp:"tcpKeepAlive"`
|
||||
DialDualStack bool
|
||||
TLSHandshakeTimeout time.Duration `capnp:"tlsHandshakeTimeout"`
|
||||
TLSVerify bool `capnp:"tlsVerify"`
|
||||
OriginCAPool string
|
||||
OriginServerName string
|
||||
MaxIdleConnections uint64
|
||||
IdleConnectionTimeout time.Duration
|
||||
ProxyConnectionTimeout time.Duration
|
||||
ExpectContinueTimeout time.Duration
|
||||
ChunkedEncoding bool
|
||||
}
|
||||
|
||||
func (hc *HTTPOriginConfig) Service() (originservice.OriginService, error) {
|
||||
@@ -248,15 +227,13 @@ func (hc *HTTPOriginConfig) Service() (originservice.OriginService, error) {
|
||||
return originservice.NewHTTPService(transport, url, hc.ChunkedEncoding), nil
|
||||
}
|
||||
|
||||
func (_ *HTTPOriginConfig) jsonType() string {
|
||||
return httpType.String()
|
||||
}
|
||||
func (*HTTPOriginConfig) isOriginConfig() {}
|
||||
|
||||
type WebSocketOriginConfig struct {
|
||||
URLString string `capnp:"urlString" json:"url_string" mapstructure:"url_string"`
|
||||
TLSVerify bool `capnp:"tlsVerify" json:"tls_verify" mapstructure:"tls_verify"`
|
||||
OriginCAPool string `json:"origin_ca_pool" mapstructure:"origin_ca_pool"`
|
||||
OriginServerName string `json:"origin_server_name" mapstructure:"origin_server_name"`
|
||||
URLString string `capnp:"urlString"`
|
||||
TLSVerify bool `capnp:"tlsVerify"`
|
||||
OriginCAPool string
|
||||
OriginServerName string
|
||||
}
|
||||
|
||||
func (wsc *WebSocketOriginConfig) Service() (originservice.OriginService, error) {
|
||||
@@ -277,13 +254,11 @@ func (wsc *WebSocketOriginConfig) Service() (originservice.OriginService, error)
|
||||
return originservice.NewWebSocketService(tlsConfig, url)
|
||||
}
|
||||
|
||||
func (_ *WebSocketOriginConfig) jsonType() string {
|
||||
return wsType.String()
|
||||
}
|
||||
func (*WebSocketOriginConfig) isOriginConfig() {}
|
||||
|
||||
type HelloWorldOriginConfig struct{}
|
||||
|
||||
func (_ *HelloWorldOriginConfig) Service() (originservice.OriginService, error) {
|
||||
func (*HelloWorldOriginConfig) Service() (originservice.OriginService, error) {
|
||||
helloCert, err := tlsconfig.GetHelloCertificateX509()
|
||||
if err != nil {
|
||||
return nil, errors.Wrap(err, "Cannot get Hello World server certificate")
|
||||
@@ -308,9 +283,7 @@ func (_ *HelloWorldOriginConfig) Service() (originservice.OriginService, error)
|
||||
return originservice.NewHelloWorldService(transport)
|
||||
}
|
||||
|
||||
func (_ *HelloWorldOriginConfig) jsonType() string {
|
||||
return helloWorldType.String()
|
||||
}
|
||||
func (*HelloWorldOriginConfig) isOriginConfig() {}
|
||||
|
||||
/*
|
||||
* Boilerplate to convert between these structs and the primitive structs
|
||||
@@ -519,9 +492,9 @@ func UnmarshalDoHProxyConfig(s tunnelrpc.DoHProxyConfig) (*DoHProxyConfig, error
|
||||
|
||||
func MarshalReverseProxyConfig(s tunnelrpc.ReverseProxyConfig, p *ReverseProxyConfig) error {
|
||||
s.SetTunnelHostname(p.TunnelHostname.String())
|
||||
switch config := p.OriginConfigJSONHandler.OriginConfig.(type) {
|
||||
switch config := p.OriginConfig.(type) {
|
||||
case *HTTPOriginConfig:
|
||||
ss, err := s.Origin().NewHttp()
|
||||
ss, err := s.OriginConfig().NewHttp()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
@@ -529,7 +502,7 @@ func MarshalReverseProxyConfig(s tunnelrpc.ReverseProxyConfig, p *ReverseProxyCo
|
||||
return err
|
||||
}
|
||||
case *WebSocketOriginConfig:
|
||||
ss, err := s.Origin().NewWebsocket()
|
||||
ss, err := s.OriginConfig().NewWebsocket()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
@@ -537,7 +510,7 @@ func MarshalReverseProxyConfig(s tunnelrpc.ReverseProxyConfig, p *ReverseProxyCo
|
||||
return err
|
||||
}
|
||||
case *HelloWorldOriginConfig:
|
||||
ss, err := s.Origin().NewHelloWorld()
|
||||
ss, err := s.OriginConfig().NewHelloWorld()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
@@ -560,9 +533,9 @@ func UnmarshalReverseProxyConfig(s tunnelrpc.ReverseProxyConfig) (*ReverseProxyC
|
||||
return nil, err
|
||||
}
|
||||
p.TunnelHostname = h2mux.TunnelHostname(tunnelHostname)
|
||||
switch s.Origin().Which() {
|
||||
case tunnelrpc.ReverseProxyConfig_origin_Which_http:
|
||||
ss, err := s.Origin().Http()
|
||||
switch s.OriginConfig().Which() {
|
||||
case tunnelrpc.ReverseProxyConfig_originConfig_Which_http:
|
||||
ss, err := s.OriginConfig().Http()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@@ -570,9 +543,9 @@ func UnmarshalReverseProxyConfig(s tunnelrpc.ReverseProxyConfig) (*ReverseProxyC
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
p.OriginConfigJSONHandler = &OriginConfigJSONHandler{config}
|
||||
case tunnelrpc.ReverseProxyConfig_origin_Which_websocket:
|
||||
ss, err := s.Origin().Websocket()
|
||||
p.OriginConfig = config
|
||||
case tunnelrpc.ReverseProxyConfig_originConfig_Which_websocket:
|
||||
ss, err := s.OriginConfig().Websocket()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@@ -580,9 +553,9 @@ func UnmarshalReverseProxyConfig(s tunnelrpc.ReverseProxyConfig) (*ReverseProxyC
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
p.OriginConfigJSONHandler = &OriginConfigJSONHandler{config}
|
||||
case tunnelrpc.ReverseProxyConfig_origin_Which_helloWorld:
|
||||
ss, err := s.Origin().HelloWorld()
|
||||
p.OriginConfig = config
|
||||
case tunnelrpc.ReverseProxyConfig_originConfig_Which_helloWorld:
|
||||
ss, err := s.OriginConfig().HelloWorld()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@@ -590,7 +563,7 @@ func UnmarshalReverseProxyConfig(s tunnelrpc.ReverseProxyConfig) (*ReverseProxyC
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
p.OriginConfigJSONHandler = &OriginConfigJSONHandler{config}
|
||||
p.OriginConfig = config
|
||||
}
|
||||
p.Retries = s.Retries()
|
||||
p.ConnectionTimeout = time.Duration(s.ConnectionTimeout())
|
||||
@@ -690,13 +663,13 @@ func (i ClientService_PogsImpl) UseConfiguration(p tunnelrpc.ClientService_useCo
|
||||
}
|
||||
|
||||
type UseConfigurationResult struct {
|
||||
Success bool `json:"success"`
|
||||
FailedConfigs []*FailedConfig `json:"failed_configs"`
|
||||
Success bool
|
||||
FailedConfigs []*FailedConfig
|
||||
}
|
||||
|
||||
type FailedConfig struct {
|
||||
Config FallibleConfig `json:"config"`
|
||||
Reason string `json:"reason"`
|
||||
Config FallibleConfig
|
||||
Reason string
|
||||
}
|
||||
|
||||
func MarshalFailedConfig(s tunnelrpc.FailedConfig, p *FailedConfig) error {
|
||||
|
@@ -61,13 +61,13 @@ func ClientConfigTestCases() []*ClientConfig {
|
||||
sampleReverseProxyConfig(func(c *ReverseProxyConfig) {
|
||||
}),
|
||||
sampleReverseProxyConfig(func(c *ReverseProxyConfig) {
|
||||
c.OriginConfigJSONHandler = &OriginConfigJSONHandler{sampleHTTPOriginConfig()}
|
||||
c.OriginConfig = sampleHTTPOriginConfig()
|
||||
}),
|
||||
sampleReverseProxyConfig(func(c *ReverseProxyConfig) {
|
||||
c.OriginConfigJSONHandler = &OriginConfigJSONHandler{sampleHTTPOriginConfigUnixPath()}
|
||||
c.OriginConfig = sampleHTTPOriginConfigUnixPath()
|
||||
}),
|
||||
sampleReverseProxyConfig(func(c *ReverseProxyConfig) {
|
||||
c.OriginConfigJSONHandler = &OriginConfigJSONHandler{sampleWebSocketOriginConfig()}
|
||||
c.OriginConfig = sampleWebSocketOriginConfig()
|
||||
}),
|
||||
}
|
||||
}
|
||||
@@ -83,21 +83,14 @@ func ClientConfigTestCases() []*ClientConfig {
|
||||
}
|
||||
|
||||
func TestClientConfig(t *testing.T) {
|
||||
for i, testCase := range ClientConfigTestCases() {
|
||||
_, seg, err := capnp.NewMessage(capnp.SingleSegment(nil))
|
||||
capnpEntity, err := tunnelrpc.NewClientConfig(seg)
|
||||
if !assert.NoError(t, err) {
|
||||
t.Fatal("Couldn't initialize a new message")
|
||||
}
|
||||
err = MarshalClientConfig(capnpEntity, testCase)
|
||||
if !assert.NoError(t, err, "testCase index %v failed to marshal", i) {
|
||||
continue
|
||||
}
|
||||
result, err := UnmarshalClientConfig(capnpEntity)
|
||||
if !assert.NoError(t, err, "testCase index %v failed to unmarshal", i) {
|
||||
continue
|
||||
}
|
||||
assert.Equal(t, testCase, result, "testCase index %v didn't preserve struct through marshalling and unmarshalling", i)
|
||||
for _, testCase := range ClientConfigTestCases() {
|
||||
b, err := testCase.MarshalBytes()
|
||||
assert.NoError(t, err)
|
||||
|
||||
clientConfig, err := UnmarshalClientConfigFromBytes(b)
|
||||
assert.NoError(t, err)
|
||||
|
||||
assert.Equal(t, testCase, clientConfig)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -167,13 +160,13 @@ func TestReverseProxyConfig(t *testing.T) {
|
||||
testCases := []*ReverseProxyConfig{
|
||||
sampleReverseProxyConfig(),
|
||||
sampleReverseProxyConfig(func(c *ReverseProxyConfig) {
|
||||
c.OriginConfigJSONHandler = &OriginConfigJSONHandler{sampleHTTPOriginConfig()}
|
||||
c.OriginConfig = sampleHTTPOriginConfig()
|
||||
}),
|
||||
sampleReverseProxyConfig(func(c *ReverseProxyConfig) {
|
||||
c.OriginConfigJSONHandler = &OriginConfigJSONHandler{sampleHTTPOriginConfigUnixPath()}
|
||||
c.OriginConfig = sampleHTTPOriginConfigUnixPath()
|
||||
}),
|
||||
sampleReverseProxyConfig(func(c *ReverseProxyConfig) {
|
||||
c.OriginConfigJSONHandler = &OriginConfigJSONHandler{sampleWebSocketOriginConfig()}
|
||||
c.OriginConfig = sampleWebSocketOriginConfig()
|
||||
}),
|
||||
}
|
||||
for i, testCase := range testCases {
|
||||
@@ -323,11 +316,11 @@ func sampleDoHProxyConfig(overrides ...func(*DoHProxyConfig)) *DoHProxyConfig {
|
||||
// applies any number of overrides to it, and returns it.
|
||||
func sampleReverseProxyConfig(overrides ...func(*ReverseProxyConfig)) *ReverseProxyConfig {
|
||||
sample := &ReverseProxyConfig{
|
||||
TunnelHostname: "mock-non-lb-tunnel.example.com",
|
||||
OriginConfigJSONHandler: &OriginConfigJSONHandler{&HelloWorldOriginConfig{}},
|
||||
Retries: 18,
|
||||
ConnectionTimeout: 5 * time.Second,
|
||||
CompressionQuality: 3,
|
||||
TunnelHostname: "mock-non-lb-tunnel.example.com",
|
||||
OriginConfig: &HelloWorldOriginConfig{},
|
||||
Retries: 18,
|
||||
ConnectionTimeout: 5 * time.Second,
|
||||
CompressionQuality: 3,
|
||||
}
|
||||
sample.ensureNoZeroFields()
|
||||
for _, f := range overrides {
|
||||
|
@@ -1,101 +0,0 @@
|
||||
package pogs
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
|
||||
"github.com/mitchellh/mapstructure"
|
||||
"github.com/pkg/errors"
|
||||
)
|
||||
|
||||
// ScopeUnmarshaler can marshal a Scope pog from JSON.
|
||||
type ScopeUnmarshaler struct {
|
||||
Scope Scope
|
||||
}
|
||||
|
||||
// UnmarshalJSON takes in a JSON string, and attempts to marshal it into a Scope.
|
||||
// If successful, the Scope member of this ScopeUnmarshaler is set and nil is returned.
|
||||
// If unsuccessful, returns an error.
|
||||
func (su *ScopeUnmarshaler) UnmarshalJSON(b []byte) error {
|
||||
var scopeJSON map[string]interface{}
|
||||
if err := json.Unmarshal(b, &scopeJSON); err != nil {
|
||||
return errors.Wrapf(err, "cannot unmarshal %s into scopeJSON", string(b))
|
||||
}
|
||||
|
||||
if group, ok := scopeJSON["group"]; ok {
|
||||
if val, ok := group.(string); ok {
|
||||
su.Scope = NewGroup(val)
|
||||
return nil
|
||||
}
|
||||
return fmt.Errorf("JSON should have been a Scope, but the 'group' key contained %v", group)
|
||||
}
|
||||
|
||||
if systemName, ok := scopeJSON["system_name"]; ok {
|
||||
if val, ok := systemName.(string); ok {
|
||||
su.Scope = NewSystemName(val)
|
||||
return nil
|
||||
}
|
||||
return fmt.Errorf("JSON should have been a Scope, but the 'system_name' key contained %v", systemName)
|
||||
}
|
||||
|
||||
return fmt.Errorf("JSON should have been an object with one root key, either 'system_name' or 'group'")
|
||||
}
|
||||
|
||||
// OriginConfigJSONHandler is a wrapper to serialize OriginConfig with type information, and deserialize JSON
|
||||
// into an OriginConfig.
|
||||
type OriginConfigJSONHandler struct {
|
||||
OriginConfig OriginConfig
|
||||
}
|
||||
|
||||
func (ocjh *OriginConfigJSONHandler) MarshalJSON() ([]byte, error) {
|
||||
marshaler := make(map[string]OriginConfig, 1)
|
||||
marshaler[ocjh.OriginConfig.jsonType()] = ocjh.OriginConfig
|
||||
return json.Marshal(marshaler)
|
||||
}
|
||||
|
||||
func (ocjh *OriginConfigJSONHandler) UnmarshalJSON(b []byte) error {
|
||||
var originJSON map[string]interface{}
|
||||
if err := json.Unmarshal(b, &originJSON); err != nil {
|
||||
return errors.Wrapf(err, "cannot unmarshal %s into originJSON", string(b))
|
||||
}
|
||||
|
||||
if originConfig, ok := originJSON[httpType.String()]; ok {
|
||||
httpOriginConfig := &HTTPOriginConfig{}
|
||||
if err := mapstructure.Decode(originConfig, httpOriginConfig); err != nil {
|
||||
return errors.Wrapf(err, "cannot decode %+v into HTTPOriginConfig", originConfig)
|
||||
}
|
||||
ocjh.OriginConfig = httpOriginConfig
|
||||
return nil
|
||||
}
|
||||
|
||||
if originConfig, ok := originJSON[wsType.String()]; ok {
|
||||
wsOriginConfig := &WebSocketOriginConfig{}
|
||||
if err := mapstructure.Decode(originConfig, wsOriginConfig); err != nil {
|
||||
return errors.Wrapf(err, "cannot decode %+v into WebSocketOriginConfig", originConfig)
|
||||
}
|
||||
ocjh.OriginConfig = wsOriginConfig
|
||||
return nil
|
||||
}
|
||||
|
||||
if originConfig, ok := originJSON[helloWorldType.String()]; ok {
|
||||
helloWorldOriginConfig := &HelloWorldOriginConfig{}
|
||||
if err := mapstructure.Decode(originConfig, helloWorldOriginConfig); err != nil {
|
||||
return errors.Wrapf(err, "cannot decode %+v into HelloWorldOriginConfig", originConfig)
|
||||
}
|
||||
ocjh.OriginConfig = helloWorldOriginConfig
|
||||
return nil
|
||||
}
|
||||
|
||||
return fmt.Errorf("cannot unmarshal %s into OriginConfig", string(b))
|
||||
}
|
||||
|
||||
// FallibleConfigMarshaler is a wrapper for FallibleConfig to implement custom marshal logic
|
||||
type FallibleConfigMarshaler struct {
|
||||
FallibleConfig FallibleConfig
|
||||
}
|
||||
|
||||
func (fcm *FallibleConfigMarshaler) MarshalJSON() ([]byte, error) {
|
||||
marshaler := make(map[string]FallibleConfig, 1)
|
||||
marshaler[fcm.FallibleConfig.jsonType()] = fcm.FallibleConfig
|
||||
return json.Marshal(marshaler)
|
||||
}
|
@@ -1,342 +0,0 @@
|
||||
package pogs
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"strings"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
func TestScopeUnmarshaler_UnmarshalJSON(t *testing.T) {
|
||||
type fields struct {
|
||||
Scope Scope
|
||||
}
|
||||
type args struct {
|
||||
b []byte
|
||||
}
|
||||
tests := []struct {
|
||||
name string
|
||||
fields fields
|
||||
args args
|
||||
wantErr bool
|
||||
wantScope Scope
|
||||
}{
|
||||
{
|
||||
name: "group_successful",
|
||||
args: args{b: []byte(`{"group": "my-group"}`)},
|
||||
wantScope: NewGroup("my-group"),
|
||||
},
|
||||
{
|
||||
name: "system_name_successful",
|
||||
args: args{b: []byte(`{"system_name": "my-computer"}`)},
|
||||
wantScope: NewSystemName("my-computer"),
|
||||
},
|
||||
{
|
||||
name: "not_a_scope",
|
||||
args: args{b: []byte(`{"x": "y"}`)},
|
||||
wantErr: true,
|
||||
},
|
||||
{
|
||||
name: "malformed_group",
|
||||
args: args{b: []byte(`{"group": ["a", "b"]}`)},
|
||||
wantErr: true,
|
||||
},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
su := &ScopeUnmarshaler{
|
||||
Scope: tt.fields.Scope,
|
||||
}
|
||||
err := su.UnmarshalJSON(tt.args.b)
|
||||
if !tt.wantErr {
|
||||
if err != nil {
|
||||
t.Errorf("ScopeUnmarshaler.UnmarshalJSON() error = %v, wantErr %v", err, tt.wantErr)
|
||||
}
|
||||
if !eqScope(tt.wantScope, su.Scope) {
|
||||
t.Errorf("Wanted scope %v but got scope %v", tt.wantScope, su.Scope)
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestUnmarshalOrigin(t *testing.T) {
|
||||
tests := []struct {
|
||||
jsonLiteral string
|
||||
exceptedOriginConfig OriginConfig
|
||||
}{
|
||||
{
|
||||
jsonLiteral: `{
|
||||
"Http":{
|
||||
"url_string":"https.example.com",
|
||||
"tcp_keep_alive":7000000000,
|
||||
"dial_dual_stack":true,
|
||||
"tls_handshake_timeout":11000000000,
|
||||
"tls_verify":true,
|
||||
"origin_ca_pool":"/etc/cert.pem",
|
||||
"origin_server_name":"secure.example.com",
|
||||
"max_idle_connections":19,
|
||||
"idle_connection_timeout":17000000000,
|
||||
"proxy_connection_timeout":15000000000,
|
||||
"expect_continue_timeout":21000000000,
|
||||
"chunked_encoding":true
|
||||
}
|
||||
}`,
|
||||
exceptedOriginConfig: sampleHTTPOriginConfig(),
|
||||
},
|
||||
{
|
||||
jsonLiteral: `{
|
||||
"WebSocket":{
|
||||
"url_string":"ssh://example.com",
|
||||
"tls_verify":true,
|
||||
"origin_ca_pool":"/etc/cert.pem",
|
||||
"origin_server_name":"secure.example.com"
|
||||
}
|
||||
}`,
|
||||
exceptedOriginConfig: sampleWebSocketOriginConfig(),
|
||||
},
|
||||
{
|
||||
jsonLiteral: `{
|
||||
"HelloWorld": {}
|
||||
}`,
|
||||
exceptedOriginConfig: &HelloWorldOriginConfig{},
|
||||
},
|
||||
}
|
||||
|
||||
for _, test := range tests {
|
||||
originConfigJSON := prettyToValidJSON(test.jsonLiteral)
|
||||
var OriginConfigJSONHandler OriginConfigJSONHandler
|
||||
err := json.Unmarshal([]byte(originConfigJSON), &OriginConfigJSONHandler)
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, test.exceptedOriginConfig, OriginConfigJSONHandler.OriginConfig)
|
||||
}
|
||||
}
|
||||
|
||||
func TestUnmarshalClientConfig(t *testing.T) {
|
||||
prettyClientConfigJSON := `{
|
||||
"version":10,
|
||||
"supervisor_config":{
|
||||
"auto_update_frequency":86400000000000,
|
||||
"metrics_update_frequency":300000000000,
|
||||
"grace_period":30000000000
|
||||
},
|
||||
"edge_connection_config":{
|
||||
"num_ha_connections":4,
|
||||
"heartbeat_interval":5000000000,
|
||||
"timeout":30000000000,
|
||||
"max_failed_heartbeats":5,
|
||||
"user_credential_path":"~/.cloudflared/cert.pem"
|
||||
},
|
||||
"doh_proxy_configs":[{
|
||||
"listen_host": "localhost",
|
||||
"listen_port": 53,
|
||||
"upstreams": ["https://1.1.1.1/dns-query", "https://1.0.0.1/dns-query"]
|
||||
}],
|
||||
"reverse_proxy_configs":[{
|
||||
"tunnel_hostname":"sdfjadk33.cftunnel.com",
|
||||
"origin_config":{
|
||||
"Http":{
|
||||
"url_string":"https://127.0.0.1:8080",
|
||||
"tcp_keep_alive":30000000000,
|
||||
"dial_dual_stack":true,
|
||||
"tls_handshake_timeout":10000000000,
|
||||
"tls_verify":true,
|
||||
"origin_ca_pool":"",
|
||||
"origin_server_name":"",
|
||||
"max_idle_connections":100,
|
||||
"idle_connection_timeout":90000000000,
|
||||
"proxy_connection_timeout":90000000000,
|
||||
"expect_continue_timeout":90000000000,
|
||||
"chunked_encoding":true
|
||||
}
|
||||
},
|
||||
"retries":5,
|
||||
"connection_timeout":30,
|
||||
"compression_quality":0
|
||||
}]
|
||||
}`
|
||||
// replace new line and tab
|
||||
clientConfigJSON := prettyToValidJSON(prettyClientConfigJSON)
|
||||
|
||||
var clientConfig ClientConfig
|
||||
err := json.Unmarshal([]byte(clientConfigJSON), &clientConfig)
|
||||
assert.NoError(t, err)
|
||||
|
||||
assert.Equal(t, Version(10), clientConfig.Version)
|
||||
|
||||
supervisorConfig := SupervisorConfig{
|
||||
AutoUpdateFrequency: time.Hour * 24,
|
||||
MetricsUpdateFrequency: time.Second * 300,
|
||||
GracePeriod: time.Second * 30,
|
||||
}
|
||||
assert.Equal(t, supervisorConfig, *clientConfig.SupervisorConfig)
|
||||
|
||||
edgeConnectionConfig := EdgeConnectionConfig{
|
||||
NumHAConnections: 4,
|
||||
HeartbeatInterval: time.Second * 5,
|
||||
Timeout: time.Second * 30,
|
||||
MaxFailedHeartbeats: 5,
|
||||
UserCredentialPath: "~/.cloudflared/cert.pem",
|
||||
}
|
||||
assert.Equal(t, edgeConnectionConfig, *clientConfig.EdgeConnectionConfig)
|
||||
|
||||
dohProxyConfig := DoHProxyConfig{
|
||||
ListenHost: "localhost",
|
||||
ListenPort: 53,
|
||||
Upstreams: []string{"https://1.1.1.1/dns-query", "https://1.0.0.1/dns-query"},
|
||||
}
|
||||
|
||||
assert.Len(t, clientConfig.DoHProxyConfigs, 1)
|
||||
assert.Equal(t, dohProxyConfig, *clientConfig.DoHProxyConfigs[0])
|
||||
|
||||
reverseProxyConfig := ReverseProxyConfig{
|
||||
TunnelHostname: "sdfjadk33.cftunnel.com",
|
||||
OriginConfigJSONHandler: &OriginConfigJSONHandler{
|
||||
OriginConfig: &HTTPOriginConfig{
|
||||
URLString: "https://127.0.0.1:8080",
|
||||
TCPKeepAlive: time.Second * 30,
|
||||
DialDualStack: true,
|
||||
TLSHandshakeTimeout: time.Second * 10,
|
||||
TLSVerify: true,
|
||||
OriginCAPool: "",
|
||||
OriginServerName: "",
|
||||
MaxIdleConnections: 100,
|
||||
IdleConnectionTimeout: time.Second * 90,
|
||||
ProxyConnectionTimeout: time.Second * 90,
|
||||
ExpectContinueTimeout: time.Second * 90,
|
||||
ChunkedEncoding: true,
|
||||
},
|
||||
},
|
||||
Retries: 5,
|
||||
ConnectionTimeout: 30,
|
||||
CompressionQuality: 0,
|
||||
}
|
||||
|
||||
assert.Len(t, clientConfig.ReverseProxyConfigs, 1)
|
||||
assert.Equal(t, reverseProxyConfig, *clientConfig.ReverseProxyConfigs[0])
|
||||
}
|
||||
|
||||
func TestMarshalFallibleConfig(t *testing.T) {
|
||||
tests := []struct {
|
||||
fallibleConfig FallibleConfig
|
||||
expctedJSONLiteral string
|
||||
}{
|
||||
{
|
||||
fallibleConfig: sampleSupervisorConfig(),
|
||||
expctedJSONLiteral: `{
|
||||
"supervisor_config":{
|
||||
"auto_update_frequency":75600000000000,
|
||||
"metrics_update_frequency":660000000000,
|
||||
"grace_period":31000000000
|
||||
}
|
||||
}`,
|
||||
},
|
||||
{
|
||||
fallibleConfig: sampleEdgeConnectionConfig(),
|
||||
expctedJSONLiteral: `{
|
||||
"edge_connection_config":{
|
||||
"num_ha_connections":49,
|
||||
"heartbeat_interval":5000000000,
|
||||
"timeout":9000000000,
|
||||
"max_failed_heartbeats":9001,
|
||||
"user_credential_path":"/Users/example/.cloudflared/cert.pem"
|
||||
}
|
||||
}`,
|
||||
},
|
||||
{
|
||||
fallibleConfig: sampleDoHProxyConfig(),
|
||||
expctedJSONLiteral: `{
|
||||
"doh_proxy_config":{
|
||||
"listen_host":"127.0.0.1",
|
||||
"listen_port":53,
|
||||
"upstreams":["1.1.1.1","1.0.0.1"]
|
||||
}
|
||||
}`,
|
||||
},
|
||||
{
|
||||
fallibleConfig: sampleReverseProxyConfig(func(c *ReverseProxyConfig) {
|
||||
c.OriginConfigJSONHandler = &OriginConfigJSONHandler{sampleHTTPOriginConfig()}
|
||||
}),
|
||||
expctedJSONLiteral: `{
|
||||
"reverse_proxy_config":{
|
||||
"tunnel_hostname":"mock-non-lb-tunnel.example.com",
|
||||
"origin_config":{
|
||||
"Http":{
|
||||
"url_string":"https.example.com",
|
||||
"tcp_keep_alive":7000000000,
|
||||
"dial_dual_stack":true,
|
||||
"tls_handshake_timeout":11000000000,
|
||||
"tls_verify":true,
|
||||
"origin_ca_pool":"/etc/cert.pem",
|
||||
"origin_server_name":"secure.example.com",
|
||||
"max_idle_connections":19,
|
||||
"idle_connection_timeout":17000000000,
|
||||
"proxy_connection_timeout":15000000000,
|
||||
"expect_continue_timeout":21000000000,
|
||||
"chunked_encoding":true
|
||||
}
|
||||
},
|
||||
"retries":18,
|
||||
"connection_timeout":5000000000,
|
||||
"compression_quality":3
|
||||
}
|
||||
}`,
|
||||
},
|
||||
{
|
||||
fallibleConfig: sampleReverseProxyConfig(func(c *ReverseProxyConfig) {
|
||||
c.OriginConfigJSONHandler = &OriginConfigJSONHandler{sampleWebSocketOriginConfig()}
|
||||
}),
|
||||
expctedJSONLiteral: `{
|
||||
"reverse_proxy_config":{
|
||||
"tunnel_hostname":"mock-non-lb-tunnel.example.com",
|
||||
"origin_config":{
|
||||
"WebSocket":{
|
||||
"url_string":"ssh://example.com",
|
||||
"tls_verify":true,
|
||||
"origin_ca_pool":"/etc/cert.pem",
|
||||
"origin_server_name":"secure.example.com"
|
||||
}
|
||||
},
|
||||
"retries":18,
|
||||
"connection_timeout":5000000000,
|
||||
"compression_quality":3
|
||||
}
|
||||
}`,
|
||||
},
|
||||
{
|
||||
fallibleConfig: sampleReverseProxyConfig(func(c *ReverseProxyConfig) {
|
||||
c.OriginConfigJSONHandler = &OriginConfigJSONHandler{&HelloWorldOriginConfig{}}
|
||||
}),
|
||||
expctedJSONLiteral: `{
|
||||
"reverse_proxy_config":{
|
||||
"tunnel_hostname":"mock-non-lb-tunnel.example.com",
|
||||
"origin_config":{
|
||||
"HelloWorld":{}
|
||||
},
|
||||
"retries":18,
|
||||
"connection_timeout":5000000000,
|
||||
"compression_quality":3
|
||||
}
|
||||
}`,
|
||||
},
|
||||
}
|
||||
|
||||
for _, test := range tests {
|
||||
b, err := json.Marshal(test.fallibleConfig)
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, prettyToValidJSON(test.expctedJSONLiteral), string(b))
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
type prettyJSON string
|
||||
|
||||
func prettyToValidJSON(prettyJSON string) string {
|
||||
return strings.ReplaceAll(strings.ReplaceAll(prettyJSON, "\n", ""), "\t", "")
|
||||
}
|
||||
|
||||
func eqScope(s1, s2 Scope) bool {
|
||||
return s1.Value() == s2.Value() && s1.PostgresType() == s2.PostgresType()
|
||||
}
|
Reference in New Issue
Block a user