mirror of
https://github.com/cloudflare/cloudflared.git
synced 2025-07-26 23:29:57 +00:00
TUN-528: Move cloudflared into a separate repo
This commit is contained in:
117
websocket/websocket.go
Normal file
117
websocket/websocket.go
Normal file
@@ -0,0 +1,117 @@
|
||||
package websocket
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
"crypto/sha1"
|
||||
"crypto/tls"
|
||||
"encoding/base64"
|
||||
"errors"
|
||||
"io"
|
||||
"net"
|
||||
"net/http"
|
||||
|
||||
"github.com/gorilla/websocket"
|
||||
)
|
||||
|
||||
var stripWebsocketHeaders = []string {
|
||||
"Upgrade",
|
||||
"Connection",
|
||||
"Sec-Websocket-Key",
|
||||
"Sec-Websocket-Version",
|
||||
"Sec-Websocket-Extensions",
|
||||
}
|
||||
|
||||
// IsWebSocketUpgrade checks to see if the request is a WebSocket connection.
|
||||
func IsWebSocketUpgrade(req *http.Request) bool {
|
||||
return websocket.IsWebSocketUpgrade(req)
|
||||
}
|
||||
|
||||
// ClientConnect creates a WebSocket client connection for provided request. Caller is responsible for closing
|
||||
// the connection. The response body may not contain the entire response and does
|
||||
// not need to be closed by the application.
|
||||
func ClientConnect(req *http.Request, tlsClientConfig *tls.Config) (*websocket.Conn, *http.Response, error) {
|
||||
req.URL.Scheme = changeRequestScheme(req)
|
||||
wsHeaders := websocketHeaders(req)
|
||||
|
||||
d := &websocket.Dialer{TLSClientConfig: tlsClientConfig}
|
||||
conn, response, err := d.Dial(req.URL.String(), wsHeaders)
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
response.Header.Set("Sec-WebSocket-Accept", generateAcceptKey(req))
|
||||
return conn, response, err
|
||||
}
|
||||
|
||||
// HijackConnection takes over an HTTP connection. Caller is responsible for closing connection.
|
||||
func HijackConnection(w http.ResponseWriter) (net.Conn, *bufio.ReadWriter, error) {
|
||||
hj, ok := w.(http.Hijacker)
|
||||
if !ok {
|
||||
return nil, nil, errors.New("hijack error")
|
||||
}
|
||||
|
||||
conn, brw, err := hj.Hijack()
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
return conn, brw, nil
|
||||
}
|
||||
|
||||
// Stream copies copy data to & from provided io.ReadWriters.
|
||||
func Stream(conn, backendConn io.ReadWriter) {
|
||||
proxyDone := make(chan struct{}, 2)
|
||||
|
||||
go func() {
|
||||
io.Copy(conn, backendConn)
|
||||
proxyDone <- struct{}{}
|
||||
}()
|
||||
|
||||
go func() {
|
||||
io.Copy(backendConn, conn)
|
||||
proxyDone <- struct{}{}
|
||||
}()
|
||||
|
||||
// If one side is done, we are done.
|
||||
<-proxyDone
|
||||
}
|
||||
|
||||
// the gorilla websocket library sets its own Upgrade, Connection, Sec-WebSocket-Key,
|
||||
// Sec-WebSocket-Version and Sec-Websocket-Extensions headers.
|
||||
// https://github.com/gorilla/websocket/blob/master/client.go#L189-L194.
|
||||
func websocketHeaders(req *http.Request) http.Header {
|
||||
wsHeaders := make(http.Header)
|
||||
for key, val := range req.Header {
|
||||
wsHeaders[key] = val
|
||||
}
|
||||
// Assume the header keys are in canonical format.
|
||||
for _, header := range stripWebsocketHeaders {
|
||||
wsHeaders.Del(header)
|
||||
}
|
||||
return wsHeaders
|
||||
}
|
||||
|
||||
// sha1Base64 sha1 and then base64 encodes str.
|
||||
func sha1Base64(str string) string {
|
||||
hasher := sha1.New()
|
||||
io.WriteString(hasher, str)
|
||||
hash := hasher.Sum(nil)
|
||||
return base64.StdEncoding.EncodeToString(hash)
|
||||
}
|
||||
|
||||
// generateAcceptKey returns the string needed for the Sec-WebSocket-Accept header.
|
||||
// https://tools.ietf.org/html/rfc6455#section-1.3 describes this process in more detail.
|
||||
func generateAcceptKey(req *http.Request) string {
|
||||
return sha1Base64(req.Header.Get("Sec-WebSocket-Key") + "258EAFA5-E914-47DA-95CA-C5AB0DC85B11")
|
||||
}
|
||||
|
||||
// changeRequestScheme is needed as the gorilla websocket library requires the ws scheme.
|
||||
// (even though it changes it back to http/https, but ¯\_(ツ)_/¯.)
|
||||
func changeRequestScheme(req *http.Request) string {
|
||||
switch req.URL.Scheme {
|
||||
case "https":
|
||||
return "wss"
|
||||
case "http":
|
||||
return "ws"
|
||||
default:
|
||||
return req.URL.Scheme
|
||||
}
|
||||
}
|
100
websocket/websocket_test.go
Normal file
100
websocket/websocket_test.go
Normal file
@@ -0,0 +1,100 @@
|
||||
package websocket
|
||||
|
||||
import (
|
||||
"crypto/tls"
|
||||
"io"
|
||||
"math/rand"
|
||||
"net/http"
|
||||
"testing"
|
||||
|
||||
"github.com/sirupsen/logrus"
|
||||
"github.com/stretchr/testify/assert"
|
||||
|
||||
"golang.org/x/net/websocket"
|
||||
|
||||
"github.com/cloudflare/cloudflared/hello"
|
||||
"github.com/cloudflare/cloudflared/tlsconfig"
|
||||
)
|
||||
|
||||
const (
|
||||
// example in Sec-Websocket-Key in rfc6455
|
||||
testSecWebsocketKey = "dGhlIHNhbXBsZSBub25jZQ=="
|
||||
// example Sec-Websocket-Accept in rfc6455
|
||||
testSecWebsocketAccept = "s3pPLMBiTxaQ9kYGzzhZRbK+xOo="
|
||||
)
|
||||
|
||||
func testRequest(t *testing.T, url string, stream io.ReadWriter) *http.Request {
|
||||
req, err := http.NewRequest("GET", url, stream)
|
||||
if err != nil {
|
||||
t.Fatalf("testRequestHeader error")
|
||||
}
|
||||
|
||||
req.Header.Add("Connection", "Upgrade")
|
||||
req.Header.Add("Upgrade", "WebSocket")
|
||||
req.Header.Add("Sec-Websocket-Key", testSecWebsocketKey)
|
||||
req.Header.Add("Sec-Websocket-Protocol", "tunnel-protocol")
|
||||
req.Header.Add("Sec-Websocket-Version", "13")
|
||||
req.Header.Add("User-Agent", "curl/7.59.0")
|
||||
|
||||
return req
|
||||
}
|
||||
|
||||
func websocketClientTLSConfig(t *testing.T) *tls.Config {
|
||||
certPool, err := tlsconfig.LoadOriginCertPool(nil)
|
||||
assert.NoError(t, err)
|
||||
assert.NotNil(t, certPool)
|
||||
return &tls.Config{RootCAs: certPool}
|
||||
}
|
||||
|
||||
func TestWebsocketHeaders(t *testing.T) {
|
||||
req := testRequest(t, "http://example.com", nil)
|
||||
wsHeaders := websocketHeaders(req)
|
||||
for _, header := range stripWebsocketHeaders {
|
||||
assert.Empty(t, wsHeaders[header])
|
||||
}
|
||||
assert.Equal(t, "curl/7.59.0", wsHeaders.Get("User-Agent"))
|
||||
}
|
||||
|
||||
func TestGenerateAcceptKey(t *testing.T) {
|
||||
req := testRequest(t, "http://example.com", nil)
|
||||
assert.Equal(t, testSecWebsocketAccept, generateAcceptKey(req))
|
||||
}
|
||||
|
||||
func TestServe(t *testing.T) {
|
||||
logger := logrus.New()
|
||||
shutdownC := make(chan struct{})
|
||||
errC := make(chan error)
|
||||
listener, err := hello.CreateTLSListener("localhost:1111")
|
||||
assert.NoError(t, err)
|
||||
defer listener.Close()
|
||||
|
||||
go func() {
|
||||
errC <- hello.StartHelloWorldServer(logger, listener, shutdownC)
|
||||
}()
|
||||
|
||||
req := testRequest(t, "https://localhost:1111/ws", nil)
|
||||
|
||||
tlsConfig := websocketClientTLSConfig(t)
|
||||
assert.NotNil(t, tlsConfig)
|
||||
conn, resp, err := ClientConnect(req, tlsConfig)
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, testSecWebsocketAccept, resp.Header.Get("Sec-WebSocket-Accept"))
|
||||
|
||||
for i := 0; i < 1000; i++ {
|
||||
messageSize := rand.Int() % 2048 + 1
|
||||
clientMessage := make([]byte, messageSize)
|
||||
// rand.Read always returns len(clientMessage) and a nil error
|
||||
rand.Read(clientMessage)
|
||||
err = conn.WriteMessage(websocket.BinaryFrame, clientMessage)
|
||||
assert.NoError(t, err)
|
||||
|
||||
messageType, message, err := conn.ReadMessage()
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, websocket.BinaryFrame, messageType)
|
||||
assert.Equal(t, clientMessage, message)
|
||||
}
|
||||
|
||||
conn.Close()
|
||||
close(shutdownC)
|
||||
<-errC
|
||||
}
|
Reference in New Issue
Block a user