diff --git a/cmd/cloudflared/flags/flags.go b/cmd/cloudflared/flags/flags.go index 975ee401..1f8f2855 100644 --- a/cmd/cloudflared/flags/flags.go +++ b/cmd/cloudflared/flags/flags.go @@ -160,4 +160,7 @@ const ( // Virtual DNS resolver service resolver addresses to use instead of dynamically fetching them from the OS. VirtualDNSServiceResolverAddresses = "dns-resolver-addrs" + + // Management hostname to signify incoming management requests + ManagementHostname = "management-hostname" ) diff --git a/cmd/cloudflared/tail/cmd.go b/cmd/cloudflared/tail/cmd.go index 6c376033..9b2ee6a5 100644 --- a/cmd/cloudflared/tail/cmd.go +++ b/cmd/cloudflared/tail/cmd.go @@ -51,6 +51,7 @@ func buildTailManagementTokenSubcommand() *cli.Command { func managementTokenCommand(c *cli.Context) error { log := createLogger(c) + token, err := getManagementToken(c, log) if err != nil { return err @@ -99,7 +100,7 @@ func buildTailCommand(subcommands []*cli.Command) *cli.Command { EnvVars: []string{"TUNNEL_MANAGEMENT_TOKEN"}, }, &cli.StringFlag{ - Name: "management-hostname", + Name: cfdflags.ManagementHostname, Usage: "Management hostname to signify incoming management requests", EnvVars: []string{"TUNNEL_MANAGEMENT_HOSTNAME"}, Hidden: true, @@ -236,7 +237,14 @@ func getManagementToken(c *cli.Context, log *zerolog.Logger) (string, error) { return "", err } - client, err := userCreds.Client(c.String(cfdflags.ApiURL), buildInfo.UserAgent(), log) + var apiURL string + if userCreds.IsFEDEndpoint() { + apiURL = credentials.FedRampBaseApiURL + } else { + apiURL = c.String(cfdflags.ApiURL) + } + + client, err := userCreds.Client(apiURL, buildInfo.UserAgent(), log) if err != nil { return "", err } @@ -261,7 +269,7 @@ func getManagementToken(c *cli.Context, log *zerolog.Logger) (string, error) { // buildURL will build the management url to contain the required query parameters to authenticate the request. func buildURL(c *cli.Context, log *zerolog.Logger) (url.URL, error) { var err error - managementHostname := c.String("management-hostname") + token := c.String("token") if token == "" { token, err = getManagementToken(c, log) @@ -269,6 +277,19 @@ func buildURL(c *cli.Context, log *zerolog.Logger) (url.URL, error) { return url.URL{}, fmt.Errorf("unable to acquire management token for requested tunnel id: %w", err) } } + + claims, err := management.ParseToken(token) + if err != nil { + return url.URL{}, fmt.Errorf("failed to determine if token is FED: %w", err) + } + + var managementHostname string + if claims.IsFed() { + managementHostname = credentials.FedRampHostname + } else { + managementHostname = c.String(cfdflags.ManagementHostname) + } + query := url.Values{} query.Add("access_token", token) connector := c.String("connector-id") diff --git a/cmd/cloudflared/tunnel/cmd.go b/cmd/cloudflared/tunnel/cmd.go index 89b5448d..925333a4 100644 --- a/cmd/cloudflared/tunnel/cmd.go +++ b/cmd/cloudflared/tunnel/cmd.go @@ -97,7 +97,7 @@ var ( "no-tls-verify", "no-chunked-encoding", "http2-origin", - "management-hostname", + cfdflags.ManagementHostname, "service-op-ip", "local-ssh-port", "ssh-idle-timeout", @@ -459,8 +459,23 @@ func StartServer( } } + userCreds, err := credentials.Read(c.String(cfdflags.OriginCert), log) + var isFEDEndpoint bool + if err != nil { + isFEDEndpoint = false + } else { + isFEDEndpoint = userCreds.IsFEDEndpoint() + } + + var managementHostname string + if isFEDEndpoint { + managementHostname = credentials.FedRampHostname + } else { + managementHostname = c.String(cfdflags.ManagementHostname) + } + mgmt := management.New( - c.String("management-hostname"), + managementHostname, c.Bool("management-diagnostics"), serviceIP, connectorID, @@ -1042,7 +1057,7 @@ func configureProxyFlags(shouldHide bool) []cli.Flag { Value: false, }), altsrc.NewStringFlag(&cli.StringFlag{ - Name: "management-hostname", + Name: cfdflags.ManagementHostname, Usage: "Management hostname to signify incoming management requests", EnvVars: []string{"TUNNEL_MANAGEMENT_HOSTNAME"}, Hidden: true, diff --git a/credentials/credentials.go b/credentials/credentials.go index f5679b25..7abd9ae4 100644 --- a/credentials/credentials.go +++ b/credentials/credentials.go @@ -10,6 +10,8 @@ import ( const ( logFieldOriginCertPath = "originCertPath" FedEndpoint = "fed" + FedRampBaseApiURL = "https://api.fed.cloudflare.com/client/v4" + FedRampHostname = "management.fed.argotunnel.com" ) type User struct { diff --git a/management/middleware.go b/management/middleware.go index 68b4a26e..e8b6e698 100644 --- a/management/middleware.go +++ b/management/middleware.go @@ -12,14 +12,7 @@ const ( accessClaimsCtxKey ctxKey = iota ) -const ( - connectorIDQuery = "connector_id" - accessTokenQuery = "access_token" -) - -var ( - errMissingAccessToken = managementError{Code: 1001, Message: "missing access_token query parameter"} -) +var errMissingAccessToken = managementError{Code: 1001, Message: "missing access_token query parameter"} // HTTP middleware setting the parsed access_token claims in the request context func ValidateAccessTokenQueryMiddleware(next http.Handler) http.Handler { @@ -30,7 +23,7 @@ func ValidateAccessTokenQueryMiddleware(next http.Handler) http.Handler { writeHTTPErrorResponse(w, errMissingAccessToken) return } - token, err := parseToken(accessToken) + token, err := ParseToken(accessToken) if err != nil { writeHTTPErrorResponse(w, errMissingAccessToken) return diff --git a/management/token.go b/management/token.go index 16b51bcc..cfca335b 100644 --- a/management/token.go +++ b/management/token.go @@ -7,9 +7,12 @@ import ( "github.com/go-jose/go-jose/v4/jwt" ) +const tunnelstoreFEDIssuer = "fed-tunnelstore" + type managementTokenClaims struct { Tunnel tunnel `json:"tun"` Actor actor `json:"actor"` + jwt.Claims } // VerifyTunnel compares the tun claim isn't empty @@ -37,7 +40,7 @@ func (t *actor) verify() bool { return t.ID != "" } -func parseToken(token string) (*managementTokenClaims, error) { +func ParseToken(token string) (*managementTokenClaims, error) { jwt, err := jwt.ParseSigned(token, []jose.SignatureAlgorithm{jose.ES256}) if err != nil { return nil, fmt.Errorf("malformed jwt: %v", err) @@ -54,3 +57,7 @@ func parseToken(token string) (*managementTokenClaims, error) { } return &claims, nil } + +func (m *managementTokenClaims) IsFed() bool { + return m.Issuer == tunnelstoreFEDIssuer +} diff --git a/management/token_test.go b/management/token_test.go index 54982bbe..eadf278b 100644 --- a/management/token_test.go +++ b/management/token_test.go @@ -12,7 +12,7 @@ import ( ) const ( - validToken = "eyJ0eXAiOiJKV1QiLCJhbGciOiJFUzI1NiIsImtpZCI6IjEifQ.eyJ0dW4iOnsiaWQiOiI3YjA5ODE0OS01MWZlLTRlZTUtYTY4Ny0zZTM3NDQ2NmVmYzciLCJhY2NvdW50X3RhZyI6ImNkMzkxZTljMDYyNmE4Zjc2Y2IxZjY3MGY2NTkxYjA1In0sImFjdG9yIjp7ImlkIjoiZGNhcnJAY2xvdWRmbGFyZS5jb20iLCJzdXBwb3J0IjpmYWxzZX0sInJlcyI6WyJsb2dzIl0sImV4cCI6MTY3NzExNzY5NiwiaWF0IjoxNjc3MTE0MDk2LCJpc3MiOiJ0dW5uZWxzdG9yZSJ9.mKenOdOy3Xi4O-grldFnAAemdlE9WajEpTDC_FwezXQTstWiRTLwU65P5jt4vNsIiZA4OJRq7bH-QYID9wf9NA" + validToken = "eyJ0eXAiOiJKV1QiLCJhbGciOiJFUzI1NiIsImtpZCI6IjEifQ.eyJ0dW4iOnsiaWQiOiI3YjA5ODE0OS01MWZlLTRlZTUtYTY4Ny0zZTM3NDQ2NmVmYzciLCJhY2NvdW50X3RhZyI6ImNkMzkxZTljMDYyNmE4Zjc2Y2IxZjY3MGY2NTkxYjA1In0sImFjdG9yIjp7ImlkIjoiZGNhcnJAY2xvdWRmbGFyZS5jb20iLCJzdXBwb3J0IjpmYWxzZX0sInJlcyI6WyJsb2dzIl0sImV4cCI6MTY3NzExNzY5NiwiaWF0IjoxNjc3MTE0MDk2LCJpc3MiOiJ0dW5uZWxzdG9yZSJ9.mKenOdOy3Xi4O-grldFnAAemdlE9WajEpTDC_FwezXQTstWiRTLwU65P5jt4vNsIiZA4OJRq7bH-QYID9wf9NA" // nolint: gosec accountTag = "cd391e9c0626a8f76cb1f670f6591b05" tunnelID = "7b098149-51fe-4ee5-a687-3e374466efc7" @@ -105,12 +105,12 @@ func TestParseToken(t *testing.T) { } { t.Run(test.name, func(t *testing.T) { jwt := signToken(t, test.claims, key) - claims, err := parseToken(jwt) + claims, err := ParseToken(jwt) if test.err != nil { require.EqualError(t, err, test.err.Error()) return } - require.Nil(t, err) + require.NoError(t, err) require.Equal(t, test.claims, *claims) }) }