mirror of
https://github.com/cloudflare/cloudflared.git
synced 2025-07-27 16:59:57 +00:00
TUN-4063: Cleanup dependencies between packages.
- Move packages the provide generic functionality (such as config) from `cmd` subtree to top level. - Remove all dependencies on `cmd` subtree from top level packages. - Consolidate all code dealing with token generation and transfer to a single cohesive package.
This commit is contained in:
380
config/configuration.go
Normal file
380
config/configuration.go
Normal file
@@ -0,0 +1,380 @@
|
||||
package config
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"io"
|
||||
"net/url"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"runtime"
|
||||
"time"
|
||||
|
||||
"github.com/mitchellh/go-homedir"
|
||||
"github.com/pkg/errors"
|
||||
"github.com/rs/zerolog"
|
||||
"github.com/urfave/cli/v2"
|
||||
"gopkg.in/yaml.v2"
|
||||
|
||||
"github.com/cloudflare/cloudflared/validation"
|
||||
)
|
||||
|
||||
var (
|
||||
// DefaultConfigFiles is the file names from which we attempt to read configuration.
|
||||
DefaultConfigFiles = []string{"config.yml", "config.yaml"}
|
||||
|
||||
// DefaultUnixConfigLocation is the primary location to find a config file
|
||||
DefaultUnixConfigLocation = "/usr/local/etc/cloudflared"
|
||||
|
||||
// DefaultUnixLogLocation is the primary location to find log files
|
||||
DefaultUnixLogLocation = "/var/log/cloudflared"
|
||||
|
||||
// Launchd doesn't set root env variables, so there is default
|
||||
// Windows default config dir was ~/cloudflare-warp in documentation; let's keep it compatible
|
||||
defaultUserConfigDirs = []string{"~/.cloudflared", "~/.cloudflare-warp", "~/cloudflare-warp"}
|
||||
defaultNixConfigDirs = []string{"/etc/cloudflared", DefaultUnixConfigLocation}
|
||||
|
||||
ErrNoConfigFile = fmt.Errorf("Cannot determine default configuration path. No file %v in %v", DefaultConfigFiles, DefaultConfigSearchDirectories())
|
||||
)
|
||||
|
||||
const (
|
||||
DefaultCredentialFile = "cert.pem"
|
||||
|
||||
// BastionFlag is to enable bastion, or jump host, operation
|
||||
BastionFlag = "bastion"
|
||||
)
|
||||
|
||||
// DefaultConfigDirectory returns the default directory of the config file
|
||||
func DefaultConfigDirectory() string {
|
||||
if runtime.GOOS == "windows" {
|
||||
path := os.Getenv("CFDPATH")
|
||||
if path == "" {
|
||||
path = filepath.Join(os.Getenv("ProgramFiles(x86)"), "cloudflared")
|
||||
if _, err := os.Stat(path); os.IsNotExist(err) { //doesn't exist, so return an empty failure string
|
||||
return ""
|
||||
}
|
||||
}
|
||||
return path
|
||||
}
|
||||
return DefaultUnixConfigLocation
|
||||
}
|
||||
|
||||
// DefaultLogDirectory returns the default directory for log files
|
||||
func DefaultLogDirectory() string {
|
||||
if runtime.GOOS == "windows" {
|
||||
return DefaultConfigDirectory()
|
||||
}
|
||||
return DefaultUnixLogLocation
|
||||
}
|
||||
|
||||
// DefaultConfigPath returns the default location of a config file
|
||||
func DefaultConfigPath() string {
|
||||
dir := DefaultConfigDirectory()
|
||||
if dir == "" {
|
||||
return DefaultConfigFiles[0]
|
||||
}
|
||||
return filepath.Join(dir, DefaultConfigFiles[0])
|
||||
}
|
||||
|
||||
// DefaultConfigSearchDirectories returns the default folder locations of the config
|
||||
func DefaultConfigSearchDirectories() []string {
|
||||
dirs := make([]string, len(defaultUserConfigDirs))
|
||||
copy(dirs, defaultUserConfigDirs)
|
||||
if runtime.GOOS != "windows" {
|
||||
dirs = append(dirs, defaultNixConfigDirs...)
|
||||
}
|
||||
return dirs
|
||||
}
|
||||
|
||||
// FileExists checks to see if a file exist at the provided path.
|
||||
func FileExists(path string) (bool, error) {
|
||||
f, err := os.Open(path)
|
||||
if err != nil {
|
||||
if os.IsNotExist(err) {
|
||||
// ignore missing files
|
||||
return false, nil
|
||||
}
|
||||
return false, err
|
||||
}
|
||||
_ = f.Close()
|
||||
return true, nil
|
||||
}
|
||||
|
||||
// FindDefaultConfigPath returns the first path that contains a config file.
|
||||
// If none of the combination of DefaultConfigSearchDirectories() and DefaultConfigFiles
|
||||
// contains a config file, return empty string.
|
||||
func FindDefaultConfigPath() string {
|
||||
for _, configDir := range DefaultConfigSearchDirectories() {
|
||||
for _, configFile := range DefaultConfigFiles {
|
||||
dirPath, err := homedir.Expand(configDir)
|
||||
if err != nil {
|
||||
continue
|
||||
}
|
||||
path := filepath.Join(dirPath, configFile)
|
||||
if ok, _ := FileExists(path); ok {
|
||||
return path
|
||||
}
|
||||
}
|
||||
}
|
||||
return ""
|
||||
}
|
||||
|
||||
// FindOrCreateConfigPath returns the first path that contains a config file
|
||||
// or creates one in the primary default path if it doesn't exist
|
||||
func FindOrCreateConfigPath() string {
|
||||
path := FindDefaultConfigPath()
|
||||
|
||||
if path == "" {
|
||||
// create the default directory if it doesn't exist
|
||||
path = DefaultConfigPath()
|
||||
if err := os.MkdirAll(filepath.Dir(path), os.ModePerm); err != nil {
|
||||
return ""
|
||||
}
|
||||
|
||||
// write a new config file out
|
||||
file, err := os.Create(path)
|
||||
if err != nil {
|
||||
return ""
|
||||
}
|
||||
defer file.Close()
|
||||
|
||||
logDir := DefaultLogDirectory()
|
||||
_ = os.MkdirAll(logDir, os.ModePerm) //try and create it. Doesn't matter if it succeed or not, only byproduct will be no logs
|
||||
|
||||
c := Root{
|
||||
LogDirectory: logDir,
|
||||
}
|
||||
if err := yaml.NewEncoder(file).Encode(&c); err != nil {
|
||||
return ""
|
||||
}
|
||||
}
|
||||
|
||||
return path
|
||||
}
|
||||
|
||||
// ValidateUnixSocket ensures --unix-socket param is used exclusively
|
||||
// i.e. it fails if a user specifies both --url and --unix-socket
|
||||
func ValidateUnixSocket(c *cli.Context) (string, error) {
|
||||
if c.IsSet("unix-socket") && (c.IsSet("url") || c.NArg() > 0) {
|
||||
return "", errors.New("--unix-socket must be used exclusivly.")
|
||||
}
|
||||
return c.String("unix-socket"), nil
|
||||
}
|
||||
|
||||
// ValidateUrl will validate url flag correctness. It can be either from --url or argument
|
||||
// Notice ValidateUnixSocket, it will enforce --unix-socket is not used with --url or argument
|
||||
func ValidateUrl(c *cli.Context, allowURLFromArgs bool) (*url.URL, error) {
|
||||
var url = c.String("url")
|
||||
if allowURLFromArgs && c.NArg() > 0 {
|
||||
if c.IsSet("url") {
|
||||
return nil, errors.New("Specified origin urls using both --url and argument. Decide which one you want, I can only support one.")
|
||||
}
|
||||
url = c.Args().Get(0)
|
||||
}
|
||||
validUrl, err := validation.ValidateUrl(url)
|
||||
return validUrl, err
|
||||
}
|
||||
|
||||
type UnvalidatedIngressRule struct {
|
||||
Hostname string
|
||||
Path string
|
||||
Service string
|
||||
OriginRequest OriginRequestConfig `yaml:"originRequest"`
|
||||
}
|
||||
|
||||
// OriginRequestConfig is a set of optional fields that users may set to
|
||||
// customize how cloudflared sends requests to origin services. It is used to set
|
||||
// up general config that apply to all rules, and also, specific per-rule
|
||||
// config.
|
||||
// Note: To specify a time.Duration in go-yaml, use e.g. "3s" or "24h".
|
||||
type OriginRequestConfig struct {
|
||||
// HTTP proxy timeout for establishing a new connection
|
||||
ConnectTimeout *time.Duration `yaml:"connectTimeout"`
|
||||
// HTTP proxy timeout for completing a TLS handshake
|
||||
TLSTimeout *time.Duration `yaml:"tlsTimeout"`
|
||||
// HTTP proxy TCP keepalive duration
|
||||
TCPKeepAlive *time.Duration `yaml:"tcpKeepAlive"`
|
||||
// HTTP proxy should disable "happy eyeballs" for IPv4/v6 fallback
|
||||
NoHappyEyeballs *bool `yaml:"noHappyEyeballs"`
|
||||
// HTTP proxy maximum keepalive connection pool size
|
||||
KeepAliveConnections *int `yaml:"keepAliveConnections"`
|
||||
// HTTP proxy timeout for closing an idle connection
|
||||
KeepAliveTimeout *time.Duration `yaml:"keepAliveTimeout"`
|
||||
// Sets the HTTP Host header for the local webserver.
|
||||
HTTPHostHeader *string `yaml:"httpHostHeader"`
|
||||
// Hostname on the origin server certificate.
|
||||
OriginServerName *string `yaml:"originServerName"`
|
||||
// Path to the CA for the certificate of your origin.
|
||||
// This option should be used only if your certificate is not signed by Cloudflare.
|
||||
CAPool *string `yaml:"caPool"`
|
||||
// Disables TLS verification of the certificate presented by your origin.
|
||||
// Will allow any certificate from the origin to be accepted.
|
||||
// Note: The connection from your machine to Cloudflare's Edge is still encrypted.
|
||||
NoTLSVerify *bool `yaml:"noTLSVerify"`
|
||||
// Disables chunked transfer encoding.
|
||||
// Useful if you are running a WSGI server.
|
||||
DisableChunkedEncoding *bool `yaml:"disableChunkedEncoding"`
|
||||
// Runs as jump host
|
||||
BastionMode *bool `yaml:"bastionMode"`
|
||||
// Listen address for the proxy.
|
||||
ProxyAddress *string `yaml:"proxyAddress"`
|
||||
// Listen port for the proxy.
|
||||
ProxyPort *uint `yaml:"proxyPort"`
|
||||
// Valid options are 'socks' or empty.
|
||||
ProxyType *string `yaml:"proxyType"`
|
||||
}
|
||||
|
||||
type Configuration struct {
|
||||
TunnelID string `yaml:"tunnel"`
|
||||
Ingress []UnvalidatedIngressRule
|
||||
WarpRouting WarpRoutingConfig `yaml:"warp-routing"`
|
||||
OriginRequest OriginRequestConfig `yaml:"originRequest"`
|
||||
sourceFile string
|
||||
}
|
||||
|
||||
type WarpRoutingConfig struct {
|
||||
Enabled bool `yaml:"enabled"`
|
||||
}
|
||||
|
||||
type configFileSettings struct {
|
||||
Configuration `yaml:",inline"`
|
||||
// older settings will be aggregated into the generic map, should be read via cli.Context
|
||||
Settings map[string]interface{} `yaml:",inline"`
|
||||
}
|
||||
|
||||
func (c *Configuration) Source() string {
|
||||
return c.sourceFile
|
||||
}
|
||||
|
||||
func (c *configFileSettings) Int(name string) (int, error) {
|
||||
if raw, ok := c.Settings[name]; ok {
|
||||
if v, ok := raw.(int); ok {
|
||||
return v, nil
|
||||
}
|
||||
return 0, fmt.Errorf("expected int found %T for %s", raw, name)
|
||||
}
|
||||
return 0, nil
|
||||
}
|
||||
|
||||
func (c *configFileSettings) Duration(name string) (time.Duration, error) {
|
||||
if raw, ok := c.Settings[name]; ok {
|
||||
switch v := raw.(type) {
|
||||
case time.Duration:
|
||||
return v, nil
|
||||
case string:
|
||||
return time.ParseDuration(v)
|
||||
}
|
||||
return 0, fmt.Errorf("expected duration found %T for %s", raw, name)
|
||||
}
|
||||
return 0, nil
|
||||
}
|
||||
|
||||
func (c *configFileSettings) Float64(name string) (float64, error) {
|
||||
if raw, ok := c.Settings[name]; ok {
|
||||
if v, ok := raw.(float64); ok {
|
||||
return v, nil
|
||||
}
|
||||
return 0, fmt.Errorf("expected float found %T for %s", raw, name)
|
||||
}
|
||||
return 0, nil
|
||||
}
|
||||
|
||||
func (c *configFileSettings) String(name string) (string, error) {
|
||||
if raw, ok := c.Settings[name]; ok {
|
||||
if v, ok := raw.(string); ok {
|
||||
return v, nil
|
||||
}
|
||||
return "", fmt.Errorf("expected string found %T for %s", raw, name)
|
||||
}
|
||||
return "", nil
|
||||
}
|
||||
|
||||
func (c *configFileSettings) StringSlice(name string) ([]string, error) {
|
||||
if raw, ok := c.Settings[name]; ok {
|
||||
if slice, ok := raw.([]interface{}); ok {
|
||||
strSlice := make([]string, len(slice))
|
||||
for i, v := range slice {
|
||||
str, ok := v.(string)
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("expected string, found %T for %v", i, v)
|
||||
}
|
||||
strSlice[i] = str
|
||||
}
|
||||
return strSlice, nil
|
||||
}
|
||||
return nil, fmt.Errorf("expected string slice found %T for %s", raw, name)
|
||||
}
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
func (c *configFileSettings) IntSlice(name string) ([]int, error) {
|
||||
if raw, ok := c.Settings[name]; ok {
|
||||
if slice, ok := raw.([]interface{}); ok {
|
||||
intSlice := make([]int, len(slice))
|
||||
for i, v := range slice {
|
||||
str, ok := v.(int)
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("expected int, found %T for %v ", v, v)
|
||||
}
|
||||
intSlice[i] = str
|
||||
}
|
||||
return intSlice, nil
|
||||
}
|
||||
if v, ok := raw.([]int); ok {
|
||||
return v, nil
|
||||
}
|
||||
return nil, fmt.Errorf("expected int slice found %T for %s", raw, name)
|
||||
}
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
func (c *configFileSettings) Generic(name string) (cli.Generic, error) {
|
||||
return nil, errors.New("option type Generic not supported")
|
||||
}
|
||||
|
||||
func (c *configFileSettings) Bool(name string) (bool, error) {
|
||||
if raw, ok := c.Settings[name]; ok {
|
||||
if v, ok := raw.(bool); ok {
|
||||
return v, nil
|
||||
}
|
||||
return false, fmt.Errorf("expected boolean found %T for %s", raw, name)
|
||||
}
|
||||
return false, nil
|
||||
}
|
||||
|
||||
var configuration configFileSettings
|
||||
|
||||
func GetConfiguration() *Configuration {
|
||||
return &configuration.Configuration
|
||||
}
|
||||
|
||||
// ReadConfigFile returns InputSourceContext initialized from the configuration file.
|
||||
// On repeat calls returns with the same file, returns without reading the file again; however,
|
||||
// if value of "config" flag changes, will read the new config file
|
||||
func ReadConfigFile(c *cli.Context, log *zerolog.Logger) (*configFileSettings, error) {
|
||||
configFile := c.String("config")
|
||||
if configuration.Source() == configFile || configFile == "" {
|
||||
if configuration.Source() == "" {
|
||||
return nil, ErrNoConfigFile
|
||||
}
|
||||
return &configuration, nil
|
||||
}
|
||||
|
||||
log.Debug().Msgf("Loading configuration from %s", configFile)
|
||||
file, err := os.Open(configFile)
|
||||
if err != nil {
|
||||
if os.IsNotExist(err) {
|
||||
err = ErrNoConfigFile
|
||||
}
|
||||
return nil, err
|
||||
}
|
||||
defer file.Close()
|
||||
if err := yaml.NewDecoder(file).Decode(&configuration); err != nil {
|
||||
if err == io.EOF {
|
||||
log.Error().Msgf("Configuration file %s was empty", configFile)
|
||||
return &configuration, nil
|
||||
}
|
||||
return nil, errors.Wrap(err, "error parsing YAML in config file at "+configFile)
|
||||
}
|
||||
configuration.sourceFile = configFile
|
||||
return &configuration, nil
|
||||
}
|
83
config/configuration_test.go
Normal file
83
config/configuration_test.go
Normal file
@@ -0,0 +1,83 @@
|
||||
package config
|
||||
|
||||
import (
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"gopkg.in/yaml.v2"
|
||||
)
|
||||
|
||||
func TestConfigFileSettings(t *testing.T) {
|
||||
var (
|
||||
firstIngress = UnvalidatedIngressRule{
|
||||
Hostname: "tunnel1.example.com",
|
||||
Path: "/id",
|
||||
Service: "https://localhost:8000",
|
||||
}
|
||||
secondIngress = UnvalidatedIngressRule{
|
||||
Hostname: "*",
|
||||
Path: "",
|
||||
Service: "https://localhost:8001",
|
||||
}
|
||||
warpRouting = WarpRoutingConfig{
|
||||
Enabled: true,
|
||||
}
|
||||
)
|
||||
rawYAML := `
|
||||
tunnel: config-file-test
|
||||
ingress:
|
||||
- hostname: tunnel1.example.com
|
||||
path: /id
|
||||
service: https://localhost:8000
|
||||
- hostname: "*"
|
||||
service: https://localhost:8001
|
||||
warp-routing:
|
||||
enabled: true
|
||||
retries: 5
|
||||
grace-period: 30s
|
||||
percentage: 3.14
|
||||
hostname: example.com
|
||||
tag:
|
||||
- test
|
||||
- central-1
|
||||
counters:
|
||||
- 123
|
||||
- 456
|
||||
`
|
||||
var config configFileSettings
|
||||
err := yaml.Unmarshal([]byte(rawYAML), &config)
|
||||
assert.NoError(t, err)
|
||||
|
||||
assert.Equal(t, "config-file-test", config.TunnelID)
|
||||
assert.Equal(t, firstIngress, config.Ingress[0])
|
||||
assert.Equal(t, secondIngress, config.Ingress[1])
|
||||
assert.Equal(t, warpRouting, config.WarpRouting)
|
||||
|
||||
retries, err := config.Int("retries")
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, 5, retries)
|
||||
|
||||
gracePeriod, err := config.Duration("grace-period")
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, time.Second*30, gracePeriod)
|
||||
|
||||
percentage, err := config.Float64("percentage")
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, 3.14, percentage)
|
||||
|
||||
hostname, err := config.String("hostname")
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, "example.com", hostname)
|
||||
|
||||
tags, err := config.StringSlice("tag")
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, "test", tags[0])
|
||||
assert.Equal(t, "central-1", tags[1])
|
||||
|
||||
counters, err := config.IntSlice("counters")
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, 123, counters[0])
|
||||
assert.Equal(t, 456, counters[1])
|
||||
|
||||
}
|
112
config/manager.go
Normal file
112
config/manager.go
Normal file
@@ -0,0 +1,112 @@
|
||||
package config
|
||||
|
||||
import (
|
||||
"io"
|
||||
"os"
|
||||
|
||||
"github.com/cloudflare/cloudflared/watcher"
|
||||
|
||||
"github.com/pkg/errors"
|
||||
"github.com/rs/zerolog"
|
||||
"gopkg.in/yaml.v2"
|
||||
)
|
||||
|
||||
// Notifier sends out config updates
|
||||
type Notifier interface {
|
||||
ConfigDidUpdate(Root)
|
||||
}
|
||||
|
||||
// Manager is the base functions of the config manager
|
||||
type Manager interface {
|
||||
Start(Notifier) error
|
||||
Shutdown()
|
||||
}
|
||||
|
||||
// FileManager watches the yaml config for changes
|
||||
// sends updates to the service to reconfigure to match the updated config
|
||||
type FileManager struct {
|
||||
watcher watcher.Notifier
|
||||
notifier Notifier
|
||||
configPath string
|
||||
log *zerolog.Logger
|
||||
ReadConfig func(string, *zerolog.Logger) (Root, error)
|
||||
}
|
||||
|
||||
// NewFileManager creates a config manager
|
||||
func NewFileManager(watcher watcher.Notifier, configPath string, log *zerolog.Logger) (*FileManager, error) {
|
||||
m := &FileManager{
|
||||
watcher: watcher,
|
||||
configPath: configPath,
|
||||
log: log,
|
||||
ReadConfig: readConfigFromPath,
|
||||
}
|
||||
err := watcher.Add(configPath)
|
||||
return m, err
|
||||
}
|
||||
|
||||
// Start starts the runloop to watch for config changes
|
||||
func (m *FileManager) Start(notifier Notifier) error {
|
||||
m.notifier = notifier
|
||||
|
||||
// update the notifier with a fresh config on start
|
||||
config, err := m.GetConfig()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
notifier.ConfigDidUpdate(config)
|
||||
|
||||
m.watcher.Start(m)
|
||||
return nil
|
||||
}
|
||||
|
||||
// GetConfig reads the yaml file from the disk
|
||||
func (m *FileManager) GetConfig() (Root, error) {
|
||||
return m.ReadConfig(m.configPath, m.log)
|
||||
}
|
||||
|
||||
// Shutdown stops the watcher
|
||||
func (m *FileManager) Shutdown() {
|
||||
m.watcher.Shutdown()
|
||||
}
|
||||
|
||||
func readConfigFromPath(configPath string, log *zerolog.Logger) (Root, error) {
|
||||
if configPath == "" {
|
||||
return Root{}, errors.New("unable to find config file")
|
||||
}
|
||||
|
||||
file, err := os.Open(configPath)
|
||||
if err != nil {
|
||||
return Root{}, err
|
||||
}
|
||||
defer file.Close()
|
||||
|
||||
var config Root
|
||||
if err := yaml.NewDecoder(file).Decode(&config); err != nil {
|
||||
if err == io.EOF {
|
||||
log.Error().Msgf("Configuration file %s was empty", configPath)
|
||||
return Root{}, nil
|
||||
}
|
||||
return Root{}, errors.Wrap(err, "error parsing YAML in config file at "+configPath)
|
||||
}
|
||||
|
||||
return config, nil
|
||||
}
|
||||
|
||||
// File change notifications from the watcher
|
||||
|
||||
// WatcherItemDidChange triggers when the yaml config is updated
|
||||
// sends the updated config to the service to reload its state
|
||||
func (m *FileManager) WatcherItemDidChange(filepath string) {
|
||||
config, err := m.GetConfig()
|
||||
if err != nil {
|
||||
m.log.Err(err).Msg("Failed to read new config")
|
||||
return
|
||||
}
|
||||
m.log.Info().Msg("Config file has been updated")
|
||||
m.notifier.ConfigDidUpdate(config)
|
||||
}
|
||||
|
||||
// WatcherDidError notifies of errors with the file watcher
|
||||
func (m *FileManager) WatcherDidError(err error) {
|
||||
m.log.Err(err).Msg("Config watcher encountered an error")
|
||||
}
|
88
config/manager_test.go
Normal file
88
config/manager_test.go
Normal file
@@ -0,0 +1,88 @@
|
||||
package config
|
||||
|
||||
import (
|
||||
"os"
|
||||
"testing"
|
||||
|
||||
"github.com/cloudflare/cloudflared/watcher"
|
||||
|
||||
"github.com/rs/zerolog"
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
type mockNotifier struct {
|
||||
configs []Root
|
||||
}
|
||||
|
||||
func (n *mockNotifier) ConfigDidUpdate(c Root) {
|
||||
n.configs = append(n.configs, c)
|
||||
}
|
||||
|
||||
type mockFileWatcher struct {
|
||||
path string
|
||||
notifier watcher.Notification
|
||||
ready chan struct{}
|
||||
}
|
||||
|
||||
func (w *mockFileWatcher) Start(n watcher.Notification) {
|
||||
w.notifier = n
|
||||
w.ready <- struct{}{}
|
||||
}
|
||||
|
||||
func (w *mockFileWatcher) Add(string) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (w *mockFileWatcher) Shutdown() {
|
||||
|
||||
}
|
||||
|
||||
func (w *mockFileWatcher) TriggerChange() {
|
||||
w.notifier.WatcherItemDidChange(w.path)
|
||||
}
|
||||
|
||||
func TestConfigChanged(t *testing.T) {
|
||||
filePath := "config.yaml"
|
||||
f, err := os.Create(filePath)
|
||||
assert.NoError(t, err)
|
||||
defer func() {
|
||||
_ = f.Close()
|
||||
_ = os.Remove(filePath)
|
||||
}()
|
||||
c := &Root{
|
||||
Forwarders: []Forwarder{
|
||||
{
|
||||
URL: "test.daltoniam.com",
|
||||
Listener: "127.0.0.1:8080",
|
||||
},
|
||||
},
|
||||
}
|
||||
configRead := func(configPath string, log *zerolog.Logger) (Root, error) {
|
||||
return *c, nil
|
||||
}
|
||||
wait := make(chan struct{})
|
||||
w := &mockFileWatcher{path: filePath, ready: wait}
|
||||
|
||||
log := zerolog.Nop()
|
||||
|
||||
service, err := NewFileManager(w, filePath, &log)
|
||||
service.ReadConfig = configRead
|
||||
assert.NoError(t, err)
|
||||
|
||||
n := &mockNotifier{}
|
||||
go service.Start(n)
|
||||
|
||||
<-wait
|
||||
c.Forwarders = append(c.Forwarders, Forwarder{URL: "add.daltoniam.com", Listener: "127.0.0.1:8081"})
|
||||
w.TriggerChange()
|
||||
|
||||
service.Shutdown()
|
||||
|
||||
assert.Len(t, n.configs, 2, "did not get 2 config updates as expected")
|
||||
assert.Len(t, n.configs[0].Forwarders, 1, "not the amount of forwarders expected")
|
||||
assert.Len(t, n.configs[1].Forwarders, 2, "not the amount of forwarders expected")
|
||||
|
||||
assert.Equal(t, n.configs[0].Forwarders[0].Hash(), c.Forwarders[0].Hash(), "forwarder hashes don't match")
|
||||
assert.Equal(t, n.configs[1].Forwarders[0].Hash(), c.Forwarders[0].Hash(), "forwarder hashes don't match")
|
||||
assert.Equal(t, n.configs[1].Forwarders[1].Hash(), c.Forwarders[1].Hash(), "forwarder hashes don't match")
|
||||
}
|
113
config/model.go
Normal file
113
config/model.go
Normal file
@@ -0,0 +1,113 @@
|
||||
package config
|
||||
|
||||
import (
|
||||
"crypto/md5"
|
||||
"fmt"
|
||||
"io"
|
||||
"strings"
|
||||
|
||||
"github.com/cloudflare/cloudflared/tunneldns"
|
||||
)
|
||||
|
||||
// Forwarder represents a client side listener to forward traffic to the edge
|
||||
type Forwarder struct {
|
||||
URL string `json:"url"`
|
||||
Listener string `json:"listener"`
|
||||
TokenClientID string `json:"service_token_id" yaml:"serviceTokenID"`
|
||||
TokenSecret string `json:"secret_token_id" yaml:"serviceTokenSecret"`
|
||||
Destination string `json:"destination"`
|
||||
}
|
||||
|
||||
// Tunnel represents a tunnel that should be started
|
||||
type Tunnel struct {
|
||||
URL string `json:"url"`
|
||||
Origin string `json:"origin"`
|
||||
ProtocolType string `json:"type"`
|
||||
}
|
||||
|
||||
// DNSResolver represents a client side DNS resolver
|
||||
type DNSResolver struct {
|
||||
Enabled bool `json:"enabled"`
|
||||
Address string `json:"address,omitempty"`
|
||||
Port uint16 `json:"port,omitempty"`
|
||||
Upstreams []string `json:"upstreams,omitempty"`
|
||||
Bootstraps []string `json:"bootstraps,omitempty"`
|
||||
MaxUpstreamConnections int `json:"max_upstream_connections,omitempty"`
|
||||
}
|
||||
|
||||
// Root is the base options to configure the service
|
||||
type Root struct {
|
||||
LogDirectory string `json:"log_directory" yaml:"logDirectory,omitempty"`
|
||||
LogLevel string `json:"log_level" yaml:"logLevel,omitempty"`
|
||||
Forwarders []Forwarder `json:"forwarders,omitempty" yaml:"forwarders,omitempty"`
|
||||
Tunnels []Tunnel `json:"tunnels,omitempty" yaml:"tunnels,omitempty"`
|
||||
Resolver DNSResolver `json:"resolver,omitempty" yaml:"resolver,omitempty"`
|
||||
}
|
||||
|
||||
// Hash returns the computed values to see if the forwarder values change
|
||||
func (f *Forwarder) Hash() string {
|
||||
h := md5.New()
|
||||
io.WriteString(h, f.URL)
|
||||
io.WriteString(h, f.Listener)
|
||||
io.WriteString(h, f.TokenClientID)
|
||||
io.WriteString(h, f.TokenSecret)
|
||||
io.WriteString(h, f.Destination)
|
||||
return fmt.Sprintf("%x", h.Sum(nil))
|
||||
}
|
||||
|
||||
// Hash returns the computed values to see if the forwarder values change
|
||||
func (r *DNSResolver) Hash() string {
|
||||
h := md5.New()
|
||||
io.WriteString(h, r.Address)
|
||||
io.WriteString(h, strings.Join(r.Bootstraps, ","))
|
||||
io.WriteString(h, strings.Join(r.Upstreams, ","))
|
||||
io.WriteString(h, fmt.Sprintf("%d", r.Port))
|
||||
io.WriteString(h, fmt.Sprintf("%d", r.MaxUpstreamConnections))
|
||||
io.WriteString(h, fmt.Sprintf("%v", r.Enabled))
|
||||
return fmt.Sprintf("%x", h.Sum(nil))
|
||||
}
|
||||
|
||||
// EnabledOrDefault returns the enabled property
|
||||
func (r *DNSResolver) EnabledOrDefault() bool {
|
||||
return r.Enabled
|
||||
}
|
||||
|
||||
// AddressOrDefault returns the address or returns the default if empty
|
||||
func (r *DNSResolver) AddressOrDefault() string {
|
||||
if r.Address != "" {
|
||||
return r.Address
|
||||
}
|
||||
return "localhost"
|
||||
}
|
||||
|
||||
// PortOrDefault return the port or returns the default if 0
|
||||
func (r *DNSResolver) PortOrDefault() uint16 {
|
||||
if r.Port > 0 {
|
||||
return r.Port
|
||||
}
|
||||
return 53
|
||||
}
|
||||
|
||||
// UpstreamsOrDefault returns the upstreams or returns the default if empty
|
||||
func (r *DNSResolver) UpstreamsOrDefault() []string {
|
||||
if len(r.Upstreams) > 0 {
|
||||
return r.Upstreams
|
||||
}
|
||||
return []string{"https://1.1.1.1/dns-query", "https://1.0.0.1/dns-query"}
|
||||
}
|
||||
|
||||
// BootstrapsOrDefault returns the bootstraps or returns the default if empty
|
||||
func (r *DNSResolver) BootstrapsOrDefault() []string {
|
||||
if len(r.Bootstraps) > 0 {
|
||||
return r.Bootstraps
|
||||
}
|
||||
return []string{"https://162.159.36.1/dns-query", "https://162.159.46.1/dns-query", "https://[2606:4700:4700::1111]/dns-query", "https://[2606:4700:4700::1001]/dns-query"}
|
||||
}
|
||||
|
||||
// MaxUpstreamConnectionsOrDefault return the max upstream connections or returns the default if negative
|
||||
func (r *DNSResolver) MaxUpstreamConnectionsOrDefault() int {
|
||||
if r.MaxUpstreamConnections >= 0 {
|
||||
return r.MaxUpstreamConnections
|
||||
}
|
||||
return tunneldns.MaxUpstreamConnsDefault
|
||||
}
|
Reference in New Issue
Block a user