mirror of
https://github.com/cloudflare/cloudflared.git
synced 2025-07-27 20:19:57 +00:00
TUN-3617: Separate service from client, and implement different client for http vs. tcp origins
- extracted ResponseWriter from proxyConnection - added bastion tests over websocket - removed HTTPResp() - added some docstrings - Renamed some ingress clients as proxies - renamed instances of client to proxy in connection and origin - Stream no longer takes a context and logger.Service
This commit is contained in:
@@ -85,16 +85,24 @@ func NewSingleOrigin(c *cli.Context, allowURLFromArgs bool) (Ingress, error) {
|
||||
}
|
||||
|
||||
// Get a single origin service from the CLI/config.
|
||||
func parseSingleOriginService(c *cli.Context, allowURLFromArgs bool) (OriginService, error) {
|
||||
func parseSingleOriginService(c *cli.Context, allowURLFromArgs bool) (originService, error) {
|
||||
if c.IsSet("hello-world") {
|
||||
return new(helloWorld), nil
|
||||
}
|
||||
if c.IsSet("url") || c.IsSet(config.BastionFlag) {
|
||||
if c.IsSet(config.BastionFlag) {
|
||||
return newBridgeService(), nil
|
||||
}
|
||||
if c.IsSet("url") {
|
||||
originURL, err := config.ValidateUrl(c, allowURLFromArgs)
|
||||
if err != nil {
|
||||
return nil, errors.Wrap(err, "Error validating origin URL")
|
||||
}
|
||||
return &localService{URL: originURL, RootURL: originURL}, nil
|
||||
if isHTTPService(originURL) {
|
||||
return &httpService{
|
||||
url: originURL,
|
||||
}, nil
|
||||
}
|
||||
return newSingleTCPService(originURL), nil
|
||||
}
|
||||
if c.IsSet("unix-socket") {
|
||||
path, err := config.ValidateUnixSocket(c)
|
||||
@@ -104,7 +112,7 @@ func parseSingleOriginService(c *cli.Context, allowURLFromArgs bool) (OriginServ
|
||||
return &unixSocketPath{path: path}, nil
|
||||
}
|
||||
u, err := url.Parse("http://localhost:8080")
|
||||
return &localService{URL: u, RootURL: u}, err
|
||||
return &httpService{url: u}, err
|
||||
}
|
||||
|
||||
// IsEmpty checks if there are any ingress rules.
|
||||
@@ -136,7 +144,7 @@ func validate(ingress []config.UnvalidatedIngressRule, defaults OriginRequestCon
|
||||
rules := make([]Rule, len(ingress))
|
||||
for i, r := range ingress {
|
||||
cfg := setConfig(defaults, r.OriginRequest)
|
||||
var service OriginService
|
||||
var service originService
|
||||
|
||||
if prefix := "unix:"; strings.HasPrefix(r.Service, prefix) {
|
||||
// No validation necessary for unix socket filepath services
|
||||
@@ -156,7 +164,7 @@ func validate(ingress []config.UnvalidatedIngressRule, defaults OriginRequestCon
|
||||
// overwrite the localService.URL field when `start` is called. So,
|
||||
// leave the URL field empty for now.
|
||||
cfg.BastionMode = true
|
||||
service = new(localService)
|
||||
service = newBridgeService()
|
||||
} else {
|
||||
// Validate URL services
|
||||
u, err := url.Parse(r.Service)
|
||||
@@ -171,8 +179,11 @@ func validate(ingress []config.UnvalidatedIngressRule, defaults OriginRequestCon
|
||||
if u.Path != "" {
|
||||
return Ingress{}, fmt.Errorf("%s is an invalid address, ingress rules don't support proxying to a different path on the origin service. The path will be the same as the eyeball request's path", r.Service)
|
||||
}
|
||||
serviceURL := localService{URL: u}
|
||||
service = &serviceURL
|
||||
if isHTTPService(u) {
|
||||
service = &httpService{url: u}
|
||||
} else {
|
||||
service = newSingleTCPService(u)
|
||||
}
|
||||
}
|
||||
|
||||
if err := validateHostname(r, i, len(ingress)); err != nil {
|
||||
@@ -241,3 +252,7 @@ func ParseIngress(conf *config.Configuration) (Ingress, error) {
|
||||
}
|
||||
return validate(conf.Ingress, originRequestFromYAML(conf.OriginRequest))
|
||||
}
|
||||
|
||||
func isHTTPService(url *url.URL) bool {
|
||||
return url.Scheme == "http" || url.Scheme == "https" || url.Scheme == "ws" || url.Scheme == "wss"
|
||||
}
|
||||
|
@@ -61,12 +61,12 @@ ingress:
|
||||
want: []Rule{
|
||||
{
|
||||
Hostname: "tunnel1.example.com",
|
||||
Service: &localService{URL: localhost8000},
|
||||
Service: &httpService{url: localhost8000},
|
||||
Config: defaultConfig,
|
||||
},
|
||||
{
|
||||
Hostname: "*",
|
||||
Service: &localService{URL: localhost8001},
|
||||
Service: &httpService{url: localhost8001},
|
||||
Config: defaultConfig,
|
||||
},
|
||||
},
|
||||
@@ -82,7 +82,22 @@ extraKey: extraValue
|
||||
want: []Rule{
|
||||
{
|
||||
Hostname: "*",
|
||||
Service: &localService{URL: localhost8000},
|
||||
Service: &httpService{url: localhost8000},
|
||||
Config: defaultConfig,
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "ws service",
|
||||
args: args{rawYAML: `
|
||||
ingress:
|
||||
- hostname: "*"
|
||||
service: wss://localhost:8000
|
||||
`},
|
||||
want: []Rule{
|
||||
{
|
||||
Hostname: "*",
|
||||
Service: &httpService{url: MustParseURL(t, "wss://localhost:8000")},
|
||||
Config: defaultConfig,
|
||||
},
|
||||
},
|
||||
@@ -95,7 +110,7 @@ ingress:
|
||||
`},
|
||||
want: []Rule{
|
||||
{
|
||||
Service: &localService{URL: localhost8000},
|
||||
Service: &httpService{url: localhost8000},
|
||||
Config: defaultConfig,
|
||||
},
|
||||
},
|
||||
@@ -209,6 +224,85 @@ ingress:
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "TCP services",
|
||||
args: args{rawYAML: `
|
||||
ingress:
|
||||
- hostname: tcp.foo.com
|
||||
service: tcp://127.0.0.1
|
||||
- hostname: tcp2.foo.com
|
||||
service: tcp://localhost:8000
|
||||
- service: http_status:404
|
||||
`},
|
||||
want: []Rule{
|
||||
{
|
||||
Hostname: "tcp.foo.com",
|
||||
Service: newSingleTCPService(MustParseURL(t, "tcp://127.0.0.1:7864")),
|
||||
Config: defaultConfig,
|
||||
},
|
||||
{
|
||||
Hostname: "tcp2.foo.com",
|
||||
Service: newSingleTCPService(MustParseURL(t, "tcp://localhost:8000")),
|
||||
Config: defaultConfig,
|
||||
},
|
||||
{
|
||||
Service: &fourOhFour,
|
||||
Config: defaultConfig,
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "SSH services",
|
||||
args: args{rawYAML: `
|
||||
ingress:
|
||||
- service: ssh://127.0.0.1
|
||||
`},
|
||||
want: []Rule{
|
||||
{
|
||||
Service: newSingleTCPService(MustParseURL(t, "ssh://127.0.0.1:22")),
|
||||
Config: defaultConfig,
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "RDP services",
|
||||
args: args{rawYAML: `
|
||||
ingress:
|
||||
- service: rdp://127.0.0.1
|
||||
`},
|
||||
want: []Rule{
|
||||
{
|
||||
Service: newSingleTCPService(MustParseURL(t, "rdp://127.0.0.1:3389")),
|
||||
Config: defaultConfig,
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "SMB services",
|
||||
args: args{rawYAML: `
|
||||
ingress:
|
||||
- service: smb://127.0.0.1
|
||||
`},
|
||||
want: []Rule{
|
||||
{
|
||||
Service: newSingleTCPService(MustParseURL(t, "smb://127.0.0.1:445")),
|
||||
Config: defaultConfig,
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "Other TCP services",
|
||||
args: args{rawYAML: `
|
||||
ingress:
|
||||
- service: ftp://127.0.0.1
|
||||
`},
|
||||
want: []Rule{
|
||||
{
|
||||
Service: newSingleTCPService(MustParseURL(t, "ftp://127.0.0.1")),
|
||||
Config: defaultConfig,
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "URL isn't necessary if using bastion",
|
||||
args: args{rawYAML: `
|
||||
@@ -221,7 +315,7 @@ ingress:
|
||||
want: []Rule{
|
||||
{
|
||||
Hostname: "bastion.foo.com",
|
||||
Service: &localService{},
|
||||
Service: newBridgeService(),
|
||||
Config: setConfig(originRequestFromYAML(config.OriginRequestConfig{}), config.OriginRequestConfig{BastionMode: &tr}),
|
||||
},
|
||||
{
|
||||
@@ -241,7 +335,7 @@ ingress:
|
||||
want: []Rule{
|
||||
{
|
||||
Hostname: "bastion.foo.com",
|
||||
Service: &localService{},
|
||||
Service: newBridgeService(),
|
||||
Config: setConfig(originRequestFromYAML(config.OriginRequestConfig{}), config.OriginRequestConfig{BastionMode: &tr}),
|
||||
},
|
||||
{
|
||||
@@ -409,6 +503,37 @@ func TestFindMatchingRule(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestIsHTTPService(t *testing.T) {
|
||||
tests := []struct {
|
||||
url *url.URL
|
||||
isHTTP bool
|
||||
}{
|
||||
{
|
||||
url: MustParseURL(t, "http://localhost"),
|
||||
isHTTP: true,
|
||||
},
|
||||
{
|
||||
url: MustParseURL(t, "https://127.0.0.1:8000"),
|
||||
isHTTP: true,
|
||||
},
|
||||
{
|
||||
url: MustParseURL(t, "ws://localhost"),
|
||||
isHTTP: true,
|
||||
},
|
||||
{
|
||||
url: MustParseURL(t, "wss://localhost:8000"),
|
||||
isHTTP: true,
|
||||
},
|
||||
{
|
||||
url: MustParseURL(t, "tcp://localhost:9000"),
|
||||
isHTTP: false,
|
||||
},
|
||||
}
|
||||
for _, test := range tests {
|
||||
assert.Equal(t, test.isHTTP, isHTTPService(test.url))
|
||||
}
|
||||
}
|
||||
|
||||
func mustParsePath(t *testing.T, path string) *regexp.Regexp {
|
||||
regexp, err := regexp.Compile(path)
|
||||
assert.NoError(t, err)
|
||||
|
62
ingress/origin_connection.go
Normal file
62
ingress/origin_connection.go
Normal file
@@ -0,0 +1,62 @@
|
||||
package ingress
|
||||
|
||||
import (
|
||||
"io"
|
||||
"net"
|
||||
"net/http"
|
||||
|
||||
"github.com/cloudflare/cloudflared/websocket"
|
||||
gws "github.com/gorilla/websocket"
|
||||
)
|
||||
|
||||
// OriginConnection is a way to stream to a service running on the user's origin.
|
||||
// Different concrete implementations will stream different protocols as long as they are io.ReadWriters.
|
||||
type OriginConnection interface {
|
||||
// Stream should generally be implemented as a bidirectional io.Copy.
|
||||
Stream(tunnelConn io.ReadWriter)
|
||||
Close()
|
||||
}
|
||||
|
||||
// tcpConnection is an OriginConnection that directly streams to raw TCP.
|
||||
type tcpConnection struct {
|
||||
conn net.Conn
|
||||
streamHandler func(tunnelConn io.ReadWriter, originConn net.Conn)
|
||||
}
|
||||
|
||||
func (tc *tcpConnection) Stream(tunnelConn io.ReadWriter) {
|
||||
tc.streamHandler(tunnelConn, tc.conn)
|
||||
}
|
||||
|
||||
func (tc *tcpConnection) Close() {
|
||||
tc.conn.Close()
|
||||
}
|
||||
|
||||
// wsConnection is an OriginConnection that streams to TCP packets by encapsulating them in Websockets.
|
||||
// TODO: TUN-3710 Remove wsConnection and have helloworld service reuse tcpConnection like bridgeService does.
|
||||
type wsConnection struct {
|
||||
wsConn *gws.Conn
|
||||
resp *http.Response
|
||||
}
|
||||
|
||||
func (wsc *wsConnection) Stream(tunnelConn io.ReadWriter) {
|
||||
websocket.Stream(tunnelConn, wsc.wsConn.UnderlyingConn())
|
||||
}
|
||||
|
||||
func (wsc *wsConnection) Close() {
|
||||
wsc.resp.Body.Close()
|
||||
wsc.wsConn.Close()
|
||||
}
|
||||
|
||||
func newWSConnection(transport *http.Transport, r *http.Request) (OriginConnection, error) {
|
||||
d := &gws.Dialer{
|
||||
TLSClientConfig: transport.TLSClientConfig,
|
||||
}
|
||||
wsConn, resp, err := websocket.ClientConnect(r, d)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &wsConnection{
|
||||
wsConn,
|
||||
resp,
|
||||
}, nil
|
||||
}
|
100
ingress/origin_proxy.go
Normal file
100
ingress/origin_proxy.go
Normal file
@@ -0,0 +1,100 @@
|
||||
package ingress
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"io"
|
||||
"net"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"strings"
|
||||
|
||||
"github.com/cloudflare/cloudflared/h2mux"
|
||||
)
|
||||
|
||||
// HTTPOriginProxy can be implemented by origin services that want to proxy http requests.
|
||||
type HTTPOriginProxy interface {
|
||||
// RoundTrip is how cloudflared proxies eyeball requests to the actual origin services
|
||||
http.RoundTripper
|
||||
}
|
||||
|
||||
// StreamBasedOriginProxy can be implemented by origin services that want to proxy at the L4 level.
|
||||
type StreamBasedOriginProxy interface {
|
||||
EstablishConnection(r *http.Request) (OriginConnection, error)
|
||||
}
|
||||
|
||||
func (o *unixSocketPath) RoundTrip(req *http.Request) (*http.Response, error) {
|
||||
return o.transport.RoundTrip(req)
|
||||
}
|
||||
|
||||
// TODO: TUN-3636: establish connection to origins over UDS
|
||||
func (*unixSocketPath) EstablishConnection(r *http.Request) (OriginConnection, error) {
|
||||
return nil, fmt.Errorf("Unix socket service currently doesn't support proxying connections")
|
||||
}
|
||||
|
||||
func (o *httpService) RoundTrip(req *http.Request) (*http.Response, error) {
|
||||
// Rewrite the request URL so that it goes to the origin service.
|
||||
req.URL.Host = o.url.Host
|
||||
req.URL.Scheme = o.url.Scheme
|
||||
return o.transport.RoundTrip(req)
|
||||
}
|
||||
|
||||
func (o *helloWorld) RoundTrip(req *http.Request) (*http.Response, error) {
|
||||
// Rewrite the request URL so that it goes to the Hello World server.
|
||||
req.URL.Host = o.server.Addr().String()
|
||||
req.URL.Scheme = "https"
|
||||
return o.transport.RoundTrip(req)
|
||||
}
|
||||
|
||||
func (o *helloWorld) EstablishConnection(req *http.Request) (OriginConnection, error) {
|
||||
req.URL.Host = o.server.Addr().String()
|
||||
req.URL.Scheme = "wss"
|
||||
return newWSConnection(o.transport, req)
|
||||
}
|
||||
|
||||
func (o *statusCode) RoundTrip(_ *http.Request) (*http.Response, error) {
|
||||
return o.resp, nil
|
||||
}
|
||||
|
||||
func (o *bridgeService) EstablishConnection(r *http.Request) (OriginConnection, error) {
|
||||
dest, err := o.destination(r)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return o.client.connect(r, dest)
|
||||
}
|
||||
|
||||
func (o *bridgeService) destination(r *http.Request) (string, error) {
|
||||
jumpDestination := r.Header.Get(h2mux.CFJumpDestinationHeader)
|
||||
if jumpDestination == "" {
|
||||
return "", fmt.Errorf("Did not receive final destination from client. The --destination flag is likely not set on the client side")
|
||||
}
|
||||
// Strip scheme and path set by client. Without a scheme
|
||||
// Parsing a hostname and path without scheme might not return an error due to parsing ambiguities
|
||||
if jumpURL, err := url.Parse(jumpDestination); err == nil && jumpURL.Host != "" {
|
||||
return removePath(jumpURL.Host), nil
|
||||
}
|
||||
return removePath(jumpDestination), nil
|
||||
}
|
||||
|
||||
func removePath(dest string) string {
|
||||
return strings.SplitN(dest, "/", 2)[0]
|
||||
}
|
||||
|
||||
func (o *singleTCPService) EstablishConnection(r *http.Request) (OriginConnection, error) {
|
||||
return o.client.connect(r, o.dest)
|
||||
}
|
||||
|
||||
type tcpClient struct {
|
||||
streamHandler func(originConn io.ReadWriter, remoteConn net.Conn)
|
||||
}
|
||||
|
||||
func (c *tcpClient) connect(r *http.Request, addr string) (OriginConnection, error) {
|
||||
conn, err := net.Dial("tcp", addr)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &tcpConnection{
|
||||
conn: conn,
|
||||
streamHandler: c.streamHandler,
|
||||
}, nil
|
||||
}
|
107
ingress/origin_proxy_test.go
Normal file
107
ingress/origin_proxy_test.go
Normal file
@@ -0,0 +1,107 @@
|
||||
package ingress
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
"testing"
|
||||
|
||||
"github.com/cloudflare/cloudflared/h2mux"
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
func TestBridgeServiceDestination(t *testing.T) {
|
||||
canonicalJumpDestHeader := http.CanonicalHeaderKey(h2mux.CFJumpDestinationHeader)
|
||||
tests := []struct {
|
||||
name string
|
||||
header http.Header
|
||||
expectedDest string
|
||||
wantErr bool
|
||||
}{
|
||||
{
|
||||
name: "hostname destination",
|
||||
header: http.Header{
|
||||
canonicalJumpDestHeader: []string{"localhost"},
|
||||
},
|
||||
expectedDest: "localhost",
|
||||
},
|
||||
{
|
||||
name: "hostname destination with port",
|
||||
header: http.Header{
|
||||
canonicalJumpDestHeader: []string{"localhost:9000"},
|
||||
},
|
||||
expectedDest: "localhost:9000",
|
||||
},
|
||||
{
|
||||
name: "hostname destination with scheme and port",
|
||||
header: http.Header{
|
||||
canonicalJumpDestHeader: []string{"ssh://localhost:9000"},
|
||||
},
|
||||
expectedDest: "localhost:9000",
|
||||
},
|
||||
{
|
||||
name: "full hostname url",
|
||||
header: http.Header{
|
||||
canonicalJumpDestHeader: []string{"ssh://localhost:9000/metrics"},
|
||||
},
|
||||
expectedDest: "localhost:9000",
|
||||
},
|
||||
{
|
||||
name: "hostname destination with port and path",
|
||||
header: http.Header{
|
||||
canonicalJumpDestHeader: []string{"localhost:9000/metrics"},
|
||||
},
|
||||
expectedDest: "localhost:9000",
|
||||
},
|
||||
{
|
||||
name: "ip destination",
|
||||
header: http.Header{
|
||||
canonicalJumpDestHeader: []string{"127.0.0.1"},
|
||||
},
|
||||
expectedDest: "127.0.0.1",
|
||||
},
|
||||
{
|
||||
name: "ip destination with port",
|
||||
header: http.Header{
|
||||
canonicalJumpDestHeader: []string{"127.0.0.1:9000"},
|
||||
},
|
||||
expectedDest: "127.0.0.1:9000",
|
||||
},
|
||||
{
|
||||
name: "ip destination with port and path",
|
||||
header: http.Header{
|
||||
canonicalJumpDestHeader: []string{"127.0.0.1:9000/metrics"},
|
||||
},
|
||||
expectedDest: "127.0.0.1:9000",
|
||||
},
|
||||
{
|
||||
name: "ip destination with schem and port",
|
||||
header: http.Header{
|
||||
canonicalJumpDestHeader: []string{"tcp://127.0.0.1:9000"},
|
||||
},
|
||||
expectedDest: "127.0.0.1:9000",
|
||||
},
|
||||
{
|
||||
name: "full ip url",
|
||||
header: http.Header{
|
||||
canonicalJumpDestHeader: []string{"ssh://127.0.0.1:9000/metrics"},
|
||||
},
|
||||
expectedDest: "127.0.0.1:9000",
|
||||
},
|
||||
{
|
||||
name: "no destination",
|
||||
wantErr: true,
|
||||
},
|
||||
}
|
||||
s := newBridgeService()
|
||||
for _, test := range tests {
|
||||
r := &http.Request{
|
||||
Header: test.header,
|
||||
}
|
||||
dest, err := s.destination(r)
|
||||
if test.wantErr {
|
||||
assert.Error(t, err, "Test %s expects error", test.name)
|
||||
} else {
|
||||
assert.NoError(t, err, "Test %s expects no error, got error %v", test.name, err)
|
||||
assert.Equal(t, test.expectedDest, dest, "Test %s expect dest %s, got %s", test.name, test.expectedDest, dest)
|
||||
}
|
||||
}
|
||||
}
|
@@ -8,7 +8,6 @@ import (
|
||||
"net"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"strconv"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
@@ -21,10 +20,8 @@ import (
|
||||
"github.com/rs/zerolog"
|
||||
)
|
||||
|
||||
// OriginService is something a tunnel can proxy traffic to.
|
||||
type OriginService interface {
|
||||
// RoundTrip is how cloudflared proxies eyeball requests to the actual origin services
|
||||
http.RoundTripper
|
||||
// originService is something a tunnel can proxy traffic to.
|
||||
type originService interface {
|
||||
String() string
|
||||
// Start the origin service if it's managed by cloudflared, e.g. proxy servers or Hello World.
|
||||
// If it's not managed by cloudflared, this is a no-op because the user is responsible for
|
||||
@@ -51,10 +48,6 @@ func (o *unixSocketPath) start(wg *sync.WaitGroup, log *zerolog.Logger, shutdown
|
||||
return nil
|
||||
}
|
||||
|
||||
func (o *unixSocketPath) RoundTrip(req *http.Request) (*http.Response, error) {
|
||||
return o.transport.RoundTrip(req)
|
||||
}
|
||||
|
||||
func (o *unixSocketPath) Dial(reqURL *url.URL, headers http.Header) (*gws.Conn, *http.Response, error) {
|
||||
d := &gws.Dialer{
|
||||
NetDial: o.transport.Dial,
|
||||
@@ -65,130 +58,87 @@ func (o *unixSocketPath) Dial(reqURL *url.URL, headers http.Header) (*gws.Conn,
|
||||
return d.Dial(reqURL.String(), headers)
|
||||
}
|
||||
|
||||
// localService is an OriginService listening on a TCP/IP address the user's origin can route to.
|
||||
type localService struct {
|
||||
// The URL for the user's origin service
|
||||
RootURL *url.URL
|
||||
// The URL that cloudflared should send requests to.
|
||||
// If this origin requires starting a proxy, this is the proxy's address,
|
||||
// and that proxy points to RootURL. Otherwise, this is equal to RootURL.
|
||||
URL *url.URL
|
||||
type httpService struct {
|
||||
url *url.URL
|
||||
transport *http.Transport
|
||||
}
|
||||
|
||||
func (o *localService) Dial(reqURL *url.URL, headers http.Header) (*gws.Conn, *http.Response, error) {
|
||||
d := &gws.Dialer{TLSClientConfig: o.transport.TLSClientConfig}
|
||||
// Rewrite the request URL so that it goes to the origin service.
|
||||
reqURL.Host = o.URL.Host
|
||||
reqURL.Scheme = websocket.ChangeRequestScheme(o.URL)
|
||||
return d.Dial(reqURL.String(), headers)
|
||||
}
|
||||
|
||||
func (o *localService) start(wg *sync.WaitGroup, log *zerolog.Logger, shutdownC <-chan struct{}, errC chan error, cfg OriginRequestConfig) error {
|
||||
func (o *httpService) start(wg *sync.WaitGroup, log *zerolog.Logger, shutdownC <-chan struct{}, errC chan error, cfg OriginRequestConfig) error {
|
||||
transport, err := newHTTPTransport(o, cfg, log)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
o.transport = transport
|
||||
|
||||
// Start a proxy if one is needed
|
||||
if staticHost := o.staticHost(); originRequiresProxy(staticHost, cfg) {
|
||||
if err := o.startProxy(staticHost, wg, log, shutdownC, errC, cfg); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (o *localService) startProxy(staticHost string, wg *sync.WaitGroup, log *zerolog.Logger, shutdownC <-chan struct{}, errC chan error, cfg OriginRequestConfig) error {
|
||||
func (o *httpService) String() string {
|
||||
return o.url.String()
|
||||
}
|
||||
|
||||
// Start a listener for the proxy
|
||||
proxyAddress := net.JoinHostPort(cfg.ProxyAddress, strconv.Itoa(int(cfg.ProxyPort)))
|
||||
listener, err := net.Listen("tcp", proxyAddress)
|
||||
if err != nil {
|
||||
log.Error().Msgf("Cannot start Websocket Proxy Server: %s", err)
|
||||
return errors.Wrap(err, "Cannot start Websocket Proxy Server")
|
||||
// bridgeService is like a jump host, the destination is specified by the client
|
||||
type bridgeService struct {
|
||||
client *tcpClient
|
||||
}
|
||||
|
||||
func newBridgeService() *bridgeService {
|
||||
return &bridgeService{
|
||||
client: &tcpClient{},
|
||||
}
|
||||
}
|
||||
|
||||
// Start the proxy itself
|
||||
wg.Add(1)
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
streamHandler := websocket.DefaultStreamHandler
|
||||
// This origin's config specifies what type of proxy to start.
|
||||
switch cfg.ProxyType {
|
||||
case socksProxy:
|
||||
log.Info().Msg("SOCKS5 server started")
|
||||
streamHandler = func(wsConn *websocket.Conn, remoteConn net.Conn, _ http.Header) {
|
||||
dialer := socks.NewConnDialer(remoteConn)
|
||||
requestHandler := socks.NewRequestHandler(dialer)
|
||||
socksServer := socks.NewConnectionHandler(requestHandler)
|
||||
func (o *bridgeService) String() string {
|
||||
return "bridge service"
|
||||
}
|
||||
|
||||
_ = socksServer.Serve(wsConn)
|
||||
}
|
||||
case "":
|
||||
log.Debug().Msg("Not starting any websocket proxy")
|
||||
default:
|
||||
log.Error().Msgf("%s isn't a valid proxy (valid options are {%s})", cfg.ProxyType, socksProxy)
|
||||
}
|
||||
|
||||
errC <- websocket.StartProxyServer(log, listener, staticHost, shutdownC, streamHandler)
|
||||
}()
|
||||
|
||||
// Modify this origin, so that it no longer points at the origin service directly.
|
||||
// Instead, it points at the proxy to the origin service.
|
||||
newURL, err := url.Parse("http://" + listener.Addr().String())
|
||||
if err != nil {
|
||||
return err
|
||||
func (o *bridgeService) start(wg *sync.WaitGroup, log *zerolog.Logger, shutdownC <-chan struct{}, errC chan error, cfg OriginRequestConfig) error {
|
||||
if cfg.ProxyType == socksProxy {
|
||||
o.client.streamHandler = socks.StreamHandler
|
||||
} else {
|
||||
o.client.streamHandler = websocket.DefaultStreamHandler
|
||||
}
|
||||
o.URL = newURL
|
||||
return nil
|
||||
}
|
||||
|
||||
func (o *localService) String() string {
|
||||
if o.isBastion() {
|
||||
return "Bastion"
|
||||
}
|
||||
return o.URL.String()
|
||||
type singleTCPService struct {
|
||||
dest string
|
||||
client *tcpClient
|
||||
}
|
||||
|
||||
func (o *localService) isBastion() bool {
|
||||
return o.URL == nil
|
||||
}
|
||||
|
||||
func (o *localService) RoundTrip(req *http.Request) (*http.Response, error) {
|
||||
// Rewrite the request URL so that it goes to the origin service.
|
||||
req.URL.Host = o.URL.Host
|
||||
req.URL.Scheme = o.URL.Scheme
|
||||
return o.transport.RoundTrip(req)
|
||||
}
|
||||
|
||||
func (o *localService) staticHost() string {
|
||||
|
||||
if o.URL == nil {
|
||||
return ""
|
||||
}
|
||||
|
||||
addPortIfMissing := func(uri *url.URL, port int) string {
|
||||
if uri.Port() != "" {
|
||||
return uri.Host
|
||||
}
|
||||
return fmt.Sprintf("%s:%d", uri.Hostname(), port)
|
||||
}
|
||||
|
||||
switch o.URL.Scheme {
|
||||
func newSingleTCPService(url *url.URL) *singleTCPService {
|
||||
switch url.Scheme {
|
||||
case "ssh":
|
||||
return addPortIfMissing(o.URL, 22)
|
||||
addPortIfMissing(url, 22)
|
||||
case "rdp":
|
||||
return addPortIfMissing(o.URL, 3389)
|
||||
addPortIfMissing(url, 3389)
|
||||
case "smb":
|
||||
return addPortIfMissing(o.URL, 445)
|
||||
addPortIfMissing(url, 445)
|
||||
case "tcp":
|
||||
return addPortIfMissing(o.URL, 7864) // just a random port since there isn't a default in this case
|
||||
addPortIfMissing(url, 7864) // just a random port since there isn't a default in this case
|
||||
}
|
||||
return ""
|
||||
return &singleTCPService{
|
||||
dest: url.Host,
|
||||
client: &tcpClient{},
|
||||
}
|
||||
}
|
||||
|
||||
func addPortIfMissing(uri *url.URL, port int) {
|
||||
if uri.Port() == "" {
|
||||
uri.Host = fmt.Sprintf("%s:%d", uri.Hostname(), port)
|
||||
}
|
||||
}
|
||||
|
||||
func (o *singleTCPService) String() string {
|
||||
return o.dest
|
||||
}
|
||||
|
||||
func (o *singleTCPService) start(wg *sync.WaitGroup, log *zerolog.Logger, shutdownC <-chan struct{}, errC chan error, cfg OriginRequestConfig) error {
|
||||
if cfg.ProxyType == socksProxy {
|
||||
o.client.streamHandler = socks.StreamHandler
|
||||
} else {
|
||||
o.client.streamHandler = websocket.DefaultStreamHandler
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// HelloWorld is an OriginService for the built-in Hello World server.
|
||||
@@ -228,26 +178,6 @@ func (o *helloWorld) start(
|
||||
return nil
|
||||
}
|
||||
|
||||
func (o *helloWorld) RoundTrip(req *http.Request) (*http.Response, error) {
|
||||
// Rewrite the request URL so that it goes to the Hello World server.
|
||||
req.URL.Host = o.server.Addr().String()
|
||||
req.URL.Scheme = "https"
|
||||
return o.transport.RoundTrip(req)
|
||||
}
|
||||
|
||||
func (o *helloWorld) Dial(reqURL *url.URL, headers http.Header) (*gws.Conn, *http.Response, error) {
|
||||
d := &gws.Dialer{
|
||||
TLSClientConfig: o.transport.TLSClientConfig,
|
||||
}
|
||||
reqURL.Host = o.server.Addr().String()
|
||||
reqURL.Scheme = "wss"
|
||||
return d.Dial(reqURL.String(), headers)
|
||||
}
|
||||
|
||||
func originRequiresProxy(staticHost string, cfg OriginRequestConfig) bool {
|
||||
return staticHost != "" || cfg.BastionMode
|
||||
}
|
||||
|
||||
// statusCode is an OriginService that just responds with a given HTTP status.
|
||||
// Typical use-case is "user wants the catch-all rule to just respond 404".
|
||||
type statusCode struct {
|
||||
@@ -277,10 +207,6 @@ func (o *statusCode) start(
|
||||
return nil
|
||||
}
|
||||
|
||||
func (o *statusCode) RoundTrip(_ *http.Request) (*http.Response, error) {
|
||||
return o.resp, nil
|
||||
}
|
||||
|
||||
type NopReadCloser struct{}
|
||||
|
||||
// Read always returns EOF to signal end of input
|
||||
@@ -292,7 +218,7 @@ func (nrc *NopReadCloser) Close() error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func newHTTPTransport(service OriginService, cfg OriginRequestConfig, log *zerolog.Logger) (*http.Transport, error) {
|
||||
func newHTTPTransport(service originService, cfg OriginRequestConfig, log *zerolog.Logger) (*http.Transport, error) {
|
||||
originCertPool, err := tlsconfig.LoadOriginCA(cfg.CAPool, log)
|
||||
if err != nil {
|
||||
return nil, errors.Wrap(err, "Error loading cert pool")
|
||||
@@ -337,19 +263,19 @@ func newHTTPTransport(service OriginService, cfg OriginRequestConfig, log *zerol
|
||||
return &httpTransport, nil
|
||||
}
|
||||
|
||||
// MockOriginService should only be used by other packages to mock OriginService. Set Transport to configure desired RoundTripper behavior.
|
||||
type MockOriginService struct {
|
||||
// MockOriginHTTPService should only be used by other packages to mock OriginService. Set Transport to configure desired RoundTripper behavior.
|
||||
type MockOriginHTTPService struct {
|
||||
Transport http.RoundTripper
|
||||
}
|
||||
|
||||
func (mos MockOriginService) RoundTrip(req *http.Request) (*http.Response, error) {
|
||||
func (mos MockOriginHTTPService) RoundTrip(req *http.Request) (*http.Response, error) {
|
||||
return mos.Transport.RoundTrip(req)
|
||||
}
|
||||
|
||||
func (mos MockOriginService) String() string {
|
||||
func (mos MockOriginHTTPService) String() string {
|
||||
return "MockOriginService"
|
||||
}
|
||||
|
||||
func (mos MockOriginService) start(wg *sync.WaitGroup, log *zerolog.Logger, shutdownC <-chan struct{}, errC chan error, cfg OriginRequestConfig) error {
|
||||
func (mos MockOriginHTTPService) start(wg *sync.WaitGroup, log *zerolog.Logger, shutdownC <-chan struct{}, errC chan error, cfg OriginRequestConfig) error {
|
||||
return nil
|
||||
}
|
||||
|
@@ -17,7 +17,7 @@ type Rule struct {
|
||||
// A (probably local) address. Requests for a hostname which matches this
|
||||
// rule's hostname pattern will be proxied to the service running on this
|
||||
// address.
|
||||
Service OriginService
|
||||
Service originService
|
||||
|
||||
// Configure the request cloudflared sends to this specific origin.
|
||||
Config OriginRequestConfig
|
||||
|
@@ -14,7 +14,7 @@ func Test_rule_matches(t *testing.T) {
|
||||
type fields struct {
|
||||
Hostname string
|
||||
Path *regexp.Regexp
|
||||
Service OriginService
|
||||
Service originService
|
||||
}
|
||||
type args struct {
|
||||
requestURL *url.URL
|
||||
|
Reference in New Issue
Block a user