mirror of
https://github.com/cloudflare/cloudflared.git
synced 2025-07-27 14:39:58 +00:00
TUN-3489: Add unit tests to cover proxy logic in connection package of cloudflared
This commit is contained in:
303
connection/http2_test.go
Normal file
303
connection/http2_test.go
Normal file
@@ -0,0 +1,303 @@
|
||||
package connection
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"io"
|
||||
"io/ioutil"
|
||||
"net"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"sync"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/cloudflare/cloudflared/logger"
|
||||
"github.com/cloudflare/cloudflared/tunnelrpc/pogs"
|
||||
tunnelpogs "github.com/cloudflare/cloudflared/tunnelrpc/pogs"
|
||||
"github.com/gobwas/ws/wsutil"
|
||||
"github.com/stretchr/testify/require"
|
||||
"golang.org/x/net/http2"
|
||||
)
|
||||
|
||||
var (
|
||||
testTransport = http2.Transport{}
|
||||
)
|
||||
|
||||
func newTestHTTP2Connection() (*http2Connection, net.Conn) {
|
||||
edgeConn, originConn := net.Pipe()
|
||||
var connIndex = uint8(0)
|
||||
return NewHTTP2Connection(
|
||||
originConn,
|
||||
testConfig,
|
||||
&NamedTunnelConfig{},
|
||||
&pogs.ConnectionOptions{},
|
||||
testObserver,
|
||||
connIndex,
|
||||
mockConnectedFuse{},
|
||||
), edgeConn
|
||||
}
|
||||
|
||||
func TestServeHTTP(t *testing.T) {
|
||||
tests := []testRequest{
|
||||
{
|
||||
name: "ok",
|
||||
endpoint: "ok",
|
||||
expectedStatus: http.StatusOK,
|
||||
expectedBody: []byte(http.StatusText(http.StatusOK)),
|
||||
},
|
||||
{
|
||||
name: "large_file",
|
||||
endpoint: "large_file",
|
||||
expectedStatus: http.StatusOK,
|
||||
expectedBody: testLargeResp,
|
||||
},
|
||||
{
|
||||
name: "Bad request",
|
||||
endpoint: "400",
|
||||
expectedStatus: http.StatusBadRequest,
|
||||
expectedBody: []byte(http.StatusText(http.StatusBadRequest)),
|
||||
},
|
||||
{
|
||||
name: "Internal server error",
|
||||
endpoint: "500",
|
||||
expectedStatus: http.StatusInternalServerError,
|
||||
expectedBody: []byte(http.StatusText(http.StatusInternalServerError)),
|
||||
},
|
||||
{
|
||||
name: "Proxy error",
|
||||
endpoint: "error",
|
||||
expectedStatus: http.StatusBadGateway,
|
||||
expectedBody: nil,
|
||||
isProxyError: true,
|
||||
},
|
||||
}
|
||||
|
||||
http2Conn, edgeConn := newTestHTTP2Connection()
|
||||
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
var wg sync.WaitGroup
|
||||
wg.Add(1)
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
http2Conn.Serve(ctx)
|
||||
}()
|
||||
|
||||
edgeHTTP2Conn, err := testTransport.NewClientConn(edgeConn)
|
||||
require.NoError(t, err)
|
||||
|
||||
for _, test := range tests {
|
||||
endpoint := fmt.Sprintf("http://localhost:8080/%s", test.endpoint)
|
||||
req, err := http.NewRequestWithContext(ctx, http.MethodGet, endpoint, nil)
|
||||
require.NoError(t, err)
|
||||
|
||||
resp, err := edgeHTTP2Conn.RoundTrip(req)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, test.expectedStatus, resp.StatusCode)
|
||||
if test.expectedBody != nil {
|
||||
respBody, err := ioutil.ReadAll(resp.Body)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, test.expectedBody, respBody)
|
||||
}
|
||||
if test.isProxyError {
|
||||
require.Equal(t, responseMetaHeaderCfd, resp.Header.Get(responseMetaHeaderField))
|
||||
} else {
|
||||
require.Equal(t, responseMetaHeaderOrigin, resp.Header.Get(responseMetaHeaderField))
|
||||
}
|
||||
}
|
||||
cancel()
|
||||
wg.Wait()
|
||||
}
|
||||
|
||||
type mockNamedTunnelRPCClient struct {
|
||||
registered chan struct{}
|
||||
unregistered chan struct{}
|
||||
}
|
||||
|
||||
func (mc mockNamedTunnelRPCClient) RegisterConnection(
|
||||
c context.Context,
|
||||
config *NamedTunnelConfig,
|
||||
options *tunnelpogs.ConnectionOptions,
|
||||
connIndex uint8,
|
||||
observer *Observer,
|
||||
) error {
|
||||
close(mc.registered)
|
||||
return nil
|
||||
}
|
||||
|
||||
func (mc mockNamedTunnelRPCClient) GracefulShutdown(ctx context.Context, gracePeriod time.Duration) {
|
||||
close(mc.unregistered)
|
||||
}
|
||||
|
||||
func (mockNamedTunnelRPCClient) Close() {}
|
||||
|
||||
type mockRPCClientFactory struct {
|
||||
registered chan struct{}
|
||||
unregistered chan struct{}
|
||||
}
|
||||
|
||||
func (mf *mockRPCClientFactory) newMockRPCClient(context.Context, io.ReadWriteCloser, logger.Service) NamedTunnelRPCClient {
|
||||
return mockNamedTunnelRPCClient{
|
||||
registered: mf.registered,
|
||||
unregistered: mf.unregistered,
|
||||
}
|
||||
}
|
||||
|
||||
type wsRespWriter struct {
|
||||
*httptest.ResponseRecorder
|
||||
readPipe *io.PipeReader
|
||||
writePipe *io.PipeWriter
|
||||
}
|
||||
|
||||
func newWSRespWriter() *wsRespWriter {
|
||||
readPipe, writePipe := io.Pipe()
|
||||
return &wsRespWriter{
|
||||
httptest.NewRecorder(),
|
||||
readPipe,
|
||||
writePipe,
|
||||
}
|
||||
}
|
||||
|
||||
func (w *wsRespWriter) RespBody() io.ReadWriter {
|
||||
return nowriter{w.readPipe}
|
||||
}
|
||||
|
||||
func (w *wsRespWriter) Write(data []byte) (n int, err error) {
|
||||
return w.writePipe.Write(data)
|
||||
}
|
||||
|
||||
func TestServeWS(t *testing.T) {
|
||||
http2Conn, _ := newTestHTTP2Connection()
|
||||
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
var wg sync.WaitGroup
|
||||
wg.Add(1)
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
http2Conn.Serve(ctx)
|
||||
}()
|
||||
|
||||
respWriter := newWSRespWriter()
|
||||
readPipe, writePipe := io.Pipe()
|
||||
|
||||
req, err := http.NewRequestWithContext(ctx, http.MethodGet, "http://localhost:8080/ws", readPipe)
|
||||
require.NoError(t, err)
|
||||
req.Header.Set(internalUpgradeHeader, websocketUpgrade)
|
||||
|
||||
wg.Add(1)
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
http2Conn.ServeHTTP(respWriter, req)
|
||||
}()
|
||||
|
||||
data := []byte("test websocket")
|
||||
err = wsutil.WriteClientText(writePipe, data)
|
||||
require.NoError(t, err)
|
||||
|
||||
respBody, err := wsutil.ReadServerText(respWriter.RespBody())
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, data, respBody, fmt.Sprintf("Expect %s, got %s", string(data), string(respBody)))
|
||||
|
||||
cancel()
|
||||
resp := respWriter.Result()
|
||||
// http2RespWriter should rewrite status 101 to 200
|
||||
require.Equal(t, http.StatusOK, resp.StatusCode)
|
||||
require.Equal(t, responseMetaHeaderOrigin, resp.Header.Get(responseMetaHeaderField))
|
||||
|
||||
wg.Wait()
|
||||
}
|
||||
|
||||
func TestServeControlStream(t *testing.T) {
|
||||
http2Conn, edgeConn := newTestHTTP2Connection()
|
||||
|
||||
rpcClientFactory := mockRPCClientFactory{
|
||||
registered: make(chan struct{}),
|
||||
unregistered: make(chan struct{}),
|
||||
}
|
||||
http2Conn.newRPCClientFunc = rpcClientFactory.newMockRPCClient
|
||||
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
var wg sync.WaitGroup
|
||||
wg.Add(1)
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
http2Conn.Serve(ctx)
|
||||
}()
|
||||
|
||||
req, err := http.NewRequestWithContext(ctx, http.MethodGet, "http://localhost:8080/", nil)
|
||||
require.NoError(t, err)
|
||||
req.Header.Set(internalUpgradeHeader, controlStreamUpgrade)
|
||||
|
||||
edgeHTTP2Conn, err := testTransport.NewClientConn(edgeConn)
|
||||
require.NoError(t, err)
|
||||
|
||||
wg.Add(1)
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
edgeHTTP2Conn.RoundTrip(req)
|
||||
}()
|
||||
|
||||
<-rpcClientFactory.registered
|
||||
cancel()
|
||||
<-rpcClientFactory.unregistered
|
||||
|
||||
wg.Wait()
|
||||
}
|
||||
|
||||
func benchmarkServeHTTP(b *testing.B, test testRequest) {
|
||||
http2Conn, edgeConn := newTestHTTP2Connection()
|
||||
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
var wg sync.WaitGroup
|
||||
wg.Add(1)
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
http2Conn.Serve(ctx)
|
||||
}()
|
||||
|
||||
endpoint := fmt.Sprintf("http://localhost:8080/%s", test.endpoint)
|
||||
req, err := http.NewRequestWithContext(ctx, http.MethodGet, endpoint, nil)
|
||||
require.NoError(b, err)
|
||||
|
||||
edgeHTTP2Conn, err := testTransport.NewClientConn(edgeConn)
|
||||
require.NoError(b, err)
|
||||
|
||||
b.ResetTimer()
|
||||
for i := 0; i < b.N; i++ {
|
||||
b.StartTimer()
|
||||
resp, err := edgeHTTP2Conn.RoundTrip(req)
|
||||
b.StopTimer()
|
||||
require.NoError(b, err)
|
||||
require.Equal(b, test.expectedStatus, resp.StatusCode)
|
||||
if test.expectedBody != nil {
|
||||
respBody, err := ioutil.ReadAll(resp.Body)
|
||||
require.NoError(b, err)
|
||||
require.Equal(b, test.expectedBody, respBody)
|
||||
}
|
||||
resp.Body.Close()
|
||||
}
|
||||
|
||||
cancel()
|
||||
wg.Wait()
|
||||
}
|
||||
func BenchmarkServeHTTPSimple(b *testing.B) {
|
||||
test := testRequest{
|
||||
name: "ok",
|
||||
endpoint: "ok",
|
||||
expectedStatus: http.StatusOK,
|
||||
expectedBody: []byte(http.StatusText(http.StatusOK)),
|
||||
}
|
||||
|
||||
benchmarkServeHTTP(b, test)
|
||||
}
|
||||
|
||||
func BenchmarkServeHTTPLargeFile(b *testing.B) {
|
||||
test := testRequest{
|
||||
name: "large_file",
|
||||
endpoint: "large_file",
|
||||
expectedStatus: http.StatusOK,
|
||||
expectedBody: testLargeResp,
|
||||
}
|
||||
|
||||
benchmarkServeHTTP(b, test)
|
||||
}
|
Reference in New Issue
Block a user