cloudflared/connection/protocol_test.go
Devin Carr 0f95f8bae5 TUN-6938: Force h2mux protocol to http2 for named tunnels
Going forward, the only protocols supported will be QUIC and HTTP2,
defaulting to QUIC for "auto". Selecting h2mux protocol will be forcibly
upgraded to http2 internally.
2023-02-06 11:06:02 -08:00

174 lines
6.0 KiB
Go

package connection
import (
"fmt"
"testing"
"github.com/stretchr/testify/assert"
"github.com/cloudflare/cloudflared/edgediscovery"
)
const (
testNoTTL = 0
testAccountTag = "testAccountTag"
)
func mockFetcher(getError bool, protocolPercent ...edgediscovery.ProtocolPercent) edgediscovery.PercentageFetcher {
return func() (edgediscovery.ProtocolPercents, error) {
if getError {
return nil, fmt.Errorf("failed to fetch percentage")
}
return protocolPercent, nil
}
}
type dynamicMockFetcher struct {
protocolPercents edgediscovery.ProtocolPercents
err error
}
func (dmf *dynamicMockFetcher) fetch() edgediscovery.PercentageFetcher {
return func() (edgediscovery.ProtocolPercents, error) {
return dmf.protocolPercents, dmf.err
}
}
func TestNewProtocolSelector(t *testing.T) {
tests := []struct {
name string
protocol string
tunnelTokenProvided bool
needPQ bool
expectedProtocol Protocol
hasFallback bool
expectedFallback Protocol
wantErr bool
}{
{
name: "named tunnel with unknown protocol",
protocol: "unknown",
wantErr: true,
},
{
name: "named tunnel with h2mux: force to http2",
protocol: "h2mux",
expectedProtocol: HTTP2,
},
{
name: "named tunnel with http2: no fallback",
protocol: "http2",
expectedProtocol: HTTP2,
},
{
name: "named tunnel with auto: quic",
protocol: AutoSelectFlag,
expectedProtocol: QUIC,
hasFallback: true,
expectedFallback: HTTP2,
},
{
name: "named tunnel (post quantum)",
protocol: AutoSelectFlag,
needPQ: true,
expectedProtocol: QUIC,
},
{
name: "named tunnel (post quantum) w/http2",
protocol: "http2",
needPQ: true,
expectedProtocol: QUIC,
},
}
fetcher := dynamicMockFetcher{
protocolPercents: edgediscovery.ProtocolPercents{},
}
for _, test := range tests {
t.Run(test.name, func(t *testing.T) {
selector, err := NewProtocolSelector(test.protocol, testAccountTag, test.tunnelTokenProvided, test.needPQ, fetcher.fetch(), ResolveTTL, &log)
if test.wantErr {
assert.Error(t, err, fmt.Sprintf("test %s failed", test.name))
} else {
assert.NoError(t, err, fmt.Sprintf("test %s failed", test.name))
assert.Equal(t, test.expectedProtocol, selector.Current(), fmt.Sprintf("test %s failed", test.name))
fallback, ok := selector.Fallback()
assert.Equal(t, test.hasFallback, ok, fmt.Sprintf("test %s failed", test.name))
if test.hasFallback {
assert.Equal(t, test.expectedFallback, fallback, fmt.Sprintf("test %s failed", test.name))
}
}
})
}
}
func TestAutoProtocolSelectorRefresh(t *testing.T) {
fetcher := dynamicMockFetcher{}
selector, err := NewProtocolSelector(AutoSelectFlag, testAccountTag, false, false, fetcher.fetch(), testNoTTL, &log)
assert.NoError(t, err)
assert.Equal(t, QUIC, selector.Current())
fetcher.protocolPercents = edgediscovery.ProtocolPercents{edgediscovery.ProtocolPercent{Protocol: "http2", Percentage: 100}}
assert.Equal(t, HTTP2, selector.Current())
fetcher.protocolPercents = edgediscovery.ProtocolPercents{edgediscovery.ProtocolPercent{Protocol: "http2", Percentage: 0}}
assert.Equal(t, QUIC, selector.Current())
fetcher.protocolPercents = edgediscovery.ProtocolPercents{edgediscovery.ProtocolPercent{Protocol: "http2", Percentage: 100}}
assert.Equal(t, HTTP2, selector.Current())
fetcher.err = fmt.Errorf("failed to fetch")
assert.Equal(t, HTTP2, selector.Current())
fetcher.protocolPercents = edgediscovery.ProtocolPercents{edgediscovery.ProtocolPercent{Protocol: "http2", Percentage: -1}}
fetcher.err = nil
assert.Equal(t, QUIC, selector.Current())
fetcher.protocolPercents = edgediscovery.ProtocolPercents{edgediscovery.ProtocolPercent{Protocol: "http2", Percentage: 0}}
assert.Equal(t, QUIC, selector.Current())
fetcher.protocolPercents = edgediscovery.ProtocolPercents{edgediscovery.ProtocolPercent{Protocol: "quic", Percentage: 100}}
assert.Equal(t, QUIC, selector.Current())
}
func TestHTTP2ProtocolSelectorRefresh(t *testing.T) {
fetcher := dynamicMockFetcher{}
// Since the user chooses http2 on purpose, we always stick to it.
selector, err := NewProtocolSelector(HTTP2.String(), testAccountTag, false, false, fetcher.fetch(), testNoTTL, &log)
assert.NoError(t, err)
assert.Equal(t, HTTP2, selector.Current())
fetcher.protocolPercents = edgediscovery.ProtocolPercents{edgediscovery.ProtocolPercent{Protocol: "http2", Percentage: 100}}
assert.Equal(t, HTTP2, selector.Current())
fetcher.protocolPercents = edgediscovery.ProtocolPercents{edgediscovery.ProtocolPercent{Protocol: "http2", Percentage: 0}}
assert.Equal(t, HTTP2, selector.Current())
fetcher.err = fmt.Errorf("failed to fetch")
assert.Equal(t, HTTP2, selector.Current())
fetcher.protocolPercents = edgediscovery.ProtocolPercents{edgediscovery.ProtocolPercent{Protocol: "http2", Percentage: -1}}
fetcher.err = nil
assert.Equal(t, HTTP2, selector.Current())
fetcher.protocolPercents = edgediscovery.ProtocolPercents{edgediscovery.ProtocolPercent{Protocol: "http2", Percentage: 0}}
assert.Equal(t, HTTP2, selector.Current())
fetcher.protocolPercents = edgediscovery.ProtocolPercents{edgediscovery.ProtocolPercent{Protocol: "http2", Percentage: 100}}
assert.Equal(t, HTTP2, selector.Current())
fetcher.protocolPercents = edgediscovery.ProtocolPercents{edgediscovery.ProtocolPercent{Protocol: "http2", Percentage: -1}}
assert.Equal(t, HTTP2, selector.Current())
}
func TestAutoProtocolSelectorNoRefreshWithToken(t *testing.T) {
fetcher := dynamicMockFetcher{}
selector, err := NewProtocolSelector(AutoSelectFlag, testAccountTag, true, false, fetcher.fetch(), testNoTTL, &log)
assert.NoError(t, err)
assert.Equal(t, QUIC, selector.Current())
fetcher.protocolPercents = edgediscovery.ProtocolPercents{edgediscovery.ProtocolPercent{Protocol: "http2", Percentage: 100}}
assert.Equal(t, QUIC, selector.Current())
}