mirror of
https://github.com/cloudflare/cloudflared.git
synced 2025-07-27 08:09:58 +00:00
TUN-528: Move cloudflared into a separate repo
This commit is contained in:
95
origin/backoffhandler.go
Normal file
95
origin/backoffhandler.go
Normal file
@@ -0,0 +1,95 @@
|
||||
package origin
|
||||
|
||||
import (
|
||||
"time"
|
||||
|
||||
"golang.org/x/net/context"
|
||||
)
|
||||
|
||||
// Redeclare time functions so they can be overridden in tests.
|
||||
var (
|
||||
timeNow = time.Now
|
||||
timeAfter = time.After
|
||||
)
|
||||
|
||||
// BackoffHandler manages exponential backoff and limits the maximum number of retries.
|
||||
// The base time period is 1 second, doubling with each retry.
|
||||
// After initial success, a grace period can be set to reset the backoff timer if
|
||||
// a connection is maintained successfully for a long enough period. The base grace period
|
||||
// is 2 seconds, doubling with each retry.
|
||||
type BackoffHandler struct {
|
||||
// MaxRetries sets the maximum number of retries to perform. The default value
|
||||
// of 0 disables retry completely.
|
||||
MaxRetries uint
|
||||
// RetryForever caps the exponential backoff period according to MaxRetries
|
||||
// but allows you to retry indefinitely.
|
||||
RetryForever bool
|
||||
// BaseTime sets the initial backoff period.
|
||||
BaseTime time.Duration
|
||||
|
||||
retries uint
|
||||
resetDeadline time.Time
|
||||
}
|
||||
|
||||
func (b BackoffHandler) GetBackoffDuration(ctx context.Context) (time.Duration, bool) {
|
||||
// Follows the same logic as Backoff, but without mutating the receiver.
|
||||
// This select has to happen first to reflect the actual behaviour of the Backoff function.
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return time.Duration(0), false
|
||||
default:
|
||||
}
|
||||
if !b.resetDeadline.IsZero() && timeNow().After(b.resetDeadline) {
|
||||
// b.retries would be set to 0 at this point
|
||||
return time.Second, true
|
||||
}
|
||||
if b.retries >= b.MaxRetries && !b.RetryForever {
|
||||
return time.Duration(0), false
|
||||
}
|
||||
return time.Duration(b.GetBaseTime() * 1 << b.retries), true
|
||||
}
|
||||
|
||||
// BackoffTimer returns a channel that sends the current time when the exponential backoff timeout expires.
|
||||
// Returns nil if the maximum number of retries have been used.
|
||||
func (b *BackoffHandler) BackoffTimer() <-chan time.Time {
|
||||
if !b.resetDeadline.IsZero() && timeNow().After(b.resetDeadline) {
|
||||
b.retries = 0
|
||||
b.resetDeadline = time.Time{}
|
||||
}
|
||||
if b.retries >= b.MaxRetries {
|
||||
if !b.RetryForever {
|
||||
return nil
|
||||
}
|
||||
} else {
|
||||
b.retries++
|
||||
}
|
||||
return timeAfter(time.Duration(b.GetBaseTime() * 1 << (b.retries - 1)))
|
||||
}
|
||||
|
||||
// Backoff is used to wait according to exponential backoff. Returns false if the
|
||||
// maximum number of retries have been used or if the underlying context has been cancelled.
|
||||
func (b *BackoffHandler) Backoff(ctx context.Context) bool {
|
||||
c := b.BackoffTimer()
|
||||
if c == nil {
|
||||
return false
|
||||
}
|
||||
select {
|
||||
case <-c:
|
||||
return true
|
||||
case <-ctx.Done():
|
||||
return false
|
||||
}
|
||||
}
|
||||
|
||||
// Sets a grace period within which the the backoff timer is maintained. After the grace
|
||||
// period expires, the number of retries & backoff duration is reset.
|
||||
func (b *BackoffHandler) SetGracePeriod() {
|
||||
b.resetDeadline = timeNow().Add(time.Duration(b.GetBaseTime() * 2 << b.retries))
|
||||
}
|
||||
|
||||
func (b BackoffHandler) GetBaseTime() time.Duration {
|
||||
if b.BaseTime == 0 {
|
||||
return time.Second
|
||||
}
|
||||
return b.BaseTime
|
||||
}
|
148
origin/backoffhandler_test.go
Normal file
148
origin/backoffhandler_test.go
Normal file
@@ -0,0 +1,148 @@
|
||||
package origin
|
||||
|
||||
import (
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"golang.org/x/net/context"
|
||||
)
|
||||
|
||||
func immediateTimeAfter(time.Duration) <-chan time.Time {
|
||||
c := make(chan time.Time, 1)
|
||||
c <- time.Now()
|
||||
return c
|
||||
}
|
||||
|
||||
func TestBackoffRetries(t *testing.T) {
|
||||
// make backoff return immediately
|
||||
timeAfter = immediateTimeAfter
|
||||
ctx := context.Background()
|
||||
backoff := BackoffHandler{MaxRetries: 3}
|
||||
if !backoff.Backoff(ctx) {
|
||||
t.Fatalf("backoff failed immediately")
|
||||
}
|
||||
if !backoff.Backoff(ctx) {
|
||||
t.Fatalf("backoff failed after 1 retry")
|
||||
}
|
||||
if !backoff.Backoff(ctx) {
|
||||
t.Fatalf("backoff failed after 2 retry")
|
||||
}
|
||||
if backoff.Backoff(ctx) {
|
||||
t.Fatalf("backoff allowed after 3 (max) retries")
|
||||
}
|
||||
}
|
||||
|
||||
func TestBackoffCancel(t *testing.T) {
|
||||
// prevent backoff from returning normally
|
||||
timeAfter = func(time.Duration) <-chan time.Time { return make(chan time.Time) }
|
||||
ctx, cancelFunc := context.WithCancel(context.Background())
|
||||
backoff := BackoffHandler{MaxRetries: 3}
|
||||
cancelFunc()
|
||||
if backoff.Backoff(ctx) {
|
||||
t.Fatalf("backoff allowed after cancel")
|
||||
}
|
||||
if _, ok := backoff.GetBackoffDuration(ctx); ok {
|
||||
t.Fatalf("backoff allowed after cancel")
|
||||
}
|
||||
}
|
||||
|
||||
func TestBackoffGracePeriod(t *testing.T) {
|
||||
currentTime := time.Now()
|
||||
// make timeNow return whatever we like
|
||||
timeNow = func() time.Time { return currentTime }
|
||||
// make backoff return immediately
|
||||
timeAfter = immediateTimeAfter
|
||||
ctx := context.Background()
|
||||
backoff := BackoffHandler{MaxRetries: 1}
|
||||
if !backoff.Backoff(ctx) {
|
||||
t.Fatalf("backoff failed immediately")
|
||||
}
|
||||
// the next call to Backoff would fail unless it's after the grace period
|
||||
backoff.SetGracePeriod()
|
||||
// advance time to after the grace period (~4 seconds) and see what happens
|
||||
currentTime = currentTime.Add(time.Second * 5)
|
||||
if !backoff.Backoff(ctx) {
|
||||
t.Fatalf("backoff failed after the grace period expired")
|
||||
}
|
||||
// confirm we ignore grace period after backoff
|
||||
if backoff.Backoff(ctx) {
|
||||
t.Fatalf("backoff allowed after 1 (max) retry")
|
||||
}
|
||||
}
|
||||
|
||||
func TestGetBackoffDurationRetries(t *testing.T) {
|
||||
// make backoff return immediately
|
||||
timeAfter = immediateTimeAfter
|
||||
ctx := context.Background()
|
||||
backoff := BackoffHandler{MaxRetries: 3}
|
||||
if _, ok := backoff.GetBackoffDuration(ctx); !ok {
|
||||
t.Fatalf("backoff failed immediately")
|
||||
}
|
||||
backoff.Backoff(ctx) // noop
|
||||
if _, ok := backoff.GetBackoffDuration(ctx); !ok {
|
||||
t.Fatalf("backoff failed after 1 retry")
|
||||
}
|
||||
backoff.Backoff(ctx) // noop
|
||||
if _, ok := backoff.GetBackoffDuration(ctx); !ok {
|
||||
t.Fatalf("backoff failed after 2 retry")
|
||||
}
|
||||
backoff.Backoff(ctx) // noop
|
||||
if _, ok := backoff.GetBackoffDuration(ctx); ok {
|
||||
t.Fatalf("backoff allowed after 3 (max) retries")
|
||||
}
|
||||
if backoff.Backoff(ctx) {
|
||||
t.Fatalf("backoff allowed after 3 (max) retries")
|
||||
}
|
||||
}
|
||||
|
||||
func TestGetBackoffDuration(t *testing.T) {
|
||||
// make backoff return immediately
|
||||
timeAfter = immediateTimeAfter
|
||||
ctx := context.Background()
|
||||
backoff := BackoffHandler{MaxRetries: 3}
|
||||
if duration, ok := backoff.GetBackoffDuration(ctx); !ok || duration != time.Second {
|
||||
t.Fatalf("backoff didn't return 1 second on first retry")
|
||||
}
|
||||
backoff.Backoff(ctx) // noop
|
||||
if duration, ok := backoff.GetBackoffDuration(ctx); !ok || duration != time.Second*2 {
|
||||
t.Fatalf("backoff didn't return 2 seconds on second retry")
|
||||
}
|
||||
backoff.Backoff(ctx) // noop
|
||||
if duration, ok := backoff.GetBackoffDuration(ctx); !ok || duration != time.Second*4 {
|
||||
t.Fatalf("backoff didn't return 4 seconds on third retry")
|
||||
}
|
||||
backoff.Backoff(ctx) // noop
|
||||
if duration, ok := backoff.GetBackoffDuration(ctx); ok || duration != 0 {
|
||||
t.Fatalf("backoff didn't return 0 seconds on fourth retry (exceeding limit)")
|
||||
}
|
||||
}
|
||||
|
||||
func TestBackoffRetryForever(t *testing.T) {
|
||||
// make backoff return immediately
|
||||
timeAfter = immediateTimeAfter
|
||||
ctx := context.Background()
|
||||
backoff := BackoffHandler{MaxRetries: 3, RetryForever: true}
|
||||
if duration, ok := backoff.GetBackoffDuration(ctx); !ok || duration != time.Second {
|
||||
t.Fatalf("backoff didn't return 1 second on first retry")
|
||||
}
|
||||
backoff.Backoff(ctx) // noop
|
||||
if duration, ok := backoff.GetBackoffDuration(ctx); !ok || duration != time.Second*2 {
|
||||
t.Fatalf("backoff didn't return 2 seconds on second retry")
|
||||
}
|
||||
backoff.Backoff(ctx) // noop
|
||||
if duration, ok := backoff.GetBackoffDuration(ctx); !ok || duration != time.Second*4 {
|
||||
t.Fatalf("backoff didn't return 4 seconds on third retry")
|
||||
}
|
||||
if !backoff.Backoff(ctx) {
|
||||
t.Fatalf("backoff refused on fourth retry despire RetryForever")
|
||||
}
|
||||
if duration, ok := backoff.GetBackoffDuration(ctx); !ok || duration != time.Second*8 {
|
||||
t.Fatalf("backoff returned %v instead of 8 seconds on fourth retry", duration)
|
||||
}
|
||||
if !backoff.Backoff(ctx) {
|
||||
t.Fatalf("backoff refused on fifth retry despire RetryForever")
|
||||
}
|
||||
if duration, ok := backoff.GetBackoffDuration(ctx); !ok || duration != time.Second*8 {
|
||||
t.Fatalf("backoff returned %v instead of 8 seconds on fifth retry", duration)
|
||||
}
|
||||
}
|
19
origin/build_info.go
Normal file
19
origin/build_info.go
Normal file
@@ -0,0 +1,19 @@
|
||||
package origin
|
||||
|
||||
import (
|
||||
"runtime"
|
||||
)
|
||||
|
||||
type BuildInfo struct {
|
||||
GoOS string `json:"go_os"`
|
||||
GoVersion string `json:"go_version"`
|
||||
GoArch string `json:"go_arch"`
|
||||
}
|
||||
|
||||
func GetBuildInfo() *BuildInfo {
|
||||
return &BuildInfo{
|
||||
GoOS: runtime.GOOS,
|
||||
GoVersion: runtime.Version(),
|
||||
GoArch: runtime.GOARCH,
|
||||
}
|
||||
}
|
82
origin/discovery.go
Normal file
82
origin/discovery.go
Normal file
@@ -0,0 +1,82 @@
|
||||
package origin
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"net"
|
||||
)
|
||||
|
||||
const (
|
||||
// Used to discover HA Warp servers
|
||||
srvService = "warp"
|
||||
srvProto = "tcp"
|
||||
srvName = "cloudflarewarp.com"
|
||||
)
|
||||
|
||||
func ResolveEdgeIPs(addresses []string) ([]*net.TCPAddr, error) {
|
||||
if len(addresses) > 0 {
|
||||
var tcpAddrs []*net.TCPAddr
|
||||
for _, address := range addresses {
|
||||
// Addresses specified (for testing, usually)
|
||||
tcpAddr, err := net.ResolveTCPAddr("tcp", address)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
tcpAddrs = append(tcpAddrs, tcpAddr)
|
||||
}
|
||||
return tcpAddrs, nil
|
||||
}
|
||||
// HA service discovery lookup
|
||||
_, addrs, err := net.LookupSRV(srvService, srvProto, srvName)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
var resolvedIPsPerCNAME [][]*net.TCPAddr
|
||||
var lookupErr error
|
||||
for _, addr := range addrs {
|
||||
ips, err := ResolveSRVToTCP(addr)
|
||||
if err != nil || len(ips) == 0 {
|
||||
// don't return early, we might be able to resolve other addresses
|
||||
lookupErr = err
|
||||
continue
|
||||
}
|
||||
resolvedIPsPerCNAME = append(resolvedIPsPerCNAME, ips)
|
||||
}
|
||||
ips := FlattenServiceIPs(resolvedIPsPerCNAME)
|
||||
if lookupErr == nil && len(ips) == 0 {
|
||||
return nil, fmt.Errorf("Unknown service discovery error")
|
||||
}
|
||||
return ips, lookupErr
|
||||
}
|
||||
|
||||
func ResolveSRVToTCP(srv *net.SRV) ([]*net.TCPAddr, error) {
|
||||
ips, err := net.LookupIP(srv.Target)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
addrs := make([]*net.TCPAddr, len(ips))
|
||||
for i, ip := range ips {
|
||||
addrs[i] = &net.TCPAddr{IP: ip, Port: int(srv.Port)}
|
||||
}
|
||||
return addrs, nil
|
||||
}
|
||||
|
||||
// FlattenServiceIPs transposes and flattens the input slices such that the
|
||||
// first element of the n inner slices are the first n elements of the result.
|
||||
func FlattenServiceIPs(ipsByService [][]*net.TCPAddr) []*net.TCPAddr {
|
||||
var result []*net.TCPAddr
|
||||
for len(ipsByService) > 0 {
|
||||
filtered := ipsByService[:0]
|
||||
for _, ips := range ipsByService {
|
||||
if len(ips) == 0 {
|
||||
// sanity check
|
||||
continue
|
||||
}
|
||||
result = append(result, ips[0])
|
||||
if len(ips) > 1 {
|
||||
filtered = append(filtered, ips[1:])
|
||||
}
|
||||
}
|
||||
ipsByService = filtered
|
||||
}
|
||||
return result
|
||||
}
|
45
origin/discovery_test.go
Normal file
45
origin/discovery_test.go
Normal file
@@ -0,0 +1,45 @@
|
||||
package origin
|
||||
|
||||
import (
|
||||
"net"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
func TestFlattenServiceIPs(t *testing.T) {
|
||||
result := FlattenServiceIPs([][]*net.TCPAddr{
|
||||
[]*net.TCPAddr{
|
||||
&net.TCPAddr{Port: 1},
|
||||
&net.TCPAddr{Port: 2},
|
||||
&net.TCPAddr{Port: 3},
|
||||
&net.TCPAddr{Port: 4},
|
||||
},
|
||||
[]*net.TCPAddr{
|
||||
&net.TCPAddr{Port: 10},
|
||||
&net.TCPAddr{Port: 12},
|
||||
&net.TCPAddr{Port: 13},
|
||||
},
|
||||
[]*net.TCPAddr{
|
||||
&net.TCPAddr{Port: 21},
|
||||
&net.TCPAddr{Port: 22},
|
||||
&net.TCPAddr{Port: 23},
|
||||
&net.TCPAddr{Port: 24},
|
||||
&net.TCPAddr{Port: 25},
|
||||
},
|
||||
})
|
||||
assert.EqualValues(t, []*net.TCPAddr{
|
||||
&net.TCPAddr{Port: 1},
|
||||
&net.TCPAddr{Port: 10},
|
||||
&net.TCPAddr{Port: 21},
|
||||
&net.TCPAddr{Port: 2},
|
||||
&net.TCPAddr{Port: 12},
|
||||
&net.TCPAddr{Port: 22},
|
||||
&net.TCPAddr{Port: 3},
|
||||
&net.TCPAddr{Port: 13},
|
||||
&net.TCPAddr{Port: 23},
|
||||
&net.TCPAddr{Port: 4},
|
||||
&net.TCPAddr{Port: 24},
|
||||
&net.TCPAddr{Port: 25},
|
||||
}, result)
|
||||
}
|
421
origin/metrics.go
Normal file
421
origin/metrics.go
Normal file
@@ -0,0 +1,421 @@
|
||||
package origin
|
||||
|
||||
import (
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/cloudflare/cloudflared/h2mux"
|
||||
|
||||
"github.com/prometheus/client_golang/prometheus"
|
||||
)
|
||||
|
||||
type muxerMetrics struct {
|
||||
rtt *prometheus.GaugeVec
|
||||
rttMin *prometheus.GaugeVec
|
||||
rttMax *prometheus.GaugeVec
|
||||
receiveWindowAve *prometheus.GaugeVec
|
||||
sendWindowAve *prometheus.GaugeVec
|
||||
receiveWindowMin *prometheus.GaugeVec
|
||||
receiveWindowMax *prometheus.GaugeVec
|
||||
sendWindowMin *prometheus.GaugeVec
|
||||
sendWindowMax *prometheus.GaugeVec
|
||||
inBoundRateCurr *prometheus.GaugeVec
|
||||
inBoundRateMin *prometheus.GaugeVec
|
||||
inBoundRateMax *prometheus.GaugeVec
|
||||
outBoundRateCurr *prometheus.GaugeVec
|
||||
outBoundRateMin *prometheus.GaugeVec
|
||||
outBoundRateMax *prometheus.GaugeVec
|
||||
compBytesBefore *prometheus.GaugeVec
|
||||
compBytesAfter *prometheus.GaugeVec
|
||||
compRateAve *prometheus.GaugeVec
|
||||
}
|
||||
|
||||
type TunnelMetrics struct {
|
||||
haConnections prometheus.Gauge
|
||||
totalRequests prometheus.Counter
|
||||
requestsPerTunnel *prometheus.CounterVec
|
||||
// concurrentRequestsLock is a mutex for concurrentRequests and maxConcurrentRequests
|
||||
concurrentRequestsLock sync.Mutex
|
||||
concurrentRequestsPerTunnel *prometheus.GaugeVec
|
||||
// concurrentRequests records count of concurrent requests for each tunnel
|
||||
concurrentRequests map[string]uint64
|
||||
maxConcurrentRequestsPerTunnel *prometheus.GaugeVec
|
||||
// concurrentRequests records max count of concurrent requests for each tunnel
|
||||
maxConcurrentRequests map[string]uint64
|
||||
timerRetries prometheus.Gauge
|
||||
responseByCode *prometheus.CounterVec
|
||||
responseCodePerTunnel *prometheus.CounterVec
|
||||
serverLocations *prometheus.GaugeVec
|
||||
// locationLock is a mutex for oldServerLocations
|
||||
locationLock sync.Mutex
|
||||
// oldServerLocations stores the last server the tunnel was connected to
|
||||
oldServerLocations map[string]string
|
||||
|
||||
muxerMetrics *muxerMetrics
|
||||
}
|
||||
|
||||
func newMuxerMetrics() *muxerMetrics {
|
||||
rtt := prometheus.NewGaugeVec(
|
||||
prometheus.GaugeOpts{
|
||||
Name: "rtt",
|
||||
Help: "Round-trip time in millisecond",
|
||||
},
|
||||
[]string{"connection_id"},
|
||||
)
|
||||
prometheus.MustRegister(rtt)
|
||||
|
||||
rttMin := prometheus.NewGaugeVec(
|
||||
prometheus.GaugeOpts{
|
||||
Name: "rtt_min",
|
||||
Help: "Shortest round-trip time in millisecond",
|
||||
},
|
||||
[]string{"connection_id"},
|
||||
)
|
||||
prometheus.MustRegister(rttMin)
|
||||
|
||||
rttMax := prometheus.NewGaugeVec(
|
||||
prometheus.GaugeOpts{
|
||||
Name: "rtt_max",
|
||||
Help: "Longest round-trip time in millisecond",
|
||||
},
|
||||
[]string{"connection_id"},
|
||||
)
|
||||
prometheus.MustRegister(rttMax)
|
||||
|
||||
receiveWindowAve := prometheus.NewGaugeVec(
|
||||
prometheus.GaugeOpts{
|
||||
Name: "receive_window_ave",
|
||||
Help: "Average receive window size in bytes",
|
||||
},
|
||||
[]string{"connection_id"},
|
||||
)
|
||||
prometheus.MustRegister(receiveWindowAve)
|
||||
|
||||
sendWindowAve := prometheus.NewGaugeVec(
|
||||
prometheus.GaugeOpts{
|
||||
Name: "send_window_ave",
|
||||
Help: "Average send window size in bytes",
|
||||
},
|
||||
[]string{"connection_id"},
|
||||
)
|
||||
prometheus.MustRegister(sendWindowAve)
|
||||
|
||||
receiveWindowMin := prometheus.NewGaugeVec(
|
||||
prometheus.GaugeOpts{
|
||||
Name: "receive_window_min",
|
||||
Help: "Smallest receive window size in bytes",
|
||||
},
|
||||
[]string{"connection_id"},
|
||||
)
|
||||
prometheus.MustRegister(receiveWindowMin)
|
||||
|
||||
receiveWindowMax := prometheus.NewGaugeVec(
|
||||
prometheus.GaugeOpts{
|
||||
Name: "receive_window_max",
|
||||
Help: "Largest receive window size in bytes",
|
||||
},
|
||||
[]string{"connection_id"},
|
||||
)
|
||||
prometheus.MustRegister(receiveWindowMax)
|
||||
|
||||
sendWindowMin := prometheus.NewGaugeVec(
|
||||
prometheus.GaugeOpts{
|
||||
Name: "send_window_min",
|
||||
Help: "Smallest send window size in bytes",
|
||||
},
|
||||
[]string{"connection_id"},
|
||||
)
|
||||
prometheus.MustRegister(sendWindowMin)
|
||||
|
||||
sendWindowMax := prometheus.NewGaugeVec(
|
||||
prometheus.GaugeOpts{
|
||||
Name: "send_window_max",
|
||||
Help: "Largest send window size in bytes",
|
||||
},
|
||||
[]string{"connection_id"},
|
||||
)
|
||||
prometheus.MustRegister(sendWindowMax)
|
||||
|
||||
inBoundRateCurr := prometheus.NewGaugeVec(
|
||||
prometheus.GaugeOpts{
|
||||
Name: "inbound_bytes_per_sec_curr",
|
||||
Help: "Current inbounding bytes per second, 0 if there is no incoming connection",
|
||||
},
|
||||
[]string{"connection_id"},
|
||||
)
|
||||
prometheus.MustRegister(inBoundRateCurr)
|
||||
|
||||
inBoundRateMin := prometheus.NewGaugeVec(
|
||||
prometheus.GaugeOpts{
|
||||
Name: "inbound_bytes_per_sec_min",
|
||||
Help: "Minimum non-zero inbounding bytes per second",
|
||||
},
|
||||
[]string{"connection_id"},
|
||||
)
|
||||
prometheus.MustRegister(inBoundRateMin)
|
||||
|
||||
inBoundRateMax := prometheus.NewGaugeVec(
|
||||
prometheus.GaugeOpts{
|
||||
Name: "inbound_bytes_per_sec_max",
|
||||
Help: "Maximum inbounding bytes per second",
|
||||
},
|
||||
[]string{"connection_id"},
|
||||
)
|
||||
prometheus.MustRegister(inBoundRateMax)
|
||||
|
||||
outBoundRateCurr := prometheus.NewGaugeVec(
|
||||
prometheus.GaugeOpts{
|
||||
Name: "outbound_bytes_per_sec_curr",
|
||||
Help: "Current outbounding bytes per second, 0 if there is no outgoing traffic",
|
||||
},
|
||||
[]string{"connection_id"},
|
||||
)
|
||||
prometheus.MustRegister(outBoundRateCurr)
|
||||
|
||||
outBoundRateMin := prometheus.NewGaugeVec(
|
||||
prometheus.GaugeOpts{
|
||||
Name: "outbound_bytes_per_sec_min",
|
||||
Help: "Minimum non-zero outbounding bytes per second",
|
||||
},
|
||||
[]string{"connection_id"},
|
||||
)
|
||||
prometheus.MustRegister(outBoundRateMin)
|
||||
|
||||
outBoundRateMax := prometheus.NewGaugeVec(
|
||||
prometheus.GaugeOpts{
|
||||
Name: "outbound_bytes_per_sec_max",
|
||||
Help: "Maximum outbounding bytes per second",
|
||||
},
|
||||
[]string{"connection_id"},
|
||||
)
|
||||
prometheus.MustRegister(outBoundRateMax)
|
||||
|
||||
compBytesBefore := prometheus.NewGaugeVec(
|
||||
prometheus.GaugeOpts{
|
||||
Name: "comp_bytes_before",
|
||||
Help: "Bytes sent via cross-stream compression, pre compression",
|
||||
},
|
||||
[]string{"connection_id"},
|
||||
)
|
||||
prometheus.MustRegister(compBytesBefore)
|
||||
|
||||
compBytesAfter := prometheus.NewGaugeVec(
|
||||
prometheus.GaugeOpts{
|
||||
Name: "comp_bytes_after",
|
||||
Help: "Bytes sent via cross-stream compression, post compression",
|
||||
},
|
||||
[]string{"connection_id"},
|
||||
)
|
||||
prometheus.MustRegister(compBytesAfter)
|
||||
|
||||
compRateAve := prometheus.NewGaugeVec(
|
||||
prometheus.GaugeOpts{
|
||||
Name: "comp_rate_ave",
|
||||
Help: "Average outbound cross-stream compression ratio",
|
||||
},
|
||||
[]string{"connection_id"},
|
||||
)
|
||||
prometheus.MustRegister(compRateAve)
|
||||
|
||||
return &muxerMetrics{
|
||||
rtt: rtt,
|
||||
rttMin: rttMin,
|
||||
rttMax: rttMax,
|
||||
receiveWindowAve: receiveWindowAve,
|
||||
sendWindowAve: sendWindowAve,
|
||||
receiveWindowMin: receiveWindowMin,
|
||||
receiveWindowMax: receiveWindowMax,
|
||||
sendWindowMin: sendWindowMin,
|
||||
sendWindowMax: sendWindowMax,
|
||||
inBoundRateCurr: inBoundRateCurr,
|
||||
inBoundRateMin: inBoundRateMin,
|
||||
inBoundRateMax: inBoundRateMax,
|
||||
outBoundRateCurr: outBoundRateCurr,
|
||||
outBoundRateMin: outBoundRateMin,
|
||||
outBoundRateMax: outBoundRateMax,
|
||||
compBytesBefore: compBytesBefore,
|
||||
compBytesAfter: compBytesAfter,
|
||||
compRateAve: compRateAve,
|
||||
}
|
||||
}
|
||||
|
||||
func (m *muxerMetrics) update(connectionID string, metrics *h2mux.MuxerMetrics) {
|
||||
m.rtt.WithLabelValues(connectionID).Set(convertRTTMilliSec(metrics.RTT))
|
||||
m.rttMin.WithLabelValues(connectionID).Set(convertRTTMilliSec(metrics.RTTMin))
|
||||
m.rttMax.WithLabelValues(connectionID).Set(convertRTTMilliSec(metrics.RTTMax))
|
||||
m.receiveWindowAve.WithLabelValues(connectionID).Set(metrics.ReceiveWindowAve)
|
||||
m.sendWindowAve.WithLabelValues(connectionID).Set(metrics.SendWindowAve)
|
||||
m.receiveWindowMin.WithLabelValues(connectionID).Set(float64(metrics.ReceiveWindowMin))
|
||||
m.receiveWindowMax.WithLabelValues(connectionID).Set(float64(metrics.ReceiveWindowMax))
|
||||
m.sendWindowMin.WithLabelValues(connectionID).Set(float64(metrics.SendWindowMin))
|
||||
m.sendWindowMax.WithLabelValues(connectionID).Set(float64(metrics.SendWindowMax))
|
||||
m.inBoundRateCurr.WithLabelValues(connectionID).Set(float64(metrics.InBoundRateCurr))
|
||||
m.inBoundRateMin.WithLabelValues(connectionID).Set(float64(metrics.InBoundRateMin))
|
||||
m.inBoundRateMax.WithLabelValues(connectionID).Set(float64(metrics.InBoundRateMax))
|
||||
m.outBoundRateCurr.WithLabelValues(connectionID).Set(float64(metrics.OutBoundRateCurr))
|
||||
m.outBoundRateMin.WithLabelValues(connectionID).Set(float64(metrics.OutBoundRateMin))
|
||||
m.outBoundRateMax.WithLabelValues(connectionID).Set(float64(metrics.OutBoundRateMax))
|
||||
m.compBytesBefore.WithLabelValues(connectionID).Set(float64(metrics.CompBytesBefore.Value()))
|
||||
m.compBytesAfter.WithLabelValues(connectionID).Set(float64(metrics.CompBytesAfter.Value()))
|
||||
m.compRateAve.WithLabelValues(connectionID).Set(float64(metrics.CompRateAve()))
|
||||
}
|
||||
|
||||
func convertRTTMilliSec(t time.Duration) float64 {
|
||||
return float64(t / time.Millisecond)
|
||||
}
|
||||
|
||||
// Metrics that can be collected without asking the edge
|
||||
func NewTunnelMetrics() *TunnelMetrics {
|
||||
haConnections := prometheus.NewGauge(
|
||||
prometheus.GaugeOpts{
|
||||
Name: "ha_connections",
|
||||
Help: "Number of active ha connections",
|
||||
})
|
||||
prometheus.MustRegister(haConnections)
|
||||
|
||||
totalRequests := prometheus.NewCounter(
|
||||
prometheus.CounterOpts{
|
||||
Name: "total_requests",
|
||||
Help: "Amount of requests proxied through all the tunnels",
|
||||
})
|
||||
prometheus.MustRegister(totalRequests)
|
||||
|
||||
requestsPerTunnel := prometheus.NewCounterVec(
|
||||
prometheus.CounterOpts{
|
||||
Name: "requests_per_tunnel",
|
||||
Help: "Amount of requests proxied through each tunnel",
|
||||
},
|
||||
[]string{"connection_id"},
|
||||
)
|
||||
prometheus.MustRegister(requestsPerTunnel)
|
||||
|
||||
concurrentRequestsPerTunnel := prometheus.NewGaugeVec(
|
||||
prometheus.GaugeOpts{
|
||||
Name: "concurrent_requests_per_tunnel",
|
||||
Help: "Concurrent requests proxied through each tunnel",
|
||||
},
|
||||
[]string{"connection_id"},
|
||||
)
|
||||
prometheus.MustRegister(concurrentRequestsPerTunnel)
|
||||
|
||||
maxConcurrentRequestsPerTunnel := prometheus.NewGaugeVec(
|
||||
prometheus.GaugeOpts{
|
||||
Name: "max_concurrent_requests_per_tunnel",
|
||||
Help: "Largest number of concurrent requests proxied through each tunnel so far",
|
||||
},
|
||||
[]string{"connection_id"},
|
||||
)
|
||||
prometheus.MustRegister(maxConcurrentRequestsPerTunnel)
|
||||
|
||||
timerRetries := prometheus.NewGauge(
|
||||
prometheus.GaugeOpts{
|
||||
Name: "timer_retries",
|
||||
Help: "Unacknowledged heart beats count",
|
||||
})
|
||||
prometheus.MustRegister(timerRetries)
|
||||
|
||||
responseByCode := prometheus.NewCounterVec(
|
||||
prometheus.CounterOpts{
|
||||
Name: "response_by_code",
|
||||
Help: "Count of responses by HTTP status code",
|
||||
},
|
||||
[]string{"status_code"},
|
||||
)
|
||||
prometheus.MustRegister(responseByCode)
|
||||
|
||||
responseCodePerTunnel := prometheus.NewCounterVec(
|
||||
prometheus.CounterOpts{
|
||||
Name: "response_code_per_tunnel",
|
||||
Help: "Count of responses by HTTP status code fore each tunnel",
|
||||
},
|
||||
[]string{"connection_id", "status_code"},
|
||||
)
|
||||
prometheus.MustRegister(responseCodePerTunnel)
|
||||
|
||||
serverLocations := prometheus.NewGaugeVec(
|
||||
prometheus.GaugeOpts{
|
||||
Name: "server_locations",
|
||||
Help: "Where each tunnel is connected to. 1 means current location, 0 means previous locations.",
|
||||
},
|
||||
[]string{"connection_id", "location"},
|
||||
)
|
||||
prometheus.MustRegister(serverLocations)
|
||||
|
||||
return &TunnelMetrics{
|
||||
haConnections: haConnections,
|
||||
totalRequests: totalRequests,
|
||||
requestsPerTunnel: requestsPerTunnel,
|
||||
concurrentRequestsPerTunnel: concurrentRequestsPerTunnel,
|
||||
concurrentRequests: make(map[string]uint64),
|
||||
maxConcurrentRequestsPerTunnel: maxConcurrentRequestsPerTunnel,
|
||||
maxConcurrentRequests: make(map[string]uint64),
|
||||
timerRetries: timerRetries,
|
||||
responseByCode: responseByCode,
|
||||
responseCodePerTunnel: responseCodePerTunnel,
|
||||
serverLocations: serverLocations,
|
||||
oldServerLocations: make(map[string]string),
|
||||
muxerMetrics: newMuxerMetrics(),
|
||||
}
|
||||
}
|
||||
|
||||
func (t *TunnelMetrics) incrementHaConnections() {
|
||||
t.haConnections.Inc()
|
||||
}
|
||||
|
||||
func (t *TunnelMetrics) decrementHaConnections() {
|
||||
t.haConnections.Dec()
|
||||
}
|
||||
|
||||
func (t *TunnelMetrics) updateMuxerMetrics(connectionID string, metrics *h2mux.MuxerMetrics) {
|
||||
t.muxerMetrics.update(connectionID, metrics)
|
||||
}
|
||||
|
||||
func (t *TunnelMetrics) incrementRequests(connectionID string) {
|
||||
t.concurrentRequestsLock.Lock()
|
||||
var concurrentRequests uint64
|
||||
var ok bool
|
||||
if concurrentRequests, ok = t.concurrentRequests[connectionID]; ok {
|
||||
t.concurrentRequests[connectionID] += 1
|
||||
concurrentRequests++
|
||||
} else {
|
||||
t.concurrentRequests[connectionID] = 1
|
||||
concurrentRequests = 1
|
||||
}
|
||||
if maxConcurrentRequests, ok := t.maxConcurrentRequests[connectionID]; (ok && maxConcurrentRequests < concurrentRequests) || !ok {
|
||||
t.maxConcurrentRequests[connectionID] = concurrentRequests
|
||||
t.maxConcurrentRequestsPerTunnel.WithLabelValues(connectionID).Set(float64(concurrentRequests))
|
||||
}
|
||||
t.concurrentRequestsLock.Unlock()
|
||||
|
||||
t.totalRequests.Inc()
|
||||
t.requestsPerTunnel.WithLabelValues(connectionID).Inc()
|
||||
t.concurrentRequestsPerTunnel.WithLabelValues(connectionID).Inc()
|
||||
}
|
||||
|
||||
func (t *TunnelMetrics) decrementConcurrentRequests(connectionID string) {
|
||||
t.concurrentRequestsLock.Lock()
|
||||
if _, ok := t.concurrentRequests[connectionID]; ok {
|
||||
t.concurrentRequests[connectionID] -= 1
|
||||
}
|
||||
t.concurrentRequestsLock.Unlock()
|
||||
|
||||
t.concurrentRequestsPerTunnel.WithLabelValues(connectionID).Dec()
|
||||
}
|
||||
|
||||
func (t *TunnelMetrics) incrementResponses(connectionID, code string) {
|
||||
t.responseByCode.WithLabelValues(code).Inc()
|
||||
t.responseCodePerTunnel.WithLabelValues(connectionID, code).Inc()
|
||||
|
||||
}
|
||||
|
||||
func (t *TunnelMetrics) registerServerLocation(connectionID, loc string) {
|
||||
t.locationLock.Lock()
|
||||
defer t.locationLock.Unlock()
|
||||
if oldLoc, ok := t.oldServerLocations[connectionID]; ok && oldLoc == loc {
|
||||
return
|
||||
} else if ok {
|
||||
t.serverLocations.WithLabelValues(connectionID, oldLoc).Dec()
|
||||
}
|
||||
t.serverLocations.WithLabelValues(connectionID, loc).Inc()
|
||||
t.oldServerLocations[connectionID] = loc
|
||||
}
|
121
origin/metrics_test.go
Normal file
121
origin/metrics_test.go
Normal file
@@ -0,0 +1,121 @@
|
||||
package origin
|
||||
|
||||
import (
|
||||
"strconv"
|
||||
"sync"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
// can only be called once
|
||||
var m = NewTunnelMetrics()
|
||||
|
||||
func TestConcurrentRequestsSingleTunnel(t *testing.T) {
|
||||
routines := 20
|
||||
var wg sync.WaitGroup
|
||||
wg.Add(routines)
|
||||
for i := 0; i < routines; i++ {
|
||||
go func() {
|
||||
m.incrementRequests("0")
|
||||
wg.Done()
|
||||
}()
|
||||
}
|
||||
wg.Wait()
|
||||
assert.Len(t, m.concurrentRequests, 1)
|
||||
assert.Equal(t, uint64(routines), m.concurrentRequests["0"])
|
||||
assert.Len(t, m.maxConcurrentRequests, 1)
|
||||
assert.Equal(t, uint64(routines), m.maxConcurrentRequests["0"])
|
||||
|
||||
wg.Add(routines / 2)
|
||||
for i := 0; i < routines/2; i++ {
|
||||
go func() {
|
||||
m.decrementConcurrentRequests("0")
|
||||
wg.Done()
|
||||
}()
|
||||
}
|
||||
wg.Wait()
|
||||
assert.Equal(t, uint64(routines-routines/2), m.concurrentRequests["0"])
|
||||
assert.Equal(t, uint64(routines), m.maxConcurrentRequests["0"])
|
||||
}
|
||||
|
||||
func TestConcurrentRequestsMultiTunnel(t *testing.T) {
|
||||
m.concurrentRequests = make(map[string]uint64)
|
||||
m.maxConcurrentRequests = make(map[string]uint64)
|
||||
tunnels := 20
|
||||
var wg sync.WaitGroup
|
||||
wg.Add(tunnels)
|
||||
for i := 0; i < tunnels; i++ {
|
||||
go func(i int) {
|
||||
// if we have j < i, then tunnel 0 won't have a chance to call incrementRequests
|
||||
for j := 0; j < i+1; j++ {
|
||||
id := strconv.Itoa(i)
|
||||
m.incrementRequests(id)
|
||||
}
|
||||
wg.Done()
|
||||
}(i)
|
||||
}
|
||||
wg.Wait()
|
||||
|
||||
assert.Len(t, m.concurrentRequests, tunnels)
|
||||
assert.Len(t, m.maxConcurrentRequests, tunnels)
|
||||
for i := 0; i < tunnels; i++ {
|
||||
id := strconv.Itoa(i)
|
||||
assert.Equal(t, uint64(i+1), m.concurrentRequests[id])
|
||||
assert.Equal(t, uint64(i+1), m.maxConcurrentRequests[id])
|
||||
}
|
||||
|
||||
wg.Add(tunnels)
|
||||
for i := 0; i < tunnels; i++ {
|
||||
go func(i int) {
|
||||
for j := 0; j < i+1; j++ {
|
||||
id := strconv.Itoa(i)
|
||||
m.decrementConcurrentRequests(id)
|
||||
}
|
||||
wg.Done()
|
||||
}(i)
|
||||
}
|
||||
wg.Wait()
|
||||
|
||||
assert.Len(t, m.concurrentRequests, tunnels)
|
||||
assert.Len(t, m.maxConcurrentRequests, tunnels)
|
||||
for i := 0; i < tunnels; i++ {
|
||||
id := strconv.Itoa(i)
|
||||
assert.Equal(t, uint64(0), m.concurrentRequests[id])
|
||||
assert.Equal(t, uint64(i+1), m.maxConcurrentRequests[id])
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
func TestRegisterServerLocation(t *testing.T) {
|
||||
tunnels := 20
|
||||
var wg sync.WaitGroup
|
||||
wg.Add(tunnels)
|
||||
for i := 0; i < tunnels; i++ {
|
||||
go func(i int) {
|
||||
id := strconv.Itoa(i)
|
||||
m.registerServerLocation(id, "LHR")
|
||||
wg.Done()
|
||||
}(i)
|
||||
}
|
||||
wg.Wait()
|
||||
for i := 0; i < tunnels; i++ {
|
||||
id := strconv.Itoa(i)
|
||||
assert.Equal(t, "LHR", m.oldServerLocations[id])
|
||||
}
|
||||
|
||||
wg.Add(tunnels)
|
||||
for i := 0; i < tunnels; i++ {
|
||||
go func(i int) {
|
||||
id := strconv.Itoa(i)
|
||||
m.registerServerLocation(id, "AUS")
|
||||
wg.Done()
|
||||
}(i)
|
||||
}
|
||||
wg.Wait()
|
||||
for i := 0; i < tunnels; i++ {
|
||||
id := strconv.Itoa(i)
|
||||
assert.Equal(t, "AUS", m.oldServerLocations[id])
|
||||
}
|
||||
|
||||
}
|
234
origin/supervisor.go
Normal file
234
origin/supervisor.go
Normal file
@@ -0,0 +1,234 @@
|
||||
package origin
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"net"
|
||||
"time"
|
||||
|
||||
"golang.org/x/net/context"
|
||||
)
|
||||
|
||||
const (
|
||||
// Waiting time before retrying a failed tunnel connection
|
||||
tunnelRetryDuration = time.Second * 10
|
||||
// SRV record resolution TTL
|
||||
resolveTTL = time.Hour
|
||||
// Interval between registering new tunnels
|
||||
registrationInterval = time.Second
|
||||
)
|
||||
|
||||
type Supervisor struct {
|
||||
config *TunnelConfig
|
||||
edgeIPs []*net.TCPAddr
|
||||
// nextUnusedEdgeIP is the index of the next addr k edgeIPs to try
|
||||
nextUnusedEdgeIP int
|
||||
lastResolve time.Time
|
||||
resolverC chan resolveResult
|
||||
tunnelErrors chan tunnelError
|
||||
tunnelsConnecting map[int]chan struct{}
|
||||
// nextConnectedIndex and nextConnectedSignal are used to wait for all
|
||||
// currently-connecting tunnels to finish connecting so we can reset backoff timer
|
||||
nextConnectedIndex int
|
||||
nextConnectedSignal chan struct{}
|
||||
}
|
||||
|
||||
type resolveResult struct {
|
||||
edgeIPs []*net.TCPAddr
|
||||
err error
|
||||
}
|
||||
|
||||
type tunnelError struct {
|
||||
index int
|
||||
err error
|
||||
}
|
||||
|
||||
func NewSupervisor(config *TunnelConfig) *Supervisor {
|
||||
return &Supervisor{
|
||||
config: config,
|
||||
tunnelErrors: make(chan tunnelError),
|
||||
tunnelsConnecting: map[int]chan struct{}{},
|
||||
}
|
||||
}
|
||||
|
||||
func (s *Supervisor) Run(ctx context.Context, connectedSignal chan struct{}) error {
|
||||
logger := s.config.Logger
|
||||
if err := s.initialize(ctx, connectedSignal); err != nil {
|
||||
return err
|
||||
}
|
||||
var tunnelsWaiting []int
|
||||
backoff := BackoffHandler{MaxRetries: s.config.Retries, BaseTime: tunnelRetryDuration, RetryForever: true}
|
||||
var backoffTimer <-chan time.Time
|
||||
tunnelsActive := s.config.HAConnections
|
||||
|
||||
for {
|
||||
select {
|
||||
// Context cancelled
|
||||
case <-ctx.Done():
|
||||
for tunnelsActive > 0 {
|
||||
<-s.tunnelErrors
|
||||
tunnelsActive--
|
||||
}
|
||||
return nil
|
||||
// startTunnel returned with error
|
||||
// (note that this may also be caused by context cancellation)
|
||||
case tunnelError := <-s.tunnelErrors:
|
||||
tunnelsActive--
|
||||
if tunnelError.err != nil {
|
||||
logger.WithError(tunnelError.err).Warn("Tunnel disconnected due to error")
|
||||
tunnelsWaiting = append(tunnelsWaiting, tunnelError.index)
|
||||
s.waitForNextTunnel(tunnelError.index)
|
||||
|
||||
if backoffTimer == nil {
|
||||
backoffTimer = backoff.BackoffTimer()
|
||||
}
|
||||
|
||||
// If the error is a dial error, the problem is likely to be network related
|
||||
// try another addr before refreshing since we are likely to get back the
|
||||
// same IPs in the same order. Same problem with duplicate connection error.
|
||||
if s.unusedIPs() {
|
||||
s.replaceEdgeIP(tunnelError.index)
|
||||
} else {
|
||||
s.refreshEdgeIPs()
|
||||
}
|
||||
}
|
||||
// Backoff was set and its timer expired
|
||||
case <-backoffTimer:
|
||||
backoffTimer = nil
|
||||
for _, index := range tunnelsWaiting {
|
||||
go s.startTunnel(ctx, index, s.newConnectedTunnelSignal(index))
|
||||
}
|
||||
tunnelsActive += len(tunnelsWaiting)
|
||||
tunnelsWaiting = nil
|
||||
// Tunnel successfully connected
|
||||
case <-s.nextConnectedSignal:
|
||||
if !s.waitForNextTunnel(s.nextConnectedIndex) && len(tunnelsWaiting) == 0 {
|
||||
// No more tunnels outstanding, clear backoff timer
|
||||
backoff.SetGracePeriod()
|
||||
}
|
||||
// DNS resolution returned
|
||||
case result := <-s.resolverC:
|
||||
s.lastResolve = time.Now()
|
||||
s.resolverC = nil
|
||||
if result.err == nil {
|
||||
logger.Debug("Service discovery refresh complete")
|
||||
s.edgeIPs = result.edgeIPs
|
||||
} else {
|
||||
logger.WithError(result.err).Error("Service discovery error")
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (s *Supervisor) initialize(ctx context.Context, connectedSignal chan struct{}) error {
|
||||
logger := s.config.Logger
|
||||
edgeIPs, err := ResolveEdgeIPs(s.config.EdgeAddrs)
|
||||
if err != nil {
|
||||
logger.Infof("ResolveEdgeIPs err")
|
||||
return err
|
||||
}
|
||||
s.edgeIPs = edgeIPs
|
||||
if s.config.HAConnections > len(edgeIPs) {
|
||||
logger.Warnf("You requested %d HA connections but I can give you at most %d.", s.config.HAConnections, len(edgeIPs))
|
||||
s.config.HAConnections = len(edgeIPs)
|
||||
}
|
||||
s.lastResolve = time.Now()
|
||||
// check entitlement and version too old error before attempting to register more tunnels
|
||||
s.nextUnusedEdgeIP = s.config.HAConnections
|
||||
go s.startFirstTunnel(ctx, connectedSignal)
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
<-s.tunnelErrors
|
||||
// Error can't be nil. A nil error signals that initialization succeed
|
||||
return fmt.Errorf("context was canceled")
|
||||
case tunnelError := <-s.tunnelErrors:
|
||||
return tunnelError.err
|
||||
case <-connectedSignal:
|
||||
}
|
||||
// At least one successful connection, so start the rest
|
||||
for i := 1; i < s.config.HAConnections; i++ {
|
||||
go s.startTunnel(ctx, i, make(chan struct{}))
|
||||
time.Sleep(registrationInterval)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// startTunnel starts the first tunnel connection. The resulting error will be sent on
|
||||
// s.tunnelErrors. It will send a signal via connectedSignal if registration succeed
|
||||
func (s *Supervisor) startFirstTunnel(ctx context.Context, connectedSignal chan struct{}) {
|
||||
err := ServeTunnelLoop(ctx, s.config, s.getEdgeIP(0), 0, connectedSignal)
|
||||
defer func() {
|
||||
s.tunnelErrors <- tunnelError{index: 0, err: err}
|
||||
}()
|
||||
|
||||
for s.unusedIPs() {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return
|
||||
default:
|
||||
}
|
||||
switch err.(type) {
|
||||
case nil:
|
||||
return
|
||||
// try the next address if it was a dialError(network problem) or
|
||||
// dupConnRegisterTunnelError
|
||||
case dialError, dupConnRegisterTunnelError:
|
||||
s.replaceEdgeIP(0)
|
||||
default:
|
||||
return
|
||||
}
|
||||
err = ServeTunnelLoop(ctx, s.config, s.getEdgeIP(0), 0, connectedSignal)
|
||||
}
|
||||
}
|
||||
|
||||
// startTunnel starts a new tunnel connection. The resulting error will be sent on
|
||||
// s.tunnelErrors.
|
||||
func (s *Supervisor) startTunnel(ctx context.Context, index int, connectedSignal chan struct{}) {
|
||||
err := ServeTunnelLoop(ctx, s.config, s.getEdgeIP(index), uint8(index), connectedSignal)
|
||||
s.tunnelErrors <- tunnelError{index: index, err: err}
|
||||
}
|
||||
|
||||
func (s *Supervisor) newConnectedTunnelSignal(index int) chan struct{} {
|
||||
signal := make(chan struct{})
|
||||
s.tunnelsConnecting[index] = signal
|
||||
s.nextConnectedSignal = signal
|
||||
s.nextConnectedIndex = index
|
||||
return signal
|
||||
}
|
||||
|
||||
func (s *Supervisor) waitForNextTunnel(index int) bool {
|
||||
delete(s.tunnelsConnecting, index)
|
||||
s.nextConnectedSignal = nil
|
||||
for k, v := range s.tunnelsConnecting {
|
||||
s.nextConnectedIndex = k
|
||||
s.nextConnectedSignal = v
|
||||
return true
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
func (s *Supervisor) getEdgeIP(index int) *net.TCPAddr {
|
||||
return s.edgeIPs[index%len(s.edgeIPs)]
|
||||
}
|
||||
|
||||
func (s *Supervisor) refreshEdgeIPs() {
|
||||
if s.resolverC != nil {
|
||||
return
|
||||
}
|
||||
if time.Since(s.lastResolve) < resolveTTL {
|
||||
return
|
||||
}
|
||||
s.resolverC = make(chan resolveResult)
|
||||
go func() {
|
||||
edgeIPs, err := ResolveEdgeIPs(s.config.EdgeAddrs)
|
||||
s.resolverC <- resolveResult{edgeIPs: edgeIPs, err: err}
|
||||
}()
|
||||
}
|
||||
|
||||
func (s *Supervisor) unusedIPs() bool {
|
||||
return s.nextUnusedEdgeIP < len(s.edgeIPs)
|
||||
}
|
||||
|
||||
func (s *Supervisor) replaceEdgeIP(badIPIndex int) {
|
||||
s.edgeIPs[badIPIndex] = s.edgeIPs[s.nextUnusedEdgeIP]
|
||||
s.nextUnusedEdgeIP++
|
||||
}
|
629
origin/tunnel.go
Normal file
629
origin/tunnel.go
Normal file
@@ -0,0 +1,629 @@
|
||||
package origin
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
"crypto/tls"
|
||||
"fmt"
|
||||
"io"
|
||||
"net"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"strconv"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"golang.org/x/net/context"
|
||||
"golang.org/x/sync/errgroup"
|
||||
|
||||
"github.com/cloudflare/cloudflared/h2mux"
|
||||
"github.com/cloudflare/cloudflared/tunnelrpc"
|
||||
tunnelpogs "github.com/cloudflare/cloudflared/tunnelrpc/pogs"
|
||||
"github.com/cloudflare/cloudflared/validation"
|
||||
"github.com/cloudflare/cloudflared/websocket"
|
||||
|
||||
raven "github.com/getsentry/raven-go"
|
||||
"github.com/pkg/errors"
|
||||
_ "github.com/prometheus/client_golang/prometheus"
|
||||
log "github.com/sirupsen/logrus"
|
||||
rpc "zombiezen.com/go/capnproto2/rpc"
|
||||
)
|
||||
|
||||
const (
|
||||
dialTimeout = 15 * time.Second
|
||||
lbProbeUserAgentPrefix = "Mozilla/5.0 (compatible; Cloudflare-Traffic-Manager/1.0; +https://www.cloudflare.com/traffic-manager/;"
|
||||
TagHeaderNamePrefix = "Cf-Warp-Tag-"
|
||||
DuplicateConnectionError = "EDUPCONN"
|
||||
)
|
||||
|
||||
type TunnelConfig struct {
|
||||
EdgeAddrs []string
|
||||
OriginUrl string
|
||||
Hostname string
|
||||
OriginCert []byte
|
||||
TlsConfig *tls.Config
|
||||
ClientTlsConfig *tls.Config
|
||||
Retries uint
|
||||
HeartbeatInterval time.Duration
|
||||
MaxHeartbeats uint64
|
||||
ClientID string
|
||||
BuildInfo *BuildInfo
|
||||
ReportedVersion string
|
||||
LBPool string
|
||||
Tags []tunnelpogs.Tag
|
||||
HAConnections int
|
||||
HTTPTransport http.RoundTripper
|
||||
Metrics *TunnelMetrics
|
||||
MetricsUpdateFreq time.Duration
|
||||
ProtocolLogger *log.Logger
|
||||
Logger *log.Logger
|
||||
IsAutoupdated bool
|
||||
GracePeriod time.Duration
|
||||
RunFromTerminal bool
|
||||
NoChunkedEncoding bool
|
||||
WSGI bool
|
||||
CompressionQuality uint64
|
||||
}
|
||||
|
||||
type dialError struct {
|
||||
cause error
|
||||
}
|
||||
|
||||
func (e dialError) Error() string {
|
||||
return e.cause.Error()
|
||||
}
|
||||
|
||||
type dupConnRegisterTunnelError struct{}
|
||||
|
||||
func (e dupConnRegisterTunnelError) Error() string {
|
||||
return "already connected to this server"
|
||||
}
|
||||
|
||||
type muxerShutdownError struct{}
|
||||
|
||||
func (e muxerShutdownError) Error() string {
|
||||
return "muxer shutdown"
|
||||
}
|
||||
|
||||
// RegisterTunnel error from server
|
||||
type serverRegisterTunnelError struct {
|
||||
cause error
|
||||
permanent bool
|
||||
}
|
||||
|
||||
func (e serverRegisterTunnelError) Error() string {
|
||||
return e.cause.Error()
|
||||
}
|
||||
|
||||
// RegisterTunnel error from client
|
||||
type clientRegisterTunnelError struct {
|
||||
cause error
|
||||
}
|
||||
|
||||
func (e clientRegisterTunnelError) Error() string {
|
||||
return e.cause.Error()
|
||||
}
|
||||
|
||||
func (c *TunnelConfig) RegistrationOptions(connectionID uint8, OriginLocalIP string) *tunnelpogs.RegistrationOptions {
|
||||
policy := tunnelrpc.ExistingTunnelPolicy_balance
|
||||
if c.HAConnections <= 1 && c.LBPool == "" {
|
||||
policy = tunnelrpc.ExistingTunnelPolicy_disconnect
|
||||
}
|
||||
return &tunnelpogs.RegistrationOptions{
|
||||
ClientID: c.ClientID,
|
||||
Version: c.ReportedVersion,
|
||||
OS: fmt.Sprintf("%s_%s", c.BuildInfo.GoOS, c.BuildInfo.GoArch),
|
||||
ExistingTunnelPolicy: policy,
|
||||
PoolName: c.LBPool,
|
||||
Tags: c.Tags,
|
||||
ConnectionID: connectionID,
|
||||
OriginLocalIP: OriginLocalIP,
|
||||
IsAutoupdated: c.IsAutoupdated,
|
||||
RunFromTerminal: c.RunFromTerminal,
|
||||
CompressionQuality: c.CompressionQuality,
|
||||
}
|
||||
}
|
||||
|
||||
func StartTunnelDaemon(config *TunnelConfig, shutdownC <-chan struct{}, connectedSignal chan struct{}) error {
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
go func() {
|
||||
<-shutdownC
|
||||
cancel()
|
||||
}()
|
||||
// If a user specified negative HAConnections, we will treat it as requesting 1 connection
|
||||
if config.HAConnections > 1 {
|
||||
return NewSupervisor(config).Run(ctx, connectedSignal)
|
||||
} else {
|
||||
addrs, err := ResolveEdgeIPs(config.EdgeAddrs)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
return ServeTunnelLoop(ctx, config, addrs[0], 0, connectedSignal)
|
||||
}
|
||||
}
|
||||
|
||||
func ServeTunnelLoop(ctx context.Context,
|
||||
config *TunnelConfig,
|
||||
addr *net.TCPAddr,
|
||||
connectionID uint8,
|
||||
connectedSignal chan struct{},
|
||||
) error {
|
||||
logger := config.Logger
|
||||
config.Metrics.incrementHaConnections()
|
||||
defer config.Metrics.decrementHaConnections()
|
||||
backoff := BackoffHandler{MaxRetries: config.Retries}
|
||||
// Used to close connectedSignal no more than once
|
||||
connectedFuse := h2mux.NewBooleanFuse()
|
||||
go func() {
|
||||
if connectedFuse.Await() {
|
||||
close(connectedSignal)
|
||||
}
|
||||
}()
|
||||
// Ensure the above goroutine will terminate if we return without connecting
|
||||
defer connectedFuse.Fuse(false)
|
||||
for {
|
||||
err, recoverable := ServeTunnel(ctx, config, addr, connectionID, connectedFuse, &backoff)
|
||||
if recoverable {
|
||||
if duration, ok := backoff.GetBackoffDuration(ctx); ok {
|
||||
logger.Infof("Retrying in %s seconds", duration)
|
||||
backoff.Backoff(ctx)
|
||||
continue
|
||||
}
|
||||
}
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
func ServeTunnel(
|
||||
ctx context.Context,
|
||||
config *TunnelConfig,
|
||||
addr *net.TCPAddr,
|
||||
connectionID uint8,
|
||||
connectedFuse *h2mux.BooleanFuse,
|
||||
backoff *BackoffHandler,
|
||||
) (err error, recoverable bool) {
|
||||
// Treat panics as recoverable errors
|
||||
defer func() {
|
||||
if r := recover(); r != nil {
|
||||
var ok bool
|
||||
err, ok = r.(error)
|
||||
if !ok {
|
||||
err = fmt.Errorf("ServeTunnel: %v", r)
|
||||
}
|
||||
recoverable = true
|
||||
}
|
||||
}()
|
||||
|
||||
connectionTag := uint8ToString(connectionID)
|
||||
logger := config.Logger.WithField("connectionID", connectionTag)
|
||||
|
||||
// additional tags to send other than hostname which is set in cloudflared main package
|
||||
tags := make(map[string]string)
|
||||
tags["ha"] = connectionTag
|
||||
|
||||
// Returns error from parsing the origin URL or handshake errors
|
||||
handler, originLocalIP, err := NewTunnelHandler(ctx, config, addr.String(), connectionID)
|
||||
if err != nil {
|
||||
errLog := config.Logger.WithError(err)
|
||||
switch err.(type) {
|
||||
case dialError:
|
||||
errLog.Error("Unable to dial edge")
|
||||
case h2mux.MuxerHandshakeError:
|
||||
errLog.Error("Handshake failed with edge server")
|
||||
default:
|
||||
errLog.Error("Tunnel creation failure")
|
||||
return err, false
|
||||
}
|
||||
return err, true
|
||||
}
|
||||
|
||||
errGroup, serveCtx := errgroup.WithContext(ctx)
|
||||
|
||||
errGroup.Go(func() error {
|
||||
err := RegisterTunnel(serveCtx, handler.muxer, config, connectionID, originLocalIP)
|
||||
if err == nil {
|
||||
connectedFuse.Fuse(true)
|
||||
backoff.SetGracePeriod()
|
||||
}
|
||||
return err
|
||||
})
|
||||
|
||||
errGroup.Go(func() error {
|
||||
updateMetricsTickC := time.Tick(config.MetricsUpdateFreq)
|
||||
for {
|
||||
select {
|
||||
case <-serveCtx.Done():
|
||||
// UnregisterTunnel blocks until the RPC call returns
|
||||
err := UnregisterTunnel(handler.muxer, config.GracePeriod, config.Logger)
|
||||
handler.muxer.Shutdown()
|
||||
return err
|
||||
case <-updateMetricsTickC:
|
||||
handler.UpdateMetrics(connectionTag)
|
||||
}
|
||||
}
|
||||
})
|
||||
|
||||
errGroup.Go(func() error {
|
||||
// All routines should stop when muxer finish serving. When muxer is shutdown
|
||||
// gracefully, it doesn't return an error, so we need to return errMuxerShutdown
|
||||
// here to notify other routines to stop
|
||||
err := handler.muxer.Serve(serveCtx)
|
||||
if err == nil {
|
||||
return muxerShutdownError{}
|
||||
}
|
||||
return err
|
||||
})
|
||||
|
||||
err = errGroup.Wait()
|
||||
if err != nil {
|
||||
switch castedErr := err.(type) {
|
||||
case dupConnRegisterTunnelError:
|
||||
logger.Info("Already connected to this server, selecting a different one")
|
||||
return err, true
|
||||
case serverRegisterTunnelError:
|
||||
logger.WithError(castedErr.cause).Error("Register tunnel error from server side")
|
||||
// Don't send registration error return from server to Sentry. They are
|
||||
// logged on server side
|
||||
return castedErr.cause, !castedErr.permanent
|
||||
case clientRegisterTunnelError:
|
||||
logger.WithError(castedErr.cause).Error("Register tunnel error on client side")
|
||||
raven.CaptureError(castedErr.cause, tags)
|
||||
return err, true
|
||||
case muxerShutdownError:
|
||||
logger.Infof("Muxer shutdown")
|
||||
return err, true
|
||||
default:
|
||||
logger.WithError(err).Error("Serve tunnel error")
|
||||
raven.CaptureError(err, tags)
|
||||
return err, true
|
||||
}
|
||||
}
|
||||
return nil, true
|
||||
}
|
||||
|
||||
func IsRPCStreamResponse(headers []h2mux.Header) bool {
|
||||
if len(headers) != 1 {
|
||||
return false
|
||||
}
|
||||
if headers[0].Name != ":status" || headers[0].Value != "200" {
|
||||
return false
|
||||
}
|
||||
return true
|
||||
}
|
||||
|
||||
func RegisterTunnel(ctx context.Context, muxer *h2mux.Muxer, config *TunnelConfig, connectionID uint8, originLocalIP string) error {
|
||||
config.Logger.Debug("initiating RPC stream to register")
|
||||
stream, err := muxer.OpenStream([]h2mux.Header{
|
||||
{Name: ":method", Value: "RPC"},
|
||||
{Name: ":scheme", Value: "capnp"},
|
||||
{Name: ":path", Value: "*"},
|
||||
}, nil)
|
||||
if err != nil {
|
||||
// RPC stream open error
|
||||
return clientRegisterTunnelError{cause: err}
|
||||
}
|
||||
if !IsRPCStreamResponse(stream.Headers) {
|
||||
// stream response error
|
||||
return clientRegisterTunnelError{cause: err}
|
||||
}
|
||||
conn := rpc.NewConn(
|
||||
tunnelrpc.NewTransportLogger(config.Logger.WithField("subsystem", "rpc-register"), rpc.StreamTransport(stream)),
|
||||
tunnelrpc.ConnLog(config.Logger.WithField("subsystem", "rpc-transport")),
|
||||
)
|
||||
defer conn.Close()
|
||||
ts := tunnelpogs.TunnelServer_PogsClient{Client: conn.Bootstrap(ctx)}
|
||||
// Request server info without blocking tunnel registration; must use capnp library directly.
|
||||
tsClient := tunnelrpc.TunnelServer{Client: ts.Client}
|
||||
serverInfoPromise := tsClient.GetServerInfo(ctx, func(tunnelrpc.TunnelServer_getServerInfo_Params) error {
|
||||
return nil
|
||||
})
|
||||
registration, err := ts.RegisterTunnel(
|
||||
ctx,
|
||||
config.OriginCert,
|
||||
config.Hostname,
|
||||
config.RegistrationOptions(connectionID, originLocalIP),
|
||||
)
|
||||
LogServerInfo(serverInfoPromise.Result(), connectionID, config.Metrics, config.Logger)
|
||||
if err != nil {
|
||||
// RegisterTunnel RPC failure
|
||||
return clientRegisterTunnelError{cause: err}
|
||||
}
|
||||
for _, logLine := range registration.LogLines {
|
||||
config.Logger.Info(logLine)
|
||||
}
|
||||
if registration.Err == DuplicateConnectionError {
|
||||
return dupConnRegisterTunnelError{}
|
||||
} else if registration.Err != "" {
|
||||
return serverRegisterTunnelError{
|
||||
cause: fmt.Errorf("Server error: %s", registration.Err),
|
||||
permanent: registration.PermanentFailure,
|
||||
}
|
||||
}
|
||||
|
||||
config.Logger.Info("Tunnel ID: " + registration.TunnelID)
|
||||
config.Logger.Infof("Route propagating, it may take up to 1 minute for your new route to become functional")
|
||||
return nil
|
||||
}
|
||||
|
||||
func UnregisterTunnel(muxer *h2mux.Muxer, gracePeriod time.Duration, logger *log.Logger) error {
|
||||
logger.Debug("initiating RPC stream to unregister")
|
||||
stream, err := muxer.OpenStream([]h2mux.Header{
|
||||
{Name: ":method", Value: "RPC"},
|
||||
{Name: ":scheme", Value: "capnp"},
|
||||
{Name: ":path", Value: "*"},
|
||||
}, nil)
|
||||
if err != nil {
|
||||
// RPC stream open error
|
||||
return err
|
||||
}
|
||||
if !IsRPCStreamResponse(stream.Headers) {
|
||||
// stream response error
|
||||
return err
|
||||
}
|
||||
ctx := context.Background()
|
||||
conn := rpc.NewConn(
|
||||
tunnelrpc.NewTransportLogger(logger.WithField("subsystem", "rpc-unregister"), rpc.StreamTransport(stream)),
|
||||
tunnelrpc.ConnLog(logger.WithField("subsystem", "rpc-transport")),
|
||||
)
|
||||
defer conn.Close()
|
||||
ts := tunnelpogs.TunnelServer_PogsClient{Client: conn.Bootstrap(ctx)}
|
||||
// gracePeriod is encoded in int64 using capnproto
|
||||
return ts.UnregisterTunnel(ctx, gracePeriod.Nanoseconds())
|
||||
}
|
||||
|
||||
func LogServerInfo(
|
||||
promise tunnelrpc.ServerInfo_Promise,
|
||||
connectionID uint8,
|
||||
metrics *TunnelMetrics,
|
||||
logger *log.Logger,
|
||||
) {
|
||||
serverInfoMessage, err := promise.Struct()
|
||||
if err != nil {
|
||||
logger.WithError(err).Warn("Failed to retrieve server information")
|
||||
return
|
||||
}
|
||||
serverInfo, err := tunnelpogs.UnmarshalServerInfo(serverInfoMessage)
|
||||
if err != nil {
|
||||
logger.WithError(err).Warn("Failed to retrieve server information")
|
||||
return
|
||||
}
|
||||
logger.Infof("Connected to %s", serverInfo.LocationName)
|
||||
metrics.registerServerLocation(uint8ToString(connectionID), serverInfo.LocationName)
|
||||
}
|
||||
|
||||
func H2RequestHeadersToH1Request(h2 []h2mux.Header, h1 *http.Request) error {
|
||||
for _, header := range h2 {
|
||||
switch header.Name {
|
||||
case ":method":
|
||||
h1.Method = header.Value
|
||||
case ":scheme":
|
||||
case ":authority":
|
||||
// Otherwise the host header will be based on the origin URL
|
||||
h1.Host = header.Value
|
||||
case ":path":
|
||||
u, err := url.Parse(header.Value)
|
||||
if err != nil {
|
||||
return fmt.Errorf("unparseable path")
|
||||
}
|
||||
resolved := h1.URL.ResolveReference(u)
|
||||
// prevent escaping base URL
|
||||
if !strings.HasPrefix(resolved.String(), h1.URL.String()) {
|
||||
return fmt.Errorf("invalid path")
|
||||
}
|
||||
h1.URL = resolved
|
||||
default:
|
||||
h1.Header.Add(http.CanonicalHeaderKey(header.Name), header.Value)
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func H1ResponseToH2Response(h1 *http.Response) (h2 []h2mux.Header) {
|
||||
h2 = []h2mux.Header{{Name: ":status", Value: fmt.Sprintf("%d", h1.StatusCode)}}
|
||||
for headerName, headerValues := range h1.Header {
|
||||
for _, headerValue := range headerValues {
|
||||
h2 = append(h2, h2mux.Header{Name: strings.ToLower(headerName), Value: headerValue})
|
||||
}
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
func FindCfRayHeader(h1 *http.Request) string {
|
||||
return h1.Header.Get("Cf-Ray")
|
||||
}
|
||||
|
||||
type TunnelHandler struct {
|
||||
originUrl string
|
||||
muxer *h2mux.Muxer
|
||||
httpClient http.RoundTripper
|
||||
tlsConfig *tls.Config
|
||||
tags []tunnelpogs.Tag
|
||||
metrics *TunnelMetrics
|
||||
// connectionID is only used by metrics, and prometheus requires labels to be string
|
||||
connectionID string
|
||||
logger *log.Logger
|
||||
noChunkedEncoding bool
|
||||
}
|
||||
|
||||
var dialer = net.Dialer{DualStack: true}
|
||||
|
||||
// NewTunnelHandler returns a TunnelHandler, origin LAN IP and error
|
||||
func NewTunnelHandler(ctx context.Context,
|
||||
config *TunnelConfig,
|
||||
addr string,
|
||||
connectionID uint8,
|
||||
) (*TunnelHandler, string, error) {
|
||||
originURL, err := validation.ValidateUrl(config.OriginUrl)
|
||||
if err != nil {
|
||||
return nil, "", fmt.Errorf("Unable to parse origin url %#v", originURL)
|
||||
}
|
||||
h := &TunnelHandler{
|
||||
originUrl: originURL,
|
||||
httpClient: config.HTTPTransport,
|
||||
tlsConfig: config.ClientTlsConfig,
|
||||
tags: config.Tags,
|
||||
metrics: config.Metrics,
|
||||
connectionID: uint8ToString(connectionID),
|
||||
logger: config.Logger,
|
||||
noChunkedEncoding: config.NoChunkedEncoding,
|
||||
}
|
||||
if h.httpClient == nil {
|
||||
h.httpClient = http.DefaultTransport
|
||||
}
|
||||
// Inherit from parent context so we can cancel (Ctrl-C) while dialing
|
||||
dialCtx, dialCancel := context.WithTimeout(ctx, dialTimeout)
|
||||
// TUN-92: enforce a timeout on dial and handshake (as tls.Dial does not support one)
|
||||
plaintextEdgeConn, err := dialer.DialContext(dialCtx, "tcp", addr)
|
||||
dialCancel()
|
||||
if err != nil {
|
||||
return nil, "", dialError{cause: errors.Wrap(err, "DialContext error")}
|
||||
}
|
||||
edgeConn := tls.Client(plaintextEdgeConn, config.TlsConfig)
|
||||
edgeConn.SetDeadline(time.Now().Add(dialTimeout))
|
||||
err = edgeConn.Handshake()
|
||||
if err != nil {
|
||||
return nil, "", dialError{cause: errors.Wrap(err, "Handshake with edge error")}
|
||||
}
|
||||
// clear the deadline on the conn; h2mux has its own timeouts
|
||||
edgeConn.SetDeadline(time.Time{})
|
||||
// Establish a muxed connection with the edge
|
||||
// Client mux handshake with agent server
|
||||
h.muxer, err = h2mux.Handshake(edgeConn, edgeConn, h2mux.MuxerConfig{
|
||||
Timeout: 5 * time.Second,
|
||||
Handler: h,
|
||||
IsClient: true,
|
||||
HeartbeatInterval: config.HeartbeatInterval,
|
||||
MaxHeartbeats: config.MaxHeartbeats,
|
||||
Logger: config.ProtocolLogger.WithFields(log.Fields{}),
|
||||
CompressionQuality: h2mux.CompressionSetting(config.CompressionQuality),
|
||||
})
|
||||
if err != nil {
|
||||
return h, "", errors.New("TLS handshake error")
|
||||
}
|
||||
return h, edgeConn.LocalAddr().String(), err
|
||||
}
|
||||
|
||||
func (h *TunnelHandler) AppendTagHeaders(r *http.Request) {
|
||||
for _, tag := range h.tags {
|
||||
r.Header.Add(TagHeaderNamePrefix+tag.Name, tag.Value)
|
||||
}
|
||||
}
|
||||
|
||||
func (h *TunnelHandler) ServeStream(stream *h2mux.MuxedStream) error {
|
||||
h.metrics.incrementRequests(h.connectionID)
|
||||
req, err := http.NewRequest("GET", h.originUrl, h2mux.MuxedStreamReader{MuxedStream: stream})
|
||||
if err != nil {
|
||||
h.logger.WithError(err).Panic("Unexpected error from http.NewRequest")
|
||||
}
|
||||
err = H2RequestHeadersToH1Request(stream.Headers, req)
|
||||
if err != nil {
|
||||
h.logger.WithError(err).Error("invalid request received")
|
||||
}
|
||||
h.AppendTagHeaders(req)
|
||||
cfRay := FindCfRayHeader(req)
|
||||
lbProbe := isLBProbeRequest(req)
|
||||
h.logRequest(req, cfRay, lbProbe)
|
||||
if websocket.IsWebSocketUpgrade(req) {
|
||||
conn, response, err := websocket.ClientConnect(req, h.tlsConfig)
|
||||
if err != nil {
|
||||
h.logError(stream, err)
|
||||
} else {
|
||||
stream.WriteHeaders(H1ResponseToH2Response(response))
|
||||
defer conn.Close()
|
||||
// Copy to/from stream to the undelying connection. Use the underlying
|
||||
// connection because cloudflared doesn't operate on the message themselves
|
||||
websocket.Stream(conn.UnderlyingConn(), stream)
|
||||
h.metrics.incrementResponses(h.connectionID, "200")
|
||||
h.logResponse(response, cfRay, lbProbe)
|
||||
}
|
||||
} else {
|
||||
// Support for WSGI Servers by switching transfer encoding from chunked to gzip/deflate
|
||||
if h.noChunkedEncoding {
|
||||
req.TransferEncoding = []string{"gzip", "deflate"}
|
||||
cLength, err := strconv.Atoi(req.Header.Get("Content-Length"))
|
||||
if err == nil {
|
||||
req.ContentLength = int64(cLength)
|
||||
}
|
||||
}
|
||||
|
||||
response, err := h.httpClient.RoundTrip(req)
|
||||
|
||||
if err != nil {
|
||||
h.logError(stream, err)
|
||||
} else {
|
||||
defer response.Body.Close()
|
||||
stream.WriteHeaders(H1ResponseToH2Response(response))
|
||||
if h.isEventStream(response) {
|
||||
h.writeEventStream(stream, response.Body)
|
||||
} else {
|
||||
// Use CopyBuffer, because Copy only allocates a 32KiB buffer, and cross-stream
|
||||
// compression generates dictionary on first write
|
||||
io.CopyBuffer(stream, response.Body, make([]byte, 512*1024))
|
||||
}
|
||||
|
||||
h.metrics.incrementResponses(h.connectionID, "200")
|
||||
h.logResponse(response, cfRay, lbProbe)
|
||||
}
|
||||
}
|
||||
h.metrics.decrementConcurrentRequests(h.connectionID)
|
||||
return nil
|
||||
}
|
||||
|
||||
func (h *TunnelHandler) writeEventStream(stream *h2mux.MuxedStream, responseBody io.ReadCloser) {
|
||||
reader := bufio.NewReader(responseBody)
|
||||
for {
|
||||
line, err := reader.ReadBytes('\n')
|
||||
if err != nil {
|
||||
break
|
||||
}
|
||||
stream.Write(line)
|
||||
}
|
||||
}
|
||||
|
||||
func (h *TunnelHandler) isEventStream(response *http.Response) bool {
|
||||
if response.Header.Get("content-type") == "text/event-stream" {
|
||||
h.logger.Debug("Detected Server-Side Events from Origin")
|
||||
return true
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
func (h *TunnelHandler) logError(stream *h2mux.MuxedStream, err error) {
|
||||
h.logger.WithError(err).Error("HTTP request error")
|
||||
stream.WriteHeaders([]h2mux.Header{{Name: ":status", Value: "502"}})
|
||||
stream.Write([]byte("502 Bad Gateway"))
|
||||
h.metrics.incrementResponses(h.connectionID, "502")
|
||||
}
|
||||
|
||||
func (h *TunnelHandler) logRequest(req *http.Request, cfRay string, lbProbe bool) {
|
||||
if cfRay != "" {
|
||||
h.logger.WithField("CF-RAY", cfRay).Debugf("%s %s %s", req.Method, req.URL, req.Proto)
|
||||
} else if lbProbe {
|
||||
h.logger.Debugf("Load Balancer health check %s %s %s", req.Method, req.URL, req.Proto)
|
||||
} else {
|
||||
h.logger.Warnf("All requests should have a CF-RAY header. Please open a support ticket with Cloudflare. %s %s %s ", req.Method, req.URL, req.Proto)
|
||||
}
|
||||
h.logger.Debugf("Request Headers %+v", req.Header)
|
||||
}
|
||||
|
||||
func (h *TunnelHandler) logResponse(r *http.Response, cfRay string, lbProbe bool) {
|
||||
if cfRay != "" {
|
||||
h.logger.WithField("CF-RAY", cfRay).Debugf("%s", r.Status)
|
||||
} else if lbProbe {
|
||||
h.logger.Debugf("Response to Load Balancer health check %s", r.Status)
|
||||
} else {
|
||||
h.logger.Infof("%s", r.Status)
|
||||
}
|
||||
h.logger.Debugf("Response Headers %+v", r.Header)
|
||||
}
|
||||
|
||||
func (h *TunnelHandler) UpdateMetrics(connectionID string) {
|
||||
h.metrics.updateMuxerMetrics(connectionID, h.muxer.Metrics())
|
||||
}
|
||||
|
||||
func uint8ToString(input uint8) string {
|
||||
return strconv.FormatUint(uint64(input), 10)
|
||||
}
|
||||
|
||||
func isLBProbeRequest(req *http.Request) bool {
|
||||
return strings.HasPrefix(req.UserAgent(), lbProbeUserAgentPrefix)
|
||||
}
|
Reference in New Issue
Block a user