AUTH-2055: Verifies token at edge on access login

This commit is contained in:
Michael Borkenstein
2019-09-19 13:47:08 -05:00
parent a412f629c2
commit 1d5cc45ac7
3 changed files with 79 additions and 15 deletions

View File

@@ -114,7 +114,7 @@ func createWebsocketStream(options *StartOptions) (*cloudflaredWebsocket.Conn, e
wsConn, resp, err := cloudflaredWebsocket.ClientConnect(req, nil)
defer closeRespBody(resp)
if err != nil && isAccessResponse(resp) {
if err != nil && IsAccessResponse(resp) {
wsConn, err = createAccessAuthenticatedStream(options)
if err != nil {
return nil, err
@@ -126,10 +126,10 @@ func createWebsocketStream(options *StartOptions) (*cloudflaredWebsocket.Conn, e
return &cloudflaredWebsocket.Conn{Conn: wsConn}, nil
}
// isAccessResponse checks the http Response to see if the url location
// IsAccessResponse checks the http Response to see if the url location
// contains the Access structure.
func isAccessResponse(resp *http.Response) bool {
if resp == nil || resp.StatusCode <= 300 {
func IsAccessResponse(resp *http.Response) bool {
if resp == nil || resp.StatusCode != http.StatusFound {
return false
}
@@ -156,7 +156,7 @@ func createAccessAuthenticatedStream(options *StartOptions) (*websocket.Conn, er
return wsConn, nil
}
if !isAccessResponse(resp) {
if !IsAccessResponse(resp) {
return nil, err
}
@@ -179,7 +179,7 @@ func createAccessAuthenticatedStream(options *StartOptions) (*websocket.Conn, er
// createAccessWebSocketStream builds an Access request and makes a connection
func createAccessWebSocketStream(options *StartOptions) (*websocket.Conn, *http.Response, error) {
req, err := buildAccessRequest(options)
req, err := BuildAccessRequest(options)
if err != nil {
return nil, nil, err
}
@@ -187,7 +187,7 @@ func createAccessWebSocketStream(options *StartOptions) (*websocket.Conn, *http.
}
// buildAccessRequest builds an HTTP request with the Access token set
func buildAccessRequest(options *StartOptions) (*http.Request, error) {
func BuildAccessRequest(options *StartOptions) (*http.Request, error) {
req, err := http.NewRequest(http.MethodGet, options.OriginURL, nil)
if err != nil {
return nil, err

View File

@@ -102,14 +102,14 @@ func TestIsAccessResponse(t *testing.T) {
ExpectedOut bool
}{
{"nil response", nil, false},
{"redirect with no location", &http.Response{StatusCode: http.StatusPermanentRedirect}, false},
{"redirect with no location", &http.Response{StatusCode: http.StatusFound}, false},
{"200 ok", &http.Response{StatusCode: http.StatusOK}, false},
{"redirect with location", &http.Response{StatusCode: http.StatusPermanentRedirect, Header: validLocationHeader}, true},
{"redirect with invalid location", &http.Response{StatusCode: http.StatusPermanentRedirect, Header: invalidLocationHeader}, false},
{"redirect with location", &http.Response{StatusCode: http.StatusFound, Header: validLocationHeader}, true},
{"redirect with invalid location", &http.Response{StatusCode: http.StatusFound, Header: invalidLocationHeader}, false},
}
for i, tc := range testCases {
if isAccessResponse(tc.In) != tc.ExpectedOut {
if IsAccessResponse(tc.In) != tc.ExpectedOut {
t.Fatalf("Failed case %d -- %s", i, tc.Description)
}
}