TUN-2110: Implement custom deserialization logic for OriginConfig

This commit is contained in:
Chung-Ting Huang
2019-07-31 14:01:23 -05:00
parent 5feba7e3a9
commit bdd70e798a
52 changed files with 1874 additions and 4233 deletions

View File

@@ -14,6 +14,7 @@ import (
"github.com/cloudflare/cloudflared/originservice"
"github.com/cloudflare/cloudflared/tlsconfig"
"github.com/cloudflare/cloudflared/tunnelrpc"
"github.com/pkg/errors"
capnp "zombiezen.com/go/capnproto2"
"zombiezen.com/go/capnproto2/pogs"
@@ -27,11 +28,11 @@ import (
// ClientConfig is a collection of FallibleConfig that determines how cloudflared should function
type ClientConfig struct {
Version Version
SupervisorConfig *SupervisorConfig
EdgeConnectionConfig *EdgeConnectionConfig
DoHProxyConfigs []*DoHProxyConfig
ReverseProxyConfigs []*ReverseProxyConfig
Version Version `json:"version"`
SupervisorConfig *SupervisorConfig `json:"supervisor_config"`
EdgeConnectionConfig *EdgeConnectionConfig `json:"edge_connection_config"`
DoHProxyConfigs []*DoHProxyConfig `json:"doh_proxy_configs"`
ReverseProxyConfigs []*ReverseProxyConfig `json:"reverse_proxy_configs"`
}
// Version type models the version of a ClientConfig
@@ -56,9 +57,9 @@ type FallibleConfig interface {
// SupervisorConfig specifies config of components managed by Supervisor other than ConnectionManager
type SupervisorConfig struct {
AutoUpdateFrequency time.Duration
MetricsUpdateFrequency time.Duration
GracePeriod time.Duration
AutoUpdateFrequency time.Duration `json:"auto_update_frequency"`
MetricsUpdateFrequency time.Duration `json:"metrics_update_frequency"`
GracePeriod time.Duration `json:"grace_period"`
}
// FailReason impelents FallibleConfig interface for SupervisorConfig
@@ -68,11 +69,11 @@ func (sc *SupervisorConfig) FailReason(err error) string {
// EdgeConnectionConfig specifies what parameters and how may connections should ConnectionManager establish with edge
type EdgeConnectionConfig struct {
NumHAConnections uint8
HeartbeatInterval time.Duration
Timeout time.Duration
MaxFailedHeartbeats uint64
UserCredentialPath string
NumHAConnections uint8 `json:"num_ha_connections"`
HeartbeatInterval time.Duration `json:"heartbeat_interval"`
Timeout time.Duration `json:"timeout"`
MaxFailedHeartbeats uint64 `json:"max_failed_heartbeats"`
UserCredentialPath string `json:"user_credential_path"`
}
// FailReason impelents FallibleConfig interface for EdgeConnectionConfig
@@ -82,9 +83,9 @@ func (cmc *EdgeConnectionConfig) FailReason(err error) string {
// DoHProxyConfig is configuration for DNS over HTTPS service
type DoHProxyConfig struct {
ListenHost string
ListenPort uint16
Upstreams []string
ListenHost string `json:"listen_host"`
ListenPort uint16 `json:"listen_port"`
Upstreams []string `json:"upstreams"`
}
// FailReason impelents FallibleConfig interface for DoHProxyConfig
@@ -94,11 +95,11 @@ func (dpc *DoHProxyConfig) FailReason(err error) string {
// ReverseProxyConfig how and for what hostnames can this cloudflared proxy
type ReverseProxyConfig struct {
TunnelHostname h2mux.TunnelHostname
Origin OriginConfig
Retries uint64
ConnectionTimeout time.Duration
CompressionQuality uint64
TunnelHostname h2mux.TunnelHostname `json:"tunnel_hostname"`
OriginConfigUnmarshaler *OriginConfigUnmarshaler `json:"origin_config"`
Retries uint64 `json:"retries"`
ConnectionTimeout time.Duration `json:"connection_timeout"`
CompressionQuality uint64 `json:"compression_quality"`
}
func NewReverseProxyConfig(
@@ -109,14 +110,14 @@ func NewReverseProxyConfig(
compressionQuality uint64,
) (*ReverseProxyConfig, error) {
if originConfig == nil {
return nil, fmt.Errorf("NewReverseProxyConfig: originConfig was null")
return nil, fmt.Errorf("NewReverseProxyConfig: originConfigUnmarshaler was null")
}
return &ReverseProxyConfig{
TunnelHostname: h2mux.TunnelHostname(tunnelHostname),
Origin: originConfig,
Retries: retries,
ConnectionTimeout: connectionTimeout,
CompressionQuality: compressionQuality,
TunnelHostname: h2mux.TunnelHostname(tunnelHostname),
OriginConfigUnmarshaler: &OriginConfigUnmarshaler{originConfig},
Retries: retries,
ConnectionTimeout: connectionTimeout,
CompressionQuality: compressionQuality,
}, nil
}
@@ -133,19 +134,40 @@ type OriginConfig interface {
isOriginConfig()
}
type originType int
const (
httpType originType = iota
wsType
helloWorldType
)
func (ot originType) String() string {
switch ot {
case httpType:
return "Http"
case wsType:
return "WebSocket"
case helloWorldType:
return "HelloWorld"
default:
return "unknown"
}
}
type HTTPOriginConfig struct {
URLString string `capnp:"urlString"`
TCPKeepAlive time.Duration `capnp:"tcpKeepAlive"`
DialDualStack bool
TLSHandshakeTimeout time.Duration `capnp:"tlsHandshakeTimeout"`
TLSVerify bool `capnp:"tlsVerify"`
OriginCAPool string
OriginServerName string
MaxIdleConnections uint64
IdleConnectionTimeout time.Duration
ProxyConnectionTimeout time.Duration
ExpectContinueTimeout time.Duration
ChunkedEncoding bool
URLString string `capnp:"urlString" mapstructure:"url_string"`
TCPKeepAlive time.Duration `capnp:"tcpKeepAlive" mapstructure:"tcp_keep_alive"`
DialDualStack bool `mapstructure:"dial_dual_stack"`
TLSHandshakeTimeout time.Duration `capnp:"tlsHandshakeTimeout" mapstructure:"tls_handshake_timeout"`
TLSVerify bool `capnp:"tlsVerify" mapstructure:"tls_verify"`
OriginCAPool string `mapstructure:"origin_ca_pool"`
OriginServerName string `mapstructure:"origin_server_name"`
MaxIdleConnections uint64 `mapstructure:"max_idle_connections"`
IdleConnectionTimeout time.Duration `mapstructure:"idle_connection_timeout"`
ProxyConnectionTimeout time.Duration `mapstructure:"proxy_connection_timeout"`
ExpectContinueTimeout time.Duration `mapstructure:"expect_continue_timeout"`
ChunkedEncoding bool `mapstructure:"chunked_encoding"`
}
func (hc *HTTPOriginConfig) Service() (originservice.OriginService, error) {
@@ -187,10 +209,10 @@ func (hc *HTTPOriginConfig) Service() (originservice.OriginService, error) {
func (_ *HTTPOriginConfig) isOriginConfig() {}
type WebSocketOriginConfig struct {
URLString string `capnp:"urlString"`
TLSVerify bool `capnp:"tlsVerify"`
OriginCAPool string
OriginServerName string
URLString string `capnp:"urlString" mapstructure:"url_string"`
TLSVerify bool `capnp:"tlsVerify" mapstructure:"tls_verify"`
OriginCAPool string `mapstructure:"origin_ca_pool"`
OriginServerName string `mapstructure:"origin_server_name"`
}
func (wsc *WebSocketOriginConfig) Service() (originservice.OriginService, error) {
@@ -449,7 +471,7 @@ func UnmarshalDoHProxyConfig(s tunnelrpc.DoHProxyConfig) (*DoHProxyConfig, error
func MarshalReverseProxyConfig(s tunnelrpc.ReverseProxyConfig, p *ReverseProxyConfig) error {
s.SetTunnelHostname(p.TunnelHostname.String())
switch config := p.Origin.(type) {
switch config := p.OriginConfigUnmarshaler.OriginConfig.(type) {
case *HTTPOriginConfig:
ss, err := s.Origin().NewHttp()
if err != nil {
@@ -500,7 +522,7 @@ func UnmarshalReverseProxyConfig(s tunnelrpc.ReverseProxyConfig) (*ReverseProxyC
if err != nil {
return nil, err
}
p.Origin = config
p.OriginConfigUnmarshaler = &OriginConfigUnmarshaler{config}
case tunnelrpc.ReverseProxyConfig_origin_Which_websocket:
ss, err := s.Origin().Websocket()
if err != nil {
@@ -510,7 +532,7 @@ func UnmarshalReverseProxyConfig(s tunnelrpc.ReverseProxyConfig) (*ReverseProxyC
if err != nil {
return nil, err
}
p.Origin = config
p.OriginConfigUnmarshaler = &OriginConfigUnmarshaler{config}
case tunnelrpc.ReverseProxyConfig_origin_Which_helloWorld:
ss, err := s.Origin().HelloWorld()
if err != nil {
@@ -520,7 +542,7 @@ func UnmarshalReverseProxyConfig(s tunnelrpc.ReverseProxyConfig) (*ReverseProxyC
if err != nil {
return nil, err
}
p.Origin = config
p.OriginConfigUnmarshaler = &OriginConfigUnmarshaler{config}
}
p.Retries = s.Retries()
p.ConnectionTimeout = time.Duration(s.ConnectionTimeout())

View File

@@ -41,13 +41,13 @@ func TestClientConfig(t *testing.T) {
sampleReverseProxyConfig(func(c *ReverseProxyConfig) {
}),
sampleReverseProxyConfig(func(c *ReverseProxyConfig) {
c.Origin = sampleHTTPOriginConfig()
c.OriginConfigUnmarshaler = &OriginConfigUnmarshaler{sampleHTTPOriginConfig()}
}),
sampleReverseProxyConfig(func(c *ReverseProxyConfig) {
c.Origin = sampleHTTPOriginConfig(unixPathOverride)
c.OriginConfigUnmarshaler = &OriginConfigUnmarshaler{sampleHTTPOriginConfigUnixPath()}
}),
sampleReverseProxyConfig(func(c *ReverseProxyConfig) {
c.Origin = sampleWebSocketOriginConfig()
c.OriginConfigUnmarshaler = &OriginConfigUnmarshaler{sampleWebSocketOriginConfig()}
}),
}
}
@@ -142,13 +142,13 @@ func TestReverseProxyConfig(t *testing.T) {
testCases := []*ReverseProxyConfig{
sampleReverseProxyConfig(),
sampleReverseProxyConfig(func(c *ReverseProxyConfig) {
c.Origin = sampleHTTPOriginConfig()
c.OriginConfigUnmarshaler = &OriginConfigUnmarshaler{sampleHTTPOriginConfig()}
}),
sampleReverseProxyConfig(func(c *ReverseProxyConfig) {
c.Origin = sampleHTTPOriginConfig(unixPathOverride)
c.OriginConfigUnmarshaler = &OriginConfigUnmarshaler{sampleHTTPOriginConfigUnixPath()}
}),
sampleReverseProxyConfig(func(c *ReverseProxyConfig) {
c.Origin = sampleWebSocketOriginConfig()
c.OriginConfigUnmarshaler = &OriginConfigUnmarshaler{sampleWebSocketOriginConfig()}
}),
}
for i, testCase := range testCases {
@@ -285,11 +285,11 @@ func sampleDoHProxyConfig(overrides ...func(*DoHProxyConfig)) *DoHProxyConfig {
// applies any number of overrides to it, and returns it.
func sampleReverseProxyConfig(overrides ...func(*ReverseProxyConfig)) *ReverseProxyConfig {
sample := &ReverseProxyConfig{
TunnelHostname: "hijk.example.com",
Origin: &HelloWorldOriginConfig{},
Retries: 18,
ConnectionTimeout: 5 * time.Second,
CompressionQuality: 4,
TunnelHostname: "hijk.example.com",
OriginConfigUnmarshaler: &OriginConfigUnmarshaler{&HelloWorldOriginConfig{}},
Retries: 18,
ConnectionTimeout: 5 * time.Second,
CompressionQuality: 4,
}
sample.ensureNoZeroFields()
for _, f := range overrides {
@@ -298,11 +298,9 @@ func sampleReverseProxyConfig(overrides ...func(*ReverseProxyConfig)) *ReversePr
return sample
}
// sampleHTTPOriginConfig initializes a new HTTPOriginConfig literal,
// applies any number of overrides to it, and returns it.
func sampleHTTPOriginConfig(overrides ...func(*HTTPOriginConfig)) *HTTPOriginConfig {
sample := &HTTPOriginConfig{
URLString: "https://example.com",
URLString: "https.example.com",
TCPKeepAlive: 7 * time.Second,
DialDualStack: true,
TLSHandshakeTimeout: 11 * time.Second,
@@ -322,14 +320,28 @@ func sampleHTTPOriginConfig(overrides ...func(*HTTPOriginConfig)) *HTTPOriginCon
return sample
}
// unixPathOverride sets the URLString of the given HTTPOriginConfig to be a
// Unix socket (i.e. `unix:` scheme plus a file path)
func unixPathOverride(sample *HTTPOriginConfig) {
sample.URLString = "unix:/var/lib/file.sock"
func sampleHTTPOriginConfigUnixPath(overrides ...func(*HTTPOriginConfig)) *HTTPOriginConfig {
sample := &HTTPOriginConfig{
URLString: "unix:/var/lib/file.sock",
TCPKeepAlive: 7 * time.Second,
DialDualStack: true,
TLSHandshakeTimeout: 11 * time.Second,
TLSVerify: true,
OriginCAPool: "/etc/cert.pem",
OriginServerName: "secure.example.com",
MaxIdleConnections: 19,
IdleConnectionTimeout: 17 * time.Second,
ProxyConnectionTimeout: 15 * time.Second,
ExpectContinueTimeout: 21 * time.Second,
ChunkedEncoding: true,
}
sample.ensureNoZeroFields()
for _, f := range overrides {
f(sample)
}
return sample
}
// sampleWebSocketOriginConfig initializes a new WebSocketOriginConfig
// struct, applies any number of overrides to it, and returns it.
func sampleWebSocketOriginConfig(overrides ...func(*WebSocketOriginConfig)) *WebSocketOriginConfig {
sample := &WebSocketOriginConfig{
URLString: "ssh://example.com",

View File

@@ -4,6 +4,7 @@ import (
"encoding/json"
"fmt"
"github.com/mitchellh/mapstructure"
"github.com/pkg/errors"
)
@@ -39,3 +40,43 @@ func (su *ScopeUnmarshaler) UnmarshalJSON(b []byte) error {
return fmt.Errorf("JSON should have been an object with one root key, either 'system_name' or 'group'")
}
type OriginConfigUnmarshaler struct {
OriginConfig OriginConfig
}
func (ocu *OriginConfigUnmarshaler) UnmarshalJSON(b []byte) error {
var originJSON map[string]interface{}
if err := json.Unmarshal(b, &originJSON); err != nil {
return errors.Wrapf(err, "cannot unmarshal %s into originJSON", string(b))
}
if originConfig, ok := originJSON[httpType.String()]; ok {
httpOriginConfig := &HTTPOriginConfig{}
if err := mapstructure.Decode(originConfig, httpOriginConfig); err != nil {
return errors.Wrapf(err, "cannot decode %+v into HTTPOriginConfig", originConfig)
}
ocu.OriginConfig = httpOriginConfig
return nil
}
if originConfig, ok := originJSON[wsType.String()]; ok {
wsOriginConfig := &WebSocketOriginConfig{}
if err := mapstructure.Decode(originConfig, wsOriginConfig); err != nil {
return errors.Wrapf(err, "cannot decode %+v into WebSocketOriginConfig", originConfig)
}
ocu.OriginConfig = wsOriginConfig
return nil
}
if originConfig, ok := originJSON[helloWorldType.String()]; ok {
helloWorldOriginConfig := &HelloWorldOriginConfig{}
if err := mapstructure.Decode(originConfig, helloWorldOriginConfig); err != nil {
return errors.Wrapf(err, "cannot decode %+v into HelloWorldOriginConfig", originConfig)
}
ocu.OriginConfig = helloWorldOriginConfig
return nil
}
return fmt.Errorf("cannot unmarshal %s into OriginConfig", string(b))
}

View File

@@ -1,6 +1,13 @@
package pogs
import "testing"
import (
"encoding/json"
"strings"
"testing"
"time"
"github.com/stretchr/testify/assert"
)
func TestScopeUnmarshaler_UnmarshalJSON(t *testing.T) {
type fields struct {
@@ -55,6 +62,180 @@ func TestScopeUnmarshaler_UnmarshalJSON(t *testing.T) {
}
}
func TestUnmarshalOrigin(t *testing.T) {
tests := []struct {
jsonLiteral string
exceptedOriginConfig OriginConfig
}{
{
jsonLiteral: `{
"Http":{
"url_string":"https://127.0.0.1:8080",
"tcp_keep_alive":30000000000,
"dial_dual_stack":true,
"tls_handshake_timeout":10000000000,
"tls_verify":true,
"origin_ca_pool":"",
"origin_server_name":"",
"max_idle_connections":100,
"idle_connection_timeout":90000000000,
"proxy_connection_timeout":90000000000,
"expect_continue_timeout":90000000000,
"chunked_encoding":true
}
}`,
exceptedOriginConfig: &HTTPOriginConfig{
URLString: "https://127.0.0.1:8080",
TCPKeepAlive: time.Second * 30,
DialDualStack: true,
TLSHandshakeTimeout: time.Second * 10,
TLSVerify: true,
OriginCAPool: "",
OriginServerName: "",
MaxIdleConnections: 100,
IdleConnectionTimeout: time.Second * 90,
ProxyConnectionTimeout: time.Second * 90,
ExpectContinueTimeout: time.Second * 90,
ChunkedEncoding: true,
},
},
{
jsonLiteral: `{
"WebSocket":{
"url_string":"https://127.0.0.1:9090",
"tls_verify":true,
"origin_ca_pool":"",
"origin_server_name":"ws.example.com"
}
}`,
exceptedOriginConfig: &WebSocketOriginConfig{
URLString: "https://127.0.0.1:9090",
TLSVerify: true,
OriginCAPool: "",
OriginServerName: "ws.example.com",
},
},
{
jsonLiteral: `{
"HelloWorld": {}
}`,
exceptedOriginConfig: &HelloWorldOriginConfig{},
},
}
for _, test := range tests {
originConfigJSON := strings.ReplaceAll(strings.ReplaceAll(test.jsonLiteral, "\n", ""), "\t", "")
var OriginConfigUnmarshaler OriginConfigUnmarshaler
err := json.Unmarshal([]byte(originConfigJSON), &OriginConfigUnmarshaler)
assert.NoError(t, err)
assert.Equal(t, test.exceptedOriginConfig, OriginConfigUnmarshaler.OriginConfig)
}
}
func TestUnmarshalClientConfig(t *testing.T) {
prettyClientConfigJSON := `{
"version":10,
"supervisor_config":{
"auto_update_frequency":86400000000000,
"metrics_update_frequency":300000000000,
"grace_period":30000000000
},
"edge_connection_config":{
"num_ha_connections":4,
"heartbeat_interval":5000000000,
"timeout":30000000000,
"max_failed_heartbeats":5,
"user_credential_path":"~/.cloudflared/cert.pem"
},
"doh_proxy_configs":[{
"listen_host": "localhost",
"listen_port": 53,
"upstreams": ["https://1.1.1.1/dns-query", "https://1.0.0.1/dns-query"]
}],
"reverse_proxy_configs":[{
"tunnel_hostname":"sdfjadk33.cftunnel.com",
"origin_config":{
"Http":{
"url_string":"https://127.0.0.1:8080",
"tcp_keep_alive":30000000000,
"dial_dual_stack":true,
"tls_handshake_timeout":10000000000,
"tls_verify":true,
"origin_ca_pool":"",
"origin_server_name":"",
"max_idle_connections":100,
"idle_connection_timeout":90000000000,
"proxy_connection_timeout":90000000000,
"expect_continue_timeout":90000000000,
"chunked_encoding":true
}
},
"retries":5,
"connection_timeout":30,
"compression_quality":0
}]
}`
// replace new line and tab
clientConfigJSON := strings.ReplaceAll(strings.ReplaceAll(prettyClientConfigJSON, "\n", ""), "\t", "")
var clientConfig ClientConfig
err := json.Unmarshal([]byte(clientConfigJSON), &clientConfig)
assert.NoError(t, err)
assert.Equal(t, Version(10), clientConfig.Version)
supervisorConfig := SupervisorConfig{
AutoUpdateFrequency: time.Hour * 24,
MetricsUpdateFrequency: time.Second * 300,
GracePeriod: time.Second * 30,
}
assert.Equal(t, supervisorConfig, *clientConfig.SupervisorConfig)
edgeConnectionConfig := EdgeConnectionConfig{
NumHAConnections: 4,
HeartbeatInterval: time.Second * 5,
Timeout: time.Second * 30,
MaxFailedHeartbeats: 5,
UserCredentialPath: "~/.cloudflared/cert.pem",
}
assert.Equal(t, edgeConnectionConfig, *clientConfig.EdgeConnectionConfig)
dohProxyConfig := DoHProxyConfig{
ListenHost: "localhost",
ListenPort: 53,
Upstreams: []string{"https://1.1.1.1/dns-query", "https://1.0.0.1/dns-query"},
}
assert.Len(t, clientConfig.DoHProxyConfigs, 1)
assert.Equal(t, dohProxyConfig, *clientConfig.DoHProxyConfigs[0])
reverseProxyConfig := ReverseProxyConfig{
TunnelHostname: "sdfjadk33.cftunnel.com",
OriginConfigUnmarshaler: &OriginConfigUnmarshaler{
OriginConfig: &HTTPOriginConfig{
URLString: "https://127.0.0.1:8080",
TCPKeepAlive: time.Second * 30,
DialDualStack: true,
TLSHandshakeTimeout: time.Second * 10,
TLSVerify: true,
OriginCAPool: "",
OriginServerName: "",
MaxIdleConnections: 100,
IdleConnectionTimeout: time.Second * 90,
ProxyConnectionTimeout: time.Second * 90,
ExpectContinueTimeout: time.Second * 90,
ChunkedEncoding: true,
},
},
Retries: 5,
ConnectionTimeout: 30,
CompressionQuality: 0,
}
assert.Len(t, clientConfig.ReverseProxyConfigs, 1)
assert.Equal(t, reverseProxyConfig, *clientConfig.ReverseProxyConfigs[0])
}
func eqScope(s1, s2 Scope) bool {
return s1.Value() == s2.Value() && s1.PostgresType() == s2.PostgresType()
}