TUN-4168: Transparently proxy websocket connections using stdlib HTTP client instead of gorilla/websocket; move websocket client code into carrier package since it's only used by access subcommands now (#345).

This commit is contained in:
Igor Postelnik
2021-04-02 01:10:43 -05:00
parent b25d38dd72
commit 3ad99b241c
12 changed files with 455 additions and 315 deletions

View File

@@ -2,12 +2,9 @@ package ingress
import (
"context"
"crypto/tls"
"io"
"net"
"net/http"
gws "github.com/gorilla/websocket"
"github.com/rs/zerolog"
"github.com/cloudflare/cloudflared/ipaccess"
@@ -58,35 +55,6 @@ func (wc *tcpOverWSConnection) Close() {
wc.conn.Close()
}
// wsConnection is an OriginConnection that streams WS between eyeball and origin.
type wsConnection struct {
wsConn *gws.Conn
resp *http.Response
}
func (wsc *wsConnection) Stream(ctx context.Context, tunnelConn io.ReadWriter, log *zerolog.Logger) {
websocket.Stream(tunnelConn, wsc.wsConn.UnderlyingConn(), log)
}
func (wsc *wsConnection) Close() {
wsc.resp.Body.Close()
wsc.wsConn.Close()
}
func newWSConnection(clientTLSConfig *tls.Config, r *http.Request) (OriginConnection, *http.Response, error) {
d := &gws.Dialer{
TLSClientConfig: clientTLSConfig,
}
wsConn, resp, err := websocket.ClientConnect(r, d)
if err != nil {
return nil, nil, err
}
return &wsConnection{
wsConn,
resp,
}, resp, nil
}
// socksProxyOverWSConnection is an OriginConnection that streams SOCKS connections over WS.
// The connection to the origin happens inside the SOCKS code as the client specifies the origin
// details in the packet.
@@ -100,3 +68,16 @@ func (sp *socksProxyOverWSConnection) Stream(ctx context.Context, tunnelConn io.
func (sp *socksProxyOverWSConnection) Close() {
}
// wsProxyConnection represents a bidirectional stream for a websocket connection to the origin
type wsProxyConnection struct {
rwc io.ReadWriteCloser
}
func (conn *wsProxyConnection) Stream(ctx context.Context, tunnelConn io.ReadWriter, log *zerolog.Logger) {
websocket.Stream(tunnelConn, conn.rwc, log)
}
func (conn *wsProxyConnection) Close() {
conn.rwc.Close()
}

View File

@@ -3,13 +3,13 @@ package ingress
import (
"bytes"
"context"
"crypto/tls"
"fmt"
"io/ioutil"
"net"
"net/http"
"net/http/httptest"
"net/url"
"sync"
"testing"
"time"
@@ -193,18 +193,26 @@ func TestSocksStreamWSOverTCPConnection(t *testing.T) {
func TestStreamWSConnection(t *testing.T) {
eyeballConn, edgeConn := net.Pipe()
origin := echoWSOrigin(t)
origin := echoWSOrigin(t, true)
defer origin.Close()
var svc httpService
err := svc.start(&sync.WaitGroup{}, testLogger, nil, nil, OriginRequestConfig{
NoTLSVerify: true,
})
require.NoError(t, err)
req, err := http.NewRequest(http.MethodGet, origin.URL, nil)
require.NoError(t, err)
req.Header.Set("Sec-Websocket-Key", "dGhlIHNhbXBsZSBub25jZQ==")
req.Header.Set("Connection", "Upgrade")
req.Header.Set("Upgrade", "websocket")
conn, resp, err := svc.newWebsocketProxyConnection(req)
clientTLSConfig := &tls.Config{
InsecureSkipVerify: true,
}
wsConn, resp, err := newWSConnection(clientTLSConfig, req)
require.NoError(t, err)
defer conn.Close()
require.Equal(t, http.StatusSwitchingProtocols, resp.StatusCode)
require.Equal(t, "Upgrade", resp.Header.Get("Connection"))
require.Equal(t, "s3pPLMBiTxaQ9kYGzzhZRbK+xOo=", resp.Header.Get("Sec-Websocket-Accept"))
@@ -213,13 +221,37 @@ func TestStreamWSConnection(t *testing.T) {
ctx, cancel := context.WithTimeout(context.Background(), testStreamTimeout)
defer cancel()
connClosed := make(chan struct{})
errGroup, ctx := errgroup.WithContext(ctx)
errGroup.Go(func() error {
select {
case <-connClosed:
case <-ctx.Done():
}
if ctx.Err() == context.DeadlineExceeded {
eyeballConn.Close()
edgeConn.Close()
conn.Close()
}
return ctx.Err()
})
errGroup.Go(func() error {
echoWSEyeball(t, eyeballConn)
fmt.Println("closing pipe")
edgeConn.Close()
return eyeballConn.Close()
})
errGroup.Go(func() error {
defer conn.Close()
conn.Stream(ctx, edgeConn, testLogger)
close(connClosed)
return nil
})
wsConn.Stream(ctx, edgeConn, testLogger)
require.NoError(t, errGroup.Wait())
}
@@ -241,17 +273,23 @@ func (wse *wsEyeball) Write(p []byte) (int, error) {
}
func echoWSEyeball(t *testing.T, conn net.Conn) {
require.NoError(t, wsutil.WriteClientBinary(conn, testMessage))
defer func() {
assert.NoError(t, conn.Close())
}()
if !assert.NoError(t, wsutil.WriteClientBinary(conn, testMessage)) {
return
}
readMsg, err := wsutil.ReadServerBinary(conn)
require.NoError(t, err)
if !assert.NoError(t, err) {
return
}
require.Equal(t, testResponse, readMsg)
require.NoError(t, conn.Close())
assert.Equal(t, testResponse, readMsg)
}
func echoWSOrigin(t *testing.T) *httptest.Server {
func echoWSOrigin(t *testing.T, expectMessages bool) *httptest.Server {
var upgrader = gorillaWS.Upgrader{
ReadBufferSize: 10,
WriteBufferSize: 10,
@@ -268,12 +306,17 @@ func echoWSOrigin(t *testing.T) *httptest.Server {
require.NoError(t, err)
defer conn.Close()
sawMessage := false
for {
messageType, p, err := conn.ReadMessage()
if err != nil {
if expectMessages && !sawMessage {
t.Errorf("unexpected error: %v", err)
}
return
}
require.Equal(t, testMessage, p)
assert.Equal(t, testMessage, p)
sawMessage = true
if err := conn.WriteMessage(messageType, testResponse); err != nil {
return
}

View File

@@ -2,8 +2,10 @@ package ingress
import (
"fmt"
"io"
"net"
"net/http"
"strings"
"github.com/pkg/errors"
@@ -12,7 +14,8 @@ import (
)
var (
switchingProtocolText = fmt.Sprintf("%d %s", http.StatusSwitchingProtocols, http.StatusText(http.StatusSwitchingProtocols))
switchingProtocolText = fmt.Sprintf("%d %s", http.StatusSwitchingProtocols, http.StatusText(http.StatusSwitchingProtocols))
errUnsupportedConnectionType = errors.New("internal error: unsupported connection type")
)
// HTTPOriginProxy can be implemented by origin services that want to proxy http requests.
@@ -42,26 +45,64 @@ func (o *httpService) RoundTrip(req *http.Request) (*http.Response, error) {
}
func (o *httpService) EstablishConnection(req *http.Request) (OriginConnection, *http.Response, error) {
req = req.Clone(req.Context())
req.URL.Host = o.url.Host
req.URL.Scheme = websocket.ChangeRequestScheme(o.url)
req.URL.Scheme = o.url.Scheme
// allow ws(s) scheme for websocket-only origins, normal http(s) requests will fail
switch req.URL.Scheme {
case "ws":
req.URL.Scheme = "http"
case "wss":
req.URL.Scheme = "https"
}
if o.hostHeader != "" {
// For incoming requests, the Host header is promoted to the Request.Host field and removed from the Header map.
req.Host = o.hostHeader
}
return newWSConnection(o.transport.TLSClientConfig, req)
return o.newWebsocketProxyConnection(req)
}
func (o *helloWorld) RoundTrip(req *http.Request) (*http.Response, error) {
// Rewrite the request URL so that it goes to the Hello World server.
req.URL.Host = o.server.Addr().String()
req.URL.Scheme = "https"
return o.transport.RoundTrip(req)
}
func (o *httpService) newWebsocketProxyConnection(req *http.Request) (OriginConnection, *http.Response, error) {
req.Header.Set("Connection", "Upgrade")
req.Header.Set("Upgrade", "websocket")
req.Header.Set("Sec-WebSocket-Version", "13")
func (o *helloWorld) EstablishConnection(req *http.Request) (OriginConnection, *http.Response, error) {
req.URL.Host = o.server.Addr().String()
req.URL.Scheme = "wss"
return newWSConnection(o.transport.TLSClientConfig, req)
req.ContentLength = 0
req.Body = nil
resp, err := o.transport.RoundTrip(req)
if err != nil {
return nil, nil, err
}
toClose := resp.Body
defer func() {
if toClose != nil {
_ = toClose.Close()
}
}()
if resp.StatusCode != http.StatusSwitchingProtocols {
return nil, nil, fmt.Errorf("unexpected origin response: %s", resp.Status)
}
if strings.ToLower(resp.Header.Get("Upgrade")) != "websocket" {
return nil, nil, fmt.Errorf("unexpected upgrade: %q", resp.Header.Get("Upgrade"))
}
rwc, ok := resp.Body.(io.ReadWriteCloser)
if !ok {
return nil, nil, errUnsupportedConnectionType
}
conn := wsProxyConnection{
rwc: rwc,
}
// clear to prevent defer from closing
toClose = nil
return &conn, resp, nil
}
func (o *statusCode) RoundTrip(_ *http.Request) (*http.Response, error) {

View File

@@ -33,7 +33,7 @@ func assertEstablishConnectionResponse(t *testing.T,
}
func TestHTTPServiceEstablishConnection(t *testing.T) {
origin := echoWSOrigin(t)
origin := echoWSOrigin(t, false)
defer origin.Close()
originURL, err := url.Parse(origin.URL)
require.NoError(t, err)
@@ -71,11 +71,11 @@ func TestHelloWorldEstablishConnection(t *testing.T) {
// Scheme and Host of URL will be override by the Scheme and Host of the helloWorld service
req, err := http.NewRequest(http.MethodGet, "https://place-holder/ws", nil)
require.NoError(t, err)
req.Header.Set("Sec-Websocket-Key", "dGhlIHNhbXBsZSBub25jZQ==")
expectHeader := http.Header{
"Connection": {"Upgrade"},
// Accept key when Sec-Websocket-Key is not specified
"Sec-Websocket-Accept": {"Kfh9QIsMVZcl6xEPYxPHzW8SZ8w="},
"Connection": {"Upgrade"},
"Sec-Websocket-Accept": {"s3pPLMBiTxaQ9kYGzzhZRbK+xOo="},
"Upgrade": {"websocket"},
}
assertEstablishConnectionResponse(t, helloWorldSerivce, req, expectHeader)

View File

@@ -11,7 +11,6 @@ import (
"sync"
"time"
gws "github.com/gorilla/websocket"
"github.com/pkg/errors"
"github.com/rs/zerolog"
@@ -19,7 +18,6 @@ import (
"github.com/cloudflare/cloudflared/ipaccess"
"github.com/cloudflare/cloudflared/socks"
"github.com/cloudflare/cloudflared/tlsconfig"
"github.com/cloudflare/cloudflared/websocket"
)
// originService is something a tunnel can proxy traffic to.
@@ -50,16 +48,6 @@ func (o *unixSocketPath) start(wg *sync.WaitGroup, log *zerolog.Logger, shutdown
return nil
}
func (o *unixSocketPath) Dial(reqURL *url.URL, headers http.Header) (*gws.Conn, *http.Response, error) {
d := &gws.Dialer{
NetDial: o.transport.Dial,
NetDialContext: o.transport.DialContext,
TLSClientConfig: o.transport.TLSClientConfig,
}
reqURL.Scheme = websocket.ChangeRequestScheme(reqURL)
return d.Dial(reqURL.String(), headers)
}
type httpService struct {
url *url.URL
hostHeader string
@@ -171,8 +159,8 @@ func (o *socksProxyOverWSService) String() string {
// HelloWorld is an OriginService for the built-in Hello World server.
// Users only use this for testing and experimenting with cloudflared.
type helloWorld struct {
server net.Listener
transport *http.Transport
httpService
server net.Listener
}
func (o *helloWorld) String() string {
@@ -187,11 +175,10 @@ func (o *helloWorld) start(
errC chan error,
cfg OriginRequestConfig,
) error {
transport, err := newHTTPTransport(o, cfg, log)
if err != nil {
if err := o.httpService.start(wg, log, shutdownC, errC, cfg); err != nil {
return err
}
o.transport = transport
helloListener, err := hello.CreateTLSListener("127.0.0.1:")
if err != nil {
return errors.Wrap(err, "Cannot start Hello World Server")
@@ -202,6 +189,12 @@ func (o *helloWorld) start(
_ = hello.StartHelloWorldServer(log, helloListener, shutdownC)
}()
o.server = helloListener
o.httpService.url = &url.URL{
Scheme: "https",
Host: o.server.Addr().String(),
}
return nil
}