TUN-4655: ingress.StreamBasedProxy.EstablishConnection takes dest input

This change extracts the need for EstablishConnection to know about a
request's entire context. It also removes the concern of populating the
http.Response from EstablishConnection's responsibilities.
This commit is contained in:
Sudarsan Reddy
2021-07-01 19:30:26 +01:00
parent f1b57526b3
commit d678584d89
7 changed files with 69 additions and 94 deletions

View File

@@ -103,7 +103,7 @@ func NewWarpRoutingService() *WarpRoutingService {
}
// Get a single origin service from the CLI/config.
func parseSingleOriginService(c *cli.Context, allowURLFromArgs bool) (originService, error) {
func parseSingleOriginService(c *cli.Context, allowURLFromArgs bool) (OriginService, error) {
if c.IsSet("hello-world") {
return new(helloWorld), nil
}
@@ -167,7 +167,7 @@ func validate(ingress []config.UnvalidatedIngressRule, defaults OriginRequestCon
rules := make([]Rule, len(ingress))
for i, r := range ingress {
cfg := setConfig(defaults, r.OriginRequest)
var service originService
var service OriginService
if prefix := "unix:"; strings.HasPrefix(r.Service, prefix) {
// No validation necessary for unix socket filepath services

View File

@@ -6,9 +6,6 @@ import (
"net/http"
"github.com/pkg/errors"
"github.com/cloudflare/cloudflared/carrier"
"github.com/cloudflare/cloudflared/websocket"
)
var (
@@ -24,7 +21,7 @@ type HTTPOriginProxy interface {
// StreamBasedOriginProxy can be implemented by origin services that want to proxy ws/TCP.
type StreamBasedOriginProxy interface {
EstablishConnection(r *http.Request) (OriginConnection, *http.Response, error)
EstablishConnection(dest string) (OriginConnection, error)
}
func (o *unixSocketPath) RoundTrip(req *http.Request) (*http.Response, error) {
@@ -54,73 +51,36 @@ func (o *statusCode) RoundTrip(_ *http.Request) (*http.Response, error) {
return o.resp, nil
}
func (o *rawTCPService) EstablishConnection(r *http.Request) (OriginConnection, *http.Response, error) {
dest, err := getRequestHost(r)
if err != nil {
return nil, nil, err
}
func (o *rawTCPService) EstablishConnection(dest string) (OriginConnection, error) {
conn, err := net.Dial("tcp", dest)
if err != nil {
return nil, nil, err
return nil, err
}
originConn := &tcpConnection{
conn: conn,
}
resp := &http.Response{
Status: switchingProtocolText,
StatusCode: http.StatusSwitchingProtocols,
ContentLength: -1,
}
return originConn, resp, nil
return originConn, nil
}
// getRequestHost returns the host of the http.Request.
func getRequestHost(r *http.Request) (string, error) {
if r.Host != "" {
return r.Host, nil
}
if r.URL != nil {
return r.URL.Host, nil
}
return "", errors.New("host not found")
}
func (o *tcpOverWSService) EstablishConnection(r *http.Request) (OriginConnection, *http.Response, error) {
func (o *tcpOverWSService) EstablishConnection(dest string) (OriginConnection, error) {
var err error
dest := o.dest
if o.isBastion {
dest, err = carrier.ResolveBastionDest(r)
if err != nil {
return nil, nil, err
}
if !o.isBastion {
dest = o.dest
}
conn, err := net.Dial("tcp", dest)
if err != nil {
return nil, nil, err
return nil, err
}
originConn := &tcpOverWSConnection{
conn: conn,
streamHandler: o.streamHandler,
}
resp := &http.Response{
Status: switchingProtocolText,
StatusCode: http.StatusSwitchingProtocols,
Header: websocket.NewResponseHeader(r),
ContentLength: -1,
}
return originConn, resp, nil
return originConn, nil
}
func (o *socksProxyOverWSService) EstablishConnection(r *http.Request) (OriginConnection, *http.Response, error) {
originConn := o.conn
resp := &http.Response{
Status: switchingProtocolText,
StatusCode: http.StatusSwitchingProtocols,
Header: websocket.NewResponseHeader(r),
ContentLength: -1,
}
return originConn, resp, nil
func (o *socksProxyOverWSService) EstablishConnection(dest string) (OriginConnection, error) {
return o.conn, nil
}

View File

@@ -17,20 +17,6 @@ import (
"github.com/cloudflare/cloudflared/websocket"
)
// TestEstablishConnectionResponse ensures each implementation of StreamBasedOriginProxy returns
// the expected response
func assertEstablishConnectionResponse(t *testing.T,
originProxy StreamBasedOriginProxy,
req *http.Request,
expectHeader http.Header,
) {
_, resp, err := originProxy.EstablishConnection(req)
assert.NoError(t, err)
assert.Equal(t, switchingProtocolText, resp.Status)
assert.Equal(t, http.StatusSwitchingProtocols, resp.StatusCode)
assert.Equal(t, expectHeader, resp.Header)
}
func TestRawTCPServiceEstablishConnection(t *testing.T) {
originListener, err := net.Listen("tcp", "127.0.0.1:0")
require.NoError(t, err)
@@ -43,8 +29,6 @@ func TestRawTCPServiceEstablishConnection(t *testing.T) {
req, err := http.NewRequest(http.MethodGet, fmt.Sprintf("http://%s", originListener.Addr()), nil)
require.NoError(t, err)
assertEstablishConnectionResponse(t, rawTCPService, req, nil)
originListener.Close()
<-listenerClosed
@@ -52,9 +36,8 @@ func TestRawTCPServiceEstablishConnection(t *testing.T) {
require.NoError(t, err)
// Origin not listening for new connection, should return an error
_, resp, err := rawTCPService.EstablishConnection(req)
_, err = rawTCPService.EstablishConnection(req.URL.String())
require.Error(t, err)
require.Nil(t, resp)
}
func TestTCPOverWSServiceEstablishConnection(t *testing.T) {
@@ -76,12 +59,6 @@ func TestTCPOverWSServiceEstablishConnection(t *testing.T) {
bastionReq := baseReq.Clone(context.Background())
carrier.SetBastionDest(bastionReq.Header, originListener.Addr().String())
expectHeader := http.Header{
"Connection": {"Upgrade"},
"Sec-Websocket-Accept": {"s3pPLMBiTxaQ9kYGzzhZRbK+xOo="},
"Upgrade": {"websocket"},
}
tests := []struct {
testCase string
service *tcpOverWSService
@@ -109,11 +86,9 @@ func TestTCPOverWSServiceEstablishConnection(t *testing.T) {
for _, test := range tests {
t.Run(test.testCase, func(t *testing.T) {
if test.expectErr {
_, resp, err := test.service.EstablishConnection(test.req)
bastionHost, _ := carrier.ResolveBastionDest(test.req)
_, err := test.service.EstablishConnection(bastionHost)
assert.Error(t, err)
assert.Nil(t, resp)
} else {
assertEstablishConnectionResponse(t, test.service, test.req, expectHeader)
}
})
}
@@ -123,9 +98,9 @@ func TestTCPOverWSServiceEstablishConnection(t *testing.T) {
for _, service := range []*tcpOverWSService{newTCPOverWSService(originURL), newBastionService()} {
// Origin not listening for new connection, should return an error
_, resp, err := service.EstablishConnection(bastionReq)
bastionHost, _ := carrier.ResolveBastionDest(bastionReq)
_, err := service.EstablishConnection(bastionHost)
assert.Error(t, err)
assert.Nil(t, resp)
}
}

View File

@@ -20,8 +20,8 @@ import (
"github.com/cloudflare/cloudflared/tlsconfig"
)
// originService is something a tunnel can proxy traffic to.
type originService interface {
// OriginService is something a tunnel can proxy traffic to.
type OriginService interface {
String() string
// Start the origin service if it's managed by cloudflared, e.g. proxy servers or Hello World.
// If it's not managed by cloudflared, this is a no-op because the user is responsible for
@@ -238,7 +238,7 @@ func (nrc *NopReadCloser) Close() error {
return nil
}
func newHTTPTransport(service originService, cfg OriginRequestConfig, log *zerolog.Logger) (*http.Transport, error) {
func newHTTPTransport(service OriginService, cfg OriginRequestConfig, log *zerolog.Logger) (*http.Transport, error) {
originCertPool, err := tlsconfig.LoadOriginCA(cfg.CAPool, log)
if err != nil {
return nil, errors.Wrap(err, "Error loading cert pool")

View File

@@ -17,7 +17,7 @@ type Rule struct {
// A (probably local) address. Requests for a hostname which matches this
// rule's hostname pattern will be proxied to the service running on this
// address.
Service originService
Service OriginService
// Configure the request cloudflared sends to this specific origin.
Config OriginRequestConfig

View File

@@ -14,7 +14,7 @@ func Test_rule_matches(t *testing.T) {
type fields struct {
Hostname string
Path *regexp.Regexp
Service originService
Service OriginService
}
type args struct {
requestURL *url.URL