mirror of
https://github.com/cloudflare/cloudflared.git
synced 2025-07-27 16:29:56 +00:00
TUN-5801: Add custom wrapper for OriginConfig for JSON serde
This commit is contained in:
@@ -1,12 +1,14 @@
|
||||
package config
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/url"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"runtime"
|
||||
"strconv"
|
||||
"time"
|
||||
|
||||
homedir "github.com/mitchellh/go-homedir"
|
||||
@@ -49,7 +51,7 @@ func DefaultConfigDirectory() string {
|
||||
path := os.Getenv("CFDPATH")
|
||||
if path == "" {
|
||||
path = filepath.Join(os.Getenv("ProgramFiles(x86)"), "cloudflared")
|
||||
if _, err := os.Stat(path); os.IsNotExist(err) { //doesn't exist, so return an empty failure string
|
||||
if _, err := os.Stat(path); os.IsNotExist(err) { // doesn't exist, so return an empty failure string
|
||||
return ""
|
||||
}
|
||||
}
|
||||
@@ -138,7 +140,7 @@ func FindOrCreateConfigPath() string {
|
||||
defer file.Close()
|
||||
|
||||
logDir := DefaultLogDirectory()
|
||||
_ = os.MkdirAll(logDir, os.ModePerm) //try and create it. Doesn't matter if it succeed or not, only byproduct will be no logs
|
||||
_ = os.MkdirAll(logDir, os.ModePerm) // try and create it. Doesn't matter if it succeed or not, only byproduct will be no logs
|
||||
|
||||
c := Root{
|
||||
LogDirectory: logDir,
|
||||
@@ -190,17 +192,17 @@ type UnvalidatedIngressRule struct {
|
||||
// - To specify a time.Duration in json, use int64 of the nanoseconds
|
||||
type OriginRequestConfig struct {
|
||||
// HTTP proxy timeout for establishing a new connection
|
||||
ConnectTimeout *time.Duration `yaml:"connectTimeout" json:"connectTimeout"`
|
||||
ConnectTimeout *CustomDuration `yaml:"connectTimeout" json:"connectTimeout"`
|
||||
// HTTP proxy timeout for completing a TLS handshake
|
||||
TLSTimeout *time.Duration `yaml:"tlsTimeout" json:"tlsTimeout"`
|
||||
TLSTimeout *CustomDuration `yaml:"tlsTimeout" json:"tlsTimeout"`
|
||||
// HTTP proxy TCP keepalive duration
|
||||
TCPKeepAlive *time.Duration `yaml:"tcpKeepAlive" json:"tcpKeepAlive"`
|
||||
TCPKeepAlive *CustomDuration `yaml:"tcpKeepAlive" json:"tcpKeepAlive"`
|
||||
// HTTP proxy should disable "happy eyeballs" for IPv4/v6 fallback
|
||||
NoHappyEyeballs *bool `yaml:"noHappyEyeballs" json:"noHappyEyeballs"`
|
||||
// HTTP proxy maximum keepalive connection pool size
|
||||
KeepAliveConnections *int `yaml:"keepAliveConnections" json:"keepAliveConnections"`
|
||||
// HTTP proxy timeout for closing an idle connection
|
||||
KeepAliveTimeout *time.Duration `yaml:"keepAliveTimeout" json:"keepAliveTimeout"`
|
||||
KeepAliveTimeout *CustomDuration `yaml:"keepAliveTimeout" json:"keepAliveTimeout"`
|
||||
// Sets the HTTP Host header for the local webserver.
|
||||
HTTPHostHeader *string `yaml:"httpHostHeader" json:"httpHostHeader"`
|
||||
// Hostname on the origin server certificate.
|
||||
@@ -399,3 +401,34 @@ func ReadConfigFile(c *cli.Context, log *zerolog.Logger) (settings *configFileSe
|
||||
|
||||
return &configuration, warnings, nil
|
||||
}
|
||||
|
||||
// A CustomDuration is a Duration that has custom serialization for JSON.
|
||||
// JSON in Javascript assumes that int fields are 32 bits and Duration fields are deserialized assuming that numbers
|
||||
// are in nanoseconds, which in 32bit integers limits to just 2 seconds.
|
||||
// This type assumes that when serializing/deserializing from JSON, that the number is in seconds, while it maintains
|
||||
// the YAML serde assumptions.
|
||||
type CustomDuration struct {
|
||||
time.Duration
|
||||
}
|
||||
|
||||
func (s *CustomDuration) MarshalJSON() ([]byte, error) {
|
||||
return json.Marshal(s.Duration.Seconds())
|
||||
}
|
||||
|
||||
func (s *CustomDuration) UnmarshalJSON(data []byte) error {
|
||||
seconds, err := strconv.ParseInt(string(data), 10, 64)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
s.Duration = time.Duration(seconds * int64(time.Second))
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *CustomDuration) MarshalYAML() (interface{}, error) {
|
||||
return s.Duration.String(), nil
|
||||
}
|
||||
|
||||
func (s *CustomDuration) UnmarshalYAML(unmarshal func(interface{}) error) error {
|
||||
return unmarshal(&s.Duration)
|
||||
}
|
||||
|
@@ -6,6 +6,7 @@ import (
|
||||
"time"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
yaml "gopkg.in/yaml.v2"
|
||||
)
|
||||
|
||||
@@ -110,14 +111,13 @@ counters:
|
||||
|
||||
}
|
||||
|
||||
func TestUnmarshalOriginRequestConfig(t *testing.T) {
|
||||
raw := []byte(`
|
||||
var rawConfig = []byte(`
|
||||
{
|
||||
"connectTimeout": 10000000000,
|
||||
"tlsTimeout": 30000000000,
|
||||
"tcpKeepAlive": 30000000000,
|
||||
"connectTimeout": 10,
|
||||
"tlsTimeout": 30,
|
||||
"tcpKeepAlive": 30,
|
||||
"noHappyEyeballs": true,
|
||||
"keepAliveTimeout": 60000000000,
|
||||
"keepAliveTimeout": 60,
|
||||
"keepAliveConnections": 10,
|
||||
"httpHostHeader": "app.tunnel.com",
|
||||
"originServerName": "app.tunnel.com",
|
||||
@@ -142,13 +142,41 @@ func TestUnmarshalOriginRequestConfig(t *testing.T) {
|
||||
]
|
||||
}
|
||||
`)
|
||||
|
||||
func TestMarshalUnmarshalOriginRequest(t *testing.T) {
|
||||
testCases := []struct {
|
||||
name string
|
||||
marshalFunc func(in interface{}) (out []byte, err error)
|
||||
unMarshalFunc func(in []byte, out interface{}) (err error)
|
||||
baseUnit time.Duration
|
||||
}{
|
||||
{"json", json.Marshal, json.Unmarshal, time.Second},
|
||||
{"yaml", yaml.Marshal, yaml.Unmarshal, time.Nanosecond},
|
||||
}
|
||||
|
||||
for _, tc := range testCases {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
assertConfig(t, tc.marshalFunc, tc.unMarshalFunc, tc.baseUnit)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func assertConfig(
|
||||
t *testing.T,
|
||||
marshalFunc func(in interface{}) (out []byte, err error),
|
||||
unMarshalFunc func(in []byte, out interface{}) (err error),
|
||||
baseUnit time.Duration,
|
||||
) {
|
||||
var config OriginRequestConfig
|
||||
assert.NoError(t, json.Unmarshal(raw, &config))
|
||||
assert.Equal(t, time.Second*10, *config.ConnectTimeout)
|
||||
assert.Equal(t, time.Second*30, *config.TLSTimeout)
|
||||
assert.Equal(t, time.Second*30, *config.TCPKeepAlive)
|
||||
var config2 OriginRequestConfig
|
||||
|
||||
assert.NoError(t, unMarshalFunc(rawConfig, &config))
|
||||
|
||||
assert.Equal(t, baseUnit*10, config.ConnectTimeout.Duration)
|
||||
assert.Equal(t, baseUnit*30, config.TLSTimeout.Duration)
|
||||
assert.Equal(t, baseUnit*30, config.TCPKeepAlive.Duration)
|
||||
assert.Equal(t, true, *config.NoHappyEyeballs)
|
||||
assert.Equal(t, time.Second*60, *config.KeepAliveTimeout)
|
||||
assert.Equal(t, baseUnit*60, config.KeepAliveTimeout.Duration)
|
||||
assert.Equal(t, 10, *config.KeepAliveConnections)
|
||||
assert.Equal(t, "app.tunnel.com", *config.HTTPHostHeader)
|
||||
assert.Equal(t, "app.tunnel.com", *config.OriginServerName)
|
||||
@@ -176,4 +204,12 @@ func TestUnmarshalOriginRequestConfig(t *testing.T) {
|
||||
},
|
||||
}
|
||||
assert.Equal(t, ipRules, config.IPRules)
|
||||
|
||||
// validate that serializing and deserializing again matches the deserialization from raw string
|
||||
result, err := marshalFunc(config)
|
||||
require.NoError(t, err)
|
||||
err = unMarshalFunc(result, &config2)
|
||||
require.NoError(t, err)
|
||||
|
||||
require.Equal(t, config2, config)
|
||||
}
|
||||
|
Reference in New Issue
Block a user