TUN-528: Move cloudflared into a separate repo

This commit is contained in:
Areg Harutyunyan
2018-05-01 18:45:06 -05:00
parent e8c621a648
commit d06fc520c7
4726 changed files with 1763680 additions and 0 deletions

View File

@@ -0,0 +1,320 @@
package main
import (
"crypto/tls"
"crypto/x509"
"encoding/hex"
"fmt"
"io/ioutil"
"math/rand"
"net"
"net/http"
"os"
"path/filepath"
"runtime"
"strings"
"time"
"github.com/cloudflare/cloudflared/origin"
"github.com/cloudflare/cloudflared/tlsconfig"
tunnelpogs "github.com/cloudflare/cloudflared/tunnelrpc/pogs"
"github.com/cloudflare/cloudflared/validation"
"github.com/sirupsen/logrus"
"gopkg.in/urfave/cli.v2"
"gopkg.in/urfave/cli.v2/altsrc"
"github.com/mitchellh/go-homedir"
"github.com/pkg/errors"
)
var (
defaultConfigFiles = []string{"config.yml", "config.yaml"}
// Launchd doesn't set root env variables, so there is default
// Windows default config dir was ~/cloudflare-warp in documentation; let's keep it compatible
defaultConfigDirs = []string{"~/.cloudflared", "~/.cloudflare-warp", "~/cloudflare-warp", "/usr/local/etc/cloudflared", "/etc/cloudflared"}
)
const defaultCredentialFile = "cert.pem"
func fileExists(path string) (bool, error) {
f, err := os.Open(path)
if err != nil {
if os.IsNotExist(err) {
// ignore missing files
return false, nil
}
return false, err
}
f.Close()
return true, nil
}
// returns the first path that contains a cert.pem file. If none of the defaultConfigDirs
// (differs by OS for legacy reasons) contains a cert.pem file, return empty string
func findDefaultOriginCertPath() string {
for _, defaultConfigDir := range defaultConfigDirs {
originCertPath, _ := homedir.Expand(filepath.Join(defaultConfigDir, defaultCredentialFile))
if ok, _ := fileExists(originCertPath); ok {
return originCertPath
}
}
return ""
}
// returns the first path that contains a config file. If none of the combination of
// defaultConfigDirs (differs by OS for legacy reasons) and defaultConfigFiles
// contains a config file, return empty string
func findDefaultConfigPath() string {
for _, configDir := range defaultConfigDirs {
for _, configFile := range defaultConfigFiles {
dirPath, err := homedir.Expand(configDir)
if err != nil {
continue
}
path := filepath.Join(dirPath, configFile)
if ok, _ := fileExists(path); ok {
return path
}
}
}
return ""
}
func findInputSourceContext(context *cli.Context) (altsrc.InputSourceContext, error) {
if context.String("config") != "" {
return altsrc.NewYamlSourceFromFile(context.String("config"))
}
return nil, nil
}
func generateRandomClientID() string {
r := rand.New(rand.NewSource(time.Now().UnixNano()))
id := make([]byte, 32)
r.Read(id)
return hex.EncodeToString(id)
}
func enoughOptionsSet(c *cli.Context) bool {
// For cloudflared to work, the user needs to at least provide a hostname,
// or runs as stand alone DNS proxy .
// When using sudo, use -E flag to preserve env vars
if c.NumFlags() == 0 && c.NArg() == 0 && os.Getenv("TUNNEL_HOSTNAME") == "" && os.Getenv("TUNNEL_DNS") == "" {
if isRunningFromTerminal() {
logger.Errorf("No arguments were provided. You need to at least specify the hostname for this tunnel. See %s", quickStartUrl)
logger.Infof("If you want to run Argo Tunnel client as a stand alone DNS proxy, run with --proxy-dns option or set TUNNEL_DNS environment variable.")
} else {
logger.Errorf("You need to specify all the options in a configuration file, or use environment variables. See %s and %s", serviceUrl, argumentsUrl)
logger.Infof("If you want to run Argo Tunnel client as a stand alone DNS proxy, specify proxy-dns option in the configuration file, or set TUNNEL_DNS environment variable.")
}
cli.ShowAppHelp(c)
return false
}
return true
}
func handleDeprecatedOptions(c *cli.Context) error {
// Fail if the user provided an old authentication method
if c.IsSet("api-key") || c.IsSet("api-email") || c.IsSet("api-ca-key") {
logger.Error("You don't need to give us your api-key anymore. Please use the new login method. Just run cloudflared login")
return fmt.Errorf("Client provided deprecated options")
}
return nil
}
// validate url. It can be either from --url or argument
func validateUrl(c *cli.Context) (string, error) {
var url = c.String("url")
if c.NArg() > 0 {
if c.IsSet("url") {
return "", errors.New("Specified origin urls using both --url and argument. Decide which one you want, I can only support one.")
}
url = c.Args().Get(0)
}
validUrl, err := validation.ValidateUrl(url)
return validUrl, err
}
func logClientOptions(c *cli.Context) {
flags := make(map[string]interface{})
for _, flag := range c.LocalFlagNames() {
flags[flag] = c.Generic(flag)
}
if len(flags) > 0 {
logger.Infof("Flags %v", flags)
}
envs := make(map[string]string)
// Find env variables for Argo Tunnel
for _, env := range os.Environ() {
// All Argo Tunnel env variables start with TUNNEL_
if strings.Contains(env, "TUNNEL_") {
vars := strings.Split(env, "=")
if len(vars) == 2 {
envs[vars[0]] = vars[1]
}
}
}
if len(envs) > 0 {
logger.Infof("Environmental variables %v", envs)
}
}
func dnsProxyStandAlone(c *cli.Context) bool {
return c.IsSet("proxy-dns") && (!c.IsSet("hostname") && !c.IsSet("tag") && !c.IsSet("hello-world"))
}
func getOriginCert(c *cli.Context) ([]byte, error) {
if c.String("origincert") == "" {
logger.Warnf("Cannot determine default origin certificate path. No file %s in %v", defaultCredentialFile, defaultConfigDirs)
if isRunningFromTerminal() {
logger.Errorf("You need to specify the origin certificate path with --origincert option, or set TUNNEL_ORIGIN_CERT environment variable. See %s for more information.", argumentsUrl)
return nil, fmt.Errorf("Client didn't specify origincert path when running from terminal")
} else {
logger.Errorf("You need to specify the origin certificate path by specifying the origincert option in the configuration file, or set TUNNEL_ORIGIN_CERT environment variable. See %s for more information.", serviceUrl)
return nil, fmt.Errorf("Client didn't specify origincert path")
}
}
// Check that the user has acquired a certificate using the login command
originCertPath, err := homedir.Expand(c.String("origincert"))
if err != nil {
logger.WithError(err).Errorf("Cannot resolve path %s", c.String("origincert"))
return nil, fmt.Errorf("Cannot resolve path %s", c.String("origincert"))
}
ok, err := fileExists(originCertPath)
if err != nil {
logger.Errorf("Cannot check if origin cert exists at path %s", c.String("origincert"))
return nil, fmt.Errorf("Cannot check if origin cert exists at path %s", c.String("origincert"))
}
if !ok {
logger.Errorf(`Cannot find a valid certificate for your origin at the path:
%s
If the path above is wrong, specify the path with the -origincert option.
If you don't have a certificate signed by Cloudflare, run the command:
%s login
`, originCertPath, os.Args[0])
return nil, fmt.Errorf("Cannot find a valid certificate at the path %s", originCertPath)
}
// Easier to send the certificate as []byte via RPC than decoding it at this point
originCert, err := ioutil.ReadFile(originCertPath)
if err != nil {
logger.WithError(err).Errorf("Cannot read %s to load origin certificate", originCertPath)
return nil, fmt.Errorf("Cannot read %s to load origin certificate", originCertPath)
}
return originCert, nil
}
func prepareTunnelConfig(c *cli.Context, buildInfo *origin.BuildInfo, logger, protoLogger *logrus.Logger) (*origin.TunnelConfig, error) {
hostname, err := validation.ValidateHostname(c.String("hostname"))
if err != nil {
logger.WithError(err).Error("Invalid hostname")
return nil, errors.Wrap(err, "Invalid hostname")
}
clientID := c.String("id")
if !c.IsSet("id") {
clientID = generateRandomClientID()
}
tags, err := NewTagSliceFromCLI(c.StringSlice("tag"))
if err != nil {
logger.WithError(err).Error("Tag parse failure")
return nil, errors.Wrap(err, "Tag parse failure")
}
tags = append(tags, tunnelpogs.Tag{Name: "ID", Value: clientID})
url, err := validateUrl(c)
if err != nil {
logger.WithError(err).Error("Error validating url")
return nil, errors.Wrap(err, "Error validating url")
}
logger.Infof("Proxying tunnel requests to %s", url)
originCert, err := getOriginCert(c)
if err != nil {
return nil, errors.Wrap(err, "Error getting origin cert")
}
originCertPool, err := loadCertPool(c, logger)
if err != nil {
logger.WithError(err).Error("Error loading cert pool")
return nil, errors.Wrap(err, "Error loading cert pool")
}
tunnelMetrics := origin.NewTunnelMetrics()
httpTransport := &http.Transport{
Proxy: http.ProxyFromEnvironment,
DialContext: (&net.Dialer{
Timeout: c.Duration("proxy-connect-timeout"),
KeepAlive: c.Duration("proxy-tcp-keepalive"),
DualStack: !c.Bool("proxy-no-happy-eyeballs"),
}).DialContext,
MaxIdleConns: c.Int("proxy-keepalive-connections"),
IdleConnTimeout: c.Duration("proxy-keepalive-timeout"),
TLSHandshakeTimeout: c.Duration("proxy-tls-timeout"),
ExpectContinueTimeout: 1 * time.Second,
TLSClientConfig: &tls.Config{RootCAs: originCertPool, InsecureSkipVerify: c.IsSet("no-tls-verify")},
}
if !c.IsSet("hello-world") && c.IsSet("origin-server-name") {
httpTransport.TLSClientConfig.ServerName = c.String("origin-server-name")
}
return &origin.TunnelConfig{
EdgeAddrs: c.StringSlice("edge"),
OriginUrl: url,
Hostname: hostname,
OriginCert: originCert,
TlsConfig: tlsconfig.CreateTunnelConfig(c, c.StringSlice("edge")),
ClientTlsConfig: httpTransport.TLSClientConfig,
Retries: c.Uint("retries"),
HeartbeatInterval: c.Duration("heartbeat-interval"),
MaxHeartbeats: c.Uint64("heartbeat-count"),
ClientID: clientID,
BuildInfo: buildInfo,
ReportedVersion: Version,
LBPool: c.String("lb-pool"),
Tags: tags,
HAConnections: c.Int("ha-connections"),
HTTPTransport: httpTransport,
Metrics: tunnelMetrics,
MetricsUpdateFreq: c.Duration("metrics-update-freq"),
ProtocolLogger: protoLogger,
Logger: logger,
IsAutoupdated: c.Bool("is-autoupdated"),
GracePeriod: c.Duration("grace-period"),
RunFromTerminal: isRunningFromTerminal(),
NoChunkedEncoding: c.Bool("no-chunked-encoding"),
CompressionQuality: c.Uint64("compression-quality"),
}, nil
}
func loadCertPool(c *cli.Context, logger *logrus.Logger) (*x509.CertPool, error) {
const originCAPoolFlag = "origin-ca-pool"
originCAPoolFilename := c.String(originCAPoolFlag)
var originCustomCAPool []byte
if originCAPoolFilename != "" {
var err error
originCustomCAPool, err = ioutil.ReadFile(originCAPoolFilename)
if err != nil {
return nil, errors.Wrap(err, fmt.Sprintf("unable to read the file %s for --%s", originCAPoolFilename, originCAPoolFlag))
}
}
originCertPool, err := tlsconfig.LoadOriginCertPool(originCustomCAPool)
if err != nil {
return nil, errors.Wrap(err, "error loading the certificate pool")
}
// Windows users should be notified that they can use the flag
if runtime.GOOS == "windows" && originCAPoolFilename == "" {
logger.Infof("cloudflared does not support loading the system root certificate pool on Windows. Please use the --%s to specify it", originCAPoolFlag)
}
return originCertPool, nil
}

View File

@@ -0,0 +1,13 @@
// +build !windows,!darwin,!linux
package main
import (
"os"
cli "gopkg.in/urfave/cli.v2"
)
func runApp(app *cli.App, shutdownC, graceShutdownC chan struct{}) {
app.Run(os.Args)
}

21
cmd/cloudflared/hello.go Normal file
View File

@@ -0,0 +1,21 @@
package main
import (
"fmt"
"gopkg.in/urfave/cli.v2"
"github.com/cloudflare/cloudflared/hello"
)
func helloWorld(c *cli.Context) error {
address := fmt.Sprintf(":%d", c.Int("port"))
listener, err := hello.CreateTLSListener(address)
if err != nil {
return err
}
defer listener.Close()
err = hello.StartHelloWorldServer(logger, listener, nil)
return err
}

View File

@@ -0,0 +1,35 @@
package main
import (
"testing"
)
func TestCreateListenerHostAndPortSuccess(t *testing.T) {
listener, err := createListener("localhost:1234")
if err != nil {
t.Fatal(err)
}
if listener.Addr().String() == "" {
t.Fatal("Fail to find available port")
}
}
func TestCreateListenerOnlyHostSuccess(t *testing.T) {
listener, err := createListener("localhost:")
if err != nil {
t.Fatal(err)
}
if listener.Addr().String() == "" {
t.Fatal("Fail to find available port")
}
}
func TestCreateListenerOnlyPortSuccess(t *testing.T) {
listener, err := createListener(":8888")
if err != nil {
t.Fatal(err)
}
if listener.Addr().String() == "" {
t.Fatal("Fail to find available port")
}
}

View File

@@ -0,0 +1,292 @@
// +build linux
package main
import (
"fmt"
"os"
"path/filepath"
cli "gopkg.in/urfave/cli.v2"
)
func runApp(app *cli.App, shutdownC, graceShutdownC chan struct{}) {
app.Commands = append(app.Commands, &cli.Command{
Name: "service",
Usage: "Manages the Argo Tunnel system service",
Subcommands: []*cli.Command{
&cli.Command{
Name: "install",
Usage: "Install Argo Tunnel as a system service",
Action: installLinuxService,
},
&cli.Command{
Name: "uninstall",
Usage: "Uninstall the Argo Tunnel service",
Action: uninstallLinuxService,
},
},
})
app.Run(os.Args)
}
const serviceConfigDir = "/etc/cloudflared"
var systemdTemplates = []ServiceTemplate{
{
Path: "/etc/systemd/system/cloudflared.service",
Content: `[Unit]
Description=Argo Tunnel
After=network.target
[Service]
TimeoutStartSec=0
Type=notify
ExecStart={{ .Path }} --config /etc/cloudflared/config.yml --origincert /etc/cloudflared/cert.pem --no-autoupdate
Restart=on-failure
RestartSec=5s
[Install]
WantedBy=multi-user.target
`,
},
{
Path: "/etc/systemd/system/cloudflared-update.service",
Content: `[Unit]
Description=Update Argo Tunnel
After=network.target
[Service]
ExecStart=/bin/bash -c '{{ .Path }} update; code=$?; if [ $code -eq 64 ]; then systemctl restart cloudflared; exit 0; fi; exit $code'
`,
},
{
Path: "/etc/systemd/system/cloudflared-update.timer",
Content: `[Unit]
Description=Update Argo Tunnel
[Timer]
OnUnitActiveSec=1d
[Install]
WantedBy=timers.target
`,
},
}
var sysvTemplate = ServiceTemplate{
Path: "/etc/init.d/cloudflared",
FileMode: 0755,
Content: `# For RedHat and cousins:
# chkconfig: 2345 99 01
# description: Argo Tunnel agent
# processname: {{.Path}}
### BEGIN INIT INFO
# Provides: {{.Path}}
# Required-Start:
# Required-Stop:
# Default-Start: 2 3 4 5
# Default-Stop: 0 1 6
# Short-Description: Argo Tunnel
# Description: Argo Tunnel agent
### END INIT INFO
name=$(basename $(readlink -f $0))
cmd="{{.Path}} --config /etc/cloudflared/config.yml --origincert /etc/cloudflared/cert.pem --pidfile /var/run/$name.pid --autoupdate-freq 24h0m0s"
pid_file="/var/run/$name.pid"
stdout_log="/var/log/$name.log"
stderr_log="/var/log/$name.err"
[ -e /etc/sysconfig/$name ] && . /etc/sysconfig/$name
get_pid() {
cat "$pid_file"
}
is_running() {
[ -f "$pid_file" ] && ps $(get_pid) > /dev/null 2>&1
}
case "$1" in
start)
if is_running; then
echo "Already started"
else
echo "Starting $name"
$cmd >> "$stdout_log" 2>> "$stderr_log" &
echo $! > "$pid_file"
if ! is_running; then
echo "Unable to start, see $stdout_log and $stderr_log"
exit 1
fi
fi
;;
stop)
if is_running; then
echo -n "Stopping $name.."
kill $(get_pid)
for i in {1..10}
do
if ! is_running; then
break
fi
echo -n "."
sleep 1
done
echo
if is_running; then
echo "Not stopped; may still be shutting down or shutdown may have failed"
exit 1
else
echo "Stopped"
if [ -f "$pid_file" ]; then
rm "$pid_file"
fi
fi
else
echo "Not running"
fi
;;
restart)
$0 stop
if is_running; then
echo "Unable to stop, will not attempt to start"
exit 1
fi
$0 start
;;
status)
if is_running; then
echo "Running"
else
echo "Stopped"
exit 1
fi
;;
*)
echo "Usage: $0 {start|stop|restart|status}"
exit 1
;;
esac
exit 0
`,
}
func isSystemd() bool {
if _, err := os.Stat("/run/systemd/system"); err == nil {
return true
}
return false
}
func installLinuxService(c *cli.Context) error {
etPath, err := os.Executable()
if err != nil {
return fmt.Errorf("error determining executable path: %v", err)
}
templateArgs := ServiceTemplateArgs{Path: etPath}
defaultConfigDir := filepath.Dir(c.String("config"))
defaultConfigFile := filepath.Base(c.String("config"))
if err = copyCredentials(serviceConfigDir, defaultConfigDir, defaultConfigFile, defaultCredentialFile); err != nil {
logger.WithError(err).Infof("Failed to copy user configuration. Before running the service, ensure that %s contains two files, %s and %s",
serviceConfigDir, defaultCredentialFile, defaultConfigFiles[0])
return err
}
switch {
case isSystemd():
logger.Infof("Using Systemd")
return installSystemd(&templateArgs)
default:
logger.Infof("Using Sysv")
return installSysv(&templateArgs)
}
}
func installSystemd(templateArgs *ServiceTemplateArgs) error {
for _, serviceTemplate := range systemdTemplates {
err := serviceTemplate.Generate(templateArgs)
if err != nil {
logger.WithError(err).Infof("error generating service template")
return err
}
}
if err := runCommand("systemctl", "enable", "cloudflared.service"); err != nil {
logger.WithError(err).Infof("systemctl enable cloudflared.service error")
return err
}
if err := runCommand("systemctl", "start", "cloudflared-update.timer"); err != nil {
logger.WithError(err).Infof("systemctl start cloudflared-update.timer error")
return err
}
logger.Infof("systemctl daemon-reload")
return runCommand("systemctl", "daemon-reload")
}
func installSysv(templateArgs *ServiceTemplateArgs) error {
confPath, err := sysvTemplate.ResolvePath()
if err != nil {
logger.WithError(err).Infof("error resolving system path")
return err
}
if err := sysvTemplate.Generate(templateArgs); err != nil {
logger.WithError(err).Infof("error generating system template")
return err
}
for _, i := range [...]string{"2", "3", "4", "5"} {
if err := os.Symlink(confPath, "/etc/rc"+i+".d/S50et"); err != nil {
continue
}
}
for _, i := range [...]string{"0", "1", "6"} {
if err := os.Symlink(confPath, "/etc/rc"+i+".d/K02et"); err != nil {
continue
}
}
return nil
}
func uninstallLinuxService(c *cli.Context) error {
switch {
case isSystemd():
logger.Infof("Using Systemd")
return uninstallSystemd()
default:
logger.Infof("Using Sysv")
return uninstallSysv()
}
}
func uninstallSystemd() error {
if err := runCommand("systemctl", "disable", "cloudflared.service"); err != nil {
logger.WithError(err).Infof("systemctl disable cloudflared.service error")
return err
}
if err := runCommand("systemctl", "stop", "cloudflared-update.timer"); err != nil {
logger.WithError(err).Infof("systemctl stop cloudflared-update.timer error")
return err
}
for _, serviceTemplate := range systemdTemplates {
if err := serviceTemplate.Remove(); err != nil {
logger.WithError(err).Infof("error removing service template")
return err
}
}
logger.Infof("Successfully uninstall cloudflared service")
return nil
}
func uninstallSysv() error {
if err := sysvTemplate.Remove(); err != nil {
logger.WithError(err).Infof("error removing service template")
return err
}
for _, i := range [...]string{"2", "3", "4", "5"} {
if err := os.Remove("/etc/rc" + i + ".d/S50et"); err != nil {
continue
}
}
for _, i := range [...]string{"0", "1", "6"} {
if err := os.Remove("/etc/rc" + i + ".d/K02et"); err != nil {
continue
}
}
logger.Infof("Successfully uninstall cloudflared service")
return nil
}

68
cmd/cloudflared/logger.go Normal file
View File

@@ -0,0 +1,68 @@
package main
import (
"fmt"
"os"
"github.com/cloudflare/cloudflared/log"
"github.com/rifflock/lfshook"
"github.com/sirupsen/logrus"
"gopkg.in/urfave/cli.v2"
"github.com/mitchellh/go-homedir"
"github.com/pkg/errors"
)
var logger = log.CreateLogger()
func configMainLogger(c *cli.Context) error {
logLevel, err := logrus.ParseLevel(c.String("loglevel"))
if err != nil {
logger.WithError(err).Error("Unknown logging level specified")
return errors.Wrap(err, "Unknown logging level specified")
}
logger.SetLevel(logLevel)
return nil
}
func configProtoLogger(c *cli.Context) (*logrus.Logger, error) {
protoLogLevel, err := logrus.ParseLevel(c.String("proto-loglevel"))
if err != nil {
logger.WithError(err).Fatal("Unknown protocol logging level specified")
return nil, errors.Wrap(err, "Unknown protocol logging level specified")
}
protoLogger := logrus.New()
protoLogger.Level = protoLogLevel
return protoLogger, nil
}
func initLogFile(c *cli.Context, loggers ...*logrus.Logger) error {
filePath, err := homedir.Expand(c.String("logfile"))
if err != nil {
return errors.Wrap(err, "Cannot resolve logfile path")
}
fileMode := os.O_WRONLY | os.O_APPEND | os.O_CREATE | os.O_TRUNC
// do not truncate log file if the client has been autoupdated
if c.Bool("is-autoupdated") {
fileMode = os.O_WRONLY | os.O_APPEND | os.O_CREATE
}
f, err := os.OpenFile(filePath, fileMode, 0664)
if err != nil {
errors.Wrap(err, fmt.Sprintf("Cannot open file %s", filePath))
}
defer f.Close()
pathMap := lfshook.PathMap{
logrus.InfoLevel: filePath,
logrus.ErrorLevel: filePath,
logrus.FatalLevel: filePath,
logrus.PanicLevel: filePath,
}
for _, l := range loggers {
l.Hooks.Add(lfshook.NewHook(pathMap, &logrus.JSONFormatter{}))
}
return nil
}

194
cmd/cloudflared/login.go Normal file
View File

@@ -0,0 +1,194 @@
package main
import (
"crypto/rand"
"encoding/base32"
"fmt"
"io"
"net/http"
"net/url"
"os"
"os/exec"
"path/filepath"
"runtime"
"syscall"
"time"
homedir "github.com/mitchellh/go-homedir"
cli "gopkg.in/urfave/cli.v2"
)
const baseLoginURL = "https://dash.cloudflare.com/warp"
const baseCertStoreURL = "https://login.cloudflarewarp.com"
const clientTimeout = time.Minute * 20
func login(c *cli.Context) error {
configPath, err := homedir.Expand(defaultConfigDirs[0])
if err != nil {
return err
}
ok, err := fileExists(configPath)
if !ok && err == nil {
// create config directory if doesn't already exist
err = os.Mkdir(configPath, 0700)
}
if err != nil {
return err
}
path := filepath.Join(configPath, defaultCredentialFile)
fileInfo, err := os.Stat(path)
if err == nil && fileInfo.Size() > 0 {
fmt.Fprintf(os.Stderr, `You have an existing certificate at %s which login would overwrite.
If this is intentional, please move or delete that file then run this command again.
`, path)
return nil
}
if err != nil && err.(*os.PathError).Err != syscall.ENOENT {
return err
}
// for local debugging
baseURL := baseCertStoreURL
if c.IsSet("url") {
baseURL = c.String("url")
}
// Generate a random post URL
certURL := baseURL + generateRandomPath()
loginURL, err := url.Parse(baseLoginURL)
if err != nil {
// shouldn't happen, URL is hardcoded
return err
}
loginURL.RawQuery = "callback=" + url.QueryEscape(certURL)
err = open(loginURL.String())
if err != nil {
fmt.Fprintf(os.Stderr, `Please open the following URL and log in with your Cloudflare account:
%s
Leave cloudflared running to install the certificate automatically.
`, loginURL.String())
} else {
fmt.Fprintf(os.Stderr, `A browser window should have opened at the following URL:
%s
If the browser failed to open, open it yourself and visit the URL above.
`, loginURL.String())
}
if download(certURL, path) {
fmt.Fprintf(os.Stderr, `You have successfully logged in.
If you wish to copy your credentials to a server, they have been saved to:
%s
`, path)
} else {
fmt.Fprintf(os.Stderr, `Failed to write the certificate due to the following error:
%v
Your browser will download the certificate instead. You will have to manually
copy it to the following path:
%s
`, err, path)
}
return nil
}
// generateRandomPath generates a random URL to associate with the certificate.
func generateRandomPath() string {
randomBytes := make([]byte, 40)
_, err := rand.Read(randomBytes)
if err != nil {
panic(err)
}
return "/" + base32.StdEncoding.EncodeToString(randomBytes)
}
// open opens the specified URL in the default browser of the user.
func open(url string) error {
var cmd string
var args []string
switch runtime.GOOS {
case "windows":
cmd = "cmd"
args = []string{"/c", "start"}
case "darwin":
cmd = "open"
default: // "linux", "freebsd", "openbsd", "netbsd"
cmd = "xdg-open"
}
args = append(args, url)
return exec.Command(cmd, args...).Start()
}
func download(certURL, filePath string) bool {
client := &http.Client{Timeout: clientTimeout}
// attempt a (long-running) certificate get
for i := 0; i < 20; i++ {
ok, err := tryDownload(client, certURL, filePath)
if ok {
putSuccess(client, certURL)
return true
}
if err != nil {
logger.WithError(err).Error("Error fetching certificate")
return false
}
}
return false
}
func tryDownload(client *http.Client, certURL, filePath string) (ok bool, err error) {
resp, err := client.Get(certURL)
if err != nil {
return false, err
}
defer resp.Body.Close()
if resp.StatusCode == 404 {
return false, nil
}
if resp.StatusCode != 200 {
return false, fmt.Errorf("Unexpected HTTP error code %d", resp.StatusCode)
}
if resp.Header.Get("Content-Type") != "application/x-pem-file" {
return false, fmt.Errorf("Unexpected content type %s", resp.Header.Get("Content-Type"))
}
// write response
file, err := os.Create(filePath)
if err != nil {
return false, err
}
defer file.Close()
written, err := io.Copy(file, resp.Body)
switch {
case err != nil:
return false, err
case resp.ContentLength != written && resp.ContentLength != -1:
return false, fmt.Errorf("Short read (%d bytes) from server while writing certificate", written)
default:
return true, nil
}
}
func putSuccess(client *http.Client, certURL string) {
// indicate success to the relay server
req, err := http.NewRequest("PUT", certURL+"/ok", nil)
if err != nil {
logger.WithError(err).Error("HTTP request error")
return
}
resp, err := client.Do(req)
if err != nil {
logger.WithError(err).Error("HTTP error")
return
}
resp.Body.Close()
if resp.StatusCode != 200 {
logger.Errorf("Unexpected HTTP error code %d", resp.StatusCode)
}
}

View File

@@ -0,0 +1,190 @@
// +build darwin
package main
import (
"fmt"
"os"
"gopkg.in/urfave/cli.v2"
"github.com/pkg/errors"
)
const (
launchdIdentifier = "com.cloudflare.cloudflared"
)
func runApp(app *cli.App, shutdownC, graceShutdownC chan struct{}) {
app.Commands = append(app.Commands, &cli.Command{
Name: "service",
Usage: "Manages the Argo Tunnel launch agent",
Subcommands: []*cli.Command{
{
Name: "install",
Usage: "Install Argo Tunnel as an user launch agent",
Action: installLaunchd,
},
{
Name: "uninstall",
Usage: "Uninstall the Argo Tunnel launch agent",
Action: uninstallLaunchd,
},
},
})
app.Run(os.Args)
}
func newLaunchdTemplate(installPath, stdoutPath, stderrPath string) *ServiceTemplate {
return &ServiceTemplate{
Path: installPath,
Content: fmt.Sprintf(`<?xml version="1.0" encoding="UTF-8"?>
<!DOCTYPE plist PUBLIC "-//Apple//DTD PLIST 1.0//EN" "http://www.apple.com/DTDs/PropertyList-1.0.dtd">
<plist version="1.0">
<dict>
<key>Label</key>
<string>%s</string>
<key>ProgramArguments</key>
<array>
<string>{{ .Path }}</string>
</array>
<key>RunAtLoad</key>
<true/>
<key>StandardOutPath</key>
<string>%s</string>
<key>StandardErrorPath</key>
<string>%s</string>
<key>KeepAlive</key>
<dict>
<key>SuccessfulExit</key>
<false/>
</dict>
<key>ThrottleInterval</key>
<integer>20</integer>
</dict>
</plist>`, launchdIdentifier, stdoutPath, stderrPath),
}
}
func isRootUser() bool {
return os.Geteuid() == 0
}
func installPath() (string, error) {
// User is root, use /Library/LaunchDaemons instead of home directory
if isRootUser() {
return fmt.Sprintf("/Library/LaunchDaemons/%s.plist", launchdIdentifier), nil
}
userHomeDir, err := userHomeDir()
if err != nil {
return "", err
}
return fmt.Sprintf("%s/Library/LaunchAgents/%s.plist", userHomeDir, launchdIdentifier), nil
}
func stdoutPath() (string, error) {
if isRootUser() {
return fmt.Sprintf("/Library/Logs/%s.out.log", launchdIdentifier), nil
}
userHomeDir, err := userHomeDir()
if err != nil {
return "", err
}
return fmt.Sprintf("%s/Library/Logs/%s.out.log", userHomeDir, launchdIdentifier), nil
}
func stderrPath() (string, error) {
if isRootUser() {
return fmt.Sprintf("/Library/Logs/%s.err.log", launchdIdentifier), nil
}
userHomeDir, err := userHomeDir()
if err != nil {
return "", err
}
return fmt.Sprintf("%s/Library/Logs/%s.err.log", userHomeDir, launchdIdentifier), nil
}
func installLaunchd(c *cli.Context) error {
if isRootUser() {
logger.Infof("Installing Argo Tunnel client as a system launch daemon. " +
"Argo Tunnel client will run at boot")
} else {
logger.Infof("Installing Argo Tunnel client as an user launch agent. " +
"Note that Argo Tunnel client will only run when the user is logged in. " +
"If you want to run Argo Tunnel client at boot, install with root permission. " +
"For more information, visit https://developers.cloudflare.com/argo-tunnel/reference/service/")
}
etPath, err := os.Executable()
if err != nil {
logger.WithError(err).Errorf("Error determining executable path")
return fmt.Errorf("Error determining executable path: %v", err)
}
installPath, err := installPath()
if err != nil {
return errors.Wrap(err, "Error determining install path")
}
stdoutPath, err := stdoutPath()
if err != nil {
return errors.Wrap(err, "error determining stdout path")
}
stderrPath, err := stderrPath()
if err != nil {
return errors.Wrap(err, "error determining stderr path")
}
launchdTemplate := newLaunchdTemplate(installPath, stdoutPath, stderrPath)
if err != nil {
logger.WithError(err).Errorf("error creating launchd template")
return errors.Wrap(err, "error creating launchd template")
}
templateArgs := ServiceTemplateArgs{Path: etPath}
err = launchdTemplate.Generate(&templateArgs)
if err != nil {
logger.WithError(err).Errorf("error generating launchd template")
return err
}
plistPath, err := launchdTemplate.ResolvePath()
if err != nil {
logger.WithError(err).Infof("error resolving launchd template path")
return err
}
logger.Infof("Outputs are logged to %s and %s", stderrPath, stdoutPath)
return runCommand("launchctl", "load", plistPath)
}
func uninstallLaunchd(c *cli.Context) error {
if isRootUser() {
logger.Infof("Uninstalling Argo Tunnel as a system launch daemon")
} else {
logger.Infof("Uninstalling Argo Tunnel as an user launch agent")
}
installPath, err := installPath()
if err != nil {
return errors.Wrap(err, "error determining install path")
}
stdoutPath, err := stdoutPath()
if err != nil {
return errors.Wrap(err, "error determining stdout path")
}
stderrPath, err := stderrPath()
if err != nil {
return errors.Wrap(err, "error determining stderr path")
}
launchdTemplate := newLaunchdTemplate(installPath, stdoutPath, stderrPath)
if err != nil {
return errors.Wrap(err, "error creating launchd template")
}
plistPath, err := launchdTemplate.ResolvePath()
if err != nil {
logger.WithError(err).Infof("error resolving launchd template path")
return err
}
err = runCommand("launchctl", "unload", plistPath)
if err != nil {
logger.WithError(err).Infof("error unloading")
return err
}
logger.Infof("Outputs are logged to %s and %s", stderrPath, stdoutPath)
return launchdTemplate.Remove()
}

564
cmd/cloudflared/main.go Normal file
View File

@@ -0,0 +1,564 @@
package main
import (
"fmt"
"os"
"sync"
"time"
"github.com/cloudflare/cloudflared/hello"
"github.com/cloudflare/cloudflared/metrics"
"github.com/cloudflare/cloudflared/origin"
"github.com/cloudflare/cloudflared/tunneldns"
"github.com/getsentry/raven-go"
"github.com/mitchellh/go-homedir"
"gopkg.in/urfave/cli.v2"
"gopkg.in/urfave/cli.v2/altsrc"
"github.com/coreos/go-systemd/daemon"
"github.com/facebookgo/grace/gracenet"
"github.com/pkg/errors"
)
const (
sentryDSN = "https://56a9c9fa5c364ab28f34b14f35ea0f1b:3e8827f6f9f740738eb11138f7bebb68@sentry.io/189878"
developerPortal = "https://developers.cloudflare.com/argo-tunnel"
quickStartUrl = developerPortal + "/quickstart/quickstart/"
serviceUrl = developerPortal + "/reference/service/"
argumentsUrl = developerPortal + "/reference/arguments/"
licenseUrl = developerPortal + "/licence/"
)
var (
Version = "DEV"
BuildTime = "unknown"
)
func main() {
metrics.RegisterBuildInfo(BuildTime, Version)
raven.SetDSN(sentryDSN)
raven.SetRelease(Version)
// Force shutdown channel used by the app. When closed, app must terminate.
// Windows service manager closes this channel when it receives shutdown command.
shutdownC := make(chan struct{})
// Graceful shutdown channel used by the app. When closed, app must terminate.
// Windows service manager closes this channel when it receives stop command.
graceShutdownC := make(chan struct{})
app := &cli.App{}
app.Name = "cloudflared"
app.Copyright = fmt.Sprintf(`(c) %d Cloudflare Inc.
Use is subject to the license agreement at %s`, time.Now().Year(), licenseUrl)
app.Usage = "Cloudflare reverse tunnelling proxy agent"
app.ArgsUsage = "origin-url"
app.Version = fmt.Sprintf("%s (built %s)", Version, BuildTime)
app.Description = `A reverse tunnel proxy agent that connects to Cloudflare's infrastructure.
Upon connecting, you are assigned a unique subdomain on cftunnel.com.
You need to specify a hostname on a zone you control.
A DNS record will be created to CNAME your hostname to the unique subdomain on cftunnel.com.
Requests made to Cloudflare's servers for your hostname will be proxied
through the tunnel to your local webserver.`
app.Flags = []cli.Flag{
&cli.StringFlag{
Name: "config",
Usage: "Specifies a config file in YAML format.",
Value: findDefaultConfigPath(),
},
altsrc.NewDurationFlag(&cli.DurationFlag{
Name: "autoupdate-freq",
Usage: "Autoupdate frequency. Default is 24h.",
Value: time.Hour * 24,
}),
altsrc.NewBoolFlag(&cli.BoolFlag{
Name: "no-autoupdate",
Usage: "Disable periodic check for updates, restarting the server with the new version.",
Value: false,
}),
altsrc.NewBoolFlag(&cli.BoolFlag{
Name: "is-autoupdated",
Usage: "Signal the new process that Argo Tunnel client has been autoupdated",
Value: false,
Hidden: true,
}),
altsrc.NewStringSliceFlag(&cli.StringSliceFlag{
Name: "edge",
Usage: "Address of the Cloudflare tunnel server.",
EnvVars: []string{"TUNNEL_EDGE"},
Hidden: true,
}),
altsrc.NewStringFlag(&cli.StringFlag{
Name: "cacert",
Usage: "Certificate Authority authenticating the Cloudflare tunnel connection.",
EnvVars: []string{"TUNNEL_CACERT"},
Hidden: true,
}),
altsrc.NewBoolFlag(&cli.BoolFlag{
Name: "no-tls-verify",
Usage: "Disables TLS verification of the certificate presented by your origin. Will allow any certificate from the origin to be accepted. Note: The connection from your machine to Cloudflare's Edge is still encrypted.",
EnvVars: []string{"NO_TLS_VERIFY"},
}),
altsrc.NewStringFlag(&cli.StringFlag{
Name: "origincert",
Usage: "Path to the certificate generated for your origin when you run cloudflared login.",
EnvVars: []string{"TUNNEL_ORIGIN_CERT"},
Value: findDefaultOriginCertPath(),
}),
altsrc.NewStringFlag(&cli.StringFlag{
Name: "origin-ca-pool",
Usage: "Path to the CA for the certificate of your origin. This option should be used only if your certificate is not signed by Cloudflare.",
EnvVars: []string{"TUNNEL_ORIGIN_CA_POOL"},
}),
altsrc.NewStringFlag(&cli.StringFlag{
Name: "url",
Value: "https://localhost:8080",
Usage: "Connect to the local webserver at `URL`.",
EnvVars: []string{"TUNNEL_URL"},
}),
altsrc.NewStringFlag(&cli.StringFlag{
Name: "hostname",
Usage: "Set a hostname on a Cloudflare zone to route traffic through this tunnel.",
EnvVars: []string{"TUNNEL_HOSTNAME"},
}),
altsrc.NewStringFlag(&cli.StringFlag{
Name: "origin-server-name",
Usage: "Hostname on the origin server certificate.",
EnvVars: []string{"TUNNEL_ORIGIN_SERVER_NAME"},
}),
altsrc.NewStringFlag(&cli.StringFlag{
Name: "id",
Usage: "A unique identifier used to tie connections to this tunnel instance.",
EnvVars: []string{"TUNNEL_ID"},
Hidden: true,
}),
altsrc.NewStringFlag(&cli.StringFlag{
Name: "lb-pool",
Usage: "The name of a (new/existing) load balancing pool to add this origin to.",
EnvVars: []string{"TUNNEL_LB_POOL"},
}),
altsrc.NewStringFlag(&cli.StringFlag{
Name: "api-key",
Usage: "This parameter has been deprecated since version 2017.10.1.",
EnvVars: []string{"TUNNEL_API_KEY"},
Hidden: true,
}),
altsrc.NewStringFlag(&cli.StringFlag{
Name: "api-email",
Usage: "This parameter has been deprecated since version 2017.10.1.",
EnvVars: []string{"TUNNEL_API_EMAIL"},
Hidden: true,
}),
altsrc.NewStringFlag(&cli.StringFlag{
Name: "api-ca-key",
Usage: "This parameter has been deprecated since version 2017.10.1.",
EnvVars: []string{"TUNNEL_API_CA_KEY"},
Hidden: true,
}),
altsrc.NewStringFlag(&cli.StringFlag{
Name: "metrics",
Value: "localhost:",
Usage: "Listen address for metrics reporting.",
EnvVars: []string{"TUNNEL_METRICS"},
}),
altsrc.NewDurationFlag(&cli.DurationFlag{
Name: "metrics-update-freq",
Usage: "Frequency to update tunnel metrics",
Value: time.Second * 5,
EnvVars: []string{"TUNNEL_METRICS_UPDATE_FREQ"},
}),
altsrc.NewStringSliceFlag(&cli.StringSliceFlag{
Name: "tag",
Usage: "Custom tags used to identify this tunnel, in format `KEY=VALUE`. Multiple tags may be specified",
EnvVars: []string{"TUNNEL_TAG"},
}),
altsrc.NewDurationFlag(&cli.DurationFlag{
Name: "heartbeat-interval",
Usage: "Minimum idle time before sending a heartbeat.",
Value: time.Second * 5,
Hidden: true,
}),
altsrc.NewUint64Flag(&cli.Uint64Flag{
Name: "heartbeat-count",
Usage: "Minimum number of unacked heartbeats to send before closing the connection.",
Value: 5,
Hidden: true,
}),
altsrc.NewStringFlag(&cli.StringFlag{
Name: "loglevel",
Value: "info",
Usage: "Application logging level {panic, fatal, error, warn, info, debug}",
EnvVars: []string{"TUNNEL_LOGLEVEL"},
}),
altsrc.NewStringFlag(&cli.StringFlag{
Name: "proto-loglevel",
Value: "warn",
Usage: "Protocol logging level {panic, fatal, error, warn, info, debug}",
EnvVars: []string{"TUNNEL_PROTO_LOGLEVEL"},
}),
altsrc.NewUintFlag(&cli.UintFlag{
Name: "retries",
Value: 5,
Usage: "Maximum number of retries for connection/protocol errors.",
EnvVars: []string{"TUNNEL_RETRIES"},
}),
altsrc.NewBoolFlag(&cli.BoolFlag{
Name: "hello-world",
Value: false,
Usage: "Run Hello World Server",
EnvVars: []string{"TUNNEL_HELLO_WORLD"},
}),
altsrc.NewStringFlag(&cli.StringFlag{
Name: "pidfile",
Usage: "Write the application's PID to this file after first successful connection.",
EnvVars: []string{"TUNNEL_PIDFILE"},
}),
altsrc.NewStringFlag(&cli.StringFlag{
Name: "logfile",
Usage: "Save application log to this file for reporting issues.",
EnvVars: []string{"TUNNEL_LOGFILE"},
}),
altsrc.NewIntFlag(&cli.IntFlag{
Name: "ha-connections",
Value: 4,
Hidden: true,
}),
altsrc.NewDurationFlag(&cli.DurationFlag{
Name: "proxy-connect-timeout",
Usage: "HTTP proxy timeout for establishing a new connection",
Value: time.Second * 30,
}),
altsrc.NewDurationFlag(&cli.DurationFlag{
Name: "proxy-tls-timeout",
Usage: "HTTP proxy timeout for completing a TLS handshake",
Value: time.Second * 10,
}),
altsrc.NewDurationFlag(&cli.DurationFlag{
Name: "proxy-tcp-keepalive",
Usage: "HTTP proxy TCP keepalive duration",
Value: time.Second * 30,
}),
altsrc.NewBoolFlag(&cli.BoolFlag{
Name: "proxy-no-happy-eyeballs",
Usage: "HTTP proxy should disable \"happy eyeballs\" for IPv4/v6 fallback",
}),
altsrc.NewIntFlag(&cli.IntFlag{
Name: "proxy-keepalive-connections",
Usage: "HTTP proxy maximum keepalive connection pool size",
Value: 100,
}),
altsrc.NewDurationFlag(&cli.DurationFlag{
Name: "proxy-keepalive-timeout",
Usage: "HTTP proxy timeout for closing an idle connection",
Value: time.Second * 90,
}),
altsrc.NewBoolFlag(&cli.BoolFlag{
Name: "proxy-dns",
Usage: "Run a DNS over HTTPS proxy server.",
EnvVars: []string{"TUNNEL_DNS"},
}),
altsrc.NewIntFlag(&cli.IntFlag{
Name: "proxy-dns-port",
Value: 53,
Usage: "Listen on given port for the DNS over HTTPS proxy server.",
EnvVars: []string{"TUNNEL_DNS_PORT"},
}),
altsrc.NewStringFlag(&cli.StringFlag{
Name: "proxy-dns-address",
Usage: "Listen address for the DNS over HTTPS proxy server.",
Value: "localhost",
EnvVars: []string{"TUNNEL_DNS_ADDRESS"},
}),
altsrc.NewStringSliceFlag(&cli.StringSliceFlag{
Name: "proxy-dns-upstream",
Usage: "Upstream endpoint URL, you can specify multiple endpoints for redundancy.",
Value: cli.NewStringSlice("https://1.1.1.1/dns-query", "https://1.0.0.1/dns-query"),
EnvVars: []string{"TUNNEL_DNS_UPSTREAM"},
}),
altsrc.NewDurationFlag(&cli.DurationFlag{
Name: "grace-period",
Usage: "Duration to accept new requests after cloudflared receives first SIGINT/SIGTERM. A second SIGINT/SIGTERM will force cloudflared to shutdown immediately.",
Value: time.Second * 30,
EnvVars: []string{"TUNNEL_GRACE_PERIOD"},
Hidden: true,
}),
altsrc.NewUintFlag(&cli.UintFlag{
Name: "compression-quality",
Value: 0,
Usage: "Use cross-stream compression instead HTTP compression. 0-off, 1-low, 2-medium, >=3-high",
EnvVars: []string{"TUNNEL_COMPRESSION_LEVEL"},
Hidden: true,
}),
altsrc.NewBoolFlag(&cli.BoolFlag{
Name: "no-chunked-encoding",
Usage: "Disables chunked transfer encoding; useful if you are running a WSGI server.",
EnvVars: []string{"TUNNEL_NO_CHUNKED_ENCODING"},
}),
}
app.Action = func(c *cli.Context) (err error) {
tags := make(map[string]string)
tags["hostname"] = c.String("hostname")
raven.SetTagsContext(tags)
raven.CapturePanic(func() { err = startServer(c, shutdownC, graceShutdownC) }, nil)
if err != nil {
raven.CaptureError(err, nil)
}
return err
}
app.Before = func(context *cli.Context) error {
if context.String("config") == "" {
logger.Warnf("Cannot determine default configuration path. No file %v in %v", defaultConfigFiles, defaultConfigDirs)
}
inputSource, err := findInputSourceContext(context)
if err != nil {
logger.WithError(err).Infof("Cannot load configuration from %s", context.String("config"))
return err
} else if inputSource != nil {
err := altsrc.ApplyInputSourceValues(context, inputSource, app.Flags)
if err != nil {
logger.WithError(err).Infof("Cannot apply configuration from %s", context.String("config"))
return err
}
logger.Infof("Applied configuration from %s", context.String("config"))
}
return nil
}
app.Commands = []*cli.Command{
{
Name: "update",
Action: update,
Usage: "Update the agent if a new version exists",
ArgsUsage: " ",
Description: `Looks for a new version on the offical download server.
If a new version exists, updates the agent binary and quits.
Otherwise, does nothing.
To determine if an update happened in a script, check for error code 64.`,
},
{
Name: "login",
Action: login,
Usage: "Generate a configuration file with your login details",
ArgsUsage: " ",
Flags: []cli.Flag{
&cli.StringFlag{
Name: "url",
Hidden: true,
},
},
},
{
Name: "hello",
Action: helloWorld,
Usage: "Run a simple \"Hello World\" server for testing Argo Tunnel.",
Flags: []cli.Flag{
&cli.IntFlag{
Name: "port",
Usage: "Listen on the selected port.",
Value: 8080,
},
},
ArgsUsage: " ", // can't be the empty string or we get the default output
},
{
Name: "proxy-dns",
Action: tunneldns.Run,
Usage: "Run a DNS over HTTPS proxy server.",
Flags: []cli.Flag{
&cli.StringFlag{
Name: "metrics",
Value: "localhost:",
Usage: "Listen address for metrics reporting.",
EnvVars: []string{"TUNNEL_METRICS"},
},
&cli.StringFlag{
Name: "address",
Usage: "Listen address for the DNS over HTTPS proxy server.",
Value: "localhost",
EnvVars: []string{"TUNNEL_DNS_ADDRESS"},
},
&cli.IntFlag{
Name: "port",
Usage: "Listen on given port for the DNS over HTTPS proxy server.",
Value: 53,
EnvVars: []string{"TUNNEL_DNS_PORT"},
},
&cli.StringSliceFlag{
Name: "upstream",
Usage: "Upstream endpoint URL, you can specify multiple endpoints for redundancy.",
Value: cli.NewStringSlice("https://1.1.1.1/dns-query", "https://1.0.0.1/dns-query"),
EnvVars: []string{"TUNNEL_DNS_UPSTREAM"},
},
},
ArgsUsage: " ", // can't be the empty string or we get the default output
},
}
runApp(app, shutdownC, graceShutdownC)
}
func startServer(c *cli.Context, shutdownC, graceShutdownC chan struct{}) error {
var wg sync.WaitGroup
listeners := gracenet.Net{}
errC := make(chan error)
connectedSignal := make(chan struct{})
dnsReadySignal := make(chan struct{})
// check whether client provides enough flags or env variables. If not, print help.
if ok := enoughOptionsSet(c); !ok {
return nil
}
if err := configMainLogger(c); err != nil {
return errors.Wrap(err, "Error configuring logger")
}
protoLogger, err := configProtoLogger(c)
if err != nil {
return errors.Wrap(err, "Error configuring protocol logger")
}
if c.String("logfile") != "" {
if err := initLogFile(c, logger, protoLogger); err != nil {
logger.Error(err)
}
}
if err := handleDeprecatedOptions(c); err != nil {
return err
}
buildInfo := origin.GetBuildInfo()
logger.Infof("Build info: %+v", *buildInfo)
logger.Infof("Version %s", Version)
logClientOptions(c)
if c.IsSet("proxy-dns") {
wg.Add(1)
go func() {
defer wg.Done()
errC <- runDNSProxyServer(c, dnsReadySignal, shutdownC)
}()
} else {
close(dnsReadySignal)
}
// Wait for proxy-dns to come up (if used)
<-dnsReadySignal
// update needs to be after DNS proxy is up to resolve equinox server address
if isAutoupdateEnabled(c) {
logger.Infof("Autoupdate frequency is set to %v", c.Duration("autoupdate-freq"))
wg.Add(1)
go func() {
defer wg.Done()
errC <- autoupdate(c.Duration("autoupdate-freq"), &listeners, shutdownC)
}()
}
metricsListener, err := listeners.Listen("tcp", c.String("metrics"))
if err != nil {
logger.WithError(err).Error("Error opening metrics server listener")
return errors.Wrap(err, "Error opening metrics server listener")
}
defer metricsListener.Close()
wg.Add(1)
go func() {
defer wg.Done()
errC <- metrics.ServeMetrics(metricsListener, shutdownC, logger)
}()
go notifySystemd(connectedSignal)
if c.IsSet("pidfile") {
go writePidFile(connectedSignal, c.String("pidfile"))
}
// Serve DNS proxy stand-alone if no hostname or tag or app is going to run
if dnsProxyStandAlone(c) {
close(connectedSignal)
// no grace period, handle SIGINT/SIGTERM immediately
return waitToShutdown(&wg, errC, shutdownC, graceShutdownC, 0)
}
if c.IsSet("hello-world") {
helloListener, err := hello.CreateTLSListener("127.0.0.1:")
if err != nil {
logger.WithError(err).Error("Cannot start Hello World Server")
return errors.Wrap(err, "Cannot start Hello World Server")
}
defer helloListener.Close()
wg.Add(1)
go func() {
defer wg.Done()
hello.StartHelloWorldServer(logger, helloListener, shutdownC)
}()
c.Set("url", "https://"+helloListener.Addr().String())
}
tunnelConfig, err := prepareTunnelConfig(c, buildInfo, logger, protoLogger)
if err != nil {
return err
}
wg.Add(1)
go func() {
defer wg.Done()
errC <- origin.StartTunnelDaemon(tunnelConfig, graceShutdownC, connectedSignal)
}()
return waitToShutdown(&wg, errC, shutdownC, graceShutdownC, c.Duration("grace-period"))
}
func waitToShutdown(wg *sync.WaitGroup,
errC chan error,
shutdownC, graceShutdownC chan struct{},
gracePeriod time.Duration,
) error {
var err error
if gracePeriod > 0 {
err = waitForSignalWithGraceShutdown(errC, shutdownC, graceShutdownC, gracePeriod)
} else {
err = waitForSignal(errC, shutdownC)
close(graceShutdownC)
}
if err != nil {
logger.WithError(err).Error("Quitting due to error")
} else {
logger.Info("Quitting...")
}
// Wait for clean exit, discarding all errors
go func() {
for range errC {
}
}()
wg.Wait()
return err
}
func notifySystemd(waitForSignal chan struct{}) {
<-waitForSignal
daemon.SdNotify(false, "READY=1")
}
func writePidFile(waitForSignal chan struct{}, pidFile string) {
<-waitForSignal
file, err := os.Create(pidFile)
if err != nil {
logger.WithError(err).Errorf("Unable to write pid to %s", pidFile)
}
defer file.Close()
fmt.Fprintf(file, "%d", os.Getpid())
}
func userHomeDir() (string, error) {
// This returns the home dir of the executing user using OS-specific method
// for discovering the home dir. It's not recommended to call this function
// when the user has root permission as $HOME depends on what options the user
// use with sudo.
homeDir, err := homedir.Dir()
if err != nil {
logger.WithError(err).Error("Cannot determine home directory for the user")
return "", errors.Wrap(err, "Cannot determine home directory for the user")
}
return homeDir, nil
}

33
cmd/cloudflared/server.go Normal file
View File

@@ -0,0 +1,33 @@
package main
import (
"github.com/cloudflare/cloudflared/tunneldns"
"gopkg.in/urfave/cli.v2"
"github.com/pkg/errors"
)
func runDNSProxyServer(c *cli.Context, dnsReadySignal, shutdownC chan struct{}) error {
port := c.Int("proxy-dns-port")
if port <= 0 || port > 65535 {
logger.Errorf("The 'proxy-dns-port' must be a valid port number in <1, 65535> range.")
return errors.New("The 'proxy-dns-port' must be a valid port number in <1, 65535> range.")
}
listener, err := tunneldns.CreateListener(c.String("proxy-dns-address"), uint16(port), c.StringSlice("proxy-dns-upstream"))
if err != nil {
close(dnsReadySignal)
listener.Stop()
logger.WithError(err).Error("Cannot create the DNS over HTTPS proxy server")
return errors.Wrap(err, "Cannot create the DNS over HTTPS proxy server")
}
err = listener.Start(dnsReadySignal)
if err != nil {
logger.WithError(err).Error("Cannot start the DNS over HTTPS proxy server")
return errors.Wrap(err, "Cannot start the DNS over HTTPS proxy server")
}
<-shutdownC
listener.Stop()
return nil
}

View File

@@ -0,0 +1,192 @@
package main
import (
"bufio"
"bytes"
"fmt"
"io"
"io/ioutil"
"os"
"os/exec"
"path/filepath"
"text/template"
"github.com/mitchellh/go-homedir"
)
type ServiceTemplate struct {
Path string
Content string
FileMode os.FileMode
}
type ServiceTemplateArgs struct {
Path string
}
func (st *ServiceTemplate) ResolvePath() (string, error) {
resolvedPath, err := homedir.Expand(st.Path)
if err != nil {
return "", fmt.Errorf("error resolving path %s: %v", st.Path, err)
}
return resolvedPath, nil
}
func (st *ServiceTemplate) Generate(args *ServiceTemplateArgs) error {
tmpl, err := template.New(st.Path).Parse(st.Content)
if err != nil {
return fmt.Errorf("error generating %s template: %v", st.Path, err)
}
resolvedPath, err := st.ResolvePath()
if err != nil {
return err
}
var buffer bytes.Buffer
err = tmpl.Execute(&buffer, args)
if err != nil {
return fmt.Errorf("error generating %s: %v", st.Path, err)
}
fileMode := os.FileMode(0644)
if st.FileMode != 0 {
fileMode = st.FileMode
}
err = ioutil.WriteFile(resolvedPath, buffer.Bytes(), fileMode)
if err != nil {
return fmt.Errorf("error writing %s: %v", resolvedPath, err)
}
return nil
}
func (st *ServiceTemplate) Remove() error {
resolvedPath, err := st.ResolvePath()
if err != nil {
return err
}
err = os.Remove(resolvedPath)
if err != nil {
return fmt.Errorf("error deleting %s: %v", resolvedPath, err)
}
return nil
}
func runCommand(command string, args ...string) error {
cmd := exec.Command(command, args...)
stderr, err := cmd.StderrPipe()
if err != nil {
logger.WithError(err).Infof("error getting stderr pipe")
return fmt.Errorf("error getting stderr pipe: %v", err)
}
err = cmd.Start()
if err != nil {
logger.WithError(err).Infof("error starting %s", command)
return fmt.Errorf("error starting %s: %v", command, err)
}
commandErr, _ := ioutil.ReadAll(stderr)
if len(commandErr) > 0 {
logger.Errorf("%s: %s", command, commandErr)
}
err = cmd.Wait()
if err != nil {
logger.WithError(err).Infof("%s returned error", command)
return fmt.Errorf("%s returned with error: %v", command, err)
}
return nil
}
func ensureConfigDirExists(configDir string) error {
ok, err := fileExists(configDir)
if !ok && err == nil {
err = os.Mkdir(configDir, 0700)
}
return err
}
// openFile opens the file at path. If create is set and the file exists, returns nil, true, nil
func openFile(path string, create bool) (file *os.File, exists bool, err error) {
expandedPath, err := homedir.Expand(path)
if err != nil {
return nil, false, err
}
if create {
fileInfo, err := os.Stat(expandedPath)
if err == nil && fileInfo.Size() > 0 {
return nil, true, nil
}
file, err = os.OpenFile(expandedPath, os.O_RDWR|os.O_CREATE, 0600)
} else {
file, err = os.Open(expandedPath)
}
return file, false, err
}
func copyCertificate(srcConfigDir, destConfigDir, credentialFile string) error {
destCredentialPath := filepath.Join(destConfigDir, credentialFile)
destFile, exists, err := openFile(destCredentialPath, true)
if err != nil {
return err
} else if exists {
// credentials already exist, do nothing
return nil
}
defer destFile.Close()
srcCredentialPath := filepath.Join(srcConfigDir, credentialFile)
srcFile, _, err := openFile(srcCredentialPath, false)
if err != nil {
return err
}
defer srcFile.Close()
// Copy certificate
_, err = io.Copy(destFile, srcFile)
if err != nil {
return fmt.Errorf("unable to copy %s to %s: %v", srcCredentialPath, destCredentialPath, err)
}
return nil
}
func copyCredentials(serviceConfigDir, defaultConfigDir, defaultConfigFile, defaultCredentialFile string) error {
if err := ensureConfigDirExists(serviceConfigDir); err != nil {
return err
}
if err := copyCertificate(defaultConfigDir, serviceConfigDir, defaultCredentialFile); err != nil {
return err
}
// Copy or create config
destConfigPath := filepath.Join(serviceConfigDir, defaultConfigFile)
destFile, exists, err := openFile(destConfigPath, true)
if err != nil {
logger.WithError(err).Infof("cannot open %s", destConfigPath)
return err
} else if exists {
// config already exists, do nothing
return nil
}
defer destFile.Close()
srcConfigPath := filepath.Join(defaultConfigDir, defaultConfigFile)
srcFile, _, err := openFile(srcConfigPath, false)
if err != nil {
fmt.Println("Your service needs a config file that at least specifies the hostname option.")
fmt.Println("Type in a hostname now, or leave it blank and create the config file later.")
fmt.Print("Hostname: ")
reader := bufio.NewReader(os.Stdin)
input, _ := reader.ReadString('\n')
if input == "" {
return err
}
fmt.Fprintf(destFile, "hostname: %s\n", input)
} else {
defer srcFile.Close()
_, err = io.Copy(destFile, srcFile)
if err != nil {
return fmt.Errorf("unable to copy %s to %s: %v", srcConfigPath, destConfigPath, err)
}
logger.Infof("Copied %s to %s", srcConfigPath, destConfigPath)
}
return nil
}

79
cmd/cloudflared/signal.go Normal file
View File

@@ -0,0 +1,79 @@
package main
import (
"os"
"os/signal"
"syscall"
"time"
)
// waitForSignal notifies all routines to shutdownC immediately by closing the
// shutdownC when one of the routines in main exits, or when this process receives
// SIGTERM/SIGINT
func waitForSignal(errC chan error, shutdownC chan struct{}) error {
signals := make(chan os.Signal, 10)
signal.Notify(signals, syscall.SIGTERM, syscall.SIGINT)
defer signal.Stop(signals)
select {
case err := <-errC:
close(shutdownC)
return err
case <-signals:
close(shutdownC)
case <-shutdownC:
}
return nil
}
// waitForSignalWithGraceShutdown notifies all routines to shutdown immediately
// by closing the shutdownC when one of the routines in main exits.
// When this process recieves SIGTERM/SIGINT, it closes the graceShutdownC to
// notify certain routines to start graceful shutdown. When grace period is over,
// or when some routine exits, it notifies the rest of the routines to shutdown
// immediately by closing shutdownC.
// In the case of handling commands from Windows Service Manager, closing graceShutdownC
// initiate graceful shutdown.
func waitForSignalWithGraceShutdown(errC chan error,
shutdownC, graceShutdownC chan struct{},
gracePeriod time.Duration,
) error {
signals := make(chan os.Signal, 10)
signal.Notify(signals, syscall.SIGTERM, syscall.SIGINT)
defer signal.Stop(signals)
select {
case err := <-errC:
close(graceShutdownC)
close(shutdownC)
return err
case <-signals:
close(graceShutdownC)
waitForGracePeriod(signals, errC, shutdownC, gracePeriod)
case <-graceShutdownC:
waitForGracePeriod(signals, errC, shutdownC, gracePeriod)
case <-shutdownC:
close(graceShutdownC)
}
return nil
}
func waitForGracePeriod(signals chan os.Signal,
errC chan error,
shutdownC chan struct{},
gracePeriod time.Duration,
) {
logger.Infof("Initiating graceful shutdown...")
// Unregister signal handler early, so the client can send a second SIGTERM/SIGINT
// to force shutdown cloudflared
signal.Stop(signals)
graceTimerTick := time.Tick(gracePeriod)
// send close signal via shutdownC when grace period expires or when an
// error is encountered.
select {
case <-graceTimerTick:
case <-errC:
}
close(shutdownC)
}

View File

@@ -0,0 +1,152 @@
package main
import (
"fmt"
"syscall"
"testing"
"time"
"github.com/stretchr/testify/assert"
)
const tick = 100 * time.Millisecond
var (
serverErr = fmt.Errorf("server error")
shutdownErr = fmt.Errorf("receive shutdown")
graceShutdownErr = fmt.Errorf("receive grace shutdown")
)
func testChannelClosed(t *testing.T, c chan struct{}) {
select {
case <-c:
return
default:
t.Fatal("Channel should be closed")
}
}
func TestWaitForSignal(t *testing.T) {
// Test handling server error
errC := make(chan error)
shutdownC := make(chan struct{})
go func() {
errC <- serverErr
}()
// received error, shutdownC should be closed
err := waitForSignal(errC, shutdownC)
assert.Equal(t, serverErr, err)
testChannelClosed(t, shutdownC)
// Test handling SIGTERM & SIGINT
for _, sig := range []syscall.Signal{syscall.SIGTERM, syscall.SIGINT} {
errC = make(chan error)
shutdownC = make(chan struct{})
go func(shutdownC chan struct{}) {
<-shutdownC
errC <- shutdownErr
}(shutdownC)
go func(sig syscall.Signal) {
// sleep for a tick to prevent sending signal before calling waitForSignal
time.Sleep(tick)
syscall.Kill(syscall.Getpid(), sig)
}(sig)
err = waitForSignal(errC, shutdownC)
assert.Equal(t, nil, err)
assert.Equal(t, shutdownErr, <-errC)
testChannelClosed(t, shutdownC)
}
}
func TestWaitForSignalWithGraceShutdown(t *testing.T) {
// Test server returning error
errC := make(chan error)
shutdownC := make(chan struct{})
graceshutdownC := make(chan struct{})
go func() {
errC <- serverErr
}()
// received error, both shutdownC and graceshutdownC should be closed
err := waitForSignalWithGraceShutdown(errC, shutdownC, graceshutdownC, tick)
assert.Equal(t, serverErr, err)
testChannelClosed(t, shutdownC)
testChannelClosed(t, graceshutdownC)
// shutdownC closed, graceshutdownC should also be closed and no error
errC = make(chan error)
shutdownC = make(chan struct{})
graceshutdownC = make(chan struct{})
close(shutdownC)
err = waitForSignalWithGraceShutdown(errC, shutdownC, graceshutdownC, tick)
assert.NoError(t, err)
testChannelClosed(t, shutdownC)
testChannelClosed(t, graceshutdownC)
// graceshutdownC closed, shutdownC should also be closed and no error
errC = make(chan error)
shutdownC = make(chan struct{})
graceshutdownC = make(chan struct{})
close(graceshutdownC)
err = waitForSignalWithGraceShutdown(errC, shutdownC, graceshutdownC, tick)
assert.NoError(t, err)
testChannelClosed(t, shutdownC)
testChannelClosed(t, graceshutdownC)
// Test handling SIGTERM & SIGINT
for _, sig := range []syscall.Signal{syscall.SIGTERM, syscall.SIGINT} {
errC := make(chan error)
shutdownC = make(chan struct{})
graceshutdownC = make(chan struct{})
go func(shutdownC, graceshutdownC chan struct{}) {
<-graceshutdownC
<-shutdownC
errC <- graceShutdownErr
}(shutdownC, graceshutdownC)
go func(sig syscall.Signal) {
// sleep for a tick to prevent sending signal before calling waitForSignalWithGraceShutdown
time.Sleep(tick)
syscall.Kill(syscall.Getpid(), sig)
}(sig)
err = waitForSignalWithGraceShutdown(errC, shutdownC, graceshutdownC, tick)
assert.Equal(t, nil, err)
assert.Equal(t, graceShutdownErr, <-errC)
testChannelClosed(t, shutdownC)
testChannelClosed(t, graceshutdownC)
}
// Test handling SIGTERM & SIGINT, server send error before end of grace period
for _, sig := range []syscall.Signal{syscall.SIGTERM, syscall.SIGINT} {
errC := make(chan error)
shutdownC = make(chan struct{})
graceshutdownC = make(chan struct{})
go func(shutdownC, graceshutdownC chan struct{}) {
<-graceshutdownC
errC <- graceShutdownErr
<-shutdownC
errC <- shutdownErr
}(shutdownC, graceshutdownC)
go func(sig syscall.Signal) {
// sleep for a tick to prevent sending signal before calling waitForSignalWithGraceShutdown
time.Sleep(tick)
syscall.Kill(syscall.Getpid(), sig)
}(sig)
err = waitForSignalWithGraceShutdown(errC, shutdownC, graceshutdownC, tick)
assert.Equal(t, nil, err)
assert.Equal(t, shutdownErr, <-errC)
testChannelClosed(t, shutdownC)
testChannelClosed(t, graceshutdownC)
}
}

32
cmd/cloudflared/tag.go Normal file
View File

@@ -0,0 +1,32 @@
package main
import (
"fmt"
"regexp"
tunnelpogs "github.com/cloudflare/cloudflared/tunnelrpc/pogs"
)
// Restrict key names to characters allowed in an HTTP header name.
// Restrict key values to printable characters (what is recognised as data in an HTTP header value).
var tagRegexp = regexp.MustCompile("^([a-zA-Z0-9!#$%&'*+\\-.^_`|~]+)=([[:print:]]+)$")
func NewTagFromCLI(compoundTag string) (tunnelpogs.Tag, bool) {
matches := tagRegexp.FindStringSubmatch(compoundTag)
if len(matches) == 0 {
return tunnelpogs.Tag{}, false
}
return tunnelpogs.Tag{Name: matches[1], Value: matches[2]}, true
}
func NewTagSliceFromCLI(tags []string) ([]tunnelpogs.Tag, error) {
var tagSlice []tunnelpogs.Tag
for _, compoundTag := range tags {
if tag, ok := NewTagFromCLI(compoundTag); ok {
tagSlice = append(tagSlice, tag)
} else {
return nil, fmt.Errorf("Cannot parse tag value %s", compoundTag)
}
}
return tagSlice, nil
}

View File

@@ -0,0 +1,46 @@
package main
import (
"testing"
tunnelpogs "github.com/cloudflare/cloudflared/tunnelrpc/pogs"
"github.com/stretchr/testify/assert"
)
func TestSingleTag(t *testing.T) {
testCases := []struct {
Input string
Output tunnelpogs.Tag
Fail bool
}{
{Input: "x=y", Output: tunnelpogs.Tag{Name: "x", Value: "y"}},
{Input: "More-Complex=Tag Values", Output: tunnelpogs.Tag{Name: "More-Complex", Value: "Tag Values"}},
{Input: "First=Equals=Wins", Output: tunnelpogs.Tag{Name: "First", Value: "Equals=Wins"}},
{Input: "x=", Fail: true},
{Input: "=y", Fail: true},
{Input: "=", Fail: true},
{Input: "No spaces allowed=in key names", Fail: true},
{Input: "omg\nwtf=bbq", Fail: true},
}
for i, testCase := range testCases {
tag, ok := NewTagFromCLI(testCase.Input)
assert.Equalf(t, !testCase.Fail, ok, "mismatched success for test case %d", i)
assert.Equalf(t, testCase.Output, tag, "mismatched output for test case %d", i)
}
}
func TestTagSlice(t *testing.T) {
tagSlice, err := NewTagSliceFromCLI([]string{"a=b", "c=d", "e=f"})
assert.NoError(t, err)
assert.Len(t, tagSlice, 3)
assert.Equal(t, "a", tagSlice[0].Name)
assert.Equal(t, "b", tagSlice[0].Value)
assert.Equal(t, "c", tagSlice[1].Name)
assert.Equal(t, "d", tagSlice[1].Value)
assert.Equal(t, "e", tagSlice[2].Name)
assert.Equal(t, "f", tagSlice[2].Value)
tagSlice, err = NewTagSliceFromCLI([]string{"a=b", "=", "e=f"})
assert.Error(t, err)
}

115
cmd/cloudflared/update.go Normal file
View File

@@ -0,0 +1,115 @@
package main
import (
"os"
"runtime"
"time"
"golang.org/x/crypto/ssh/terminal"
"gopkg.in/urfave/cli.v2"
"github.com/equinox-io/equinox"
"github.com/facebookgo/grace/gracenet"
)
const (
appID = "app_idCzgxYerVD"
noUpdateInShellMessage = "cloudflared will not automatically update when run from the shell. To enable auto-updates, run cloudflared as a service: https://developers.cloudflare.com/argo-tunnel/reference/service/"
noUpdateOnWindowsMessage = "cloudflared will not automatically update on Windows systems."
)
var publicKey = []byte(`
-----BEGIN ECDSA PUBLIC KEY-----
MHYwEAYHKoZIzj0CAQYFK4EEACIDYgAE4OWZocTVZ8Do/L6ScLdkV+9A0IYMHoOf
dsCmJ/QZ6aw0w9qkkwEpne1Lmo6+0pGexZzFZOH6w5amShn+RXt7qkSid9iWlzGq
EKx0BZogHSor9Wy5VztdFaAaVbsJiCbO
-----END ECDSA PUBLIC KEY-----
`)
type ReleaseInfo struct {
Updated bool
Version string
Error error
}
func checkForUpdates() ReleaseInfo {
var opts equinox.Options
if err := opts.SetPublicKeyPEM(publicKey); err != nil {
return ReleaseInfo{Error: err}
}
resp, err := equinox.Check(appID, opts)
switch {
case err == equinox.NotAvailableErr:
return ReleaseInfo{}
case err != nil:
return ReleaseInfo{Error: err}
}
err = resp.Apply()
if err != nil {
return ReleaseInfo{Error: err}
}
return ReleaseInfo{Updated: true, Version: resp.ReleaseVersion}
}
func update(_ *cli.Context) error {
if updateApplied() {
os.Exit(64)
}
return nil
}
func autoupdate(freq time.Duration, listeners *gracenet.Net, shutdownC chan struct{}) error {
tickC := time.Tick(freq)
for {
if updateApplied() {
os.Args = append(os.Args, "--is-autoupdated=true")
pid, err := listeners.StartProcess()
if err != nil {
logger.WithError(err).Error("Unable to restart server automatically")
return err
}
// stop old process after autoupdate. Otherwise we create a new process
// after each update
logger.Infof("PID of the new process is %d", pid)
return nil
}
select {
case <-tickC:
case <-shutdownC:
return nil
}
}
}
func updateApplied() bool {
releaseInfo := checkForUpdates()
if releaseInfo.Updated {
logger.Infof("Updated to version %s", releaseInfo.Version)
return true
}
if releaseInfo.Error != nil {
logger.WithError(releaseInfo.Error).Error("Update check failed")
}
return false
}
func isAutoupdateEnabled(c *cli.Context) bool {
if runtime.GOOS == "windows" {
logger.Info(noUpdateOnWindowsMessage)
return false
}
if isRunningFromTerminal() {
logger.Info(noUpdateInShellMessage)
return false
}
return !c.Bool("no-autoupdate") && c.Duration("autoupdate-freq") != 0
}
func isRunningFromTerminal() bool {
return terminal.IsTerminal(int(os.Stdout.Fd()))
}

View File

@@ -0,0 +1,252 @@
// +build windows
package main
// Copypasta from the example files:
// https://github.com/golang/sys/blob/master/windows/svc/example
import (
"fmt"
"os"
"time"
"unsafe"
cli "gopkg.in/urfave/cli.v2"
"golang.org/x/sys/windows"
"golang.org/x/sys/windows/svc"
"golang.org/x/sys/windows/svc/eventlog"
"golang.org/x/sys/windows/svc/mgr"
)
const (
windowsServiceName = "Cloudflared"
windowsServiceDescription = "Argo Tunnel agent"
recoverActionDelay = time.Second * 20
failureCountResetPeriod = time.Hour * 24
// not defined in golang.org/x/sys/windows package
// https://msdn.microsoft.com/en-us/library/windows/desktop/ms681988(v=vs.85).aspx
serviceConfigFailureActionsFlag = 4
)
func runApp(app *cli.App, shutdownC, graceShutdownC chan struct{}) {
app.Commands = append(app.Commands, &cli.Command{
Name: "service",
Usage: "Manages the Argo Tunnel Windows service",
Subcommands: []*cli.Command{
&cli.Command{
Name: "install",
Usage: "Install Argo Tunnel as a Windows service",
Action: installWindowsService,
},
&cli.Command{
Name: "uninstall",
Usage: "Uninstall the Argo Tunnel service",
Action: uninstallWindowsService,
},
},
})
isIntSess, err := svc.IsAnInteractiveSession()
if err != nil {
logger.Fatalf("failed to determine if we are running in an interactive session: %v", err)
}
if isIntSess {
app.Run(os.Args)
return
}
elog, err := eventlog.Open(windowsServiceName)
if err != nil {
logger.WithError(err).Errorf("Cannot open event log for %s", windowsServiceName)
return
}
defer elog.Close()
elog.Info(1, fmt.Sprintf("%s service starting", windowsServiceName))
// Run executes service name by calling windowsService which is a Handler
// interface that implements Execute method.
// It will set service status to stop after Execute returns
err = svc.Run(windowsServiceName, &windowsService{app: app, elog: elog, shutdownC: shutdownC, graceShutdownC: graceShutdownC})
if err != nil {
elog.Error(1, fmt.Sprintf("%s service failed: %v", windowsServiceName, err))
return
}
elog.Info(1, fmt.Sprintf("%s service stopped", windowsServiceName))
}
type windowsService struct {
app *cli.App
elog *eventlog.Log
shutdownC chan struct{}
graceShutdownC chan struct{}
}
// called by the package code at the start of the service
func (s *windowsService) Execute(args []string, r <-chan svc.ChangeRequest, statusChan chan<- svc.Status) (ssec bool, errno uint32) {
const cmdsAccepted = svc.AcceptStop | svc.AcceptShutdown
statusChan <- svc.Status{State: svc.StartPending}
errC := make(chan error)
go func() {
errC <- s.app.Run(args)
}()
statusChan <- svc.Status{State: svc.Running, Accepts: cmdsAccepted}
for {
select {
case c := <-r:
switch c.Cmd {
case svc.Interrogate:
s.elog.Info(1, fmt.Sprintf("control request 1 #%d", c))
statusChan <- c.CurrentStatus
case svc.Stop:
s.elog.Info(1, "received stop control request")
close(s.graceShutdownC)
statusChan <- svc.Status{State: svc.StopPending}
case svc.Shutdown:
s.elog.Info(1, "received shutdown control request")
close(s.shutdownC)
statusChan <- svc.Status{State: svc.StopPending}
default:
s.elog.Error(1, fmt.Sprintf("unexpected control request #%d", c))
}
case err := <-errC:
ssec = true
if err != nil {
s.elog.Error(1, fmt.Sprintf("cloudflared terminated with error %v", err))
errno = 1
} else {
s.elog.Info(1, "cloudflared terminated without error")
errno = 0
}
return
}
}
}
func installWindowsService(c *cli.Context) error {
logger.Infof("Installing Argo Tunnel Windows service")
exepath, err := os.Executable()
if err != nil {
logger.Errorf("Cannot find path name that start the process")
return err
}
m, err := mgr.Connect()
if err != nil {
logger.WithError(err).Errorf("Cannot establish a connection to the service control manager")
return err
}
defer m.Disconnect()
s, err := m.OpenService(windowsServiceName)
if err == nil {
s.Close()
logger.Errorf("service %s already exists", windowsServiceName)
return fmt.Errorf("service %s already exists", windowsServiceName)
}
config := mgr.Config{StartType: mgr.StartAutomatic, DisplayName: windowsServiceDescription}
s, err = m.CreateService(windowsServiceName, exepath, config)
if err != nil {
logger.Errorf("Cannot install service %s", windowsServiceName)
return err
}
defer s.Close()
logger.Infof("Argo Tunnel agent service is installed")
err = eventlog.InstallAsEventCreate(windowsServiceName, eventlog.Error|eventlog.Warning|eventlog.Info)
if err != nil {
s.Delete()
logger.WithError(err).Errorf("Cannot install event logger")
return fmt.Errorf("SetupEventLogSource() failed: %s", err)
}
err = configRecoveryOption(s.Handle)
if err != nil {
logger.WithError(err).Errorf("Cannot set service recovery actions")
logger.Infof("See %s to manually configure service recovery actions", serviceUrl)
}
return nil
}
func uninstallWindowsService(c *cli.Context) error {
logger.Infof("Uninstalling Argo Tunnel Windows Service")
m, err := mgr.Connect()
if err != nil {
logger.Errorf("Cannot establish a connection to the service control manager")
return err
}
defer m.Disconnect()
s, err := m.OpenService(windowsServiceName)
if err != nil {
logger.Errorf("service %s is not installed", windowsServiceName)
return fmt.Errorf("service %s is not installed", windowsServiceName)
}
defer s.Close()
err = s.Delete()
if err != nil {
logger.Errorf("Cannot delete service %s", windowsServiceName)
return err
}
logger.Infof("Argo Tunnel agent service is uninstalled")
err = eventlog.Remove(windowsServiceName)
if err != nil {
logger.Errorf("Cannot remove event logger")
return fmt.Errorf("RemoveEventLogSource() failed: %s", err)
}
return nil
}
// defined in https://msdn.microsoft.com/en-us/library/windows/desktop/ms685126(v=vs.85).aspx
type scAction int
// https://msdn.microsoft.com/en-us/library/windows/desktop/ms685126(v=vs.85).aspx
const (
scActionNone scAction = iota
scActionRestart
scActionReboot
scActionRunCommand
)
// defined in https://msdn.microsoft.com/en-us/library/windows/desktop/ms685939(v=vs.85).aspx
type serviceFailureActions struct {
// time to wait to reset the failure count to zero if there are no failures in seconds
resetPeriod uint32
rebootMsg *uint16
command *uint16
// If failure count is greater than actionCount, the service controller repeats
// the last action in actions
actionCount uint32
actions uintptr
}
// https://msdn.microsoft.com/en-us/library/windows/desktop/ms685937(v=vs.85).aspx
// Not supported in Windows Server 2003 and Windows XP
type serviceFailureActionsFlag struct {
// enableActionsForStopsWithErr is of type BOOL, which is declared as
// typedef int BOOL in C
enableActionsForStopsWithErr int
}
type recoveryAction struct {
recoveryType uint32
// The time to wait before performing the specified action, in milliseconds
delay uint32
}
// until https://github.com/golang/go/issues/23239 is release, we will need to
// configure through ChangeServiceConfig2
func configRecoveryOption(handle windows.Handle) error {
actions := []recoveryAction{
{recoveryType: uint32(scActionRestart), delay: uint32(recoverActionDelay / time.Millisecond)},
}
serviceRecoveryActions := serviceFailureActions{
resetPeriod: uint32(failureCountResetPeriod / time.Second),
actionCount: uint32(len(actions)),
actions: uintptr(unsafe.Pointer(&actions[0])),
}
if err := windows.ChangeServiceConfig2(handle, windows.SERVICE_CONFIG_FAILURE_ACTIONS, (*byte)(unsafe.Pointer(&serviceRecoveryActions))); err != nil {
return err
}
serviceFailureActionsFlag := serviceFailureActionsFlag{enableActionsForStopsWithErr: 1}
return windows.ChangeServiceConfig2(handle, serviceConfigFailureActionsFlag, (*byte)(unsafe.Pointer(&serviceFailureActionsFlag)))
}