mirror of
https://github.com/cloudflare/cloudflared.git
synced 2025-07-27 16:39:58 +00:00
TUN-3492: Refactor OriginService, shrink its interface
This commit is contained in:
@@ -1,72 +1,103 @@
|
||||
package ingress
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/tls"
|
||||
"fmt"
|
||||
"net"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"strconv"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/cloudflare/cloudflared/hello"
|
||||
"github.com/cloudflare/cloudflared/logger"
|
||||
"github.com/cloudflare/cloudflared/socks"
|
||||
"github.com/cloudflare/cloudflared/tlsconfig"
|
||||
"github.com/cloudflare/cloudflared/websocket"
|
||||
gws "github.com/gorilla/websocket"
|
||||
"github.com/pkg/errors"
|
||||
)
|
||||
|
||||
// OriginService is something a tunnel can proxy traffic to.
|
||||
type OriginService interface {
|
||||
Address() string
|
||||
// RoundTrip is how cloudflared proxies eyeball requests to the actual origin services
|
||||
http.RoundTripper
|
||||
String() string
|
||||
// Start the origin service if it's managed by cloudflared, e.g. proxy servers or Hello World.
|
||||
// If it's not managed by cloudflared, this is a no-op because the user is responsible for
|
||||
// starting the origin service.
|
||||
Start(wg *sync.WaitGroup, log logger.Service, shutdownC <-chan struct{}, errC chan error, cfg OriginRequestConfig) error
|
||||
String() string
|
||||
// RewriteOriginURL modifies the HTTP request from cloudflared to the origin, so that it apply
|
||||
// this particular type of origin service's specific routing logic.
|
||||
RewriteOriginURL(*url.URL)
|
||||
start(wg *sync.WaitGroup, log logger.Service, shutdownC <-chan struct{}, errC chan error, cfg OriginRequestConfig) error
|
||||
}
|
||||
|
||||
// UnixSocketPath is an OriginService representing a unix socket (which accepts HTTP)
|
||||
type UnixSocketPath string
|
||||
|
||||
func (o UnixSocketPath) Address() string {
|
||||
return string(o)
|
||||
// unixSocketPath is an OriginService representing a unix socket (which accepts HTTP)
|
||||
type unixSocketPath struct {
|
||||
path string
|
||||
transport *http.Transport
|
||||
}
|
||||
|
||||
func (o UnixSocketPath) String() string {
|
||||
return "unix socket: " + string(o)
|
||||
func (o *unixSocketPath) String() string {
|
||||
return "unix socket: " + o.path
|
||||
}
|
||||
|
||||
func (o UnixSocketPath) Start(wg *sync.WaitGroup, log logger.Service, shutdownC <-chan struct{}, errC chan error, cfg OriginRequestConfig) error {
|
||||
func (o *unixSocketPath) start(wg *sync.WaitGroup, log logger.Service, shutdownC <-chan struct{}, errC chan error, cfg OriginRequestConfig) error {
|
||||
transport, err := newHTTPTransport(o, cfg)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
o.transport = transport
|
||||
return nil
|
||||
}
|
||||
|
||||
func (o UnixSocketPath) RewriteOriginURL(u *url.URL) {
|
||||
// No changes necessary because the origin request URL isn't used.
|
||||
// Instead, HTTPTransport's dial is already configured to address the unix socket.
|
||||
func (o *unixSocketPath) RoundTrip(req *http.Request) (*http.Response, error) {
|
||||
return o.transport.RoundTrip(req)
|
||||
}
|
||||
|
||||
// URL is an OriginService listening on a TCP address
|
||||
type URL struct {
|
||||
func (o *unixSocketPath) Dial(url string, headers http.Header) (*gws.Conn, *http.Response, error) {
|
||||
d := &gws.Dialer{TLSClientConfig: o.transport.TLSClientConfig}
|
||||
return d.Dial(url, headers)
|
||||
}
|
||||
|
||||
// localService is an OriginService listening on a TCP/IP address the user's origin can route to.
|
||||
type localService struct {
|
||||
// The URL for the user's origin service
|
||||
RootURL *url.URL
|
||||
// The URL that cloudflared should send requests to.
|
||||
// If this origin requires starting a proxy, this is the proxy's address,
|
||||
// and that proxy points to RootURL. Otherwise, this is equal to RootURL.
|
||||
URL *url.URL
|
||||
URL *url.URL
|
||||
transport *http.Transport
|
||||
}
|
||||
|
||||
func (o *URL) Address() string {
|
||||
func (o *localService) Dial(url string, headers http.Header) (*gws.Conn, *http.Response, error) {
|
||||
d := &gws.Dialer{TLSClientConfig: o.transport.TLSClientConfig}
|
||||
return d.Dial(url, headers)
|
||||
}
|
||||
|
||||
func (o *localService) address() string {
|
||||
return o.URL.String()
|
||||
}
|
||||
|
||||
func (o *URL) Start(wg *sync.WaitGroup, log logger.Service, shutdownC <-chan struct{}, errC chan error, cfg OriginRequestConfig) error {
|
||||
staticHost := o.staticHost()
|
||||
if !originRequiresProxy(staticHost, cfg) {
|
||||
return nil
|
||||
func (o *localService) start(wg *sync.WaitGroup, log logger.Service, shutdownC <-chan struct{}, errC chan error, cfg OriginRequestConfig) error {
|
||||
transport, err := newHTTPTransport(o, cfg)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
o.transport = transport
|
||||
|
||||
// Start a proxy if one is needed
|
||||
staticHost := o.staticHost()
|
||||
if originRequiresProxy(staticHost, cfg) {
|
||||
if err := o.startProxy(staticHost, wg, log, shutdownC, errC, cfg); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (o *localService) startProxy(staticHost string, wg *sync.WaitGroup, log logger.Service, shutdownC <-chan struct{}, errC chan error, cfg OriginRequestConfig) error {
|
||||
|
||||
// Start a listener for the proxy
|
||||
proxyAddress := net.JoinHostPort(cfg.ProxyAddress, strconv.Itoa(int(cfg.ProxyPort)))
|
||||
@@ -111,16 +142,18 @@ func (o *URL) Start(wg *sync.WaitGroup, log logger.Service, shutdownC <-chan str
|
||||
return nil
|
||||
}
|
||||
|
||||
func (o *URL) String() string {
|
||||
return o.Address()
|
||||
func (o *localService) String() string {
|
||||
return o.address()
|
||||
}
|
||||
|
||||
func (o *URL) RewriteOriginURL(u *url.URL) {
|
||||
u.Host = o.URL.Host
|
||||
u.Scheme = o.URL.Scheme
|
||||
func (o *localService) RoundTrip(req *http.Request) (*http.Response, error) {
|
||||
// Rewrite the request URL so that it goes to the origin service.
|
||||
req.URL.Host = o.URL.Host
|
||||
req.URL.Scheme = o.URL.Scheme
|
||||
return o.transport.RoundTrip(req)
|
||||
}
|
||||
|
||||
func (o *URL) staticHost() string {
|
||||
func (o *localService) staticHost() string {
|
||||
|
||||
addPortIfMissing := func(uri *url.URL, port int) string {
|
||||
if uri.Port() != "" {
|
||||
@@ -143,21 +176,24 @@ func (o *URL) staticHost() string {
|
||||
|
||||
}
|
||||
|
||||
// HelloWorld is the built-in Hello World service. Used for testing and experimenting with cloudflared.
|
||||
type HelloWorld struct {
|
||||
server net.Listener
|
||||
// HelloWorld is an OriginService for the built-in Hello World server.
|
||||
// Users only use this for testing and experimenting with cloudflared.
|
||||
type helloWorld struct {
|
||||
server net.Listener
|
||||
transport *http.Transport
|
||||
}
|
||||
|
||||
func (o *HelloWorld) Address() string {
|
||||
return o.server.Addr().String()
|
||||
}
|
||||
|
||||
func (o *HelloWorld) String() string {
|
||||
return "Hello World static HTML service"
|
||||
func (o *helloWorld) String() string {
|
||||
return "Hello World test origin"
|
||||
}
|
||||
|
||||
// Start starts a HelloWorld server and stores its address in the Service receiver.
|
||||
func (o *HelloWorld) Start(wg *sync.WaitGroup, log logger.Service, shutdownC <-chan struct{}, errC chan error, cfg OriginRequestConfig) error {
|
||||
func (o *helloWorld) start(wg *sync.WaitGroup, log logger.Service, shutdownC <-chan struct{}, errC chan error, cfg OriginRequestConfig) error {
|
||||
transport, err := newHTTPTransport(o, cfg)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
o.transport = transport
|
||||
helloListener, err := hello.CreateTLSListener("127.0.0.1:")
|
||||
if err != nil {
|
||||
return errors.Wrap(err, "Cannot start Hello World Server")
|
||||
@@ -171,11 +207,63 @@ func (o *HelloWorld) Start(wg *sync.WaitGroup, log logger.Service, shutdownC <-c
|
||||
return nil
|
||||
}
|
||||
|
||||
func (o *HelloWorld) RewriteOriginURL(u *url.URL) {
|
||||
u.Host = o.Address()
|
||||
u.Scheme = "https"
|
||||
func (o *helloWorld) RoundTrip(req *http.Request) (*http.Response, error) {
|
||||
// Rewrite the request URL so that it goes to the Hello World server.
|
||||
req.URL.Host = o.server.Addr().String()
|
||||
req.URL.Scheme = "https"
|
||||
return o.transport.RoundTrip(req)
|
||||
}
|
||||
|
||||
func (o *helloWorld) Dial(url string, headers http.Header) (*gws.Conn, *http.Response, error) {
|
||||
d := &gws.Dialer{TLSClientConfig: o.transport.TLSClientConfig}
|
||||
return d.Dial(url, headers)
|
||||
}
|
||||
|
||||
func originRequiresProxy(staticHost string, cfg OriginRequestConfig) bool {
|
||||
return staticHost != "" || cfg.BastionMode
|
||||
}
|
||||
|
||||
func newHTTPTransport(service OriginService, cfg OriginRequestConfig) (*http.Transport, error) {
|
||||
originCertPool, err := tlsconfig.LoadOriginCA(cfg.CAPool, nil)
|
||||
if err != nil {
|
||||
return nil, errors.Wrap(err, "Error loading cert pool")
|
||||
}
|
||||
|
||||
httpTransport := http.Transport{
|
||||
Proxy: http.ProxyFromEnvironment,
|
||||
MaxIdleConns: cfg.KeepAliveConnections,
|
||||
MaxIdleConnsPerHost: cfg.KeepAliveConnections,
|
||||
IdleConnTimeout: cfg.KeepAliveTimeout,
|
||||
TLSHandshakeTimeout: cfg.TLSTimeout,
|
||||
ExpectContinueTimeout: 1 * time.Second,
|
||||
TLSClientConfig: &tls.Config{RootCAs: originCertPool, InsecureSkipVerify: cfg.NoTLSVerify},
|
||||
}
|
||||
if _, isHelloWorld := service.(*helloWorld); !isHelloWorld && cfg.OriginServerName != "" {
|
||||
httpTransport.TLSClientConfig.ServerName = cfg.OriginServerName
|
||||
}
|
||||
|
||||
dialer := &net.Dialer{
|
||||
Timeout: cfg.ConnectTimeout,
|
||||
KeepAlive: cfg.TCPKeepAlive,
|
||||
}
|
||||
if cfg.NoHappyEyeballs {
|
||||
dialer.FallbackDelay = -1 // As of Golang 1.12, a negative delay disables "happy eyeballs"
|
||||
}
|
||||
|
||||
// DialContext depends on which kind of origin is being used.
|
||||
dialContext := dialer.DialContext
|
||||
switch service := service.(type) {
|
||||
|
||||
// If this origin is a unix socket, enforce network type "unix".
|
||||
case *unixSocketPath:
|
||||
httpTransport.DialContext = func(ctx context.Context, _, _ string) (net.Conn, error) {
|
||||
return dialContext(ctx, "unix", service.path)
|
||||
}
|
||||
|
||||
// Otherwise, use the regular network config.
|
||||
default:
|
||||
httpTransport.DialContext = dialContext
|
||||
}
|
||||
|
||||
return &httpTransport, nil
|
||||
}
|
||||
|
Reference in New Issue
Block a user