mirror of
https://github.com/cloudflare/cloudflared.git
synced 2025-07-27 15:49:58 +00:00
TUN-2640: Users can configure per-origin config. Unify single-rule CLI
flow with multi-rule config file code.
This commit is contained in:
100
origin/tunnel.go
100
origin/tunnel.go
@@ -30,7 +30,6 @@ import (
|
||||
"github.com/cloudflare/cloudflared/tunnelrpc"
|
||||
"github.com/cloudflare/cloudflared/tunnelrpc/pogs"
|
||||
tunnelpogs "github.com/cloudflare/cloudflared/tunnelrpc/pogs"
|
||||
"github.com/cloudflare/cloudflared/validation"
|
||||
"github.com/cloudflare/cloudflared/websocket"
|
||||
)
|
||||
|
||||
@@ -57,16 +56,13 @@ const (
|
||||
type TunnelConfig struct {
|
||||
BuildInfo *buildinfo.BuildInfo
|
||||
ClientID string
|
||||
ClientTlsConfig *tls.Config
|
||||
CloseConnOnce *sync.Once // Used to close connectedSignal no more than once
|
||||
CompressionQuality uint64
|
||||
EdgeAddrs []string
|
||||
GracePeriod time.Duration
|
||||
HAConnections int
|
||||
HTTPTransport http.RoundTripper
|
||||
HeartbeatInterval time.Duration
|
||||
Hostname string
|
||||
HTTPHostHeader string
|
||||
IncidentLookup IncidentLookup
|
||||
IsAutoupdated bool
|
||||
IsFreeTunnel bool
|
||||
@@ -76,7 +72,6 @@ type TunnelConfig struct {
|
||||
MaxHeartbeats uint64
|
||||
Metrics *TunnelMetrics
|
||||
MetricsUpdateFreq time.Duration
|
||||
NoChunkedEncoding bool
|
||||
OriginCert []byte
|
||||
ReportedVersion string
|
||||
Retries uint
|
||||
@@ -84,8 +79,6 @@ type TunnelConfig struct {
|
||||
Tags []tunnelpogs.Tag
|
||||
TlsConfig *tls.Config
|
||||
WSGI bool
|
||||
// OriginUrl may not be used if a user specifies a unix socket.
|
||||
OriginUrl string
|
||||
|
||||
// feature-flag to use new edge reconnect tokens
|
||||
UseReconnectToken bool
|
||||
@@ -618,18 +611,13 @@ func LogServerInfo(
|
||||
}
|
||||
|
||||
type TunnelHandler struct {
|
||||
originUrl string
|
||||
ingressRules ingress.Ingress
|
||||
httpHostHeader string
|
||||
muxer *h2mux.Muxer
|
||||
httpClient http.RoundTripper
|
||||
tlsConfig *tls.Config
|
||||
tags []tunnelpogs.Tag
|
||||
metrics *TunnelMetrics
|
||||
ingressRules ingress.Ingress
|
||||
muxer *h2mux.Muxer
|
||||
tags []tunnelpogs.Tag
|
||||
metrics *TunnelMetrics
|
||||
// connectionID is only used by metrics, and prometheus requires labels to be string
|
||||
connectionID string
|
||||
logger logger.Service
|
||||
noChunkedEncoding bool
|
||||
connectionID string
|
||||
logger logger.Service
|
||||
|
||||
bufferPool *buffer.Pool
|
||||
}
|
||||
@@ -642,31 +630,13 @@ func NewTunnelHandler(ctx context.Context,
|
||||
bufferPool *buffer.Pool,
|
||||
) (*TunnelHandler, string, error) {
|
||||
|
||||
// Check single-origin config
|
||||
var originURL string
|
||||
var err error
|
||||
if config.IngressRules.IsEmpty() {
|
||||
originURL, err = validation.ValidateUrl(config.OriginUrl)
|
||||
if err != nil {
|
||||
return nil, "", fmt.Errorf("unable to parse origin URL %#v", originURL)
|
||||
}
|
||||
}
|
||||
|
||||
h := &TunnelHandler{
|
||||
originUrl: originURL,
|
||||
ingressRules: config.IngressRules,
|
||||
httpHostHeader: config.HTTPHostHeader,
|
||||
httpClient: config.HTTPTransport,
|
||||
tlsConfig: config.ClientTlsConfig,
|
||||
tags: config.Tags,
|
||||
metrics: config.Metrics,
|
||||
connectionID: uint8ToString(connectionID),
|
||||
logger: config.Logger,
|
||||
noChunkedEncoding: config.NoChunkedEncoding,
|
||||
bufferPool: bufferPool,
|
||||
}
|
||||
if h.httpClient == nil {
|
||||
h.httpClient = http.DefaultTransport
|
||||
ingressRules: config.IngressRules,
|
||||
tags: config.Tags,
|
||||
metrics: config.Metrics,
|
||||
connectionID: uint8ToString(connectionID),
|
||||
logger: config.Logger,
|
||||
bufferPool: bufferPool,
|
||||
}
|
||||
|
||||
edgeConn, err := connection.DialEdge(ctx, dialTimeout, config.TlsConfig, addr)
|
||||
@@ -692,7 +662,7 @@ func (h *TunnelHandler) ServeStream(stream *h2mux.MuxedStream) error {
|
||||
h.metrics.incrementRequests(h.connectionID)
|
||||
defer h.metrics.decrementConcurrentRequests(h.connectionID)
|
||||
|
||||
req, reqErr := h.createRequest(stream)
|
||||
req, rule, reqErr := h.createRequest(stream)
|
||||
if reqErr != nil {
|
||||
h.writeErrorResponse(stream, reqErr)
|
||||
return reqErr
|
||||
@@ -705,9 +675,9 @@ func (h *TunnelHandler) ServeStream(stream *h2mux.MuxedStream) error {
|
||||
var resp *http.Response
|
||||
var respErr error
|
||||
if websocket.IsWebSocketUpgrade(req) {
|
||||
resp, respErr = h.serveWebsocket(stream, req)
|
||||
resp, respErr = h.serveWebsocket(stream, req, rule)
|
||||
} else {
|
||||
resp, respErr = h.serveHTTP(stream, req)
|
||||
resp, respErr = h.serveHTTP(stream, req, rule)
|
||||
}
|
||||
if respErr != nil {
|
||||
h.writeErrorResponse(stream, respErr)
|
||||
@@ -717,32 +687,28 @@ func (h *TunnelHandler) ServeStream(stream *h2mux.MuxedStream) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (h *TunnelHandler) createRequest(stream *h2mux.MuxedStream) (*http.Request, error) {
|
||||
req, err := http.NewRequest("GET", h.originUrl, h2mux.MuxedStreamReader{MuxedStream: stream})
|
||||
func (h *TunnelHandler) createRequest(stream *h2mux.MuxedStream) (*http.Request, *ingress.Rule, error) {
|
||||
req, err := http.NewRequest("GET", "http://localhost:8080", h2mux.MuxedStreamReader{MuxedStream: stream})
|
||||
if err != nil {
|
||||
return nil, errors.Wrap(err, "Unexpected error from http.NewRequest")
|
||||
return nil, nil, errors.Wrap(err, "Unexpected error from http.NewRequest")
|
||||
}
|
||||
err = h2mux.H2RequestHeadersToH1Request(stream.Headers, req)
|
||||
if err != nil {
|
||||
return nil, errors.Wrap(err, "invalid request received")
|
||||
return nil, nil, errors.Wrap(err, "invalid request received")
|
||||
}
|
||||
h.AppendTagHeaders(req)
|
||||
if !h.ingressRules.IsEmpty() {
|
||||
ruleNumber := h.ingressRules.FindMatchingRule(req.Host, req.URL.Path)
|
||||
destination := h.ingressRules.Rules[ruleNumber].Service
|
||||
req.URL.Host = destination.Host
|
||||
req.URL.Scheme = destination.Scheme
|
||||
}
|
||||
return req, nil
|
||||
rule, _ := h.ingressRules.FindMatchingRule(req.Host, req.URL.Path)
|
||||
rule.Service.RewriteOriginURL(req.URL)
|
||||
return req, rule, nil
|
||||
}
|
||||
|
||||
func (h *TunnelHandler) serveWebsocket(stream *h2mux.MuxedStream, req *http.Request) (*http.Response, error) {
|
||||
if h.httpHostHeader != "" {
|
||||
req.Header.Set("Host", h.httpHostHeader)
|
||||
req.Host = h.httpHostHeader
|
||||
func (h *TunnelHandler) serveWebsocket(stream *h2mux.MuxedStream, req *http.Request, rule *ingress.Rule) (*http.Response, error) {
|
||||
if hostHeader := rule.Config.HTTPHostHeader; hostHeader != "" {
|
||||
req.Header.Set("Host", hostHeader)
|
||||
req.Host = hostHeader
|
||||
}
|
||||
|
||||
conn, response, err := websocket.ClientConnect(req, h.tlsConfig)
|
||||
conn, response, err := websocket.ClientConnect(req, rule.ClientTLSConfig)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@@ -758,9 +724,9 @@ func (h *TunnelHandler) serveWebsocket(stream *h2mux.MuxedStream, req *http.Requ
|
||||
return response, nil
|
||||
}
|
||||
|
||||
func (h *TunnelHandler) serveHTTP(stream *h2mux.MuxedStream, req *http.Request) (*http.Response, error) {
|
||||
func (h *TunnelHandler) serveHTTP(stream *h2mux.MuxedStream, req *http.Request, rule *ingress.Rule) (*http.Response, error) {
|
||||
// Support for WSGI Servers by switching transfer encoding from chunked to gzip/deflate
|
||||
if h.noChunkedEncoding {
|
||||
if rule.Config.DisableChunkedEncoding {
|
||||
req.TransferEncoding = []string{"gzip", "deflate"}
|
||||
cLength, err := strconv.Atoi(req.Header.Get("Content-Length"))
|
||||
if err == nil {
|
||||
@@ -771,12 +737,12 @@ func (h *TunnelHandler) serveHTTP(stream *h2mux.MuxedStream, req *http.Request)
|
||||
// Request origin to keep connection alive to improve performance
|
||||
req.Header.Set("Connection", "keep-alive")
|
||||
|
||||
if h.httpHostHeader != "" {
|
||||
req.Header.Set("Host", h.httpHostHeader)
|
||||
req.Host = h.httpHostHeader
|
||||
if hostHeader := rule.Config.HTTPHostHeader; hostHeader != "" {
|
||||
req.Header.Set("Host", hostHeader)
|
||||
req.Host = hostHeader
|
||||
}
|
||||
|
||||
response, err := h.httpClient.RoundTrip(req)
|
||||
response, err := rule.HTTPTransport.RoundTrip(req)
|
||||
if err != nil {
|
||||
return nil, errors.Wrap(err, "Error proxying request to origin")
|
||||
}
|
||||
|
Reference in New Issue
Block a user