TUN-4701: Split Proxy into ProxyHTTP and ProxyTCP

http.Request now is only used by ProxyHTTP and not required if the
proxying is TCP. The dest conversion is handled by the transport layer.
This commit is contained in:
Sudarsan Reddy
2021-07-16 16:14:37 +01:00
parent 81dff44bb9
commit 8f3526289a
9 changed files with 263 additions and 135 deletions

View File

@@ -7,7 +7,6 @@ import (
"io"
"net/http"
"strconv"
"strings"
"github.com/pkg/errors"
"github.com/rs/zerolog"
@@ -20,13 +19,15 @@ import (
)
const (
// TagHeaderNamePrefix indicates a Cloudflared Warp Tag prefix that gets appended for warp traffic stream headers.
TagHeaderNamePrefix = "Cf-Warp-Tag-"
LogFieldCFRay = "cfRay"
LogFieldRule = "ingressRule"
LogFieldOriginService = "originService"
)
type proxy struct {
// Proxy represents a means to Proxy between cloudflared and the origin services.
type Proxy struct {
ingressRules ingress.Ingress
warpRouting *ingress.WarpRoutingService
tags []tunnelpogs.Tag
@@ -34,15 +35,14 @@ type proxy struct {
bufferPool *bufferPool
}
var switchingProtocolText = fmt.Sprintf("%d %s", http.StatusSwitchingProtocols, http.StatusText(http.StatusSwitchingProtocols))
// NewOriginProxy returns a new instance of the Proxy struct.
func NewOriginProxy(
ingressRules ingress.Ingress,
warpRouting *ingress.WarpRoutingService,
tags []tunnelpogs.Tag,
log *zerolog.Logger) connection.OriginProxy {
return &proxy{
log *zerolog.Logger,
) *Proxy {
return &Proxy{
ingressRules: ingressRules,
warpRouting: warpRouting,
tags: tags,
@@ -51,41 +51,18 @@ func NewOriginProxy(
}
}
// Caller is responsible for writing any error to ResponseWriter
func (p *proxy) Proxy(w connection.ResponseWriter, req *http.Request, sourceConnectionType connection.Type) error {
// ProxyHTTP further depends on ingress rules to establish a connection with the origin service. This may be
// a simple roundtrip or a tcp/websocket dial depending on ingres rule setup.
func (p *Proxy) ProxyHTTP(
w connection.ResponseWriter,
req *http.Request,
isWebsocket bool,
) error {
incrementRequests()
defer decrementConcurrentRequests()
cfRay := findCfRayHeader(req)
lbProbe := isLBProbeRequest(req)
serveCtx, cancel := context.WithCancel(req.Context())
defer cancel()
p.appendTagHeaders(req)
if sourceConnectionType == connection.TypeTCP {
if p.warpRouting == nil {
err := errors.New(`cloudflared received a request from WARP client, but your configuration has disabled ingress from WARP clients. To enable this, set "warp-routing:\n\t enabled: true" in your config.yaml`)
p.log.Error().Msg(err.Error())
return err
}
logFields := logFields{
cfRay: cfRay,
lbProbe: lbProbe,
rule: ingress.ServiceWarpRouting,
}
host, err := getRequestHost(req)
if err != nil {
err = fmt.Errorf(`cloudflared recieved a warp-routing request with an empty host value: %v`, err)
return err
}
if err := p.proxyStreamRequest(serveCtx, w, host, req, p.warpRouting.Proxy, logFields); err != nil {
p.logRequestError(err, cfRay, "", ingress.ServiceWarpRouting)
return err
}
return nil
}
cfRay := connection.FindCfRayHeader(req)
lbProbe := connection.IsLBProbeRequest(req)
rule, ruleNum := p.ingressRules.FindMatchingRule(req.Host, req.URL.Path)
logFields := logFields{
@@ -97,8 +74,14 @@ func (p *proxy) Proxy(w connection.ResponseWriter, req *http.Request, sourceConn
switch originProxy := rule.Service.(type) {
case ingress.HTTPOriginProxy:
if err := p.proxyHTTPRequest(w, req, originProxy, sourceConnectionType == connection.TypeWebsocket,
rule.Config.DisableChunkedEncoding, logFields); err != nil {
if err := p.proxyHTTPRequest(
w,
req,
originProxy,
isWebsocket,
rule.Config.DisableChunkedEncoding,
logFields,
); err != nil {
rule, srv := ruleField(p.ingressRules, ruleNum)
p.logRequestError(err, cfRay, rule, srv)
return err
@@ -110,7 +93,9 @@ func (p *proxy) Proxy(w connection.ResponseWriter, req *http.Request, sourceConn
if err != nil {
return err
}
if err := p.proxyStreamRequest(serveCtx, w, dest, req, originProxy, logFields); err != nil {
rws := connection.NewHTTPResponseReadWriterAcker(w, req)
if err := p.proxyStream(req.Context(), rws, dest, originProxy, logFields); err != nil {
rule, srv := ruleField(p.ingressRules, ruleNum)
p.logRequestError(err, cfRay, rule, srv)
return err
@@ -121,24 +106,36 @@ func (p *proxy) Proxy(w connection.ResponseWriter, req *http.Request, sourceConn
}
}
func getDestFromRule(rule *ingress.Rule, req *http.Request) (string, error) {
switch rule.Service.String() {
case ingress.ServiceBastion:
return carrier.ResolveBastionDest(req)
default:
return rule.Service.String(), nil
}
}
// ProxyTCP proxies to a TCP connection between the origin service and cloudflared.
func (p *Proxy) ProxyTCP(
ctx context.Context,
rwa connection.ReadWriteAcker,
req *connection.TCPRequest,
) error {
incrementRequests()
defer decrementConcurrentRequests()
// getRequestHost returns the host of the http.Request.
func getRequestHost(r *http.Request) (string, error) {
if r.Host != "" {
return r.Host, nil
if p.warpRouting == nil {
err := errors.New(`cloudflared received a request from WARP client, but your configuration has disabled ingress from WARP clients. To enable this, set "warp-routing:\n\t enabled: true" in your config.yaml`)
p.log.Error().Msg(err.Error())
return err
}
if r.URL != nil {
return r.URL.Host, nil
serveCtx, cancel := context.WithCancel(ctx)
defer cancel()
logFields := logFields{
cfRay: req.CFRay,
lbProbe: req.LBProbe,
rule: ingress.ServiceWarpRouting,
}
return "", errors.New("host not set in incoming request")
if err := p.proxyStream(serveCtx, rwa, req.Dest, p.warpRouting.Proxy, logFields); err != nil {
p.logRequestError(err, req.CFRay, "", ingress.ServiceWarpRouting)
return err
}
return nil
}
func ruleField(ing ingress.Ingress, ruleNum int) (ruleID string, srv string) {
@@ -149,13 +146,15 @@ func ruleField(ing ingress.Ingress, ruleNum int) (ruleID string, srv string) {
return fmt.Sprintf("%d", ruleNum), srv
}
func (p *proxy) proxyHTTPRequest(
// ProxyHTTPRequest proxies requests of underlying type http and websocket to the origin service.
func (p *Proxy) proxyHTTPRequest(
w connection.ResponseWriter,
req *http.Request,
httpService ingress.HTTPOriginProxy,
isWebsocket bool,
disableChunkedEncoding bool,
fields logFields) error {
fields logFields,
) error {
roundTripReq := req
if isWebsocket {
roundTripReq = req.Clone(req.Context())
@@ -214,17 +213,17 @@ func (p *proxy) proxyHTTPRequest(
defer p.bufferPool.Put(buf)
_, _ = io.CopyBuffer(w, resp.Body, buf)
}
p.logOriginResponse(resp, fields)
return nil
}
// proxyStreamRequest first establish a connection with origin, then it writes the status code and headers, and finally it streams data between
// eyeball and origin.
func (p *proxy) proxyStreamRequest(
serveCtx context.Context,
w connection.ResponseWriter,
// proxyStream proxies type TCP and other underlying types if the connection is defined as a stream oriented
// ingress rule.
func (p *Proxy) proxyStream(
ctx context.Context,
rwa connection.ReadWriteAcker,
dest string,
req *http.Request,
connectionProxy ingress.StreamBasedOriginProxy,
fields logFields,
) error {
@@ -233,21 +232,11 @@ func (p *proxy) proxyStreamRequest(
return err
}
resp := &http.Response{
Status: switchingProtocolText,
StatusCode: http.StatusSwitchingProtocols,
ContentLength: -1,
}
if secWebsocketKey := req.Header.Get("Sec-WebSocket-Key"); secWebsocketKey != "" {
resp.Header = websocket.NewResponseHeader(req)
}
if err = w.WriteRespHeaders(resp.StatusCode, resp.Header); err != nil {
if err := rwa.AckConnection(); err != nil {
return err
}
streamCtx, cancel := context.WithCancel(serveCtx)
streamCtx, cancel := context.WithCancel(ctx)
defer cancel()
go func() {
@@ -256,12 +245,7 @@ func (p *proxy) proxyStreamRequest(
originConn.Close()
}()
eyeballStream := &bidirectionalStream{
writer: w,
reader: req.Body,
}
originConn.Stream(serveCtx, eyeballStream, p.log)
p.logOriginResponse(resp, fields)
originConn.Stream(ctx, rwa, p.log)
return nil
}
@@ -278,7 +262,7 @@ func (wr *bidirectionalStream) Write(p []byte) (n int, err error) {
return wr.writer.Write(p)
}
func (p *proxy) writeEventStream(w connection.ResponseWriter, respBody io.ReadCloser) {
func (p *Proxy) writeEventStream(w connection.ResponseWriter, respBody io.ReadCloser) {
reader := bufio.NewReader(respBody)
for {
line, err := reader.ReadBytes('\n')
@@ -289,7 +273,7 @@ func (p *proxy) writeEventStream(w connection.ResponseWriter, respBody io.ReadCl
}
}
func (p *proxy) appendTagHeaders(r *http.Request) {
func (p *Proxy) appendTagHeaders(r *http.Request) {
for _, tag := range p.tags {
r.Header.Add(TagHeaderNamePrefix+tag.Name, tag.Value)
}
@@ -301,7 +285,7 @@ type logFields struct {
rule interface{}
}
func (p *proxy) logRequest(r *http.Request, fields logFields) {
func (p *Proxy) logRequest(r *http.Request, fields logFields) {
if fields.cfRay != "" {
p.log.Debug().Msgf("CF-RAY: %s %s %s %s", fields.cfRay, r.Method, r.URL, r.Proto)
} else if fields.lbProbe {
@@ -324,7 +308,7 @@ func (p *proxy) logRequest(r *http.Request, fields logFields) {
}
}
func (p *proxy) logOriginResponse(resp *http.Response, fields logFields) {
func (p *Proxy) logOriginResponse(resp *http.Response, fields logFields) {
responseByCode.WithLabelValues(strconv.Itoa(resp.StatusCode)).Inc()
if fields.cfRay != "" {
p.log.Debug().Msgf("CF-RAY: %s Status: %s served by ingress %d", fields.cfRay, resp.Status, fields.rule)
@@ -342,7 +326,7 @@ func (p *proxy) logOriginResponse(resp *http.Response, fields logFields) {
}
}
func (p *proxy) logRequestError(err error, cfRay string, rule, service string) {
func (p *Proxy) logRequestError(err error, cfRay string, rule, service string) {
requestErrors.Inc()
log := p.log.Error().Err(err)
if cfRay != "" {
@@ -357,10 +341,11 @@ func (p *proxy) logRequestError(err error, cfRay string, rule, service string) {
log.Msg("")
}
func findCfRayHeader(req *http.Request) string {
return req.Header.Get("Cf-Ray")
}
func isLBProbeRequest(req *http.Request) bool {
return strings.HasPrefix(req.UserAgent(), lbProbeUserAgentPrefix)
func getDestFromRule(rule *ingress.Rule, req *http.Request) (string, error) {
switch rule.Service.String() {
case ingress.ServiceBastion:
return carrier.ResolveBastionDest(req)
default:
return rule.Service.String(), nil
}
}

View File

@@ -46,6 +46,10 @@ func newMockHTTPRespWriter() *mockHTTPRespWriter {
}
}
func (w *mockHTTPRespWriter) WriteResponse() error {
return nil
}
func (w *mockHTTPRespWriter) WriteRespHeaders(status int, header http.Header) error {
w.WriteHeader(status)
for header, val := range header {
@@ -146,7 +150,7 @@ func testProxyHTTP(proxy connection.OriginProxy) func(t *testing.T) {
req, err := http.NewRequest(http.MethodGet, "http://localhost:8080", nil)
require.NoError(t, err)
err = proxy.Proxy(responseWriter, req, connection.TypeHTTP)
err = proxy.ProxyHTTP(responseWriter, req, false)
require.NoError(t, err)
assert.Equal(t, http.StatusOK, responseWriter.Code)
@@ -170,7 +174,7 @@ func testProxyWebsocket(proxy connection.OriginProxy) func(t *testing.T) {
errGroup, ctx := errgroup.WithContext(ctx)
errGroup.Go(func() error {
err = proxy.Proxy(responseWriter, req, connection.TypeWebsocket)
err = proxy.ProxyHTTP(responseWriter, req, true)
require.NoError(t, err)
require.Equal(t, http.StatusSwitchingProtocols, responseWriter.Code)
@@ -231,7 +235,7 @@ func testProxySSE(proxy connection.OriginProxy) func(t *testing.T) {
wg.Add(1)
go func() {
defer wg.Done()
err = proxy.Proxy(responseWriter, req, connection.TypeHTTP)
err = proxy.ProxyHTTP(responseWriter, req, false)
require.NoError(t, err)
require.Equal(t, http.StatusOK, responseWriter.Code)
@@ -330,7 +334,7 @@ func runIngressTestScenarios(t *testing.T, unvalidatedIngress []config.Unvalidat
req, err := http.NewRequest(http.MethodGet, test.url, nil)
require.NoError(t, err)
err = proxy.Proxy(responseWriter, req, connection.TypeHTTP)
err = proxy.ProxyHTTP(responseWriter, req, false)
require.NoError(t, err)
assert.Equal(t, test.expectedStatus, responseWriter.Code)
@@ -358,7 +362,7 @@ func (errorOriginTransport) RoundTrip(*http.Request) (*http.Response, error) {
}
func TestProxyError(t *testing.T) {
ingress := ingress.Ingress{
ing := ingress.Ingress{
Rules: []ingress.Rule{
{
Hostname: "*",
@@ -372,13 +376,13 @@ func TestProxyError(t *testing.T) {
log := zerolog.Nop()
proxy := NewOriginProxy(ingress, unusedWarpRoutingService, testTags, &log)
proxy := NewOriginProxy(ing, unusedWarpRoutingService, testTags, &log)
responseWriter := newMockHTTPRespWriter()
req, err := http.NewRequest(http.MethodGet, "http://127.0.0.1", nil)
assert.NoError(t, err)
assert.Error(t, proxy.Proxy(responseWriter, req, connection.TypeHTTP))
assert.Error(t, proxy.ProxyHTTP(responseWriter, req, false))
}
type replayer struct {
@@ -617,6 +621,7 @@ func TestConnections(t *testing.T) {
ingressRule.StartOrigins(&wg, logger, ctx.Done(), errC)
proxy := NewOriginProxy(ingressRule, test.args.warpRoutingService, testTags, logger)
dest := ln.Addr().String()
req, err := http.NewRequest(
http.MethodGet,
test.args.ingressServiceScheme+ln.Addr().String(),
@@ -634,8 +639,12 @@ func TestConnections(t *testing.T) {
replayer.Write(resp)
}()
}
err = proxy.Proxy(respWriter, req, test.args.connectionType)
if test.args.connectionType == connection.TypeTCP {
rws := connection.NewHTTPResponseReadWriterAcker(respWriter, req)
err = proxy.ProxyTCP(ctx, rws, &connection.TCPRequest{Dest: dest})
} else {
err = proxy.ProxyHTTP(respWriter, req, test.args.connectionType == connection.TypeWebsocket)
}
cancel()
assert.Equal(t, test.want.err, err != nil)
@@ -829,6 +838,10 @@ func newTCPRespWriter(w io.Writer) *mockTCPRespWriter {
}
}
func (m *mockTCPRespWriter) Read(p []byte) (n int, err error) {
return len(p), nil
}
func (m *mockTCPRespWriter) Write(p []byte) (n int, err error) {
return m.w.Write(p)
}

View File

@@ -26,7 +26,6 @@ import (
const (
dialTimeout = 15 * time.Second
lbProbeUserAgentPrefix = "Mozilla/5.0 (compatible; Cloudflare-Traffic-Manager/1.0; +https://www.cloudflare.com/traffic-manager/;"
FeatureSerializedHeaders = "serialized_headers"
FeatureQuickReconnects = "quick_reconnects"
)
@@ -417,6 +416,7 @@ func ServeHTTP2(
config.Observer,
connIndex,
connectedFuse,
config.Log,
gracefulShutdownC,
)