cloudflared/carrier/carrier_test.go
Austin Cherry 8f25704a90 AUTH-1736: Better handling of token revocation
We removed all token validation from cloudflared and now rely on
the edge to do the validation. This is better because the edge is
the only thing that fully knows about token revocation. So if a user
logs out or the application revokes all it's tokens cloudflared will
now handle that process instead of barfing on it.

When we go to fetch a token we will check for the existence of a
lock file. If the lock file exists, we stop and poll every half
second to see if the lock is still there. Once the lock file is
removed, it will restart the function to (hopefully) go pick up
the valid token that was just created.
2019-07-10 21:35:46 +00:00

156 lines
3.9 KiB
Go

package carrier
import (
"bytes"
"io"
"net"
"net/http"
"net/http/httptest"
"sync"
"testing"
ws "github.com/gorilla/websocket"
"github.com/sirupsen/logrus"
"github.com/stretchr/testify/assert"
)
const (
// example in Sec-Websocket-Key in rfc6455
testSecWebsocketKey = "dGhlIHNhbXBsZSBub25jZQ=="
)
type testStreamer struct {
buf *bytes.Buffer
l sync.RWMutex
}
func newTestStream() *testStreamer {
return &testStreamer{buf: new(bytes.Buffer)}
}
func (s *testStreamer) Read(p []byte) (int, error) {
s.l.RLock()
defer s.l.RUnlock()
return s.buf.Read(p)
}
func (s *testStreamer) Write(p []byte) (int, error) {
s.l.Lock()
defer s.l.Unlock()
return s.buf.Write(p)
}
func TestStartClient(t *testing.T) {
message := "Good morning Austin! Time for another sunny day in the great state of Texas."
logger := logrus.New()
ts := newTestWebSocketServer()
defer ts.Close()
buf := newTestStream()
options := &StartOptions{
OriginURL: "http://" + ts.Listener.Addr().String(),
Headers: nil,
}
err := StartClient(logger, buf, options)
assert.NoError(t, err)
buf.Write([]byte(message))
readBuffer := make([]byte, len(message))
buf.Read(readBuffer)
assert.Equal(t, message, string(readBuffer))
}
func TestStartServer(t *testing.T) {
listener, err := net.Listen("tcp", "localhost:")
if err != nil {
t.Fatalf("Error starting listener: %v", err)
}
message := "Good morning Austin! Time for another sunny day in the great state of Texas."
logger := logrus.New()
shutdownC := make(chan struct{})
ts := newTestWebSocketServer()
defer ts.Close()
options := &StartOptions{
OriginURL: "http://" + ts.Listener.Addr().String(),
Headers: nil,
}
go func() {
err := Serve(logger, listener, shutdownC, options)
if err != nil {
t.Fatalf("Error running server: %v", err)
}
}()
conn, err := net.Dial("tcp", listener.Addr().String())
conn.Write([]byte(message))
readBuffer := make([]byte, len(message))
conn.Read(readBuffer)
assert.Equal(t, string(readBuffer), message)
}
func TestIsAccessResponse(t *testing.T) {
validLocationHeader := http.Header{}
validLocationHeader.Add("location", "https://test.cloudflareaccess.com/cdn-cgi/access/login/blahblah")
invalidLocationHeader := http.Header{}
invalidLocationHeader.Add("location", "https://google.com")
testCases := []struct {
Description string
In *http.Response
ExpectedOut bool
}{
{"nil response", nil, false},
{"redirect with no location", &http.Response{StatusCode: http.StatusPermanentRedirect}, false},
{"200 ok", &http.Response{StatusCode: http.StatusOK}, false},
{"redirect with location", &http.Response{StatusCode: http.StatusPermanentRedirect, Header: validLocationHeader}, true},
{"redirect with invalid location", &http.Response{StatusCode: http.StatusPermanentRedirect, Header: invalidLocationHeader}, false},
}
for i, tc := range testCases {
if isAccessResponse(tc.In) != tc.ExpectedOut {
t.Fatalf("Failed case %d -- %s", i, tc.Description)
}
}
}
func newTestWebSocketServer() *httptest.Server {
upgrader := ws.Upgrader{
ReadBufferSize: 1024,
WriteBufferSize: 1024,
}
return httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
conn, _ := upgrader.Upgrade(w, r, nil)
defer conn.Close()
for {
mt, message, err := conn.ReadMessage()
if err != nil {
break
}
if err := conn.WriteMessage(mt, []byte(message)); err != nil {
break
}
}
}))
}
func testRequest(t *testing.T, url string, stream io.ReadWriter) *http.Request {
req, err := http.NewRequest("GET", url, stream)
if err != nil {
t.Fatalf("testRequestHeader error")
}
req.Header.Add("Connection", "Upgrade")
req.Header.Add("Upgrade", "WebSocket")
req.Header.Add("Sec-Websocket-Key", testSecWebsocketKey)
req.Header.Add("Sec-Websocket-Protocol", "tunnel-protocol")
req.Header.Add("Sec-Websocket-Version", "13")
req.Header.Add("User-Agent", "curl/7.59.0")
return req
}