TUN-3492: Refactor OriginService, shrink its interface

This commit is contained in:
Adam Chalmers
2020-10-30 16:37:40 -05:00
parent 18c359cb86
commit d01770107e
12 changed files with 214 additions and 185 deletions

View File

@@ -1,72 +1,103 @@
package ingress
import (
"context"
"crypto/tls"
"fmt"
"net"
"net/http"
"net/url"
"strconv"
"sync"
"time"
"github.com/cloudflare/cloudflared/hello"
"github.com/cloudflare/cloudflared/logger"
"github.com/cloudflare/cloudflared/socks"
"github.com/cloudflare/cloudflared/tlsconfig"
"github.com/cloudflare/cloudflared/websocket"
gws "github.com/gorilla/websocket"
"github.com/pkg/errors"
)
// OriginService is something a tunnel can proxy traffic to.
type OriginService interface {
Address() string
// RoundTrip is how cloudflared proxies eyeball requests to the actual origin services
http.RoundTripper
String() string
// Start the origin service if it's managed by cloudflared, e.g. proxy servers or Hello World.
// If it's not managed by cloudflared, this is a no-op because the user is responsible for
// starting the origin service.
Start(wg *sync.WaitGroup, log logger.Service, shutdownC <-chan struct{}, errC chan error, cfg OriginRequestConfig) error
String() string
// RewriteOriginURL modifies the HTTP request from cloudflared to the origin, so that it apply
// this particular type of origin service's specific routing logic.
RewriteOriginURL(*url.URL)
start(wg *sync.WaitGroup, log logger.Service, shutdownC <-chan struct{}, errC chan error, cfg OriginRequestConfig) error
}
// UnixSocketPath is an OriginService representing a unix socket (which accepts HTTP)
type UnixSocketPath string
func (o UnixSocketPath) Address() string {
return string(o)
// unixSocketPath is an OriginService representing a unix socket (which accepts HTTP)
type unixSocketPath struct {
path string
transport *http.Transport
}
func (o UnixSocketPath) String() string {
return "unix socket: " + string(o)
func (o *unixSocketPath) String() string {
return "unix socket: " + o.path
}
func (o UnixSocketPath) Start(wg *sync.WaitGroup, log logger.Service, shutdownC <-chan struct{}, errC chan error, cfg OriginRequestConfig) error {
func (o *unixSocketPath) start(wg *sync.WaitGroup, log logger.Service, shutdownC <-chan struct{}, errC chan error, cfg OriginRequestConfig) error {
transport, err := newHTTPTransport(o, cfg)
if err != nil {
return err
}
o.transport = transport
return nil
}
func (o UnixSocketPath) RewriteOriginURL(u *url.URL) {
// No changes necessary because the origin request URL isn't used.
// Instead, HTTPTransport's dial is already configured to address the unix socket.
func (o *unixSocketPath) RoundTrip(req *http.Request) (*http.Response, error) {
return o.transport.RoundTrip(req)
}
// URL is an OriginService listening on a TCP address
type URL struct {
func (o *unixSocketPath) Dial(url string, headers http.Header) (*gws.Conn, *http.Response, error) {
d := &gws.Dialer{TLSClientConfig: o.transport.TLSClientConfig}
return d.Dial(url, headers)
}
// localService is an OriginService listening on a TCP/IP address the user's origin can route to.
type localService struct {
// The URL for the user's origin service
RootURL *url.URL
// The URL that cloudflared should send requests to.
// If this origin requires starting a proxy, this is the proxy's address,
// and that proxy points to RootURL. Otherwise, this is equal to RootURL.
URL *url.URL
URL *url.URL
transport *http.Transport
}
func (o *URL) Address() string {
func (o *localService) Dial(url string, headers http.Header) (*gws.Conn, *http.Response, error) {
d := &gws.Dialer{TLSClientConfig: o.transport.TLSClientConfig}
return d.Dial(url, headers)
}
func (o *localService) address() string {
return o.URL.String()
}
func (o *URL) Start(wg *sync.WaitGroup, log logger.Service, shutdownC <-chan struct{}, errC chan error, cfg OriginRequestConfig) error {
staticHost := o.staticHost()
if !originRequiresProxy(staticHost, cfg) {
return nil
func (o *localService) start(wg *sync.WaitGroup, log logger.Service, shutdownC <-chan struct{}, errC chan error, cfg OriginRequestConfig) error {
transport, err := newHTTPTransport(o, cfg)
if err != nil {
return err
}
o.transport = transport
// Start a proxy if one is needed
staticHost := o.staticHost()
if originRequiresProxy(staticHost, cfg) {
if err := o.startProxy(staticHost, wg, log, shutdownC, errC, cfg); err != nil {
return err
}
}
return nil
}
func (o *localService) startProxy(staticHost string, wg *sync.WaitGroup, log logger.Service, shutdownC <-chan struct{}, errC chan error, cfg OriginRequestConfig) error {
// Start a listener for the proxy
proxyAddress := net.JoinHostPort(cfg.ProxyAddress, strconv.Itoa(int(cfg.ProxyPort)))
@@ -111,16 +142,18 @@ func (o *URL) Start(wg *sync.WaitGroup, log logger.Service, shutdownC <-chan str
return nil
}
func (o *URL) String() string {
return o.Address()
func (o *localService) String() string {
return o.address()
}
func (o *URL) RewriteOriginURL(u *url.URL) {
u.Host = o.URL.Host
u.Scheme = o.URL.Scheme
func (o *localService) RoundTrip(req *http.Request) (*http.Response, error) {
// Rewrite the request URL so that it goes to the origin service.
req.URL.Host = o.URL.Host
req.URL.Scheme = o.URL.Scheme
return o.transport.RoundTrip(req)
}
func (o *URL) staticHost() string {
func (o *localService) staticHost() string {
addPortIfMissing := func(uri *url.URL, port int) string {
if uri.Port() != "" {
@@ -143,21 +176,24 @@ func (o *URL) staticHost() string {
}
// HelloWorld is the built-in Hello World service. Used for testing and experimenting with cloudflared.
type HelloWorld struct {
server net.Listener
// HelloWorld is an OriginService for the built-in Hello World server.
// Users only use this for testing and experimenting with cloudflared.
type helloWorld struct {
server net.Listener
transport *http.Transport
}
func (o *HelloWorld) Address() string {
return o.server.Addr().String()
}
func (o *HelloWorld) String() string {
return "Hello World static HTML service"
func (o *helloWorld) String() string {
return "Hello World test origin"
}
// Start starts a HelloWorld server and stores its address in the Service receiver.
func (o *HelloWorld) Start(wg *sync.WaitGroup, log logger.Service, shutdownC <-chan struct{}, errC chan error, cfg OriginRequestConfig) error {
func (o *helloWorld) start(wg *sync.WaitGroup, log logger.Service, shutdownC <-chan struct{}, errC chan error, cfg OriginRequestConfig) error {
transport, err := newHTTPTransport(o, cfg)
if err != nil {
return err
}
o.transport = transport
helloListener, err := hello.CreateTLSListener("127.0.0.1:")
if err != nil {
return errors.Wrap(err, "Cannot start Hello World Server")
@@ -171,11 +207,63 @@ func (o *HelloWorld) Start(wg *sync.WaitGroup, log logger.Service, shutdownC <-c
return nil
}
func (o *HelloWorld) RewriteOriginURL(u *url.URL) {
u.Host = o.Address()
u.Scheme = "https"
func (o *helloWorld) RoundTrip(req *http.Request) (*http.Response, error) {
// Rewrite the request URL so that it goes to the Hello World server.
req.URL.Host = o.server.Addr().String()
req.URL.Scheme = "https"
return o.transport.RoundTrip(req)
}
func (o *helloWorld) Dial(url string, headers http.Header) (*gws.Conn, *http.Response, error) {
d := &gws.Dialer{TLSClientConfig: o.transport.TLSClientConfig}
return d.Dial(url, headers)
}
func originRequiresProxy(staticHost string, cfg OriginRequestConfig) bool {
return staticHost != "" || cfg.BastionMode
}
func newHTTPTransport(service OriginService, cfg OriginRequestConfig) (*http.Transport, error) {
originCertPool, err := tlsconfig.LoadOriginCA(cfg.CAPool, nil)
if err != nil {
return nil, errors.Wrap(err, "Error loading cert pool")
}
httpTransport := http.Transport{
Proxy: http.ProxyFromEnvironment,
MaxIdleConns: cfg.KeepAliveConnections,
MaxIdleConnsPerHost: cfg.KeepAliveConnections,
IdleConnTimeout: cfg.KeepAliveTimeout,
TLSHandshakeTimeout: cfg.TLSTimeout,
ExpectContinueTimeout: 1 * time.Second,
TLSClientConfig: &tls.Config{RootCAs: originCertPool, InsecureSkipVerify: cfg.NoTLSVerify},
}
if _, isHelloWorld := service.(*helloWorld); !isHelloWorld && cfg.OriginServerName != "" {
httpTransport.TLSClientConfig.ServerName = cfg.OriginServerName
}
dialer := &net.Dialer{
Timeout: cfg.ConnectTimeout,
KeepAlive: cfg.TCPKeepAlive,
}
if cfg.NoHappyEyeballs {
dialer.FallbackDelay = -1 // As of Golang 1.12, a negative delay disables "happy eyeballs"
}
// DialContext depends on which kind of origin is being used.
dialContext := dialer.DialContext
switch service := service.(type) {
// If this origin is a unix socket, enforce network type "unix".
case *unixSocketPath:
httpTransport.DialContext = func(ctx context.Context, _, _ string) (net.Conn, error) {
return dialContext(ctx, "unix", service.path)
}
// Otherwise, use the regular network config.
default:
httpTransport.DialContext = dialContext
}
return &httpTransport, nil
}