mirror of
https://github.com/cloudflare/cloudflared.git
synced 2025-07-28 15:30:04 +00:00
TUN-4063: Cleanup dependencies between packages.
- Move packages the provide generic functionality (such as config) from `cmd` subtree to top level. - Remove all dependencies on `cmd` subtree from top level packages. - Consolidate all code dealing with token generation and transfer to a single cohesive package.
This commit is contained in:
@@ -7,7 +7,7 @@ import (
|
||||
"strings"
|
||||
|
||||
"github.com/cloudflare/cloudflared/carrier"
|
||||
"github.com/cloudflare/cloudflared/cmd/cloudflared/config"
|
||||
"github.com/cloudflare/cloudflared/config"
|
||||
"github.com/cloudflare/cloudflared/h2mux"
|
||||
"github.com/cloudflare/cloudflared/logger"
|
||||
"github.com/cloudflare/cloudflared/validation"
|
||||
|
@@ -2,20 +2,21 @@ package access
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"os"
|
||||
"os/exec"
|
||||
"strings"
|
||||
"text/template"
|
||||
"time"
|
||||
|
||||
"github.com/cloudflare/cloudflared/carrier"
|
||||
"github.com/cloudflare/cloudflared/cmd/cloudflared/cliutil"
|
||||
"github.com/cloudflare/cloudflared/cmd/cloudflared/shell"
|
||||
"github.com/cloudflare/cloudflared/cmd/cloudflared/token"
|
||||
"github.com/cloudflare/cloudflared/h2mux"
|
||||
"github.com/cloudflare/cloudflared/logger"
|
||||
"github.com/cloudflare/cloudflared/sshgen"
|
||||
"github.com/cloudflare/cloudflared/token"
|
||||
"github.com/cloudflare/cloudflared/validation"
|
||||
|
||||
"github.com/getsentry/raven-go"
|
||||
@@ -271,7 +272,7 @@ func curl(c *cli.Context) error {
|
||||
if err != nil || tok == "" {
|
||||
if allowRequest {
|
||||
log.Info().Msg("You don't have an Access token set. Please run access token <access application> to fetch one.")
|
||||
return shell.Run("curl", cmdArgs...)
|
||||
return run("curl", cmdArgs...)
|
||||
}
|
||||
tok, err = token.FetchToken(appURL, log)
|
||||
if err != nil {
|
||||
@@ -282,7 +283,29 @@ func curl(c *cli.Context) error {
|
||||
|
||||
cmdArgs = append(cmdArgs, "-H")
|
||||
cmdArgs = append(cmdArgs, fmt.Sprintf("%s: %s", h2mux.CFAccessTokenHeader, tok))
|
||||
return shell.Run("curl", cmdArgs...)
|
||||
return run("curl", cmdArgs...)
|
||||
}
|
||||
|
||||
|
||||
// run kicks off a shell task and pipe the results to the respective std pipes
|
||||
func run(cmd string, args ...string) error {
|
||||
c := exec.Command(cmd, args...)
|
||||
stderr, err := c.StderrPipe()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
go func() {
|
||||
io.Copy(os.Stderr, stderr)
|
||||
}()
|
||||
|
||||
stdout, err := c.StdoutPipe()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
go func() {
|
||||
io.Copy(os.Stdout, stdout)
|
||||
}()
|
||||
return c.Run()
|
||||
}
|
||||
|
||||
// token dumps provided token to stdout
|
||||
|
@@ -2,7 +2,7 @@ package main
|
||||
|
||||
import (
|
||||
"github.com/cloudflare/cloudflared/cmd/cloudflared/access"
|
||||
"github.com/cloudflare/cloudflared/cmd/cloudflared/config"
|
||||
"github.com/cloudflare/cloudflared/config"
|
||||
|
||||
"github.com/rs/zerolog"
|
||||
)
|
||||
|
@@ -1,7 +1,7 @@
|
||||
package main
|
||||
|
||||
import (
|
||||
"github.com/cloudflare/cloudflared/cmd/cloudflared/config"
|
||||
"github.com/cloudflare/cloudflared/config"
|
||||
"github.com/cloudflare/cloudflared/tunneldns"
|
||||
|
||||
"github.com/rs/zerolog"
|
||||
|
@@ -1,7 +1,7 @@
|
||||
package main
|
||||
|
||||
import (
|
||||
"github.com/cloudflare/cloudflared/cmd/cloudflared/config"
|
||||
"github.com/cloudflare/cloudflared/config"
|
||||
"github.com/cloudflare/cloudflared/overwatch"
|
||||
|
||||
"github.com/rs/zerolog"
|
||||
|
@@ -2,6 +2,7 @@ package buildinfo
|
||||
|
||||
import (
|
||||
"github.com/rs/zerolog"
|
||||
"fmt"
|
||||
"runtime"
|
||||
)
|
||||
|
||||
@@ -25,3 +26,7 @@ func (bi *BuildInfo) Log(log *zerolog.Logger) {
|
||||
log.Info().Msgf("Version %s", bi.CloudflaredVersion)
|
||||
log.Info().Msgf("GOOS: %s, GOVersion: %s, GoArch: %s", bi.GoOS, bi.GoVersion, bi.GoArch)
|
||||
}
|
||||
|
||||
func (bi *BuildInfo) OSArch() string {
|
||||
return fmt.Sprintf("%s_%s", bi.GoOS, bi.GoArch)
|
||||
}
|
||||
|
@@ -1,380 +0,0 @@
|
||||
package config
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"io"
|
||||
"net/url"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"runtime"
|
||||
"time"
|
||||
|
||||
"github.com/mitchellh/go-homedir"
|
||||
"github.com/pkg/errors"
|
||||
"github.com/urfave/cli/v2"
|
||||
"gopkg.in/yaml.v2"
|
||||
|
||||
"github.com/cloudflare/cloudflared/validation"
|
||||
"github.com/rs/zerolog"
|
||||
)
|
||||
|
||||
var (
|
||||
// DefaultConfigFiles is the file names from which we attempt to read configuration.
|
||||
DefaultConfigFiles = []string{"config.yml", "config.yaml"}
|
||||
|
||||
// DefaultUnixConfigLocation is the primary location to find a config file
|
||||
DefaultUnixConfigLocation = "/usr/local/etc/cloudflared"
|
||||
|
||||
// DefaultUnixLogLocation is the primary location to find log files
|
||||
DefaultUnixLogLocation = "/var/log/cloudflared"
|
||||
|
||||
// Launchd doesn't set root env variables, so there is default
|
||||
// Windows default config dir was ~/cloudflare-warp in documentation; let's keep it compatible
|
||||
defaultUserConfigDirs = []string{"~/.cloudflared", "~/.cloudflare-warp", "~/cloudflare-warp"}
|
||||
defaultNixConfigDirs = []string{"/etc/cloudflared", DefaultUnixConfigLocation}
|
||||
|
||||
ErrNoConfigFile = fmt.Errorf("Cannot determine default configuration path. No file %v in %v", DefaultConfigFiles, DefaultConfigSearchDirectories())
|
||||
)
|
||||
|
||||
const (
|
||||
DefaultCredentialFile = "cert.pem"
|
||||
|
||||
// BastionFlag is to enable bastion, or jump host, operation
|
||||
BastionFlag = "bastion"
|
||||
)
|
||||
|
||||
// DefaultConfigDirectory returns the default directory of the config file
|
||||
func DefaultConfigDirectory() string {
|
||||
if runtime.GOOS == "windows" {
|
||||
path := os.Getenv("CFDPATH")
|
||||
if path == "" {
|
||||
path = filepath.Join(os.Getenv("ProgramFiles(x86)"), "cloudflared")
|
||||
if _, err := os.Stat(path); os.IsNotExist(err) { //doesn't exist, so return an empty failure string
|
||||
return ""
|
||||
}
|
||||
}
|
||||
return path
|
||||
}
|
||||
return DefaultUnixConfigLocation
|
||||
}
|
||||
|
||||
// DefaultLogDirectory returns the default directory for log files
|
||||
func DefaultLogDirectory() string {
|
||||
if runtime.GOOS == "windows" {
|
||||
return DefaultConfigDirectory()
|
||||
}
|
||||
return DefaultUnixLogLocation
|
||||
}
|
||||
|
||||
// DefaultConfigPath returns the default location of a config file
|
||||
func DefaultConfigPath() string {
|
||||
dir := DefaultConfigDirectory()
|
||||
if dir == "" {
|
||||
return DefaultConfigFiles[0]
|
||||
}
|
||||
return filepath.Join(dir, DefaultConfigFiles[0])
|
||||
}
|
||||
|
||||
// DefaultConfigSearchDirectories returns the default folder locations of the config
|
||||
func DefaultConfigSearchDirectories() []string {
|
||||
dirs := make([]string, len(defaultUserConfigDirs))
|
||||
copy(dirs, defaultUserConfigDirs)
|
||||
if runtime.GOOS != "windows" {
|
||||
dirs = append(dirs, defaultNixConfigDirs...)
|
||||
}
|
||||
return dirs
|
||||
}
|
||||
|
||||
// FileExists checks to see if a file exist at the provided path.
|
||||
func FileExists(path string) (bool, error) {
|
||||
f, err := os.Open(path)
|
||||
if err != nil {
|
||||
if os.IsNotExist(err) {
|
||||
// ignore missing files
|
||||
return false, nil
|
||||
}
|
||||
return false, err
|
||||
}
|
||||
_ = f.Close()
|
||||
return true, nil
|
||||
}
|
||||
|
||||
// FindDefaultConfigPath returns the first path that contains a config file.
|
||||
// If none of the combination of DefaultConfigSearchDirectories() and DefaultConfigFiles
|
||||
// contains a config file, return empty string.
|
||||
func FindDefaultConfigPath() string {
|
||||
for _, configDir := range DefaultConfigSearchDirectories() {
|
||||
for _, configFile := range DefaultConfigFiles {
|
||||
dirPath, err := homedir.Expand(configDir)
|
||||
if err != nil {
|
||||
continue
|
||||
}
|
||||
path := filepath.Join(dirPath, configFile)
|
||||
if ok, _ := FileExists(path); ok {
|
||||
return path
|
||||
}
|
||||
}
|
||||
}
|
||||
return ""
|
||||
}
|
||||
|
||||
// FindOrCreateConfigPath returns the first path that contains a config file
|
||||
// or creates one in the primary default path if it doesn't exist
|
||||
func FindOrCreateConfigPath() string {
|
||||
path := FindDefaultConfigPath()
|
||||
|
||||
if path == "" {
|
||||
// create the default directory if it doesn't exist
|
||||
path = DefaultConfigPath()
|
||||
if err := os.MkdirAll(filepath.Dir(path), os.ModePerm); err != nil {
|
||||
return ""
|
||||
}
|
||||
|
||||
// write a new config file out
|
||||
file, err := os.Create(path)
|
||||
if err != nil {
|
||||
return ""
|
||||
}
|
||||
defer file.Close()
|
||||
|
||||
logDir := DefaultLogDirectory()
|
||||
_ = os.MkdirAll(logDir, os.ModePerm) //try and create it. Doesn't matter if it succeed or not, only byproduct will be no logs
|
||||
|
||||
c := Root{
|
||||
LogDirectory: logDir,
|
||||
}
|
||||
if err := yaml.NewEncoder(file).Encode(&c); err != nil {
|
||||
return ""
|
||||
}
|
||||
}
|
||||
|
||||
return path
|
||||
}
|
||||
|
||||
// ValidateUnixSocket ensures --unix-socket param is used exclusively
|
||||
// i.e. it fails if a user specifies both --url and --unix-socket
|
||||
func ValidateUnixSocket(c *cli.Context) (string, error) {
|
||||
if c.IsSet("unix-socket") && (c.IsSet("url") || c.NArg() > 0) {
|
||||
return "", errors.New("--unix-socket must be used exclusivly.")
|
||||
}
|
||||
return c.String("unix-socket"), nil
|
||||
}
|
||||
|
||||
// ValidateUrl will validate url flag correctness. It can be either from --url or argument
|
||||
// Notice ValidateUnixSocket, it will enforce --unix-socket is not used with --url or argument
|
||||
func ValidateUrl(c *cli.Context, allowURLFromArgs bool) (*url.URL, error) {
|
||||
var url = c.String("url")
|
||||
if allowURLFromArgs && c.NArg() > 0 {
|
||||
if c.IsSet("url") {
|
||||
return nil, errors.New("Specified origin urls using both --url and argument. Decide which one you want, I can only support one.")
|
||||
}
|
||||
url = c.Args().Get(0)
|
||||
}
|
||||
validUrl, err := validation.ValidateUrl(url)
|
||||
return validUrl, err
|
||||
}
|
||||
|
||||
type UnvalidatedIngressRule struct {
|
||||
Hostname string
|
||||
Path string
|
||||
Service string
|
||||
OriginRequest OriginRequestConfig `yaml:"originRequest"`
|
||||
}
|
||||
|
||||
// OriginRequestConfig is a set of optional fields that users may set to
|
||||
// customize how cloudflared sends requests to origin services. It is used to set
|
||||
// up general config that apply to all rules, and also, specific per-rule
|
||||
// config.
|
||||
// Note: To specify a time.Duration in go-yaml, use e.g. "3s" or "24h".
|
||||
type OriginRequestConfig struct {
|
||||
// HTTP proxy timeout for establishing a new connection
|
||||
ConnectTimeout *time.Duration `yaml:"connectTimeout"`
|
||||
// HTTP proxy timeout for completing a TLS handshake
|
||||
TLSTimeout *time.Duration `yaml:"tlsTimeout"`
|
||||
// HTTP proxy TCP keepalive duration
|
||||
TCPKeepAlive *time.Duration `yaml:"tcpKeepAlive"`
|
||||
// HTTP proxy should disable "happy eyeballs" for IPv4/v6 fallback
|
||||
NoHappyEyeballs *bool `yaml:"noHappyEyeballs"`
|
||||
// HTTP proxy maximum keepalive connection pool size
|
||||
KeepAliveConnections *int `yaml:"keepAliveConnections"`
|
||||
// HTTP proxy timeout for closing an idle connection
|
||||
KeepAliveTimeout *time.Duration `yaml:"keepAliveTimeout"`
|
||||
// Sets the HTTP Host header for the local webserver.
|
||||
HTTPHostHeader *string `yaml:"httpHostHeader"`
|
||||
// Hostname on the origin server certificate.
|
||||
OriginServerName *string `yaml:"originServerName"`
|
||||
// Path to the CA for the certificate of your origin.
|
||||
// This option should be used only if your certificate is not signed by Cloudflare.
|
||||
CAPool *string `yaml:"caPool"`
|
||||
// Disables TLS verification of the certificate presented by your origin.
|
||||
// Will allow any certificate from the origin to be accepted.
|
||||
// Note: The connection from your machine to Cloudflare's Edge is still encrypted.
|
||||
NoTLSVerify *bool `yaml:"noTLSVerify"`
|
||||
// Disables chunked transfer encoding.
|
||||
// Useful if you are running a WSGI server.
|
||||
DisableChunkedEncoding *bool `yaml:"disableChunkedEncoding"`
|
||||
// Runs as jump host
|
||||
BastionMode *bool `yaml:"bastionMode"`
|
||||
// Listen address for the proxy.
|
||||
ProxyAddress *string `yaml:"proxyAddress"`
|
||||
// Listen port for the proxy.
|
||||
ProxyPort *uint `yaml:"proxyPort"`
|
||||
// Valid options are 'socks' or empty.
|
||||
ProxyType *string `yaml:"proxyType"`
|
||||
}
|
||||
|
||||
type Configuration struct {
|
||||
TunnelID string `yaml:"tunnel"`
|
||||
Ingress []UnvalidatedIngressRule
|
||||
WarpRouting WarpRoutingConfig `yaml:"warp-routing"`
|
||||
OriginRequest OriginRequestConfig `yaml:"originRequest"`
|
||||
sourceFile string
|
||||
}
|
||||
|
||||
type WarpRoutingConfig struct {
|
||||
Enabled bool `yaml:"enabled"`
|
||||
}
|
||||
|
||||
type configFileSettings struct {
|
||||
Configuration `yaml:",inline"`
|
||||
// older settings will be aggregated into the generic map, should be read via cli.Context
|
||||
Settings map[string]interface{} `yaml:",inline"`
|
||||
}
|
||||
|
||||
func (c *Configuration) Source() string {
|
||||
return c.sourceFile
|
||||
}
|
||||
|
||||
func (c *configFileSettings) Int(name string) (int, error) {
|
||||
if raw, ok := c.Settings[name]; ok {
|
||||
if v, ok := raw.(int); ok {
|
||||
return v, nil
|
||||
}
|
||||
return 0, fmt.Errorf("expected int found %T for %s", raw, name)
|
||||
}
|
||||
return 0, nil
|
||||
}
|
||||
|
||||
func (c *configFileSettings) Duration(name string) (time.Duration, error) {
|
||||
if raw, ok := c.Settings[name]; ok {
|
||||
switch v := raw.(type) {
|
||||
case time.Duration:
|
||||
return v, nil
|
||||
case string:
|
||||
return time.ParseDuration(v)
|
||||
}
|
||||
return 0, fmt.Errorf("expected duration found %T for %s", raw, name)
|
||||
}
|
||||
return 0, nil
|
||||
}
|
||||
|
||||
func (c *configFileSettings) Float64(name string) (float64, error) {
|
||||
if raw, ok := c.Settings[name]; ok {
|
||||
if v, ok := raw.(float64); ok {
|
||||
return v, nil
|
||||
}
|
||||
return 0, fmt.Errorf("expected float found %T for %s", raw, name)
|
||||
}
|
||||
return 0, nil
|
||||
}
|
||||
|
||||
func (c *configFileSettings) String(name string) (string, error) {
|
||||
if raw, ok := c.Settings[name]; ok {
|
||||
if v, ok := raw.(string); ok {
|
||||
return v, nil
|
||||
}
|
||||
return "", fmt.Errorf("expected string found %T for %s", raw, name)
|
||||
}
|
||||
return "", nil
|
||||
}
|
||||
|
||||
func (c *configFileSettings) StringSlice(name string) ([]string, error) {
|
||||
if raw, ok := c.Settings[name]; ok {
|
||||
if slice, ok := raw.([]interface{}); ok {
|
||||
strSlice := make([]string, len(slice))
|
||||
for i, v := range slice {
|
||||
str, ok := v.(string)
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("expected string, found %T for %v", i, v)
|
||||
}
|
||||
strSlice[i] = str
|
||||
}
|
||||
return strSlice, nil
|
||||
}
|
||||
return nil, fmt.Errorf("expected string slice found %T for %s", raw, name)
|
||||
}
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
func (c *configFileSettings) IntSlice(name string) ([]int, error) {
|
||||
if raw, ok := c.Settings[name]; ok {
|
||||
if slice, ok := raw.([]interface{}); ok {
|
||||
intSlice := make([]int, len(slice))
|
||||
for i, v := range slice {
|
||||
str, ok := v.(int)
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("expected int, found %T for %v ", v, v)
|
||||
}
|
||||
intSlice[i] = str
|
||||
}
|
||||
return intSlice, nil
|
||||
}
|
||||
if v, ok := raw.([]int); ok {
|
||||
return v, nil
|
||||
}
|
||||
return nil, fmt.Errorf("expected int slice found %T for %s", raw, name)
|
||||
}
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
func (c *configFileSettings) Generic(name string) (cli.Generic, error) {
|
||||
return nil, errors.New("option type Generic not supported")
|
||||
}
|
||||
|
||||
func (c *configFileSettings) Bool(name string) (bool, error) {
|
||||
if raw, ok := c.Settings[name]; ok {
|
||||
if v, ok := raw.(bool); ok {
|
||||
return v, nil
|
||||
}
|
||||
return false, fmt.Errorf("expected boolean found %T for %s", raw, name)
|
||||
}
|
||||
return false, nil
|
||||
}
|
||||
|
||||
var configuration configFileSettings
|
||||
|
||||
func GetConfiguration() *Configuration {
|
||||
return &configuration.Configuration
|
||||
}
|
||||
|
||||
// ReadConfigFile returns InputSourceContext initialized from the configuration file.
|
||||
// On repeat calls returns with the same file, returns without reading the file again; however,
|
||||
// if value of "config" flag changes, will read the new config file
|
||||
func ReadConfigFile(c *cli.Context, log *zerolog.Logger) (*configFileSettings, error) {
|
||||
configFile := c.String("config")
|
||||
if configuration.Source() == configFile || configFile == "" {
|
||||
if configuration.Source() == "" {
|
||||
return nil, ErrNoConfigFile
|
||||
}
|
||||
return &configuration, nil
|
||||
}
|
||||
|
||||
log.Debug().Msgf("Loading configuration from %s", configFile)
|
||||
file, err := os.Open(configFile)
|
||||
if err != nil {
|
||||
if os.IsNotExist(err) {
|
||||
err = ErrNoConfigFile
|
||||
}
|
||||
return nil, err
|
||||
}
|
||||
defer file.Close()
|
||||
if err := yaml.NewDecoder(file).Decode(&configuration); err != nil {
|
||||
if err == io.EOF {
|
||||
log.Error().Msgf("Configuration file %s was empty", configFile)
|
||||
return &configuration, nil
|
||||
}
|
||||
return nil, errors.Wrap(err, "error parsing YAML in config file at "+configFile)
|
||||
}
|
||||
configuration.sourceFile = configFile
|
||||
return &configuration, nil
|
||||
}
|
@@ -1,83 +0,0 @@
|
||||
package config
|
||||
|
||||
import (
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"gopkg.in/yaml.v2"
|
||||
)
|
||||
|
||||
func TestConfigFileSettings(t *testing.T) {
|
||||
var (
|
||||
firstIngress = UnvalidatedIngressRule{
|
||||
Hostname: "tunnel1.example.com",
|
||||
Path: "/id",
|
||||
Service: "https://localhost:8000",
|
||||
}
|
||||
secondIngress = UnvalidatedIngressRule{
|
||||
Hostname: "*",
|
||||
Path: "",
|
||||
Service: "https://localhost:8001",
|
||||
}
|
||||
warpRouting = WarpRoutingConfig{
|
||||
Enabled: true,
|
||||
}
|
||||
)
|
||||
rawYAML := `
|
||||
tunnel: config-file-test
|
||||
ingress:
|
||||
- hostname: tunnel1.example.com
|
||||
path: /id
|
||||
service: https://localhost:8000
|
||||
- hostname: "*"
|
||||
service: https://localhost:8001
|
||||
warp-routing:
|
||||
enabled: true
|
||||
retries: 5
|
||||
grace-period: 30s
|
||||
percentage: 3.14
|
||||
hostname: example.com
|
||||
tag:
|
||||
- test
|
||||
- central-1
|
||||
counters:
|
||||
- 123
|
||||
- 456
|
||||
`
|
||||
var config configFileSettings
|
||||
err := yaml.Unmarshal([]byte(rawYAML), &config)
|
||||
assert.NoError(t, err)
|
||||
|
||||
assert.Equal(t, "config-file-test", config.TunnelID)
|
||||
assert.Equal(t, firstIngress, config.Ingress[0])
|
||||
assert.Equal(t, secondIngress, config.Ingress[1])
|
||||
assert.Equal(t, warpRouting, config.WarpRouting)
|
||||
|
||||
retries, err := config.Int("retries")
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, 5, retries)
|
||||
|
||||
gracePeriod, err := config.Duration("grace-period")
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, time.Second*30, gracePeriod)
|
||||
|
||||
percentage, err := config.Float64("percentage")
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, 3.14, percentage)
|
||||
|
||||
hostname, err := config.String("hostname")
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, "example.com", hostname)
|
||||
|
||||
tags, err := config.StringSlice("tag")
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, "test", tags[0])
|
||||
assert.Equal(t, "central-1", tags[1])
|
||||
|
||||
counters, err := config.IntSlice("counters")
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, 123, counters[0])
|
||||
assert.Equal(t, 456, counters[1])
|
||||
|
||||
}
|
@@ -1,112 +0,0 @@
|
||||
package config
|
||||
|
||||
import (
|
||||
"io"
|
||||
"os"
|
||||
|
||||
"github.com/cloudflare/cloudflared/watcher"
|
||||
|
||||
"github.com/pkg/errors"
|
||||
"github.com/rs/zerolog"
|
||||
"gopkg.in/yaml.v2"
|
||||
)
|
||||
|
||||
// Notifier sends out config updates
|
||||
type Notifier interface {
|
||||
ConfigDidUpdate(Root)
|
||||
}
|
||||
|
||||
// Manager is the base functions of the config manager
|
||||
type Manager interface {
|
||||
Start(Notifier) error
|
||||
Shutdown()
|
||||
}
|
||||
|
||||
// FileManager watches the yaml config for changes
|
||||
// sends updates to the service to reconfigure to match the updated config
|
||||
type FileManager struct {
|
||||
watcher watcher.Notifier
|
||||
notifier Notifier
|
||||
configPath string
|
||||
log *zerolog.Logger
|
||||
ReadConfig func(string, *zerolog.Logger) (Root, error)
|
||||
}
|
||||
|
||||
// NewFileManager creates a config manager
|
||||
func NewFileManager(watcher watcher.Notifier, configPath string, log *zerolog.Logger) (*FileManager, error) {
|
||||
m := &FileManager{
|
||||
watcher: watcher,
|
||||
configPath: configPath,
|
||||
log: log,
|
||||
ReadConfig: readConfigFromPath,
|
||||
}
|
||||
err := watcher.Add(configPath)
|
||||
return m, err
|
||||
}
|
||||
|
||||
// Start starts the runloop to watch for config changes
|
||||
func (m *FileManager) Start(notifier Notifier) error {
|
||||
m.notifier = notifier
|
||||
|
||||
// update the notifier with a fresh config on start
|
||||
config, err := m.GetConfig()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
notifier.ConfigDidUpdate(config)
|
||||
|
||||
m.watcher.Start(m)
|
||||
return nil
|
||||
}
|
||||
|
||||
// GetConfig reads the yaml file from the disk
|
||||
func (m *FileManager) GetConfig() (Root, error) {
|
||||
return m.ReadConfig(m.configPath, m.log)
|
||||
}
|
||||
|
||||
// Shutdown stops the watcher
|
||||
func (m *FileManager) Shutdown() {
|
||||
m.watcher.Shutdown()
|
||||
}
|
||||
|
||||
func readConfigFromPath(configPath string, log *zerolog.Logger) (Root, error) {
|
||||
if configPath == "" {
|
||||
return Root{}, errors.New("unable to find config file")
|
||||
}
|
||||
|
||||
file, err := os.Open(configPath)
|
||||
if err != nil {
|
||||
return Root{}, err
|
||||
}
|
||||
defer file.Close()
|
||||
|
||||
var config Root
|
||||
if err := yaml.NewDecoder(file).Decode(&config); err != nil {
|
||||
if err == io.EOF {
|
||||
log.Error().Msgf("Configuration file %s was empty", configPath)
|
||||
return Root{}, nil
|
||||
}
|
||||
return Root{}, errors.Wrap(err, "error parsing YAML in config file at "+configPath)
|
||||
}
|
||||
|
||||
return config, nil
|
||||
}
|
||||
|
||||
// File change notifications from the watcher
|
||||
|
||||
// WatcherItemDidChange triggers when the yaml config is updated
|
||||
// sends the updated config to the service to reload its state
|
||||
func (m *FileManager) WatcherItemDidChange(filepath string) {
|
||||
config, err := m.GetConfig()
|
||||
if err != nil {
|
||||
m.log.Err(err).Msg("Failed to read new config")
|
||||
return
|
||||
}
|
||||
m.log.Info().Msg("Config file has been updated")
|
||||
m.notifier.ConfigDidUpdate(config)
|
||||
}
|
||||
|
||||
// WatcherDidError notifies of errors with the file watcher
|
||||
func (m *FileManager) WatcherDidError(err error) {
|
||||
m.log.Err(err).Msg("Config watcher encountered an error")
|
||||
}
|
@@ -1,88 +0,0 @@
|
||||
package config
|
||||
|
||||
import (
|
||||
"os"
|
||||
"testing"
|
||||
|
||||
"github.com/cloudflare/cloudflared/watcher"
|
||||
|
||||
"github.com/rs/zerolog"
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
type mockNotifier struct {
|
||||
configs []Root
|
||||
}
|
||||
|
||||
func (n *mockNotifier) ConfigDidUpdate(c Root) {
|
||||
n.configs = append(n.configs, c)
|
||||
}
|
||||
|
||||
type mockFileWatcher struct {
|
||||
path string
|
||||
notifier watcher.Notification
|
||||
ready chan struct{}
|
||||
}
|
||||
|
||||
func (w *mockFileWatcher) Start(n watcher.Notification) {
|
||||
w.notifier = n
|
||||
w.ready <- struct{}{}
|
||||
}
|
||||
|
||||
func (w *mockFileWatcher) Add(string) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (w *mockFileWatcher) Shutdown() {
|
||||
|
||||
}
|
||||
|
||||
func (w *mockFileWatcher) TriggerChange() {
|
||||
w.notifier.WatcherItemDidChange(w.path)
|
||||
}
|
||||
|
||||
func TestConfigChanged(t *testing.T) {
|
||||
filePath := "config.yaml"
|
||||
f, err := os.Create(filePath)
|
||||
assert.NoError(t, err)
|
||||
defer func() {
|
||||
_ = f.Close()
|
||||
_ = os.Remove(filePath)
|
||||
}()
|
||||
c := &Root{
|
||||
Forwarders: []Forwarder{
|
||||
{
|
||||
URL: "test.daltoniam.com",
|
||||
Listener: "127.0.0.1:8080",
|
||||
},
|
||||
},
|
||||
}
|
||||
configRead := func(configPath string, log *zerolog.Logger) (Root, error) {
|
||||
return *c, nil
|
||||
}
|
||||
wait := make(chan struct{})
|
||||
w := &mockFileWatcher{path: filePath, ready: wait}
|
||||
|
||||
log := zerolog.Nop()
|
||||
|
||||
service, err := NewFileManager(w, filePath, &log)
|
||||
service.ReadConfig = configRead
|
||||
assert.NoError(t, err)
|
||||
|
||||
n := &mockNotifier{}
|
||||
go service.Start(n)
|
||||
|
||||
<-wait
|
||||
c.Forwarders = append(c.Forwarders, Forwarder{URL: "add.daltoniam.com", Listener: "127.0.0.1:8081"})
|
||||
w.TriggerChange()
|
||||
|
||||
service.Shutdown()
|
||||
|
||||
assert.Len(t, n.configs, 2, "did not get 2 config updates as expected")
|
||||
assert.Len(t, n.configs[0].Forwarders, 1, "not the amount of forwarders expected")
|
||||
assert.Len(t, n.configs[1].Forwarders, 2, "not the amount of forwarders expected")
|
||||
|
||||
assert.Equal(t, n.configs[0].Forwarders[0].Hash(), c.Forwarders[0].Hash(), "forwarder hashes don't match")
|
||||
assert.Equal(t, n.configs[1].Forwarders[0].Hash(), c.Forwarders[0].Hash(), "forwarder hashes don't match")
|
||||
assert.Equal(t, n.configs[1].Forwarders[1].Hash(), c.Forwarders[1].Hash(), "forwarder hashes don't match")
|
||||
}
|
@@ -1,113 +0,0 @@
|
||||
package config
|
||||
|
||||
import (
|
||||
"crypto/md5"
|
||||
"fmt"
|
||||
"io"
|
||||
"strings"
|
||||
|
||||
"github.com/cloudflare/cloudflared/tunneldns"
|
||||
)
|
||||
|
||||
// Forwarder represents a client side listener to forward traffic to the edge
|
||||
type Forwarder struct {
|
||||
URL string `json:"url"`
|
||||
Listener string `json:"listener"`
|
||||
TokenClientID string `json:"service_token_id" yaml:"serviceTokenID"`
|
||||
TokenSecret string `json:"secret_token_id" yaml:"serviceTokenSecret"`
|
||||
Destination string `json:"destination"`
|
||||
}
|
||||
|
||||
// Tunnel represents a tunnel that should be started
|
||||
type Tunnel struct {
|
||||
URL string `json:"url"`
|
||||
Origin string `json:"origin"`
|
||||
ProtocolType string `json:"type"`
|
||||
}
|
||||
|
||||
// DNSResolver represents a client side DNS resolver
|
||||
type DNSResolver struct {
|
||||
Enabled bool `json:"enabled"`
|
||||
Address string `json:"address,omitempty"`
|
||||
Port uint16 `json:"port,omitempty"`
|
||||
Upstreams []string `json:"upstreams,omitempty"`
|
||||
Bootstraps []string `json:"bootstraps,omitempty"`
|
||||
MaxUpstreamConnections int `json:"max_upstream_connections,omitempty"`
|
||||
}
|
||||
|
||||
// Root is the base options to configure the service
|
||||
type Root struct {
|
||||
LogDirectory string `json:"log_directory" yaml:"logDirectory,omitempty"`
|
||||
LogLevel string `json:"log_level" yaml:"logLevel,omitempty"`
|
||||
Forwarders []Forwarder `json:"forwarders,omitempty" yaml:"forwarders,omitempty"`
|
||||
Tunnels []Tunnel `json:"tunnels,omitempty" yaml:"tunnels,omitempty"`
|
||||
Resolver DNSResolver `json:"resolver,omitempty" yaml:"resolver,omitempty"`
|
||||
}
|
||||
|
||||
// Hash returns the computed values to see if the forwarder values change
|
||||
func (f *Forwarder) Hash() string {
|
||||
h := md5.New()
|
||||
io.WriteString(h, f.URL)
|
||||
io.WriteString(h, f.Listener)
|
||||
io.WriteString(h, f.TokenClientID)
|
||||
io.WriteString(h, f.TokenSecret)
|
||||
io.WriteString(h, f.Destination)
|
||||
return fmt.Sprintf("%x", h.Sum(nil))
|
||||
}
|
||||
|
||||
// Hash returns the computed values to see if the forwarder values change
|
||||
func (r *DNSResolver) Hash() string {
|
||||
h := md5.New()
|
||||
io.WriteString(h, r.Address)
|
||||
io.WriteString(h, strings.Join(r.Bootstraps, ","))
|
||||
io.WriteString(h, strings.Join(r.Upstreams, ","))
|
||||
io.WriteString(h, fmt.Sprintf("%d", r.Port))
|
||||
io.WriteString(h, fmt.Sprintf("%d", r.MaxUpstreamConnections))
|
||||
io.WriteString(h, fmt.Sprintf("%v", r.Enabled))
|
||||
return fmt.Sprintf("%x", h.Sum(nil))
|
||||
}
|
||||
|
||||
// EnabledOrDefault returns the enabled property
|
||||
func (r *DNSResolver) EnabledOrDefault() bool {
|
||||
return r.Enabled
|
||||
}
|
||||
|
||||
// AddressOrDefault returns the address or returns the default if empty
|
||||
func (r *DNSResolver) AddressOrDefault() string {
|
||||
if r.Address != "" {
|
||||
return r.Address
|
||||
}
|
||||
return "localhost"
|
||||
}
|
||||
|
||||
// PortOrDefault return the port or returns the default if 0
|
||||
func (r *DNSResolver) PortOrDefault() uint16 {
|
||||
if r.Port > 0 {
|
||||
return r.Port
|
||||
}
|
||||
return 53
|
||||
}
|
||||
|
||||
// UpstreamsOrDefault returns the upstreams or returns the default if empty
|
||||
func (r *DNSResolver) UpstreamsOrDefault() []string {
|
||||
if len(r.Upstreams) > 0 {
|
||||
return r.Upstreams
|
||||
}
|
||||
return []string{"https://1.1.1.1/dns-query", "https://1.0.0.1/dns-query"}
|
||||
}
|
||||
|
||||
// BootstrapsOrDefault returns the bootstraps or returns the default if empty
|
||||
func (r *DNSResolver) BootstrapsOrDefault() []string {
|
||||
if len(r.Bootstraps) > 0 {
|
||||
return r.Bootstraps
|
||||
}
|
||||
return []string{"https://162.159.36.1/dns-query", "https://162.159.46.1/dns-query", "https://[2606:4700:4700::1111]/dns-query", "https://[2606:4700:4700::1001]/dns-query"}
|
||||
}
|
||||
|
||||
// MaxUpstreamConnectionsOrDefault return the max upstream connections or returns the default if negative
|
||||
func (r *DNSResolver) MaxUpstreamConnectionsOrDefault() int {
|
||||
if r.MaxUpstreamConnections >= 0 {
|
||||
return r.MaxUpstreamConnections
|
||||
}
|
||||
return tunneldns.MaxUpstreamConnsDefault
|
||||
}
|
@@ -1,176 +0,0 @@
|
||||
// Package encrypter is suitable for encrypting messages you would like to securely share between two points.
|
||||
// Useful for providing end to end encryption (E2EE). It uses Box (NaCl) for encrypting the messages.
|
||||
// tldr is it uses Elliptic Curves (Curve25519) for the keys, XSalsa20 and Poly1305 for encryption.
|
||||
// You can read more here https://godoc.org/golang.org/x/crypto/nacl/box.
|
||||
//
|
||||
// msg := []byte("super safe message.")
|
||||
// alice, err := New("alice_priv_key.pem", "alice_pub_key.pem")
|
||||
// if err != nil {
|
||||
// log.Fatal(err)
|
||||
// }
|
||||
//
|
||||
// bob, err := New("bob_priv_key.pem", "bob_pub_key.pem")
|
||||
// if err != nil {
|
||||
// log.Fatal(err)
|
||||
// }
|
||||
// encrypted, err := alice.Encrypt(msg, bob.PublicKey())
|
||||
// if err != nil {
|
||||
// log.Fatal(err)
|
||||
// }
|
||||
//
|
||||
// data, err := bob.Decrypt(encrypted, alice.PublicKey())
|
||||
// if err != nil {
|
||||
// log.Fatal(err)
|
||||
// }
|
||||
// fmt.Println(string(data))
|
||||
package encrypter
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"crypto/rand"
|
||||
"encoding/base64"
|
||||
"encoding/pem"
|
||||
"errors"
|
||||
"io"
|
||||
"os"
|
||||
|
||||
"golang.org/x/crypto/nacl/box"
|
||||
)
|
||||
|
||||
// Encrypter represents a keypair value with auxiliary functions to make
|
||||
// doing encryption and decryption easier
|
||||
type Encrypter struct {
|
||||
privateKey *[32]byte
|
||||
publicKey *[32]byte
|
||||
}
|
||||
|
||||
// New returns a new encrypter with initialized keypair
|
||||
func New(privateKey, publicKey string) (*Encrypter, error) {
|
||||
e := &Encrypter{}
|
||||
pubKey, key, err := e.fetchOrGenerateKeys(privateKey, publicKey)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
e.privateKey, e.publicKey = key, pubKey
|
||||
return e, nil
|
||||
}
|
||||
|
||||
// PublicKey returns a base64 encoded public key. Useful for transport (like in HTTP requests)
|
||||
func (e *Encrypter) PublicKey() string {
|
||||
return base64.URLEncoding.EncodeToString(e.publicKey[:])
|
||||
}
|
||||
|
||||
// Decrypt data that was encrypted using our publicKey. It will use our privateKey and the sender's publicKey to decrypt
|
||||
// data is an encrypted buffer of data, mostly like from the Encrypt function. Messages contain the nonce data on the front
|
||||
// of the message.
|
||||
// senderPublicKey is a base64 encoded version of the sender's public key (most likely from the PublicKey function).
|
||||
// The return value is the decrypted buffer or an error.
|
||||
func (e *Encrypter) Decrypt(data []byte, senderPublicKey string) ([]byte, error) {
|
||||
var decryptNonce [24]byte
|
||||
copy(decryptNonce[:], data[:24]) // we pull the nonce from the front of the actual message.
|
||||
pubKey, err := e.decodePublicKey(senderPublicKey)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
decrypted, ok := box.Open(nil, data[24:], &decryptNonce, pubKey, e.privateKey)
|
||||
if !ok {
|
||||
return nil, errors.New("failed to decrypt message")
|
||||
}
|
||||
return decrypted, nil
|
||||
}
|
||||
|
||||
// Encrypt data using our privateKey and the recipient publicKey
|
||||
// data is a buffer of data that we would like to encrypt. Messages will have the nonce added to front
|
||||
// as they have to unique for each message shared.
|
||||
// recipientPublicKey is a base64 encoded version of the sender's public key (most likely from the PublicKey function).
|
||||
// The return value is the encrypted buffer or an error.
|
||||
func (e *Encrypter) Encrypt(data []byte, recipientPublicKey string) ([]byte, error) {
|
||||
var nonce [24]byte
|
||||
if _, err := io.ReadFull(rand.Reader, nonce[:]); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
pubKey, err := e.decodePublicKey(recipientPublicKey)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
// This encrypts msg and adds the nonce to the front of the message, since the nonce has to be
|
||||
// the same for encrypting and decrypting
|
||||
return box.Seal(nonce[:], data, &nonce, pubKey, e.privateKey), nil
|
||||
}
|
||||
|
||||
// WriteKeys keys will take the currently initialized keypair and write them to provided filenames
|
||||
func (e *Encrypter) WriteKeys(privateKey, publicKey string) error {
|
||||
if err := e.writeKey(e.privateKey[:], "BOX PRIVATE KEY", privateKey); err != nil {
|
||||
return err
|
||||
}
|
||||
return e.writeKey(e.publicKey[:], "PUBLIC KEY", publicKey)
|
||||
}
|
||||
|
||||
// fetchOrGenerateKeys will either load or create a keypair if it doesn't exist
|
||||
func (e *Encrypter) fetchOrGenerateKeys(privateKey, publicKey string) (*[32]byte, *[32]byte, error) {
|
||||
key, err := e.fetchKey(privateKey)
|
||||
if os.IsNotExist(err) {
|
||||
return box.GenerateKey(rand.Reader)
|
||||
} else if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
|
||||
pub, err := e.fetchKey(publicKey)
|
||||
if os.IsNotExist(err) {
|
||||
return box.GenerateKey(rand.Reader)
|
||||
} else if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
return pub, key, nil
|
||||
}
|
||||
|
||||
// writeKey will write a key to disk in DER format (it's a standard pem key)
|
||||
func (e *Encrypter) writeKey(key []byte, pemType, filename string) error {
|
||||
data := pem.EncodeToMemory(&pem.Block{
|
||||
Type: pemType,
|
||||
Bytes: key,
|
||||
})
|
||||
|
||||
f, err := os.Create(filename)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
_, err = f.Write(data)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// fetchKey will load a a DER formatted key from disk
|
||||
func (e *Encrypter) fetchKey(filename string) (*[32]byte, error) {
|
||||
f, err := os.Open(filename)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
buf := new(bytes.Buffer)
|
||||
io.Copy(buf, f)
|
||||
|
||||
p, _ := pem.Decode(buf.Bytes())
|
||||
if p == nil {
|
||||
return nil, errors.New("Failed to decode key")
|
||||
}
|
||||
var newKey [32]byte
|
||||
copy(newKey[:], p.Bytes)
|
||||
|
||||
return &newKey, nil
|
||||
}
|
||||
|
||||
// decodePublicKey will base64 decode the provided key to the box representation
|
||||
func (e *Encrypter) decodePublicKey(key string) (*[32]byte, error) {
|
||||
pub, err := base64.URLEncoding.DecodeString(key)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
var newKey [32]byte
|
||||
copy(newKey[:], pub)
|
||||
return &newKey, nil
|
||||
}
|
@@ -8,8 +8,8 @@ import (
|
||||
"path/filepath"
|
||||
|
||||
"github.com/cloudflare/cloudflared/cmd/cloudflared/cliutil"
|
||||
"github.com/cloudflare/cloudflared/cmd/cloudflared/config"
|
||||
"github.com/cloudflare/cloudflared/cmd/cloudflared/tunnel"
|
||||
"github.com/cloudflare/cloudflared/config"
|
||||
"github.com/cloudflare/cloudflared/logger"
|
||||
|
||||
"github.com/rs/zerolog"
|
||||
|
@@ -8,13 +8,13 @@ import (
|
||||
|
||||
"github.com/cloudflare/cloudflared/cmd/cloudflared/access"
|
||||
"github.com/cloudflare/cloudflared/cmd/cloudflared/cliutil"
|
||||
"github.com/cloudflare/cloudflared/cmd/cloudflared/config"
|
||||
"github.com/cloudflare/cloudflared/config"
|
||||
"github.com/cloudflare/cloudflared/cmd/cloudflared/proxydns"
|
||||
"github.com/cloudflare/cloudflared/cmd/cloudflared/tunnel"
|
||||
"github.com/cloudflare/cloudflared/cmd/cloudflared/updater"
|
||||
"github.com/cloudflare/cloudflared/logger"
|
||||
"github.com/cloudflare/cloudflared/metrics"
|
||||
"github.com/cloudflare/cloudflared/overwatch"
|
||||
"github.com/cloudflare/cloudflared/tunneldns"
|
||||
"github.com/cloudflare/cloudflared/watcher"
|
||||
|
||||
"github.com/getsentry/raven-go"
|
||||
@@ -130,7 +130,7 @@ To determine if an update happened in a script, check for error code 11.`,
|
||||
},
|
||||
}
|
||||
cmds = append(cmds, tunnel.Commands()...)
|
||||
cmds = append(cmds, tunneldns.Command(false))
|
||||
cmds = append(cmds, proxydns.Command(false))
|
||||
cmds = append(cmds, access.Commands()...)
|
||||
return cmds
|
||||
}
|
||||
|
@@ -1,45 +0,0 @@
|
||||
package path
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"net/url"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
|
||||
"github.com/cloudflare/cloudflared/cmd/cloudflared/config"
|
||||
"github.com/mitchellh/go-homedir"
|
||||
)
|
||||
|
||||
// GenerateAppTokenFilePathFromURL will return a filepath for given Access org token
|
||||
func GenerateAppTokenFilePathFromURL(url *url.URL, suffix string) (string, error) {
|
||||
configPath, err := getConfigPath()
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
name := strings.Replace(fmt.Sprintf("%s%s-%s", url.Hostname(), url.EscapedPath(), suffix), "/", "-", -1)
|
||||
return filepath.Join(configPath, name), nil
|
||||
}
|
||||
|
||||
// GenerateOrgTokenFilePathFromURL will return a filepath for given Access application token
|
||||
func GenerateOrgTokenFilePathFromURL(authDomain string) (string, error) {
|
||||
configPath, err := getConfigPath()
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
name := strings.Replace(fmt.Sprintf("%s-org-token", authDomain), "/", "-", -1)
|
||||
return filepath.Join(configPath, name), nil
|
||||
}
|
||||
|
||||
func getConfigPath() (string, error) {
|
||||
configPath, err := homedir.Expand(config.DefaultConfigSearchDirectories()[0])
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
ok, err := config.FileExists(configPath)
|
||||
if !ok && err == nil {
|
||||
// create config directory if doesn't already exist
|
||||
err = os.Mkdir(configPath, 0700)
|
||||
}
|
||||
return configPath, err
|
||||
}
|
115
cmd/cloudflared/proxydns/cmd.go
Normal file
115
cmd/cloudflared/proxydns/cmd.go
Normal file
@@ -0,0 +1,115 @@
|
||||
package proxydns
|
||||
|
||||
import (
|
||||
"net"
|
||||
"os"
|
||||
"os/signal"
|
||||
"syscall"
|
||||
|
||||
"github.com/urfave/cli/v2"
|
||||
|
||||
"github.com/cloudflare/cloudflared/cmd/cloudflared/cliutil"
|
||||
"github.com/cloudflare/cloudflared/logger"
|
||||
"github.com/cloudflare/cloudflared/metrics"
|
||||
"github.com/cloudflare/cloudflared/tunneldns"
|
||||
)
|
||||
|
||||
func Command(hidden bool) *cli.Command {
|
||||
return &cli.Command{
|
||||
Name: "proxy-dns",
|
||||
Action: cliutil.ErrorHandler(Run),
|
||||
Usage: "Run a DNS over HTTPS proxy server.",
|
||||
Flags: []cli.Flag{
|
||||
&cli.StringFlag{
|
||||
Name: "metrics",
|
||||
Value: "localhost:",
|
||||
Usage: "Listen address for metrics reporting.",
|
||||
EnvVars: []string{"TUNNEL_METRICS"},
|
||||
},
|
||||
&cli.StringFlag{
|
||||
Name: "address",
|
||||
Usage: "Listen address for the DNS over HTTPS proxy server.",
|
||||
Value: "localhost",
|
||||
EnvVars: []string{"TUNNEL_DNS_ADDRESS"},
|
||||
},
|
||||
// Note TUN-3758 , we use Int because UInt is not supported with altsrc
|
||||
&cli.IntFlag{
|
||||
Name: "port",
|
||||
Usage: "Listen on given port for the DNS over HTTPS proxy server.",
|
||||
Value: 53,
|
||||
EnvVars: []string{"TUNNEL_DNS_PORT"},
|
||||
},
|
||||
&cli.StringSliceFlag{
|
||||
Name: "upstream",
|
||||
Usage: "Upstream endpoint URL, you can specify multiple endpoints for redundancy.",
|
||||
Value: cli.NewStringSlice("https://1.1.1.1/dns-query", "https://1.0.0.1/dns-query"),
|
||||
EnvVars: []string{"TUNNEL_DNS_UPSTREAM"},
|
||||
},
|
||||
&cli.StringSliceFlag{
|
||||
Name: "bootstrap",
|
||||
Usage: "bootstrap endpoint URL, you can specify multiple endpoints for redundancy.",
|
||||
Value: cli.NewStringSlice("https://162.159.36.1/dns-query", "https://162.159.46.1/dns-query", "https://[2606:4700:4700::1111]/dns-query", "https://[2606:4700:4700::1001]/dns-query"),
|
||||
EnvVars: []string{"TUNNEL_DNS_BOOTSTRAP"},
|
||||
},
|
||||
&cli.IntFlag{
|
||||
Name: "max-upstream-conns",
|
||||
Usage: "Maximum concurrent connections to upstream. Setting to 0 means unlimited.",
|
||||
Value: tunneldns.MaxUpstreamConnsDefault,
|
||||
EnvVars: []string{"TUNNEL_DNS_MAX_UPSTREAM_CONNS"},
|
||||
},
|
||||
},
|
||||
ArgsUsage: " ", // can't be the empty string or we get the default output
|
||||
Hidden: hidden,
|
||||
}
|
||||
}
|
||||
|
||||
// Run implements a foreground runner
|
||||
func Run(c *cli.Context) error {
|
||||
log := logger.CreateLoggerFromContext(c, logger.EnableTerminalLog)
|
||||
|
||||
metricsListener, err := net.Listen("tcp", c.String("metrics"))
|
||||
if err != nil {
|
||||
log.Fatal().Err(err).Msg("Failed to open the metrics listener")
|
||||
}
|
||||
|
||||
go metrics.ServeMetrics(metricsListener, nil, nil, log)
|
||||
|
||||
listener, err := tunneldns.CreateListener(
|
||||
c.String("address"),
|
||||
// Note TUN-3758 , we use Int because UInt is not supported with altsrc
|
||||
uint16(c.Int("port")),
|
||||
c.StringSlice("upstream"),
|
||||
c.StringSlice("bootstrap"),
|
||||
c.Int("max-upstream-conns"),
|
||||
log,
|
||||
)
|
||||
|
||||
if err != nil {
|
||||
log.Err(err).Msg("Failed to create the listeners")
|
||||
return err
|
||||
}
|
||||
|
||||
// Try to start the server
|
||||
readySignal := make(chan struct{})
|
||||
err = listener.Start(readySignal)
|
||||
if err != nil {
|
||||
log.Err(err).Msg("Failed to start the listeners")
|
||||
return listener.Stop()
|
||||
}
|
||||
<-readySignal
|
||||
|
||||
// Wait for signal
|
||||
signals := make(chan os.Signal, 10)
|
||||
signal.Notify(signals, syscall.SIGTERM, syscall.SIGINT)
|
||||
defer signal.Stop(signals)
|
||||
<-signals
|
||||
|
||||
// Shut down server
|
||||
err = listener.Stop()
|
||||
if err != nil {
|
||||
log.Err(err).Msg("failed to stop")
|
||||
}
|
||||
return err
|
||||
}
|
||||
|
||||
|
@@ -12,7 +12,7 @@ import (
|
||||
|
||||
"github.com/mitchellh/go-homedir"
|
||||
|
||||
"github.com/cloudflare/cloudflared/cmd/cloudflared/config"
|
||||
"github.com/cloudflare/cloudflared/config"
|
||||
)
|
||||
|
||||
type ServiceTemplate struct {
|
||||
|
@@ -1,11 +0,0 @@
|
||||
//+build darwin
|
||||
|
||||
package shell
|
||||
|
||||
import (
|
||||
"os/exec"
|
||||
)
|
||||
|
||||
func getBrowserCmd(url string) *exec.Cmd {
|
||||
return exec.Command("open", url)
|
||||
}
|
@@ -1,11 +0,0 @@
|
||||
//+build !windows,!darwin,!linux,!netbsd,!freebsd,!openbsd
|
||||
|
||||
package shell
|
||||
|
||||
import (
|
||||
"os/exec"
|
||||
)
|
||||
|
||||
func getBrowserCmd(url string) *exec.Cmd {
|
||||
return nil
|
||||
}
|
@@ -1,11 +0,0 @@
|
||||
//+build linux freebsd openbsd netbsd
|
||||
|
||||
package shell
|
||||
|
||||
import (
|
||||
"os/exec"
|
||||
)
|
||||
|
||||
func getBrowserCmd(url string) *exec.Cmd {
|
||||
return exec.Command("xdg-open", url)
|
||||
}
|
@@ -1,18 +0,0 @@
|
||||
//+build windows
|
||||
|
||||
package shell
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"os/exec"
|
||||
"syscall"
|
||||
)
|
||||
|
||||
func getBrowserCmd(url string) *exec.Cmd {
|
||||
cmd := exec.Command("cmd")
|
||||
// CmdLine is only defined when compiling for windows.
|
||||
// Empty string is the cmd proc "Title". Needs to be included because the start command will interpret the first
|
||||
// quoted string as that field and we want to quote the URL.
|
||||
cmd.SysProcAttr = &syscall.SysProcAttr{CmdLine: fmt.Sprintf(`/c start "" "%s"`, url)}
|
||||
return cmd
|
||||
}
|
@@ -1,33 +0,0 @@
|
||||
package shell
|
||||
|
||||
import (
|
||||
"io"
|
||||
"os"
|
||||
"os/exec"
|
||||
)
|
||||
|
||||
// OpenBrowser opens the specified URL in the default browser of the user
|
||||
func OpenBrowser(url string) error {
|
||||
return getBrowserCmd(url).Start()
|
||||
}
|
||||
|
||||
// Run will kick off a shell task and pipe the results to the respective std pipes
|
||||
func Run(cmd string, args ...string) error {
|
||||
c := exec.Command(cmd, args...)
|
||||
stderr, err := c.StderrPipe()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
go func() {
|
||||
io.Copy(os.Stderr, stderr)
|
||||
}()
|
||||
|
||||
stdout, err := c.StdoutPipe()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
go func() {
|
||||
io.Copy(os.Stdout, stdout)
|
||||
}()
|
||||
return c.Run()
|
||||
}
|
@@ -1,388 +0,0 @@
|
||||
package token
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io/ioutil"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"os"
|
||||
"os/signal"
|
||||
"strings"
|
||||
"syscall"
|
||||
"time"
|
||||
|
||||
"github.com/cloudflare/cloudflared/cmd/cloudflared/config"
|
||||
"github.com/cloudflare/cloudflared/cmd/cloudflared/path"
|
||||
"github.com/cloudflare/cloudflared/cmd/cloudflared/transfer"
|
||||
"github.com/cloudflare/cloudflared/origin"
|
||||
|
||||
"github.com/coreos/go-oidc/jose"
|
||||
"github.com/pkg/errors"
|
||||
"github.com/rs/zerolog"
|
||||
)
|
||||
|
||||
const (
|
||||
keyName = "token"
|
||||
tokenHeader = "CF_Authorization"
|
||||
)
|
||||
|
||||
type lock struct {
|
||||
lockFilePath string
|
||||
backoff *origin.BackoffHandler
|
||||
sigHandler *signalHandler
|
||||
}
|
||||
|
||||
type signalHandler struct {
|
||||
sigChannel chan os.Signal
|
||||
signals []os.Signal
|
||||
}
|
||||
|
||||
type appJWTPayload struct {
|
||||
Aud []string `json:"aud"`
|
||||
Email string `json:"email"`
|
||||
Exp int `json:"exp"`
|
||||
Iat int `json:"iat"`
|
||||
Nbf int `json:"nbf"`
|
||||
Iss string `json:"iss"`
|
||||
Type string `json:"type"`
|
||||
Subt string `json:"sub"`
|
||||
}
|
||||
|
||||
type orgJWTPayload struct {
|
||||
appJWTPayload
|
||||
Aud string `json:"aud"`
|
||||
}
|
||||
|
||||
type transferServiceResponse struct {
|
||||
AppToken string `json:"app_token"`
|
||||
OrgToken string `json:"org_token"`
|
||||
}
|
||||
|
||||
func (p appJWTPayload) isExpired() bool {
|
||||
return int(time.Now().Unix()) > p.Exp
|
||||
}
|
||||
|
||||
func (s *signalHandler) register(handler func()) {
|
||||
s.sigChannel = make(chan os.Signal, 1)
|
||||
signal.Notify(s.sigChannel, s.signals...)
|
||||
go func(s *signalHandler) {
|
||||
for range s.sigChannel {
|
||||
handler()
|
||||
}
|
||||
}(s)
|
||||
}
|
||||
|
||||
func (s *signalHandler) deregister() {
|
||||
signal.Stop(s.sigChannel)
|
||||
close(s.sigChannel)
|
||||
}
|
||||
|
||||
func errDeleteTokenFailed(lockFilePath string) error {
|
||||
return fmt.Errorf("failed to acquire a new Access token. Please try to delete %s", lockFilePath)
|
||||
}
|
||||
|
||||
// newLock will get a new file lock
|
||||
func newLock(path string) *lock {
|
||||
lockPath := path + ".lock"
|
||||
return &lock{
|
||||
lockFilePath: lockPath,
|
||||
backoff: &origin.BackoffHandler{MaxRetries: 7},
|
||||
sigHandler: &signalHandler{
|
||||
signals: []os.Signal{syscall.SIGINT, syscall.SIGTERM},
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
func (l *lock) Acquire() error {
|
||||
// Intercept SIGINT and SIGTERM to release lock before exiting
|
||||
l.sigHandler.register(func() {
|
||||
_ = l.deleteLockFile()
|
||||
os.Exit(0)
|
||||
})
|
||||
|
||||
// Check for a path.lock file
|
||||
// if the lock file exists; start polling
|
||||
// if not, create the lock file and go through the normal flow.
|
||||
// See AUTH-1736 for the reason why we do all this
|
||||
for isTokenLocked(l.lockFilePath) {
|
||||
if l.backoff.Backoff(context.Background()) {
|
||||
continue
|
||||
}
|
||||
|
||||
if err := l.deleteLockFile(); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
// Create a lock file so other processes won't also try to get the token at
|
||||
// the same time
|
||||
if err := ioutil.WriteFile(l.lockFilePath, []byte{}, 0600); err != nil {
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (l *lock) deleteLockFile() error {
|
||||
if err := os.Remove(l.lockFilePath); err != nil && !os.IsNotExist(err) {
|
||||
return errDeleteTokenFailed(l.lockFilePath)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (l *lock) Release() error {
|
||||
defer l.sigHandler.deregister()
|
||||
return l.deleteLockFile()
|
||||
}
|
||||
|
||||
// isTokenLocked checks to see if there is another process attempting to get the token already
|
||||
func isTokenLocked(lockFilePath string) bool {
|
||||
exists, err := config.FileExists(lockFilePath)
|
||||
return exists && err == nil
|
||||
}
|
||||
|
||||
// FetchTokenWithRedirect will either load a stored token or generate a new one
|
||||
// it appends the full url as the redirect URL to the access cli request if opening the browser
|
||||
func FetchTokenWithRedirect(appURL *url.URL, log *zerolog.Logger) (string, error) {
|
||||
return getToken(appURL, false, log)
|
||||
}
|
||||
|
||||
// FetchToken will either load a stored token or generate a new one
|
||||
// it appends the host of the appURL as the redirect URL to the access cli request if opening the browser
|
||||
func FetchToken(appURL *url.URL, log *zerolog.Logger) (string, error) {
|
||||
return getToken(appURL, true, log)
|
||||
}
|
||||
|
||||
// getToken will either load a stored token or generate a new one
|
||||
func getToken(appURL *url.URL, useHostOnly bool, log *zerolog.Logger) (string, error) {
|
||||
if token, err := GetAppTokenIfExists(appURL); token != "" && err == nil {
|
||||
return token, nil
|
||||
}
|
||||
|
||||
appTokenPath, err := path.GenerateAppTokenFilePathFromURL(appURL, keyName)
|
||||
if err != nil {
|
||||
return "", errors.Wrap(err, "failed to generate app token file path")
|
||||
}
|
||||
|
||||
fileLockAppToken := newLock(appTokenPath)
|
||||
if err = fileLockAppToken.Acquire(); err != nil {
|
||||
return "", errors.Wrap(err, "failed to acquire app token lock")
|
||||
}
|
||||
defer fileLockAppToken.Release()
|
||||
|
||||
// check to see if another process has gotten a token while we waited for the lock
|
||||
if token, err := GetAppTokenIfExists(appURL); token != "" && err == nil {
|
||||
return token, nil
|
||||
}
|
||||
|
||||
// If an app token couldnt be found on disk, check for an org token and attempt to exchange it for an app token.
|
||||
var orgTokenPath string
|
||||
// Get auth domain to format into org token file path
|
||||
if authDomain, err := getAuthDomain(appURL); err != nil {
|
||||
log.Error().Msgf("failed to get auth domain: %s", err)
|
||||
} else {
|
||||
orgToken, err := GetOrgTokenIfExists(authDomain)
|
||||
if err != nil {
|
||||
orgTokenPath, err = path.GenerateOrgTokenFilePathFromURL(authDomain)
|
||||
if err != nil {
|
||||
return "", errors.Wrap(err, "failed to generate org token file path")
|
||||
}
|
||||
|
||||
fileLockOrgToken := newLock(orgTokenPath)
|
||||
if err = fileLockOrgToken.Acquire(); err != nil {
|
||||
return "", errors.Wrap(err, "failed to acquire org token lock")
|
||||
}
|
||||
defer fileLockOrgToken.Release()
|
||||
// check if an org token has been created since the lock was acquired
|
||||
orgToken, err = GetOrgTokenIfExists(authDomain)
|
||||
}
|
||||
if err == nil {
|
||||
if appToken, err := exchangeOrgToken(appURL, orgToken); err != nil {
|
||||
log.Debug().Msgf("failed to exchange org token for app token: %s", err)
|
||||
} else {
|
||||
if err := ioutil.WriteFile(appTokenPath, []byte(appToken), 0600); err != nil {
|
||||
return "", errors.Wrap(err, "failed to write app token to disk")
|
||||
}
|
||||
return appToken, nil
|
||||
}
|
||||
}
|
||||
}
|
||||
return getTokensFromEdge(appURL, appTokenPath, orgTokenPath, useHostOnly, log)
|
||||
|
||||
}
|
||||
|
||||
// getTokensFromEdge will attempt to use the transfer service to retrieve an app and org token, save them to disk,
|
||||
// and return the app token.
|
||||
func getTokensFromEdge(appURL *url.URL, appTokenPath, orgTokenPath string, useHostOnly bool, log *zerolog.Logger) (string, error) {
|
||||
// If no org token exists or if it couldnt be exchanged for an app token, then run the transfer service flow.
|
||||
|
||||
// this weird parameter is the resource name (token) and the key/value
|
||||
// we want to send to the transfer service. the key is token and the value
|
||||
// is blank (basically just the id generated in the transfer service)
|
||||
resourceData, err := transfer.Run(appURL, keyName, keyName, "", true, useHostOnly, log)
|
||||
if err != nil {
|
||||
return "", errors.Wrap(err, "failed to run transfer service")
|
||||
}
|
||||
var resp transferServiceResponse
|
||||
if err = json.Unmarshal(resourceData, &resp); err != nil {
|
||||
return "", errors.Wrap(err, "failed to marshal transfer service response")
|
||||
}
|
||||
|
||||
// If we were able to get the auth domain and generate an org token path, lets write it to disk.
|
||||
if orgTokenPath != "" {
|
||||
if err := ioutil.WriteFile(orgTokenPath, []byte(resp.OrgToken), 0600); err != nil {
|
||||
return "", errors.Wrap(err, "failed to write org token to disk")
|
||||
}
|
||||
}
|
||||
|
||||
if err := ioutil.WriteFile(appTokenPath, []byte(resp.AppToken), 0600); err != nil {
|
||||
return "", errors.Wrap(err, "failed to write app token to disk")
|
||||
}
|
||||
|
||||
return resp.AppToken, nil
|
||||
|
||||
}
|
||||
|
||||
// getAuthDomain makes a request to the appURL and stops at the first redirect. The 302 location header will contain the
|
||||
// auth domain
|
||||
func getAuthDomain(appURL *url.URL) (string, error) {
|
||||
client := &http.Client{
|
||||
// do not follow redirects
|
||||
CheckRedirect: func(req *http.Request, via []*http.Request) error {
|
||||
return http.ErrUseLastResponse
|
||||
},
|
||||
Timeout: time.Second * 7,
|
||||
}
|
||||
|
||||
authDomainReq, err := http.NewRequest("HEAD", appURL.String(), nil)
|
||||
if err != nil {
|
||||
return "", errors.Wrap(err, "failed to create auth domain request")
|
||||
}
|
||||
resp, err := client.Do(authDomainReq)
|
||||
if err != nil {
|
||||
return "", errors.Wrap(err, "failed to get auth domain")
|
||||
}
|
||||
resp.Body.Close()
|
||||
location, err := resp.Location()
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("failed to get auth domain. Received status code %d from %s", resp.StatusCode, appURL.String())
|
||||
}
|
||||
return location.Hostname(), nil
|
||||
|
||||
}
|
||||
|
||||
// exchangeOrgToken attaches an org token to a request to the appURL and returns an app token. This uses the Access SSO
|
||||
// flow to automatically generate and return an app token without the login page.
|
||||
func exchangeOrgToken(appURL *url.URL, orgToken string) (string, error) {
|
||||
client := &http.Client{
|
||||
CheckRedirect: func(req *http.Request, via []*http.Request) error {
|
||||
// attach org token to login request
|
||||
if strings.Contains(req.URL.Path, "cdn-cgi/access/login") {
|
||||
req.AddCookie(&http.Cookie{Name: tokenHeader, Value: orgToken})
|
||||
}
|
||||
// stop after hitting authorized endpoint since it will contain the app token
|
||||
if strings.Contains(via[len(via)-1].URL.Path, "cdn-cgi/access/authorized") {
|
||||
return http.ErrUseLastResponse
|
||||
}
|
||||
return nil
|
||||
},
|
||||
Timeout: time.Second * 7,
|
||||
}
|
||||
|
||||
appTokenRequest, err := http.NewRequest("HEAD", appURL.String(), nil)
|
||||
if err != nil {
|
||||
return "", errors.Wrap(err, "failed to create app token request")
|
||||
}
|
||||
resp, err := client.Do(appTokenRequest)
|
||||
if err != nil {
|
||||
return "", errors.Wrap(err, "failed to get app token")
|
||||
}
|
||||
resp.Body.Close()
|
||||
var appToken string
|
||||
for _, c := range resp.Cookies() {
|
||||
//if Org token revoked on exchange, getTokensFromEdge instead
|
||||
validAppToken := c.Name == tokenHeader && time.Now().Before(c.Expires)
|
||||
if validAppToken {
|
||||
appToken = c.Value
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
if len(appToken) > 0 {
|
||||
return appToken, nil
|
||||
}
|
||||
return "", fmt.Errorf("response from %s did not contain app token", resp.Request.URL.String())
|
||||
}
|
||||
|
||||
func GetOrgTokenIfExists(authDomain string) (string, error) {
|
||||
path, err := path.GenerateOrgTokenFilePathFromURL(authDomain)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
token, err := getTokenIfExists(path)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
var payload orgJWTPayload
|
||||
err = json.Unmarshal(token.Payload, &payload)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
if payload.isExpired() {
|
||||
err := os.Remove(path)
|
||||
return "", err
|
||||
}
|
||||
return token.Encode(), nil
|
||||
}
|
||||
|
||||
func GetAppTokenIfExists(url *url.URL) (string, error) {
|
||||
path, err := path.GenerateAppTokenFilePathFromURL(url, keyName)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
token, err := getTokenIfExists(path)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
var payload appJWTPayload
|
||||
err = json.Unmarshal(token.Payload, &payload)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
if payload.isExpired() {
|
||||
err := os.Remove(path)
|
||||
return "", err
|
||||
}
|
||||
return token.Encode(), nil
|
||||
|
||||
}
|
||||
|
||||
// GetTokenIfExists will return the token from local storage if it exists and not expired
|
||||
func getTokenIfExists(path string) (*jose.JWT, error) {
|
||||
content, err := ioutil.ReadFile(path)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
token, err := jose.ParseJWT(string(content))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return &token, nil
|
||||
}
|
||||
|
||||
// RemoveTokenIfExists removes the a token from local storage if it exists
|
||||
func RemoveTokenIfExists(url *url.URL) error {
|
||||
path, err := path.GenerateAppTokenFilePathFromURL(url, keyName)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if err := os.Remove(path); err != nil && !os.IsNotExist(err) {
|
||||
return err
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
@@ -1,52 +0,0 @@
|
||||
package token
|
||||
|
||||
import (
|
||||
"os"
|
||||
"syscall"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestSignalHandler(t *testing.T) {
|
||||
sigHandler := signalHandler{signals: []os.Signal{syscall.SIGUSR1}}
|
||||
handlerRan := false
|
||||
done := make(chan struct{})
|
||||
timer := time.NewTimer(time.Second)
|
||||
sigHandler.register(func(){
|
||||
handlerRan = true
|
||||
done <- struct{}{}
|
||||
})
|
||||
|
||||
p, err := os.FindProcess(os.Getpid())
|
||||
require.Nil(t, err)
|
||||
p.Signal(syscall.SIGUSR1)
|
||||
|
||||
// Blocks for up to one second to make sure the handler callback runs before the assert.
|
||||
select {
|
||||
case <- done:
|
||||
assert.True(t, handlerRan)
|
||||
case <- timer.C:
|
||||
t.Fail()
|
||||
}
|
||||
sigHandler.deregister()
|
||||
}
|
||||
|
||||
func TestSignalHandlerClose(t *testing.T) {
|
||||
sigHandler := signalHandler{signals: []os.Signal{syscall.SIGUSR1}}
|
||||
done := make(chan struct{})
|
||||
timer := time.NewTimer(time.Second)
|
||||
sigHandler.register(func(){done <- struct{}{}})
|
||||
sigHandler.deregister()
|
||||
|
||||
p, err := os.FindProcess(os.Getpid())
|
||||
require.Nil(t, err)
|
||||
p.Signal(syscall.SIGUSR1)
|
||||
select {
|
||||
case <- done:
|
||||
t.Fail()
|
||||
case <- timer.C:
|
||||
}
|
||||
}
|
@@ -1,159 +0,0 @@
|
||||
package transfer
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"encoding/base64"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"os"
|
||||
"time"
|
||||
|
||||
"github.com/cloudflare/cloudflared/cmd/cloudflared/encrypter"
|
||||
"github.com/cloudflare/cloudflared/cmd/cloudflared/shell"
|
||||
"github.com/pkg/errors"
|
||||
"github.com/rs/zerolog"
|
||||
)
|
||||
|
||||
const (
|
||||
baseStoreURL = "https://login.argotunnel.com/"
|
||||
clientTimeout = time.Second * 60
|
||||
)
|
||||
|
||||
// Run does the transfer "dance" with the end result downloading the supported resource.
|
||||
// The expanded description is run is encapsulation of shared business logic needed
|
||||
// to request a resource (token/cert/etc) from the transfer service (loginhelper).
|
||||
// The "dance" we refer to is building a HTTP request, opening that in a browser waiting for
|
||||
// the user to complete an action, while it long polls in the background waiting for an
|
||||
// action to be completed to download the resource.
|
||||
func Run(transferURL *url.URL, resourceName, key, value string, shouldEncrypt bool, useHostOnly bool, log *zerolog.Logger) ([]byte, error) {
|
||||
encrypterClient, err := encrypter.New("cloudflared_priv.pem", "cloudflared_pub.pem")
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
requestURL, err := buildRequestURL(transferURL, key, value+encrypterClient.PublicKey(), shouldEncrypt, useHostOnly)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// See AUTH-1423 for why we use stderr (the way git wraps ssh)
|
||||
err = shell.OpenBrowser(requestURL)
|
||||
if err != nil {
|
||||
fmt.Fprintf(os.Stderr, "Please open the following URL and log in with your Cloudflare account:\n\n%s\n\nLeave cloudflared running to download the %s automatically.\n", requestURL, resourceName)
|
||||
} else {
|
||||
fmt.Fprintf(os.Stderr, "A browser window should have opened at the following URL:\n\n%s\n\nIf the browser failed to open, please visit the URL above directly in your browser.\n", requestURL)
|
||||
}
|
||||
|
||||
var resourceData []byte
|
||||
|
||||
if shouldEncrypt {
|
||||
buf, key, err := transferRequest(baseStoreURL+"transfer/"+encrypterClient.PublicKey(), log)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
decodedBuf, err := base64.StdEncoding.DecodeString(string(buf))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
decrypted, err := encrypterClient.Decrypt(decodedBuf, key)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
resourceData = decrypted
|
||||
} else {
|
||||
buf, _, err := transferRequest(baseStoreURL+encrypterClient.PublicKey(), log)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
resourceData = buf
|
||||
}
|
||||
|
||||
return resourceData, nil
|
||||
|
||||
}
|
||||
|
||||
// BuildRequestURL creates a request suitable for a resource transfer.
|
||||
// it will return a constructed url based off the base url and query key/value provided.
|
||||
// cli will build a url for cli transfer request.
|
||||
func buildRequestURL(baseURL *url.URL, key, value string, cli, useHostOnly bool) (string, error) {
|
||||
q := baseURL.Query()
|
||||
q.Set(key, value)
|
||||
baseURL.RawQuery = q.Encode()
|
||||
if useHostOnly {
|
||||
baseURL.Path = ""
|
||||
}
|
||||
if !cli {
|
||||
return baseURL.String(), nil
|
||||
}
|
||||
q.Set("redirect_url", baseURL.String()) // we add the token as a query param on both the redirect_url and the main url
|
||||
q.Set("send_org_token", "true") // indicates that the cli endpoint should return both the org and app token
|
||||
baseURL.RawQuery = q.Encode() // and this actual baseURL.
|
||||
baseURL.Path = "cdn-cgi/access/cli"
|
||||
return baseURL.String(), nil
|
||||
}
|
||||
|
||||
// transferRequest downloads the requested resource from the request URL
|
||||
func transferRequest(requestURL string, log *zerolog.Logger) ([]byte, string, error) {
|
||||
client := &http.Client{Timeout: clientTimeout}
|
||||
const pollAttempts = 10
|
||||
// we do "long polling" on the endpoint to get the resource.
|
||||
for i := 0; i < pollAttempts; i++ {
|
||||
buf, key, err := poll(client, requestURL, log)
|
||||
if err != nil {
|
||||
return nil, "", err
|
||||
} else if len(buf) > 0 {
|
||||
if err := putSuccess(client, requestURL); err != nil {
|
||||
log.Err(err).Msg("Failed to update resource success")
|
||||
}
|
||||
return buf, key, nil
|
||||
}
|
||||
}
|
||||
return nil, "", errors.New("Failed to fetch resource")
|
||||
}
|
||||
|
||||
// poll the endpoint for the request resource, waiting for the user interaction
|
||||
func poll(client *http.Client, requestURL string, log *zerolog.Logger) ([]byte, string, error) {
|
||||
resp, err := client.Get(requestURL)
|
||||
if err != nil {
|
||||
return nil, "", err
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
// ignore everything other than server errors as the resource
|
||||
// may not exist until the user does the interaction
|
||||
if resp.StatusCode >= 500 {
|
||||
return nil, "", fmt.Errorf("error on request %d", resp.StatusCode)
|
||||
}
|
||||
if resp.StatusCode != 200 {
|
||||
log.Info().Msg("Waiting for login...")
|
||||
return nil, "", nil
|
||||
}
|
||||
|
||||
buf := new(bytes.Buffer)
|
||||
if _, err := io.Copy(buf, resp.Body); err != nil {
|
||||
return nil, "", err
|
||||
}
|
||||
return buf.Bytes(), resp.Header.Get("service-public-key"), nil
|
||||
}
|
||||
|
||||
// putSuccess tells the server we successfully downloaded the resource
|
||||
func putSuccess(client *http.Client, requestURL string) error {
|
||||
req, err := http.NewRequest("PUT", requestURL+"/ok", nil)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
resp, err := client.Do(req)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
resp.Body.Close()
|
||||
if resp.StatusCode != 200 {
|
||||
return fmt.Errorf("HTTP Response Status Code: %d", resp.StatusCode)
|
||||
}
|
||||
return nil
|
||||
}
|
@@ -15,9 +15,10 @@ import (
|
||||
|
||||
"github.com/cloudflare/cloudflared/cmd/cloudflared/buildinfo"
|
||||
"github.com/cloudflare/cloudflared/cmd/cloudflared/cliutil"
|
||||
"github.com/cloudflare/cloudflared/cmd/cloudflared/config"
|
||||
"github.com/cloudflare/cloudflared/cmd/cloudflared/proxydns"
|
||||
"github.com/cloudflare/cloudflared/cmd/cloudflared/ui"
|
||||
"github.com/cloudflare/cloudflared/cmd/cloudflared/updater"
|
||||
"github.com/cloudflare/cloudflared/config"
|
||||
"github.com/cloudflare/cloudflared/connection"
|
||||
"github.com/cloudflare/cloudflared/ingress"
|
||||
"github.com/cloudflare/cloudflared/logger"
|
||||
@@ -104,7 +105,7 @@ func Commands() []*cli.Command {
|
||||
buildDeleteCommand(),
|
||||
buildCleanupCommand(),
|
||||
// for compatibility, allow following as tunnel subcommands
|
||||
tunneldns.Command(true),
|
||||
proxydns.Command(true),
|
||||
cliutil.RemovedCommand("db-connect"),
|
||||
}
|
||||
|
||||
|
@@ -9,7 +9,7 @@ import (
|
||||
"strings"
|
||||
|
||||
"github.com/cloudflare/cloudflared/cmd/cloudflared/buildinfo"
|
||||
"github.com/cloudflare/cloudflared/cmd/cloudflared/config"
|
||||
"github.com/cloudflare/cloudflared/config"
|
||||
"github.com/cloudflare/cloudflared/connection"
|
||||
"github.com/cloudflare/cloudflared/edgediscovery"
|
||||
"github.com/cloudflare/cloudflared/h2mux"
|
||||
@@ -207,7 +207,7 @@ func prepareTunnelConfig(
|
||||
ClientID: clientUUID[:],
|
||||
Features: dedup(features),
|
||||
Version: version,
|
||||
Arch: fmt.Sprintf("%s_%s", buildInfo.GoOS, buildInfo.GoArch),
|
||||
Arch: buildInfo.OSArch(),
|
||||
}
|
||||
ingressRules, err = ingress.ParseIngress(cfg)
|
||||
if err != nil && err != ingress.ErrNoIngressRules {
|
||||
@@ -272,7 +272,7 @@ func prepareTunnelConfig(
|
||||
|
||||
return &origin.TunnelConfig{
|
||||
ConnectionConfig: connectionConfig,
|
||||
BuildInfo: buildInfo,
|
||||
OSArch: buildInfo.OSArch(),
|
||||
ClientID: clientID,
|
||||
EdgeAddrs: c.StringSlice("edge"),
|
||||
HAConnections: c.Int("ha-connections"),
|
||||
|
@@ -4,7 +4,7 @@ import (
|
||||
"fmt"
|
||||
"path/filepath"
|
||||
|
||||
"github.com/cloudflare/cloudflared/cmd/cloudflared/config"
|
||||
"github.com/cloudflare/cloudflared/config"
|
||||
|
||||
"github.com/google/uuid"
|
||||
"github.com/rs/zerolog"
|
||||
|
@@ -5,7 +5,7 @@ import (
|
||||
"net/url"
|
||||
|
||||
"github.com/cloudflare/cloudflared/cmd/cloudflared/cliutil"
|
||||
"github.com/cloudflare/cloudflared/cmd/cloudflared/config"
|
||||
"github.com/cloudflare/cloudflared/config"
|
||||
"github.com/cloudflare/cloudflared/ingress"
|
||||
|
||||
"github.com/pkg/errors"
|
||||
|
@@ -13,9 +13,9 @@ import (
|
||||
"github.com/urfave/cli/v2"
|
||||
|
||||
"github.com/cloudflare/cloudflared/cmd/cloudflared/cliutil"
|
||||
"github.com/cloudflare/cloudflared/cmd/cloudflared/config"
|
||||
"github.com/cloudflare/cloudflared/cmd/cloudflared/transfer"
|
||||
"github.com/cloudflare/cloudflared/config"
|
||||
"github.com/cloudflare/cloudflared/logger"
|
||||
"github.com/cloudflare/cloudflared/token"
|
||||
)
|
||||
|
||||
const (
|
||||
@@ -56,7 +56,7 @@ func login(c *cli.Context) error {
|
||||
return err
|
||||
}
|
||||
|
||||
resourceData, err := transfer.Run(
|
||||
resourceData, err := token.RunTransfer(
|
||||
loginURL,
|
||||
"cert",
|
||||
"callback",
|
||||
|
@@ -13,7 +13,6 @@ import (
|
||||
"text/tabwriter"
|
||||
"time"
|
||||
|
||||
"github.com/cloudflare/cloudflared/cmd/cloudflared/updater"
|
||||
"github.com/google/uuid"
|
||||
"github.com/mitchellh/go-homedir"
|
||||
"github.com/pkg/errors"
|
||||
@@ -23,7 +22,8 @@ import (
|
||||
"gopkg.in/yaml.v2"
|
||||
|
||||
"github.com/cloudflare/cloudflared/cmd/cloudflared/cliutil"
|
||||
"github.com/cloudflare/cloudflared/cmd/cloudflared/config"
|
||||
"github.com/cloudflare/cloudflared/cmd/cloudflared/updater"
|
||||
"github.com/cloudflare/cloudflared/config"
|
||||
"github.com/cloudflare/cloudflared/connection"
|
||||
"github.com/cloudflare/cloudflared/tunnelstore"
|
||||
)
|
||||
|
@@ -8,12 +8,13 @@ import (
|
||||
"runtime"
|
||||
"time"
|
||||
|
||||
"github.com/cloudflare/cloudflared/cmd/cloudflared/config"
|
||||
"github.com/cloudflare/cloudflared/logger"
|
||||
"github.com/facebookgo/grace/gracenet"
|
||||
"github.com/rs/zerolog"
|
||||
"github.com/urfave/cli/v2"
|
||||
"golang.org/x/crypto/ssh/terminal"
|
||||
|
||||
"github.com/cloudflare/cloudflared/config"
|
||||
"github.com/cloudflare/cloudflared/logger"
|
||||
)
|
||||
|
||||
const (
|
||||
|
Reference in New Issue
Block a user