TUN-5249: Revert "TUN-5138: Switch to QUIC on auto protocol based on threshold"

This reverts commit e445fd92f7
This commit is contained in:
Sudarsan Reddy
2021-10-13 19:06:31 +01:00
parent 5148d00516
commit 2822fbe3db
6 changed files with 217 additions and 249 deletions

View File

@@ -7,8 +7,6 @@ import (
"time"
"github.com/rs/zerolog"
"github.com/cloudflare/cloudflared/edgediscovery"
)
const (
@@ -26,7 +24,7 @@ const (
var (
// ProtocolList represents a list of supported protocols for communication with the edge.
ProtocolList = []Protocol{H2mux, HTTP2, HTTP2Warp, QUIC, QUICWarp}
ProtocolList = []Protocol{H2mux, HTTP2, QUIC}
)
type Protocol int64
@@ -38,12 +36,6 @@ const (
HTTP2
// QUIC is used only with named tunnels.
QUIC
// HTTP2Warp is used only with named tunnels. It's useful for warp-routing where we don't want to fallback to
// H2mux on HTTP2 failure to connect.
HTTP2Warp
//QUICWarp is used only with named tunnels. It's useful for warp-routing where we want to fallback to HTTP2 but
// dont' want HTTP2 to fallback to H2mux
QUICWarp
)
// Fallback returns the fallback protocol and whether the protocol has a fallback
@@ -53,12 +45,8 @@ func (p Protocol) fallback() (Protocol, bool) {
return 0, false
case HTTP2:
return H2mux, true
case HTTP2Warp:
return 0, false
case QUIC:
return HTTP2, true
case QUICWarp:
return HTTP2Warp, true
default:
return 0, false
}
@@ -68,9 +56,9 @@ func (p Protocol) String() string {
switch p {
case H2mux:
return "h2mux"
case HTTP2, HTTP2Warp:
case HTTP2:
return "http2"
case QUIC, QUICWarp:
case QUIC:
return "quic"
default:
return fmt.Sprintf("unknown protocol")
@@ -83,11 +71,11 @@ func (p Protocol) TLSSettings() *TLSSettings {
return &TLSSettings{
ServerName: edgeH2muxTLSServerName,
}
case HTTP2, HTTP2Warp:
case HTTP2:
return &TLSSettings{
ServerName: edgeH2TLSServerName,
}
case QUIC, QUICWarp:
case QUIC:
return &TLSSettings{
ServerName: edgeQUICServerName,
NextProtos: []string{"argotunnel"},
@@ -120,36 +108,29 @@ func (s *staticProtocolSelector) Fallback() (Protocol, bool) {
}
type autoProtocolSelector struct {
lock sync.RWMutex
current Protocol
// protocolPool is desired protocols in the order of priority they should be picked in.
protocolPool []Protocol
switchThreshold int32
fetchFunc PercentageFetcher
refreshAfter time.Time
ttl time.Duration
log *zerolog.Logger
lock sync.RWMutex
current Protocol
switchThrehold int32
fetchFunc PercentageFetcher
refreshAfter time.Time
ttl time.Duration
log *zerolog.Logger
}
func newAutoProtocolSelector(
current Protocol,
protocolPool []Protocol,
switchThreshold int32,
switchThrehold int32,
fetchFunc PercentageFetcher,
ttl time.Duration,
log *zerolog.Logger,
) *autoProtocolSelector {
return &autoProtocolSelector{
current: current,
protocolPool: protocolPool,
switchThreshold: switchThreshold,
fetchFunc: fetchFunc,
refreshAfter: time.Now().Add(ttl),
ttl: ttl,
log: log,
current: current,
switchThrehold: switchThrehold,
fetchFunc: fetchFunc,
refreshAfter: time.Now().Add(ttl),
ttl: ttl,
log: log,
}
}
@@ -160,39 +141,28 @@ func (s *autoProtocolSelector) Current() Protocol {
return s.current
}
protocol, err := getProtocol(s.protocolPool, s.fetchFunc, s.switchThreshold)
percentage, err := s.fetchFunc()
if err != nil {
s.log.Err(err).Msg("Failed to refresh protocol")
return s.current
}
s.current = protocol
if s.switchThrehold < percentage {
s.current = HTTP2
} else {
s.current = H2mux
}
s.refreshAfter = time.Now().Add(s.ttl)
return s.current
}
func getProtocol(protocolPool []Protocol, fetchFunc PercentageFetcher, switchThreshold int32) (Protocol, error) {
protocolPercentages, err := fetchFunc()
if err != nil {
return 0, err
}
for _, protocol := range protocolPool {
protocolPercentage := protocolPercentages.GetPercentage(protocol.String())
if protocolPercentage > switchThreshold {
return protocol, nil
}
}
return protocolPool[len(protocolPool)-1], nil
}
func (s *autoProtocolSelector) Fallback() (Protocol, bool) {
s.lock.RLock()
defer s.lock.RUnlock()
return s.current.fallback()
}
type PercentageFetcher func() (edgediscovery.ProtocolPercents, error)
type PercentageFetcher func() (int32, error)
func NewProtocolSelector(
protocolFlag string,
@@ -209,34 +179,22 @@ func NewProtocolSelector(
}, nil
}
threshold := switchThreshold(namedTunnel.Credentials.AccountTag)
fetchedProtocol, err := getProtocol([]Protocol{QUIC, HTTP2}, fetchFunc, threshold)
if err != nil {
log.Err(err).Msg("Unable to lookup protocol. Defaulting to `http2`. If this fails, you can set `--protocol h2mux` in your cloudflared command.")
// warp routing cannot be served over h2mux connections
if warpRoutingEnabled {
if protocolFlag == H2mux.String() {
log.Warn().Msg("Warp routing is not supported in h2mux protocol. Upgrading to http2 to allow it.")
}
if protocolFlag == QUIC.String() {
return &staticProtocolSelector{
current: QUIC,
}, nil
}
return &staticProtocolSelector{
current: HTTP2,
}, nil
}
if warpRoutingEnabled {
if protocolFlag == H2mux.String() || fetchedProtocol == H2mux {
log.Warn().Msg("Warp routing is not supported in h2mux protocol. Upgrading to http2 to allow it.")
protocolFlag = HTTP2.String()
fetchedProtocol = HTTP2Warp
}
return selectWarpRoutingProtocols(protocolFlag, fetchFunc, ttl, log, threshold, fetchedProtocol)
}
return selectNamedTunnelProtocols(protocolFlag, fetchFunc, ttl, log, threshold, fetchedProtocol)
}
func selectNamedTunnelProtocols(
protocolFlag string,
fetchFunc PercentageFetcher,
ttl time.Duration,
log *zerolog.Logger,
threshold int32,
protocol Protocol,
) (ProtocolSelector, error) {
if protocolFlag == H2mux.String() {
return &staticProtocolSelector{
current: H2mux,
@@ -244,41 +202,31 @@ func selectNamedTunnelProtocols(
}
if protocolFlag == QUIC.String() {
return newAutoProtocolSelector(QUIC, []Protocol{QUIC, HTTP2, H2mux}, explicitHTTP2FallbackThreshold, fetchFunc, ttl, log), nil
return newAutoProtocolSelector(QUIC, explicitHTTP2FallbackThreshold, fetchFunc, ttl, log), nil
}
http2Percentage, err := fetchFunc()
if err != nil {
log.Err(err).Msg("Unable to lookup protocol. Defaulting to `http2`. If this fails, you can set `--protocol h2mux` in your cloudflared command.")
return &staticProtocolSelector{
current: HTTP2,
}, nil
}
if protocolFlag == HTTP2.String() {
return newAutoProtocolSelector(HTTP2, []Protocol{HTTP2, H2mux}, explicitHTTP2FallbackThreshold, fetchFunc, ttl, log), nil
if http2Percentage < 0 {
return newAutoProtocolSelector(H2mux, explicitHTTP2FallbackThreshold, fetchFunc, ttl, log), nil
}
return newAutoProtocolSelector(HTTP2, explicitHTTP2FallbackThreshold, fetchFunc, ttl, log), nil
}
if protocolFlag != autoSelectFlag {
return nil, fmt.Errorf("Unknown protocol %s, %s", protocolFlag, AvailableProtocolFlagMessage)
}
return newAutoProtocolSelector(protocol, []Protocol{QUIC, HTTP2, H2mux}, threshold, fetchFunc, ttl, log), nil
}
func selectWarpRoutingProtocols(
protocolFlag string,
fetchFunc PercentageFetcher,
ttl time.Duration,
log *zerolog.Logger,
threshold int32,
protocol Protocol,
) (ProtocolSelector, error) {
if protocolFlag == QUIC.String() {
return newAutoProtocolSelector(QUICWarp, []Protocol{QUICWarp, HTTP2Warp}, explicitHTTP2FallbackThreshold, fetchFunc, ttl, log), nil
threshold := switchThreshold(namedTunnel.Credentials.AccountTag)
if threshold < http2Percentage {
return newAutoProtocolSelector(HTTP2, threshold, fetchFunc, ttl, log), nil
}
if protocolFlag == HTTP2.String() {
return newAutoProtocolSelector(HTTP2Warp, []Protocol{HTTP2Warp}, explicitHTTP2FallbackThreshold, fetchFunc, ttl, log), nil
}
if protocolFlag != autoSelectFlag {
return nil, fmt.Errorf("Unknown protocol %s, %s", protocolFlag, AvailableProtocolFlagMessage)
}
return newAutoProtocolSelector(protocol, []Protocol{QUICWarp, HTTP2Warp}, threshold, fetchFunc, ttl, log), nil
return newAutoProtocolSelector(H2mux, threshold, fetchFunc, ttl, log), nil
}
func switchThreshold(accountTag string) int32 {

View File

@@ -6,8 +6,6 @@ import (
"time"
"github.com/stretchr/testify/assert"
"github.com/cloudflare/cloudflared/edgediscovery"
)
const (
@@ -23,23 +21,29 @@ var (
}
)
func mockFetcher(getError bool, protocolPercent ...edgediscovery.ProtocolPercent) PercentageFetcher {
return func() (edgediscovery.ProtocolPercents, error) {
if getError {
return nil, fmt.Errorf("failed to fetch precentage")
}
return protocolPercent, nil
func mockFetcher(percentage int32) PercentageFetcher {
return func() (int32, error) {
return percentage, nil
}
}
func mockFetcherWithError() PercentageFetcher {
return func() (int32, error) {
return 0, fmt.Errorf("failed to fetch precentage")
}
}
type dynamicMockFetcher struct {
protocolPercents edgediscovery.ProtocolPercents
err error
percentage int32
err error
}
func (dmf *dynamicMockFetcher) fetch() PercentageFetcher {
return func() (edgediscovery.ProtocolPercents, error) {
return dmf.protocolPercents, dmf.err
return func() (int32, error) {
if dmf.err != nil {
return 0, dmf.err
}
return dmf.percentage, nil
}
}
@@ -65,7 +69,6 @@ func TestNewProtocolSelector(t *testing.T) {
name: "named tunnel over h2mux",
protocol: "h2mux",
expectedProtocol: H2mux,
fetchFunc: func() (edgediscovery.ProtocolPercents, error) { return nil, nil },
namedTunnelConfig: testNamedTunnelConfig,
},
{
@@ -74,38 +77,28 @@ func TestNewProtocolSelector(t *testing.T) {
expectedProtocol: HTTP2,
hasFallback: true,
expectedFallback: H2mux,
fetchFunc: mockFetcher(false, edgediscovery.ProtocolPercent{Protocol: "http2", Percentage: 0}),
fetchFunc: mockFetcher(0),
namedTunnelConfig: testNamedTunnelConfig,
},
{
name: "named tunnel http2 disabled",
protocol: "http2",
expectedProtocol: H2mux,
fetchFunc: mockFetcher(false, edgediscovery.ProtocolPercent{Protocol: "http2", Percentage: -1}),
namedTunnelConfig: testNamedTunnelConfig,
},
{
name: "named tunnel quic disabled",
protocol: "quic",
expectedProtocol: HTTP2,
// Hasfallback true is because if http2 fails, then we further fallback to h2mux.
hasFallback: true,
expectedFallback: H2mux,
fetchFunc: mockFetcher(false, edgediscovery.ProtocolPercent{Protocol: "http2", Percentage: 100}, edgediscovery.ProtocolPercent{Protocol: "quic", Percentage: -1}),
fetchFunc: mockFetcher(-1),
namedTunnelConfig: testNamedTunnelConfig,
},
{
name: "named tunnel auto all http2 disabled",
protocol: "auto",
expectedProtocol: H2mux,
fetchFunc: mockFetcher(false, edgediscovery.ProtocolPercent{Protocol: "http2", Percentage: -1}),
fetchFunc: mockFetcher(-1),
namedTunnelConfig: testNamedTunnelConfig,
},
{
name: "named tunnel auto to h2mux",
protocol: "auto",
expectedProtocol: H2mux,
fetchFunc: mockFetcher(false, edgediscovery.ProtocolPercent{Protocol: "http2", Percentage: 0}),
fetchFunc: mockFetcher(0),
namedTunnelConfig: testNamedTunnelConfig,
},
{
@@ -114,71 +107,36 @@ func TestNewProtocolSelector(t *testing.T) {
expectedProtocol: HTTP2,
hasFallback: true,
expectedFallback: H2mux,
fetchFunc: mockFetcher(false, edgediscovery.ProtocolPercent{Protocol: "http2", Percentage: 100}),
namedTunnelConfig: testNamedTunnelConfig,
},
{
name: "named tunnel auto to quic",
protocol: "auto",
expectedProtocol: QUIC,
hasFallback: true,
expectedFallback: HTTP2,
fetchFunc: mockFetcher(false, edgediscovery.ProtocolPercent{Protocol: "quic", Percentage: 100}),
fetchFunc: mockFetcher(100),
namedTunnelConfig: testNamedTunnelConfig,
},
{
name: "warp routing requesting h2mux",
protocol: "h2mux",
expectedProtocol: HTTP2Warp,
expectedProtocol: HTTP2,
hasFallback: false,
fetchFunc: mockFetcher(false, edgediscovery.ProtocolPercent{Protocol: "http2", Percentage: 100}),
warpRoutingEnabled: true,
namedTunnelConfig: testNamedTunnelConfig,
},
{
name: "warp routing requesting h2mux picks HTTP2 even if http2 percent is -1",
protocol: "h2mux",
expectedProtocol: HTTP2Warp,
hasFallback: false,
fetchFunc: mockFetcher(false, edgediscovery.ProtocolPercent{Protocol: "http2", Percentage: -1}),
expectedFallback: H2mux,
fetchFunc: mockFetcher(100),
warpRoutingEnabled: true,
namedTunnelConfig: testNamedTunnelConfig,
},
{
name: "warp routing http2",
protocol: "http2",
expectedProtocol: HTTP2Warp,
expectedProtocol: HTTP2,
hasFallback: false,
fetchFunc: mockFetcher(false, edgediscovery.ProtocolPercent{Protocol: "http2", Percentage: 100}),
warpRoutingEnabled: true,
namedTunnelConfig: testNamedTunnelConfig,
},
{
name: "warp routing quic",
protocol: "quic",
expectedProtocol: QUICWarp,
hasFallback: true,
expectedFallback: HTTP2Warp,
fetchFunc: mockFetcher(false, edgediscovery.ProtocolPercent{Protocol: "quic", Percentage: 100}),
expectedFallback: H2mux,
fetchFunc: mockFetcher(100),
warpRoutingEnabled: true,
namedTunnelConfig: testNamedTunnelConfig,
},
{
name: "warp routing auto",
protocol: "auto",
expectedProtocol: HTTP2Warp,
expectedProtocol: HTTP2,
hasFallback: false,
fetchFunc: mockFetcher(false, edgediscovery.ProtocolPercent{Protocol: "http2", Percentage: 100}),
warpRoutingEnabled: true,
namedTunnelConfig: testNamedTunnelConfig,
},
{
name: "warp routing auto- quic",
protocol: "auto",
expectedProtocol: QUICWarp,
hasFallback: true,
expectedFallback: HTTP2Warp,
fetchFunc: mockFetcher(false, edgediscovery.ProtocolPercent{Protocol: "http2", Percentage: 100}, edgediscovery.ProtocolPercent{Protocol: "quic", Percentage: 100}),
expectedFallback: H2mux,
fetchFunc: mockFetcher(100),
warpRoutingEnabled: true,
namedTunnelConfig: testNamedTunnelConfig,
},
@@ -191,14 +149,14 @@ func TestNewProtocolSelector(t *testing.T) {
{
name: "named tunnel unknown protocol",
protocol: "unknown",
fetchFunc: mockFetcher(false, edgediscovery.ProtocolPercent{Protocol: "http2", Percentage: 100}),
fetchFunc: mockFetcher(100),
namedTunnelConfig: testNamedTunnelConfig,
wantErr: true,
},
{
name: "named tunnel fetch error",
protocol: "auto",
fetchFunc: mockFetcher(true),
protocol: "unknown",
fetchFunc: mockFetcherWithError(),
namedTunnelConfig: testNamedTunnelConfig,
expectedProtocol: HTTP2,
wantErr: false,
@@ -206,20 +164,18 @@ func TestNewProtocolSelector(t *testing.T) {
}
for _, test := range tests {
t.Run(test.name, func(t *testing.T) {
selector, err := NewProtocolSelector(test.protocol, test.warpRoutingEnabled, test.namedTunnelConfig, test.fetchFunc, testNoTTL, &log)
if test.wantErr {
assert.Error(t, err, fmt.Sprintf("test %s failed", test.name))
} else {
assert.NoError(t, err, fmt.Sprintf("test %s failed", test.name))
assert.Equal(t, test.expectedProtocol, selector.Current(), fmt.Sprintf("test %s failed", test.name))
fallback, ok := selector.Fallback()
assert.Equal(t, test.hasFallback, ok, fmt.Sprintf("test %s failed", test.name))
if test.hasFallback {
assert.Equal(t, test.expectedFallback, fallback, fmt.Sprintf("test %s failed", test.name))
}
selector, err := NewProtocolSelector(test.protocol, test.warpRoutingEnabled, test.namedTunnelConfig, test.fetchFunc, testNoTTL, &log)
if test.wantErr {
assert.Error(t, err, fmt.Sprintf("test %s failed", test.name))
} else {
assert.NoError(t, err, fmt.Sprintf("test %s failed", test.name))
assert.Equal(t, test.expectedProtocol, selector.Current(), fmt.Sprintf("test %s failed", test.name))
fallback, ok := selector.Fallback()
assert.Equal(t, test.hasFallback, ok, fmt.Sprintf("test %s failed", test.name))
if test.hasFallback {
assert.Equal(t, test.expectedFallback, fallback, fmt.Sprintf("test %s failed", test.name))
}
})
}
}
}
@@ -229,27 +185,27 @@ func TestAutoProtocolSelectorRefresh(t *testing.T) {
assert.NoError(t, err)
assert.Equal(t, H2mux, selector.Current())
fetcher.protocolPercents = edgediscovery.ProtocolPercents{edgediscovery.ProtocolPercent{Protocol: "http2", Percentage: 100}}
fetcher.percentage = 100
assert.Equal(t, HTTP2, selector.Current())
fetcher.protocolPercents = edgediscovery.ProtocolPercents{edgediscovery.ProtocolPercent{Protocol: "http2", Percentage: 0}}
fetcher.percentage = 0
assert.Equal(t, H2mux, selector.Current())
fetcher.protocolPercents = edgediscovery.ProtocolPercents{edgediscovery.ProtocolPercent{Protocol: "http2", Percentage: 100}}
fetcher.percentage = 100
assert.Equal(t, HTTP2, selector.Current())
fetcher.err = fmt.Errorf("failed to fetch")
assert.Equal(t, HTTP2, selector.Current())
fetcher.protocolPercents = edgediscovery.ProtocolPercents{edgediscovery.ProtocolPercent{Protocol: "http2", Percentage: -1}}
fetcher.percentage = -1
fetcher.err = nil
assert.Equal(t, H2mux, selector.Current())
fetcher.protocolPercents = edgediscovery.ProtocolPercents{edgediscovery.ProtocolPercent{Protocol: "http2", Percentage: 0}}
fetcher.percentage = 0
assert.Equal(t, H2mux, selector.Current())
fetcher.protocolPercents = edgediscovery.ProtocolPercents{edgediscovery.ProtocolPercent{Protocol: "quic", Percentage: 100}}
assert.Equal(t, QUIC, selector.Current())
fetcher.percentage = 100
assert.Equal(t, HTTP2, selector.Current())
}
func TestHTTP2ProtocolSelectorRefresh(t *testing.T) {
@@ -258,36 +214,35 @@ func TestHTTP2ProtocolSelectorRefresh(t *testing.T) {
assert.NoError(t, err)
assert.Equal(t, HTTP2, selector.Current())
fetcher.protocolPercents = edgediscovery.ProtocolPercents{edgediscovery.ProtocolPercent{Protocol: "http2", Percentage: 100}}
fetcher.percentage = 100
assert.Equal(t, HTTP2, selector.Current())
fetcher.protocolPercents = edgediscovery.ProtocolPercents{edgediscovery.ProtocolPercent{Protocol: "http2", Percentage: 0}}
fetcher.percentage = 0
assert.Equal(t, HTTP2, selector.Current())
fetcher.err = fmt.Errorf("failed to fetch")
assert.Equal(t, HTTP2, selector.Current())
fetcher.protocolPercents = edgediscovery.ProtocolPercents{edgediscovery.ProtocolPercent{Protocol: "http2", Percentage: -1}}
fetcher.percentage = -1
fetcher.err = nil
assert.Equal(t, H2mux, selector.Current())
fetcher.protocolPercents = edgediscovery.ProtocolPercents{edgediscovery.ProtocolPercent{Protocol: "http2", Percentage: 0}}
fetcher.percentage = 0
assert.Equal(t, HTTP2, selector.Current())
fetcher.protocolPercents = edgediscovery.ProtocolPercents{edgediscovery.ProtocolPercent{Protocol: "http2", Percentage: 100}}
fetcher.percentage = 100
assert.Equal(t, HTTP2, selector.Current())
fetcher.protocolPercents = edgediscovery.ProtocolPercents{edgediscovery.ProtocolPercent{Protocol: "http2", Percentage: -1}}
fetcher.percentage = -1
assert.Equal(t, H2mux, selector.Current())
}
func TestProtocolSelectorRefreshTTL(t *testing.T) {
fetcher := dynamicMockFetcher{}
fetcher.protocolPercents = edgediscovery.ProtocolPercents{edgediscovery.ProtocolPercent{Protocol: "quic", Percentage: 100}}
fetcher := dynamicMockFetcher{percentage: 100}
selector, err := NewProtocolSelector("auto", noWarpRoutingEnabled, testNamedTunnelConfig, fetcher.fetch(), time.Hour, &log)
assert.NoError(t, err)
assert.Equal(t, QUIC, selector.Current())
assert.Equal(t, HTTP2, selector.Current())
fetcher.protocolPercents = edgediscovery.ProtocolPercents{edgediscovery.ProtocolPercent{Protocol: "quic", Percentage: 0}}
assert.Equal(t, QUIC, selector.Current())
fetcher.percentage = 0
assert.Equal(t, HTTP2, selector.Current())
}