mirror of
https://github.com/cloudflare/cloudflared.git
synced 2025-07-27 00:49:57 +00:00
TUN-3458: Upgrade to http2 when available, fallback to h2mux when we reach max retries
This commit is contained in:
@@ -1,8 +1,6 @@
|
||||
package connection
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"hash/fnv"
|
||||
"io"
|
||||
"net/http"
|
||||
"strconv"
|
||||
@@ -13,10 +11,6 @@ import (
|
||||
)
|
||||
|
||||
const (
|
||||
// edgeH2muxTLSServerName is the server name to establish h2mux connection with edge
|
||||
edgeH2muxTLSServerName = "cftunnel.com"
|
||||
// edgeH2TLSServerName is the server name to establish http2 connection with edge
|
||||
edgeH2TLSServerName = "h2.cftunnel.com"
|
||||
lbProbeUserAgentPrefix = "Mozilla/5.0 (compatible; Cloudflare-Traffic-Manager/1.0; +https://www.cloudflare.com/traffic-manager/;"
|
||||
)
|
||||
|
||||
@@ -43,57 +37,6 @@ func (c *ClassicTunnelConfig) IsTrialZone() bool {
|
||||
return c.Hostname == ""
|
||||
}
|
||||
|
||||
type Protocol int64
|
||||
|
||||
const (
|
||||
H2mux Protocol = iota
|
||||
HTTP2
|
||||
)
|
||||
|
||||
func SelectProtocol(s string, accountTag string, http2Percentage uint32) (Protocol, bool) {
|
||||
switch s {
|
||||
case "h2mux":
|
||||
return H2mux, true
|
||||
case "http2":
|
||||
return HTTP2, true
|
||||
case "auto":
|
||||
if tryHTTP2(accountTag, http2Percentage) {
|
||||
return HTTP2, true
|
||||
}
|
||||
return H2mux, true
|
||||
default:
|
||||
return 0, false
|
||||
}
|
||||
}
|
||||
|
||||
func tryHTTP2(accountTag string, http2Percentage uint32) bool {
|
||||
h := fnv.New32a()
|
||||
h.Write([]byte(accountTag))
|
||||
return h.Sum32()%100 < http2Percentage
|
||||
}
|
||||
|
||||
func (p Protocol) ServerName() string {
|
||||
switch p {
|
||||
case H2mux:
|
||||
return edgeH2muxTLSServerName
|
||||
case HTTP2:
|
||||
return edgeH2TLSServerName
|
||||
default:
|
||||
return ""
|
||||
}
|
||||
}
|
||||
|
||||
func (p Protocol) String() string {
|
||||
switch p {
|
||||
case H2mux:
|
||||
return "h2mux"
|
||||
case HTTP2:
|
||||
return "http2"
|
||||
default:
|
||||
return fmt.Sprintf("unknown protocol")
|
||||
}
|
||||
}
|
||||
|
||||
type OriginClient interface {
|
||||
Proxy(w ResponseWriter, req *http.Request, isWebsocket bool) error
|
||||
}
|
||||
|
@@ -37,7 +37,7 @@ type HTTP2Connection struct {
|
||||
connectedFuse ConnectedFuse
|
||||
}
|
||||
|
||||
func NewHTTP2Connection(conn net.Conn, config *Config, originURL *url.URL, namedTunnelConfig *NamedTunnelConfig, connOptions *tunnelpogs.ConnectionOptions, observer *Observer, connIndex uint8, connectedFuse ConnectedFuse) (*HTTP2Connection, error) {
|
||||
func NewHTTP2Connection(conn net.Conn, config *Config, originURL *url.URL, namedTunnelConfig *NamedTunnelConfig, connOptions *tunnelpogs.ConnectionOptions, observer *Observer, connIndex uint8, connectedFuse ConnectedFuse) *HTTP2Connection {
|
||||
return &HTTP2Connection{
|
||||
conn: conn,
|
||||
server: &http2.Server{
|
||||
@@ -52,7 +52,7 @@ func NewHTTP2Connection(conn net.Conn, config *Config, originURL *url.URL, named
|
||||
connIndex: connIndex,
|
||||
wg: &sync.WaitGroup{},
|
||||
connectedFuse: connectedFuse,
|
||||
}, nil
|
||||
}
|
||||
}
|
||||
|
||||
func (c *HTTP2Connection) Serve(ctx context.Context) {
|
||||
|
@@ -299,7 +299,7 @@ func convertRTTMilliSec(t time.Duration) float64 {
|
||||
}
|
||||
|
||||
// Metrics that can be collected without asking the edge
|
||||
func newTunnelMetrics(protocol Protocol) *tunnelMetrics {
|
||||
func newTunnelMetrics() *tunnelMetrics {
|
||||
maxConcurrentRequestsPerTunnel := prometheus.NewGaugeVec(
|
||||
prometheus.GaugeOpts{
|
||||
Namespace: MetricsNamespace,
|
||||
@@ -374,16 +374,12 @@ func newTunnelMetrics(protocol Protocol) *tunnelMetrics {
|
||||
[]string{"rpcName"},
|
||||
)
|
||||
prometheus.MustRegister(registerSuccess)
|
||||
var muxerMetrics *muxerMetrics
|
||||
if protocol == H2mux {
|
||||
muxerMetrics = newMuxerMetrics()
|
||||
}
|
||||
|
||||
return &tunnelMetrics{
|
||||
timerRetries: timerRetries,
|
||||
serverLocations: serverLocations,
|
||||
oldServerLocations: make(map[string]string),
|
||||
muxerMetrics: muxerMetrics,
|
||||
muxerMetrics: newMuxerMetrics(),
|
||||
tunnelsHA: NewTunnelsForHA(),
|
||||
regSuccess: registerSuccess,
|
||||
regFail: registerFail,
|
||||
|
@@ -16,10 +16,10 @@ type Observer struct {
|
||||
tunnelEventChan chan<- ui.TunnelEvent
|
||||
}
|
||||
|
||||
func NewObserver(logger logger.Service, tunnelEventChan chan<- ui.TunnelEvent, protocol Protocol) *Observer {
|
||||
func NewObserver(logger logger.Service, tunnelEventChan chan<- ui.TunnelEvent) *Observer {
|
||||
return &Observer{
|
||||
logger,
|
||||
newTunnelMetrics(protocol),
|
||||
newTunnelMetrics(),
|
||||
tunnelEventChan,
|
||||
}
|
||||
}
|
||||
|
@@ -9,7 +9,7 @@ import (
|
||||
)
|
||||
|
||||
// can only be called once
|
||||
var m = newTunnelMetrics(H2mux)
|
||||
var m = newTunnelMetrics()
|
||||
|
||||
func TestRegisterServerLocation(t *testing.T) {
|
||||
tunnels := 20
|
||||
|
179
connection/protocol.go
Normal file
179
connection/protocol.go
Normal file
@@ -0,0 +1,179 @@
|
||||
package connection
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"hash/fnv"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/cloudflare/cloudflared/logger"
|
||||
)
|
||||
|
||||
const (
|
||||
AvailableProtocolFlagMessage = "Available protocols: http2 - Go's implementation, h2mux - Cloudflare's implementation of HTTP/2, and auto - automatically select between http2 and h2mux"
|
||||
// edgeH2muxTLSServerName is the server name to establish h2mux connection with edge
|
||||
edgeH2muxTLSServerName = "cftunnel.com"
|
||||
// edgeH2TLSServerName is the server name to establish http2 connection with edge
|
||||
edgeH2TLSServerName = "h2.cftunnel.com"
|
||||
// threshold to switch back to h2mux when the user intentionally pick --protocol http2
|
||||
explicitHTTP2FallbackThreshold = -1
|
||||
autoSelectFlag = "auto"
|
||||
)
|
||||
|
||||
var (
|
||||
ProtocolList = []Protocol{H2mux, HTTP2}
|
||||
)
|
||||
|
||||
type Protocol int64
|
||||
|
||||
const (
|
||||
H2mux Protocol = iota
|
||||
HTTP2
|
||||
)
|
||||
|
||||
func (p Protocol) ServerName() string {
|
||||
switch p {
|
||||
case H2mux:
|
||||
return edgeH2muxTLSServerName
|
||||
case HTTP2:
|
||||
return edgeH2TLSServerName
|
||||
default:
|
||||
return ""
|
||||
}
|
||||
}
|
||||
|
||||
// Fallback returns the fallback protocol and whether the protocol has a fallback
|
||||
func (p Protocol) fallback() (Protocol, bool) {
|
||||
switch p {
|
||||
case H2mux:
|
||||
return 0, false
|
||||
case HTTP2:
|
||||
return H2mux, true
|
||||
default:
|
||||
return 0, false
|
||||
}
|
||||
}
|
||||
|
||||
func (p Protocol) String() string {
|
||||
switch p {
|
||||
case H2mux:
|
||||
return "h2mux"
|
||||
case HTTP2:
|
||||
return "http2"
|
||||
default:
|
||||
return fmt.Sprintf("unknown protocol")
|
||||
}
|
||||
}
|
||||
|
||||
type ProtocolSelector interface {
|
||||
Current() Protocol
|
||||
Fallback() (Protocol, bool)
|
||||
}
|
||||
|
||||
type staticProtocolSelector struct {
|
||||
current Protocol
|
||||
}
|
||||
|
||||
func (s *staticProtocolSelector) Current() Protocol {
|
||||
return s.current
|
||||
}
|
||||
|
||||
func (s *staticProtocolSelector) Fallback() (Protocol, bool) {
|
||||
return 0, false
|
||||
}
|
||||
|
||||
type autoProtocolSelector struct {
|
||||
lock sync.RWMutex
|
||||
current Protocol
|
||||
switchThrehold int32
|
||||
fetchFunc PercentageFetcher
|
||||
refreshAfter time.Time
|
||||
ttl time.Duration
|
||||
logger logger.Service
|
||||
}
|
||||
|
||||
func newAutoProtocolSelector(
|
||||
current Protocol,
|
||||
switchThrehold int32,
|
||||
fetchFunc PercentageFetcher,
|
||||
ttl time.Duration,
|
||||
logger logger.Service,
|
||||
) *autoProtocolSelector {
|
||||
return &autoProtocolSelector{
|
||||
current: current,
|
||||
switchThrehold: switchThrehold,
|
||||
fetchFunc: fetchFunc,
|
||||
refreshAfter: time.Now().Add(ttl),
|
||||
ttl: ttl,
|
||||
logger: logger,
|
||||
}
|
||||
}
|
||||
|
||||
func (s *autoProtocolSelector) Current() Protocol {
|
||||
s.lock.Lock()
|
||||
defer s.lock.Unlock()
|
||||
if time.Now().Before(s.refreshAfter) {
|
||||
return s.current
|
||||
}
|
||||
|
||||
percentage, err := s.fetchFunc()
|
||||
if err != nil {
|
||||
s.logger.Errorf("Failed to refresh protocol, err: %v", err)
|
||||
return s.current
|
||||
}
|
||||
|
||||
if s.switchThrehold < percentage {
|
||||
s.current = HTTP2
|
||||
} else {
|
||||
s.current = H2mux
|
||||
}
|
||||
s.refreshAfter = time.Now().Add(s.ttl)
|
||||
return s.current
|
||||
}
|
||||
|
||||
func (s *autoProtocolSelector) Fallback() (Protocol, bool) {
|
||||
s.lock.RLock()
|
||||
defer s.lock.RUnlock()
|
||||
return s.current.fallback()
|
||||
}
|
||||
|
||||
type PercentageFetcher func() (int32, error)
|
||||
|
||||
func NewProtocolSelector(protocolFlag string, namedTunnel *NamedTunnelConfig, fetchFunc PercentageFetcher, ttl time.Duration, logger logger.Service) (ProtocolSelector, error) {
|
||||
if namedTunnel == nil {
|
||||
return &staticProtocolSelector{
|
||||
current: H2mux,
|
||||
}, nil
|
||||
}
|
||||
if protocolFlag == H2mux.String() {
|
||||
return &staticProtocolSelector{
|
||||
current: H2mux,
|
||||
}, nil
|
||||
}
|
||||
|
||||
http2Percentage, err := fetchFunc()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if protocolFlag == HTTP2.String() {
|
||||
if http2Percentage < 0 {
|
||||
return newAutoProtocolSelector(H2mux, explicitHTTP2FallbackThreshold, fetchFunc, ttl, logger), nil
|
||||
}
|
||||
return newAutoProtocolSelector(HTTP2, explicitHTTP2FallbackThreshold, fetchFunc, ttl, logger), nil
|
||||
}
|
||||
|
||||
if protocolFlag != autoSelectFlag {
|
||||
return nil, fmt.Errorf("Unknown protocol %s, %s", protocolFlag, AvailableProtocolFlagMessage)
|
||||
}
|
||||
threshold := switchThreshold(namedTunnel.Auth.AccountTag)
|
||||
if threshold < http2Percentage {
|
||||
return newAutoProtocolSelector(HTTP2, threshold, fetchFunc, ttl, logger), nil
|
||||
}
|
||||
return newAutoProtocolSelector(H2mux, threshold, fetchFunc, ttl, logger), nil
|
||||
}
|
||||
|
||||
func switchThreshold(accountTag string) int32 {
|
||||
h := fnv.New32a()
|
||||
h.Write([]byte(accountTag))
|
||||
return int32(h.Sum32() % 100)
|
||||
}
|
220
connection/protocol_test.go
Normal file
220
connection/protocol_test.go
Normal file
@@ -0,0 +1,220 @@
|
||||
package connection
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/cloudflare/cloudflared/logger"
|
||||
"github.com/cloudflare/cloudflared/tunnelrpc/pogs"
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
const (
|
||||
testNoTTL = 0
|
||||
)
|
||||
|
||||
var (
|
||||
testNamedTunnelConfig = &NamedTunnelConfig{
|
||||
Auth: pogs.TunnelAuth{
|
||||
AccountTag: "testAccountTag",
|
||||
},
|
||||
}
|
||||
)
|
||||
|
||||
func mockFetcher(percentage int32) PercentageFetcher {
|
||||
return func() (int32, error) {
|
||||
return percentage, nil
|
||||
}
|
||||
}
|
||||
|
||||
func mockFetcherWithError() PercentageFetcher {
|
||||
return func() (int32, error) {
|
||||
return 0, fmt.Errorf("failed to fetch precentage")
|
||||
}
|
||||
}
|
||||
|
||||
type dynamicMockFetcher struct {
|
||||
percentage int32
|
||||
err error
|
||||
}
|
||||
|
||||
func (dmf *dynamicMockFetcher) fetch() PercentageFetcher {
|
||||
return func() (int32, error) {
|
||||
if dmf.err != nil {
|
||||
return 0, dmf.err
|
||||
}
|
||||
return dmf.percentage, nil
|
||||
}
|
||||
}
|
||||
|
||||
func TestNewProtocolSelector(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
protocol string
|
||||
expectedProtocol Protocol
|
||||
hasFallback bool
|
||||
expectedFallback Protocol
|
||||
namedTunnelConfig *NamedTunnelConfig
|
||||
fetchFunc PercentageFetcher
|
||||
wantErr bool
|
||||
}{
|
||||
{
|
||||
name: "classic tunnel",
|
||||
protocol: "h2mux",
|
||||
expectedProtocol: H2mux,
|
||||
namedTunnelConfig: nil,
|
||||
},
|
||||
{
|
||||
name: "named tunnel over h2mux",
|
||||
protocol: "h2mux",
|
||||
expectedProtocol: H2mux,
|
||||
namedTunnelConfig: testNamedTunnelConfig,
|
||||
},
|
||||
{
|
||||
name: "named tunnel over http2",
|
||||
protocol: "http2",
|
||||
expectedProtocol: HTTP2,
|
||||
hasFallback: true,
|
||||
expectedFallback: H2mux,
|
||||
fetchFunc: mockFetcher(0),
|
||||
namedTunnelConfig: testNamedTunnelConfig,
|
||||
},
|
||||
{
|
||||
name: "named tunnel http2 disabled",
|
||||
protocol: "http2",
|
||||
expectedProtocol: H2mux,
|
||||
fetchFunc: mockFetcher(-1),
|
||||
namedTunnelConfig: testNamedTunnelConfig,
|
||||
},
|
||||
{
|
||||
name: "named tunnel auto all http2 disabled",
|
||||
protocol: "auto",
|
||||
expectedProtocol: H2mux,
|
||||
fetchFunc: mockFetcher(-1),
|
||||
namedTunnelConfig: testNamedTunnelConfig,
|
||||
},
|
||||
{
|
||||
name: "named tunnel auto to h2mux",
|
||||
protocol: "auto",
|
||||
expectedProtocol: H2mux,
|
||||
fetchFunc: mockFetcher(0),
|
||||
namedTunnelConfig: testNamedTunnelConfig,
|
||||
},
|
||||
{
|
||||
name: "named tunnel auto to http2",
|
||||
protocol: "auto",
|
||||
expectedProtocol: HTTP2,
|
||||
hasFallback: true,
|
||||
expectedFallback: H2mux,
|
||||
fetchFunc: mockFetcher(100),
|
||||
namedTunnelConfig: testNamedTunnelConfig,
|
||||
},
|
||||
{
|
||||
// None named tunnel can only use h2mux, so specifying an unknown protocol is not an error
|
||||
name: "classic tunnel unknown protocol",
|
||||
protocol: "unknown",
|
||||
expectedProtocol: H2mux,
|
||||
},
|
||||
{
|
||||
name: "named tunnel unknown protocol",
|
||||
protocol: "unknown",
|
||||
fetchFunc: mockFetcher(100),
|
||||
namedTunnelConfig: testNamedTunnelConfig,
|
||||
wantErr: true,
|
||||
},
|
||||
{
|
||||
name: "named tunnel fetch error",
|
||||
protocol: "unknown",
|
||||
fetchFunc: mockFetcherWithError(),
|
||||
namedTunnelConfig: testNamedTunnelConfig,
|
||||
wantErr: true,
|
||||
},
|
||||
}
|
||||
logger, _ := logger.New()
|
||||
for _, test := range tests {
|
||||
selector, err := NewProtocolSelector(test.protocol, test.namedTunnelConfig, test.fetchFunc, testNoTTL, logger)
|
||||
if test.wantErr {
|
||||
assert.Error(t, err, fmt.Sprintf("test %s failed", test.name))
|
||||
} else {
|
||||
assert.NoError(t, err, fmt.Sprintf("test %s failed", test.name))
|
||||
assert.Equal(t, test.expectedProtocol, selector.Current(), fmt.Sprintf("test %s failed", test.name))
|
||||
fallback, ok := selector.Fallback()
|
||||
assert.Equal(t, test.hasFallback, ok, fmt.Sprintf("test %s failed", test.name))
|
||||
if test.hasFallback {
|
||||
assert.Equal(t, test.expectedFallback, fallback, fmt.Sprintf("test %s failed", test.name))
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestAutoProtocolSelectorRefresh(t *testing.T) {
|
||||
logger, _ := logger.New()
|
||||
fetcher := dynamicMockFetcher{}
|
||||
selector, err := NewProtocolSelector("auto", testNamedTunnelConfig, fetcher.fetch(), testNoTTL, logger)
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, H2mux, selector.Current())
|
||||
|
||||
fetcher.percentage = 100
|
||||
assert.Equal(t, HTTP2, selector.Current())
|
||||
|
||||
fetcher.percentage = 0
|
||||
assert.Equal(t, H2mux, selector.Current())
|
||||
|
||||
fetcher.percentage = 100
|
||||
assert.Equal(t, HTTP2, selector.Current())
|
||||
|
||||
fetcher.err = fmt.Errorf("failed to fetch")
|
||||
assert.Equal(t, HTTP2, selector.Current())
|
||||
|
||||
fetcher.percentage = -1
|
||||
fetcher.err = nil
|
||||
assert.Equal(t, H2mux, selector.Current())
|
||||
|
||||
fetcher.percentage = 0
|
||||
assert.Equal(t, H2mux, selector.Current())
|
||||
|
||||
fetcher.percentage = 100
|
||||
assert.Equal(t, HTTP2, selector.Current())
|
||||
}
|
||||
|
||||
func TestHTTP2ProtocolSelectorRefresh(t *testing.T) {
|
||||
logger, _ := logger.New()
|
||||
fetcher := dynamicMockFetcher{}
|
||||
selector, err := NewProtocolSelector("http2", testNamedTunnelConfig, fetcher.fetch(), testNoTTL, logger)
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, HTTP2, selector.Current())
|
||||
|
||||
fetcher.percentage = 100
|
||||
assert.Equal(t, HTTP2, selector.Current())
|
||||
|
||||
fetcher.percentage = 0
|
||||
assert.Equal(t, HTTP2, selector.Current())
|
||||
|
||||
fetcher.err = fmt.Errorf("failed to fetch")
|
||||
assert.Equal(t, HTTP2, selector.Current())
|
||||
|
||||
fetcher.percentage = -1
|
||||
fetcher.err = nil
|
||||
assert.Equal(t, H2mux, selector.Current())
|
||||
|
||||
fetcher.percentage = 0
|
||||
assert.Equal(t, HTTP2, selector.Current())
|
||||
|
||||
fetcher.percentage = 100
|
||||
assert.Equal(t, HTTP2, selector.Current())
|
||||
|
||||
fetcher.percentage = -1
|
||||
assert.Equal(t, H2mux, selector.Current())
|
||||
}
|
||||
|
||||
func TestProtocolSelectorRefreshTTL(t *testing.T) {
|
||||
logger, _ := logger.New()
|
||||
fetcher := dynamicMockFetcher{percentage: 100}
|
||||
selector, err := NewProtocolSelector("auto", testNamedTunnelConfig, fetcher.fetch(), time.Hour, logger)
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, HTTP2, selector.Current())
|
||||
|
||||
fetcher.percentage = 0
|
||||
assert.Equal(t, HTTP2, selector.Current())
|
||||
}
|
Reference in New Issue
Block a user