mirror of
https://github.com/cloudflare/cloudflared.git
synced 2025-07-27 19:29:57 +00:00
TUN-7628: Correct Host parsing for Access
Will no longer provide full hostname with path from provided `--hostname` flag for cloudflared access to the Host header field. This addresses certain issues caught from a security fix in go 1.19.11 and 1.20.6 in the net/http URL parsing.
This commit is contained in:
@@ -168,68 +168,6 @@ func validateIP(scheme, host, port string) (string, error) {
|
||||
return fmt.Sprintf("%s://%s", scheme, host), nil
|
||||
}
|
||||
|
||||
// originURL shouldn't be a pointer, because this function might change the scheme
|
||||
func ValidateHTTPService(originURL string, hostname string, transport http.RoundTripper) error {
|
||||
parsedURL, err := url.Parse(originURL)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
client := &http.Client{
|
||||
Transport: transport,
|
||||
CheckRedirect: func(req *http.Request, via []*http.Request) error {
|
||||
return http.ErrUseLastResponse
|
||||
},
|
||||
Timeout: validationTimeout,
|
||||
}
|
||||
|
||||
initialRequest, err := http.NewRequest("GET", parsedURL.String(), nil)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
initialRequest.Host = hostname
|
||||
resp, initialErr := client.Do(initialRequest)
|
||||
if initialErr == nil {
|
||||
resp.Body.Close()
|
||||
return nil
|
||||
}
|
||||
|
||||
// Attempt the same endpoint via the other protocol (http/https); maybe we have better luck?
|
||||
oldScheme := parsedURL.Scheme
|
||||
parsedURL.Scheme = toggleProtocol(oldScheme)
|
||||
|
||||
secondRequest, err := http.NewRequest("GET", parsedURL.String(), nil)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
secondRequest.Host = hostname
|
||||
resp, secondErr := client.Do(secondRequest)
|
||||
if secondErr == nil { // Worked this time--advise the user to switch protocols
|
||||
_ = resp.Body.Close()
|
||||
return errors.Errorf(
|
||||
"%s doesn't seem to work over %s, but does seem to work over %s. Reason: %v. Consider changing the origin URL to %v",
|
||||
parsedURL.Host,
|
||||
oldScheme,
|
||||
parsedURL.Scheme,
|
||||
initialErr,
|
||||
originURL,
|
||||
)
|
||||
}
|
||||
|
||||
return initialErr
|
||||
}
|
||||
|
||||
func toggleProtocol(httpProtocol string) string {
|
||||
switch httpProtocol {
|
||||
case "http":
|
||||
return "https"
|
||||
case "https":
|
||||
return "http"
|
||||
default:
|
||||
return httpProtocol
|
||||
}
|
||||
}
|
||||
|
||||
// Access checks if a JWT from Cloudflare Access is valid.
|
||||
type Access struct {
|
||||
verifier *oidc.IDTokenVerifier
|
||||
|
@@ -5,7 +5,6 @@ import (
|
||||
"fmt"
|
||||
"io"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"context"
|
||||
"crypto/tls"
|
||||
@@ -114,179 +113,6 @@ func TestValidateUrl(t *testing.T) {
|
||||
|
||||
}
|
||||
|
||||
func TestToggleProtocol(t *testing.T) {
|
||||
assert.Equal(t, "https", toggleProtocol("http"))
|
||||
assert.Equal(t, "http", toggleProtocol("https"))
|
||||
assert.Equal(t, "random", toggleProtocol("random"))
|
||||
assert.Equal(t, "", toggleProtocol(""))
|
||||
}
|
||||
|
||||
// Happy path 1: originURL is HTTP, and HTTP connections work
|
||||
func TestValidateHTTPService_HTTP2HTTP(t *testing.T) {
|
||||
originURL := "http://127.0.0.1/"
|
||||
hostname := "example.com"
|
||||
|
||||
assert.Nil(t, ValidateHTTPService(originURL, hostname, testRoundTripper(func(req *http.Request) (*http.Response, error) {
|
||||
assert.Equal(t, req.Host, hostname)
|
||||
if req.URL.Scheme == "http" {
|
||||
return emptyResponse(200), nil
|
||||
}
|
||||
if req.URL.Scheme == "https" {
|
||||
t.Fatal("http works, shouldn't have tried with https")
|
||||
}
|
||||
panic("Shouldn't reach here")
|
||||
})))
|
||||
|
||||
assert.Nil(t, ValidateHTTPService(originURL, hostname, testRoundTripper(func(req *http.Request) (*http.Response, error) {
|
||||
assert.Equal(t, req.Host, hostname)
|
||||
if req.URL.Scheme == "http" {
|
||||
return emptyResponse(503), nil
|
||||
}
|
||||
if req.URL.Scheme == "https" {
|
||||
t.Fatal("http works, shouldn't have tried with https")
|
||||
}
|
||||
panic("Shouldn't reach here")
|
||||
})))
|
||||
}
|
||||
|
||||
// Happy path 2: originURL is HTTPS, and HTTPS connections work
|
||||
func TestValidateHTTPService_HTTPS2HTTPS(t *testing.T) {
|
||||
originURL := "https://127.0.0.1:1234/"
|
||||
hostname := "example.com"
|
||||
|
||||
assert.Nil(t, ValidateHTTPService(originURL, hostname, testRoundTripper(func(req *http.Request) (*http.Response, error) {
|
||||
assert.Equal(t, req.Host, hostname)
|
||||
if req.URL.Scheme == "http" {
|
||||
t.Fatal("https works, shouldn't have tried with http")
|
||||
}
|
||||
if req.URL.Scheme == "https" {
|
||||
return emptyResponse(200), nil
|
||||
}
|
||||
panic("Shouldn't reach here")
|
||||
})))
|
||||
|
||||
assert.Nil(t, ValidateHTTPService(originURL, hostname, testRoundTripper(func(req *http.Request) (*http.Response, error) {
|
||||
assert.Equal(t, req.Host, hostname)
|
||||
if req.URL.Scheme == "http" {
|
||||
t.Fatal("https works, shouldn't have tried with http")
|
||||
}
|
||||
if req.URL.Scheme == "https" {
|
||||
return emptyResponse(503), nil
|
||||
}
|
||||
panic("Shouldn't reach here")
|
||||
})))
|
||||
}
|
||||
|
||||
// Error path 1: originURL is HTTPS, but HTTP connections work
|
||||
func TestValidateHTTPService_HTTPS2HTTP(t *testing.T) {
|
||||
originURL := "https://127.0.0.1:1234/"
|
||||
hostname := "example.com"
|
||||
|
||||
assert.Error(t, ValidateHTTPService(originURL, hostname, testRoundTripper(func(req *http.Request) (*http.Response, error) {
|
||||
assert.Equal(t, req.Host, hostname)
|
||||
if req.URL.Scheme == "http" {
|
||||
return emptyResponse(200), nil
|
||||
}
|
||||
if req.URL.Scheme == "https" {
|
||||
return nil, assert.AnError
|
||||
}
|
||||
panic("Shouldn't reach here")
|
||||
})))
|
||||
|
||||
assert.Error(t, ValidateHTTPService(originURL, hostname, testRoundTripper(func(req *http.Request) (*http.Response, error) {
|
||||
assert.Equal(t, req.Host, hostname)
|
||||
if req.URL.Scheme == "http" {
|
||||
return emptyResponse(503), nil
|
||||
}
|
||||
if req.URL.Scheme == "https" {
|
||||
return nil, assert.AnError
|
||||
}
|
||||
panic("Shouldn't reach here")
|
||||
})))
|
||||
}
|
||||
|
||||
// Error path 2: originURL is HTTP, but HTTPS connections work
|
||||
func TestValidateHTTPService_HTTP2HTTPS(t *testing.T) {
|
||||
originURL := "http://127.0.0.1:1234/"
|
||||
hostname := "example.com"
|
||||
|
||||
assert.Error(t, ValidateHTTPService(originURL, hostname, testRoundTripper(func(req *http.Request) (*http.Response, error) {
|
||||
assert.Equal(t, req.Host, hostname)
|
||||
if req.URL.Scheme == "http" {
|
||||
return nil, assert.AnError
|
||||
}
|
||||
if req.URL.Scheme == "https" {
|
||||
return emptyResponse(200), nil
|
||||
}
|
||||
panic("Shouldn't reach here")
|
||||
})))
|
||||
|
||||
assert.Error(t, ValidateHTTPService(originURL, hostname, testRoundTripper(func(req *http.Request) (*http.Response, error) {
|
||||
assert.Equal(t, req.Host, hostname)
|
||||
if req.URL.Scheme == "http" {
|
||||
return nil, assert.AnError
|
||||
}
|
||||
if req.URL.Scheme == "https" {
|
||||
return emptyResponse(503), nil
|
||||
}
|
||||
panic("Shouldn't reach here")
|
||||
})))
|
||||
}
|
||||
|
||||
// Ensure the client does not follow 302 responses
|
||||
func TestValidateHTTPService_NoFollowRedirects(t *testing.T) {
|
||||
hostname := "example.com"
|
||||
redirectServer, redirectClient, err := createSecureMockServerAndClient(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
if r.URL.Path == "/followedRedirect" {
|
||||
t.Fatal("shouldn't have followed the 302")
|
||||
}
|
||||
if r.Method == "CONNECT" {
|
||||
assert.Equal(t, "127.0.0.1:443", r.Host)
|
||||
} else {
|
||||
assert.Equal(t, hostname, r.Host)
|
||||
}
|
||||
w.Header().Set("Location", "/followedRedirect")
|
||||
w.WriteHeader(302)
|
||||
}))
|
||||
assert.NoError(t, err)
|
||||
defer redirectServer.Close()
|
||||
assert.NoError(t, ValidateHTTPService(redirectServer.URL, hostname, redirectClient.Transport))
|
||||
}
|
||||
|
||||
// Ensure validation times out when origin URL is nonresponsive
|
||||
func TestValidateHTTPService_NonResponsiveOrigin(t *testing.T) {
|
||||
originURL := "http://127.0.0.1/"
|
||||
hostname := "example.com"
|
||||
oldValidationTimeout := validationTimeout
|
||||
defer func() {
|
||||
validationTimeout = oldValidationTimeout
|
||||
}()
|
||||
validationTimeout = 500 * time.Millisecond
|
||||
|
||||
// Use createMockServerAndClient, not createSecureMockServerAndClient.
|
||||
// The latter will bail with HTTP 400 immediately on an http:// request,
|
||||
// which defeats the purpose of a 'nonresponsive origin' test.
|
||||
server, client, err := createMockServerAndClient(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
if r.Method == "CONNECT" {
|
||||
assert.Equal(t, "127.0.0.1:443", r.Host)
|
||||
} else {
|
||||
assert.Equal(t, hostname, r.Host)
|
||||
}
|
||||
time.Sleep(1 * time.Second)
|
||||
w.WriteHeader(200)
|
||||
}))
|
||||
if !assert.NoError(t, err) {
|
||||
t.FailNow()
|
||||
}
|
||||
defer server.Close()
|
||||
|
||||
err = ValidateHTTPService(originURL, hostname, client.Transport)
|
||||
fmt.Println(err)
|
||||
if err, ok := err.(net.Error); assert.True(t, ok) {
|
||||
assert.True(t, err.Timeout())
|
||||
}
|
||||
}
|
||||
|
||||
func TestNewAccessValidatorOk(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
url := "test.cloudflareaccess.com"
|
||||
|
Reference in New Issue
Block a user