TUN-3462: Refactor cloudflared to separate origin from connection

This commit is contained in:
cthuang
2020-10-08 11:12:26 +01:00
parent a5a5b93b64
commit 9ac40dcf04
32 changed files with 2006 additions and 1339 deletions

View File

@@ -66,7 +66,15 @@ func ValidateHostname(hostname string) (string, error) {
// but when it does not, the path is preserved:
// ValidateUrl("localhost:8080/api/") => "http://localhost:8080/api/"
// This is arguably a bug, but changing it might break some cloudflared users.
func ValidateUrl(originUrl string) (string, error) {
func ValidateUrl(originUrl string) (*url.URL, error) {
urlStr, err := validateUrlString(originUrl)
if err != nil {
return nil, err
}
return url.Parse(urlStr)
}
func validateUrlString(originUrl string) (string, error) {
if originUrl == "" {
return "", fmt.Errorf("URL should not be empty")
}
@@ -157,12 +165,8 @@ func validateIP(scheme, host, port string) (string, error) {
return fmt.Sprintf("%s://%s", scheme, host), nil
}
func ValidateHTTPService(originURL string, hostname string, transport http.RoundTripper) error {
parsedURL, err := url.Parse(originURL)
if err != nil {
return err
}
// originURL shouldn't be a pointer, because this function might change the scheme
func ValidateHTTPService(originURL url.URL, hostname string, transport http.RoundTripper) error {
client := &http.Client{
Transport: transport,
CheckRedirect: func(req *http.Request, via []*http.Request) error {
@@ -171,7 +175,7 @@ func ValidateHTTPService(originURL string, hostname string, transport http.Round
Timeout: validationTimeout,
}
initialRequest, err := http.NewRequest("GET", parsedURL.String(), nil)
initialRequest, err := http.NewRequest("GET", originURL.String(), nil)
if err != nil {
return err
}
@@ -183,10 +187,10 @@ func ValidateHTTPService(originURL string, hostname string, transport http.Round
}
// Attempt the same endpoint via the other protocol (http/https); maybe we have better luck?
oldScheme := parsedURL.Scheme
parsedURL.Scheme = toggleProtocol(parsedURL.Scheme)
oldScheme := originURL.Scheme
originURL.Scheme = toggleProtocol(originURL.Scheme)
secondRequest, err := http.NewRequest("GET", parsedURL.String(), nil)
secondRequest, err := http.NewRequest("GET", originURL.String(), nil)
if err != nil {
return err
}
@@ -195,12 +199,12 @@ func ValidateHTTPService(originURL string, hostname string, transport http.Round
if secondErr == nil { // Worked this time--advise the user to switch protocols
resp.Body.Close()
return errors.Errorf(
"%s doesn't seem to work over %s, but does seem to work over %s. Reason: %v. Consider changing the origin URL to %s",
parsedURL.Host,
"%s doesn't seem to work over %s, but does seem to work over %s. Reason: %v. Consider changing the origin URL to %v",
originURL.Host,
oldScheme,
parsedURL.Scheme,
originURL.Scheme,
initialErr,
parsedURL,
originURL,
)
}
@@ -224,12 +228,12 @@ type Access struct {
}
func NewAccessValidator(ctx context.Context, domain, issuer, applicationAUD string) (*Access, error) {
domainURL, err := ValidateUrl(domain)
domainURL, err := validateUrlString(domain)
if err != nil {
return nil, err
}
issuerURL, err := ValidateUrl(issuer)
issuerURL, err := validateUrlString(issuer)
if err != nil {
return nil, err
}

View File

@@ -101,7 +101,7 @@ func TestValidateUrl(t *testing.T) {
for i, testCase := range testCases {
validUrl, err := ValidateUrl(testCase.input)
assert.NoError(t, err, "test case %v", i)
assert.Equal(t, testCase.expectedOutput, validUrl, "test case %v", i)
assert.Equal(t, testCase.expectedOutput, validUrl.String(), "test case %v", i)
}
validUrl, err := ValidateUrl("")
@@ -123,7 +123,7 @@ func TestToggleProtocol(t *testing.T) {
// Happy path 1: originURL is HTTP, and HTTP connections work
func TestValidateHTTPService_HTTP2HTTP(t *testing.T) {
originURL := "http://127.0.0.1/"
originURL := mustParse(t, "http://127.0.0.1/")
hostname := "example.com"
assert.Nil(t, ValidateHTTPService(originURL, hostname, testRoundTripper(func(req *http.Request) (*http.Response, error) {
@@ -151,7 +151,7 @@ func TestValidateHTTPService_HTTP2HTTP(t *testing.T) {
// Happy path 2: originURL is HTTPS, and HTTPS connections work
func TestValidateHTTPService_HTTPS2HTTPS(t *testing.T) {
originURL := "https://127.0.0.1/"
originURL := mustParse(t, "https://127.0.0.1:1234/")
hostname := "example.com"
assert.Nil(t, ValidateHTTPService(originURL, hostname, testRoundTripper(func(req *http.Request) (*http.Response, error) {
@@ -179,7 +179,7 @@ func TestValidateHTTPService_HTTPS2HTTPS(t *testing.T) {
// Error path 1: originURL is HTTPS, but HTTP connections work
func TestValidateHTTPService_HTTPS2HTTP(t *testing.T) {
originURL := "https://127.0.0.1:1234/"
originURL := mustParse(t, "https://127.0.0.1:1234/")
hostname := "example.com"
assert.Error(t, ValidateHTTPService(originURL, hostname, testRoundTripper(func(req *http.Request) (*http.Response, error) {
@@ -207,10 +207,13 @@ func TestValidateHTTPService_HTTPS2HTTP(t *testing.T) {
// Error path 2: originURL is HTTP, but HTTPS connections work
func TestValidateHTTPService_HTTP2HTTPS(t *testing.T) {
originURL := "http://127.0.0.1:1234/"
originURLWithPort := url.URL{
Scheme: "http",
Host: "127.0.0.1:1234",
}
hostname := "example.com"
assert.Error(t, ValidateHTTPService(originURL, hostname, testRoundTripper(func(req *http.Request) (*http.Response, error) {
assert.Error(t, ValidateHTTPService(originURLWithPort, hostname, testRoundTripper(func(req *http.Request) (*http.Response, error) {
assert.Equal(t, req.Host, hostname)
if req.URL.Scheme == "http" {
return nil, assert.AnError
@@ -221,7 +224,7 @@ func TestValidateHTTPService_HTTP2HTTPS(t *testing.T) {
panic("Shouldn't reach here")
})))
assert.Error(t, ValidateHTTPService(originURL, hostname, testRoundTripper(func(req *http.Request) (*http.Response, error) {
assert.Error(t, ValidateHTTPService(originURLWithPort, hostname, testRoundTripper(func(req *http.Request) (*http.Response, error) {
assert.Equal(t, req.Host, hostname)
if req.URL.Scheme == "http" {
return nil, assert.AnError
@@ -250,12 +253,14 @@ func TestValidateHTTPService_NoFollowRedirects(t *testing.T) {
}))
assert.NoError(t, err)
defer redirectServer.Close()
assert.NoError(t, ValidateHTTPService(redirectServer.URL, hostname, redirectClient.Transport))
redirectServerURL, err := url.Parse(redirectServer.URL)
assert.NoError(t, err)
assert.NoError(t, ValidateHTTPService(*redirectServerURL, hostname, redirectClient.Transport))
}
// Ensure validation times out when origin URL is nonresponsive
func TestValidateHTTPService_NonResponsiveOrigin(t *testing.T) {
originURL := "http://127.0.0.1/"
originURL := mustParse(t, "http://127.0.0.1/")
hostname := "example.com"
oldValidationTimeout := validationTimeout
defer func() {
@@ -371,3 +376,9 @@ func createSecureMockServerAndClient(handler http.Handler) (*httptest.Server, *h
return server, client, nil
}
func mustParse(t *testing.T, originURL string) url.URL {
parsedURL, err := url.Parse(originURL)
assert.NoError(t, err)
return *parsedURL
}