mirror of
https://github.com/cloudflare/cloudflared.git
synced 2025-07-27 00:29:58 +00:00
TUN-3462: Refactor cloudflared to separate origin from connection
This commit is contained in:
@@ -1,14 +0,0 @@
|
||||
package origin
|
||||
|
||||
import (
|
||||
"net"
|
||||
)
|
||||
|
||||
// persistentTCPConn is a wrapper around net.Conn that is noop when Close is called
|
||||
type persistentConn struct {
|
||||
net.Conn
|
||||
}
|
||||
|
||||
func (pc *persistentConn) Close() error {
|
||||
return nil
|
||||
}
|
@@ -1,540 +1,63 @@
|
||||
package origin
|
||||
|
||||
import (
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/cloudflare/cloudflared/h2mux"
|
||||
|
||||
"github.com/cloudflare/cloudflared/connection"
|
||||
"github.com/prometheus/client_golang/prometheus"
|
||||
)
|
||||
|
||||
const (
|
||||
metricsNamespace = "cloudflared"
|
||||
tunnelSubsystem = "tunnel"
|
||||
muxerSubsystem = "muxer"
|
||||
)
|
||||
// Metrics uses connection.MetricsNamespace(aka cloudflared) as namespace and connection.TunnelSubsystem
|
||||
// (tunnel) as subsystem to keep them consistent with the previous qualifier.
|
||||
|
||||
type muxerMetrics struct {
|
||||
rtt *prometheus.GaugeVec
|
||||
rttMin *prometheus.GaugeVec
|
||||
rttMax *prometheus.GaugeVec
|
||||
receiveWindowAve *prometheus.GaugeVec
|
||||
sendWindowAve *prometheus.GaugeVec
|
||||
receiveWindowMin *prometheus.GaugeVec
|
||||
receiveWindowMax *prometheus.GaugeVec
|
||||
sendWindowMin *prometheus.GaugeVec
|
||||
sendWindowMax *prometheus.GaugeVec
|
||||
inBoundRateCurr *prometheus.GaugeVec
|
||||
inBoundRateMin *prometheus.GaugeVec
|
||||
inBoundRateMax *prometheus.GaugeVec
|
||||
outBoundRateCurr *prometheus.GaugeVec
|
||||
outBoundRateMin *prometheus.GaugeVec
|
||||
outBoundRateMax *prometheus.GaugeVec
|
||||
compBytesBefore *prometheus.GaugeVec
|
||||
compBytesAfter *prometheus.GaugeVec
|
||||
compRateAve *prometheus.GaugeVec
|
||||
}
|
||||
|
||||
type TunnelMetrics struct {
|
||||
haConnections prometheus.Gauge
|
||||
activeStreams prometheus.Gauge
|
||||
totalRequests prometheus.Counter
|
||||
requestsPerTunnel *prometheus.CounterVec
|
||||
// concurrentRequestsLock is a mutex for concurrentRequests and maxConcurrentRequests
|
||||
concurrentRequestsLock sync.Mutex
|
||||
concurrentRequestsPerTunnel *prometheus.GaugeVec
|
||||
// concurrentRequests records count of concurrent requests for each tunnel
|
||||
concurrentRequests map[string]uint64
|
||||
maxConcurrentRequestsPerTunnel *prometheus.GaugeVec
|
||||
// concurrentRequests records max count of concurrent requests for each tunnel
|
||||
maxConcurrentRequests map[string]uint64
|
||||
timerRetries prometheus.Gauge
|
||||
responseByCode *prometheus.CounterVec
|
||||
responseCodePerTunnel *prometheus.CounterVec
|
||||
serverLocations *prometheus.GaugeVec
|
||||
// locationLock is a mutex for oldServerLocations
|
||||
locationLock sync.Mutex
|
||||
// oldServerLocations stores the last server the tunnel was connected to
|
||||
oldServerLocations map[string]string
|
||||
|
||||
regSuccess *prometheus.CounterVec
|
||||
regFail *prometheus.CounterVec
|
||||
rpcFail *prometheus.CounterVec
|
||||
|
||||
muxerMetrics *muxerMetrics
|
||||
tunnelsHA tunnelsForHA
|
||||
userHostnamesCounts *prometheus.CounterVec
|
||||
}
|
||||
|
||||
func newMuxerMetrics() *muxerMetrics {
|
||||
rtt := prometheus.NewGaugeVec(
|
||||
prometheus.GaugeOpts{
|
||||
Namespace: metricsNamespace,
|
||||
Subsystem: muxerSubsystem,
|
||||
Name: "rtt",
|
||||
Help: "Round-trip time in millisecond",
|
||||
},
|
||||
[]string{"connection_id"},
|
||||
)
|
||||
prometheus.MustRegister(rtt)
|
||||
|
||||
rttMin := prometheus.NewGaugeVec(
|
||||
prometheus.GaugeOpts{
|
||||
Namespace: metricsNamespace,
|
||||
Subsystem: muxerSubsystem,
|
||||
Name: "rtt_min",
|
||||
Help: "Shortest round-trip time in millisecond",
|
||||
},
|
||||
[]string{"connection_id"},
|
||||
)
|
||||
prometheus.MustRegister(rttMin)
|
||||
|
||||
rttMax := prometheus.NewGaugeVec(
|
||||
prometheus.GaugeOpts{
|
||||
Namespace: metricsNamespace,
|
||||
Subsystem: muxerSubsystem,
|
||||
Name: "rtt_max",
|
||||
Help: "Longest round-trip time in millisecond",
|
||||
},
|
||||
[]string{"connection_id"},
|
||||
)
|
||||
prometheus.MustRegister(rttMax)
|
||||
|
||||
receiveWindowAve := prometheus.NewGaugeVec(
|
||||
prometheus.GaugeOpts{
|
||||
Namespace: metricsNamespace,
|
||||
Subsystem: muxerSubsystem,
|
||||
Name: "receive_window_ave",
|
||||
Help: "Average receive window size in bytes",
|
||||
},
|
||||
[]string{"connection_id"},
|
||||
)
|
||||
prometheus.MustRegister(receiveWindowAve)
|
||||
|
||||
sendWindowAve := prometheus.NewGaugeVec(
|
||||
prometheus.GaugeOpts{
|
||||
Namespace: metricsNamespace,
|
||||
Subsystem: muxerSubsystem,
|
||||
Name: "send_window_ave",
|
||||
Help: "Average send window size in bytes",
|
||||
},
|
||||
[]string{"connection_id"},
|
||||
)
|
||||
prometheus.MustRegister(sendWindowAve)
|
||||
|
||||
receiveWindowMin := prometheus.NewGaugeVec(
|
||||
prometheus.GaugeOpts{
|
||||
Namespace: metricsNamespace,
|
||||
Subsystem: muxerSubsystem,
|
||||
Name: "receive_window_min",
|
||||
Help: "Smallest receive window size in bytes",
|
||||
},
|
||||
[]string{"connection_id"},
|
||||
)
|
||||
prometheus.MustRegister(receiveWindowMin)
|
||||
|
||||
receiveWindowMax := prometheus.NewGaugeVec(
|
||||
prometheus.GaugeOpts{
|
||||
Namespace: metricsNamespace,
|
||||
Subsystem: muxerSubsystem,
|
||||
Name: "receive_window_max",
|
||||
Help: "Largest receive window size in bytes",
|
||||
},
|
||||
[]string{"connection_id"},
|
||||
)
|
||||
prometheus.MustRegister(receiveWindowMax)
|
||||
|
||||
sendWindowMin := prometheus.NewGaugeVec(
|
||||
prometheus.GaugeOpts{
|
||||
Namespace: metricsNamespace,
|
||||
Subsystem: muxerSubsystem,
|
||||
Name: "send_window_min",
|
||||
Help: "Smallest send window size in bytes",
|
||||
},
|
||||
[]string{"connection_id"},
|
||||
)
|
||||
prometheus.MustRegister(sendWindowMin)
|
||||
|
||||
sendWindowMax := prometheus.NewGaugeVec(
|
||||
prometheus.GaugeOpts{
|
||||
Namespace: metricsNamespace,
|
||||
Subsystem: muxerSubsystem,
|
||||
Name: "send_window_max",
|
||||
Help: "Largest send window size in bytes",
|
||||
},
|
||||
[]string{"connection_id"},
|
||||
)
|
||||
prometheus.MustRegister(sendWindowMax)
|
||||
|
||||
inBoundRateCurr := prometheus.NewGaugeVec(
|
||||
prometheus.GaugeOpts{
|
||||
Namespace: metricsNamespace,
|
||||
Subsystem: muxerSubsystem,
|
||||
Name: "inbound_bytes_per_sec_curr",
|
||||
Help: "Current inbounding bytes per second, 0 if there is no incoming connection",
|
||||
},
|
||||
[]string{"connection_id"},
|
||||
)
|
||||
prometheus.MustRegister(inBoundRateCurr)
|
||||
|
||||
inBoundRateMin := prometheus.NewGaugeVec(
|
||||
prometheus.GaugeOpts{
|
||||
Namespace: metricsNamespace,
|
||||
Subsystem: muxerSubsystem,
|
||||
Name: "inbound_bytes_per_sec_min",
|
||||
Help: "Minimum non-zero inbounding bytes per second",
|
||||
},
|
||||
[]string{"connection_id"},
|
||||
)
|
||||
prometheus.MustRegister(inBoundRateMin)
|
||||
|
||||
inBoundRateMax := prometheus.NewGaugeVec(
|
||||
prometheus.GaugeOpts{
|
||||
Namespace: metricsNamespace,
|
||||
Subsystem: muxerSubsystem,
|
||||
Name: "inbound_bytes_per_sec_max",
|
||||
Help: "Maximum inbounding bytes per second",
|
||||
},
|
||||
[]string{"connection_id"},
|
||||
)
|
||||
prometheus.MustRegister(inBoundRateMax)
|
||||
|
||||
outBoundRateCurr := prometheus.NewGaugeVec(
|
||||
prometheus.GaugeOpts{
|
||||
Namespace: metricsNamespace,
|
||||
Subsystem: muxerSubsystem,
|
||||
Name: "outbound_bytes_per_sec_curr",
|
||||
Help: "Current outbounding bytes per second, 0 if there is no outgoing traffic",
|
||||
},
|
||||
[]string{"connection_id"},
|
||||
)
|
||||
prometheus.MustRegister(outBoundRateCurr)
|
||||
|
||||
outBoundRateMin := prometheus.NewGaugeVec(
|
||||
prometheus.GaugeOpts{
|
||||
Namespace: metricsNamespace,
|
||||
Subsystem: muxerSubsystem,
|
||||
Name: "outbound_bytes_per_sec_min",
|
||||
Help: "Minimum non-zero outbounding bytes per second",
|
||||
},
|
||||
[]string{"connection_id"},
|
||||
)
|
||||
prometheus.MustRegister(outBoundRateMin)
|
||||
|
||||
outBoundRateMax := prometheus.NewGaugeVec(
|
||||
prometheus.GaugeOpts{
|
||||
Namespace: metricsNamespace,
|
||||
Subsystem: muxerSubsystem,
|
||||
Name: "outbound_bytes_per_sec_max",
|
||||
Help: "Maximum outbounding bytes per second",
|
||||
},
|
||||
[]string{"connection_id"},
|
||||
)
|
||||
prometheus.MustRegister(outBoundRateMax)
|
||||
|
||||
compBytesBefore := prometheus.NewGaugeVec(
|
||||
prometheus.GaugeOpts{
|
||||
Namespace: metricsNamespace,
|
||||
Subsystem: muxerSubsystem,
|
||||
Name: "comp_bytes_before",
|
||||
Help: "Bytes sent via cross-stream compression, pre compression",
|
||||
},
|
||||
[]string{"connection_id"},
|
||||
)
|
||||
prometheus.MustRegister(compBytesBefore)
|
||||
|
||||
compBytesAfter := prometheus.NewGaugeVec(
|
||||
prometheus.GaugeOpts{
|
||||
Namespace: metricsNamespace,
|
||||
Subsystem: muxerSubsystem,
|
||||
Name: "comp_bytes_after",
|
||||
Help: "Bytes sent via cross-stream compression, post compression",
|
||||
},
|
||||
[]string{"connection_id"},
|
||||
)
|
||||
prometheus.MustRegister(compBytesAfter)
|
||||
|
||||
compRateAve := prometheus.NewGaugeVec(
|
||||
prometheus.GaugeOpts{
|
||||
Namespace: metricsNamespace,
|
||||
Subsystem: muxerSubsystem,
|
||||
Name: "comp_rate_ave",
|
||||
Help: "Average outbound cross-stream compression ratio",
|
||||
},
|
||||
[]string{"connection_id"},
|
||||
)
|
||||
prometheus.MustRegister(compRateAve)
|
||||
|
||||
return &muxerMetrics{
|
||||
rtt: rtt,
|
||||
rttMin: rttMin,
|
||||
rttMax: rttMax,
|
||||
receiveWindowAve: receiveWindowAve,
|
||||
sendWindowAve: sendWindowAve,
|
||||
receiveWindowMin: receiveWindowMin,
|
||||
receiveWindowMax: receiveWindowMax,
|
||||
sendWindowMin: sendWindowMin,
|
||||
sendWindowMax: sendWindowMax,
|
||||
inBoundRateCurr: inBoundRateCurr,
|
||||
inBoundRateMin: inBoundRateMin,
|
||||
inBoundRateMax: inBoundRateMax,
|
||||
outBoundRateCurr: outBoundRateCurr,
|
||||
outBoundRateMin: outBoundRateMin,
|
||||
outBoundRateMax: outBoundRateMax,
|
||||
compBytesBefore: compBytesBefore,
|
||||
compBytesAfter: compBytesAfter,
|
||||
compRateAve: compRateAve,
|
||||
}
|
||||
}
|
||||
|
||||
func (m *muxerMetrics) update(connectionID string, metrics *h2mux.MuxerMetrics) {
|
||||
m.rtt.WithLabelValues(connectionID).Set(convertRTTMilliSec(metrics.RTT))
|
||||
m.rttMin.WithLabelValues(connectionID).Set(convertRTTMilliSec(metrics.RTTMin))
|
||||
m.rttMax.WithLabelValues(connectionID).Set(convertRTTMilliSec(metrics.RTTMax))
|
||||
m.receiveWindowAve.WithLabelValues(connectionID).Set(metrics.ReceiveWindowAve)
|
||||
m.sendWindowAve.WithLabelValues(connectionID).Set(metrics.SendWindowAve)
|
||||
m.receiveWindowMin.WithLabelValues(connectionID).Set(float64(metrics.ReceiveWindowMin))
|
||||
m.receiveWindowMax.WithLabelValues(connectionID).Set(float64(metrics.ReceiveWindowMax))
|
||||
m.sendWindowMin.WithLabelValues(connectionID).Set(float64(metrics.SendWindowMin))
|
||||
m.sendWindowMax.WithLabelValues(connectionID).Set(float64(metrics.SendWindowMax))
|
||||
m.inBoundRateCurr.WithLabelValues(connectionID).Set(float64(metrics.InBoundRateCurr))
|
||||
m.inBoundRateMin.WithLabelValues(connectionID).Set(float64(metrics.InBoundRateMin))
|
||||
m.inBoundRateMax.WithLabelValues(connectionID).Set(float64(metrics.InBoundRateMax))
|
||||
m.outBoundRateCurr.WithLabelValues(connectionID).Set(float64(metrics.OutBoundRateCurr))
|
||||
m.outBoundRateMin.WithLabelValues(connectionID).Set(float64(metrics.OutBoundRateMin))
|
||||
m.outBoundRateMax.WithLabelValues(connectionID).Set(float64(metrics.OutBoundRateMax))
|
||||
m.compBytesBefore.WithLabelValues(connectionID).Set(float64(metrics.CompBytesBefore.Value()))
|
||||
m.compBytesAfter.WithLabelValues(connectionID).Set(float64(metrics.CompBytesAfter.Value()))
|
||||
m.compRateAve.WithLabelValues(connectionID).Set(float64(metrics.CompRateAve()))
|
||||
}
|
||||
|
||||
func convertRTTMilliSec(t time.Duration) float64 {
|
||||
return float64(t / time.Millisecond)
|
||||
}
|
||||
|
||||
// Metrics that can be collected without asking the edge
|
||||
func NewTunnelMetrics() *TunnelMetrics {
|
||||
haConnections := prometheus.NewGauge(
|
||||
prometheus.GaugeOpts{
|
||||
Namespace: metricsNamespace,
|
||||
Subsystem: tunnelSubsystem,
|
||||
Name: "ha_connections",
|
||||
Help: "Number of active ha connections",
|
||||
})
|
||||
prometheus.MustRegister(haConnections)
|
||||
|
||||
activeStreams := h2mux.NewActiveStreamsMetrics(metricsNamespace, tunnelSubsystem)
|
||||
|
||||
totalRequests := prometheus.NewCounter(
|
||||
var (
|
||||
totalRequests = prometheus.NewCounter(
|
||||
prometheus.CounterOpts{
|
||||
Namespace: metricsNamespace,
|
||||
Subsystem: tunnelSubsystem,
|
||||
Namespace: connection.MetricsNamespace,
|
||||
Subsystem: connection.TunnelSubsystem,
|
||||
Name: "total_requests",
|
||||
Help: "Amount of requests proxied through all the tunnels",
|
||||
})
|
||||
prometheus.MustRegister(totalRequests)
|
||||
|
||||
requestsPerTunnel := prometheus.NewCounterVec(
|
||||
prometheus.CounterOpts{
|
||||
Namespace: metricsNamespace,
|
||||
Subsystem: tunnelSubsystem,
|
||||
Name: "requests_per_tunnel",
|
||||
Help: "Amount of requests proxied through each tunnel",
|
||||
},
|
||||
[]string{"connection_id"},
|
||||
)
|
||||
prometheus.MustRegister(requestsPerTunnel)
|
||||
|
||||
concurrentRequestsPerTunnel := prometheus.NewGaugeVec(
|
||||
concurrentRequests = prometheus.NewGauge(
|
||||
prometheus.GaugeOpts{
|
||||
Namespace: metricsNamespace,
|
||||
Subsystem: tunnelSubsystem,
|
||||
Namespace: connection.MetricsNamespace,
|
||||
Subsystem: connection.TunnelSubsystem,
|
||||
Name: "concurrent_requests_per_tunnel",
|
||||
Help: "Concurrent requests proxied through each tunnel",
|
||||
},
|
||||
[]string{"connection_id"},
|
||||
)
|
||||
prometheus.MustRegister(concurrentRequestsPerTunnel)
|
||||
|
||||
maxConcurrentRequestsPerTunnel := prometheus.NewGaugeVec(
|
||||
prometheus.GaugeOpts{
|
||||
Namespace: metricsNamespace,
|
||||
Subsystem: tunnelSubsystem,
|
||||
Name: "max_concurrent_requests_per_tunnel",
|
||||
Help: "Largest number of concurrent requests proxied through each tunnel so far",
|
||||
},
|
||||
[]string{"connection_id"},
|
||||
)
|
||||
prometheus.MustRegister(maxConcurrentRequestsPerTunnel)
|
||||
|
||||
timerRetries := prometheus.NewGauge(
|
||||
prometheus.GaugeOpts{
|
||||
Namespace: metricsNamespace,
|
||||
Subsystem: tunnelSubsystem,
|
||||
Name: "timer_retries",
|
||||
Help: "Unacknowledged heart beats count",
|
||||
})
|
||||
prometheus.MustRegister(timerRetries)
|
||||
|
||||
responseByCode := prometheus.NewCounterVec(
|
||||
responseByCode = prometheus.NewCounterVec(
|
||||
prometheus.CounterOpts{
|
||||
Namespace: metricsNamespace,
|
||||
Subsystem: tunnelSubsystem,
|
||||
Namespace: connection.MetricsNamespace,
|
||||
Subsystem: connection.TunnelSubsystem,
|
||||
Name: "response_by_code",
|
||||
Help: "Count of responses by HTTP status code",
|
||||
},
|
||||
[]string{"status_code"},
|
||||
)
|
||||
prometheus.MustRegister(responseByCode)
|
||||
|
||||
responseCodePerTunnel := prometheus.NewCounterVec(
|
||||
prometheus.CounterOpts{
|
||||
Namespace: metricsNamespace,
|
||||
Subsystem: tunnelSubsystem,
|
||||
Name: "response_code_per_tunnel",
|
||||
Help: "Count of responses by HTTP status code fore each tunnel",
|
||||
},
|
||||
[]string{"connection_id", "status_code"},
|
||||
)
|
||||
prometheus.MustRegister(responseCodePerTunnel)
|
||||
|
||||
serverLocations := prometheus.NewGaugeVec(
|
||||
haConnections = prometheus.NewGauge(
|
||||
prometheus.GaugeOpts{
|
||||
Namespace: metricsNamespace,
|
||||
Subsystem: tunnelSubsystem,
|
||||
Name: "server_locations",
|
||||
Help: "Where each tunnel is connected to. 1 means current location, 0 means previous locations.",
|
||||
Namespace: connection.MetricsNamespace,
|
||||
Subsystem: connection.TunnelSubsystem,
|
||||
Name: "ha_connections",
|
||||
Help: "Number of active ha connections",
|
||||
},
|
||||
[]string{"connection_id", "location"},
|
||||
)
|
||||
prometheus.MustRegister(serverLocations)
|
||||
)
|
||||
|
||||
rpcFail := prometheus.NewCounterVec(
|
||||
prometheus.CounterOpts{
|
||||
Namespace: metricsNamespace,
|
||||
Subsystem: tunnelSubsystem,
|
||||
Name: "tunnel_rpc_fail",
|
||||
Help: "Count of RPC connection errors by type",
|
||||
},
|
||||
[]string{"error", "rpcName"},
|
||||
func init() {
|
||||
prometheus.MustRegister(
|
||||
totalRequests,
|
||||
concurrentRequests,
|
||||
responseByCode,
|
||||
haConnections,
|
||||
)
|
||||
prometheus.MustRegister(rpcFail)
|
||||
|
||||
registerFail := prometheus.NewCounterVec(
|
||||
prometheus.CounterOpts{
|
||||
Namespace: metricsNamespace,
|
||||
Subsystem: tunnelSubsystem,
|
||||
Name: "tunnel_register_fail",
|
||||
Help: "Count of tunnel registration errors by type",
|
||||
},
|
||||
[]string{"error", "rpcName"},
|
||||
)
|
||||
prometheus.MustRegister(registerFail)
|
||||
|
||||
userHostnamesCounts := prometheus.NewCounterVec(
|
||||
prometheus.CounterOpts{
|
||||
Namespace: metricsNamespace,
|
||||
Subsystem: tunnelSubsystem,
|
||||
Name: "user_hostnames_counts",
|
||||
Help: "Which user hostnames cloudflared is serving",
|
||||
},
|
||||
[]string{"userHostname"},
|
||||
)
|
||||
prometheus.MustRegister(userHostnamesCounts)
|
||||
|
||||
registerSuccess := prometheus.NewCounterVec(
|
||||
prometheus.CounterOpts{
|
||||
Namespace: metricsNamespace,
|
||||
Subsystem: tunnelSubsystem,
|
||||
Name: "tunnel_register_success",
|
||||
Help: "Count of successful tunnel registrations",
|
||||
},
|
||||
[]string{"rpcName"},
|
||||
)
|
||||
prometheus.MustRegister(registerSuccess)
|
||||
|
||||
return &TunnelMetrics{
|
||||
haConnections: haConnections,
|
||||
activeStreams: activeStreams,
|
||||
totalRequests: totalRequests,
|
||||
requestsPerTunnel: requestsPerTunnel,
|
||||
concurrentRequestsPerTunnel: concurrentRequestsPerTunnel,
|
||||
concurrentRequests: make(map[string]uint64),
|
||||
maxConcurrentRequestsPerTunnel: maxConcurrentRequestsPerTunnel,
|
||||
maxConcurrentRequests: make(map[string]uint64),
|
||||
timerRetries: timerRetries,
|
||||
responseByCode: responseByCode,
|
||||
responseCodePerTunnel: responseCodePerTunnel,
|
||||
serverLocations: serverLocations,
|
||||
oldServerLocations: make(map[string]string),
|
||||
muxerMetrics: newMuxerMetrics(),
|
||||
tunnelsHA: NewTunnelsForHA(),
|
||||
regSuccess: registerSuccess,
|
||||
regFail: registerFail,
|
||||
rpcFail: rpcFail,
|
||||
userHostnamesCounts: userHostnamesCounts,
|
||||
}
|
||||
}
|
||||
|
||||
func (t *TunnelMetrics) incrementHaConnections() {
|
||||
t.haConnections.Inc()
|
||||
func incrementRequests() {
|
||||
totalRequests.Inc()
|
||||
concurrentRequests.Inc()
|
||||
}
|
||||
|
||||
func (t *TunnelMetrics) decrementHaConnections() {
|
||||
t.haConnections.Dec()
|
||||
}
|
||||
|
||||
func (t *TunnelMetrics) updateMuxerMetrics(connectionID string, metrics *h2mux.MuxerMetrics) {
|
||||
t.muxerMetrics.update(connectionID, metrics)
|
||||
}
|
||||
|
||||
func (t *TunnelMetrics) incrementRequests(connectionID string) {
|
||||
t.concurrentRequestsLock.Lock()
|
||||
var concurrentRequests uint64
|
||||
var ok bool
|
||||
if concurrentRequests, ok = t.concurrentRequests[connectionID]; ok {
|
||||
t.concurrentRequests[connectionID]++
|
||||
concurrentRequests++
|
||||
} else {
|
||||
t.concurrentRequests[connectionID] = 1
|
||||
concurrentRequests = 1
|
||||
}
|
||||
if maxConcurrentRequests, ok := t.maxConcurrentRequests[connectionID]; (ok && maxConcurrentRequests < concurrentRequests) || !ok {
|
||||
t.maxConcurrentRequests[connectionID] = concurrentRequests
|
||||
t.maxConcurrentRequestsPerTunnel.WithLabelValues(connectionID).Set(float64(concurrentRequests))
|
||||
}
|
||||
t.concurrentRequestsLock.Unlock()
|
||||
|
||||
t.totalRequests.Inc()
|
||||
t.requestsPerTunnel.WithLabelValues(connectionID).Inc()
|
||||
t.concurrentRequestsPerTunnel.WithLabelValues(connectionID).Inc()
|
||||
}
|
||||
|
||||
func (t *TunnelMetrics) decrementConcurrentRequests(connectionID string) {
|
||||
t.concurrentRequestsLock.Lock()
|
||||
if _, ok := t.concurrentRequests[connectionID]; ok {
|
||||
t.concurrentRequests[connectionID]--
|
||||
}
|
||||
t.concurrentRequestsLock.Unlock()
|
||||
|
||||
t.concurrentRequestsPerTunnel.WithLabelValues(connectionID).Dec()
|
||||
}
|
||||
|
||||
func (t *TunnelMetrics) incrementResponses(connectionID, code string) {
|
||||
t.responseByCode.WithLabelValues(code).Inc()
|
||||
t.responseCodePerTunnel.WithLabelValues(connectionID, code).Inc()
|
||||
|
||||
}
|
||||
|
||||
func (t *TunnelMetrics) registerServerLocation(connectionID, loc string) {
|
||||
t.locationLock.Lock()
|
||||
defer t.locationLock.Unlock()
|
||||
if oldLoc, ok := t.oldServerLocations[connectionID]; ok && oldLoc == loc {
|
||||
return
|
||||
} else if ok {
|
||||
t.serverLocations.WithLabelValues(connectionID, oldLoc).Dec()
|
||||
}
|
||||
t.serverLocations.WithLabelValues(connectionID, loc).Inc()
|
||||
t.oldServerLocations[connectionID] = loc
|
||||
func decrementConcurrentRequests() {
|
||||
concurrentRequests.Dec()
|
||||
}
|
||||
|
@@ -1,121 +0,0 @@
|
||||
package origin
|
||||
|
||||
import (
|
||||
"strconv"
|
||||
"sync"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
// can only be called once
|
||||
var m = NewTunnelMetrics()
|
||||
|
||||
func TestConcurrentRequestsSingleTunnel(t *testing.T) {
|
||||
routines := 20
|
||||
var wg sync.WaitGroup
|
||||
wg.Add(routines)
|
||||
for i := 0; i < routines; i++ {
|
||||
go func() {
|
||||
m.incrementRequests("0")
|
||||
wg.Done()
|
||||
}()
|
||||
}
|
||||
wg.Wait()
|
||||
assert.Len(t, m.concurrentRequests, 1)
|
||||
assert.Equal(t, uint64(routines), m.concurrentRequests["0"])
|
||||
assert.Len(t, m.maxConcurrentRequests, 1)
|
||||
assert.Equal(t, uint64(routines), m.maxConcurrentRequests["0"])
|
||||
|
||||
wg.Add(routines / 2)
|
||||
for i := 0; i < routines/2; i++ {
|
||||
go func() {
|
||||
m.decrementConcurrentRequests("0")
|
||||
wg.Done()
|
||||
}()
|
||||
}
|
||||
wg.Wait()
|
||||
assert.Equal(t, uint64(routines-routines/2), m.concurrentRequests["0"])
|
||||
assert.Equal(t, uint64(routines), m.maxConcurrentRequests["0"])
|
||||
}
|
||||
|
||||
func TestConcurrentRequestsMultiTunnel(t *testing.T) {
|
||||
m.concurrentRequests = make(map[string]uint64)
|
||||
m.maxConcurrentRequests = make(map[string]uint64)
|
||||
tunnels := 20
|
||||
var wg sync.WaitGroup
|
||||
wg.Add(tunnels)
|
||||
for i := 0; i < tunnels; i++ {
|
||||
go func(i int) {
|
||||
// if we have j < i, then tunnel 0 won't have a chance to call incrementRequests
|
||||
for j := 0; j < i+1; j++ {
|
||||
id := strconv.Itoa(i)
|
||||
m.incrementRequests(id)
|
||||
}
|
||||
wg.Done()
|
||||
}(i)
|
||||
}
|
||||
wg.Wait()
|
||||
|
||||
assert.Len(t, m.concurrentRequests, tunnels)
|
||||
assert.Len(t, m.maxConcurrentRequests, tunnels)
|
||||
for i := 0; i < tunnels; i++ {
|
||||
id := strconv.Itoa(i)
|
||||
assert.Equal(t, uint64(i+1), m.concurrentRequests[id])
|
||||
assert.Equal(t, uint64(i+1), m.maxConcurrentRequests[id])
|
||||
}
|
||||
|
||||
wg.Add(tunnels)
|
||||
for i := 0; i < tunnels; i++ {
|
||||
go func(i int) {
|
||||
for j := 0; j < i+1; j++ {
|
||||
id := strconv.Itoa(i)
|
||||
m.decrementConcurrentRequests(id)
|
||||
}
|
||||
wg.Done()
|
||||
}(i)
|
||||
}
|
||||
wg.Wait()
|
||||
|
||||
assert.Len(t, m.concurrentRequests, tunnels)
|
||||
assert.Len(t, m.maxConcurrentRequests, tunnels)
|
||||
for i := 0; i < tunnels; i++ {
|
||||
id := strconv.Itoa(i)
|
||||
assert.Equal(t, uint64(0), m.concurrentRequests[id])
|
||||
assert.Equal(t, uint64(i+1), m.maxConcurrentRequests[id])
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
func TestRegisterServerLocation(t *testing.T) {
|
||||
tunnels := 20
|
||||
var wg sync.WaitGroup
|
||||
wg.Add(tunnels)
|
||||
for i := 0; i < tunnels; i++ {
|
||||
go func(i int) {
|
||||
id := strconv.Itoa(i)
|
||||
m.registerServerLocation(id, "LHR")
|
||||
wg.Done()
|
||||
}(i)
|
||||
}
|
||||
wg.Wait()
|
||||
for i := 0; i < tunnels; i++ {
|
||||
id := strconv.Itoa(i)
|
||||
assert.Equal(t, "LHR", m.oldServerLocations[id])
|
||||
}
|
||||
|
||||
wg.Add(tunnels)
|
||||
for i := 0; i < tunnels; i++ {
|
||||
go func(i int) {
|
||||
id := strconv.Itoa(i)
|
||||
m.registerServerLocation(id, "AUS")
|
||||
wg.Done()
|
||||
}(i)
|
||||
}
|
||||
wg.Wait()
|
||||
for i := 0; i < tunnels; i++ {
|
||||
id := strconv.Itoa(i)
|
||||
assert.Equal(t, "AUS", m.oldServerLocations[id])
|
||||
}
|
||||
|
||||
}
|
208
origin/proxy.go
Normal file
208
origin/proxy.go
Normal file
@@ -0,0 +1,208 @@
|
||||
package origin
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
"crypto/tls"
|
||||
"io"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"strconv"
|
||||
"strings"
|
||||
|
||||
"github.com/cloudflare/cloudflared/buffer"
|
||||
"github.com/cloudflare/cloudflared/connection"
|
||||
"github.com/cloudflare/cloudflared/logger"
|
||||
tunnelpogs "github.com/cloudflare/cloudflared/tunnelrpc/pogs"
|
||||
"github.com/cloudflare/cloudflared/websocket"
|
||||
"github.com/pkg/errors"
|
||||
)
|
||||
|
||||
const (
|
||||
TagHeaderNamePrefix = "Cf-Warp-Tag-"
|
||||
)
|
||||
|
||||
type client struct {
|
||||
config *ProxyConfig
|
||||
logger logger.Service
|
||||
bufferPool *buffer.Pool
|
||||
}
|
||||
|
||||
func NewClient(config *ProxyConfig, logger logger.Service) connection.OriginClient {
|
||||
return &client{
|
||||
config: config,
|
||||
logger: logger,
|
||||
bufferPool: buffer.NewPool(512 * 1024),
|
||||
}
|
||||
}
|
||||
|
||||
type ProxyConfig struct {
|
||||
Client http.RoundTripper
|
||||
URL *url.URL
|
||||
TLSConfig *tls.Config
|
||||
HostHeader string
|
||||
NoChunkedEncoding bool
|
||||
Tags []tunnelpogs.Tag
|
||||
}
|
||||
|
||||
func (c *client) Proxy(w connection.ResponseWriter, req *http.Request, isWebsocket bool) error {
|
||||
incrementRequests()
|
||||
defer decrementConcurrentRequests()
|
||||
|
||||
cfRay := findCfRayHeader(req)
|
||||
lbProbe := isLBProbeRequest(req)
|
||||
|
||||
c.appendTagHeaders(req)
|
||||
c.logRequest(req, cfRay, lbProbe)
|
||||
var (
|
||||
resp *http.Response
|
||||
err error
|
||||
)
|
||||
if isWebsocket {
|
||||
resp, err = c.proxyWebsocket(w, req)
|
||||
} else {
|
||||
resp, err = c.proxyHTTP(w, req)
|
||||
}
|
||||
if err != nil {
|
||||
c.logger.Errorf("HTTP request error: %s", err)
|
||||
responseByCode.WithLabelValues("502").Inc()
|
||||
w.WriteErrorResponse(err)
|
||||
return err
|
||||
}
|
||||
c.logResponseOk(resp, cfRay, lbProbe)
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *client) proxyHTTP(w connection.ResponseWriter, req *http.Request) (*http.Response, error) {
|
||||
// Support for WSGI Servers by switching transfer encoding from chunked to gzip/deflate
|
||||
if c.config.NoChunkedEncoding {
|
||||
req.TransferEncoding = []string{"gzip", "deflate"}
|
||||
cLength, err := strconv.Atoi(req.Header.Get("Content-Length"))
|
||||
if err == nil {
|
||||
req.ContentLength = int64(cLength)
|
||||
}
|
||||
}
|
||||
|
||||
// Request origin to keep connection alive to improve performance
|
||||
req.Header.Set("Connection", "keep-alive")
|
||||
|
||||
c.setHostHeader(req)
|
||||
|
||||
resp, err := c.config.Client.RoundTrip(req)
|
||||
if err != nil {
|
||||
return nil, errors.Wrap(err, "Error proxying request to origin")
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
err = w.WriteRespHeaders(resp)
|
||||
if err != nil {
|
||||
return nil, errors.Wrap(err, "Error writing response header")
|
||||
}
|
||||
if isEventStream(resp) {
|
||||
//h.observer.Debug("Detected Server-Side Events from Origin")
|
||||
c.writeEventStream(w, resp.Body)
|
||||
} else {
|
||||
// Use CopyBuffer, because Copy only allocates a 32KiB buffer, and cross-stream
|
||||
// compression generates dictionary on first write
|
||||
buf := c.bufferPool.Get()
|
||||
defer c.bufferPool.Put(buf)
|
||||
io.CopyBuffer(w, resp.Body, buf)
|
||||
}
|
||||
return resp, nil
|
||||
}
|
||||
|
||||
func (c *client) proxyWebsocket(w connection.ResponseWriter, req *http.Request) (*http.Response, error) {
|
||||
c.setHostHeader(req)
|
||||
|
||||
conn, resp, err := websocket.ClientConnect(req, c.config.TLSConfig)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer conn.Close()
|
||||
err = w.WriteRespHeaders(resp)
|
||||
if err != nil {
|
||||
return nil, errors.Wrap(err, "Error writing response header")
|
||||
}
|
||||
// Copy to/from stream to the undelying connection. Use the underlying
|
||||
// connection because cloudflared doesn't operate on the message themselves
|
||||
websocket.Stream(conn.UnderlyingConn(), w)
|
||||
|
||||
return resp, nil
|
||||
}
|
||||
|
||||
func (c *client) writeEventStream(w connection.ResponseWriter, respBody io.ReadCloser) {
|
||||
reader := bufio.NewReader(respBody)
|
||||
for {
|
||||
line, err := reader.ReadBytes('\n')
|
||||
if err != nil {
|
||||
break
|
||||
}
|
||||
w.Write(line)
|
||||
}
|
||||
}
|
||||
|
||||
func (c *client) setHostHeader(req *http.Request) {
|
||||
if c.config.HostHeader != "" {
|
||||
req.Header.Set("Host", c.config.HostHeader)
|
||||
req.Host = c.config.HostHeader
|
||||
}
|
||||
}
|
||||
|
||||
func (c *client) appendTagHeaders(r *http.Request) {
|
||||
for _, tag := range c.config.Tags {
|
||||
r.Header.Add(TagHeaderNamePrefix+tag.Name, tag.Value)
|
||||
}
|
||||
}
|
||||
|
||||
func (c *client) logRequest(r *http.Request, cfRay string, lbProbe bool) {
|
||||
if cfRay != "" {
|
||||
c.logger.Debugf("CF-RAY: %s %s %s %s", cfRay, r.Method, r.URL, r.Proto)
|
||||
} else if lbProbe {
|
||||
c.logger.Debugf("CF-RAY: %s Load Balancer health check %s %s %s", cfRay, r.Method, r.URL, r.Proto)
|
||||
} else {
|
||||
c.logger.Debugf("CF-RAY: %s All requests should have a CF-RAY header. Please open a support ticket with Cloudflare. %s %s %s ", cfRay, r.Method, r.URL, r.Proto)
|
||||
}
|
||||
c.logger.Debugf("CF-RAY: %s Request Headers %+v", cfRay, r.Header)
|
||||
|
||||
if contentLen := r.ContentLength; contentLen == -1 {
|
||||
c.logger.Debugf("CF-RAY: %s Request Content length unknown", cfRay)
|
||||
} else {
|
||||
c.logger.Debugf("CF-RAY: %s Request content length %d", cfRay, contentLen)
|
||||
}
|
||||
}
|
||||
|
||||
func (c *client) logResponseOk(r *http.Response, cfRay string, lbProbe bool) {
|
||||
responseByCode.WithLabelValues("200").Inc()
|
||||
if cfRay != "" {
|
||||
c.logger.Debugf("CF-RAY: %s %s", cfRay, r.Status)
|
||||
} else if lbProbe {
|
||||
c.logger.Debugf("Response to Load Balancer health check %s", r.Status)
|
||||
} else {
|
||||
c.logger.Infof("%s", r.Status)
|
||||
}
|
||||
c.logger.Debugf("CF-RAY: %s Response Headers %+v", cfRay, r.Header)
|
||||
|
||||
if contentLen := r.ContentLength; contentLen == -1 {
|
||||
c.logger.Debugf("CF-RAY: %s Response content length unknown", cfRay)
|
||||
} else {
|
||||
c.logger.Debugf("CF-RAY: %s Response content length %d", cfRay, contentLen)
|
||||
}
|
||||
}
|
||||
|
||||
func findCfRayHeader(req *http.Request) string {
|
||||
return req.Header.Get("Cf-Ray")
|
||||
}
|
||||
|
||||
func isLBProbeRequest(req *http.Request) bool {
|
||||
return strings.HasPrefix(req.UserAgent(), lbProbeUserAgentPrefix)
|
||||
}
|
||||
|
||||
func uint8ToString(input uint8) string {
|
||||
return strconv.FormatUint(uint64(input), 10)
|
||||
}
|
||||
|
||||
func isEventStream(response *http.Response) bool {
|
||||
if response.Header.Get("content-type") == "text/event-stream" {
|
||||
return true
|
||||
}
|
||||
return false
|
||||
}
|
@@ -7,11 +7,7 @@ import (
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/cloudflare/cloudflared/h2mux"
|
||||
"github.com/cloudflare/cloudflared/logger"
|
||||
"github.com/cloudflare/cloudflared/tunnelrpc"
|
||||
tunnelpogs "github.com/cloudflare/cloudflared/tunnelrpc/pogs"
|
||||
"github.com/google/uuid"
|
||||
"github.com/prometheus/client_golang/prometheus"
|
||||
)
|
||||
|
||||
@@ -138,52 +134,3 @@ func (cm *reconnectCredentialManager) RefreshAuth(
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
|
||||
func ReconnectTunnel(
|
||||
ctx context.Context,
|
||||
muxer *h2mux.Muxer,
|
||||
config *TunnelConfig,
|
||||
logger logger.Service,
|
||||
connectionID uint8,
|
||||
originLocalAddr string,
|
||||
uuid uuid.UUID,
|
||||
credentialManager *reconnectCredentialManager,
|
||||
) error {
|
||||
token, err := credentialManager.ReconnectToken()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
eventDigest, err := credentialManager.EventDigest(connectionID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
connDigest, err := credentialManager.ConnDigest(connectionID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
config.TransportLogger.Debug("initiating RPC stream to reconnect")
|
||||
rpcClient, err := newTunnelRPCClient(ctx, muxer, config, reconnect)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer rpcClient.Close()
|
||||
// Request server info without blocking tunnel registration; must use capnp library directly.
|
||||
serverInfoPromise := tunnelrpc.TunnelServer{Client: rpcClient.Client}.GetServerInfo(ctx, func(tunnelrpc.TunnelServer_getServerInfo_Params) error {
|
||||
return nil
|
||||
})
|
||||
LogServerInfo(serverInfoPromise.Result(), connectionID, config.Metrics, logger, config.TunnelEventChan)
|
||||
registration := rpcClient.ReconnectTunnel(
|
||||
ctx,
|
||||
token,
|
||||
eventDigest,
|
||||
connDigest,
|
||||
config.Hostname,
|
||||
config.RegistrationOptions(connectionID, originLocalAddr, uuid),
|
||||
)
|
||||
if registrationErr := registration.DeserializeError(); registrationErr != nil {
|
||||
// ReconnectTunnel RPC failure
|
||||
return processRegisterTunnelError(registrationErr, config.Metrics, reconnect)
|
||||
}
|
||||
return processRegistrationSuccess(config, logger, connectionID, registration, reconnect, credentialManager)
|
||||
}
|
||||
|
@@ -8,7 +8,6 @@ import (
|
||||
|
||||
"github.com/google/uuid"
|
||||
|
||||
"github.com/cloudflare/cloudflared/buffer"
|
||||
"github.com/cloudflare/cloudflared/connection"
|
||||
"github.com/cloudflare/cloudflared/edgediscovery"
|
||||
"github.com/cloudflare/cloudflared/h2mux"
|
||||
@@ -56,8 +55,7 @@ type Supervisor struct {
|
||||
logger logger.Service
|
||||
|
||||
reconnectCredentialManager *reconnectCredentialManager
|
||||
|
||||
bufferPool *buffer.Pool
|
||||
useReconnectToken bool
|
||||
}
|
||||
|
||||
type resolveResult struct {
|
||||
@@ -76,28 +74,33 @@ func NewSupervisor(config *TunnelConfig, cloudflaredUUID uuid.UUID) (*Supervisor
|
||||
err error
|
||||
)
|
||||
if len(config.EdgeAddrs) > 0 {
|
||||
edgeIPs, err = edgediscovery.StaticEdge(config.Logger, config.EdgeAddrs)
|
||||
edgeIPs, err = edgediscovery.StaticEdge(config.Observer, config.EdgeAddrs)
|
||||
} else {
|
||||
edgeIPs, err = edgediscovery.ResolveEdge(config.Logger)
|
||||
edgeIPs, err = edgediscovery.ResolveEdge(config.Observer)
|
||||
}
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
useReconnectToken := false
|
||||
if config.ClassicTunnel != nil {
|
||||
useReconnectToken = config.ClassicTunnel.UseReconnectToken
|
||||
}
|
||||
|
||||
return &Supervisor{
|
||||
cloudflaredUUID: cloudflaredUUID,
|
||||
config: config,
|
||||
edgeIPs: edgeIPs,
|
||||
tunnelErrors: make(chan tunnelError),
|
||||
tunnelsConnecting: map[int]chan struct{}{},
|
||||
logger: config.Logger,
|
||||
reconnectCredentialManager: newReconnectCredentialManager(metricsNamespace, tunnelSubsystem, config.HAConnections),
|
||||
bufferPool: buffer.NewPool(512 * 1024),
|
||||
logger: config.Observer,
|
||||
reconnectCredentialManager: newReconnectCredentialManager(connection.MetricsNamespace, connection.TunnelSubsystem, config.HAConnections),
|
||||
useReconnectToken: useReconnectToken,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (s *Supervisor) Run(ctx context.Context, connectedSignal *signal.Signal, reconnectCh chan ReconnectSignal) error {
|
||||
logger := s.config.Logger
|
||||
logger := s.config.Observer
|
||||
if err := s.initialize(ctx, connectedSignal, reconnectCh); err != nil {
|
||||
return err
|
||||
}
|
||||
@@ -110,7 +113,7 @@ func (s *Supervisor) Run(ctx context.Context, connectedSignal *signal.Signal, re
|
||||
refreshAuthBackoff := &BackoffHandler{MaxRetries: refreshAuthMaxBackoff, BaseTime: refreshAuthRetryDuration, RetryForever: true}
|
||||
var refreshAuthBackoffTimer <-chan time.Time
|
||||
|
||||
if s.config.UseReconnectToken {
|
||||
if s.useReconnectToken {
|
||||
if timer, err := s.reconnectCredentialManager.RefreshAuth(ctx, refreshAuthBackoff, s.authenticate); err == nil {
|
||||
refreshAuthBackoffTimer = timer
|
||||
} else {
|
||||
@@ -227,7 +230,7 @@ func (s *Supervisor) startFirstTunnel(ctx context.Context, connectedSignal *sign
|
||||
return
|
||||
}
|
||||
|
||||
err = ServeTunnelLoop(ctx, s.reconnectCredentialManager, s.config, addr, firstConnIndex, connectedSignal, s.cloudflaredUUID, s.bufferPool, reconnectCh)
|
||||
err = ServeTunnelLoop(ctx, s.reconnectCredentialManager, s.config, addr, firstConnIndex, connectedSignal, s.cloudflaredUUID, reconnectCh)
|
||||
// If the first tunnel disconnects, keep restarting it.
|
||||
edgeErrors := 0
|
||||
for s.unusedIPs() {
|
||||
@@ -239,7 +242,7 @@ func (s *Supervisor) startFirstTunnel(ctx context.Context, connectedSignal *sign
|
||||
return
|
||||
// try the next address if it was a dialError(network problem) or
|
||||
// dupConnRegisterTunnelError
|
||||
case connection.DialError, dupConnRegisterTunnelError:
|
||||
case edgediscovery.DialError, connection.DupConnRegisterTunnelError:
|
||||
edgeErrors++
|
||||
default:
|
||||
return
|
||||
@@ -250,7 +253,7 @@ func (s *Supervisor) startFirstTunnel(ctx context.Context, connectedSignal *sign
|
||||
return
|
||||
}
|
||||
}
|
||||
err = ServeTunnelLoop(ctx, s.reconnectCredentialManager, s.config, addr, firstConnIndex, connectedSignal, s.cloudflaredUUID, s.bufferPool, reconnectCh)
|
||||
err = ServeTunnelLoop(ctx, s.reconnectCredentialManager, s.config, addr, firstConnIndex, connectedSignal, s.cloudflaredUUID, reconnectCh)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -269,7 +272,7 @@ func (s *Supervisor) startTunnel(ctx context.Context, index int, connectedSignal
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
err = ServeTunnelLoop(ctx, s.reconnectCredentialManager, s.config, addr, uint8(index), connectedSignal, s.cloudflaredUUID, s.bufferPool, reconnectCh)
|
||||
err = ServeTunnelLoop(ctx, s.reconnectCredentialManager, s.config, addr, uint8(index), connectedSignal, s.cloudflaredUUID, reconnectCh)
|
||||
}
|
||||
|
||||
func (s *Supervisor) newConnectedTunnelSignal(index int) *signal.Signal {
|
||||
@@ -301,7 +304,7 @@ func (s *Supervisor) authenticate(ctx context.Context, numPreviousAttempts int)
|
||||
return nil, err
|
||||
}
|
||||
|
||||
edgeConn, err := connection.DialEdge(ctx, dialTimeout, s.config.TlsConfig, arbitraryEdgeIP)
|
||||
edgeConn, err := edgediscovery.DialEdge(ctx, dialTimeout, s.config.TLSConfig, arbitraryEdgeIP)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@@ -311,8 +314,8 @@ func (s *Supervisor) authenticate(ctx context.Context, numPreviousAttempts int)
|
||||
// This callback is invoked by h2mux when the edge initiates a stream.
|
||||
return nil // noop
|
||||
})
|
||||
muxerConfig := s.config.muxerConfig(handler)
|
||||
muxer, err := h2mux.Handshake(edgeConn, edgeConn, muxerConfig, s.config.Metrics.activeStreams)
|
||||
muxerConfig := s.config.MuxerConfig.H2MuxerConfig(handler, s.logger)
|
||||
muxer, err := h2mux.Handshake(edgeConn, edgeConn, *muxerConfig, h2mux.ActiveStreams)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@@ -323,23 +326,15 @@ func (s *Supervisor) authenticate(ctx context.Context, numPreviousAttempts int)
|
||||
<-muxer.Shutdown()
|
||||
}()
|
||||
|
||||
rpcClient, err := newTunnelRPCClient(ctx, muxer, s.config, authenticate)
|
||||
stream, err := muxer.OpenRPCStream(ctx)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
rpcClient := connection.NewTunnelServerClient(ctx, stream, s.logger)
|
||||
defer rpcClient.Close()
|
||||
|
||||
const arbitraryConnectionID = uint8(0)
|
||||
registrationOptions := s.config.RegistrationOptions(arbitraryConnectionID, edgeConn.LocalAddr().String(), s.cloudflaredUUID)
|
||||
registrationOptions.NumPreviousAttempts = uint8(numPreviousAttempts)
|
||||
authResponse, err := rpcClient.Authenticate(
|
||||
ctx,
|
||||
s.config.OriginCert,
|
||||
s.config.Hostname,
|
||||
registrationOptions,
|
||||
)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return authResponse.Outcome(), nil
|
||||
return rpcClient.Authenticate(ctx, s.config.ClassicTunnel, registrationOptions)
|
||||
}
|
||||
|
579
origin/tunnel.go
579
origin/tunnel.go
@@ -5,9 +5,7 @@ import (
|
||||
"crypto/tls"
|
||||
"fmt"
|
||||
"net"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"strconv"
|
||||
"runtime/debug"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
@@ -17,26 +15,22 @@ import (
|
||||
"github.com/prometheus/client_golang/prometheus"
|
||||
"golang.org/x/sync/errgroup"
|
||||
|
||||
"github.com/cloudflare/cloudflared/buffer"
|
||||
"github.com/cloudflare/cloudflared/cmd/cloudflared/buildinfo"
|
||||
"github.com/cloudflare/cloudflared/cmd/cloudflared/ui"
|
||||
"github.com/cloudflare/cloudflared/connection"
|
||||
"github.com/cloudflare/cloudflared/edgediscovery"
|
||||
"github.com/cloudflare/cloudflared/h2mux"
|
||||
"github.com/cloudflare/cloudflared/ingress"
|
||||
"github.com/cloudflare/cloudflared/logger"
|
||||
"github.com/cloudflare/cloudflared/signal"
|
||||
"github.com/cloudflare/cloudflared/tunnelrpc"
|
||||
"github.com/cloudflare/cloudflared/tunnelrpc/pogs"
|
||||
tunnelpogs "github.com/cloudflare/cloudflared/tunnelrpc/pogs"
|
||||
"github.com/cloudflare/cloudflared/websocket"
|
||||
)
|
||||
|
||||
const (
|
||||
dialTimeout = 15 * time.Second
|
||||
openStreamTimeout = 30 * time.Second
|
||||
muxerTimeout = 5 * time.Second
|
||||
lbProbeUserAgentPrefix = "Mozilla/5.0 (compatible; Cloudflare-Traffic-Manager/1.0; +https://www.cloudflare.com/traffic-manager/;"
|
||||
TagHeaderNamePrefix = "Cf-Warp-Tag-"
|
||||
DuplicateConnectionError = "EDUPCONN"
|
||||
FeatureSerializedHeaders = "serialized_headers"
|
||||
FeatureQuickReconnects = "quick_reconnects"
|
||||
@@ -52,49 +46,31 @@ const (
|
||||
)
|
||||
|
||||
type TunnelConfig struct {
|
||||
BuildInfo *buildinfo.BuildInfo
|
||||
ClientID string
|
||||
CloseConnOnce *sync.Once // Used to close connectedSignal no more than once
|
||||
CompressionQuality uint64
|
||||
EdgeAddrs []string
|
||||
GracePeriod time.Duration
|
||||
HAConnections int
|
||||
HeartbeatInterval time.Duration
|
||||
Hostname string
|
||||
IncidentLookup IncidentLookup
|
||||
IsAutoupdated bool
|
||||
IsFreeTunnel bool
|
||||
LBPool string
|
||||
Logger logger.Service
|
||||
TransportLogger logger.Service
|
||||
MaxHeartbeats uint64
|
||||
Metrics *TunnelMetrics
|
||||
MetricsUpdateFreq time.Duration
|
||||
OriginCert []byte
|
||||
ReportedVersion string
|
||||
Retries uint
|
||||
RunFromTerminal bool
|
||||
Tags []tunnelpogs.Tag
|
||||
TlsConfig *tls.Config
|
||||
WSGI bool
|
||||
ConnectionConfig *connection.Config
|
||||
ProxyConfig *ProxyConfig
|
||||
BuildInfo *buildinfo.BuildInfo
|
||||
ClientID string
|
||||
CloseConnOnce *sync.Once // Used to close connectedSignal no more than once
|
||||
EdgeAddrs []string
|
||||
HAConnections int
|
||||
IncidentLookup IncidentLookup
|
||||
IsAutoupdated bool
|
||||
IsFreeTunnel bool
|
||||
LBPool string
|
||||
Logger logger.Service
|
||||
Observer *connection.Observer
|
||||
ReportedVersion string
|
||||
Retries uint
|
||||
RunFromTerminal bool
|
||||
TLSConfig *tls.Config
|
||||
|
||||
// feature-flag to use new edge reconnect tokens
|
||||
UseReconnectToken bool
|
||||
|
||||
NamedTunnel *NamedTunnelConfig
|
||||
ReplaceExisting bool
|
||||
TunnelEventChan chan<- ui.TunnelEvent
|
||||
NamedTunnel *connection.NamedTunnelConfig
|
||||
ClassicTunnel *connection.ClassicTunnelConfig
|
||||
MuxerConfig *connection.MuxerConfig
|
||||
TunnelEventChan chan ui.TunnelEvent
|
||||
IngressRules ingress.Ingress
|
||||
}
|
||||
|
||||
type dupConnRegisterTunnelError struct{}
|
||||
|
||||
var errDuplicationConnection = &dupConnRegisterTunnelError{}
|
||||
|
||||
func (e dupConnRegisterTunnelError) Error() string {
|
||||
return "already connected to this server, trying another address"
|
||||
}
|
||||
|
||||
type muxerShutdownError struct{}
|
||||
|
||||
func (e muxerShutdownError) Error() string {
|
||||
@@ -125,18 +101,6 @@ func (e clientRegisterTunnelError) Error() string {
|
||||
return e.cause.Error()
|
||||
}
|
||||
|
||||
func (c *TunnelConfig) muxerConfig(handler h2mux.MuxedStreamHandler) h2mux.MuxerConfig {
|
||||
return h2mux.MuxerConfig{
|
||||
Timeout: muxerTimeout,
|
||||
Handler: handler,
|
||||
IsClient: true,
|
||||
HeartbeatInterval: c.HeartbeatInterval,
|
||||
MaxHeartbeats: c.MaxHeartbeats,
|
||||
Logger: c.TransportLogger,
|
||||
CompressionQuality: h2mux.CompressionSetting(c.CompressionQuality),
|
||||
}
|
||||
}
|
||||
|
||||
func (c *TunnelConfig) RegistrationOptions(connectionID uint8, OriginLocalIP string, uuid uuid.UUID) *tunnelpogs.RegistrationOptions {
|
||||
policy := tunnelrpc.ExistingTunnelPolicy_balance
|
||||
if c.HAConnections <= 1 && c.LBPool == "" {
|
||||
@@ -148,12 +112,12 @@ func (c *TunnelConfig) RegistrationOptions(connectionID uint8, OriginLocalIP str
|
||||
OS: fmt.Sprintf("%s_%s", c.BuildInfo.GoOS, c.BuildInfo.GoArch),
|
||||
ExistingTunnelPolicy: policy,
|
||||
PoolName: c.LBPool,
|
||||
Tags: c.Tags,
|
||||
Tags: c.ProxyConfig.Tags,
|
||||
ConnectionID: connectionID,
|
||||
OriginLocalIP: OriginLocalIP,
|
||||
IsAutoupdated: c.IsAutoupdated,
|
||||
RunFromTerminal: c.RunFromTerminal,
|
||||
CompressionQuality: c.CompressionQuality,
|
||||
CompressionQuality: uint64(c.MuxerConfig.CompressionSetting),
|
||||
UUID: uuid.String(),
|
||||
Features: c.SupportedFeatures(),
|
||||
}
|
||||
@@ -167,8 +131,8 @@ func (c *TunnelConfig) ConnectionOptions(originLocalAddr string, numPreviousAtte
|
||||
return &tunnelpogs.ConnectionOptions{
|
||||
Client: c.NamedTunnel.Client,
|
||||
OriginLocalIP: originIP,
|
||||
ReplaceExisting: c.ReplaceExisting,
|
||||
CompressionQuality: uint8(c.CompressionQuality),
|
||||
ReplaceExisting: c.ConnectionConfig.ReplaceExisting,
|
||||
CompressionQuality: uint8(c.MuxerConfig.CompressionSetting),
|
||||
NumPreviousAttempts: numPreviousAttempts,
|
||||
}
|
||||
}
|
||||
@@ -181,35 +145,6 @@ func (c *TunnelConfig) SupportedFeatures() []string {
|
||||
return features
|
||||
}
|
||||
|
||||
func (c *TunnelConfig) IsTrialTunnel() bool {
|
||||
return c.Hostname == ""
|
||||
}
|
||||
|
||||
type NamedTunnelConfig struct {
|
||||
Auth pogs.TunnelAuth
|
||||
ID uuid.UUID
|
||||
Client pogs.ClientInfo
|
||||
Protocol Protocol
|
||||
}
|
||||
|
||||
type Protocol int64
|
||||
|
||||
const (
|
||||
h2muxProtocol Protocol = iota
|
||||
http2Protocol
|
||||
)
|
||||
|
||||
func ParseProtocol(s string) (Protocol, bool) {
|
||||
switch s {
|
||||
case "h2mux":
|
||||
return h2muxProtocol, true
|
||||
case "http2":
|
||||
return http2Protocol, true
|
||||
default:
|
||||
return 0, false
|
||||
}
|
||||
}
|
||||
|
||||
func StartTunnelDaemon(ctx context.Context, config *TunnelConfig, connectedSignal *signal.Signal, cloudflaredID uuid.UUID, reconnectCh chan ReconnectSignal) error {
|
||||
s, err := NewSupervisor(config, cloudflaredID)
|
||||
if err != nil {
|
||||
@@ -225,11 +160,11 @@ func ServeTunnelLoop(ctx context.Context,
|
||||
connectionIndex uint8,
|
||||
connectedSignal *signal.Signal,
|
||||
cloudflaredUUID uuid.UUID,
|
||||
bufferPool *buffer.Pool,
|
||||
reconnectCh chan ReconnectSignal,
|
||||
) error {
|
||||
config.Metrics.incrementHaConnections()
|
||||
defer config.Metrics.decrementHaConnections()
|
||||
haConnections.Inc()
|
||||
defer haConnections.Dec()
|
||||
|
||||
backoff := BackoffHandler{MaxRetries: config.Retries}
|
||||
connectedFuse := h2mux.NewBooleanFuse()
|
||||
go func() {
|
||||
@@ -244,12 +179,10 @@ func ServeTunnelLoop(ctx context.Context,
|
||||
ctx,
|
||||
credentialManager,
|
||||
config,
|
||||
config.Logger,
|
||||
addr, connectionIndex,
|
||||
connectedFuse,
|
||||
&backoff,
|
||||
cloudflaredUUID,
|
||||
bufferPool,
|
||||
reconnectCh,
|
||||
)
|
||||
if recoverable {
|
||||
@@ -257,7 +190,7 @@ func ServeTunnelLoop(ctx context.Context,
|
||||
if config.TunnelEventChan != nil {
|
||||
config.TunnelEventChan <- ui.TunnelEvent{Index: connectionIndex, EventType: ui.Reconnecting}
|
||||
}
|
||||
config.Logger.Infof("Retrying connection %d in %s seconds", connectionIndex, duration)
|
||||
config.Logger.Infof("Retrying connection %d in %s seconds, error %v", connectionIndex, duration, err)
|
||||
backoff.Backoff(ctx)
|
||||
continue
|
||||
}
|
||||
@@ -270,13 +203,11 @@ func ServeTunnel(
|
||||
ctx context.Context,
|
||||
credentialManager *reconnectCredentialManager,
|
||||
config *TunnelConfig,
|
||||
logger logger.Service,
|
||||
addr *net.TCPAddr,
|
||||
connectionIndex uint8,
|
||||
connectedFuse *h2mux.BooleanFuse,
|
||||
fuse *h2mux.BooleanFuse,
|
||||
backoff *BackoffHandler,
|
||||
cloudflaredUUID uuid.UUID,
|
||||
bufferPool *buffer.Pool,
|
||||
reconnectCh chan ReconnectSignal,
|
||||
) (err error, recoverable bool) {
|
||||
// Treat panics as recoverable errors
|
||||
@@ -287,6 +218,7 @@ func ServeTunnel(
|
||||
if !ok {
|
||||
err = fmt.Errorf("ServeTunnel: %v", r)
|
||||
}
|
||||
err = errors.Wrapf(err, "stack trace: %s", string(debug.Stack()))
|
||||
recoverable = true
|
||||
}
|
||||
}()
|
||||
@@ -298,203 +230,107 @@ func ServeTunnel(
|
||||
}()
|
||||
}
|
||||
|
||||
connectionTag := uint8ToString(connectionIndex)
|
||||
|
||||
if config.NamedTunnel != nil && config.NamedTunnel.Protocol == http2Protocol {
|
||||
return ServeNamedTunnel(ctx, config, connectionIndex, addr, connectedFuse, reconnectCh)
|
||||
}
|
||||
|
||||
// Returns error from parsing the origin URL or handshake errors
|
||||
handler, originLocalAddr, err := NewTunnelHandler(ctx, config, addr, connectionIndex, bufferPool)
|
||||
edgeConn, err := edgediscovery.DialEdge(ctx, dialTimeout, config.TLSConfig, addr)
|
||||
if err != nil {
|
||||
switch err.(type) {
|
||||
case connection.DialError:
|
||||
logger.Errorf("Connection %d unable to dial edge: %s", connectionIndex, err)
|
||||
case h2mux.MuxerHandshakeError:
|
||||
logger.Errorf("Connection %d handshake with edge server failed: %s", connectionIndex, err)
|
||||
default:
|
||||
logger.Errorf("Connection %d failed: %s", connectionIndex, err)
|
||||
return err, false
|
||||
}
|
||||
return err, true
|
||||
}
|
||||
connectedFuse := &connectedFuse{
|
||||
fuse: fuse,
|
||||
backoff: backoff,
|
||||
}
|
||||
if config.NamedTunnel != nil && config.NamedTunnel.Protocol == connection.HTTP2 {
|
||||
connOptions := config.ConnectionOptions(edgeConn.LocalAddr().String(), uint8(backoff.retries))
|
||||
return ServeHTTP2(ctx, config, edgeConn, connOptions, connectionIndex, connectedFuse, reconnectCh)
|
||||
}
|
||||
return ServeH2mux(ctx, credentialManager, config, edgeConn, connectionIndex, connectedFuse, cloudflaredUUID, reconnectCh)
|
||||
}
|
||||
|
||||
func ServeH2mux(
|
||||
ctx context.Context,
|
||||
credentialManager *reconnectCredentialManager,
|
||||
config *TunnelConfig,
|
||||
edgeConn net.Conn,
|
||||
connectionIndex uint8,
|
||||
connectedFuse *connectedFuse,
|
||||
cloudflaredUUID uuid.UUID,
|
||||
reconnectCh chan ReconnectSignal,
|
||||
) (err error, recoverable bool) {
|
||||
// Returns error from parsing the origin URL or handshake errors
|
||||
handler, err, recoverable := connection.NewH2muxConnection(ctx, config.ConnectionConfig, config.MuxerConfig, config.ProxyConfig.URL.String(), edgeConn, connectionIndex, config.Observer)
|
||||
if err != nil {
|
||||
return err, recoverable
|
||||
}
|
||||
|
||||
errGroup, serveCtx := errgroup.WithContext(ctx)
|
||||
|
||||
errGroup.Go(func() (err error) {
|
||||
defer func() {
|
||||
if err == nil {
|
||||
connectedFuse.Fuse(true)
|
||||
backoff.SetGracePeriod()
|
||||
}
|
||||
}()
|
||||
|
||||
if config.UseReconnectToken && connectedFuse.Value() {
|
||||
err := ReconnectTunnel(serveCtx, handler.muxer, config, logger, connectionIndex, originLocalAddr, cloudflaredUUID, credentialManager)
|
||||
if err == nil {
|
||||
return nil
|
||||
}
|
||||
// log errors and proceed to RegisterTunnel
|
||||
logger.Errorf("Couldn't reconnect connection %d. Reregistering it instead. Error was: %v", connectionIndex, err)
|
||||
if config.NamedTunnel != nil {
|
||||
connOptions := config.ConnectionOptions(edgeConn.LocalAddr().String(), uint8(connectedFuse.backoff.retries))
|
||||
return handler.ServeNamedTunnel(ctx, config.NamedTunnel, credentialManager, connOptions, connectedFuse)
|
||||
}
|
||||
return RegisterTunnel(serveCtx, credentialManager, handler.muxer, config, logger, connectionIndex, originLocalAddr, cloudflaredUUID)
|
||||
registrationOptions := config.RegistrationOptions(connectionIndex, edgeConn.LocalAddr().String(), cloudflaredUUID)
|
||||
return handler.ServeClassicTunnel(ctx, config.ClassicTunnel, credentialManager, registrationOptions, connectedFuse)
|
||||
})
|
||||
|
||||
errGroup.Go(func() error {
|
||||
updateMetricsTickC := time.Tick(config.MetricsUpdateFreq)
|
||||
for {
|
||||
select {
|
||||
case <-serveCtx.Done():
|
||||
// UnregisterTunnel blocks until the RPC call returns
|
||||
if connectedFuse.Value() {
|
||||
if config.NamedTunnel != nil {
|
||||
_ = UnregisterConnection(ctx, handler.muxer, config)
|
||||
} else {
|
||||
_ = UnregisterTunnel(handler.muxer, config)
|
||||
}
|
||||
}
|
||||
handler.muxer.Shutdown()
|
||||
return nil
|
||||
case <-updateMetricsTickC:
|
||||
handler.UpdateMetrics(connectionTag)
|
||||
}
|
||||
}
|
||||
})
|
||||
|
||||
errGroup.Go(func() error {
|
||||
for {
|
||||
select {
|
||||
case reconnect := <-reconnectCh:
|
||||
return &reconnect
|
||||
case <-serveCtx.Done():
|
||||
return nil
|
||||
}
|
||||
}
|
||||
})
|
||||
|
||||
errGroup.Go(func() error {
|
||||
// All routines should stop when muxer finish serving. When muxer is shutdown
|
||||
// gracefully, it doesn't return an error, so we need to return errMuxerShutdown
|
||||
// here to notify other routines to stop
|
||||
err := handler.muxer.Serve(serveCtx)
|
||||
if err == nil {
|
||||
return muxerShutdownError{}
|
||||
}
|
||||
return err
|
||||
})
|
||||
errGroup.Go(listenReconnect(serveCtx, reconnectCh))
|
||||
|
||||
err = errGroup.Wait()
|
||||
if err != nil {
|
||||
switch err := err.(type) {
|
||||
case *dupConnRegisterTunnelError:
|
||||
case *connection.DupConnRegisterTunnelError:
|
||||
// don't retry this connection anymore, let supervisor pick new a address
|
||||
return err, false
|
||||
case *serverRegisterTunnelError:
|
||||
logger.Errorf("Register tunnel error from server side: %s", err.cause)
|
||||
config.Logger.Errorf("Register tunnel error from server side: %s", err.cause)
|
||||
// Don't send registration error return from server to Sentry. They are
|
||||
// logged on server side
|
||||
if incidents := config.IncidentLookup.ActiveIncidents(); len(incidents) > 0 {
|
||||
logger.Error(activeIncidentsMsg(incidents))
|
||||
config.Logger.Error(activeIncidentsMsg(incidents))
|
||||
}
|
||||
return err.cause, !err.permanent
|
||||
case *clientRegisterTunnelError:
|
||||
logger.Errorf("Register tunnel error on client side: %s", err.cause)
|
||||
config.Logger.Errorf("Register tunnel error on client side: %s", err.cause)
|
||||
return err, true
|
||||
case *muxerShutdownError:
|
||||
logger.Info("Muxer shutdown")
|
||||
config.Logger.Info("Muxer shutdown")
|
||||
return err, true
|
||||
case *ReconnectSignal:
|
||||
logger.Infof("Restarting connection %d due to reconnect signal in %d seconds", connectionIndex, err.Delay)
|
||||
config.Logger.Infof("Restarting connection %d due to reconnect signal in %d seconds", connectionIndex, err.Delay)
|
||||
err.DelayBeforeReconnect()
|
||||
return err, true
|
||||
default:
|
||||
if err == context.Canceled {
|
||||
logger.Debugf("Serve tunnel error: %s", err)
|
||||
config.Logger.Debugf("Serve tunnel error: %s", err)
|
||||
return err, false
|
||||
}
|
||||
logger.Errorf("Serve tunnel error: %s", err)
|
||||
config.Logger.Errorf("Serve tunnel error: %s", err)
|
||||
return err, true
|
||||
}
|
||||
}
|
||||
return nil, true
|
||||
}
|
||||
|
||||
func RegisterConnectionWithH2Mux(
|
||||
ctx context.Context,
|
||||
muxer *h2mux.Muxer,
|
||||
config *TunnelConfig,
|
||||
connectionIndex uint8,
|
||||
originLocalAddr string,
|
||||
numPreviousAttempts uint8,
|
||||
) error {
|
||||
const registerConnection = "registerConnection"
|
||||
|
||||
config.TransportLogger.Debug("initiating RPC stream for RegisterConnection")
|
||||
rpcClient, err := newTunnelRPCClient(ctx, muxer, config, registerConnection)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer rpcClient.Close()
|
||||
|
||||
conn, err := rpcClient.RegisterConnection(
|
||||
ctx,
|
||||
config.NamedTunnel.Auth,
|
||||
config.NamedTunnel.ID,
|
||||
connectionIndex,
|
||||
config.ConnectionOptions(originLocalAddr, numPreviousAttempts),
|
||||
)
|
||||
if err != nil {
|
||||
if err.Error() == DuplicateConnectionError {
|
||||
config.Metrics.regFail.WithLabelValues("dup_edge_conn", registerConnection).Inc()
|
||||
return errDuplicationConnection
|
||||
}
|
||||
config.Metrics.regFail.WithLabelValues("server_error", registerConnection).Inc()
|
||||
return serverRegistrationErrorFromRPC(err)
|
||||
}
|
||||
|
||||
config.Metrics.regSuccess.WithLabelValues(registerConnection).Inc()
|
||||
config.Logger.Infof("Connection %d registered with %s using ID %s", connectionIndex, conn.Location, conn.UUID)
|
||||
|
||||
// If launch-ui flag is set, send connect msg
|
||||
if config.TunnelEventChan != nil {
|
||||
config.TunnelEventChan <- ui.TunnelEvent{Index: connectionIndex, EventType: ui.Connected, Location: conn.Location}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func ServeNamedTunnel(
|
||||
func ServeHTTP2(
|
||||
ctx context.Context,
|
||||
config *TunnelConfig,
|
||||
tlsServerConn net.Conn,
|
||||
connOptions *tunnelpogs.ConnectionOptions,
|
||||
connIndex uint8,
|
||||
addr *net.TCPAddr,
|
||||
connectedFuse *h2mux.BooleanFuse,
|
||||
connectedFuse connection.ConnectedFuse,
|
||||
reconnectCh chan ReconnectSignal,
|
||||
) (err error, recoverable bool) {
|
||||
tlsServerConn, err := connection.DialEdge(ctx, dialTimeout, config.TlsConfig, addr)
|
||||
if err != nil {
|
||||
return err, true
|
||||
}
|
||||
|
||||
cfdServer, err := newHTTP2Server(config, connIndex, tlsServerConn.LocalAddr(), connectedFuse)
|
||||
server, err := connection.NewHTTP2Connection(tlsServerConn, config.ConnectionConfig, config.ProxyConfig.URL, config.NamedTunnel, connOptions, config.Observer, connIndex, connectedFuse)
|
||||
if err != nil {
|
||||
return err, false
|
||||
}
|
||||
|
||||
errGroup, serveCtx := errgroup.WithContext(ctx)
|
||||
errGroup.Go(func() error {
|
||||
cfdServer.serve(serveCtx, tlsServerConn)
|
||||
server.Serve(serveCtx)
|
||||
return fmt.Errorf("Connection with edge closed")
|
||||
})
|
||||
|
||||
errGroup.Go(func() error {
|
||||
select {
|
||||
case reconnect := <-reconnectCh:
|
||||
return &reconnect
|
||||
case <-serveCtx.Done():
|
||||
return nil
|
||||
}
|
||||
})
|
||||
errGroup.Go(listenReconnect(serveCtx, reconnectCh))
|
||||
|
||||
err = errGroup.Wait()
|
||||
if err != nil {
|
||||
@@ -503,229 +339,29 @@ func ServeNamedTunnel(
|
||||
return nil, false
|
||||
}
|
||||
|
||||
func serverRegistrationErrorFromRPC(err error) *serverRegisterTunnelError {
|
||||
if retryable, ok := err.(*tunnelpogs.RetryableError); ok {
|
||||
return &serverRegisterTunnelError{
|
||||
cause: retryable.Unwrap(),
|
||||
permanent: false,
|
||||
func listenReconnect(ctx context.Context, reconnectCh <-chan ReconnectSignal) func() error {
|
||||
return func() error {
|
||||
select {
|
||||
case reconnect := <-reconnectCh:
|
||||
return &reconnect
|
||||
case <-ctx.Done():
|
||||
return nil
|
||||
}
|
||||
}
|
||||
return &serverRegisterTunnelError{
|
||||
cause: err,
|
||||
permanent: true,
|
||||
}
|
||||
}
|
||||
|
||||
func UnregisterConnection(
|
||||
ctx context.Context,
|
||||
muxer *h2mux.Muxer,
|
||||
config *TunnelConfig,
|
||||
) error {
|
||||
config.TransportLogger.Debug("initiating RPC stream for UnregisterConnection")
|
||||
rpcClient, err := newTunnelRPCClient(ctx, muxer, config, register)
|
||||
if err != nil {
|
||||
// RPC stream open error
|
||||
return err
|
||||
}
|
||||
defer rpcClient.Close()
|
||||
|
||||
return rpcClient.UnregisterConnection(ctx)
|
||||
type connectedFuse struct {
|
||||
fuse *h2mux.BooleanFuse
|
||||
backoff *BackoffHandler
|
||||
}
|
||||
|
||||
func RegisterTunnel(
|
||||
ctx context.Context,
|
||||
credentialManager *reconnectCredentialManager,
|
||||
muxer *h2mux.Muxer,
|
||||
config *TunnelConfig,
|
||||
logger logger.Service,
|
||||
connectionID uint8,
|
||||
originLocalIP string,
|
||||
uuid uuid.UUID,
|
||||
) error {
|
||||
config.TransportLogger.Debug("initiating RPC stream to register")
|
||||
if config.TunnelEventChan != nil {
|
||||
config.TunnelEventChan <- ui.TunnelEvent{EventType: ui.RegisteringTunnel}
|
||||
}
|
||||
|
||||
rpcClient, err := newTunnelRPCClient(ctx, muxer, config, register)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer rpcClient.Close()
|
||||
// Request server info without blocking tunnel registration; must use capnp library directly.
|
||||
serverInfoPromise := tunnelrpc.TunnelServer{Client: rpcClient.Client}.GetServerInfo(ctx, func(tunnelrpc.TunnelServer_getServerInfo_Params) error {
|
||||
return nil
|
||||
})
|
||||
LogServerInfo(serverInfoPromise.Result(), connectionID, config.Metrics, logger, config.TunnelEventChan)
|
||||
registration := rpcClient.RegisterTunnel(
|
||||
ctx,
|
||||
config.OriginCert,
|
||||
config.Hostname,
|
||||
config.RegistrationOptions(connectionID, originLocalIP, uuid),
|
||||
)
|
||||
if registrationErr := registration.DeserializeError(); registrationErr != nil {
|
||||
// RegisterTunnel RPC failure
|
||||
return processRegisterTunnelError(registrationErr, config.Metrics, register)
|
||||
}
|
||||
|
||||
// Send free tunnel URL to UI
|
||||
if config.TunnelEventChan != nil {
|
||||
config.TunnelEventChan <- ui.TunnelEvent{EventType: ui.SetUrl, Url: registration.Url}
|
||||
}
|
||||
credentialManager.SetEventDigest(connectionID, registration.EventDigest)
|
||||
return processRegistrationSuccess(config, logger, connectionID, registration, register, credentialManager)
|
||||
func (cf *connectedFuse) Connected() {
|
||||
cf.fuse.Fuse(true)
|
||||
cf.backoff.SetGracePeriod()
|
||||
}
|
||||
|
||||
func processRegistrationSuccess(
|
||||
config *TunnelConfig,
|
||||
logger logger.Service,
|
||||
connectionID uint8,
|
||||
registration *tunnelpogs.TunnelRegistration,
|
||||
name rpcName,
|
||||
credentialManager *reconnectCredentialManager,
|
||||
) error {
|
||||
for _, logLine := range registration.LogLines {
|
||||
logger.Info(logLine)
|
||||
}
|
||||
|
||||
if registration.TunnelID != "" {
|
||||
config.Metrics.tunnelsHA.AddTunnelID(connectionID, registration.TunnelID)
|
||||
logger.Infof("Each HA connection's tunnel IDs: %v", config.Metrics.tunnelsHA.String())
|
||||
}
|
||||
|
||||
// Print out the user's trial zone URL in a nice box (if they requested and got one and UI flag is not set)
|
||||
if config.TunnelEventChan == nil {
|
||||
if config.IsTrialTunnel() {
|
||||
if registrationURL, err := url.Parse(registration.Url); err == nil {
|
||||
for _, line := range asciiBox(trialZoneMsg(registrationURL.String()), 2) {
|
||||
logger.Info(line)
|
||||
}
|
||||
} else {
|
||||
logger.Error("Failed to connect tunnel, please try again.")
|
||||
return fmt.Errorf("empty URL in response from Cloudflare edge")
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
credentialManager.SetConnDigest(connectionID, registration.ConnDigest)
|
||||
config.Metrics.userHostnamesCounts.WithLabelValues(registration.Url).Inc()
|
||||
|
||||
logger.Infof("Route propagating, it may take up to 1 minute for your new route to become functional")
|
||||
config.Metrics.regSuccess.WithLabelValues(string(name)).Inc()
|
||||
return nil
|
||||
}
|
||||
|
||||
func processRegisterTunnelError(err tunnelpogs.TunnelRegistrationError, metrics *TunnelMetrics, name rpcName) error {
|
||||
if err.Error() == DuplicateConnectionError {
|
||||
metrics.regFail.WithLabelValues("dup_edge_conn", string(name)).Inc()
|
||||
return errDuplicationConnection
|
||||
}
|
||||
metrics.regFail.WithLabelValues("server_error", string(name)).Inc()
|
||||
return serverRegisterTunnelError{
|
||||
cause: err,
|
||||
permanent: err.IsPermanent(),
|
||||
}
|
||||
}
|
||||
|
||||
func UnregisterTunnel(muxer *h2mux.Muxer, config *TunnelConfig) error {
|
||||
config.TransportLogger.Debug("initiating RPC stream to unregister")
|
||||
ctx := context.Background()
|
||||
rpcClient, err := newTunnelRPCClient(ctx, muxer, config, unregister)
|
||||
if err != nil {
|
||||
// RPC stream open error
|
||||
return err
|
||||
}
|
||||
defer rpcClient.Close()
|
||||
|
||||
// gracePeriod is encoded in int64 using capnproto
|
||||
return rpcClient.UnregisterTunnel(ctx, config.GracePeriod.Nanoseconds())
|
||||
}
|
||||
|
||||
func LogServerInfo(
|
||||
promise tunnelrpc.ServerInfo_Promise,
|
||||
connectionID uint8,
|
||||
metrics *TunnelMetrics,
|
||||
logger logger.Service,
|
||||
tunnelEventChan chan<- ui.TunnelEvent,
|
||||
) {
|
||||
serverInfoMessage, err := promise.Struct()
|
||||
if err != nil {
|
||||
logger.Errorf("Failed to retrieve server information: %s", err)
|
||||
return
|
||||
}
|
||||
serverInfo, err := tunnelpogs.UnmarshalServerInfo(serverInfoMessage)
|
||||
if err != nil {
|
||||
logger.Errorf("Failed to retrieve server information: %s", err)
|
||||
return
|
||||
}
|
||||
// If launch-ui flag is set, send connect msg
|
||||
if tunnelEventChan != nil {
|
||||
tunnelEventChan <- ui.TunnelEvent{Index: connectionID, EventType: ui.Connected, Location: serverInfo.LocationName}
|
||||
}
|
||||
logger.Infof("Connected to %s", serverInfo.LocationName)
|
||||
metrics.registerServerLocation(uint8ToString(connectionID), serverInfo.LocationName)
|
||||
}
|
||||
|
||||
func serveWebsocket(wsResp WebsocketResp, req *http.Request, rule *ingress.Rule) (*http.Response, error) {
|
||||
if hostHeader := rule.Config.HTTPHostHeader; hostHeader != "" {
|
||||
req.Header.Set("Host", hostHeader)
|
||||
req.Host = hostHeader
|
||||
}
|
||||
|
||||
dialler, ok := rule.Service.(websocket.Dialler)
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("Websockets aren't supported by the origin service '%s'", rule.Service)
|
||||
}
|
||||
conn, response, err := websocket.ClientConnect(req, dialler)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer conn.Close()
|
||||
err = wsResp.WriteRespHeaders(response)
|
||||
if err != nil {
|
||||
return nil, errors.Wrap(err, "Error writing response header")
|
||||
}
|
||||
// Copy to/from stream to the undelying connection. Use the underlying
|
||||
// connection because cloudflared doesn't operate on the message themselves
|
||||
websocket.Stream(conn.UnderlyingConn(), wsResp)
|
||||
|
||||
return response, nil
|
||||
}
|
||||
|
||||
func uint8ToString(input uint8) string {
|
||||
return strconv.FormatUint(uint64(input), 10)
|
||||
}
|
||||
|
||||
// Print out the given lines in a nice ASCII box.
|
||||
func asciiBox(lines []string, padding int) (box []string) {
|
||||
maxLen := maxLen(lines)
|
||||
spacer := strings.Repeat(" ", padding)
|
||||
|
||||
border := "+" + strings.Repeat("-", maxLen+(padding*2)) + "+"
|
||||
|
||||
box = append(box, border)
|
||||
for _, line := range lines {
|
||||
box = append(box, "|"+spacer+line+strings.Repeat(" ", maxLen-len(line))+spacer+"|")
|
||||
}
|
||||
box = append(box, border)
|
||||
return
|
||||
}
|
||||
|
||||
func maxLen(lines []string) int {
|
||||
max := 0
|
||||
for _, line := range lines {
|
||||
if len(line) > max {
|
||||
max = len(line)
|
||||
}
|
||||
}
|
||||
return max
|
||||
}
|
||||
|
||||
func trialZoneMsg(url string) []string {
|
||||
return []string{
|
||||
"Your free tunnel has started! Visit it:",
|
||||
" " + url,
|
||||
}
|
||||
func (cf *connectedFuse) IsConnected() bool {
|
||||
return cf.fuse.Value()
|
||||
}
|
||||
|
||||
func activeIncidentsMsg(incidents []Incident) string {
|
||||
@@ -741,26 +377,3 @@ func activeIncidentsMsg(incidents []Incident) string {
|
||||
return preamble + " " + strings.Join(incidentStrings, "; ")
|
||||
|
||||
}
|
||||
|
||||
func findCfRayHeader(h1 *http.Request) string {
|
||||
return h1.Header.Get("Cf-Ray")
|
||||
}
|
||||
|
||||
func isLBProbeRequest(req *http.Request) bool {
|
||||
return strings.HasPrefix(req.UserAgent(), lbProbeUserAgentPrefix)
|
||||
}
|
||||
|
||||
func newTunnelRPCClient(ctx context.Context, muxer *h2mux.Muxer, config *TunnelConfig, rpcName rpcName) (tunnelpogs.TunnelServer_PogsClient, error) {
|
||||
openStreamCtx, openStreamCancel := context.WithTimeout(ctx, openStreamTimeout)
|
||||
defer openStreamCancel()
|
||||
stream, err := muxer.OpenRPCStream(openStreamCtx)
|
||||
if err != nil {
|
||||
return tunnelpogs.TunnelServer_PogsClient{}, err
|
||||
}
|
||||
rpcClient, err := connection.NewTunnelRPCClient(ctx, stream, config.TransportLogger)
|
||||
if err != nil {
|
||||
// RPC stream open error
|
||||
return tunnelpogs.TunnelServer_PogsClient{}, newRPCError(err, config.Metrics.rpcFail, rpcName)
|
||||
}
|
||||
return rpcClient, nil
|
||||
}
|
||||
|
Reference in New Issue
Block a user