TUN-3617: Separate service from client, and implement different client for http vs. tcp origins

- extracted ResponseWriter from proxyConnection
 - added bastion tests over websocket
 - removed HTTPResp()
 - added some docstrings
 - Renamed some ingress clients as proxies
 - renamed instances of client to proxy in connection and origin
 - Stream no longer takes a context and logger.Service
This commit is contained in:
cthuang
2020-12-09 21:46:53 +00:00
committed by Nuno Diegues
parent 5e2b43adb5
commit e2262085e5
23 changed files with 839 additions and 354 deletions

View File

@@ -8,7 +8,6 @@ import (
"net"
"net/http"
"net/url"
"strconv"
"sync"
"time"
@@ -21,10 +20,8 @@ import (
"github.com/rs/zerolog"
)
// OriginService is something a tunnel can proxy traffic to.
type OriginService interface {
// RoundTrip is how cloudflared proxies eyeball requests to the actual origin services
http.RoundTripper
// originService is something a tunnel can proxy traffic to.
type originService interface {
String() string
// Start the origin service if it's managed by cloudflared, e.g. proxy servers or Hello World.
// If it's not managed by cloudflared, this is a no-op because the user is responsible for
@@ -51,10 +48,6 @@ func (o *unixSocketPath) start(wg *sync.WaitGroup, log *zerolog.Logger, shutdown
return nil
}
func (o *unixSocketPath) RoundTrip(req *http.Request) (*http.Response, error) {
return o.transport.RoundTrip(req)
}
func (o *unixSocketPath) Dial(reqURL *url.URL, headers http.Header) (*gws.Conn, *http.Response, error) {
d := &gws.Dialer{
NetDial: o.transport.Dial,
@@ -65,130 +58,87 @@ func (o *unixSocketPath) Dial(reqURL *url.URL, headers http.Header) (*gws.Conn,
return d.Dial(reqURL.String(), headers)
}
// localService is an OriginService listening on a TCP/IP address the user's origin can route to.
type localService struct {
// The URL for the user's origin service
RootURL *url.URL
// The URL that cloudflared should send requests to.
// If this origin requires starting a proxy, this is the proxy's address,
// and that proxy points to RootURL. Otherwise, this is equal to RootURL.
URL *url.URL
type httpService struct {
url *url.URL
transport *http.Transport
}
func (o *localService) Dial(reqURL *url.URL, headers http.Header) (*gws.Conn, *http.Response, error) {
d := &gws.Dialer{TLSClientConfig: o.transport.TLSClientConfig}
// Rewrite the request URL so that it goes to the origin service.
reqURL.Host = o.URL.Host
reqURL.Scheme = websocket.ChangeRequestScheme(o.URL)
return d.Dial(reqURL.String(), headers)
}
func (o *localService) start(wg *sync.WaitGroup, log *zerolog.Logger, shutdownC <-chan struct{}, errC chan error, cfg OriginRequestConfig) error {
func (o *httpService) start(wg *sync.WaitGroup, log *zerolog.Logger, shutdownC <-chan struct{}, errC chan error, cfg OriginRequestConfig) error {
transport, err := newHTTPTransport(o, cfg, log)
if err != nil {
return err
}
o.transport = transport
// Start a proxy if one is needed
if staticHost := o.staticHost(); originRequiresProxy(staticHost, cfg) {
if err := o.startProxy(staticHost, wg, log, shutdownC, errC, cfg); err != nil {
return err
}
}
return nil
}
func (o *localService) startProxy(staticHost string, wg *sync.WaitGroup, log *zerolog.Logger, shutdownC <-chan struct{}, errC chan error, cfg OriginRequestConfig) error {
func (o *httpService) String() string {
return o.url.String()
}
// Start a listener for the proxy
proxyAddress := net.JoinHostPort(cfg.ProxyAddress, strconv.Itoa(int(cfg.ProxyPort)))
listener, err := net.Listen("tcp", proxyAddress)
if err != nil {
log.Error().Msgf("Cannot start Websocket Proxy Server: %s", err)
return errors.Wrap(err, "Cannot start Websocket Proxy Server")
// bridgeService is like a jump host, the destination is specified by the client
type bridgeService struct {
client *tcpClient
}
func newBridgeService() *bridgeService {
return &bridgeService{
client: &tcpClient{},
}
}
// Start the proxy itself
wg.Add(1)
go func() {
defer wg.Done()
streamHandler := websocket.DefaultStreamHandler
// This origin's config specifies what type of proxy to start.
switch cfg.ProxyType {
case socksProxy:
log.Info().Msg("SOCKS5 server started")
streamHandler = func(wsConn *websocket.Conn, remoteConn net.Conn, _ http.Header) {
dialer := socks.NewConnDialer(remoteConn)
requestHandler := socks.NewRequestHandler(dialer)
socksServer := socks.NewConnectionHandler(requestHandler)
func (o *bridgeService) String() string {
return "bridge service"
}
_ = socksServer.Serve(wsConn)
}
case "":
log.Debug().Msg("Not starting any websocket proxy")
default:
log.Error().Msgf("%s isn't a valid proxy (valid options are {%s})", cfg.ProxyType, socksProxy)
}
errC <- websocket.StartProxyServer(log, listener, staticHost, shutdownC, streamHandler)
}()
// Modify this origin, so that it no longer points at the origin service directly.
// Instead, it points at the proxy to the origin service.
newURL, err := url.Parse("http://" + listener.Addr().String())
if err != nil {
return err
func (o *bridgeService) start(wg *sync.WaitGroup, log *zerolog.Logger, shutdownC <-chan struct{}, errC chan error, cfg OriginRequestConfig) error {
if cfg.ProxyType == socksProxy {
o.client.streamHandler = socks.StreamHandler
} else {
o.client.streamHandler = websocket.DefaultStreamHandler
}
o.URL = newURL
return nil
}
func (o *localService) String() string {
if o.isBastion() {
return "Bastion"
}
return o.URL.String()
type singleTCPService struct {
dest string
client *tcpClient
}
func (o *localService) isBastion() bool {
return o.URL == nil
}
func (o *localService) RoundTrip(req *http.Request) (*http.Response, error) {
// Rewrite the request URL so that it goes to the origin service.
req.URL.Host = o.URL.Host
req.URL.Scheme = o.URL.Scheme
return o.transport.RoundTrip(req)
}
func (o *localService) staticHost() string {
if o.URL == nil {
return ""
}
addPortIfMissing := func(uri *url.URL, port int) string {
if uri.Port() != "" {
return uri.Host
}
return fmt.Sprintf("%s:%d", uri.Hostname(), port)
}
switch o.URL.Scheme {
func newSingleTCPService(url *url.URL) *singleTCPService {
switch url.Scheme {
case "ssh":
return addPortIfMissing(o.URL, 22)
addPortIfMissing(url, 22)
case "rdp":
return addPortIfMissing(o.URL, 3389)
addPortIfMissing(url, 3389)
case "smb":
return addPortIfMissing(o.URL, 445)
addPortIfMissing(url, 445)
case "tcp":
return addPortIfMissing(o.URL, 7864) // just a random port since there isn't a default in this case
addPortIfMissing(url, 7864) // just a random port since there isn't a default in this case
}
return ""
return &singleTCPService{
dest: url.Host,
client: &tcpClient{},
}
}
func addPortIfMissing(uri *url.URL, port int) {
if uri.Port() == "" {
uri.Host = fmt.Sprintf("%s:%d", uri.Hostname(), port)
}
}
func (o *singleTCPService) String() string {
return o.dest
}
func (o *singleTCPService) start(wg *sync.WaitGroup, log *zerolog.Logger, shutdownC <-chan struct{}, errC chan error, cfg OriginRequestConfig) error {
if cfg.ProxyType == socksProxy {
o.client.streamHandler = socks.StreamHandler
} else {
o.client.streamHandler = websocket.DefaultStreamHandler
}
return nil
}
// HelloWorld is an OriginService for the built-in Hello World server.
@@ -228,26 +178,6 @@ func (o *helloWorld) start(
return nil
}
func (o *helloWorld) RoundTrip(req *http.Request) (*http.Response, error) {
// Rewrite the request URL so that it goes to the Hello World server.
req.URL.Host = o.server.Addr().String()
req.URL.Scheme = "https"
return o.transport.RoundTrip(req)
}
func (o *helloWorld) Dial(reqURL *url.URL, headers http.Header) (*gws.Conn, *http.Response, error) {
d := &gws.Dialer{
TLSClientConfig: o.transport.TLSClientConfig,
}
reqURL.Host = o.server.Addr().String()
reqURL.Scheme = "wss"
return d.Dial(reqURL.String(), headers)
}
func originRequiresProxy(staticHost string, cfg OriginRequestConfig) bool {
return staticHost != "" || cfg.BastionMode
}
// statusCode is an OriginService that just responds with a given HTTP status.
// Typical use-case is "user wants the catch-all rule to just respond 404".
type statusCode struct {
@@ -277,10 +207,6 @@ func (o *statusCode) start(
return nil
}
func (o *statusCode) RoundTrip(_ *http.Request) (*http.Response, error) {
return o.resp, nil
}
type NopReadCloser struct{}
// Read always returns EOF to signal end of input
@@ -292,7 +218,7 @@ func (nrc *NopReadCloser) Close() error {
return nil
}
func newHTTPTransport(service OriginService, cfg OriginRequestConfig, log *zerolog.Logger) (*http.Transport, error) {
func newHTTPTransport(service originService, cfg OriginRequestConfig, log *zerolog.Logger) (*http.Transport, error) {
originCertPool, err := tlsconfig.LoadOriginCA(cfg.CAPool, log)
if err != nil {
return nil, errors.Wrap(err, "Error loading cert pool")
@@ -337,19 +263,19 @@ func newHTTPTransport(service OriginService, cfg OriginRequestConfig, log *zerol
return &httpTransport, nil
}
// MockOriginService should only be used by other packages to mock OriginService. Set Transport to configure desired RoundTripper behavior.
type MockOriginService struct {
// MockOriginHTTPService should only be used by other packages to mock OriginService. Set Transport to configure desired RoundTripper behavior.
type MockOriginHTTPService struct {
Transport http.RoundTripper
}
func (mos MockOriginService) RoundTrip(req *http.Request) (*http.Response, error) {
func (mos MockOriginHTTPService) RoundTrip(req *http.Request) (*http.Response, error) {
return mos.Transport.RoundTrip(req)
}
func (mos MockOriginService) String() string {
func (mos MockOriginHTTPService) String() string {
return "MockOriginService"
}
func (mos MockOriginService) start(wg *sync.WaitGroup, log *zerolog.Logger, shutdownC <-chan struct{}, errC chan error, cfg OriginRequestConfig) error {
func (mos MockOriginHTTPService) start(wg *sync.WaitGroup, log *zerolog.Logger, shutdownC <-chan struct{}, errC chan error, cfg OriginRequestConfig) error {
return nil
}