TUN-4063: Cleanup dependencies between packages.

- Move packages the provide generic functionality (such as config) from `cmd` subtree to top level.
- Remove all dependencies on `cmd` subtree from top level packages.
- Consolidate all code dealing with token generation and transfer to a single cohesive package.
This commit is contained in:
Igor Postelnik
2021-03-08 10:46:23 -06:00
parent d83d6d54ed
commit 39065377b5
47 changed files with 246 additions and 236 deletions

380
config/configuration.go Normal file
View File

@@ -0,0 +1,380 @@
package config
import (
"fmt"
"io"
"net/url"
"os"
"path/filepath"
"runtime"
"time"
"github.com/mitchellh/go-homedir"
"github.com/pkg/errors"
"github.com/rs/zerolog"
"github.com/urfave/cli/v2"
"gopkg.in/yaml.v2"
"github.com/cloudflare/cloudflared/validation"
)
var (
// DefaultConfigFiles is the file names from which we attempt to read configuration.
DefaultConfigFiles = []string{"config.yml", "config.yaml"}
// DefaultUnixConfigLocation is the primary location to find a config file
DefaultUnixConfigLocation = "/usr/local/etc/cloudflared"
// DefaultUnixLogLocation is the primary location to find log files
DefaultUnixLogLocation = "/var/log/cloudflared"
// Launchd doesn't set root env variables, so there is default
// Windows default config dir was ~/cloudflare-warp in documentation; let's keep it compatible
defaultUserConfigDirs = []string{"~/.cloudflared", "~/.cloudflare-warp", "~/cloudflare-warp"}
defaultNixConfigDirs = []string{"/etc/cloudflared", DefaultUnixConfigLocation}
ErrNoConfigFile = fmt.Errorf("Cannot determine default configuration path. No file %v in %v", DefaultConfigFiles, DefaultConfigSearchDirectories())
)
const (
DefaultCredentialFile = "cert.pem"
// BastionFlag is to enable bastion, or jump host, operation
BastionFlag = "bastion"
)
// DefaultConfigDirectory returns the default directory of the config file
func DefaultConfigDirectory() string {
if runtime.GOOS == "windows" {
path := os.Getenv("CFDPATH")
if path == "" {
path = filepath.Join(os.Getenv("ProgramFiles(x86)"), "cloudflared")
if _, err := os.Stat(path); os.IsNotExist(err) { //doesn't exist, so return an empty failure string
return ""
}
}
return path
}
return DefaultUnixConfigLocation
}
// DefaultLogDirectory returns the default directory for log files
func DefaultLogDirectory() string {
if runtime.GOOS == "windows" {
return DefaultConfigDirectory()
}
return DefaultUnixLogLocation
}
// DefaultConfigPath returns the default location of a config file
func DefaultConfigPath() string {
dir := DefaultConfigDirectory()
if dir == "" {
return DefaultConfigFiles[0]
}
return filepath.Join(dir, DefaultConfigFiles[0])
}
// DefaultConfigSearchDirectories returns the default folder locations of the config
func DefaultConfigSearchDirectories() []string {
dirs := make([]string, len(defaultUserConfigDirs))
copy(dirs, defaultUserConfigDirs)
if runtime.GOOS != "windows" {
dirs = append(dirs, defaultNixConfigDirs...)
}
return dirs
}
// FileExists checks to see if a file exist at the provided path.
func FileExists(path string) (bool, error) {
f, err := os.Open(path)
if err != nil {
if os.IsNotExist(err) {
// ignore missing files
return false, nil
}
return false, err
}
_ = f.Close()
return true, nil
}
// FindDefaultConfigPath returns the first path that contains a config file.
// If none of the combination of DefaultConfigSearchDirectories() and DefaultConfigFiles
// contains a config file, return empty string.
func FindDefaultConfigPath() string {
for _, configDir := range DefaultConfigSearchDirectories() {
for _, configFile := range DefaultConfigFiles {
dirPath, err := homedir.Expand(configDir)
if err != nil {
continue
}
path := filepath.Join(dirPath, configFile)
if ok, _ := FileExists(path); ok {
return path
}
}
}
return ""
}
// FindOrCreateConfigPath returns the first path that contains a config file
// or creates one in the primary default path if it doesn't exist
func FindOrCreateConfigPath() string {
path := FindDefaultConfigPath()
if path == "" {
// create the default directory if it doesn't exist
path = DefaultConfigPath()
if err := os.MkdirAll(filepath.Dir(path), os.ModePerm); err != nil {
return ""
}
// write a new config file out
file, err := os.Create(path)
if err != nil {
return ""
}
defer file.Close()
logDir := DefaultLogDirectory()
_ = os.MkdirAll(logDir, os.ModePerm) //try and create it. Doesn't matter if it succeed or not, only byproduct will be no logs
c := Root{
LogDirectory: logDir,
}
if err := yaml.NewEncoder(file).Encode(&c); err != nil {
return ""
}
}
return path
}
// ValidateUnixSocket ensures --unix-socket param is used exclusively
// i.e. it fails if a user specifies both --url and --unix-socket
func ValidateUnixSocket(c *cli.Context) (string, error) {
if c.IsSet("unix-socket") && (c.IsSet("url") || c.NArg() > 0) {
return "", errors.New("--unix-socket must be used exclusivly.")
}
return c.String("unix-socket"), nil
}
// ValidateUrl will validate url flag correctness. It can be either from --url or argument
// Notice ValidateUnixSocket, it will enforce --unix-socket is not used with --url or argument
func ValidateUrl(c *cli.Context, allowURLFromArgs bool) (*url.URL, error) {
var url = c.String("url")
if allowURLFromArgs && c.NArg() > 0 {
if c.IsSet("url") {
return nil, errors.New("Specified origin urls using both --url and argument. Decide which one you want, I can only support one.")
}
url = c.Args().Get(0)
}
validUrl, err := validation.ValidateUrl(url)
return validUrl, err
}
type UnvalidatedIngressRule struct {
Hostname string
Path string
Service string
OriginRequest OriginRequestConfig `yaml:"originRequest"`
}
// OriginRequestConfig is a set of optional fields that users may set to
// customize how cloudflared sends requests to origin services. It is used to set
// up general config that apply to all rules, and also, specific per-rule
// config.
// Note: To specify a time.Duration in go-yaml, use e.g. "3s" or "24h".
type OriginRequestConfig struct {
// HTTP proxy timeout for establishing a new connection
ConnectTimeout *time.Duration `yaml:"connectTimeout"`
// HTTP proxy timeout for completing a TLS handshake
TLSTimeout *time.Duration `yaml:"tlsTimeout"`
// HTTP proxy TCP keepalive duration
TCPKeepAlive *time.Duration `yaml:"tcpKeepAlive"`
// HTTP proxy should disable "happy eyeballs" for IPv4/v6 fallback
NoHappyEyeballs *bool `yaml:"noHappyEyeballs"`
// HTTP proxy maximum keepalive connection pool size
KeepAliveConnections *int `yaml:"keepAliveConnections"`
// HTTP proxy timeout for closing an idle connection
KeepAliveTimeout *time.Duration `yaml:"keepAliveTimeout"`
// Sets the HTTP Host header for the local webserver.
HTTPHostHeader *string `yaml:"httpHostHeader"`
// Hostname on the origin server certificate.
OriginServerName *string `yaml:"originServerName"`
// Path to the CA for the certificate of your origin.
// This option should be used only if your certificate is not signed by Cloudflare.
CAPool *string `yaml:"caPool"`
// Disables TLS verification of the certificate presented by your origin.
// Will allow any certificate from the origin to be accepted.
// Note: The connection from your machine to Cloudflare's Edge is still encrypted.
NoTLSVerify *bool `yaml:"noTLSVerify"`
// Disables chunked transfer encoding.
// Useful if you are running a WSGI server.
DisableChunkedEncoding *bool `yaml:"disableChunkedEncoding"`
// Runs as jump host
BastionMode *bool `yaml:"bastionMode"`
// Listen address for the proxy.
ProxyAddress *string `yaml:"proxyAddress"`
// Listen port for the proxy.
ProxyPort *uint `yaml:"proxyPort"`
// Valid options are 'socks' or empty.
ProxyType *string `yaml:"proxyType"`
}
type Configuration struct {
TunnelID string `yaml:"tunnel"`
Ingress []UnvalidatedIngressRule
WarpRouting WarpRoutingConfig `yaml:"warp-routing"`
OriginRequest OriginRequestConfig `yaml:"originRequest"`
sourceFile string
}
type WarpRoutingConfig struct {
Enabled bool `yaml:"enabled"`
}
type configFileSettings struct {
Configuration `yaml:",inline"`
// older settings will be aggregated into the generic map, should be read via cli.Context
Settings map[string]interface{} `yaml:",inline"`
}
func (c *Configuration) Source() string {
return c.sourceFile
}
func (c *configFileSettings) Int(name string) (int, error) {
if raw, ok := c.Settings[name]; ok {
if v, ok := raw.(int); ok {
return v, nil
}
return 0, fmt.Errorf("expected int found %T for %s", raw, name)
}
return 0, nil
}
func (c *configFileSettings) Duration(name string) (time.Duration, error) {
if raw, ok := c.Settings[name]; ok {
switch v := raw.(type) {
case time.Duration:
return v, nil
case string:
return time.ParseDuration(v)
}
return 0, fmt.Errorf("expected duration found %T for %s", raw, name)
}
return 0, nil
}
func (c *configFileSettings) Float64(name string) (float64, error) {
if raw, ok := c.Settings[name]; ok {
if v, ok := raw.(float64); ok {
return v, nil
}
return 0, fmt.Errorf("expected float found %T for %s", raw, name)
}
return 0, nil
}
func (c *configFileSettings) String(name string) (string, error) {
if raw, ok := c.Settings[name]; ok {
if v, ok := raw.(string); ok {
return v, nil
}
return "", fmt.Errorf("expected string found %T for %s", raw, name)
}
return "", nil
}
func (c *configFileSettings) StringSlice(name string) ([]string, error) {
if raw, ok := c.Settings[name]; ok {
if slice, ok := raw.([]interface{}); ok {
strSlice := make([]string, len(slice))
for i, v := range slice {
str, ok := v.(string)
if !ok {
return nil, fmt.Errorf("expected string, found %T for %v", i, v)
}
strSlice[i] = str
}
return strSlice, nil
}
return nil, fmt.Errorf("expected string slice found %T for %s", raw, name)
}
return nil, nil
}
func (c *configFileSettings) IntSlice(name string) ([]int, error) {
if raw, ok := c.Settings[name]; ok {
if slice, ok := raw.([]interface{}); ok {
intSlice := make([]int, len(slice))
for i, v := range slice {
str, ok := v.(int)
if !ok {
return nil, fmt.Errorf("expected int, found %T for %v ", v, v)
}
intSlice[i] = str
}
return intSlice, nil
}
if v, ok := raw.([]int); ok {
return v, nil
}
return nil, fmt.Errorf("expected int slice found %T for %s", raw, name)
}
return nil, nil
}
func (c *configFileSettings) Generic(name string) (cli.Generic, error) {
return nil, errors.New("option type Generic not supported")
}
func (c *configFileSettings) Bool(name string) (bool, error) {
if raw, ok := c.Settings[name]; ok {
if v, ok := raw.(bool); ok {
return v, nil
}
return false, fmt.Errorf("expected boolean found %T for %s", raw, name)
}
return false, nil
}
var configuration configFileSettings
func GetConfiguration() *Configuration {
return &configuration.Configuration
}
// ReadConfigFile returns InputSourceContext initialized from the configuration file.
// On repeat calls returns with the same file, returns without reading the file again; however,
// if value of "config" flag changes, will read the new config file
func ReadConfigFile(c *cli.Context, log *zerolog.Logger) (*configFileSettings, error) {
configFile := c.String("config")
if configuration.Source() == configFile || configFile == "" {
if configuration.Source() == "" {
return nil, ErrNoConfigFile
}
return &configuration, nil
}
log.Debug().Msgf("Loading configuration from %s", configFile)
file, err := os.Open(configFile)
if err != nil {
if os.IsNotExist(err) {
err = ErrNoConfigFile
}
return nil, err
}
defer file.Close()
if err := yaml.NewDecoder(file).Decode(&configuration); err != nil {
if err == io.EOF {
log.Error().Msgf("Configuration file %s was empty", configFile)
return &configuration, nil
}
return nil, errors.Wrap(err, "error parsing YAML in config file at "+configFile)
}
configuration.sourceFile = configFile
return &configuration, nil
}

View File

@@ -0,0 +1,83 @@
package config
import (
"testing"
"time"
"github.com/stretchr/testify/assert"
"gopkg.in/yaml.v2"
)
func TestConfigFileSettings(t *testing.T) {
var (
firstIngress = UnvalidatedIngressRule{
Hostname: "tunnel1.example.com",
Path: "/id",
Service: "https://localhost:8000",
}
secondIngress = UnvalidatedIngressRule{
Hostname: "*",
Path: "",
Service: "https://localhost:8001",
}
warpRouting = WarpRoutingConfig{
Enabled: true,
}
)
rawYAML := `
tunnel: config-file-test
ingress:
- hostname: tunnel1.example.com
path: /id
service: https://localhost:8000
- hostname: "*"
service: https://localhost:8001
warp-routing:
enabled: true
retries: 5
grace-period: 30s
percentage: 3.14
hostname: example.com
tag:
- test
- central-1
counters:
- 123
- 456
`
var config configFileSettings
err := yaml.Unmarshal([]byte(rawYAML), &config)
assert.NoError(t, err)
assert.Equal(t, "config-file-test", config.TunnelID)
assert.Equal(t, firstIngress, config.Ingress[0])
assert.Equal(t, secondIngress, config.Ingress[1])
assert.Equal(t, warpRouting, config.WarpRouting)
retries, err := config.Int("retries")
assert.NoError(t, err)
assert.Equal(t, 5, retries)
gracePeriod, err := config.Duration("grace-period")
assert.NoError(t, err)
assert.Equal(t, time.Second*30, gracePeriod)
percentage, err := config.Float64("percentage")
assert.NoError(t, err)
assert.Equal(t, 3.14, percentage)
hostname, err := config.String("hostname")
assert.NoError(t, err)
assert.Equal(t, "example.com", hostname)
tags, err := config.StringSlice("tag")
assert.NoError(t, err)
assert.Equal(t, "test", tags[0])
assert.Equal(t, "central-1", tags[1])
counters, err := config.IntSlice("counters")
assert.NoError(t, err)
assert.Equal(t, 123, counters[0])
assert.Equal(t, 456, counters[1])
}

112
config/manager.go Normal file
View File

@@ -0,0 +1,112 @@
package config
import (
"io"
"os"
"github.com/cloudflare/cloudflared/watcher"
"github.com/pkg/errors"
"github.com/rs/zerolog"
"gopkg.in/yaml.v2"
)
// Notifier sends out config updates
type Notifier interface {
ConfigDidUpdate(Root)
}
// Manager is the base functions of the config manager
type Manager interface {
Start(Notifier) error
Shutdown()
}
// FileManager watches the yaml config for changes
// sends updates to the service to reconfigure to match the updated config
type FileManager struct {
watcher watcher.Notifier
notifier Notifier
configPath string
log *zerolog.Logger
ReadConfig func(string, *zerolog.Logger) (Root, error)
}
// NewFileManager creates a config manager
func NewFileManager(watcher watcher.Notifier, configPath string, log *zerolog.Logger) (*FileManager, error) {
m := &FileManager{
watcher: watcher,
configPath: configPath,
log: log,
ReadConfig: readConfigFromPath,
}
err := watcher.Add(configPath)
return m, err
}
// Start starts the runloop to watch for config changes
func (m *FileManager) Start(notifier Notifier) error {
m.notifier = notifier
// update the notifier with a fresh config on start
config, err := m.GetConfig()
if err != nil {
return err
}
notifier.ConfigDidUpdate(config)
m.watcher.Start(m)
return nil
}
// GetConfig reads the yaml file from the disk
func (m *FileManager) GetConfig() (Root, error) {
return m.ReadConfig(m.configPath, m.log)
}
// Shutdown stops the watcher
func (m *FileManager) Shutdown() {
m.watcher.Shutdown()
}
func readConfigFromPath(configPath string, log *zerolog.Logger) (Root, error) {
if configPath == "" {
return Root{}, errors.New("unable to find config file")
}
file, err := os.Open(configPath)
if err != nil {
return Root{}, err
}
defer file.Close()
var config Root
if err := yaml.NewDecoder(file).Decode(&config); err != nil {
if err == io.EOF {
log.Error().Msgf("Configuration file %s was empty", configPath)
return Root{}, nil
}
return Root{}, errors.Wrap(err, "error parsing YAML in config file at "+configPath)
}
return config, nil
}
// File change notifications from the watcher
// WatcherItemDidChange triggers when the yaml config is updated
// sends the updated config to the service to reload its state
func (m *FileManager) WatcherItemDidChange(filepath string) {
config, err := m.GetConfig()
if err != nil {
m.log.Err(err).Msg("Failed to read new config")
return
}
m.log.Info().Msg("Config file has been updated")
m.notifier.ConfigDidUpdate(config)
}
// WatcherDidError notifies of errors with the file watcher
func (m *FileManager) WatcherDidError(err error) {
m.log.Err(err).Msg("Config watcher encountered an error")
}

88
config/manager_test.go Normal file
View File

@@ -0,0 +1,88 @@
package config
import (
"os"
"testing"
"github.com/cloudflare/cloudflared/watcher"
"github.com/rs/zerolog"
"github.com/stretchr/testify/assert"
)
type mockNotifier struct {
configs []Root
}
func (n *mockNotifier) ConfigDidUpdate(c Root) {
n.configs = append(n.configs, c)
}
type mockFileWatcher struct {
path string
notifier watcher.Notification
ready chan struct{}
}
func (w *mockFileWatcher) Start(n watcher.Notification) {
w.notifier = n
w.ready <- struct{}{}
}
func (w *mockFileWatcher) Add(string) error {
return nil
}
func (w *mockFileWatcher) Shutdown() {
}
func (w *mockFileWatcher) TriggerChange() {
w.notifier.WatcherItemDidChange(w.path)
}
func TestConfigChanged(t *testing.T) {
filePath := "config.yaml"
f, err := os.Create(filePath)
assert.NoError(t, err)
defer func() {
_ = f.Close()
_ = os.Remove(filePath)
}()
c := &Root{
Forwarders: []Forwarder{
{
URL: "test.daltoniam.com",
Listener: "127.0.0.1:8080",
},
},
}
configRead := func(configPath string, log *zerolog.Logger) (Root, error) {
return *c, nil
}
wait := make(chan struct{})
w := &mockFileWatcher{path: filePath, ready: wait}
log := zerolog.Nop()
service, err := NewFileManager(w, filePath, &log)
service.ReadConfig = configRead
assert.NoError(t, err)
n := &mockNotifier{}
go service.Start(n)
<-wait
c.Forwarders = append(c.Forwarders, Forwarder{URL: "add.daltoniam.com", Listener: "127.0.0.1:8081"})
w.TriggerChange()
service.Shutdown()
assert.Len(t, n.configs, 2, "did not get 2 config updates as expected")
assert.Len(t, n.configs[0].Forwarders, 1, "not the amount of forwarders expected")
assert.Len(t, n.configs[1].Forwarders, 2, "not the amount of forwarders expected")
assert.Equal(t, n.configs[0].Forwarders[0].Hash(), c.Forwarders[0].Hash(), "forwarder hashes don't match")
assert.Equal(t, n.configs[1].Forwarders[0].Hash(), c.Forwarders[0].Hash(), "forwarder hashes don't match")
assert.Equal(t, n.configs[1].Forwarders[1].Hash(), c.Forwarders[1].Hash(), "forwarder hashes don't match")
}

113
config/model.go Normal file
View File

@@ -0,0 +1,113 @@
package config
import (
"crypto/md5"
"fmt"
"io"
"strings"
"github.com/cloudflare/cloudflared/tunneldns"
)
// Forwarder represents a client side listener to forward traffic to the edge
type Forwarder struct {
URL string `json:"url"`
Listener string `json:"listener"`
TokenClientID string `json:"service_token_id" yaml:"serviceTokenID"`
TokenSecret string `json:"secret_token_id" yaml:"serviceTokenSecret"`
Destination string `json:"destination"`
}
// Tunnel represents a tunnel that should be started
type Tunnel struct {
URL string `json:"url"`
Origin string `json:"origin"`
ProtocolType string `json:"type"`
}
// DNSResolver represents a client side DNS resolver
type DNSResolver struct {
Enabled bool `json:"enabled"`
Address string `json:"address,omitempty"`
Port uint16 `json:"port,omitempty"`
Upstreams []string `json:"upstreams,omitempty"`
Bootstraps []string `json:"bootstraps,omitempty"`
MaxUpstreamConnections int `json:"max_upstream_connections,omitempty"`
}
// Root is the base options to configure the service
type Root struct {
LogDirectory string `json:"log_directory" yaml:"logDirectory,omitempty"`
LogLevel string `json:"log_level" yaml:"logLevel,omitempty"`
Forwarders []Forwarder `json:"forwarders,omitempty" yaml:"forwarders,omitempty"`
Tunnels []Tunnel `json:"tunnels,omitempty" yaml:"tunnels,omitempty"`
Resolver DNSResolver `json:"resolver,omitempty" yaml:"resolver,omitempty"`
}
// Hash returns the computed values to see if the forwarder values change
func (f *Forwarder) Hash() string {
h := md5.New()
io.WriteString(h, f.URL)
io.WriteString(h, f.Listener)
io.WriteString(h, f.TokenClientID)
io.WriteString(h, f.TokenSecret)
io.WriteString(h, f.Destination)
return fmt.Sprintf("%x", h.Sum(nil))
}
// Hash returns the computed values to see if the forwarder values change
func (r *DNSResolver) Hash() string {
h := md5.New()
io.WriteString(h, r.Address)
io.WriteString(h, strings.Join(r.Bootstraps, ","))
io.WriteString(h, strings.Join(r.Upstreams, ","))
io.WriteString(h, fmt.Sprintf("%d", r.Port))
io.WriteString(h, fmt.Sprintf("%d", r.MaxUpstreamConnections))
io.WriteString(h, fmt.Sprintf("%v", r.Enabled))
return fmt.Sprintf("%x", h.Sum(nil))
}
// EnabledOrDefault returns the enabled property
func (r *DNSResolver) EnabledOrDefault() bool {
return r.Enabled
}
// AddressOrDefault returns the address or returns the default if empty
func (r *DNSResolver) AddressOrDefault() string {
if r.Address != "" {
return r.Address
}
return "localhost"
}
// PortOrDefault return the port or returns the default if 0
func (r *DNSResolver) PortOrDefault() uint16 {
if r.Port > 0 {
return r.Port
}
return 53
}
// UpstreamsOrDefault returns the upstreams or returns the default if empty
func (r *DNSResolver) UpstreamsOrDefault() []string {
if len(r.Upstreams) > 0 {
return r.Upstreams
}
return []string{"https://1.1.1.1/dns-query", "https://1.0.0.1/dns-query"}
}
// BootstrapsOrDefault returns the bootstraps or returns the default if empty
func (r *DNSResolver) BootstrapsOrDefault() []string {
if len(r.Bootstraps) > 0 {
return r.Bootstraps
}
return []string{"https://162.159.36.1/dns-query", "https://162.159.46.1/dns-query", "https://[2606:4700:4700::1111]/dns-query", "https://[2606:4700:4700::1001]/dns-query"}
}
// MaxUpstreamConnectionsOrDefault return the max upstream connections or returns the default if negative
func (r *DNSResolver) MaxUpstreamConnectionsOrDefault() int {
if r.MaxUpstreamConnections >= 0 {
return r.MaxUpstreamConnections
}
return tunneldns.MaxUpstreamConnsDefault
}