TUN-4168: Transparently proxy websocket connections using stdlib HTTP client instead of gorilla/websocket; move websocket client code into carrier package since it's only used by access subcommands now (#345).

This commit is contained in:
Igor Postelnik
2021-04-02 01:10:43 -05:00
parent b25d38dd72
commit 3ad99b241c
12 changed files with 455 additions and 315 deletions

View File

@@ -3,116 +3,47 @@ package websocket
import (
"crypto/sha1"
"encoding/base64"
"encoding/hex"
"errors"
"fmt"
"io"
"net/http"
"net/url"
"time"
"github.com/gorilla/websocket"
"github.com/rs/zerolog"
)
var stripWebsocketHeaders = []string{
"Upgrade",
"Connection",
"Sec-Websocket-Key",
"Sec-Websocket-Version",
"Sec-Websocket-Extensions",
}
// IsWebSocketUpgrade checks to see if the request is a WebSocket connection.
func IsWebSocketUpgrade(req *http.Request) bool {
return websocket.IsWebSocketUpgrade(req)
}
// ClientConnect creates a WebSocket client connection for provided request. Caller is responsible for closing
// the connection. The response body may not contain the entire response and does
// not need to be closed by the application.
func ClientConnect(req *http.Request, dialler *websocket.Dialer) (*websocket.Conn, *http.Response, error) {
req.URL.Scheme = ChangeRequestScheme(req.URL)
wsHeaders := websocketHeaders(req)
if dialler == nil {
dialler = &websocket.Dialer{
Proxy: http.ProxyFromEnvironment,
}
}
conn, response, err := dialler.Dial(req.URL.String(), wsHeaders)
if err != nil {
return nil, response, err
}
response.Header.Set("Sec-WebSocket-Accept", generateAcceptKey(req))
return conn, response, nil
}
// NewResponseHeader returns headers needed to return to origin for completing handshake
func NewResponseHeader(req *http.Request) http.Header {
header := http.Header{}
header.Add("Connection", "Upgrade")
header.Add("Sec-Websocket-Accept", generateAcceptKey(req))
header.Add("Sec-Websocket-Accept", generateAcceptKey(req.Header.Get("Sec-WebSocket-Key")))
header.Add("Upgrade", "websocket")
return header
}
// the gorilla websocket library sets its own Upgrade, Connection, Sec-WebSocket-Key,
// Sec-WebSocket-Version and Sec-Websocket-Extensions headers.
// https://github.com/gorilla/websocket/blob/master/client.go#L189-L194.
func websocketHeaders(req *http.Request) http.Header {
wsHeaders := make(http.Header)
for key, val := range req.Header {
wsHeaders[key] = val
}
// Assume the header keys are in canonical format.
for _, header := range stripWebsocketHeaders {
wsHeaders.Del(header)
}
wsHeaders.Set("Host", req.Host) // See TUN-1097
return wsHeaders
}
// sha1Base64 sha1 and then base64 encodes str.
func sha1Base64(str string) string {
hasher := sha1.New()
_, _ = io.WriteString(hasher, str)
hash := hasher.Sum(nil)
return base64.StdEncoding.EncodeToString(hash)
}
// generateAcceptKey returns the string needed for the Sec-WebSocket-Accept header.
// https://tools.ietf.org/html/rfc6455#section-1.3 describes this process in more detail.
func generateAcceptKey(req *http.Request) string {
return sha1Base64(req.Header.Get("Sec-WebSocket-Key") + "258EAFA5-E914-47DA-95CA-C5AB0DC85B11")
}
// ChangeRequestScheme is needed as the gorilla websocket library requires the ws scheme.
// (even though it changes it back to http/https, but ¯\_(ツ)_/¯.)
func ChangeRequestScheme(reqURL *url.URL) string {
switch reqURL.Scheme {
case "https":
return "wss"
case "http":
return "ws"
case "":
return "ws"
default:
return reqURL.Scheme
}
}
// Stream copies copy data to & from provided io.ReadWriters.
func Stream(conn, backendConn io.ReadWriter, log *zerolog.Logger) {
func Stream(tunnelConn, originConn io.ReadWriter, log *zerolog.Logger) {
proxyDone := make(chan struct{}, 2)
go func() {
_, err := io.Copy(conn, backendConn)
_, err := copyData(tunnelConn, originConn, "origin->tunnel")
if err != nil {
log.Debug().Msgf("conn to backendConn copy: %v", err)
log.Debug().Msgf("origin to tunnel copy: %v", err)
}
proxyDone <- struct{}{}
}()
go func() {
_, err := io.Copy(backendConn, conn)
_, err := copyData(originConn, tunnelConn, "tunnel->origin")
if err != nil {
log.Debug().Msgf("backendConn to conn copy: %v", err)
log.Debug().Msgf("tunnel to origin copy: %v", err)
}
proxyDone <- struct{}{}
}()
@@ -120,3 +51,60 @@ func Stream(conn, backendConn io.ReadWriter, log *zerolog.Logger) {
// If one side is done, we are done.
<-proxyDone
}
// when set to true, enables logging of content copied to/from origin and tunnel
const debugCopy = false
func copyData(dst io.Writer, src io.Reader, dir string) (written int64, err error) {
if debugCopy {
// copyBuffer is based on stdio Copy implementation but shows copied data
copyBuffer := func(dst io.Writer, src io.Reader, dir string) (written int64, err error) {
var buf []byte
size := 32 * 1024
buf = make([]byte, size)
for {
t := time.Now()
nr, er := src.Read(buf)
if nr > 0 {
fmt.Println(dir, t.UnixNano(), "\n"+hex.Dump(buf[0:nr]))
nw, ew := dst.Write(buf[0:nr])
if nw < 0 || nr < nw {
nw = 0
if ew == nil {
ew = errors.New("invalid write")
}
}
written += int64(nw)
if ew != nil {
err = ew
break
}
if nr != nw {
err = io.ErrShortWrite
break
}
}
if er != nil {
if er != io.EOF {
err = er
}
break
}
}
return written, err
}
return copyBuffer(dst, src, dir)
} else {
return io.Copy(dst, src)
}
}
// from RFC-6455
var keyGUID = []byte("258EAFA5-E914-47DA-95CA-C5AB0DC85B11")
func generateAcceptKey(challengeKey string) string {
h := sha1.New()
h.Write([]byte(challengeKey))
h.Write(keyGUID)
return base64.StdEncoding.EncodeToString(h.Sum(nil))
}

View File

@@ -1,24 +1,9 @@
package websocket
import (
"context"
"crypto/tls"
"crypto/x509"
"fmt"
"io"
"math/rand"
"net/http"
"testing"
"time"
gws "github.com/gorilla/websocket"
"github.com/rs/zerolog"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"golang.org/x/net/websocket"
"github.com/cloudflare/cloudflared/hello"
"github.com/cloudflare/cloudflared/tlsconfig"
)
const (
@@ -28,126 +13,6 @@ const (
testSecWebsocketAccept = "s3pPLMBiTxaQ9kYGzzhZRbK+xOo="
)
func testRequest(t *testing.T, url string, stream io.ReadWriter) *http.Request {
req, err := http.NewRequest("GET", url, stream)
if err != nil {
t.Fatalf("testRequestHeader error")
}
req.Header.Add("Connection", "Upgrade")
req.Header.Add("Upgrade", "WebSocket")
req.Header.Add("Sec-Websocket-Key", testSecWebsocketKey)
req.Header.Add("Sec-Websocket-Protocol", "tunnel-protocol")
req.Header.Add("Sec-Websocket-Version", "13")
req.Header.Add("User-Agent", "curl/7.59.0")
return req
}
func websocketClientTLSConfig(t *testing.T) *tls.Config {
certPool := x509.NewCertPool()
helloCert, err := tlsconfig.GetHelloCertificateX509()
assert.NoError(t, err)
certPool.AddCert(helloCert)
assert.NotNil(t, certPool)
return &tls.Config{RootCAs: certPool}
}
func TestWebsocketHeaders(t *testing.T) {
req := testRequest(t, "http://example.com", nil)
wsHeaders := websocketHeaders(req)
for _, header := range stripWebsocketHeaders {
assert.Empty(t, wsHeaders[header])
}
assert.Equal(t, "curl/7.59.0", wsHeaders.Get("User-Agent"))
}
func TestGenerateAcceptKey(t *testing.T) {
req := testRequest(t, "http://example.com", nil)
assert.Equal(t, testSecWebsocketAccept, generateAcceptKey(req))
}
func TestServe(t *testing.T) {
log := zerolog.Nop()
shutdownC := make(chan struct{})
errC := make(chan error)
listener, err := hello.CreateTLSListener("localhost:1111")
assert.NoError(t, err)
defer listener.Close()
go func() {
errC <- hello.StartHelloWorldServer(&log, listener, shutdownC)
}()
req := testRequest(t, "https://localhost:1111/ws", nil)
tlsConfig := websocketClientTLSConfig(t)
assert.NotNil(t, tlsConfig)
d := gws.Dialer{TLSClientConfig: tlsConfig}
conn, resp, err := ClientConnect(req, &d)
assert.NoError(t, err)
assert.Equal(t, testSecWebsocketAccept, resp.Header.Get("Sec-WebSocket-Accept"))
for i := 0; i < 1000; i++ {
messageSize := rand.Int()%2048 + 1
clientMessage := make([]byte, messageSize)
// rand.Read always returns len(clientMessage) and a nil error
rand.Read(clientMessage)
err = conn.WriteMessage(websocket.BinaryFrame, clientMessage)
assert.NoError(t, err)
messageType, message, err := conn.ReadMessage()
assert.NoError(t, err)
assert.Equal(t, websocket.BinaryFrame, messageType)
assert.Equal(t, clientMessage, message)
}
_ = conn.Close()
close(shutdownC)
<-errC
}
func TestWebsocketWrapper(t *testing.T) {
listener, err := hello.CreateTLSListener("localhost:0")
require.NoError(t, err)
serverErrorChan := make(chan error)
helloSvrCtx, cancelHelloSvr := context.WithCancel(context.Background())
defer func() { <-serverErrorChan }()
defer cancelHelloSvr()
go func() {
log := zerolog.Nop()
serverErrorChan <- hello.StartHelloWorldServer(&log, listener, helloSvrCtx.Done())
}()
tlsConfig := websocketClientTLSConfig(t)
d := gws.Dialer{TLSClientConfig: tlsConfig, HandshakeTimeout: time.Minute}
testAddr := fmt.Sprintf("https://%s/ws", listener.Addr().String())
req := testRequest(t, testAddr, nil)
conn, resp, err := ClientConnect(req, &d)
require.NoError(t, err)
require.Equal(t, testSecWebsocketAccept, resp.Header.Get("Sec-WebSocket-Accept"))
// Websocket now connected to test server so lets check our wrapper
wrapper := GorillaConn{Conn: conn}
buf := make([]byte, 100)
wrapper.Write([]byte("abc"))
n, err := wrapper.Read(buf)
require.NoError(t, err)
require.Equal(t, n, 3)
require.Equal(t, "abc", string(buf[:n]))
// Test partial read, read 1 of 3 bytes in one read and the other 2 in another read
wrapper.Write([]byte("abc"))
buf = buf[:1]
n, err = wrapper.Read(buf)
require.NoError(t, err)
require.Equal(t, n, 1)
require.Equal(t, "a", string(buf[:n]))
buf = buf[:cap(buf)]
n, err = wrapper.Read(buf)
require.NoError(t, err)
require.Equal(t, n, 2)
require.Equal(t, "bc", string(buf[:n]))
assert.Equal(t, testSecWebsocketAccept, generateAcceptKey(testSecWebsocketKey))
}