mirror of
https://github.com/cloudflare/cloudflared.git
synced 2025-07-28 16:29:58 +00:00
TUN-7628: Correct Host parsing for Access
Will no longer provide full hostname with path from provided `--hostname` flag for cloudflared access to the Host header field. This addresses certain issues caught from a security fix in go 1.19.11 and 1.20.6 in the net/http URL parsing.
This commit is contained in:
@@ -70,14 +70,14 @@ func ssh(c *cli.Context) error {
|
||||
|
||||
// get the hostname from the cmdline and error out if its not provided
|
||||
rawHostName := c.String(sshHostnameFlag)
|
||||
hostname, err := validation.ValidateHostname(rawHostName)
|
||||
if err != nil || rawHostName == "" {
|
||||
url, err := parseURL(rawHostName)
|
||||
if err != nil {
|
||||
log.Err(err).Send()
|
||||
return cli.ShowCommandHelp(c, "ssh")
|
||||
}
|
||||
originURL := ensureURLScheme(hostname)
|
||||
|
||||
// get the headers from the cmdline and add them
|
||||
headers := buildRequestHeaders(c.StringSlice(sshHeaderFlag))
|
||||
headers := parseRequestHeaders(c.StringSlice(sshHeaderFlag))
|
||||
if c.IsSet(sshTokenIDFlag) {
|
||||
headers.Set(cfAccessClientIDHeader, c.String(sshTokenIDFlag))
|
||||
}
|
||||
@@ -89,9 +89,9 @@ func ssh(c *cli.Context) error {
|
||||
carrier.SetBastionDest(headers, c.String(sshDestinationFlag))
|
||||
|
||||
options := &carrier.StartOptions{
|
||||
OriginURL: originURL,
|
||||
OriginURL: url.String(),
|
||||
Headers: headers,
|
||||
Host: hostname,
|
||||
Host: url.Host,
|
||||
}
|
||||
|
||||
if connectTo := c.String(sshConnectTo); connectTo != "" {
|
||||
@@ -138,20 +138,9 @@ func ssh(c *cli.Context) error {
|
||||
// default to 10 if provided but unset
|
||||
maxMessages = 10
|
||||
}
|
||||
logger := log.With().Str("host", hostname).Logger()
|
||||
logger := log.With().Str("host", url.Host).Logger()
|
||||
s = stream.NewDebugStream(s, &logger, maxMessages)
|
||||
}
|
||||
carrier.StartClient(wsConn, s, options)
|
||||
return nil
|
||||
}
|
||||
|
||||
func buildRequestHeaders(values []string) http.Header {
|
||||
headers := make(http.Header)
|
||||
for _, valuePair := range values {
|
||||
header, value, found := strings.Cut(valuePair, ":")
|
||||
if found {
|
||||
headers.Add(strings.TrimSpace(header), strings.TrimSpace(value))
|
||||
}
|
||||
}
|
||||
return headers
|
||||
}
|
||||
|
@@ -1,19 +0,0 @@
|
||||
package access
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
func TestBuildRequestHeaders(t *testing.T) {
|
||||
headers := make(http.Header)
|
||||
headers.Add("client", "value")
|
||||
headers.Add("secret", "safe-value")
|
||||
|
||||
values := buildRequestHeaders([]string{"client: value", "secret: safe-value", "trash", "cf-trace-id: 000:000:0:1:asd"})
|
||||
assert.Equal(t, headers.Get("client"), values.Get("client"))
|
||||
assert.Equal(t, headers.Get("secret"), values.Get("secret"))
|
||||
assert.Equal(t, headers.Get("cf-trace-id"), values.Get("000:000:0:1:asd"))
|
||||
}
|
@@ -222,8 +222,7 @@ func login(c *cli.Context) error {
|
||||
log := logger.CreateLoggerFromContext(c, logger.EnableTerminalLog)
|
||||
|
||||
args := c.Args()
|
||||
rawURL := ensureURLScheme(args.First())
|
||||
appURL, err := url.Parse(rawURL)
|
||||
appURL, err := parseURL(args.First())
|
||||
if args.Len() < 1 || err != nil {
|
||||
log.Error().Msg("Please provide the url of the Access application")
|
||||
return err
|
||||
@@ -252,16 +251,6 @@ func login(c *cli.Context) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
// ensureURLScheme prepends a URL with https:// if it doesn't have a scheme. http:// URLs will not be converted.
|
||||
func ensureURLScheme(url string) string {
|
||||
url = strings.Replace(strings.ToLower(url), "http://", "https://", 1)
|
||||
if !strings.HasPrefix(url, "https://") {
|
||||
url = fmt.Sprintf("https://%s", url)
|
||||
|
||||
}
|
||||
return url
|
||||
}
|
||||
|
||||
// curl provides a wrapper around curl, passing Access JWT along in request
|
||||
func curl(c *cli.Context) error {
|
||||
err := sentry.Init(sentry.ClientOptions{
|
||||
@@ -345,7 +334,7 @@ func generateToken(c *cli.Context) error {
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
appURL, err := url.Parse(ensureURLScheme(c.String("app")))
|
||||
appURL, err := parseURL(c.String("app"))
|
||||
if err != nil || c.NumFlags() < 1 {
|
||||
fmt.Fprintln(os.Stderr, "Please provide a url.")
|
||||
return err
|
||||
@@ -398,7 +387,7 @@ func sshGen(c *cli.Context) error {
|
||||
return cli.ShowCommandHelp(c, "ssh-gen")
|
||||
}
|
||||
|
||||
originURL, err := url.Parse(ensureURLScheme(hostname))
|
||||
originURL, err := parseURL(hostname)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
@@ -499,7 +488,7 @@ func isFileThere(candidate string) bool {
|
||||
// Then makes a request to to the origin with the token to ensure it is valid.
|
||||
// Returns nil if token is valid.
|
||||
func verifyTokenAtEdge(appUrl *url.URL, appInfo *token.AppInfo, c *cli.Context, log *zerolog.Logger) error {
|
||||
headers := buildRequestHeaders(c.StringSlice(sshHeaderFlag))
|
||||
headers := parseRequestHeaders(c.StringSlice(sshHeaderFlag))
|
||||
if c.IsSet(sshTokenIDFlag) {
|
||||
headers.Add(cfAccessClientIDHeader, c.String(sshTokenIDFlag))
|
||||
}
|
||||
|
@@ -1,25 +0,0 @@
|
||||
package access
|
||||
|
||||
import "testing"
|
||||
|
||||
func Test_ensureURLScheme(t *testing.T) {
|
||||
type args struct {
|
||||
url string
|
||||
}
|
||||
tests := []struct {
|
||||
name string
|
||||
args args
|
||||
want string
|
||||
}{
|
||||
{"no scheme", args{"localhost:123"}, "https://localhost:123"},
|
||||
{"http scheme", args{"http://test"}, "https://test"},
|
||||
{"https scheme", args{"https://test"}, "https://test"},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
if got := ensureURLScheme(tt.args.url); got != tt.want {
|
||||
t.Errorf("ensureURLScheme() = %v, want %v", got, tt.want)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
55
cmd/cloudflared/access/validation.go
Normal file
55
cmd/cloudflared/access/validation.go
Normal file
@@ -0,0 +1,55 @@
|
||||
package access
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"strings"
|
||||
|
||||
"golang.org/x/net/http/httpguts"
|
||||
)
|
||||
|
||||
// parseRequestHeaders will take user-provided header values as strings "Content-Type: application/json" and create
|
||||
// a http.Header object.
|
||||
func parseRequestHeaders(values []string) http.Header {
|
||||
headers := make(http.Header)
|
||||
for _, valuePair := range values {
|
||||
header, value, found := strings.Cut(valuePair, ":")
|
||||
if found {
|
||||
headers.Add(strings.TrimSpace(header), strings.TrimSpace(value))
|
||||
}
|
||||
}
|
||||
return headers
|
||||
}
|
||||
|
||||
// parseHostname will attempt to convert a user provided URL string into a string with some light error checking on
|
||||
// certain expectations from the URL.
|
||||
// Will convert all HTTP URLs to HTTPS
|
||||
func parseURL(input string) (*url.URL, error) {
|
||||
if input == "" {
|
||||
return nil, errors.New("no input provided")
|
||||
}
|
||||
if !strings.HasPrefix(input, "https://") && !strings.HasPrefix(input, "http://") {
|
||||
input = fmt.Sprintf("https://%s", input)
|
||||
}
|
||||
url, err := url.ParseRequestURI(input)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to parse as URL: %w", err)
|
||||
}
|
||||
if url.Scheme != "https" {
|
||||
url.Scheme = "https"
|
||||
}
|
||||
if url.Host == "" {
|
||||
return nil, errors.New("failed to parse Host")
|
||||
}
|
||||
host, err := httpguts.PunycodeHostPort(url.Host)
|
||||
if err != nil || host == "" {
|
||||
return nil, err
|
||||
}
|
||||
if !httpguts.ValidHostHeader(host) {
|
||||
return nil, errors.New("invalid Host provided")
|
||||
}
|
||||
url.Host = host
|
||||
return url, nil
|
||||
}
|
80
cmd/cloudflared/access/validation_test.go
Normal file
80
cmd/cloudflared/access/validation_test.go
Normal file
@@ -0,0 +1,80 @@
|
||||
package access
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
func TestParseRequestHeaders(t *testing.T) {
|
||||
values := parseRequestHeaders([]string{"client: value", "secret: safe-value", "trash", "cf-trace-id: 000:000:0:1:asd"})
|
||||
assert.Len(t, values, 3)
|
||||
assert.Equal(t, "value", values.Get("client"))
|
||||
assert.Equal(t, "safe-value", values.Get("secret"))
|
||||
assert.Equal(t, "000:000:0:1:asd", values.Get("cf-trace-id"))
|
||||
}
|
||||
|
||||
func TestParseURL(t *testing.T) {
|
||||
schemes := []string{
|
||||
"http://",
|
||||
"https://",
|
||||
"",
|
||||
}
|
||||
hosts := []struct {
|
||||
input string
|
||||
expected string
|
||||
}{
|
||||
{"localhost", "localhost"},
|
||||
{"127.0.0.1", "127.0.0.1"},
|
||||
{"127.0.0.1:9090", "127.0.0.1:9090"},
|
||||
{"::1", "::1"},
|
||||
{"::1:8080", "::1:8080"},
|
||||
{"[::1]", "[::1]"},
|
||||
{"[::1]:8080", "[::1]:8080"},
|
||||
{":8080", ":8080"},
|
||||
{"example.com", "example.com"},
|
||||
{"hello.example.com", "hello.example.com"},
|
||||
{"bücher.example.com", "xn--bcher-kva.example.com"},
|
||||
}
|
||||
paths := []string{
|
||||
"",
|
||||
"/test",
|
||||
"/example.com?qwe=123",
|
||||
}
|
||||
for i, scheme := range schemes {
|
||||
for j, host := range hosts {
|
||||
for k, path := range paths {
|
||||
t.Run(fmt.Sprintf("%d_%d_%d", i, j, k), func(t *testing.T) {
|
||||
input := fmt.Sprintf("%s%s%s", scheme, host.input, path)
|
||||
expected := fmt.Sprintf("%s%s%s", "https://", host.expected, path)
|
||||
url, err := parseURL(input)
|
||||
assert.NoError(t, err, "input: %s\texpected: %s", input, expected)
|
||||
assert.Equal(t, expected, url.String())
|
||||
assert.Equal(t, host.expected, url.Host)
|
||||
assert.Equal(t, "https", url.Scheme)
|
||||
})
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
t.Run("no input", func(t *testing.T) {
|
||||
_, err := parseURL("")
|
||||
assert.ErrorContains(t, err, "no input provided")
|
||||
})
|
||||
|
||||
t.Run("missing host", func(t *testing.T) {
|
||||
_, err := parseURL("https:///host")
|
||||
assert.ErrorContains(t, err, "failed to parse Host")
|
||||
})
|
||||
|
||||
t.Run("invalid path only", func(t *testing.T) {
|
||||
_, err := parseURL("/host")
|
||||
assert.ErrorContains(t, err, "failed to parse Host")
|
||||
})
|
||||
|
||||
t.Run("invalid parse URL", func(t *testing.T) {
|
||||
_, err := parseURL("https://host\\host")
|
||||
assert.ErrorContains(t, err, "failed to parse as URL")
|
||||
})
|
||||
}
|
Reference in New Issue
Block a user