TUN-2640: Users can configure per-origin config. Unify single-rule CLI

flow with multi-rule config file code.
This commit is contained in:
Adam Chalmers
2020-10-15 16:41:03 -05:00
parent ea71b78e6d
commit e933ef9e1a
13 changed files with 1210 additions and 481 deletions

View File

@@ -1,14 +1,24 @@
package ingress
import (
"context"
"crypto/tls"
"fmt"
"net"
"net/http"
"net/url"
"regexp"
"strings"
"sync"
"time"
"github.com/pkg/errors"
"github.com/urfave/cli/v2"
"github.com/cloudflare/cloudflared/cmd/cloudflared/config"
"github.com/cloudflare/cloudflared/logger"
"github.com/cloudflare/cloudflared/tlsconfig"
"github.com/cloudflare/cloudflared/validation"
)
var (
@@ -18,54 +28,93 @@ var (
ErrURLIncompatibleWithIngress = errors.New("You can't set the --url flag (or $TUNNEL_URL) when using multiple-origin ingress rules")
)
// Each rule route traffic from a hostname/path on the public
// internet to the service running on the given URL.
type Rule struct {
// Requests for this hostname will be proxied to this rule's service.
Hostname string
// Finalize the rules by adding missing struct fields and validating each origin.
func (ing *Ingress) setHTTPTransport(logger logger.Service) error {
for ruleNumber, rule := range ing.Rules {
cfg := rule.Config
originCertPool, err := tlsconfig.LoadOriginCA(cfg.CAPool, nil)
if err != nil {
return errors.Wrap(err, "Error loading cert pool")
}
// Path is an optional regex that can specify path-driven ingress rules.
Path *regexp.Regexp
httpTransport := &http.Transport{
Proxy: http.ProxyFromEnvironment,
MaxIdleConns: cfg.KeepAliveConnections,
MaxIdleConnsPerHost: cfg.KeepAliveConnections,
IdleConnTimeout: cfg.KeepAliveTimeout,
TLSHandshakeTimeout: cfg.TLSTimeout,
ExpectContinueTimeout: 1 * time.Second,
TLSClientConfig: &tls.Config{RootCAs: originCertPool, InsecureSkipVerify: cfg.NoTLSVerify},
}
if _, isHelloWorld := rule.Service.(*HelloWorld); !isHelloWorld && cfg.OriginServerName != "" {
httpTransport.TLSClientConfig.ServerName = cfg.OriginServerName
}
// A (probably local) address. Requests for a hostname which matches this
// rule's hostname pattern will be proxied to the service running on this
// address.
Service *url.URL
}
dialer := &net.Dialer{
Timeout: cfg.ConnectTimeout,
KeepAlive: cfg.TCPKeepAlive,
}
if cfg.NoHappyEyeballs {
dialer.FallbackDelay = -1 // As of Golang 1.12, a negative delay disables "happy eyeballs"
}
func (r Rule) MultiLineString() string {
var out strings.Builder
if r.Hostname != "" {
out.WriteString("\thostname: ")
out.WriteString(r.Hostname)
out.WriteRune('\n')
// DialContext depends on which kind of origin is being used.
dialContext := dialer.DialContext
switch service := rule.Service.(type) {
// If this origin is a unix socket, enforce network type "unix".
case UnixSocketPath:
httpTransport.DialContext = func(ctx context.Context, _, _ string) (net.Conn, error) {
return dialContext(ctx, "unix", service.Address())
}
// Otherwise, use the regular network config.
default:
httpTransport.DialContext = dialContext
}
ing.Rules[ruleNumber].HTTPTransport = httpTransport
ing.Rules[ruleNumber].ClientTLSConfig = httpTransport.TLSClientConfig
}
if r.Path != nil {
out.WriteString("\tpath: ")
out.WriteString(r.Path.String())
out.WriteRune('\n')
}
out.WriteString("\tservice: ")
out.WriteString(r.Service.String())
return out.String()
}
func (r *Rule) Matches(hostname, path string) bool {
hostMatch := r.Hostname == "" || r.Hostname == "*" || matchHost(r.Hostname, hostname)
pathMatch := r.Path == nil || r.Path.MatchString(path)
return hostMatch && pathMatch
// Validate each origin
for _, rule := range ing.Rules {
// If tunnel running in bastion mode, a connection to origin will not exist until initiated by the client.
if rule.Config.BastionMode {
continue
}
// Unix sockets don't have validation
if _, ok := rule.Service.(UnixSocketPath); ok {
continue
}
switch service := rule.Service.(type) {
case UnixSocketPath:
continue
case *HelloWorld:
continue
default:
if err := validation.ValidateHTTPService(service.Address(), rule.Hostname, rule.HTTPTransport); err != nil {
logger.Errorf("unable to connect to the origin: %s", err)
}
}
}
return nil
}
// FindMatchingRule returns the index of the Ingress Rule which matches the given
// hostname and path. This function assumes the last rule matches everything,
// which is the case if the rules were instantiated via the ingress#Validate method
func (ing Ingress) FindMatchingRule(hostname, path string) int {
func (ing Ingress) FindMatchingRule(hostname, path string) (*Rule, int) {
for i, rule := range ing.Rules {
if rule.Matches(hostname, path) {
return i
return &rule, i
}
}
return len(ing.Rules) - 1
i := len(ing.Rules) - 1
return &ing.Rules[i], i
}
func matchHost(ruleHost, reqHost string) bool {
@@ -83,7 +132,56 @@ func matchHost(ruleHost, reqHost string) bool {
// Ingress maps eyeball requests to origins.
type Ingress struct {
Rules []Rule
Rules []Rule
defaults OriginRequestConfig
}
// NewSingleOrigin constructs an Ingress set with only one rule, constructed from
// legacy CLI parameters like --url or --no-chunked-encoding.
func NewSingleOrigin(c *cli.Context, compatibilityMode bool, logger logger.Service) (Ingress, error) {
service, err := parseSingleOriginService(c, compatibilityMode)
if err != nil {
return Ingress{}, err
}
// Construct an Ingress with the single rule.
ing := Ingress{
Rules: []Rule{
{
Service: service,
},
},
defaults: originRequestFromSingeRule(c),
}
err = ing.setHTTPTransport(logger)
return ing, err
}
// Get a single origin service from the CLI/config.
func parseSingleOriginService(c *cli.Context, compatibilityMode bool) (OriginService, error) {
if c.IsSet("hello-world") {
return new(HelloWorld), nil
}
if c.IsSet("url") {
originURLStr, err := config.ValidateUrl(c, compatibilityMode)
if err != nil {
return nil, errors.Wrap(err, "Error validating origin URL")
}
originURL, err := url.Parse(originURLStr)
if err != nil {
return nil, errors.Wrap(err, "couldn't parse origin URL")
}
return &URL{URL: originURL, RootURL: originURL}, nil
}
if c.IsSet("unix-socket") {
unixSocket, err := config.ValidateUnixSocket(c)
if err != nil {
return nil, errors.Wrap(err, "Error validating --unix-socket")
}
return UnixSocketPath(unixSocket), nil
}
return nil, errors.New("You must either set ingress rules in your config file, or use --url or use --unix-socket")
}
// IsEmpty checks if there are any ingress rules.
@@ -91,19 +189,47 @@ func (ing Ingress) IsEmpty() bool {
return len(ing.Rules) == 0
}
func validate(ingress []config.UnvalidatedIngressRule) (Ingress, error) {
// StartOrigins will start any origin services managed by cloudflared, e.g. proxy servers or Hello World.
func (ing Ingress) StartOrigins(wg *sync.WaitGroup, log logger.Service, shutdownC <-chan struct{}, errC chan error) error {
for _, rule := range ing.Rules {
if err := rule.Service.Start(wg, log, shutdownC, errC, rule.Config); err != nil {
return err
}
}
return nil
}
// CatchAll returns the catch-all rule (i.e. the last rule)
func (ing Ingress) CatchAll() *Rule {
return &ing.Rules[len(ing.Rules)-1]
}
func validate(ingress []config.UnvalidatedIngressRule, defaults OriginRequestConfig) (Ingress, error) {
rules := make([]Rule, len(ingress))
for i, r := range ingress {
service, err := url.Parse(r.Service)
if err != nil {
return Ingress{}, err
}
if service.Scheme == "" || service.Hostname() == "" {
return Ingress{}, fmt.Errorf("The service %s must have a scheme and a hostname", r.Service)
}
var service OriginService
if service.Path != "" {
return Ingress{}, fmt.Errorf("%s is an invalid address, ingress rules don't support proxying to a different path on the origin service. The path will be the same as the eyeball request's path.", r.Service)
if strings.HasPrefix(r.Service, "unix:") {
// No validation necessary for unix socket filepath services
service = UnixSocketPath(strings.TrimPrefix(r.Service, "unix:"))
} else if r.Service == "hello_world" || r.Service == "hello-world" || r.Service == "helloworld" {
service = new(HelloWorld)
} else {
// Validate URL services
u, err := url.Parse(r.Service)
if err != nil {
return Ingress{}, err
}
if u.Scheme == "" || u.Hostname() == "" {
return Ingress{}, fmt.Errorf("The service %s must have a scheme and a hostname", r.Service)
}
if u.Path != "" {
return Ingress{}, fmt.Errorf("%s is an invalid address, ingress rules don't support proxying to a different path on the origin service. The path will be the same as the eyeball request's path", r.Service)
}
serviceURL := URL{URL: u}
service = &serviceURL
}
// Ensure that there are no wildcards anywhere except the first character
@@ -125,6 +251,7 @@ func validate(ingress []config.UnvalidatedIngressRule) (Ingress, error) {
var pathRegex *regexp.Regexp
if r.Path != "" {
var err error
pathRegex, err = regexp.Compile(r.Path)
if err != nil {
return Ingress{}, errors.Wrapf(err, "Rule #%d has an invalid regex", i+1)
@@ -135,9 +262,10 @@ func validate(ingress []config.UnvalidatedIngressRule) (Ingress, error) {
Hostname: r.Hostname,
Service: service,
Path: pathRegex,
Config: SetConfig(defaults, r.OriginRequest),
}
}
return Ingress{Rules: rules}, nil
return Ingress{Rules: rules, defaults: defaults}, nil
}
type errRuleShouldNotBeCatchAll struct {
@@ -151,9 +279,20 @@ func (e errRuleShouldNotBeCatchAll) Error() string {
"will never be triggered.", e.i+1, e.hostname)
}
func ParseIngress(conf *config.Configuration) (Ingress, error) {
// ParseIngress parses, validates and initializes HTTP transports to each origin.
func ParseIngress(conf *config.Configuration, logger logger.Service) (Ingress, error) {
ing, err := ParseIngressDryRun(conf)
if err != nil {
return Ingress{}, err
}
err = ing.setHTTPTransport(logger)
return ing, err
}
// ParseIngressDryRun parses ingress rules, but does not send HTTP requests to the origins.
func ParseIngressDryRun(conf *config.Configuration) (Ingress, error) {
if len(conf.Ingress) == 0 {
return Ingress{}, ErrNoIngressRules
}
return validate(conf.Ingress)
return validate(conf.Ingress, OriginRequestFromYAML(conf.OriginRequest))
}

View File

@@ -2,7 +2,6 @@ package ingress
import (
"net/url"
"regexp"
"testing"
"github.com/stretchr/testify/assert"
@@ -12,16 +11,29 @@ import (
"github.com/cloudflare/cloudflared/cmd/cloudflared/config"
)
func TestParseUnixSocket(t *testing.T) {
rawYAML := `
ingress:
- service: unix:/tmp/echo.sock
`
ing, err := ParseIngressDryRun(MustReadIngress(rawYAML))
require.NoError(t, err)
_, ok := ing.Rules[0].Service.(UnixSocketPath)
require.True(t, ok)
}
func Test_parseIngress(t *testing.T) {
localhost8000 := MustParseURL(t, "https://localhost:8000")
localhost8001 := MustParseURL(t, "https://localhost:8001")
defaultConfig := SetConfig(OriginRequestFromYAML(config.OriginRequestConfig{}), config.OriginRequestConfig{})
require.Equal(t, defaultKeepAliveConnections, defaultConfig.KeepAliveConnections)
type args struct {
rawYAML string
}
tests := []struct {
name string
args args
want Ingress
want []Rule
wantErr bool
}{
{
@@ -38,16 +50,18 @@ ingress:
- hostname: "*"
service: https://localhost:8001
`},
want: Ingress{Rules: []Rule{
want: []Rule{
{
Hostname: "tunnel1.example.com",
Service: localhost8000,
Service: &URL{URL: localhost8000},
Config: defaultConfig,
},
{
Hostname: "*",
Service: localhost8001,
Service: &URL{URL: localhost8001},
Config: defaultConfig,
},
}},
},
},
{
name: "Extra keys",
@@ -57,12 +71,13 @@ ingress:
service: https://localhost:8000
extraKey: extraValue
`},
want: Ingress{Rules: []Rule{
want: []Rule{
{
Hostname: "*",
Service: localhost8000,
Service: &URL{URL: localhost8000},
Config: defaultConfig,
},
}},
},
},
{
name: "Hostname can be omitted",
@@ -70,11 +85,12 @@ extraKey: extraValue
ingress:
- service: https://localhost:8000
`},
want: Ingress{Rules: []Rule{
want: []Rule{
{
Service: localhost8000,
Service: &URL{URL: localhost8000},
Config: defaultConfig,
},
}},
},
},
{
name: "Invalid service",
@@ -152,12 +168,12 @@ ingress:
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
got, err := ParseIngress(MustReadIngress(tt.args.rawYAML))
got, err := ParseIngressDryRun(MustReadIngress(tt.args.rawYAML))
if (err != nil) != tt.wantErr {
t.Errorf("ParseIngress() error = %v, wantErr %v", err, tt.wantErr)
t.Errorf("ParseIngressDryRun() error = %v, wantErr %v", err, tt.wantErr)
return
}
assert.Equal(t, tt.want, got)
assert.Equal(t, tt.want, got.Rules)
})
}
}
@@ -168,118 +184,6 @@ func MustParseURL(t *testing.T, rawURL string) *url.URL {
return u
}
func Test_rule_matches(t *testing.T) {
type fields struct {
Hostname string
Path *regexp.Regexp
Service *url.URL
}
type args struct {
requestURL *url.URL
}
tests := []struct {
name string
fields fields
args args
want bool
}{
{
name: "Just hostname, pass",
fields: fields{
Hostname: "example.com",
},
args: args{
requestURL: MustParseURL(t, "https://example.com"),
},
want: true,
},
{
name: "Entire hostname is wildcard, should match everything",
fields: fields{
Hostname: "*",
},
args: args{
requestURL: MustParseURL(t, "https://example.com"),
},
want: true,
},
{
name: "Just hostname, fail",
fields: fields{
Hostname: "example.com",
},
args: args{
requestURL: MustParseURL(t, "https://foo.bar"),
},
want: false,
},
{
name: "Just wildcard hostname, pass",
fields: fields{
Hostname: "*.example.com",
},
args: args{
requestURL: MustParseURL(t, "https://adam.example.com"),
},
want: true,
},
{
name: "Just wildcard hostname, fail",
fields: fields{
Hostname: "*.example.com",
},
args: args{
requestURL: MustParseURL(t, "https://tunnel.com"),
},
want: false,
},
{
name: "Just wildcard outside of subdomain in hostname, fail",
fields: fields{
Hostname: "*example.com",
},
args: args{
requestURL: MustParseURL(t, "https://www.example.com"),
},
want: false,
},
{
name: "Wildcard over multiple subdomains",
fields: fields{
Hostname: "*.example.com",
},
args: args{
requestURL: MustParseURL(t, "https://adam.chalmers.example.com"),
},
want: true,
},
{
name: "Hostname and path",
fields: fields{
Hostname: "*.example.com",
Path: regexp.MustCompile("/static/.*\\.html"),
},
args: args{
requestURL: MustParseURL(t, "https://www.example.com/static/index.html"),
},
want: true,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
r := Rule{
Hostname: tt.fields.Hostname,
Path: tt.fields.Path,
Service: tt.fields.Service,
}
u := tt.args.requestURL
if got := r.Matches(u.Hostname(), u.Path); got != tt.want {
t.Errorf("rule.matches() = %v, want %v", got, tt.want)
}
})
}
}
func BenchmarkFindMatch(b *testing.B) {
rulesYAML := `
ingress:
@@ -291,7 +195,7 @@ ingress:
service: https://localhost:8002
`
ing, err := ParseIngress(MustReadIngress(rulesYAML))
ing, err := ParseIngressDryRun(MustReadIngress(rulesYAML))
if err != nil {
b.Error(err)
}

View File

@@ -0,0 +1,331 @@
package ingress
import (
"time"
"github.com/cloudflare/cloudflared/cmd/cloudflared/config"
"github.com/cloudflare/cloudflared/tlsconfig"
"github.com/urfave/cli/v2"
)
const (
defaultConnectTimeout = 30 * time.Second
defaultTLSTimeout = 10 * time.Second
defaultTCPKeepAlive = 30 * time.Second
defaultKeepAliveConnections = 100
defaultKeepAliveTimeout = 90 * time.Second
defaultProxyAddress = "127.0.0.1"
SSHServerFlag = "ssh-server"
Socks5Flag = "socks5"
ProxyConnectTimeoutFlag = "proxy-connect-timeout"
ProxyTLSTimeoutFlag = "proxy-tls-timeout"
ProxyTCPKeepAlive = "proxy-tcp-keepalive"
ProxyNoHappyEyeballsFlag = "proxy-no-happy-eyeballs"
ProxyKeepAliveConnectionsFlag = "proxy-keepalive-connections"
ProxyKeepAliveTimeoutFlag = "proxy-keepalive-timeout"
HTTPHostHeaderFlag = "http-host-header"
OriginServerNameFlag = "origin-server-name"
NoTLSVerifyFlag = "no-tls-verify"
NoChunkedEncodingFlag = "no-chunked-encoding"
ProxyAddressFlag = "proxy-address"
ProxyPortFlag = "proxy-port"
)
const (
socksProxy = "socks"
)
func originRequestFromSingeRule(c *cli.Context) OriginRequestConfig {
var connectTimeout time.Duration = defaultConnectTimeout
var tlsTimeout time.Duration = defaultTLSTimeout
var tcpKeepAlive time.Duration = defaultTCPKeepAlive
var noHappyEyeballs bool
var keepAliveConnections int = defaultKeepAliveConnections
var keepAliveTimeout time.Duration = defaultKeepAliveTimeout
var httpHostHeader string
var originServerName string
var caPool string
var noTLSVerify bool
var disableChunkedEncoding bool
var bastionMode bool
var proxyAddress string
var proxyPort uint
var proxyType string
if flag := ProxyConnectTimeoutFlag; c.IsSet(flag) {
connectTimeout = c.Duration(flag)
}
if flag := ProxyTLSTimeoutFlag; c.IsSet(flag) {
tlsTimeout = c.Duration(flag)
}
if flag := ProxyTCPKeepAlive; c.IsSet(flag) {
tcpKeepAlive = c.Duration(flag)
}
if flag := ProxyNoHappyEyeballsFlag; c.IsSet(flag) {
noHappyEyeballs = c.Bool(flag)
}
if flag := ProxyKeepAliveConnectionsFlag; c.IsSet(flag) {
keepAliveConnections = c.Int(flag)
}
if flag := ProxyKeepAliveTimeoutFlag; c.IsSet(flag) {
keepAliveTimeout = c.Duration(flag)
}
if flag := HTTPHostHeaderFlag; c.IsSet(flag) {
httpHostHeader = c.String(flag)
}
if flag := OriginServerNameFlag; c.IsSet(flag) {
originServerName = c.String(flag)
}
if flag := tlsconfig.OriginCAPoolFlag; c.IsSet(flag) {
caPool = c.String(flag)
}
if flag := NoTLSVerifyFlag; c.IsSet(flag) {
noTLSVerify = c.Bool(flag)
}
if flag := NoChunkedEncodingFlag; c.IsSet(flag) {
disableChunkedEncoding = c.Bool(flag)
}
if flag := config.BastionFlag; c.IsSet(flag) {
bastionMode = c.Bool(flag)
}
if flag := ProxyAddressFlag; c.IsSet(flag) {
proxyAddress = c.String(flag)
}
if flag := ProxyPortFlag; c.IsSet(flag) {
proxyPort = c.Uint(flag)
}
if c.IsSet(Socks5Flag) {
proxyType = socksProxy
}
return OriginRequestConfig{
ConnectTimeout: connectTimeout,
TLSTimeout: tlsTimeout,
TCPKeepAlive: tcpKeepAlive,
NoHappyEyeballs: noHappyEyeballs,
KeepAliveConnections: keepAliveConnections,
KeepAliveTimeout: keepAliveTimeout,
HTTPHostHeader: httpHostHeader,
OriginServerName: originServerName,
CAPool: caPool,
NoTLSVerify: noTLSVerify,
DisableChunkedEncoding: disableChunkedEncoding,
BastionMode: bastionMode,
ProxyAddress: proxyAddress,
ProxyPort: proxyPort,
ProxyType: proxyType,
}
}
func OriginRequestFromYAML(y config.OriginRequestConfig) OriginRequestConfig {
out := OriginRequestConfig{
ConnectTimeout: defaultConnectTimeout,
TLSTimeout: defaultTLSTimeout,
TCPKeepAlive: defaultTCPKeepAlive,
KeepAliveConnections: defaultKeepAliveConnections,
KeepAliveTimeout: defaultKeepAliveTimeout,
ProxyAddress: defaultProxyAddress,
}
if y.ConnectTimeout != nil {
out.ConnectTimeout = *y.ConnectTimeout
}
if y.TLSTimeout != nil {
out.TLSTimeout = *y.TLSTimeout
}
if y.TCPKeepAlive != nil {
out.TCPKeepAlive = *y.TCPKeepAlive
}
if y.NoHappyEyeballs != nil {
out.NoHappyEyeballs = *y.NoHappyEyeballs
}
if y.KeepAliveConnections != nil {
out.KeepAliveConnections = *y.KeepAliveConnections
}
if y.KeepAliveTimeout != nil {
out.KeepAliveTimeout = *y.KeepAliveTimeout
}
if y.HTTPHostHeader != nil {
out.HTTPHostHeader = *y.HTTPHostHeader
}
if y.OriginServerName != nil {
out.OriginServerName = *y.OriginServerName
}
if y.CAPool != nil {
out.CAPool = *y.CAPool
}
if y.NoTLSVerify != nil {
out.NoTLSVerify = *y.NoTLSVerify
}
if y.DisableChunkedEncoding != nil {
out.DisableChunkedEncoding = *y.DisableChunkedEncoding
}
if y.BastionMode != nil {
out.BastionMode = *y.BastionMode
}
if y.ProxyAddress != nil {
out.ProxyAddress = *y.ProxyAddress
}
if y.ProxyPort != nil {
out.ProxyPort = *y.ProxyPort
}
if y.ProxyType != nil {
out.ProxyType = *y.ProxyType
}
return out
}
// OriginRequestConfig configures how Cloudflared sends requests to origin
// services.
// Note: To specify a time.Duration in go-yaml, use e.g. "3s" or "24h".
type OriginRequestConfig struct {
// HTTP proxy timeout for establishing a new connection
ConnectTimeout time.Duration `yaml:"connectTimeout"`
// HTTP proxy timeout for completing a TLS handshake
TLSTimeout time.Duration `yaml:"tlsTimeout"`
// HTTP proxy TCP keepalive duration
TCPKeepAlive time.Duration `yaml:"tcpKeepAlive"`
// HTTP proxy should disable "happy eyeballs" for IPv4/v6 fallback
NoHappyEyeballs bool `yaml:"noHappyEyeballs"`
// HTTP proxy maximum keepalive connection pool size
KeepAliveConnections int `yaml:"keepAliveConnections"`
// HTTP proxy timeout for closing an idle connection
KeepAliveTimeout time.Duration `yaml:"keepAliveTimeout"`
// Sets the HTTP Host header for the local webserver.
HTTPHostHeader string `yaml:"httpHostHeader"`
// Hostname on the origin server certificate.
OriginServerName string `yaml:"originServerName"`
// Path to the CA for the certificate of your origin.
// This option should be used only if your certificate is not signed by Cloudflare.
CAPool string `yaml:"caPool"`
// Disables TLS verification of the certificate presented by your origin.
// Will allow any certificate from the origin to be accepted.
// Note: The connection from your machine to Cloudflare's Edge is still encrypted.
NoTLSVerify bool `yaml:"noTLSVerify"`
// Disables chunked transfer encoding.
// Useful if you are running a WSGI server.
DisableChunkedEncoding bool `yaml:"disableChunkedEncoding"`
// Runs as jump host
BastionMode bool `yaml:"bastionMode"`
// Listen address for the proxy.
ProxyAddress string `yaml:"proxyAddress"`
// Listen port for the proxy.
ProxyPort uint `yaml:"proxyPort"`
// What sort of proxy should be started
ProxyType string `yaml:"proxyType"`
}
func (defaults *OriginRequestConfig) setConnectTimeout(overrides config.OriginRequestConfig) {
if val := overrides.ConnectTimeout; val != nil {
defaults.ConnectTimeout = *val
}
}
func (defaults *OriginRequestConfig) setTLSTimeout(overrides config.OriginRequestConfig) {
if val := overrides.TLSTimeout; val != nil {
defaults.TLSTimeout = *val
}
}
func (defaults *OriginRequestConfig) setNoHappyEyeballs(overrides config.OriginRequestConfig) {
if val := overrides.NoHappyEyeballs; val != nil {
defaults.NoHappyEyeballs = *val
}
}
func (defaults *OriginRequestConfig) setKeepAliveConnections(overrides config.OriginRequestConfig) {
if val := overrides.KeepAliveConnections; val != nil {
defaults.KeepAliveConnections = *val
}
}
func (defaults *OriginRequestConfig) setKeepAliveTimeout(overrides config.OriginRequestConfig) {
if val := overrides.KeepAliveTimeout; val != nil {
defaults.KeepAliveTimeout = *val
}
}
func (defaults *OriginRequestConfig) setTCPKeepAlive(overrides config.OriginRequestConfig) {
if val := overrides.TCPKeepAlive; val != nil {
defaults.TCPKeepAlive = *val
}
}
func (defaults *OriginRequestConfig) setHTTPHostHeader(overrides config.OriginRequestConfig) {
if val := overrides.HTTPHostHeader; val != nil {
defaults.HTTPHostHeader = *val
}
}
func (defaults *OriginRequestConfig) setOriginServerName(overrides config.OriginRequestConfig) {
if val := overrides.OriginServerName; val != nil {
defaults.OriginServerName = *val
}
}
func (defaults *OriginRequestConfig) setCAPool(overrides config.OriginRequestConfig) {
if val := overrides.CAPool; val != nil {
defaults.CAPool = *val
}
}
func (defaults *OriginRequestConfig) setNoTLSVerify(overrides config.OriginRequestConfig) {
if val := overrides.NoTLSVerify; val != nil {
defaults.NoTLSVerify = *val
}
}
func (defaults *OriginRequestConfig) setDisableChunkedEncoding(overrides config.OriginRequestConfig) {
if val := overrides.DisableChunkedEncoding; val != nil {
defaults.DisableChunkedEncoding = *val
}
}
func (defaults *OriginRequestConfig) setBastionMode(overrides config.OriginRequestConfig) {
if val := overrides.BastionMode; val != nil {
defaults.BastionMode = *val
}
}
func (defaults *OriginRequestConfig) setProxyPort(overrides config.OriginRequestConfig) {
if val := overrides.ProxyPort; val != nil {
defaults.ProxyPort = *val
}
}
func (defaults *OriginRequestConfig) setProxyAddress(overrides config.OriginRequestConfig) {
if val := overrides.ProxyAddress; val != nil {
defaults.ProxyAddress = *val
}
}
func (defaults *OriginRequestConfig) setProxyType(overrides config.OriginRequestConfig) {
if val := overrides.ProxyType; val != nil {
defaults.ProxyType = *val
}
}
// SetConfig gets config for the requests that cloudflared sends to origins.
// Each field has a setter method which sets a value for the field by trying to find:
// 1. The user config for this rule
// 2. The user config for the overall ingress config
// 3. Defaults chosen by the cloudflared team
// 4. Golang zero values for that type
// If an earlier option isn't set, it will try the next option down.
func SetConfig(defaults OriginRequestConfig, overrides config.OriginRequestConfig) OriginRequestConfig {
cfg := defaults
cfg.setConnectTimeout(overrides)
cfg.setTLSTimeout(overrides)
cfg.setNoHappyEyeballs(overrides)
cfg.setKeepAliveConnections(overrides)
cfg.setKeepAliveTimeout(overrides)
cfg.setTCPKeepAlive(overrides)
cfg.setHTTPHostHeader(overrides)
cfg.setOriginServerName(overrides)
cfg.setCAPool(overrides)
cfg.setNoTLSVerify(overrides)
cfg.setDisableChunkedEncoding(overrides)
cfg.setBastionMode(overrides)
cfg.setProxyPort(overrides)
cfg.setProxyAddress(overrides)
cfg.setProxyType(overrides)
return cfg
}

View File

@@ -0,0 +1,184 @@
package ingress
import (
"testing"
"time"
"github.com/cloudflare/cloudflared/cmd/cloudflared/config"
"github.com/stretchr/testify/require"
"gopkg.in/yaml.v2"
)
// Ensure that the nullable config from `config` package and the
// non-nullable config from `ingress` package have the same number of
// fields.
// This test ensures that programmers didn't add a new field to
// one struct and forget to add it to the other ;)
func TestCorrespondingFields(t *testing.T) {
require.Equal(
t,
CountFields(t, config.OriginRequestConfig{}),
CountFields(t, OriginRequestConfig{}),
)
}
func CountFields(t *testing.T, val interface{}) int {
b, err := yaml.Marshal(val)
require.NoError(t, err)
m := make(map[string]interface{}, 0)
err = yaml.Unmarshal(b, &m)
require.NoError(t, err)
return len(m)
}
func TestOriginRequestConfigOverrides(t *testing.T) {
rulesYAML := `
originRequest:
connectTimeout: 1m
tlsTimeout: 1s
noHappyEyeballs: true
tcpKeepAlive: 1s
keepAliveConnections: 1
keepAliveTimeout: 1s
httpHostHeader: abc
originServerName: a1
caPool: /tmp/path0
noTLSVerify: true
disableChunkedEncoding: true
bastionMode: True
proxyAddress: 127.1.2.3
proxyPort: 100
proxyType: socks5
ingress:
- hostname: tun.example.com
service: https://localhost:8000
- hostname: "*"
service: https://localhost:8001
originRequest:
connectTimeout: 2m
tlsTimeout: 2s
noHappyEyeballs: false
tcpKeepAlive: 2s
keepAliveConnections: 2
keepAliveTimeout: 2s
httpHostHeader: def
originServerName: b2
caPool: /tmp/path1
noTLSVerify: false
disableChunkedEncoding: false
bastionMode: false
proxyAddress: interface
proxyPort: 200
proxyType: ""
`
ing, err := ParseIngressDryRun(MustReadIngress(rulesYAML))
if err != nil {
t.Error(err)
}
// Rule 0 didn't override anything, so it inherits the user-specified
// root-level configuration.
actual0 := ing.Rules[0].Config
expected0 := OriginRequestConfig{
ConnectTimeout: 1 * time.Minute,
TLSTimeout: 1 * time.Second,
NoHappyEyeballs: true,
TCPKeepAlive: 1 * time.Second,
KeepAliveConnections: 1,
KeepAliveTimeout: 1 * time.Second,
HTTPHostHeader: "abc",
OriginServerName: "a1",
CAPool: "/tmp/path0",
NoTLSVerify: true,
DisableChunkedEncoding: true,
BastionMode: true,
ProxyAddress: "127.1.2.3",
ProxyPort: uint(100),
ProxyType: "socks5",
}
require.Equal(t, expected0, actual0)
// Rule 1 overrode all the root-level config.
actual1 := ing.Rules[1].Config
expected1 := OriginRequestConfig{
ConnectTimeout: 2 * time.Minute,
TLSTimeout: 2 * time.Second,
NoHappyEyeballs: false,
TCPKeepAlive: 2 * time.Second,
KeepAliveConnections: 2,
KeepAliveTimeout: 2 * time.Second,
HTTPHostHeader: "def",
OriginServerName: "b2",
CAPool: "/tmp/path1",
NoTLSVerify: false,
DisableChunkedEncoding: false,
BastionMode: false,
ProxyAddress: "interface",
ProxyPort: uint(200),
ProxyType: "",
}
require.Equal(t, expected1, actual1)
}
func TestOriginRequestConfigDefaults(t *testing.T) {
rulesYAML := `
ingress:
- hostname: tun.example.com
service: https://localhost:8000
- hostname: "*"
service: https://localhost:8001
originRequest:
connectTimeout: 2m
tlsTimeout: 2s
noHappyEyeballs: false
tcpKeepAlive: 2s
keepAliveConnections: 2
keepAliveTimeout: 2s
httpHostHeader: def
originServerName: b2
caPool: /tmp/path1
noTLSVerify: false
disableChunkedEncoding: false
bastionMode: false
proxyAddress: interface
proxyPort: 200
proxyType: ""
`
ing, err := ParseIngressDryRun(MustReadIngress(rulesYAML))
if err != nil {
t.Error(err)
}
// Rule 0 didn't override anything, so it inherits the cloudflared defaults
actual0 := ing.Rules[0].Config
expected0 := OriginRequestConfig{
ConnectTimeout: defaultConnectTimeout,
TLSTimeout: defaultTLSTimeout,
TCPKeepAlive: defaultTCPKeepAlive,
KeepAliveConnections: defaultKeepAliveConnections,
KeepAliveTimeout: defaultKeepAliveTimeout,
ProxyAddress: defaultProxyAddress,
}
require.Equal(t, expected0, actual0)
// Rule 1 overrode all defaults.
actual1 := ing.Rules[1].Config
expected1 := OriginRequestConfig{
ConnectTimeout: 2 * time.Minute,
TLSTimeout: 2 * time.Second,
NoHappyEyeballs: false,
TCPKeepAlive: 2 * time.Second,
KeepAliveConnections: 2,
KeepAliveTimeout: 2 * time.Second,
HTTPHostHeader: "def",
OriginServerName: "b2",
CAPool: "/tmp/path1",
NoTLSVerify: false,
DisableChunkedEncoding: false,
BastionMode: false,
ProxyAddress: "interface",
ProxyPort: uint(200),
ProxyType: "",
}
require.Equal(t, expected1, actual1)
}

181
ingress/origin_service.go Normal file
View File

@@ -0,0 +1,181 @@
package ingress
import (
"fmt"
"net"
"net/http"
"net/url"
"strconv"
"sync"
"github.com/cloudflare/cloudflared/hello"
"github.com/cloudflare/cloudflared/logger"
"github.com/cloudflare/cloudflared/socks"
"github.com/cloudflare/cloudflared/websocket"
"github.com/pkg/errors"
)
// OriginService is something a tunnel can proxy traffic to.
type OriginService interface {
Address() string
// Start the origin service if it's managed by cloudflared, e.g. proxy servers or Hello World.
// If it's not managed by cloudflared, this is a no-op because the user is responsible for
// starting the origin service.
Start(wg *sync.WaitGroup, log logger.Service, shutdownC <-chan struct{}, errC chan error, cfg OriginRequestConfig) error
String() string
// RewriteOriginURL modifies the HTTP request from cloudflared to the origin, so that it apply
// this particular type of origin service's specific routing logic.
RewriteOriginURL(*url.URL)
}
// UnixSocketPath is an OriginService representing a unix socket (which accepts HTTP)
type UnixSocketPath string
func (o UnixSocketPath) Address() string {
return string(o)
}
func (o UnixSocketPath) String() string {
return "unix socket: " + string(o)
}
func (o UnixSocketPath) Start(wg *sync.WaitGroup, log logger.Service, shutdownC <-chan struct{}, errC chan error, cfg OriginRequestConfig) error {
return nil
}
func (o UnixSocketPath) RewriteOriginURL(u *url.URL) {
// No changes necessary because the origin request URL isn't used.
// Instead, HTTPTransport's dial is already configured to address the unix socket.
}
// URL is an OriginService listening on a TCP address
type URL struct {
// The URL for the user's origin service
RootURL *url.URL
// The URL that cloudflared should send requests to.
// If this origin requires starting a proxy, this is the proxy's address,
// and that proxy points to RootURL. Otherwise, this is equal to RootURL.
URL *url.URL
}
func (o *URL) Address() string {
return o.URL.String()
}
func (o *URL) Start(wg *sync.WaitGroup, log logger.Service, shutdownC <-chan struct{}, errC chan error, cfg OriginRequestConfig) error {
staticHost := o.staticHost()
if !originRequiresProxy(staticHost, cfg) {
return nil
}
// Start a listener for the proxy
proxyAddress := net.JoinHostPort(cfg.ProxyAddress, strconv.Itoa(int(cfg.ProxyPort)))
listener, err := net.Listen("tcp", proxyAddress)
if err != nil {
log.Errorf("Cannot start Websocket Proxy Server: %s", err)
return errors.Wrap(err, "Cannot start Websocket Proxy Server")
}
// Start the proxy itself
wg.Add(1)
go func() {
defer wg.Done()
streamHandler := websocket.DefaultStreamHandler
// This origin's config specifies what type of proxy to start.
switch cfg.ProxyType {
case socksProxy:
log.Info("SOCKS5 server started")
streamHandler = func(wsConn *websocket.Conn, remoteConn net.Conn, _ http.Header) {
dialer := socks.NewConnDialer(remoteConn)
requestHandler := socks.NewRequestHandler(dialer)
socksServer := socks.NewConnectionHandler(requestHandler)
socksServer.Serve(wsConn)
}
case "":
log.Debug("Not starting any websocket proxy")
default:
log.Errorf("%s isn't a valid proxy (valid options are {%s})", cfg.ProxyType, socksProxy)
}
errC <- websocket.StartProxyServer(log, listener, staticHost, shutdownC, streamHandler)
}()
// Modify this origin, so that it no longer points at the origin service directly.
// Instead, it points at the proxy to the origin service.
newURL, err := url.Parse("http://" + listener.Addr().String())
if err != nil {
return err
}
o.URL = newURL
return nil
}
func (o *URL) String() string {
return o.Address()
}
func (o *URL) RewriteOriginURL(u *url.URL) {
u.Host = o.URL.Host
u.Scheme = o.URL.Scheme
}
func (o *URL) staticHost() string {
addPortIfMissing := func(uri *url.URL, port int) string {
if uri.Port() != "" {
return uri.Host
}
return fmt.Sprintf("%s:%d", uri.Hostname(), port)
}
switch o.URL.Scheme {
case "ssh":
return addPortIfMissing(o.URL, 22)
case "rdp":
return addPortIfMissing(o.URL, 3389)
case "smb":
return addPortIfMissing(o.URL, 445)
case "tcp":
return addPortIfMissing(o.URL, 7864) // just a random port since there isn't a default in this case
}
return ""
}
// HelloWorld is the built-in Hello World service. Used for testing and experimenting with cloudflared.
type HelloWorld struct {
server net.Listener
}
func (o *HelloWorld) Address() string {
return o.server.Addr().String()
}
func (o *HelloWorld) String() string {
return "Hello World static HTML service"
}
// Start starts a HelloWorld server and stores its address in the Service receiver.
func (o *HelloWorld) Start(wg *sync.WaitGroup, log logger.Service, shutdownC <-chan struct{}, errC chan error, cfg OriginRequestConfig) error {
helloListener, err := hello.CreateTLSListener("127.0.0.1:")
if err != nil {
return errors.Wrap(err, "Cannot start Hello World Server")
}
wg.Add(1)
go func() {
defer wg.Done()
_ = hello.StartHelloWorldServer(log, helloListener, shutdownC)
}()
o.server = helloListener
return nil
}
func (o *HelloWorld) RewriteOriginURL(u *url.URL) {
u.Host = o.Address()
u.Scheme = "https"
}
func originRequiresProxy(staticHost string, cfg OriginRequestConfig) bool {
return staticHost != "" || cfg.BastionMode
}

57
ingress/rule.go Normal file
View File

@@ -0,0 +1,57 @@
package ingress
import (
"crypto/tls"
"net/http"
"regexp"
"strings"
)
// Rule routes traffic from a hostname/path on the public internet to the
// service running on the given URL.
type Rule struct {
// Requests for this hostname will be proxied to this rule's service.
Hostname string
// Path is an optional regex that can specify path-driven ingress rules.
Path *regexp.Regexp
// A (probably local) address. Requests for a hostname which matches this
// rule's hostname pattern will be proxied to the service running on this
// address.
Service OriginService
// Configure the request cloudflared sends to this specific origin.
Config OriginRequestConfig
// Configures TLS for the cloudflared -> origin request
ClientTLSConfig *tls.Config
// Configures HTTP for the cloudflared -> origin request
HTTPTransport http.RoundTripper
}
// MultiLineString is for outputting rules in a human-friendly way when Cloudflared
// is used as a CLI tool (not as a daemon).
func (r Rule) MultiLineString() string {
var out strings.Builder
if r.Hostname != "" {
out.WriteString("\thostname: ")
out.WriteString(r.Hostname)
out.WriteRune('\n')
}
if r.Path != nil {
out.WriteString("\tpath: ")
out.WriteString(r.Path.String())
out.WriteRune('\n')
}
out.WriteString("\tservice: ")
out.WriteString(r.Service.String())
return out.String()
}
// Matches checks if the rule matches a given hostname/path combination.
func (r *Rule) Matches(hostname, path string) bool {
hostMatch := r.Hostname == "" || r.Hostname == "*" || matchHost(r.Hostname, hostname)
pathMatch := r.Path == nil || r.Path.MatchString(path)
return hostMatch && pathMatch
}

119
ingress/rule_test.go Normal file
View File

@@ -0,0 +1,119 @@
package ingress
import (
"net/url"
"regexp"
"testing"
)
func Test_rule_matches(t *testing.T) {
type fields struct {
Hostname string
Path *regexp.Regexp
Service OriginService
}
type args struct {
requestURL *url.URL
}
tests := []struct {
name string
fields fields
args args
want bool
}{
{
name: "Just hostname, pass",
fields: fields{
Hostname: "example.com",
},
args: args{
requestURL: MustParseURL(t, "https://example.com"),
},
want: true,
},
{
name: "Entire hostname is wildcard, should match everything",
fields: fields{
Hostname: "*",
},
args: args{
requestURL: MustParseURL(t, "https://example.com"),
},
want: true,
},
{
name: "Just hostname, fail",
fields: fields{
Hostname: "example.com",
},
args: args{
requestURL: MustParseURL(t, "https://foo.bar"),
},
want: false,
},
{
name: "Just wildcard hostname, pass",
fields: fields{
Hostname: "*.example.com",
},
args: args{
requestURL: MustParseURL(t, "https://adam.example.com"),
},
want: true,
},
{
name: "Just wildcard hostname, fail",
fields: fields{
Hostname: "*.example.com",
},
args: args{
requestURL: MustParseURL(t, "https://tunnel.com"),
},
want: false,
},
{
name: "Just wildcard outside of subdomain in hostname, fail",
fields: fields{
Hostname: "*example.com",
},
args: args{
requestURL: MustParseURL(t, "https://www.example.com"),
},
want: false,
},
{
name: "Wildcard over multiple subdomains",
fields: fields{
Hostname: "*.example.com",
},
args: args{
requestURL: MustParseURL(t, "https://adam.chalmers.example.com"),
},
want: true,
},
{
name: "Hostname and path",
fields: fields{
Hostname: "*.example.com",
Path: regexp.MustCompile("/static/.*\\.html"),
},
args: args{
requestURL: MustParseURL(t, "https://www.example.com/static/index.html"),
},
want: true,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
r := Rule{
Hostname: tt.fields.Hostname,
Path: tt.fields.Path,
Service: tt.fields.Service,
}
u := tt.args.requestURL
if got := r.Matches(u.Hostname(), u.Path); got != tt.want {
t.Errorf("rule.matches() = %v, want %v", got, tt.want)
}
})
}
}