mirror of
https://github.com/cloudflare/cloudflared.git
synced 2025-07-27 20:19:57 +00:00
TUN-4168: Transparently proxy websocket connections using stdlib HTTP client instead of gorilla/websocket; move websocket client code into carrier package since it's only used by access subcommands now (#345).
This commit is contained in:
@@ -2,12 +2,9 @@ package ingress
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/tls"
|
||||
"io"
|
||||
"net"
|
||||
"net/http"
|
||||
|
||||
gws "github.com/gorilla/websocket"
|
||||
"github.com/rs/zerolog"
|
||||
|
||||
"github.com/cloudflare/cloudflared/ipaccess"
|
||||
@@ -58,35 +55,6 @@ func (wc *tcpOverWSConnection) Close() {
|
||||
wc.conn.Close()
|
||||
}
|
||||
|
||||
// wsConnection is an OriginConnection that streams WS between eyeball and origin.
|
||||
type wsConnection struct {
|
||||
wsConn *gws.Conn
|
||||
resp *http.Response
|
||||
}
|
||||
|
||||
func (wsc *wsConnection) Stream(ctx context.Context, tunnelConn io.ReadWriter, log *zerolog.Logger) {
|
||||
websocket.Stream(tunnelConn, wsc.wsConn.UnderlyingConn(), log)
|
||||
}
|
||||
|
||||
func (wsc *wsConnection) Close() {
|
||||
wsc.resp.Body.Close()
|
||||
wsc.wsConn.Close()
|
||||
}
|
||||
|
||||
func newWSConnection(clientTLSConfig *tls.Config, r *http.Request) (OriginConnection, *http.Response, error) {
|
||||
d := &gws.Dialer{
|
||||
TLSClientConfig: clientTLSConfig,
|
||||
}
|
||||
wsConn, resp, err := websocket.ClientConnect(r, d)
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
return &wsConnection{
|
||||
wsConn,
|
||||
resp,
|
||||
}, resp, nil
|
||||
}
|
||||
|
||||
// socksProxyOverWSConnection is an OriginConnection that streams SOCKS connections over WS.
|
||||
// The connection to the origin happens inside the SOCKS code as the client specifies the origin
|
||||
// details in the packet.
|
||||
@@ -100,3 +68,16 @@ func (sp *socksProxyOverWSConnection) Stream(ctx context.Context, tunnelConn io.
|
||||
|
||||
func (sp *socksProxyOverWSConnection) Close() {
|
||||
}
|
||||
|
||||
// wsProxyConnection represents a bidirectional stream for a websocket connection to the origin
|
||||
type wsProxyConnection struct {
|
||||
rwc io.ReadWriteCloser
|
||||
}
|
||||
|
||||
func (conn *wsProxyConnection) Stream(ctx context.Context, tunnelConn io.ReadWriter, log *zerolog.Logger) {
|
||||
websocket.Stream(tunnelConn, conn.rwc, log)
|
||||
}
|
||||
|
||||
func (conn *wsProxyConnection) Close() {
|
||||
conn.rwc.Close()
|
||||
}
|
||||
|
@@ -3,13 +3,13 @@ package ingress
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"crypto/tls"
|
||||
"fmt"
|
||||
"io/ioutil"
|
||||
"net"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"net/url"
|
||||
"sync"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
@@ -193,18 +193,26 @@ func TestSocksStreamWSOverTCPConnection(t *testing.T) {
|
||||
func TestStreamWSConnection(t *testing.T) {
|
||||
eyeballConn, edgeConn := net.Pipe()
|
||||
|
||||
origin := echoWSOrigin(t)
|
||||
origin := echoWSOrigin(t, true)
|
||||
defer origin.Close()
|
||||
|
||||
var svc httpService
|
||||
err := svc.start(&sync.WaitGroup{}, testLogger, nil, nil, OriginRequestConfig{
|
||||
NoTLSVerify: true,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
req, err := http.NewRequest(http.MethodGet, origin.URL, nil)
|
||||
require.NoError(t, err)
|
||||
req.Header.Set("Sec-Websocket-Key", "dGhlIHNhbXBsZSBub25jZQ==")
|
||||
req.Header.Set("Connection", "Upgrade")
|
||||
req.Header.Set("Upgrade", "websocket")
|
||||
|
||||
conn, resp, err := svc.newWebsocketProxyConnection(req)
|
||||
|
||||
clientTLSConfig := &tls.Config{
|
||||
InsecureSkipVerify: true,
|
||||
}
|
||||
wsConn, resp, err := newWSConnection(clientTLSConfig, req)
|
||||
require.NoError(t, err)
|
||||
defer conn.Close()
|
||||
|
||||
require.Equal(t, http.StatusSwitchingProtocols, resp.StatusCode)
|
||||
require.Equal(t, "Upgrade", resp.Header.Get("Connection"))
|
||||
require.Equal(t, "s3pPLMBiTxaQ9kYGzzhZRbK+xOo=", resp.Header.Get("Sec-Websocket-Accept"))
|
||||
@@ -213,13 +221,37 @@ func TestStreamWSConnection(t *testing.T) {
|
||||
ctx, cancel := context.WithTimeout(context.Background(), testStreamTimeout)
|
||||
defer cancel()
|
||||
|
||||
connClosed := make(chan struct{})
|
||||
|
||||
errGroup, ctx := errgroup.WithContext(ctx)
|
||||
errGroup.Go(func() error {
|
||||
select {
|
||||
case <-connClosed:
|
||||
case <-ctx.Done():
|
||||
}
|
||||
if ctx.Err() == context.DeadlineExceeded {
|
||||
eyeballConn.Close()
|
||||
edgeConn.Close()
|
||||
conn.Close()
|
||||
}
|
||||
|
||||
return ctx.Err()
|
||||
})
|
||||
|
||||
errGroup.Go(func() error {
|
||||
echoWSEyeball(t, eyeballConn)
|
||||
fmt.Println("closing pipe")
|
||||
edgeConn.Close()
|
||||
return eyeballConn.Close()
|
||||
})
|
||||
|
||||
errGroup.Go(func() error {
|
||||
defer conn.Close()
|
||||
conn.Stream(ctx, edgeConn, testLogger)
|
||||
close(connClosed)
|
||||
return nil
|
||||
})
|
||||
|
||||
wsConn.Stream(ctx, edgeConn, testLogger)
|
||||
require.NoError(t, errGroup.Wait())
|
||||
}
|
||||
|
||||
@@ -241,17 +273,23 @@ func (wse *wsEyeball) Write(p []byte) (int, error) {
|
||||
}
|
||||
|
||||
func echoWSEyeball(t *testing.T, conn net.Conn) {
|
||||
require.NoError(t, wsutil.WriteClientBinary(conn, testMessage))
|
||||
defer func() {
|
||||
assert.NoError(t, conn.Close())
|
||||
}()
|
||||
|
||||
if !assert.NoError(t, wsutil.WriteClientBinary(conn, testMessage)) {
|
||||
return
|
||||
}
|
||||
|
||||
readMsg, err := wsutil.ReadServerBinary(conn)
|
||||
require.NoError(t, err)
|
||||
if !assert.NoError(t, err) {
|
||||
return
|
||||
}
|
||||
|
||||
require.Equal(t, testResponse, readMsg)
|
||||
|
||||
require.NoError(t, conn.Close())
|
||||
assert.Equal(t, testResponse, readMsg)
|
||||
}
|
||||
|
||||
func echoWSOrigin(t *testing.T) *httptest.Server {
|
||||
func echoWSOrigin(t *testing.T, expectMessages bool) *httptest.Server {
|
||||
var upgrader = gorillaWS.Upgrader{
|
||||
ReadBufferSize: 10,
|
||||
WriteBufferSize: 10,
|
||||
@@ -268,12 +306,17 @@ func echoWSOrigin(t *testing.T) *httptest.Server {
|
||||
require.NoError(t, err)
|
||||
defer conn.Close()
|
||||
|
||||
sawMessage := false
|
||||
for {
|
||||
messageType, p, err := conn.ReadMessage()
|
||||
if err != nil {
|
||||
if expectMessages && !sawMessage {
|
||||
t.Errorf("unexpected error: %v", err)
|
||||
}
|
||||
return
|
||||
}
|
||||
require.Equal(t, testMessage, p)
|
||||
assert.Equal(t, testMessage, p)
|
||||
sawMessage = true
|
||||
if err := conn.WriteMessage(messageType, testResponse); err != nil {
|
||||
return
|
||||
}
|
||||
|
@@ -2,8 +2,10 @@ package ingress
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"io"
|
||||
"net"
|
||||
"net/http"
|
||||
"strings"
|
||||
|
||||
"github.com/pkg/errors"
|
||||
|
||||
@@ -12,7 +14,8 @@ import (
|
||||
)
|
||||
|
||||
var (
|
||||
switchingProtocolText = fmt.Sprintf("%d %s", http.StatusSwitchingProtocols, http.StatusText(http.StatusSwitchingProtocols))
|
||||
switchingProtocolText = fmt.Sprintf("%d %s", http.StatusSwitchingProtocols, http.StatusText(http.StatusSwitchingProtocols))
|
||||
errUnsupportedConnectionType = errors.New("internal error: unsupported connection type")
|
||||
)
|
||||
|
||||
// HTTPOriginProxy can be implemented by origin services that want to proxy http requests.
|
||||
@@ -42,26 +45,64 @@ func (o *httpService) RoundTrip(req *http.Request) (*http.Response, error) {
|
||||
}
|
||||
|
||||
func (o *httpService) EstablishConnection(req *http.Request) (OriginConnection, *http.Response, error) {
|
||||
req = req.Clone(req.Context())
|
||||
|
||||
req.URL.Host = o.url.Host
|
||||
req.URL.Scheme = websocket.ChangeRequestScheme(o.url)
|
||||
req.URL.Scheme = o.url.Scheme
|
||||
// allow ws(s) scheme for websocket-only origins, normal http(s) requests will fail
|
||||
switch req.URL.Scheme {
|
||||
case "ws":
|
||||
req.URL.Scheme = "http"
|
||||
case "wss":
|
||||
req.URL.Scheme = "https"
|
||||
}
|
||||
|
||||
if o.hostHeader != "" {
|
||||
// For incoming requests, the Host header is promoted to the Request.Host field and removed from the Header map.
|
||||
req.Host = o.hostHeader
|
||||
}
|
||||
return newWSConnection(o.transport.TLSClientConfig, req)
|
||||
|
||||
return o.newWebsocketProxyConnection(req)
|
||||
}
|
||||
|
||||
func (o *helloWorld) RoundTrip(req *http.Request) (*http.Response, error) {
|
||||
// Rewrite the request URL so that it goes to the Hello World server.
|
||||
req.URL.Host = o.server.Addr().String()
|
||||
req.URL.Scheme = "https"
|
||||
return o.transport.RoundTrip(req)
|
||||
}
|
||||
func (o *httpService) newWebsocketProxyConnection(req *http.Request) (OriginConnection, *http.Response, error) {
|
||||
req.Header.Set("Connection", "Upgrade")
|
||||
req.Header.Set("Upgrade", "websocket")
|
||||
req.Header.Set("Sec-WebSocket-Version", "13")
|
||||
|
||||
func (o *helloWorld) EstablishConnection(req *http.Request) (OriginConnection, *http.Response, error) {
|
||||
req.URL.Host = o.server.Addr().String()
|
||||
req.URL.Scheme = "wss"
|
||||
return newWSConnection(o.transport.TLSClientConfig, req)
|
||||
req.ContentLength = 0
|
||||
req.Body = nil
|
||||
|
||||
resp, err := o.transport.RoundTrip(req)
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
|
||||
toClose := resp.Body
|
||||
defer func() {
|
||||
if toClose != nil {
|
||||
_ = toClose.Close()
|
||||
}
|
||||
}()
|
||||
|
||||
if resp.StatusCode != http.StatusSwitchingProtocols {
|
||||
return nil, nil, fmt.Errorf("unexpected origin response: %s", resp.Status)
|
||||
}
|
||||
if strings.ToLower(resp.Header.Get("Upgrade")) != "websocket" {
|
||||
return nil, nil, fmt.Errorf("unexpected upgrade: %q", resp.Header.Get("Upgrade"))
|
||||
}
|
||||
|
||||
rwc, ok := resp.Body.(io.ReadWriteCloser)
|
||||
if !ok {
|
||||
return nil, nil, errUnsupportedConnectionType
|
||||
}
|
||||
conn := wsProxyConnection{
|
||||
rwc: rwc,
|
||||
}
|
||||
// clear to prevent defer from closing
|
||||
toClose = nil
|
||||
|
||||
return &conn, resp, nil
|
||||
}
|
||||
|
||||
func (o *statusCode) RoundTrip(_ *http.Request) (*http.Response, error) {
|
||||
|
@@ -33,7 +33,7 @@ func assertEstablishConnectionResponse(t *testing.T,
|
||||
}
|
||||
|
||||
func TestHTTPServiceEstablishConnection(t *testing.T) {
|
||||
origin := echoWSOrigin(t)
|
||||
origin := echoWSOrigin(t, false)
|
||||
defer origin.Close()
|
||||
originURL, err := url.Parse(origin.URL)
|
||||
require.NoError(t, err)
|
||||
@@ -71,11 +71,11 @@ func TestHelloWorldEstablishConnection(t *testing.T) {
|
||||
// Scheme and Host of URL will be override by the Scheme and Host of the helloWorld service
|
||||
req, err := http.NewRequest(http.MethodGet, "https://place-holder/ws", nil)
|
||||
require.NoError(t, err)
|
||||
req.Header.Set("Sec-Websocket-Key", "dGhlIHNhbXBsZSBub25jZQ==")
|
||||
|
||||
expectHeader := http.Header{
|
||||
"Connection": {"Upgrade"},
|
||||
// Accept key when Sec-Websocket-Key is not specified
|
||||
"Sec-Websocket-Accept": {"Kfh9QIsMVZcl6xEPYxPHzW8SZ8w="},
|
||||
"Connection": {"Upgrade"},
|
||||
"Sec-Websocket-Accept": {"s3pPLMBiTxaQ9kYGzzhZRbK+xOo="},
|
||||
"Upgrade": {"websocket"},
|
||||
}
|
||||
assertEstablishConnectionResponse(t, helloWorldSerivce, req, expectHeader)
|
||||
|
@@ -11,7 +11,6 @@ import (
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
gws "github.com/gorilla/websocket"
|
||||
"github.com/pkg/errors"
|
||||
"github.com/rs/zerolog"
|
||||
|
||||
@@ -19,7 +18,6 @@ import (
|
||||
"github.com/cloudflare/cloudflared/ipaccess"
|
||||
"github.com/cloudflare/cloudflared/socks"
|
||||
"github.com/cloudflare/cloudflared/tlsconfig"
|
||||
"github.com/cloudflare/cloudflared/websocket"
|
||||
)
|
||||
|
||||
// originService is something a tunnel can proxy traffic to.
|
||||
@@ -50,16 +48,6 @@ func (o *unixSocketPath) start(wg *sync.WaitGroup, log *zerolog.Logger, shutdown
|
||||
return nil
|
||||
}
|
||||
|
||||
func (o *unixSocketPath) Dial(reqURL *url.URL, headers http.Header) (*gws.Conn, *http.Response, error) {
|
||||
d := &gws.Dialer{
|
||||
NetDial: o.transport.Dial,
|
||||
NetDialContext: o.transport.DialContext,
|
||||
TLSClientConfig: o.transport.TLSClientConfig,
|
||||
}
|
||||
reqURL.Scheme = websocket.ChangeRequestScheme(reqURL)
|
||||
return d.Dial(reqURL.String(), headers)
|
||||
}
|
||||
|
||||
type httpService struct {
|
||||
url *url.URL
|
||||
hostHeader string
|
||||
@@ -171,8 +159,8 @@ func (o *socksProxyOverWSService) String() string {
|
||||
// HelloWorld is an OriginService for the built-in Hello World server.
|
||||
// Users only use this for testing and experimenting with cloudflared.
|
||||
type helloWorld struct {
|
||||
server net.Listener
|
||||
transport *http.Transport
|
||||
httpService
|
||||
server net.Listener
|
||||
}
|
||||
|
||||
func (o *helloWorld) String() string {
|
||||
@@ -187,11 +175,10 @@ func (o *helloWorld) start(
|
||||
errC chan error,
|
||||
cfg OriginRequestConfig,
|
||||
) error {
|
||||
transport, err := newHTTPTransport(o, cfg, log)
|
||||
if err != nil {
|
||||
if err := o.httpService.start(wg, log, shutdownC, errC, cfg); err != nil {
|
||||
return err
|
||||
}
|
||||
o.transport = transport
|
||||
|
||||
helloListener, err := hello.CreateTLSListener("127.0.0.1:")
|
||||
if err != nil {
|
||||
return errors.Wrap(err, "Cannot start Hello World Server")
|
||||
@@ -202,6 +189,12 @@ func (o *helloWorld) start(
|
||||
_ = hello.StartHelloWorldServer(log, helloListener, shutdownC)
|
||||
}()
|
||||
o.server = helloListener
|
||||
|
||||
o.httpService.url = &url.URL{
|
||||
Scheme: "https",
|
||||
Host: o.server.Addr().String(),
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
|
Reference in New Issue
Block a user