TUN-3617: Separate service from client, and implement different client for http vs. tcp origins

- extracted ResponseWriter from proxyConnection
 - added bastion tests over websocket
 - removed HTTPResp()
 - added some docstrings
 - Renamed some ingress clients as proxies
 - renamed instances of client to proxy in connection and origin
 - Stream no longer takes a context and logger.Service
This commit is contained in:
cthuang
2020-12-09 21:46:53 +00:00
committed by Nuno Diegues
parent 5e2b43adb5
commit e2262085e5
23 changed files with 839 additions and 354 deletions

118
websocket/connection.go Normal file
View File

@@ -0,0 +1,118 @@
package websocket
import (
"context"
"github.com/rs/zerolog"
"io"
"time"
gobwas "github.com/gobwas/ws"
"github.com/gobwas/ws/wsutil"
"github.com/gorilla/websocket"
)
const (
// Time allowed to write a message to the peer.
writeWait = 10 * time.Second
// Time allowed to read the next pong message from the peer.
pongWait = 60 * time.Second
// Send pings to peer with this period. Must be less than pongWait.
pingPeriod = (pongWait * 9) / 10
)
// GorillaConn is a wrapper around the standard gorilla websocket but implements a ReadWriter
// This is still used by access carrier
type GorillaConn struct {
*websocket.Conn
log *zerolog.Logger
}
// Read will read messages from the websocket connection
func (c *GorillaConn) Read(p []byte) (int, error) {
_, message, err := c.Conn.ReadMessage()
if err != nil {
return 0, err
}
return copy(p, message), nil
}
// Write will write messages to the websocket connection
func (c *GorillaConn) Write(p []byte) (int, error) {
if err := c.Conn.WriteMessage(websocket.BinaryMessage, p); err != nil {
return 0, err
}
return len(p), nil
}
// pinger simulates the websocket connection to keep it alive
func (c *GorillaConn) pinger(ctx context.Context) {
ticker := time.NewTicker(pingPeriod)
defer ticker.Stop()
for {
select {
case <-ticker.C:
if err := c.WriteControl(websocket.PingMessage, []byte{}, time.Now().Add(writeWait)); err != nil {
c.log.Debug().Msgf("failed to send ping message: %s", err)
}
case <-ctx.Done():
return
}
}
}
type Conn struct {
rw io.ReadWriter
log *zerolog.Logger
}
func NewConn(rw io.ReadWriter, log *zerolog.Logger) *Conn {
return &Conn{
rw: rw,
log: log,
}
}
// Read will read messages from the websocket connection
func (c *Conn) Read(reader []byte) (int, error) {
data, err := wsutil.ReadClientBinary(c.rw)
if err != nil {
return 0, err
}
return copy(reader, data), nil
}
// Write will write messages to the websocket connection
func (c *Conn) Write(p []byte) (int, error) {
if err := wsutil.WriteServerBinary(c.rw, p); err != nil {
return 0, err
}
return len(p), nil
}
func (c *Conn) Pinger(ctx context.Context) {
pongMessge := wsutil.Message{
OpCode: gobwas.OpPong,
Payload: []byte{},
}
ticker := time.NewTicker(pingPeriod)
defer ticker.Stop()
for {
select {
case <-ticker.C:
if err := wsutil.WriteServerMessage(c.rw, gobwas.OpPing, []byte{}); err != nil {
c.log.Err(err).Msgf("failed to write ping message")
}
if err := wsutil.HandleClientControlMessage(c.rw, pongMessge); err != nil {
c.log.Err(err).Msgf("failed to write pong message")
}
case <-ctx.Done():
return
}
}
}

View File

@@ -2,7 +2,6 @@ package websocket
import (
"crypto/sha1"
"crypto/tls"
"encoding/base64"
"io"
"net"
@@ -16,17 +15,6 @@ import (
"github.com/rs/zerolog"
)
const (
// Time allowed to write a message to the peer.
writeWait = 10 * time.Second
// Time allowed to read the next pong message from the peer.
pongWait = 60 * time.Second
// Send pings to peer with this period. Must be less than pongWait.
pingPeriod = (pongWait * 9) / 10
)
var stripWebsocketHeaders = []string{
"Upgrade",
"Connection",
@@ -35,70 +23,28 @@ var stripWebsocketHeaders = []string{
"Sec-Websocket-Extensions",
}
// Conn is a wrapper around the standard gorilla websocket
// but implements a ReadWriter
type Conn struct {
*websocket.Conn
}
// Read will read messages from the websocket connection
func (c *Conn) Read(p []byte) (int, error) {
_, message, err := c.Conn.ReadMessage()
if err != nil {
return 0, err
}
return copy(p, message), nil
}
// Write will write messages to the websocket connection
func (c *Conn) Write(p []byte) (int, error) {
if err := c.Conn.WriteMessage(websocket.BinaryMessage, p); err != nil {
return 0, err
}
return len(p), nil
}
// IsWebSocketUpgrade checks to see if the request is a WebSocket connection.
func IsWebSocketUpgrade(req *http.Request) bool {
return websocket.IsWebSocketUpgrade(req)
}
// Dialler is something that can proxy websocket requests.
type Dialler interface {
Dial(url *url.URL, headers http.Header) (*websocket.Conn, *http.Response, error)
}
type defaultDialler struct {
tlsConfig *tls.Config
}
func (dd *defaultDialler) Dial(url *url.URL, header http.Header) (*websocket.Conn, *http.Response, error) {
d := &websocket.Dialer{
TLSClientConfig: dd.tlsConfig,
Proxy: http.ProxyFromEnvironment,
}
return d.Dial(url.String(), header)
}
// ClientConnect creates a WebSocket client connection for provided request. Caller is responsible for closing
// the connection. The response body may not contain the entire response and does
// not need to be closed by the application.
func ClientConnect(req *http.Request, dialler Dialler) (*websocket.Conn, *http.Response, error) {
func ClientConnect(req *http.Request, dialler *websocket.Dialer) (*websocket.Conn, *http.Response, error) {
req.URL.Scheme = ChangeRequestScheme(req.URL)
wsHeaders := websocketHeaders(req)
if dialler == nil {
dialler = new(defaultDialler)
dialler = &websocket.Dialer{
Proxy: http.ProxyFromEnvironment,
}
}
conn, response, err := dialler.Dial(req.URL, wsHeaders)
conn, response, err := dialler.Dial(req.URL.String(), wsHeaders)
if err != nil {
return nil, response, err
}
response.Header.Set("Sec-WebSocket-Accept", generateAcceptKey(req))
return conn, response, err
return conn, response, nil
}
// Stream copies copy data to & from provided io.ReadWriters.
@@ -121,8 +67,8 @@ func Stream(conn, backendConn io.ReadWriter) {
// DefaultStreamHandler is provided to the the standard websocket to origin stream
// This exist to allow SOCKS to deframe data before it gets to the origin
func DefaultStreamHandler(wsConn *Conn, remoteConn net.Conn, _ http.Header) {
Stream(wsConn, remoteConn)
func DefaultStreamHandler(originConn io.ReadWriter, remoteConn net.Conn) {
Stream(originConn, remoteConn)
}
// StartProxyServer will start a websocket server that will decode
@@ -132,7 +78,7 @@ func StartProxyServer(
listener net.Listener,
staticHost string,
shutdownC <-chan struct{},
streamHandler func(wsConn *Conn, remoteConn net.Conn, requestHeaders http.Header),
streamHandler func(originConn io.ReadWriter, remoteConn net.Conn),
) error {
upgrader := websocket.Upgrader{
ReadBufferSize: 1024,
@@ -159,7 +105,7 @@ type handler struct {
log *zerolog.Logger
staticHost string
upgrader websocket.Upgrader
streamHandler func(wsConn *Conn, remoteConn net.Conn, requestHeaders http.Header)
streamHandler func(originConn io.ReadWriter, remoteConn net.Conn)
}
func (h *handler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
@@ -192,14 +138,20 @@ func (h *handler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
}
_ = conn.SetReadDeadline(time.Now().Add(pongWait))
conn.SetPongHandler(func(string) error { _ = conn.SetReadDeadline(time.Now().Add(pongWait)); return nil })
done := make(chan struct{})
go pinger(h.log, conn, done)
defer func() {
done <- struct{}{}
_ = conn.Close()
}()
gorillaConn := &GorillaConn{conn, h.log}
go gorillaConn.pinger(r.Context())
defer conn.Close()
h.streamHandler(&Conn{conn}, stream, r.Header)
h.streamHandler(gorillaConn, stream)
}
// NewResponseHeader returns headers needed to return to origin for completing handshake
func NewResponseHeader(req *http.Request) http.Header {
header := http.Header{}
header.Add("Connection", "Upgrade")
header.Add("Sec-Websocket-Accept", generateAcceptKey(req))
header.Add("Upgrade", "websocket")
return header
}
// the gorilla websocket library sets its own Upgrade, Connection, Sec-WebSocket-Key,
@@ -246,19 +198,3 @@ func ChangeRequestScheme(reqURL *url.URL) string {
return reqURL.Scheme
}
}
// pinger simulates the websocket connection to keep it alive
func pinger(logger *zerolog.Logger, ws *websocket.Conn, done chan struct{}) {
ticker := time.NewTicker(pingPeriod)
defer ticker.Stop()
for {
select {
case <-ticker.C:
if err := ws.WriteControl(websocket.PingMessage, []byte{}, time.Now().Add(writeWait)); err != nil {
logger.Debug().Msgf("failed to send ping message: %s", err)
}
case <-done:
return
}
}
}

View File

@@ -11,7 +11,7 @@ import (
"github.com/cloudflare/cloudflared/hello"
"github.com/cloudflare/cloudflared/tlsconfig"
gws "github.com/gorilla/websocket"
"github.com/stretchr/testify/assert"
"golang.org/x/net/websocket"
)
@@ -78,7 +78,7 @@ func TestServe(t *testing.T) {
tlsConfig := websocketClientTLSConfig(t)
assert.NotNil(t, tlsConfig)
d := defaultDialler{tlsConfig: tlsConfig}
d := gws.Dialer{TLSClientConfig: tlsConfig}
conn, resp, err := ClientConnect(req, &d)
assert.NoError(t, err)
assert.Equal(t, testSecWebsocketAccept, resp.Header.Get("Sec-WebSocket-Accept"))