mirror of
https://github.com/cloudflare/cloudflared.git
synced 2025-07-27 20:09:58 +00:00
TUN-3868: Refactor singleTCPService and bridgeService to tcpOverWSService and rawTCPService
This commit is contained in:
@@ -25,7 +25,6 @@ var (
|
||||
)
|
||||
|
||||
const (
|
||||
ServiceBridge = "bridge service"
|
||||
ServiceBastion = "bastion"
|
||||
ServiceWarpRouting = "warp-routing"
|
||||
)
|
||||
@@ -98,8 +97,7 @@ type WarpRoutingService struct {
|
||||
}
|
||||
|
||||
func NewWarpRoutingService() *WarpRoutingService {
|
||||
warpRoutingService := newBridgeService(DefaultStreamHandler, ServiceWarpRouting)
|
||||
return &WarpRoutingService{Proxy: warpRoutingService}
|
||||
return &WarpRoutingService{Proxy: &rawTCPService{name: ServiceWarpRouting}}
|
||||
}
|
||||
|
||||
// Get a single origin service from the CLI/config.
|
||||
@@ -108,7 +106,7 @@ func parseSingleOriginService(c *cli.Context, allowURLFromArgs bool) (originServ
|
||||
return new(helloWorld), nil
|
||||
}
|
||||
if c.IsSet(config.BastionFlag) {
|
||||
return newBridgeService(nil, ServiceBastion), nil
|
||||
return newBastionService(), nil
|
||||
}
|
||||
if c.IsSet("url") {
|
||||
originURL, err := config.ValidateUrl(c, allowURLFromArgs)
|
||||
@@ -120,7 +118,7 @@ func parseSingleOriginService(c *cli.Context, allowURLFromArgs bool) (originServ
|
||||
url: originURL,
|
||||
}, nil
|
||||
}
|
||||
return newSingleTCPService(originURL), nil
|
||||
return newTCPOverWSService(originURL), nil
|
||||
}
|
||||
if c.IsSet("unix-socket") {
|
||||
path, err := config.ValidateUnixSocket(c)
|
||||
@@ -182,7 +180,7 @@ func validate(ingress []config.UnvalidatedIngressRule, defaults OriginRequestCon
|
||||
// overwrite the localService.URL field when `start` is called. So,
|
||||
// leave the URL field empty for now.
|
||||
cfg.BastionMode = true
|
||||
service = newBridgeService(nil, ServiceBastion)
|
||||
service = newBastionService()
|
||||
} else {
|
||||
// Validate URL services
|
||||
u, err := url.Parse(r.Service)
|
||||
@@ -200,7 +198,7 @@ func validate(ingress []config.UnvalidatedIngressRule, defaults OriginRequestCon
|
||||
if isHTTPService(u) {
|
||||
service = &httpService{url: u}
|
||||
} else {
|
||||
service = newSingleTCPService(u)
|
||||
service = newTCPOverWSService(u)
|
||||
}
|
||||
}
|
||||
|
||||
|
@@ -238,12 +238,12 @@ ingress:
|
||||
want: []Rule{
|
||||
{
|
||||
Hostname: "tcp.foo.com",
|
||||
Service: newSingleTCPService(MustParseURL(t, "tcp://127.0.0.1:7864")),
|
||||
Service: newTCPOverWSService(MustParseURL(t, "tcp://127.0.0.1:7864")),
|
||||
Config: defaultConfig,
|
||||
},
|
||||
{
|
||||
Hostname: "tcp2.foo.com",
|
||||
Service: newSingleTCPService(MustParseURL(t, "tcp://localhost:8000")),
|
||||
Service: newTCPOverWSService(MustParseURL(t, "tcp://localhost:8000")),
|
||||
Config: defaultConfig,
|
||||
},
|
||||
{
|
||||
@@ -260,7 +260,7 @@ ingress:
|
||||
`},
|
||||
want: []Rule{
|
||||
{
|
||||
Service: newSingleTCPService(MustParseURL(t, "ssh://127.0.0.1:22")),
|
||||
Service: newTCPOverWSService(MustParseURL(t, "ssh://127.0.0.1:22")),
|
||||
Config: defaultConfig,
|
||||
},
|
||||
},
|
||||
@@ -273,7 +273,7 @@ ingress:
|
||||
`},
|
||||
want: []Rule{
|
||||
{
|
||||
Service: newSingleTCPService(MustParseURL(t, "rdp://127.0.0.1:3389")),
|
||||
Service: newTCPOverWSService(MustParseURL(t, "rdp://127.0.0.1:3389")),
|
||||
Config: defaultConfig,
|
||||
},
|
||||
},
|
||||
@@ -286,7 +286,7 @@ ingress:
|
||||
`},
|
||||
want: []Rule{
|
||||
{
|
||||
Service: newSingleTCPService(MustParseURL(t, "smb://127.0.0.1:445")),
|
||||
Service: newTCPOverWSService(MustParseURL(t, "smb://127.0.0.1:445")),
|
||||
Config: defaultConfig,
|
||||
},
|
||||
},
|
||||
@@ -299,7 +299,7 @@ ingress:
|
||||
`},
|
||||
want: []Rule{
|
||||
{
|
||||
Service: newSingleTCPService(MustParseURL(t, "ftp://127.0.0.1")),
|
||||
Service: newTCPOverWSService(MustParseURL(t, "ftp://127.0.0.1")),
|
||||
Config: defaultConfig,
|
||||
},
|
||||
},
|
||||
@@ -316,7 +316,7 @@ ingress:
|
||||
want: []Rule{
|
||||
{
|
||||
Hostname: "bastion.foo.com",
|
||||
Service: newBridgeService(nil, ServiceBastion),
|
||||
Service: newBastionService(),
|
||||
Config: setConfig(originRequestFromYAML(config.OriginRequestConfig{}), config.OriginRequestConfig{BastionMode: &tr}),
|
||||
},
|
||||
{
|
||||
@@ -336,7 +336,7 @@ ingress:
|
||||
want: []Rule{
|
||||
{
|
||||
Hostname: "bastion.foo.com",
|
||||
Service: newBridgeService(nil, ServiceBastion),
|
||||
Service: newBastionService(),
|
||||
Config: setConfig(originRequestFromYAML(config.OriginRequestConfig{}), config.OriginRequestConfig{BastionMode: &tr}),
|
||||
},
|
||||
{
|
||||
|
@@ -1,11 +1,12 @@
|
||||
package ingress
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/tls"
|
||||
"io"
|
||||
"net"
|
||||
"net/http"
|
||||
|
||||
"github.com/cloudflare/cloudflared/connection"
|
||||
"github.com/cloudflare/cloudflared/websocket"
|
||||
gws "github.com/gorilla/websocket"
|
||||
"github.com/rs/zerolog"
|
||||
@@ -15,9 +16,8 @@ import (
|
||||
// Different concrete implementations will stream different protocols as long as they are io.ReadWriters.
|
||||
type OriginConnection interface {
|
||||
// Stream should generally be implemented as a bidirectional io.Copy.
|
||||
Stream(tunnelConn io.ReadWriter, log *zerolog.Logger)
|
||||
Stream(ctx context.Context, tunnelConn io.ReadWriter, log *zerolog.Logger)
|
||||
Close()
|
||||
Type() connection.Type
|
||||
}
|
||||
|
||||
type streamHandlerFunc func(originConn io.ReadWriter, remoteConn net.Conn, log *zerolog.Logger)
|
||||
@@ -54,30 +54,38 @@ func DefaultStreamHandler(originConn io.ReadWriter, remoteConn net.Conn, log *ze
|
||||
|
||||
// tcpConnection is an OriginConnection that directly streams to raw TCP.
|
||||
type tcpConnection struct {
|
||||
conn net.Conn
|
||||
streamHandler streamHandlerFunc
|
||||
conn net.Conn
|
||||
}
|
||||
|
||||
func (tc *tcpConnection) Stream(tunnelConn io.ReadWriter, log *zerolog.Logger) {
|
||||
tc.streamHandler(tunnelConn, tc.conn, log)
|
||||
func (tc *tcpConnection) Stream(ctx context.Context, tunnelConn io.ReadWriter, log *zerolog.Logger) {
|
||||
Stream(tunnelConn, tc.conn, log)
|
||||
}
|
||||
|
||||
func (tc *tcpConnection) Close() {
|
||||
tc.conn.Close()
|
||||
}
|
||||
|
||||
func (*tcpConnection) Type() connection.Type {
|
||||
return connection.TypeTCP
|
||||
// tcpOverWSConnection is an OriginConnection that streams to TCP over WS.
|
||||
type tcpOverWSConnection struct {
|
||||
conn net.Conn
|
||||
streamHandler streamHandlerFunc
|
||||
}
|
||||
|
||||
// wsConnection is an OriginConnection that streams to TCP packets by encapsulating them in Websockets.
|
||||
// TODO: TUN-3710 Remove wsConnection and have helloworld service reuse tcpConnection like bridgeService does.
|
||||
func (wc *tcpOverWSConnection) Stream(ctx context.Context, tunnelConn io.ReadWriter, log *zerolog.Logger) {
|
||||
wc.streamHandler(websocket.NewConn(ctx, tunnelConn, log), wc.conn, log)
|
||||
}
|
||||
|
||||
func (wc *tcpOverWSConnection) Close() {
|
||||
wc.conn.Close()
|
||||
}
|
||||
|
||||
// wsConnection is an OriginConnection that streams WS between eyeball and origin.
|
||||
type wsConnection struct {
|
||||
wsConn *gws.Conn
|
||||
resp *http.Response
|
||||
}
|
||||
|
||||
func (wsc *wsConnection) Stream(tunnelConn io.ReadWriter, log *zerolog.Logger) {
|
||||
func (wsc *wsConnection) Stream(ctx context.Context, tunnelConn io.ReadWriter, log *zerolog.Logger) {
|
||||
Stream(tunnelConn, wsc.wsConn.UnderlyingConn(), log)
|
||||
}
|
||||
|
||||
@@ -86,13 +94,9 @@ func (wsc *wsConnection) Close() {
|
||||
wsc.wsConn.Close()
|
||||
}
|
||||
|
||||
func (wsc *wsConnection) Type() connection.Type {
|
||||
return connection.TypeWebsocket
|
||||
}
|
||||
|
||||
func newWSConnection(transport *http.Transport, r *http.Request) (OriginConnection, *http.Response, error) {
|
||||
func newWSConnection(clientTLSConfig *tls.Config, r *http.Request) (OriginConnection, *http.Response, error) {
|
||||
d := &gws.Dialer{
|
||||
TLSClientConfig: transport.TLSClientConfig,
|
||||
TLSClientConfig: clientTLSConfig,
|
||||
}
|
||||
wsConn, resp, err := websocket.ClientConnect(r, d)
|
||||
if err != nil {
|
||||
|
177
ingress/origin_connection_test.go
Normal file
177
ingress/origin_connection_test.go
Normal file
@@ -0,0 +1,177 @@
|
||||
package ingress
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/tls"
|
||||
"fmt"
|
||||
"net"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/cloudflare/cloudflared/logger"
|
||||
"github.com/gobwas/ws/wsutil"
|
||||
"github.com/gorilla/websocket"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
"golang.org/x/sync/errgroup"
|
||||
)
|
||||
|
||||
const (
|
||||
testStreamTimeout = time.Second * 3
|
||||
)
|
||||
|
||||
var (
|
||||
testLogger = logger.Create(nil)
|
||||
testMessage = []byte("TestStreamOriginConnection")
|
||||
testResponse = []byte(fmt.Sprintf("echo-%s", testMessage))
|
||||
)
|
||||
|
||||
func TestStreamTCPConnection(t *testing.T) {
|
||||
cfdConn, originConn := net.Pipe()
|
||||
tcpConn := tcpConnection{
|
||||
conn: cfdConn,
|
||||
}
|
||||
|
||||
eyeballConn, edgeConn := net.Pipe()
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), testStreamTimeout)
|
||||
defer cancel()
|
||||
|
||||
errGroup, ctx := errgroup.WithContext(ctx)
|
||||
errGroup.Go(func() error {
|
||||
_, err := eyeballConn.Write(testMessage)
|
||||
|
||||
readBuffer := make([]byte, len(testResponse))
|
||||
_, err = eyeballConn.Read(readBuffer)
|
||||
require.NoError(t, err)
|
||||
|
||||
require.Equal(t, testResponse, readBuffer)
|
||||
|
||||
return nil
|
||||
})
|
||||
errGroup.Go(func() error {
|
||||
echoTCPOrigin(t, originConn)
|
||||
originConn.Close()
|
||||
return nil
|
||||
})
|
||||
|
||||
tcpConn.Stream(ctx, edgeConn, testLogger)
|
||||
require.NoError(t, errGroup.Wait())
|
||||
}
|
||||
|
||||
func TestStreamWSOverTCPConnection(t *testing.T) {
|
||||
cfdConn, originConn := net.Pipe()
|
||||
tcpOverWSConn := tcpOverWSConnection{
|
||||
conn: cfdConn,
|
||||
streamHandler: DefaultStreamHandler,
|
||||
}
|
||||
|
||||
eyeballConn, edgeConn := net.Pipe()
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), testStreamTimeout)
|
||||
defer cancel()
|
||||
|
||||
errGroup, ctx := errgroup.WithContext(ctx)
|
||||
errGroup.Go(func() error {
|
||||
echoWSEyeball(t, eyeballConn)
|
||||
return nil
|
||||
})
|
||||
errGroup.Go(func() error {
|
||||
echoTCPOrigin(t, originConn)
|
||||
originConn.Close()
|
||||
return nil
|
||||
})
|
||||
|
||||
tcpOverWSConn.Stream(ctx, edgeConn, testLogger)
|
||||
require.NoError(t, errGroup.Wait())
|
||||
}
|
||||
|
||||
func TestStreamWSConnection(t *testing.T) {
|
||||
eyeballConn, edgeConn := net.Pipe()
|
||||
|
||||
origin := echoWSOrigin(t)
|
||||
defer origin.Close()
|
||||
|
||||
req, err := http.NewRequest(http.MethodGet, origin.URL, nil)
|
||||
require.NoError(t, err)
|
||||
req.Header.Set("Sec-Websocket-Key", "dGhlIHNhbXBsZSBub25jZQ==")
|
||||
|
||||
clientTLSConfig := &tls.Config{
|
||||
InsecureSkipVerify: true,
|
||||
}
|
||||
wsConn, resp, err := newWSConnection(clientTLSConfig, req)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, http.StatusSwitchingProtocols, resp.StatusCode)
|
||||
require.Equal(t, "Upgrade", resp.Header.Get("Connection"))
|
||||
require.Equal(t, "s3pPLMBiTxaQ9kYGzzhZRbK+xOo=", resp.Header.Get("Sec-Websocket-Accept"))
|
||||
require.Equal(t, "websocket", resp.Header.Get("Upgrade"))
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), testStreamTimeout)
|
||||
defer cancel()
|
||||
|
||||
errGroup, ctx := errgroup.WithContext(ctx)
|
||||
errGroup.Go(func() error {
|
||||
echoWSEyeball(t, eyeballConn)
|
||||
return nil
|
||||
})
|
||||
|
||||
wsConn.Stream(ctx, edgeConn, testLogger)
|
||||
require.NoError(t, errGroup.Wait())
|
||||
}
|
||||
|
||||
func echoWSEyeball(t *testing.T, conn net.Conn) {
|
||||
require.NoError(t, wsutil.WriteClientBinary(conn, testMessage))
|
||||
|
||||
readMsg, err := wsutil.ReadServerBinary(conn)
|
||||
require.NoError(t, err)
|
||||
|
||||
require.Equal(t, testResponse, readMsg)
|
||||
|
||||
require.NoError(t, conn.Close())
|
||||
}
|
||||
|
||||
func echoWSOrigin(t *testing.T) *httptest.Server {
|
||||
var upgrader = websocket.Upgrader{
|
||||
ReadBufferSize: 10,
|
||||
WriteBufferSize: 10,
|
||||
}
|
||||
|
||||
ws := func(w http.ResponseWriter, r *http.Request) {
|
||||
header := make(http.Header)
|
||||
for k, vs := range r.Header {
|
||||
if k == "Test-Cloudflared-Echo" {
|
||||
header[k] = vs
|
||||
}
|
||||
}
|
||||
conn, err := upgrader.Upgrade(w, r, header)
|
||||
require.NoError(t, err)
|
||||
defer conn.Close()
|
||||
|
||||
for {
|
||||
messageType, p, err := conn.ReadMessage()
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
require.Equal(t, testMessage, p)
|
||||
if err := conn.WriteMessage(messageType, testResponse); err != nil {
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// NewTLSServer starts the server in another thread
|
||||
return httptest.NewTLSServer(http.HandlerFunc(ws))
|
||||
}
|
||||
|
||||
func echoTCPOrigin(t *testing.T, conn net.Conn) {
|
||||
readBuffer := make([]byte, len(testMessage))
|
||||
_, err := conn.Read(readBuffer)
|
||||
assert.NoError(t, err)
|
||||
|
||||
assert.Equal(t, testMessage, readBuffer)
|
||||
|
||||
_, err = conn.Write(testResponse)
|
||||
assert.NoError(t, err)
|
||||
}
|
@@ -7,19 +7,22 @@ import (
|
||||
"net/url"
|
||||
"strings"
|
||||
|
||||
"github.com/cloudflare/cloudflared/connection"
|
||||
"github.com/cloudflare/cloudflared/h2mux"
|
||||
"github.com/cloudflare/cloudflared/websocket"
|
||||
"github.com/pkg/errors"
|
||||
)
|
||||
|
||||
var (
|
||||
switchingProtocolText = fmt.Sprintf("%d %s", http.StatusSwitchingProtocols, http.StatusText(http.StatusSwitchingProtocols))
|
||||
)
|
||||
|
||||
// HTTPOriginProxy can be implemented by origin services that want to proxy http requests.
|
||||
type HTTPOriginProxy interface {
|
||||
// RoundTrip is how cloudflared proxies eyeball requests to the actual origin services
|
||||
http.RoundTripper
|
||||
}
|
||||
|
||||
// StreamBasedOriginProxy can be implemented by origin services that want to proxy at the L4 level.
|
||||
// StreamBasedOriginProxy can be implemented by origin services that want to proxy ws/TCP.
|
||||
type StreamBasedOriginProxy interface {
|
||||
EstablishConnection(r *http.Request) (OriginConnection, *http.Response, error)
|
||||
}
|
||||
@@ -28,11 +31,6 @@ func (o *unixSocketPath) RoundTrip(req *http.Request) (*http.Response, error) {
|
||||
return o.transport.RoundTrip(req)
|
||||
}
|
||||
|
||||
// TODO: TUN-3636: establish connection to origins over UDS
|
||||
func (*unixSocketPath) EstablishConnection(r *http.Request) (OriginConnection, *http.Response, error) {
|
||||
return nil, nil, fmt.Errorf("Unix socket service currently doesn't support proxying connections")
|
||||
}
|
||||
|
||||
func (o *httpService) RoundTrip(req *http.Request) (*http.Response, error) {
|
||||
// Rewrite the request URL so that it goes to the origin service.
|
||||
req.URL.Host = o.url.Host
|
||||
@@ -51,7 +49,7 @@ func (o *httpService) EstablishConnection(req *http.Request) (OriginConnection,
|
||||
// For incoming requests, the Host header is promoted to the Request.Host field and removed from the Header map.
|
||||
req.Host = o.hostHeader
|
||||
}
|
||||
return newWSConnection(o.transport, req)
|
||||
return newWSConnection(o.transport.TLSClientConfig, req)
|
||||
}
|
||||
|
||||
func (o *helloWorld) RoundTrip(req *http.Request) (*http.Response, error) {
|
||||
@@ -64,20 +62,32 @@ func (o *helloWorld) RoundTrip(req *http.Request) (*http.Response, error) {
|
||||
func (o *helloWorld) EstablishConnection(req *http.Request) (OriginConnection, *http.Response, error) {
|
||||
req.URL.Host = o.server.Addr().String()
|
||||
req.URL.Scheme = "wss"
|
||||
return newWSConnection(o.transport, req)
|
||||
return newWSConnection(o.transport.TLSClientConfig, req)
|
||||
}
|
||||
|
||||
func (o *statusCode) RoundTrip(_ *http.Request) (*http.Response, error) {
|
||||
return o.resp, nil
|
||||
}
|
||||
|
||||
func (o *bridgeService) EstablishConnection(r *http.Request) (OriginConnection, *http.Response, error) {
|
||||
dest, err := o.destination(r)
|
||||
func (o *rawTCPService) EstablishConnection(r *http.Request) (OriginConnection, *http.Response, error) {
|
||||
dest, err := getRequestHost(r)
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
conn, err := o.client.connect(r, dest)
|
||||
return conn, nil, err
|
||||
conn, err := net.Dial("tcp", dest)
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
|
||||
originConn := &tcpConnection{
|
||||
conn: conn,
|
||||
}
|
||||
resp := &http.Response{
|
||||
Status: switchingProtocolText,
|
||||
StatusCode: http.StatusSwitchingProtocols,
|
||||
ContentLength: -1,
|
||||
}
|
||||
return originConn, resp, nil
|
||||
}
|
||||
|
||||
// getRequestHost returns the host of the http.Request.
|
||||
@@ -91,10 +101,35 @@ func getRequestHost(r *http.Request) (string, error) {
|
||||
return "", errors.New("host not found")
|
||||
}
|
||||
|
||||
func (o *bridgeService) destination(r *http.Request) (string, error) {
|
||||
if connection.IsTCPStream(r) {
|
||||
return getRequestHost(r)
|
||||
func (o *tcpOverWSService) EstablishConnection(r *http.Request) (OriginConnection, *http.Response, error) {
|
||||
var err error
|
||||
dest := o.dest
|
||||
if o.isBastion {
|
||||
dest, err = o.bastionDest(r)
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
}
|
||||
|
||||
conn, err := net.Dial("tcp", dest)
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
originConn := &tcpOverWSConnection{
|
||||
conn: conn,
|
||||
streamHandler: o.streamHandler,
|
||||
}
|
||||
resp := &http.Response{
|
||||
Status: switchingProtocolText,
|
||||
StatusCode: http.StatusSwitchingProtocols,
|
||||
Header: websocket.NewResponseHeader(r),
|
||||
ContentLength: -1,
|
||||
}
|
||||
return originConn, resp, nil
|
||||
|
||||
}
|
||||
|
||||
func (o *tcpOverWSService) bastionDest(r *http.Request) (string, error) {
|
||||
jumpDestination := r.Header.Get(h2mux.CFJumpDestinationHeader)
|
||||
if jumpDestination == "" {
|
||||
return "", fmt.Errorf("Did not receive final destination from client. The --destination flag is likely not set on the client side")
|
||||
@@ -110,24 +145,3 @@ func (o *bridgeService) destination(r *http.Request) (string, error) {
|
||||
func removePath(dest string) string {
|
||||
return strings.SplitN(dest, "/", 2)[0]
|
||||
}
|
||||
|
||||
func (o *singleTCPService) EstablishConnection(r *http.Request) (OriginConnection, *http.Response, error) {
|
||||
conn, err := o.client.connect(r, o.dest)
|
||||
return conn, nil, err
|
||||
|
||||
}
|
||||
|
||||
type tcpClient struct {
|
||||
streamHandler streamHandlerFunc
|
||||
}
|
||||
|
||||
func (c *tcpClient) connect(r *http.Request, addr string) (OriginConnection, error) {
|
||||
conn, err := net.Dial("tcp", addr)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &tcpConnection{
|
||||
conn: conn,
|
||||
streamHandler: c.streamHandler,
|
||||
}, nil
|
||||
}
|
||||
|
@@ -2,6 +2,9 @@ package ingress
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/tls"
|
||||
"fmt"
|
||||
"net"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"net/url"
|
||||
@@ -10,12 +13,168 @@ import (
|
||||
|
||||
"github.com/cloudflare/cloudflared/h2mux"
|
||||
"github.com/cloudflare/cloudflared/websocket"
|
||||
"github.com/rs/zerolog"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestBridgeServiceDestination(t *testing.T) {
|
||||
// TestEstablishConnectionResponse ensures each implementation of StreamBasedOriginProxy returns
|
||||
// the expected response
|
||||
func assertEstablishConnectionResponse(t *testing.T,
|
||||
originProxy StreamBasedOriginProxy,
|
||||
req *http.Request,
|
||||
expectHeader http.Header,
|
||||
) {
|
||||
_, resp, err := originProxy.EstablishConnection(req)
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, switchingProtocolText, resp.Status)
|
||||
assert.Equal(t, http.StatusSwitchingProtocols, resp.StatusCode)
|
||||
assert.Equal(t, expectHeader, resp.Header)
|
||||
}
|
||||
|
||||
func TestHTTPServiceEstablishConnection(t *testing.T) {
|
||||
origin := echoWSOrigin(t)
|
||||
defer origin.Close()
|
||||
originURL, err := url.Parse(origin.URL)
|
||||
require.NoError(t, err)
|
||||
|
||||
httpService := &httpService{
|
||||
url: originURL,
|
||||
hostHeader: origin.URL,
|
||||
transport: &http.Transport{
|
||||
TLSClientConfig: &tls.Config{
|
||||
InsecureSkipVerify: true,
|
||||
},
|
||||
},
|
||||
}
|
||||
req, err := http.NewRequest(http.MethodGet, origin.URL, nil)
|
||||
require.NoError(t, err)
|
||||
req.Header.Set("Sec-Websocket-Key", "dGhlIHNhbXBsZSBub25jZQ==")
|
||||
req.Header.Set("Test-Cloudflared-Echo", t.Name())
|
||||
|
||||
expectHeader := http.Header{
|
||||
"Connection": {"Upgrade"},
|
||||
"Sec-Websocket-Accept": {"s3pPLMBiTxaQ9kYGzzhZRbK+xOo="},
|
||||
"Upgrade": {"websocket"},
|
||||
"Test-Cloudflared-Echo": {t.Name()},
|
||||
}
|
||||
assertEstablishConnectionResponse(t, httpService, req, expectHeader)
|
||||
}
|
||||
|
||||
func TestHelloWorldEstablishConnection(t *testing.T) {
|
||||
var wg sync.WaitGroup
|
||||
shutdownC := make(chan struct{})
|
||||
errC := make(chan error)
|
||||
helloWorldSerivce := &helloWorld{}
|
||||
helloWorldSerivce.start(&wg, testLogger, shutdownC, errC, OriginRequestConfig{})
|
||||
|
||||
// Scheme and Host of URL will be override by the Scheme and Host of the helloWorld service
|
||||
req, err := http.NewRequest(http.MethodGet, "https://place-holder/ws", nil)
|
||||
require.NoError(t, err)
|
||||
|
||||
expectHeader := http.Header{
|
||||
"Connection": {"Upgrade"},
|
||||
// Accept key when Sec-Websocket-Key is not specified
|
||||
"Sec-Websocket-Accept": {"Kfh9QIsMVZcl6xEPYxPHzW8SZ8w="},
|
||||
"Upgrade": {"websocket"},
|
||||
}
|
||||
assertEstablishConnectionResponse(t, helloWorldSerivce, req, expectHeader)
|
||||
|
||||
close(shutdownC)
|
||||
}
|
||||
|
||||
func TestRawTCPServiceEstablishConnection(t *testing.T) {
|
||||
originListener, err := net.Listen("tcp", "127.0.0.1:0")
|
||||
require.NoError(t, err)
|
||||
|
||||
listenerClosed := make(chan struct{})
|
||||
tcpListenRoutine(originListener, listenerClosed)
|
||||
|
||||
rawTCPService := &rawTCPService{name: ServiceWarpRouting}
|
||||
|
||||
req, err := http.NewRequest(http.MethodGet, fmt.Sprintf("http://%s", originListener.Addr()), nil)
|
||||
require.NoError(t, err)
|
||||
|
||||
assertEstablishConnectionResponse(t, rawTCPService, req, nil)
|
||||
|
||||
originListener.Close()
|
||||
<-listenerClosed
|
||||
|
||||
req, err = http.NewRequest(http.MethodGet, fmt.Sprintf("http://%s", originListener.Addr()), nil)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Origin not listening for new connection, should return an error
|
||||
_, resp, err := rawTCPService.EstablishConnection(req)
|
||||
require.Error(t, err)
|
||||
require.Nil(t, resp)
|
||||
}
|
||||
|
||||
func TestTCPOverWSServiceEstablishConnection(t *testing.T) {
|
||||
originListener, err := net.Listen("tcp", "127.0.0.1:0")
|
||||
require.NoError(t, err)
|
||||
|
||||
listenerClosed := make(chan struct{})
|
||||
tcpListenRoutine(originListener, listenerClosed)
|
||||
|
||||
originURL := &url.URL{
|
||||
Scheme: "tcp",
|
||||
Host: originListener.Addr().String(),
|
||||
}
|
||||
|
||||
baseReq, err := http.NewRequest(http.MethodGet, "https://place-holder", nil)
|
||||
require.NoError(t, err)
|
||||
baseReq.Header.Set("Sec-Websocket-Key", "dGhlIHNhbXBsZSBub25jZQ==")
|
||||
|
||||
bastionReq := baseReq.Clone(context.Background())
|
||||
bastionReq.Header.Set(h2mux.CFJumpDestinationHeader, originListener.Addr().String())
|
||||
|
||||
expectHeader := http.Header{
|
||||
"Connection": {"Upgrade"},
|
||||
"Sec-Websocket-Accept": {"s3pPLMBiTxaQ9kYGzzhZRbK+xOo="},
|
||||
"Upgrade": {"websocket"},
|
||||
}
|
||||
|
||||
tests := []struct {
|
||||
service *tcpOverWSService
|
||||
req *http.Request
|
||||
expectErr bool
|
||||
}{
|
||||
{
|
||||
service: newTCPOverWSService(originURL),
|
||||
req: baseReq,
|
||||
},
|
||||
{
|
||||
service: newBastionService(),
|
||||
req: bastionReq,
|
||||
},
|
||||
{
|
||||
service: newBastionService(),
|
||||
req: baseReq,
|
||||
expectErr: true,
|
||||
},
|
||||
}
|
||||
|
||||
for _, test := range tests {
|
||||
if test.expectErr {
|
||||
_, resp, err := test.service.EstablishConnection(test.req)
|
||||
assert.Error(t, err)
|
||||
assert.Nil(t, resp)
|
||||
} else {
|
||||
assertEstablishConnectionResponse(t, test.service, test.req, expectHeader)
|
||||
}
|
||||
}
|
||||
|
||||
originListener.Close()
|
||||
<-listenerClosed
|
||||
|
||||
for _, service := range []*tcpOverWSService{newTCPOverWSService(originURL), newBastionService()} {
|
||||
// Origin not listening for new connection, should return an error
|
||||
_, resp, err := service.EstablishConnection(bastionReq)
|
||||
assert.Error(t, err)
|
||||
assert.Nil(t, resp)
|
||||
}
|
||||
}
|
||||
|
||||
func TestBastionDestination(t *testing.T) {
|
||||
canonicalJumpDestHeader := http.CanonicalHeaderKey(h2mux.CFJumpDestinationHeader)
|
||||
tests := []struct {
|
||||
name string
|
||||
@@ -98,12 +257,12 @@ func TestBridgeServiceDestination(t *testing.T) {
|
||||
wantErr: true,
|
||||
},
|
||||
}
|
||||
s := newBridgeService(nil, ServiceBastion)
|
||||
s := newBastionService()
|
||||
for _, test := range tests {
|
||||
r := &http.Request{
|
||||
Header: test.header,
|
||||
}
|
||||
dest, err := s.destination(r)
|
||||
dest, err := s.bastionDest(r)
|
||||
if test.wantErr {
|
||||
assert.Error(t, err, "Test %s expects error", test.name)
|
||||
} else {
|
||||
@@ -139,10 +298,9 @@ func TestHTTPServiceHostHeaderOverride(t *testing.T) {
|
||||
url: originURL,
|
||||
}
|
||||
var wg sync.WaitGroup
|
||||
log := zerolog.Nop()
|
||||
shutdownC := make(chan struct{})
|
||||
errC := make(chan error)
|
||||
require.NoError(t, httpService.start(&wg, &log, shutdownC, errC, cfg))
|
||||
require.NoError(t, httpService.start(&wg, testLogger, shutdownC, errC, cfg))
|
||||
|
||||
req, err := http.NewRequest(http.MethodGet, originURL.String(), nil)
|
||||
require.NoError(t, err)
|
||||
@@ -156,3 +314,17 @@ func TestHTTPServiceHostHeaderOverride(t *testing.T) {
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, http.StatusSwitchingProtocols, resp.StatusCode)
|
||||
}
|
||||
|
||||
func tcpListenRoutine(listener net.Listener, closeChan chan struct{}) {
|
||||
go func() {
|
||||
for {
|
||||
conn, err := listener.Accept()
|
||||
if err != nil {
|
||||
close(closeChan)
|
||||
return
|
||||
}
|
||||
// Close immediately, this test is not about testing read/write on connection
|
||||
conn.Close()
|
||||
}
|
||||
}()
|
||||
}
|
||||
|
@@ -78,46 +78,29 @@ func (o *httpService) String() string {
|
||||
return o.url.String()
|
||||
}
|
||||
|
||||
// bridgeService is like a jump host, the destination is specified by the client
|
||||
type bridgeService struct {
|
||||
client *tcpClient
|
||||
serviceName string
|
||||
// rawTCPService dials TCP to the destination specified by the client
|
||||
// It's used by warp routing
|
||||
type rawTCPService struct {
|
||||
name string
|
||||
}
|
||||
|
||||
// if streamHandler is nil, a default one is set.
|
||||
func newBridgeService(streamHandler streamHandlerFunc, serviceName string) *bridgeService {
|
||||
return &bridgeService{
|
||||
client: &tcpClient{
|
||||
streamHandler: streamHandler,
|
||||
},
|
||||
serviceName: serviceName,
|
||||
}
|
||||
func (o *rawTCPService) String() string {
|
||||
return o.name
|
||||
}
|
||||
|
||||
func (o *bridgeService) String() string {
|
||||
return ServiceBridge + ":" + o.serviceName
|
||||
}
|
||||
|
||||
func (o *bridgeService) start(wg *sync.WaitGroup, log *zerolog.Logger, shutdownC <-chan struct{}, errC chan error, cfg OriginRequestConfig) error {
|
||||
// streamHandler is already set by the constructor.
|
||||
if o.client.streamHandler != nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
if cfg.ProxyType == socksProxy {
|
||||
o.client.streamHandler = socks.StreamHandler
|
||||
} else {
|
||||
o.client.streamHandler = DefaultStreamHandler
|
||||
}
|
||||
func (o *rawTCPService) start(wg *sync.WaitGroup, log *zerolog.Logger, shutdownC <-chan struct{}, errC chan error, cfg OriginRequestConfig) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
type singleTCPService struct {
|
||||
dest string
|
||||
client *tcpClient
|
||||
// tcpOverWSService models TCP origins serving eyeballs connecting over websocket, such as
|
||||
// cloudflared access commands.
|
||||
type tcpOverWSService struct {
|
||||
dest string
|
||||
isBastion bool
|
||||
streamHandler streamHandlerFunc
|
||||
}
|
||||
|
||||
func newSingleTCPService(url *url.URL) *singleTCPService {
|
||||
func newTCPOverWSService(url *url.URL) *tcpOverWSService {
|
||||
switch url.Scheme {
|
||||
case "ssh":
|
||||
addPortIfMissing(url, 22)
|
||||
@@ -128,9 +111,14 @@ func newSingleTCPService(url *url.URL) *singleTCPService {
|
||||
case "tcp":
|
||||
addPortIfMissing(url, 7864) // just a random port since there isn't a default in this case
|
||||
}
|
||||
return &singleTCPService{
|
||||
dest: url.Host,
|
||||
client: &tcpClient{},
|
||||
return &tcpOverWSService{
|
||||
dest: url.Host,
|
||||
}
|
||||
}
|
||||
|
||||
func newBastionService() *tcpOverWSService {
|
||||
return &tcpOverWSService{
|
||||
isBastion: true,
|
||||
}
|
||||
}
|
||||
|
||||
@@ -140,15 +128,18 @@ func addPortIfMissing(uri *url.URL, port int) {
|
||||
}
|
||||
}
|
||||
|
||||
func (o *singleTCPService) String() string {
|
||||
func (o *tcpOverWSService) String() string {
|
||||
if o.isBastion {
|
||||
return ServiceBastion
|
||||
}
|
||||
return o.dest
|
||||
}
|
||||
|
||||
func (o *singleTCPService) start(wg *sync.WaitGroup, log *zerolog.Logger, shutdownC <-chan struct{}, errC chan error, cfg OriginRequestConfig) error {
|
||||
func (o *tcpOverWSService) start(wg *sync.WaitGroup, log *zerolog.Logger, shutdownC <-chan struct{}, errC chan error, cfg OriginRequestConfig) error {
|
||||
if cfg.ProxyType == socksProxy {
|
||||
o.client.streamHandler = socks.StreamHandler
|
||||
o.streamHandler = socks.StreamHandler
|
||||
} else {
|
||||
o.client.streamHandler = DefaultStreamHandler
|
||||
o.streamHandler = DefaultStreamHandler
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
Reference in New Issue
Block a user