mirror of
https://github.com/cloudflare/cloudflared.git
synced 2025-07-27 20:59:58 +00:00
TUN-3617: Separate service from client, and implement different client for http vs. tcp origins
- extracted ResponseWriter from proxyConnection - added bastion tests over websocket - removed HTTPResp() - added some docstrings - Renamed some ingress clients as proxies - renamed instances of client to proxy in connection and origin - Stream no longer takes a context and logger.Service
This commit is contained in:
@@ -8,7 +8,6 @@ import (
|
||||
"net"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"strconv"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
@@ -21,10 +20,8 @@ import (
|
||||
"github.com/rs/zerolog"
|
||||
)
|
||||
|
||||
// OriginService is something a tunnel can proxy traffic to.
|
||||
type OriginService interface {
|
||||
// RoundTrip is how cloudflared proxies eyeball requests to the actual origin services
|
||||
http.RoundTripper
|
||||
// originService is something a tunnel can proxy traffic to.
|
||||
type originService interface {
|
||||
String() string
|
||||
// Start the origin service if it's managed by cloudflared, e.g. proxy servers or Hello World.
|
||||
// If it's not managed by cloudflared, this is a no-op because the user is responsible for
|
||||
@@ -51,10 +48,6 @@ func (o *unixSocketPath) start(wg *sync.WaitGroup, log *zerolog.Logger, shutdown
|
||||
return nil
|
||||
}
|
||||
|
||||
func (o *unixSocketPath) RoundTrip(req *http.Request) (*http.Response, error) {
|
||||
return o.transport.RoundTrip(req)
|
||||
}
|
||||
|
||||
func (o *unixSocketPath) Dial(reqURL *url.URL, headers http.Header) (*gws.Conn, *http.Response, error) {
|
||||
d := &gws.Dialer{
|
||||
NetDial: o.transport.Dial,
|
||||
@@ -65,130 +58,87 @@ func (o *unixSocketPath) Dial(reqURL *url.URL, headers http.Header) (*gws.Conn,
|
||||
return d.Dial(reqURL.String(), headers)
|
||||
}
|
||||
|
||||
// localService is an OriginService listening on a TCP/IP address the user's origin can route to.
|
||||
type localService struct {
|
||||
// The URL for the user's origin service
|
||||
RootURL *url.URL
|
||||
// The URL that cloudflared should send requests to.
|
||||
// If this origin requires starting a proxy, this is the proxy's address,
|
||||
// and that proxy points to RootURL. Otherwise, this is equal to RootURL.
|
||||
URL *url.URL
|
||||
type httpService struct {
|
||||
url *url.URL
|
||||
transport *http.Transport
|
||||
}
|
||||
|
||||
func (o *localService) Dial(reqURL *url.URL, headers http.Header) (*gws.Conn, *http.Response, error) {
|
||||
d := &gws.Dialer{TLSClientConfig: o.transport.TLSClientConfig}
|
||||
// Rewrite the request URL so that it goes to the origin service.
|
||||
reqURL.Host = o.URL.Host
|
||||
reqURL.Scheme = websocket.ChangeRequestScheme(o.URL)
|
||||
return d.Dial(reqURL.String(), headers)
|
||||
}
|
||||
|
||||
func (o *localService) start(wg *sync.WaitGroup, log *zerolog.Logger, shutdownC <-chan struct{}, errC chan error, cfg OriginRequestConfig) error {
|
||||
func (o *httpService) start(wg *sync.WaitGroup, log *zerolog.Logger, shutdownC <-chan struct{}, errC chan error, cfg OriginRequestConfig) error {
|
||||
transport, err := newHTTPTransport(o, cfg, log)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
o.transport = transport
|
||||
|
||||
// Start a proxy if one is needed
|
||||
if staticHost := o.staticHost(); originRequiresProxy(staticHost, cfg) {
|
||||
if err := o.startProxy(staticHost, wg, log, shutdownC, errC, cfg); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (o *localService) startProxy(staticHost string, wg *sync.WaitGroup, log *zerolog.Logger, shutdownC <-chan struct{}, errC chan error, cfg OriginRequestConfig) error {
|
||||
func (o *httpService) String() string {
|
||||
return o.url.String()
|
||||
}
|
||||
|
||||
// Start a listener for the proxy
|
||||
proxyAddress := net.JoinHostPort(cfg.ProxyAddress, strconv.Itoa(int(cfg.ProxyPort)))
|
||||
listener, err := net.Listen("tcp", proxyAddress)
|
||||
if err != nil {
|
||||
log.Error().Msgf("Cannot start Websocket Proxy Server: %s", err)
|
||||
return errors.Wrap(err, "Cannot start Websocket Proxy Server")
|
||||
// bridgeService is like a jump host, the destination is specified by the client
|
||||
type bridgeService struct {
|
||||
client *tcpClient
|
||||
}
|
||||
|
||||
func newBridgeService() *bridgeService {
|
||||
return &bridgeService{
|
||||
client: &tcpClient{},
|
||||
}
|
||||
}
|
||||
|
||||
// Start the proxy itself
|
||||
wg.Add(1)
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
streamHandler := websocket.DefaultStreamHandler
|
||||
// This origin's config specifies what type of proxy to start.
|
||||
switch cfg.ProxyType {
|
||||
case socksProxy:
|
||||
log.Info().Msg("SOCKS5 server started")
|
||||
streamHandler = func(wsConn *websocket.Conn, remoteConn net.Conn, _ http.Header) {
|
||||
dialer := socks.NewConnDialer(remoteConn)
|
||||
requestHandler := socks.NewRequestHandler(dialer)
|
||||
socksServer := socks.NewConnectionHandler(requestHandler)
|
||||
func (o *bridgeService) String() string {
|
||||
return "bridge service"
|
||||
}
|
||||
|
||||
_ = socksServer.Serve(wsConn)
|
||||
}
|
||||
case "":
|
||||
log.Debug().Msg("Not starting any websocket proxy")
|
||||
default:
|
||||
log.Error().Msgf("%s isn't a valid proxy (valid options are {%s})", cfg.ProxyType, socksProxy)
|
||||
}
|
||||
|
||||
errC <- websocket.StartProxyServer(log, listener, staticHost, shutdownC, streamHandler)
|
||||
}()
|
||||
|
||||
// Modify this origin, so that it no longer points at the origin service directly.
|
||||
// Instead, it points at the proxy to the origin service.
|
||||
newURL, err := url.Parse("http://" + listener.Addr().String())
|
||||
if err != nil {
|
||||
return err
|
||||
func (o *bridgeService) start(wg *sync.WaitGroup, log *zerolog.Logger, shutdownC <-chan struct{}, errC chan error, cfg OriginRequestConfig) error {
|
||||
if cfg.ProxyType == socksProxy {
|
||||
o.client.streamHandler = socks.StreamHandler
|
||||
} else {
|
||||
o.client.streamHandler = websocket.DefaultStreamHandler
|
||||
}
|
||||
o.URL = newURL
|
||||
return nil
|
||||
}
|
||||
|
||||
func (o *localService) String() string {
|
||||
if o.isBastion() {
|
||||
return "Bastion"
|
||||
}
|
||||
return o.URL.String()
|
||||
type singleTCPService struct {
|
||||
dest string
|
||||
client *tcpClient
|
||||
}
|
||||
|
||||
func (o *localService) isBastion() bool {
|
||||
return o.URL == nil
|
||||
}
|
||||
|
||||
func (o *localService) RoundTrip(req *http.Request) (*http.Response, error) {
|
||||
// Rewrite the request URL so that it goes to the origin service.
|
||||
req.URL.Host = o.URL.Host
|
||||
req.URL.Scheme = o.URL.Scheme
|
||||
return o.transport.RoundTrip(req)
|
||||
}
|
||||
|
||||
func (o *localService) staticHost() string {
|
||||
|
||||
if o.URL == nil {
|
||||
return ""
|
||||
}
|
||||
|
||||
addPortIfMissing := func(uri *url.URL, port int) string {
|
||||
if uri.Port() != "" {
|
||||
return uri.Host
|
||||
}
|
||||
return fmt.Sprintf("%s:%d", uri.Hostname(), port)
|
||||
}
|
||||
|
||||
switch o.URL.Scheme {
|
||||
func newSingleTCPService(url *url.URL) *singleTCPService {
|
||||
switch url.Scheme {
|
||||
case "ssh":
|
||||
return addPortIfMissing(o.URL, 22)
|
||||
addPortIfMissing(url, 22)
|
||||
case "rdp":
|
||||
return addPortIfMissing(o.URL, 3389)
|
||||
addPortIfMissing(url, 3389)
|
||||
case "smb":
|
||||
return addPortIfMissing(o.URL, 445)
|
||||
addPortIfMissing(url, 445)
|
||||
case "tcp":
|
||||
return addPortIfMissing(o.URL, 7864) // just a random port since there isn't a default in this case
|
||||
addPortIfMissing(url, 7864) // just a random port since there isn't a default in this case
|
||||
}
|
||||
return ""
|
||||
return &singleTCPService{
|
||||
dest: url.Host,
|
||||
client: &tcpClient{},
|
||||
}
|
||||
}
|
||||
|
||||
func addPortIfMissing(uri *url.URL, port int) {
|
||||
if uri.Port() == "" {
|
||||
uri.Host = fmt.Sprintf("%s:%d", uri.Hostname(), port)
|
||||
}
|
||||
}
|
||||
|
||||
func (o *singleTCPService) String() string {
|
||||
return o.dest
|
||||
}
|
||||
|
||||
func (o *singleTCPService) start(wg *sync.WaitGroup, log *zerolog.Logger, shutdownC <-chan struct{}, errC chan error, cfg OriginRequestConfig) error {
|
||||
if cfg.ProxyType == socksProxy {
|
||||
o.client.streamHandler = socks.StreamHandler
|
||||
} else {
|
||||
o.client.streamHandler = websocket.DefaultStreamHandler
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// HelloWorld is an OriginService for the built-in Hello World server.
|
||||
@@ -228,26 +178,6 @@ func (o *helloWorld) start(
|
||||
return nil
|
||||
}
|
||||
|
||||
func (o *helloWorld) RoundTrip(req *http.Request) (*http.Response, error) {
|
||||
// Rewrite the request URL so that it goes to the Hello World server.
|
||||
req.URL.Host = o.server.Addr().String()
|
||||
req.URL.Scheme = "https"
|
||||
return o.transport.RoundTrip(req)
|
||||
}
|
||||
|
||||
func (o *helloWorld) Dial(reqURL *url.URL, headers http.Header) (*gws.Conn, *http.Response, error) {
|
||||
d := &gws.Dialer{
|
||||
TLSClientConfig: o.transport.TLSClientConfig,
|
||||
}
|
||||
reqURL.Host = o.server.Addr().String()
|
||||
reqURL.Scheme = "wss"
|
||||
return d.Dial(reqURL.String(), headers)
|
||||
}
|
||||
|
||||
func originRequiresProxy(staticHost string, cfg OriginRequestConfig) bool {
|
||||
return staticHost != "" || cfg.BastionMode
|
||||
}
|
||||
|
||||
// statusCode is an OriginService that just responds with a given HTTP status.
|
||||
// Typical use-case is "user wants the catch-all rule to just respond 404".
|
||||
type statusCode struct {
|
||||
@@ -277,10 +207,6 @@ func (o *statusCode) start(
|
||||
return nil
|
||||
}
|
||||
|
||||
func (o *statusCode) RoundTrip(_ *http.Request) (*http.Response, error) {
|
||||
return o.resp, nil
|
||||
}
|
||||
|
||||
type NopReadCloser struct{}
|
||||
|
||||
// Read always returns EOF to signal end of input
|
||||
@@ -292,7 +218,7 @@ func (nrc *NopReadCloser) Close() error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func newHTTPTransport(service OriginService, cfg OriginRequestConfig, log *zerolog.Logger) (*http.Transport, error) {
|
||||
func newHTTPTransport(service originService, cfg OriginRequestConfig, log *zerolog.Logger) (*http.Transport, error) {
|
||||
originCertPool, err := tlsconfig.LoadOriginCA(cfg.CAPool, log)
|
||||
if err != nil {
|
||||
return nil, errors.Wrap(err, "Error loading cert pool")
|
||||
@@ -337,19 +263,19 @@ func newHTTPTransport(service OriginService, cfg OriginRequestConfig, log *zerol
|
||||
return &httpTransport, nil
|
||||
}
|
||||
|
||||
// MockOriginService should only be used by other packages to mock OriginService. Set Transport to configure desired RoundTripper behavior.
|
||||
type MockOriginService struct {
|
||||
// MockOriginHTTPService should only be used by other packages to mock OriginService. Set Transport to configure desired RoundTripper behavior.
|
||||
type MockOriginHTTPService struct {
|
||||
Transport http.RoundTripper
|
||||
}
|
||||
|
||||
func (mos MockOriginService) RoundTrip(req *http.Request) (*http.Response, error) {
|
||||
func (mos MockOriginHTTPService) RoundTrip(req *http.Request) (*http.Response, error) {
|
||||
return mos.Transport.RoundTrip(req)
|
||||
}
|
||||
|
||||
func (mos MockOriginService) String() string {
|
||||
func (mos MockOriginHTTPService) String() string {
|
||||
return "MockOriginService"
|
||||
}
|
||||
|
||||
func (mos MockOriginService) start(wg *sync.WaitGroup, log *zerolog.Logger, shutdownC <-chan struct{}, errC chan error, cfg OriginRequestConfig) error {
|
||||
func (mos MockOriginHTTPService) start(wg *sync.WaitGroup, log *zerolog.Logger, shutdownC <-chan struct{}, errC chan error, cfg OriginRequestConfig) error {
|
||||
return nil
|
||||
}
|
||||
|
Reference in New Issue
Block a user