TUN-3007: Implement named tunnel connection registration and unregistration.

Removed flag for using quick reconnect, this logic is now always enabled.
This commit is contained in:
Igor Postelnik
2020-06-25 13:25:39 -05:00
parent 932e383051
commit 2a3d486126
9 changed files with 248 additions and 141 deletions

View File

@@ -106,7 +106,7 @@ func ssh(c *cli.Context) error {
wsConn := carrier.NewWSConnection(logger, false)
if c.NArg() > 0 || c.IsSet(sshURLFlag) {
localForwarder, err := config.ValidateUrl(c)
localForwarder, err := config.ValidateUrl(c, true)
if err != nil {
logger.Errorf("Error validating origin URL: %s", err)
return errors.Wrap(err, "error validating origin URL")

View File

@@ -6,11 +6,12 @@ import (
"path/filepath"
"runtime"
"github.com/cloudflare/cloudflared/validation"
homedir "github.com/mitchellh/go-homedir"
"gopkg.in/urfave/cli.v2"
"gopkg.in/urfave/cli.v2/altsrc"
"gopkg.in/yaml.v2"
"github.com/cloudflare/cloudflared/validation"
)
var (
@@ -176,9 +177,9 @@ func ValidateUnixSocket(c *cli.Context) (string, error) {
// ValidateUrl will validate url flag correctness. It can be either from --url or argument
// Notice ValidateUnixSocket, it will enforce --unix-socket is not used with --url or argument
func ValidateUrl(c *cli.Context) (string, error) {
func ValidateUrl(c *cli.Context, allowFromArgs bool) (string, error) {
var url = c.String("url")
if c.NArg() > 0 {
if allowFromArgs && c.NArg() > 0 {
if c.IsSet("url") {
return "", errors.New("Specified origin urls using both --url and argument. Decide which one you want, I can only support one.")
}

View File

@@ -359,7 +359,7 @@ func StartServer(c *cli.Context, version string, shutdownC, graceShutdownC chan
defer wg.Done()
hello.StartHelloWorldServer(logger, helloListener, shutdownC)
}()
c.Set("url", "https://"+helloListener.Addr().String())
forceSetFlag(c, "url", "https://"+helloListener.Addr().String())
}
if c.IsSet(sshServerFlag) {
@@ -409,7 +409,7 @@ func StartServer(c *cli.Context, version string, shutdownC, graceShutdownC chan
close(shutdownC)
}
}()
c.Set("url", "ssh://"+localServerAddress)
forceSetFlag(c, "url", "ssh://"+localServerAddress)
}
url := c.String("url")
@@ -453,7 +453,7 @@ func StartServer(c *cli.Context, version string, shutdownC, graceShutdownC chan
}
errC <- websocket.StartProxyServer(logger, listener, staticHost, shutdownC, streamHandler)
}()
c.Set("url", "http://"+listener.Addr().String())
forceSetFlag(c, "url", "http://"+listener.Addr().String())
}
transportLogger, err := createLogger(c, true)
@@ -461,7 +461,7 @@ func StartServer(c *cli.Context, version string, shutdownC, graceShutdownC chan
return errors.Wrap(err, "error setting up transport logger")
}
tunnelConfig, err := prepareTunnelConfig(c, buildInfo, version, logger, transportLogger)
tunnelConfig, err := prepareTunnelConfig(c, buildInfo, version, logger, transportLogger, namedTunnel)
if err != nil {
return err
}
@@ -475,12 +475,21 @@ func StartServer(c *cli.Context, version string, shutdownC, graceShutdownC chan
wg.Add(1)
go func() {
defer wg.Done()
errC <- origin.StartTunnelDaemon(ctx, tunnelConfig, connectedSignal, cloudflaredID, reconnectCh, namedTunnel)
errC <- origin.StartTunnelDaemon(ctx, tunnelConfig, connectedSignal, cloudflaredID, reconnectCh)
}()
return waitToShutdown(&wg, errC, shutdownC, graceShutdownC, c.Duration("grace-period"), logger)
}
// forceSetFlag attempts to set the given flag value in the closest context that has it defined
func forceSetFlag(c *cli.Context, name, value string) {
for _, ctx := range c.Lineage() {
if err := ctx.Set(name, value); err == nil {
break
}
}
}
func Before(c *cli.Context) error {
logger, err := createLogger(c, false)
if err != nil {
@@ -969,13 +978,6 @@ func tunnelFlags(shouldHide bool) []cli.Flag {
EnvVars: []string{"TUNNEL_USE_RECONNECT_TOKEN"},
Hidden: true,
}),
altsrc.NewBoolFlag(&cli.BoolFlag{
Name: "use-quick-reconnects",
Usage: "Test reestablishing connections with the new 'connection digest' flow.",
Value: true,
EnvVars: []string{"TUNNEL_USE_QUICK_RECONNECTS"},
Hidden: true,
}),
altsrc.NewDurationFlag(&cli.DurationFlag{
Name: "dial-edge-timeout",
Usage: "Maximum wait time to set up a connection with the edge",

View File

@@ -158,7 +158,10 @@ func prepareTunnelConfig(
version string,
logger logger.Service,
transportLogger logger.Service,
namedTunnel *origin.NamedTunnelConfig,
) (*origin.TunnelConfig, error) {
compatibilityMode := namedTunnel == nil
hostname, err := validation.ValidateHostname(c.String("hostname"))
if err != nil {
logger.Errorf("Invalid hostname: %s", err)
@@ -181,7 +184,7 @@ func prepareTunnelConfig(
tags = append(tags, tunnelpogs.Tag{Name: "ID", Value: clientID})
originURL, err := config.ValidateUrl(c)
originURL, err := config.ValidateUrl(c, compatibilityMode)
if err != nil {
logger.Errorf("Error validating origin URL: %s", err)
return nil, errors.Wrap(err, "Error validating origin URL")
@@ -254,38 +257,52 @@ func prepareTunnelConfig(
return nil, errors.Wrap(err, "unable to create TLS config to connect with edge")
}
if namedTunnel != nil {
clientUUID, err := uuid.NewRandom()
if err != nil {
return nil, errors.Wrap(err, "can't generate clientUUID")
}
namedTunnel.Client = tunnelpogs.ClientInfo{
ClientID: clientUUID[:],
Features: []string{origin.FeatureSerializedHeaders},
Version: version,
Arch: fmt.Sprintf("%s_%s", buildInfo.GoOS, buildInfo.GoArch),
}
}
return &origin.TunnelConfig{
BuildInfo: buildInfo,
ClientID: clientID,
ClientTlsConfig: httpTransport.TLSClientConfig,
CompressionQuality: c.Uint64("compression-quality"),
EdgeAddrs: c.StringSlice("edge"),
GracePeriod: c.Duration("grace-period"),
HAConnections: c.Int("ha-connections"),
HTTPTransport: httpTransport,
HeartbeatInterval: c.Duration("heartbeat-interval"),
Hostname: hostname,
HTTPHostHeader: c.String("http-host-header"),
IncidentLookup: origin.NewIncidentLookup(),
IsAutoupdated: c.Bool("is-autoupdated"),
IsFreeTunnel: isFreeTunnel,
LBPool: c.String("lb-pool"),
Logger: logger,
TransportLogger: transportLogger,
MaxHeartbeats: c.Uint64("heartbeat-count"),
Metrics: tunnelMetrics,
MetricsUpdateFreq: c.Duration("metrics-update-freq"),
NoChunkedEncoding: c.Bool("no-chunked-encoding"),
OriginCert: originCert,
OriginUrl: originURL,
ReportedVersion: version,
Retries: c.Uint("retries"),
RunFromTerminal: isRunningFromTerminal(),
Tags: tags,
TlsConfig: toEdgeTLSConfig,
UseDeclarativeTunnel: c.Bool("use-declarative-tunnels"),
UseReconnectToken: c.Bool("use-reconnect-token"),
UseQuickReconnects: c.Bool("use-quick-reconnects"),
BuildInfo: buildInfo,
ClientID: clientID,
ClientTlsConfig: httpTransport.TLSClientConfig,
CompressionQuality: c.Uint64("compression-quality"),
EdgeAddrs: c.StringSlice("edge"),
GracePeriod: c.Duration("grace-period"),
HAConnections: c.Int("ha-connections"),
HTTPTransport: httpTransport,
HeartbeatInterval: c.Duration("heartbeat-interval"),
Hostname: hostname,
HTTPHostHeader: c.String("http-host-header"),
IncidentLookup: origin.NewIncidentLookup(),
IsAutoupdated: c.Bool("is-autoupdated"),
IsFreeTunnel: isFreeTunnel,
LBPool: c.String("lb-pool"),
Logger: logger,
TransportLogger: transportLogger,
MaxHeartbeats: c.Uint64("heartbeat-count"),
Metrics: tunnelMetrics,
MetricsUpdateFreq: c.Duration("metrics-update-freq"),
NoChunkedEncoding: c.Bool("no-chunked-encoding"),
OriginCert: originCert,
OriginUrl: originURL,
ReportedVersion: version,
Retries: c.Uint("retries"),
RunFromTerminal: isRunningFromTerminal(),
Tags: tags,
TlsConfig: toEdgeTLSConfig,
NamedTunnel: namedTunnel,
ReplaceExisting: c.Bool("force"),
// turn off use of reconnect token and auth refresh when using named tunnels
UseReconnectToken: compatibilityMode && c.Bool("use-reconnect-token"),
}, nil
}

View File

@@ -11,6 +11,7 @@ import (
"strings"
"time"
"github.com/google/uuid"
"github.com/pkg/errors"
"gopkg.in/urfave/cli.v2"
"gopkg.in/yaml.v2"
@@ -34,7 +35,7 @@ var (
Aliases: []string{"o"},
Usage: "Render output using given `FORMAT`. Valid options are 'json' or 'yaml'",
}
forceFlag = &cli.StringFlag{
forceFlag = &cli.BoolFlag{
Name: "force",
Aliases: []string{"f"},
Usage: "By default, if a tunnel is currently being run from a cloudflared, you can't " +
@@ -148,9 +149,12 @@ func readTunnelCredentials(tunnelID, originCertPath string) (*pogs.TunnelAuth, e
if err != nil {
return nil, errors.Wrapf(err, "couldn't read tunnel credentials from %v", filePath)
}
auth := pogs.TunnelAuth{}
err = json.Unmarshal(body, &auth)
return &auth, errors.Wrap(err, "couldn't parse tunnel credentials from JSON")
var auth pogs.TunnelAuth
if err = json.Unmarshal(body, &auth); err != nil {
return nil, err
}
return &auth, nil
}
func buildListCommand() *cli.Command {
@@ -325,6 +329,10 @@ func runTunnel(c *cli.Context) error {
return cliutil.UsageError(`"cloudflared tunnel run" requires exactly 1 argument, the ID of the tunnel to run.`)
}
id := c.Args().First()
tunnelID, err := uuid.Parse(id)
if err != nil {
return errors.Wrap(err, "error parsing tunnel ID")
}
logger, err := logger.New()
if err != nil {
@@ -340,5 +348,5 @@ func runTunnel(c *cli.Context) error {
return err
}
logger.Debugf("Read credentials for %v", credentials.AccountTag)
return StartServer(c, version, shutdownC, graceShutdownC, &origin.NamedTunnelConfig{Auth: *credentials, ID: id})
return StartServer(c, version, shutdownC, graceShutdownC, &origin.NamedTunnelConfig{Auth: *credentials, ID: tunnelID})
}