TUN-4626: Proxy non-stream based origin websockets with http Roundtrip.

Reuses HTTPProxy's Roundtrip method to directly proxy websockets from
eyeball clients (determined by websocket type and ingress not being
connection oriented , i.e. Not ssh or smb for example) to proxy
websocket traffic.
This commit is contained in:
Sudarsan Reddy
2021-07-01 10:29:53 +01:00
parent 3eb9efd9f0
commit f1b57526b3
5 changed files with 76 additions and 218 deletions

View File

@@ -15,6 +15,7 @@ import (
"github.com/cloudflare/cloudflared/connection"
"github.com/cloudflare/cloudflared/ingress"
tunnelpogs "github.com/cloudflare/cloudflared/tunnelrpc/pogs"
"github.com/cloudflare/cloudflared/websocket"
)
const (
@@ -85,27 +86,27 @@ func (p *proxy) Proxy(w connection.ResponseWriter, req *http.Request, sourceConn
}
p.logRequest(req, logFields)
if sourceConnectionType == connection.TypeHTTP {
if err := p.proxyHTTPRequest(w, req, rule, logFields); err != nil {
switch originProxy := rule.Service.(type) {
case ingress.HTTPOriginProxy:
if err := p.proxyHTTPRequest(w, req, originProxy, sourceConnectionType == connection.TypeWebsocket,
rule.Config.DisableChunkedEncoding, logFields); err != nil {
rule, srv := ruleField(p.ingressRules, ruleNum)
p.logRequestError(err, cfRay, rule, srv)
return err
}
return nil
}
connectionProxy, ok := rule.Service.(ingress.StreamBasedOriginProxy)
if !ok {
p.log.Error().Msgf("%s is not a connection-oriented service", rule.Service)
return fmt.Errorf("Not a connection-oriented service")
}
case ingress.StreamBasedOriginProxy:
if err := p.proxyStreamRequest(serveCtx, w, req, originProxy, logFields); err != nil {
rule, srv := ruleField(p.ingressRules, ruleNum)
p.logRequestError(err, cfRay, rule, srv)
return err
}
return nil
default:
return fmt.Errorf("Unrecognized service: %s, %t", rule.Service, originProxy)
if err := p.proxyStreamRequest(serveCtx, w, req, connectionProxy, logFields); err != nil {
rule, srv := ruleField(p.ingressRules, ruleNum)
p.logRequestError(err, cfRay, rule, srv)
return err
}
return nil
}
func ruleField(ing ingress.Ingress, ruleNum int) (ruleID string, srv string) {
@@ -116,26 +117,35 @@ func ruleField(ing ingress.Ingress, ruleNum int) (ruleID string, srv string) {
return fmt.Sprintf("%d", ruleNum), srv
}
func (p *proxy) proxyHTTPRequest(w connection.ResponseWriter, req *http.Request, rule *ingress.Rule, fields logFields) error {
// Support for WSGI Servers by switching transfer encoding from chunked to gzip/deflate
if rule.Config.DisableChunkedEncoding {
req.TransferEncoding = []string{"gzip", "deflate"}
cLength, err := strconv.Atoi(req.Header.Get("Content-Length"))
if err == nil {
req.ContentLength = int64(cLength)
func (p *proxy) proxyHTTPRequest(
w connection.ResponseWriter,
req *http.Request,
httpService ingress.HTTPOriginProxy,
isWebsocket bool,
disableChunkedEncoding bool,
fields logFields) error {
roundTripReq := req
if isWebsocket {
roundTripReq = req.Clone(req.Context())
roundTripReq.Header.Set("Connection", "Upgrade")
roundTripReq.Header.Set("Upgrade", "websocket")
roundTripReq.Header.Set("Sec-Websocket-Version", "13")
roundTripReq.ContentLength = 0
roundTripReq.Body = nil
} else {
// Support for WSGI Servers by switching transfer encoding from chunked to gzip/deflate
if disableChunkedEncoding {
roundTripReq.TransferEncoding = []string{"gzip", "deflate"}
cLength, err := strconv.Atoi(req.Header.Get("Content-Length"))
if err == nil {
roundTripReq.ContentLength = int64(cLength)
}
}
// Request origin to keep connection alive to improve performance
roundTripReq.Header.Set("Connection", "keep-alive")
}
// Request origin to keep connection alive to improve performance
req.Header.Set("Connection", "keep-alive")
httpService, ok := rule.Service.(ingress.HTTPOriginProxy)
if !ok {
p.log.Error().Msgf("%s is not a http service", rule.Service)
return fmt.Errorf("Not a http service")
}
resp, err := httpService.RoundTrip(req)
resp, err := httpService.RoundTrip(roundTripReq)
if err != nil {
return errors.Wrap(err, "Unable to reach the origin service. The service may be down or it may not be responding to traffic from cloudflared")
}
@@ -145,6 +155,23 @@ func (p *proxy) proxyHTTPRequest(w connection.ResponseWriter, req *http.Request,
if err != nil {
return errors.Wrap(err, "Error writing response header")
}
if resp.StatusCode == http.StatusSwitchingProtocols {
rwc, ok := resp.Body.(io.ReadWriteCloser)
if !ok {
return errors.New("internal error: unsupported connection type")
}
defer rwc.Close()
eyeballStream := &bidirectionalStream{
writer: w,
reader: req.Body,
}
websocket.Stream(eyeballStream, rwc, p.log)
return nil
}
if connection.IsServerSentEvent(resp.Header) {
p.log.Debug().Msg("Detected Server-Side Events from Origin")
p.writeEventStream(w, resp.Body)

View File

@@ -571,8 +571,14 @@ func TestConnections(t *testing.T) {
},
},
want: want{
message: []byte{},
err: true,
message: []byte("Forbidden\n"),
err: false,
headers: map[string][]string{
"Content-Length": {"10"},
"Content-Type": {"text/plain; charset=utf-8"},
"Sec-Websocket-Version": {"13"},
"X-Content-Type-Options": {"nosniff"},
},
},
},
{
@@ -806,6 +812,8 @@ func (w *wsRespWriter) WriteRespHeaders(status int, header http.Header) error {
// respHeaders is a test function to read respHeaders
func (w *wsRespWriter) headers() http.Header {
// Removing indeterminstic header because it cannot be asserted.
w.responseHeaders.Del("Date")
return w.responseHeaders
}