mirror of
https://github.com/cloudflare/cloudflared.git
synced 2025-07-27 18:39:58 +00:00
TUN-3868: Refactor singleTCPService and bridgeService to tcpOverWSService and rawTCPService
This commit is contained in:
@@ -2,6 +2,9 @@ package ingress
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/tls"
|
||||
"fmt"
|
||||
"net"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"net/url"
|
||||
@@ -10,12 +13,168 @@ import (
|
||||
|
||||
"github.com/cloudflare/cloudflared/h2mux"
|
||||
"github.com/cloudflare/cloudflared/websocket"
|
||||
"github.com/rs/zerolog"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestBridgeServiceDestination(t *testing.T) {
|
||||
// TestEstablishConnectionResponse ensures each implementation of StreamBasedOriginProxy returns
|
||||
// the expected response
|
||||
func assertEstablishConnectionResponse(t *testing.T,
|
||||
originProxy StreamBasedOriginProxy,
|
||||
req *http.Request,
|
||||
expectHeader http.Header,
|
||||
) {
|
||||
_, resp, err := originProxy.EstablishConnection(req)
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, switchingProtocolText, resp.Status)
|
||||
assert.Equal(t, http.StatusSwitchingProtocols, resp.StatusCode)
|
||||
assert.Equal(t, expectHeader, resp.Header)
|
||||
}
|
||||
|
||||
func TestHTTPServiceEstablishConnection(t *testing.T) {
|
||||
origin := echoWSOrigin(t)
|
||||
defer origin.Close()
|
||||
originURL, err := url.Parse(origin.URL)
|
||||
require.NoError(t, err)
|
||||
|
||||
httpService := &httpService{
|
||||
url: originURL,
|
||||
hostHeader: origin.URL,
|
||||
transport: &http.Transport{
|
||||
TLSClientConfig: &tls.Config{
|
||||
InsecureSkipVerify: true,
|
||||
},
|
||||
},
|
||||
}
|
||||
req, err := http.NewRequest(http.MethodGet, origin.URL, nil)
|
||||
require.NoError(t, err)
|
||||
req.Header.Set("Sec-Websocket-Key", "dGhlIHNhbXBsZSBub25jZQ==")
|
||||
req.Header.Set("Test-Cloudflared-Echo", t.Name())
|
||||
|
||||
expectHeader := http.Header{
|
||||
"Connection": {"Upgrade"},
|
||||
"Sec-Websocket-Accept": {"s3pPLMBiTxaQ9kYGzzhZRbK+xOo="},
|
||||
"Upgrade": {"websocket"},
|
||||
"Test-Cloudflared-Echo": {t.Name()},
|
||||
}
|
||||
assertEstablishConnectionResponse(t, httpService, req, expectHeader)
|
||||
}
|
||||
|
||||
func TestHelloWorldEstablishConnection(t *testing.T) {
|
||||
var wg sync.WaitGroup
|
||||
shutdownC := make(chan struct{})
|
||||
errC := make(chan error)
|
||||
helloWorldSerivce := &helloWorld{}
|
||||
helloWorldSerivce.start(&wg, testLogger, shutdownC, errC, OriginRequestConfig{})
|
||||
|
||||
// Scheme and Host of URL will be override by the Scheme and Host of the helloWorld service
|
||||
req, err := http.NewRequest(http.MethodGet, "https://place-holder/ws", nil)
|
||||
require.NoError(t, err)
|
||||
|
||||
expectHeader := http.Header{
|
||||
"Connection": {"Upgrade"},
|
||||
// Accept key when Sec-Websocket-Key is not specified
|
||||
"Sec-Websocket-Accept": {"Kfh9QIsMVZcl6xEPYxPHzW8SZ8w="},
|
||||
"Upgrade": {"websocket"},
|
||||
}
|
||||
assertEstablishConnectionResponse(t, helloWorldSerivce, req, expectHeader)
|
||||
|
||||
close(shutdownC)
|
||||
}
|
||||
|
||||
func TestRawTCPServiceEstablishConnection(t *testing.T) {
|
||||
originListener, err := net.Listen("tcp", "127.0.0.1:0")
|
||||
require.NoError(t, err)
|
||||
|
||||
listenerClosed := make(chan struct{})
|
||||
tcpListenRoutine(originListener, listenerClosed)
|
||||
|
||||
rawTCPService := &rawTCPService{name: ServiceWarpRouting}
|
||||
|
||||
req, err := http.NewRequest(http.MethodGet, fmt.Sprintf("http://%s", originListener.Addr()), nil)
|
||||
require.NoError(t, err)
|
||||
|
||||
assertEstablishConnectionResponse(t, rawTCPService, req, nil)
|
||||
|
||||
originListener.Close()
|
||||
<-listenerClosed
|
||||
|
||||
req, err = http.NewRequest(http.MethodGet, fmt.Sprintf("http://%s", originListener.Addr()), nil)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Origin not listening for new connection, should return an error
|
||||
_, resp, err := rawTCPService.EstablishConnection(req)
|
||||
require.Error(t, err)
|
||||
require.Nil(t, resp)
|
||||
}
|
||||
|
||||
func TestTCPOverWSServiceEstablishConnection(t *testing.T) {
|
||||
originListener, err := net.Listen("tcp", "127.0.0.1:0")
|
||||
require.NoError(t, err)
|
||||
|
||||
listenerClosed := make(chan struct{})
|
||||
tcpListenRoutine(originListener, listenerClosed)
|
||||
|
||||
originURL := &url.URL{
|
||||
Scheme: "tcp",
|
||||
Host: originListener.Addr().String(),
|
||||
}
|
||||
|
||||
baseReq, err := http.NewRequest(http.MethodGet, "https://place-holder", nil)
|
||||
require.NoError(t, err)
|
||||
baseReq.Header.Set("Sec-Websocket-Key", "dGhlIHNhbXBsZSBub25jZQ==")
|
||||
|
||||
bastionReq := baseReq.Clone(context.Background())
|
||||
bastionReq.Header.Set(h2mux.CFJumpDestinationHeader, originListener.Addr().String())
|
||||
|
||||
expectHeader := http.Header{
|
||||
"Connection": {"Upgrade"},
|
||||
"Sec-Websocket-Accept": {"s3pPLMBiTxaQ9kYGzzhZRbK+xOo="},
|
||||
"Upgrade": {"websocket"},
|
||||
}
|
||||
|
||||
tests := []struct {
|
||||
service *tcpOverWSService
|
||||
req *http.Request
|
||||
expectErr bool
|
||||
}{
|
||||
{
|
||||
service: newTCPOverWSService(originURL),
|
||||
req: baseReq,
|
||||
},
|
||||
{
|
||||
service: newBastionService(),
|
||||
req: bastionReq,
|
||||
},
|
||||
{
|
||||
service: newBastionService(),
|
||||
req: baseReq,
|
||||
expectErr: true,
|
||||
},
|
||||
}
|
||||
|
||||
for _, test := range tests {
|
||||
if test.expectErr {
|
||||
_, resp, err := test.service.EstablishConnection(test.req)
|
||||
assert.Error(t, err)
|
||||
assert.Nil(t, resp)
|
||||
} else {
|
||||
assertEstablishConnectionResponse(t, test.service, test.req, expectHeader)
|
||||
}
|
||||
}
|
||||
|
||||
originListener.Close()
|
||||
<-listenerClosed
|
||||
|
||||
for _, service := range []*tcpOverWSService{newTCPOverWSService(originURL), newBastionService()} {
|
||||
// Origin not listening for new connection, should return an error
|
||||
_, resp, err := service.EstablishConnection(bastionReq)
|
||||
assert.Error(t, err)
|
||||
assert.Nil(t, resp)
|
||||
}
|
||||
}
|
||||
|
||||
func TestBastionDestination(t *testing.T) {
|
||||
canonicalJumpDestHeader := http.CanonicalHeaderKey(h2mux.CFJumpDestinationHeader)
|
||||
tests := []struct {
|
||||
name string
|
||||
@@ -98,12 +257,12 @@ func TestBridgeServiceDestination(t *testing.T) {
|
||||
wantErr: true,
|
||||
},
|
||||
}
|
||||
s := newBridgeService(nil, ServiceBastion)
|
||||
s := newBastionService()
|
||||
for _, test := range tests {
|
||||
r := &http.Request{
|
||||
Header: test.header,
|
||||
}
|
||||
dest, err := s.destination(r)
|
||||
dest, err := s.bastionDest(r)
|
||||
if test.wantErr {
|
||||
assert.Error(t, err, "Test %s expects error", test.name)
|
||||
} else {
|
||||
@@ -139,10 +298,9 @@ func TestHTTPServiceHostHeaderOverride(t *testing.T) {
|
||||
url: originURL,
|
||||
}
|
||||
var wg sync.WaitGroup
|
||||
log := zerolog.Nop()
|
||||
shutdownC := make(chan struct{})
|
||||
errC := make(chan error)
|
||||
require.NoError(t, httpService.start(&wg, &log, shutdownC, errC, cfg))
|
||||
require.NoError(t, httpService.start(&wg, testLogger, shutdownC, errC, cfg))
|
||||
|
||||
req, err := http.NewRequest(http.MethodGet, originURL.String(), nil)
|
||||
require.NoError(t, err)
|
||||
@@ -156,3 +314,17 @@ func TestHTTPServiceHostHeaderOverride(t *testing.T) {
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, http.StatusSwitchingProtocols, resp.StatusCode)
|
||||
}
|
||||
|
||||
func tcpListenRoutine(listener net.Listener, closeChan chan struct{}) {
|
||||
go func() {
|
||||
for {
|
||||
conn, err := listener.Accept()
|
||||
if err != nil {
|
||||
close(closeChan)
|
||||
return
|
||||
}
|
||||
// Close immediately, this test is not about testing read/write on connection
|
||||
conn.Close()
|
||||
}
|
||||
}()
|
||||
}
|
||||
|
Reference in New Issue
Block a user