mirror of
https://github.com/cloudflare/cloudflared.git
synced 2025-07-27 15:39:58 +00:00
TUN-3500: Integrate replace h2mux by http2 work with multiple origin support
This commit is contained in:
@@ -3,27 +3,32 @@ package origin
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"crypto/tls"
|
||||
"crypto/x509"
|
||||
"flag"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"net/url"
|
||||
"sync"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/cloudflare/cloudflared/cmd/cloudflared/config"
|
||||
"github.com/cloudflare/cloudflared/connection"
|
||||
"github.com/cloudflare/cloudflared/hello"
|
||||
"github.com/cloudflare/cloudflared/ingress"
|
||||
"github.com/cloudflare/cloudflared/logger"
|
||||
"github.com/cloudflare/cloudflared/tlsconfig"
|
||||
tunnelpogs "github.com/cloudflare/cloudflared/tunnelrpc/pogs"
|
||||
"github.com/urfave/cli/v2"
|
||||
|
||||
"github.com/gobwas/ws/wsutil"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
var (
|
||||
testTags = []tunnelpogs.Tag(nil)
|
||||
)
|
||||
|
||||
type mockHTTPRespWriter struct {
|
||||
*httptest.ResponseRecorder
|
||||
}
|
||||
@@ -99,49 +104,39 @@ func (w *mockSSERespWriter) ReadBytes() []byte {
|
||||
return <-w.writeNotification
|
||||
}
|
||||
|
||||
func TestProxy(t *testing.T) {
|
||||
func TestProxySingleOrigin(t *testing.T) {
|
||||
logger, err := logger.New()
|
||||
require.NoError(t, err)
|
||||
// let runtime pick an available port
|
||||
listener, err := hello.CreateTLSListener("127.0.0.1:0")
|
||||
require.NoError(t, err)
|
||||
|
||||
originURL := &url.URL{
|
||||
Scheme: "https",
|
||||
Host: listener.Addr().String(),
|
||||
}
|
||||
originCA := x509.NewCertPool()
|
||||
helloCert, err := tlsconfig.GetHelloCertificateX509()
|
||||
require.NoError(t, err)
|
||||
originCA.AddCert(helloCert)
|
||||
clientTLS := &tls.Config{
|
||||
RootCAs: originCA,
|
||||
}
|
||||
proxyConfig := &ProxyConfig{
|
||||
Client: &http.Transport{
|
||||
TLSClientConfig: clientTLS,
|
||||
},
|
||||
URL: originURL,
|
||||
TLSConfig: clientTLS,
|
||||
}
|
||||
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
|
||||
go func() {
|
||||
hello.StartHelloWorldServer(logger, listener, ctx.Done())
|
||||
}()
|
||||
flagSet := flag.NewFlagSet(t.Name(), flag.PanicOnError)
|
||||
flagSet.Bool("hello-world", true, "")
|
||||
|
||||
client := NewClient(proxyConfig, logger)
|
||||
t.Run("testProxyHTTP", testProxyHTTP(t, client, originURL))
|
||||
t.Run("testProxyWebsocket", testProxyWebsocket(t, client, originURL, clientTLS))
|
||||
t.Run("testProxySSE", testProxySSE(t, client, originURL))
|
||||
cliCtx := cli.NewContext(cli.NewApp(), flagSet, nil)
|
||||
err = cliCtx.Set("hello-world", "true")
|
||||
require.NoError(t, err)
|
||||
|
||||
allowURLFromArgs := false
|
||||
ingressRule, err := ingress.NewSingleOrigin(cliCtx, allowURLFromArgs, logger)
|
||||
require.NoError(t, err)
|
||||
|
||||
var wg sync.WaitGroup
|
||||
errC := make(chan error)
|
||||
ingressRule.StartOrigins(&wg, logger, ctx.Done(), errC)
|
||||
|
||||
client := NewClient(ingressRule, testTags, logger)
|
||||
t.Run("testProxyHTTP", testProxyHTTP(t, client))
|
||||
t.Run("testProxyWebsocket", testProxyWebsocket(t, client))
|
||||
t.Run("testProxySSE", testProxySSE(t, client))
|
||||
cancel()
|
||||
wg.Wait()
|
||||
}
|
||||
|
||||
func testProxyHTTP(t *testing.T, client connection.OriginClient, originURL *url.URL) func(t *testing.T) {
|
||||
func testProxyHTTP(t *testing.T, client connection.OriginClient) func(t *testing.T) {
|
||||
return func(t *testing.T) {
|
||||
respWriter := newMockHTTPRespWriter()
|
||||
req, err := http.NewRequest(http.MethodGet, originURL.String(), nil)
|
||||
req, err := http.NewRequest(http.MethodGet, "http://localhost:8080", nil)
|
||||
require.NoError(t, err)
|
||||
|
||||
err = client.Proxy(respWriter, req, false)
|
||||
@@ -151,11 +146,11 @@ func testProxyHTTP(t *testing.T, client connection.OriginClient, originURL *url.
|
||||
}
|
||||
}
|
||||
|
||||
func testProxyWebsocket(t *testing.T, client connection.OriginClient, originURL *url.URL, tlsConfig *tls.Config) func(t *testing.T) {
|
||||
func testProxyWebsocket(t *testing.T, client connection.OriginClient) func(t *testing.T) {
|
||||
return func(t *testing.T) {
|
||||
// WSRoute is a websocket echo handler
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
req, err := http.NewRequestWithContext(ctx, http.MethodGet, fmt.Sprintf("%s%s", originURL, hello.WSRoute), nil)
|
||||
req, err := http.NewRequestWithContext(ctx, http.MethodGet, fmt.Sprintf("http://localhost:8080%s", hello.WSRoute), nil)
|
||||
|
||||
readPipe, writePipe := io.Pipe()
|
||||
respWriter := newMockWSRespWriter(readPipe)
|
||||
@@ -191,7 +186,7 @@ func testProxyWebsocket(t *testing.T, client connection.OriginClient, originURL
|
||||
}
|
||||
}
|
||||
|
||||
func testProxySSE(t *testing.T, client connection.OriginClient, originURL *url.URL) func(t *testing.T) {
|
||||
func testProxySSE(t *testing.T, client connection.OriginClient) func(t *testing.T) {
|
||||
return func(t *testing.T) {
|
||||
var (
|
||||
pushCount = 50
|
||||
@@ -199,7 +194,7 @@ func testProxySSE(t *testing.T, client connection.OriginClient, originURL *url.U
|
||||
)
|
||||
respWriter := newMockSSERespWriter()
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
req, err := http.NewRequestWithContext(ctx, http.MethodGet, fmt.Sprintf("%s%s?freq=%s", originURL, hello.SSERoute, pushFreq), nil)
|
||||
req, err := http.NewRequestWithContext(ctx, http.MethodGet, fmt.Sprintf("http://localhost:8080%s?freq=%s", hello.SSERoute, pushFreq), nil)
|
||||
require.NoError(t, err)
|
||||
|
||||
var wg sync.WaitGroup
|
||||
@@ -225,3 +220,98 @@ func testProxySSE(t *testing.T, client connection.OriginClient, originURL *url.U
|
||||
wg.Wait()
|
||||
}
|
||||
}
|
||||
|
||||
func TestProxyMultipleOrigins(t *testing.T) {
|
||||
api := httptest.NewServer(mockAPI{})
|
||||
defer api.Close()
|
||||
|
||||
unvalidatedIngress := []config.UnvalidatedIngressRule{
|
||||
{
|
||||
Hostname: "api.example.com",
|
||||
Service: api.URL,
|
||||
},
|
||||
{
|
||||
Hostname: "hello.example.com",
|
||||
Service: "hello-world",
|
||||
},
|
||||
{
|
||||
Hostname: "health.example.com",
|
||||
Path: "/health",
|
||||
Service: "http_status:200",
|
||||
},
|
||||
{
|
||||
Hostname: "*",
|
||||
Service: "http_status:404",
|
||||
},
|
||||
}
|
||||
|
||||
ingress, err := ingress.ParseIngress(&config.Configuration{
|
||||
TunnelID: t.Name(),
|
||||
Ingress: unvalidatedIngress,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
logger, err := logger.New()
|
||||
require.NoError(t, err)
|
||||
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
errC := make(chan error)
|
||||
var wg sync.WaitGroup
|
||||
ingress.StartOrigins(&wg, logger, ctx.Done(), errC)
|
||||
|
||||
client := NewClient(ingress, testTags, logger)
|
||||
|
||||
tests := []struct {
|
||||
url string
|
||||
expectedStatus int
|
||||
expectedBody []byte
|
||||
}{
|
||||
{
|
||||
url: "http://api.example.com",
|
||||
expectedStatus: http.StatusCreated,
|
||||
expectedBody: []byte("Created"),
|
||||
},
|
||||
{
|
||||
url: fmt.Sprintf("http://hello.example.com%s", hello.HealthRoute),
|
||||
expectedStatus: http.StatusOK,
|
||||
expectedBody: []byte("ok"),
|
||||
},
|
||||
{
|
||||
url: "http://health.example.com/health",
|
||||
expectedStatus: http.StatusOK,
|
||||
},
|
||||
{
|
||||
url: "http://health.example.com/",
|
||||
expectedStatus: http.StatusNotFound,
|
||||
},
|
||||
{
|
||||
url: "http://not-found.example.com",
|
||||
expectedStatus: http.StatusNotFound,
|
||||
},
|
||||
}
|
||||
|
||||
for _, test := range tests {
|
||||
respWriter := newMockHTTPRespWriter()
|
||||
req, err := http.NewRequest(http.MethodGet, test.url, nil)
|
||||
require.NoError(t, err)
|
||||
|
||||
err = client.Proxy(respWriter, req, false)
|
||||
require.NoError(t, err)
|
||||
|
||||
assert.Equal(t, test.expectedStatus, respWriter.Code)
|
||||
if test.expectedBody != nil {
|
||||
assert.Equal(t, test.expectedBody, respWriter.Body.Bytes())
|
||||
} else {
|
||||
assert.Equal(t, 0, respWriter.Body.Len())
|
||||
}
|
||||
}
|
||||
cancel()
|
||||
wg.Wait()
|
||||
}
|
||||
|
||||
type mockAPI struct{}
|
||||
|
||||
func (ma mockAPI) ServeHTTP(w http.ResponseWriter, r *http.Request) {
|
||||
w.WriteHeader(http.StatusCreated)
|
||||
w.Write([]byte("Created"))
|
||||
}
|
||||
|
Reference in New Issue
Block a user