mirror of
https://github.com/cloudflare/cloudflared.git
synced 2025-07-27 20:09:58 +00:00
TUN-3863: Consolidate header handling logic in the connection package; move headers definitions from h2mux to packages that manage them; cleanup header conversions
All header transformation code from h2mux has been consolidated in the connection package since it's used by both h2mux and http2 logic. Exported headers used by proxying between edge and cloudflared so then can be shared by tunnel service on the edge. Moved access-related headers to corresponding packages that have the code that sets/uses these headers. Removed tunnel hostname tracking from h2mux since it wasn't used by anything. We will continue to set the tunnel hostname header from the edge for backward compatibilty, but it's no longer used by cloudflared. Move bastion-related logic into carrier package, untangled dependencies between carrier, origin, and websocket packages.
This commit is contained in:
@@ -25,34 +25,10 @@ type OriginConnection interface {
|
||||
|
||||
type streamHandlerFunc func(originConn io.ReadWriter, remoteConn net.Conn, log *zerolog.Logger)
|
||||
|
||||
// Stream copies copy data to & from provided io.ReadWriters.
|
||||
func Stream(conn, backendConn io.ReadWriter, log *zerolog.Logger) {
|
||||
proxyDone := make(chan struct{}, 2)
|
||||
|
||||
go func() {
|
||||
_, err := io.Copy(conn, backendConn)
|
||||
if err != nil {
|
||||
log.Debug().Msgf("conn to backendConn copy: %v", err)
|
||||
}
|
||||
proxyDone <- struct{}{}
|
||||
}()
|
||||
|
||||
go func() {
|
||||
_, err := io.Copy(backendConn, conn)
|
||||
if err != nil {
|
||||
log.Debug().Msgf("backendConn to conn copy: %v", err)
|
||||
}
|
||||
proxyDone <- struct{}{}
|
||||
}()
|
||||
|
||||
// If one side is done, we are done.
|
||||
<-proxyDone
|
||||
}
|
||||
|
||||
// DefaultStreamHandler is an implementation of streamHandlerFunc that
|
||||
// performs a two way io.Copy between originConn and remoteConn.
|
||||
func DefaultStreamHandler(originConn io.ReadWriter, remoteConn net.Conn, log *zerolog.Logger) {
|
||||
Stream(originConn, remoteConn, log)
|
||||
websocket.Stream(originConn, remoteConn, log)
|
||||
}
|
||||
|
||||
// tcpConnection is an OriginConnection that directly streams to raw TCP.
|
||||
@@ -61,7 +37,7 @@ type tcpConnection struct {
|
||||
}
|
||||
|
||||
func (tc *tcpConnection) Stream(ctx context.Context, tunnelConn io.ReadWriter, log *zerolog.Logger) {
|
||||
Stream(tunnelConn, tc.conn, log)
|
||||
websocket.Stream(tunnelConn, tc.conn, log)
|
||||
}
|
||||
|
||||
func (tc *tcpConnection) Close() {
|
||||
@@ -89,7 +65,7 @@ type wsConnection struct {
|
||||
}
|
||||
|
||||
func (wsc *wsConnection) Stream(ctx context.Context, tunnelConn io.ReadWriter, log *zerolog.Logger) {
|
||||
Stream(tunnelConn, wsc.wsConn.UnderlyingConn(), log)
|
||||
websocket.Stream(tunnelConn, wsc.wsConn.UnderlyingConn(), log)
|
||||
}
|
||||
|
||||
func (wsc *wsConnection) Close() {
|
||||
|
@@ -22,6 +22,7 @@ import (
|
||||
|
||||
"github.com/cloudflare/cloudflared/logger"
|
||||
"github.com/cloudflare/cloudflared/socks"
|
||||
"github.com/cloudflare/cloudflared/websocket"
|
||||
)
|
||||
|
||||
const (
|
||||
@@ -157,7 +158,7 @@ func TestSocksStreamWSOverTCPConnection(t *testing.T) {
|
||||
require.NoError(t, err)
|
||||
defer wsForwarderInConn.Close()
|
||||
|
||||
Stream(wsForwarderInConn, &wsEyeball{wsForwarderOutConn}, testLogger)
|
||||
websocket.Stream(wsForwarderInConn, &wsEyeball{wsForwarderOutConn}, testLogger)
|
||||
return nil
|
||||
})
|
||||
|
||||
|
@@ -4,12 +4,10 @@ import (
|
||||
"fmt"
|
||||
"net"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"strings"
|
||||
|
||||
"github.com/pkg/errors"
|
||||
|
||||
"github.com/cloudflare/cloudflared/h2mux"
|
||||
"github.com/cloudflare/cloudflared/carrier"
|
||||
"github.com/cloudflare/cloudflared/websocket"
|
||||
)
|
||||
|
||||
@@ -106,7 +104,7 @@ func (o *tcpOverWSService) EstablishConnection(r *http.Request) (OriginConnectio
|
||||
var err error
|
||||
dest := o.dest
|
||||
if o.isBastion {
|
||||
dest, err = o.bastionDest(r)
|
||||
dest, err = carrier.ResolveBastionDest(r)
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
@@ -130,23 +128,6 @@ func (o *tcpOverWSService) EstablishConnection(r *http.Request) (OriginConnectio
|
||||
|
||||
}
|
||||
|
||||
func (o *tcpOverWSService) bastionDest(r *http.Request) (string, error) {
|
||||
jumpDestination := r.Header.Get(h2mux.CFJumpDestinationHeader)
|
||||
if jumpDestination == "" {
|
||||
return "", fmt.Errorf("Did not receive final destination from client. The --destination flag is likely not set on the client side")
|
||||
}
|
||||
// Strip scheme and path set by client. Without a scheme
|
||||
// Parsing a hostname and path without scheme might not return an error due to parsing ambiguities
|
||||
if jumpURL, err := url.Parse(jumpDestination); err == nil && jumpURL.Host != "" {
|
||||
return removePath(jumpURL.Host), nil
|
||||
}
|
||||
return removePath(jumpDestination), nil
|
||||
}
|
||||
|
||||
func removePath(dest string) string {
|
||||
return strings.SplitN(dest, "/", 2)[0]
|
||||
}
|
||||
|
||||
func (o *socksProxyOverWSService) EstablishConnection(r *http.Request) (OriginConnection, *http.Response, error) {
|
||||
originConn := o.conn
|
||||
resp := &http.Response{
|
||||
|
@@ -14,7 +14,7 @@ import (
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
"github.com/cloudflare/cloudflared/h2mux"
|
||||
"github.com/cloudflare/cloudflared/carrier"
|
||||
"github.com/cloudflare/cloudflared/websocket"
|
||||
)
|
||||
|
||||
@@ -126,7 +126,7 @@ func TestTCPOverWSServiceEstablishConnection(t *testing.T) {
|
||||
baseReq.Header.Set("Sec-Websocket-Key", "dGhlIHNhbXBsZSBub25jZQ==")
|
||||
|
||||
bastionReq := baseReq.Clone(context.Background())
|
||||
bastionReq.Header.Set(h2mux.CFJumpDestinationHeader, originListener.Addr().String())
|
||||
carrier.SetBastionDest(bastionReq.Header, originListener.Addr().String())
|
||||
|
||||
expectHeader := http.Header{
|
||||
"Connection": {"Upgrade"},
|
||||
@@ -135,19 +135,23 @@ func TestTCPOverWSServiceEstablishConnection(t *testing.T) {
|
||||
}
|
||||
|
||||
tests := []struct {
|
||||
testCase string
|
||||
service *tcpOverWSService
|
||||
req *http.Request
|
||||
expectErr bool
|
||||
}{
|
||||
{
|
||||
service: newTCPOverWSService(originURL),
|
||||
req: baseReq,
|
||||
testCase: "specific TCP service",
|
||||
service: newTCPOverWSService(originURL),
|
||||
req: baseReq,
|
||||
},
|
||||
{
|
||||
service: newBastionService(),
|
||||
req: bastionReq,
|
||||
testCase: "bastion service",
|
||||
service: newBastionService(),
|
||||
req: bastionReq,
|
||||
},
|
||||
{
|
||||
testCase: "invalid bastion request",
|
||||
service: newBastionService(),
|
||||
req: baseReq,
|
||||
expectErr: true,
|
||||
@@ -155,13 +159,15 @@ func TestTCPOverWSServiceEstablishConnection(t *testing.T) {
|
||||
}
|
||||
|
||||
for _, test := range tests {
|
||||
if test.expectErr {
|
||||
_, resp, err := test.service.EstablishConnection(test.req)
|
||||
assert.Error(t, err)
|
||||
assert.Nil(t, resp)
|
||||
} else {
|
||||
assertEstablishConnectionResponse(t, test.service, test.req, expectHeader)
|
||||
}
|
||||
t.Run(test.testCase, func(t *testing.T) {
|
||||
if test.expectErr {
|
||||
_, resp, err := test.service.EstablishConnection(test.req)
|
||||
assert.Error(t, err)
|
||||
assert.Nil(t, resp)
|
||||
} else {
|
||||
assertEstablishConnectionResponse(t, test.service, test.req, expectHeader)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
originListener.Close()
|
||||
@@ -175,104 +181,6 @@ func TestTCPOverWSServiceEstablishConnection(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestBastionDestination(t *testing.T) {
|
||||
canonicalJumpDestHeader := http.CanonicalHeaderKey(h2mux.CFJumpDestinationHeader)
|
||||
tests := []struct {
|
||||
name string
|
||||
header http.Header
|
||||
expectedDest string
|
||||
wantErr bool
|
||||
}{
|
||||
{
|
||||
name: "hostname destination",
|
||||
header: http.Header{
|
||||
canonicalJumpDestHeader: []string{"localhost"},
|
||||
},
|
||||
expectedDest: "localhost",
|
||||
},
|
||||
{
|
||||
name: "hostname destination with port",
|
||||
header: http.Header{
|
||||
canonicalJumpDestHeader: []string{"localhost:9000"},
|
||||
},
|
||||
expectedDest: "localhost:9000",
|
||||
},
|
||||
{
|
||||
name: "hostname destination with scheme and port",
|
||||
header: http.Header{
|
||||
canonicalJumpDestHeader: []string{"ssh://localhost:9000"},
|
||||
},
|
||||
expectedDest: "localhost:9000",
|
||||
},
|
||||
{
|
||||
name: "full hostname url",
|
||||
header: http.Header{
|
||||
canonicalJumpDestHeader: []string{"ssh://localhost:9000/metrics"},
|
||||
},
|
||||
expectedDest: "localhost:9000",
|
||||
},
|
||||
{
|
||||
name: "hostname destination with port and path",
|
||||
header: http.Header{
|
||||
canonicalJumpDestHeader: []string{"localhost:9000/metrics"},
|
||||
},
|
||||
expectedDest: "localhost:9000",
|
||||
},
|
||||
{
|
||||
name: "ip destination",
|
||||
header: http.Header{
|
||||
canonicalJumpDestHeader: []string{"127.0.0.1"},
|
||||
},
|
||||
expectedDest: "127.0.0.1",
|
||||
},
|
||||
{
|
||||
name: "ip destination with port",
|
||||
header: http.Header{
|
||||
canonicalJumpDestHeader: []string{"127.0.0.1:9000"},
|
||||
},
|
||||
expectedDest: "127.0.0.1:9000",
|
||||
},
|
||||
{
|
||||
name: "ip destination with port and path",
|
||||
header: http.Header{
|
||||
canonicalJumpDestHeader: []string{"127.0.0.1:9000/metrics"},
|
||||
},
|
||||
expectedDest: "127.0.0.1:9000",
|
||||
},
|
||||
{
|
||||
name: "ip destination with schem and port",
|
||||
header: http.Header{
|
||||
canonicalJumpDestHeader: []string{"tcp://127.0.0.1:9000"},
|
||||
},
|
||||
expectedDest: "127.0.0.1:9000",
|
||||
},
|
||||
{
|
||||
name: "full ip url",
|
||||
header: http.Header{
|
||||
canonicalJumpDestHeader: []string{"ssh://127.0.0.1:9000/metrics"},
|
||||
},
|
||||
expectedDest: "127.0.0.1:9000",
|
||||
},
|
||||
{
|
||||
name: "no destination",
|
||||
wantErr: true,
|
||||
},
|
||||
}
|
||||
s := newBastionService()
|
||||
for _, test := range tests {
|
||||
r := &http.Request{
|
||||
Header: test.header,
|
||||
}
|
||||
dest, err := s.bastionDest(r)
|
||||
if test.wantErr {
|
||||
assert.Error(t, err, "Test %s expects error", test.name)
|
||||
} else {
|
||||
assert.NoError(t, err, "Test %s expects no error, got error %v", test.name, err)
|
||||
assert.Equal(t, test.expectedDest, dest, "Test %s expect dest %s, got %s", test.name, test.expectedDest, dest)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestHTTPServiceHostHeaderOverride(t *testing.T) {
|
||||
cfg := OriginRequestConfig{
|
||||
HTTPHostHeader: t.Name(),
|
||||
|
Reference in New Issue
Block a user