mirror of
https://github.com/cloudflare/cloudflared.git
synced 2025-07-27 15:49:58 +00:00
TUN-3617: Separate service from client, and implement different client for http vs. tcp origins
- extracted ResponseWriter from proxyConnection - added bastion tests over websocket - removed HTTPResp() - added some docstrings - Renamed some ingress clients as proxies - renamed instances of client to proxy in connection and origin - Stream no longer takes a context and logger.Service
This commit is contained in:
118
websocket/connection.go
Normal file
118
websocket/connection.go
Normal file
@@ -0,0 +1,118 @@
|
||||
package websocket
|
||||
|
||||
import (
|
||||
"context"
|
||||
"github.com/rs/zerolog"
|
||||
"io"
|
||||
"time"
|
||||
|
||||
gobwas "github.com/gobwas/ws"
|
||||
"github.com/gobwas/ws/wsutil"
|
||||
"github.com/gorilla/websocket"
|
||||
)
|
||||
|
||||
const (
|
||||
// Time allowed to write a message to the peer.
|
||||
writeWait = 10 * time.Second
|
||||
|
||||
// Time allowed to read the next pong message from the peer.
|
||||
pongWait = 60 * time.Second
|
||||
|
||||
// Send pings to peer with this period. Must be less than pongWait.
|
||||
pingPeriod = (pongWait * 9) / 10
|
||||
)
|
||||
|
||||
// GorillaConn is a wrapper around the standard gorilla websocket but implements a ReadWriter
|
||||
// This is still used by access carrier
|
||||
type GorillaConn struct {
|
||||
*websocket.Conn
|
||||
log *zerolog.Logger
|
||||
}
|
||||
|
||||
// Read will read messages from the websocket connection
|
||||
func (c *GorillaConn) Read(p []byte) (int, error) {
|
||||
_, message, err := c.Conn.ReadMessage()
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
|
||||
return copy(p, message), nil
|
||||
|
||||
}
|
||||
|
||||
// Write will write messages to the websocket connection
|
||||
func (c *GorillaConn) Write(p []byte) (int, error) {
|
||||
if err := c.Conn.WriteMessage(websocket.BinaryMessage, p); err != nil {
|
||||
return 0, err
|
||||
}
|
||||
|
||||
return len(p), nil
|
||||
}
|
||||
|
||||
// pinger simulates the websocket connection to keep it alive
|
||||
func (c *GorillaConn) pinger(ctx context.Context) {
|
||||
ticker := time.NewTicker(pingPeriod)
|
||||
defer ticker.Stop()
|
||||
for {
|
||||
select {
|
||||
case <-ticker.C:
|
||||
if err := c.WriteControl(websocket.PingMessage, []byte{}, time.Now().Add(writeWait)); err != nil {
|
||||
c.log.Debug().Msgf("failed to send ping message: %s", err)
|
||||
}
|
||||
case <-ctx.Done():
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
type Conn struct {
|
||||
rw io.ReadWriter
|
||||
log *zerolog.Logger
|
||||
}
|
||||
|
||||
func NewConn(rw io.ReadWriter, log *zerolog.Logger) *Conn {
|
||||
return &Conn{
|
||||
rw: rw,
|
||||
log: log,
|
||||
}
|
||||
}
|
||||
|
||||
// Read will read messages from the websocket connection
|
||||
func (c *Conn) Read(reader []byte) (int, error) {
|
||||
data, err := wsutil.ReadClientBinary(c.rw)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
return copy(reader, data), nil
|
||||
}
|
||||
|
||||
// Write will write messages to the websocket connection
|
||||
func (c *Conn) Write(p []byte) (int, error) {
|
||||
if err := wsutil.WriteServerBinary(c.rw, p); err != nil {
|
||||
return 0, err
|
||||
}
|
||||
|
||||
return len(p), nil
|
||||
}
|
||||
|
||||
func (c *Conn) Pinger(ctx context.Context) {
|
||||
pongMessge := wsutil.Message{
|
||||
OpCode: gobwas.OpPong,
|
||||
Payload: []byte{},
|
||||
}
|
||||
ticker := time.NewTicker(pingPeriod)
|
||||
defer ticker.Stop()
|
||||
for {
|
||||
select {
|
||||
case <-ticker.C:
|
||||
if err := wsutil.WriteServerMessage(c.rw, gobwas.OpPing, []byte{}); err != nil {
|
||||
c.log.Err(err).Msgf("failed to write ping message")
|
||||
}
|
||||
if err := wsutil.HandleClientControlMessage(c.rw, pongMessge); err != nil {
|
||||
c.log.Err(err).Msgf("failed to write pong message")
|
||||
}
|
||||
case <-ctx.Done():
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
@@ -2,7 +2,6 @@ package websocket
|
||||
|
||||
import (
|
||||
"crypto/sha1"
|
||||
"crypto/tls"
|
||||
"encoding/base64"
|
||||
"io"
|
||||
"net"
|
||||
@@ -16,17 +15,6 @@ import (
|
||||
"github.com/rs/zerolog"
|
||||
)
|
||||
|
||||
const (
|
||||
// Time allowed to write a message to the peer.
|
||||
writeWait = 10 * time.Second
|
||||
|
||||
// Time allowed to read the next pong message from the peer.
|
||||
pongWait = 60 * time.Second
|
||||
|
||||
// Send pings to peer with this period. Must be less than pongWait.
|
||||
pingPeriod = (pongWait * 9) / 10
|
||||
)
|
||||
|
||||
var stripWebsocketHeaders = []string{
|
||||
"Upgrade",
|
||||
"Connection",
|
||||
@@ -35,70 +23,28 @@ var stripWebsocketHeaders = []string{
|
||||
"Sec-Websocket-Extensions",
|
||||
}
|
||||
|
||||
// Conn is a wrapper around the standard gorilla websocket
|
||||
// but implements a ReadWriter
|
||||
type Conn struct {
|
||||
*websocket.Conn
|
||||
}
|
||||
|
||||
// Read will read messages from the websocket connection
|
||||
func (c *Conn) Read(p []byte) (int, error) {
|
||||
_, message, err := c.Conn.ReadMessage()
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
|
||||
return copy(p, message), nil
|
||||
|
||||
}
|
||||
|
||||
// Write will write messages to the websocket connection
|
||||
func (c *Conn) Write(p []byte) (int, error) {
|
||||
if err := c.Conn.WriteMessage(websocket.BinaryMessage, p); err != nil {
|
||||
return 0, err
|
||||
}
|
||||
|
||||
return len(p), nil
|
||||
}
|
||||
|
||||
// IsWebSocketUpgrade checks to see if the request is a WebSocket connection.
|
||||
func IsWebSocketUpgrade(req *http.Request) bool {
|
||||
return websocket.IsWebSocketUpgrade(req)
|
||||
}
|
||||
|
||||
// Dialler is something that can proxy websocket requests.
|
||||
type Dialler interface {
|
||||
Dial(url *url.URL, headers http.Header) (*websocket.Conn, *http.Response, error)
|
||||
}
|
||||
|
||||
type defaultDialler struct {
|
||||
tlsConfig *tls.Config
|
||||
}
|
||||
|
||||
func (dd *defaultDialler) Dial(url *url.URL, header http.Header) (*websocket.Conn, *http.Response, error) {
|
||||
d := &websocket.Dialer{
|
||||
TLSClientConfig: dd.tlsConfig,
|
||||
Proxy: http.ProxyFromEnvironment,
|
||||
}
|
||||
return d.Dial(url.String(), header)
|
||||
}
|
||||
|
||||
// ClientConnect creates a WebSocket client connection for provided request. Caller is responsible for closing
|
||||
// the connection. The response body may not contain the entire response and does
|
||||
// not need to be closed by the application.
|
||||
func ClientConnect(req *http.Request, dialler Dialler) (*websocket.Conn, *http.Response, error) {
|
||||
func ClientConnect(req *http.Request, dialler *websocket.Dialer) (*websocket.Conn, *http.Response, error) {
|
||||
req.URL.Scheme = ChangeRequestScheme(req.URL)
|
||||
wsHeaders := websocketHeaders(req)
|
||||
|
||||
if dialler == nil {
|
||||
dialler = new(defaultDialler)
|
||||
dialler = &websocket.Dialer{
|
||||
Proxy: http.ProxyFromEnvironment,
|
||||
}
|
||||
}
|
||||
conn, response, err := dialler.Dial(req.URL, wsHeaders)
|
||||
conn, response, err := dialler.Dial(req.URL.String(), wsHeaders)
|
||||
if err != nil {
|
||||
return nil, response, err
|
||||
}
|
||||
response.Header.Set("Sec-WebSocket-Accept", generateAcceptKey(req))
|
||||
return conn, response, err
|
||||
return conn, response, nil
|
||||
}
|
||||
|
||||
// Stream copies copy data to & from provided io.ReadWriters.
|
||||
@@ -121,8 +67,8 @@ func Stream(conn, backendConn io.ReadWriter) {
|
||||
|
||||
// DefaultStreamHandler is provided to the the standard websocket to origin stream
|
||||
// This exist to allow SOCKS to deframe data before it gets to the origin
|
||||
func DefaultStreamHandler(wsConn *Conn, remoteConn net.Conn, _ http.Header) {
|
||||
Stream(wsConn, remoteConn)
|
||||
func DefaultStreamHandler(originConn io.ReadWriter, remoteConn net.Conn) {
|
||||
Stream(originConn, remoteConn)
|
||||
}
|
||||
|
||||
// StartProxyServer will start a websocket server that will decode
|
||||
@@ -132,7 +78,7 @@ func StartProxyServer(
|
||||
listener net.Listener,
|
||||
staticHost string,
|
||||
shutdownC <-chan struct{},
|
||||
streamHandler func(wsConn *Conn, remoteConn net.Conn, requestHeaders http.Header),
|
||||
streamHandler func(originConn io.ReadWriter, remoteConn net.Conn),
|
||||
) error {
|
||||
upgrader := websocket.Upgrader{
|
||||
ReadBufferSize: 1024,
|
||||
@@ -159,7 +105,7 @@ type handler struct {
|
||||
log *zerolog.Logger
|
||||
staticHost string
|
||||
upgrader websocket.Upgrader
|
||||
streamHandler func(wsConn *Conn, remoteConn net.Conn, requestHeaders http.Header)
|
||||
streamHandler func(originConn io.ReadWriter, remoteConn net.Conn)
|
||||
}
|
||||
|
||||
func (h *handler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
|
||||
@@ -192,14 +138,20 @@ func (h *handler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
|
||||
}
|
||||
_ = conn.SetReadDeadline(time.Now().Add(pongWait))
|
||||
conn.SetPongHandler(func(string) error { _ = conn.SetReadDeadline(time.Now().Add(pongWait)); return nil })
|
||||
done := make(chan struct{})
|
||||
go pinger(h.log, conn, done)
|
||||
defer func() {
|
||||
done <- struct{}{}
|
||||
_ = conn.Close()
|
||||
}()
|
||||
gorillaConn := &GorillaConn{conn, h.log}
|
||||
go gorillaConn.pinger(r.Context())
|
||||
defer conn.Close()
|
||||
|
||||
h.streamHandler(&Conn{conn}, stream, r.Header)
|
||||
h.streamHandler(gorillaConn, stream)
|
||||
}
|
||||
|
||||
// NewResponseHeader returns headers needed to return to origin for completing handshake
|
||||
func NewResponseHeader(req *http.Request) http.Header {
|
||||
header := http.Header{}
|
||||
header.Add("Connection", "Upgrade")
|
||||
header.Add("Sec-Websocket-Accept", generateAcceptKey(req))
|
||||
header.Add("Upgrade", "websocket")
|
||||
return header
|
||||
}
|
||||
|
||||
// the gorilla websocket library sets its own Upgrade, Connection, Sec-WebSocket-Key,
|
||||
@@ -246,19 +198,3 @@ func ChangeRequestScheme(reqURL *url.URL) string {
|
||||
return reqURL.Scheme
|
||||
}
|
||||
}
|
||||
|
||||
// pinger simulates the websocket connection to keep it alive
|
||||
func pinger(logger *zerolog.Logger, ws *websocket.Conn, done chan struct{}) {
|
||||
ticker := time.NewTicker(pingPeriod)
|
||||
defer ticker.Stop()
|
||||
for {
|
||||
select {
|
||||
case <-ticker.C:
|
||||
if err := ws.WriteControl(websocket.PingMessage, []byte{}, time.Now().Add(writeWait)); err != nil {
|
||||
logger.Debug().Msgf("failed to send ping message: %s", err)
|
||||
}
|
||||
case <-done:
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@@ -11,7 +11,7 @@ import (
|
||||
|
||||
"github.com/cloudflare/cloudflared/hello"
|
||||
"github.com/cloudflare/cloudflared/tlsconfig"
|
||||
|
||||
gws "github.com/gorilla/websocket"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"golang.org/x/net/websocket"
|
||||
)
|
||||
@@ -78,7 +78,7 @@ func TestServe(t *testing.T) {
|
||||
|
||||
tlsConfig := websocketClientTLSConfig(t)
|
||||
assert.NotNil(t, tlsConfig)
|
||||
d := defaultDialler{tlsConfig: tlsConfig}
|
||||
d := gws.Dialer{TLSClientConfig: tlsConfig}
|
||||
conn, resp, err := ClientConnect(req, &d)
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, testSecWebsocketAccept, resp.Header.Get("Sec-WebSocket-Accept"))
|
||||
|
Reference in New Issue
Block a user