RTG-1339 Support post-quantum hybrid key exchange

Func spec: https://wiki.cfops.it/x/ZcBKHw
This commit is contained in:
Bas Westerbaan
2022-08-24 14:33:10 +02:00
committed by Devin Carr
parent 3e0ff3a771
commit 11cbff4ff7
171 changed files with 15270 additions and 196 deletions

View File

@@ -1,6 +1,7 @@
package connection
import (
"errors"
"fmt"
"hash/fnv"
"sync"
@@ -130,6 +131,7 @@ type autoProtocolSelector struct {
refreshAfter time.Time
ttl time.Duration
log *zerolog.Logger
needPQ bool
}
func newAutoProtocolSelector(
@@ -139,6 +141,7 @@ func newAutoProtocolSelector(
fetchFunc PercentageFetcher,
ttl time.Duration,
log *zerolog.Logger,
needPQ bool,
) *autoProtocolSelector {
return &autoProtocolSelector{
current: current,
@@ -148,6 +151,7 @@ func newAutoProtocolSelector(
refreshAfter: time.Now().Add(ttl),
ttl: ttl,
log: log,
needPQ: needPQ,
}
}
@@ -187,6 +191,9 @@ func getProtocol(protocolPool []Protocol, fetchFunc PercentageFetcher, switchThr
func (s *autoProtocolSelector) Fallback() (Protocol, bool) {
s.lock.RLock()
defer s.lock.RUnlock()
if s.needPQ {
return 0, false
}
return s.current.fallback()
}
@@ -199,9 +206,14 @@ func NewProtocolSelector(
fetchFunc PercentageFetcher,
ttl time.Duration,
log *zerolog.Logger,
needPQ bool,
) (ProtocolSelector, error) {
// Classic tunnel is only supported with h2mux
if namedTunnel == nil {
if needPQ {
return nil, errors.New("Classic tunnel does not support post-quantum")
}
return &staticProtocolSelector{
current: H2mux,
}, nil
@@ -211,6 +223,9 @@ func NewProtocolSelector(
fetchedProtocol, err := getProtocol([]Protocol{QUIC, HTTP2}, fetchFunc, threshold)
if err != nil && protocolFlag == "auto" {
log.Err(err).Msg("Unable to lookup protocol. Defaulting to `http2`. If this fails, you can attempt `--protocol quic` instead.")
if needPQ {
return nil, errors.New("http2 does not support post-quantum")
}
return &staticProtocolSelector{
current: HTTP2,
}, nil
@@ -221,10 +236,10 @@ func NewProtocolSelector(
protocolFlag = HTTP2.String()
fetchedProtocol = HTTP2Warp
}
return selectWarpRoutingProtocols(protocolFlag, fetchFunc, ttl, log, threshold, fetchedProtocol)
return selectWarpRoutingProtocols(protocolFlag, fetchFunc, ttl, log, threshold, fetchedProtocol, needPQ)
}
return selectNamedTunnelProtocols(protocolFlag, fetchFunc, ttl, log, threshold, fetchedProtocol)
return selectNamedTunnelProtocols(protocolFlag, fetchFunc, ttl, log, threshold, fetchedProtocol, needPQ)
}
func selectNamedTunnelProtocols(
@@ -234,6 +249,7 @@ func selectNamedTunnelProtocols(
log *zerolog.Logger,
threshold int32,
protocol Protocol,
needPQ bool,
) (ProtocolSelector, error) {
// If the user picks a protocol, then we stick to it no matter what.
switch protocolFlag {
@@ -248,7 +264,7 @@ func selectNamedTunnelProtocols(
// If the user does not pick (hopefully the majority) then we use the one derived from the TXT DNS record and
// fallback on failures.
if protocolFlag == AutoSelectFlag {
return newAutoProtocolSelector(protocol, []Protocol{QUIC, HTTP2, H2mux}, threshold, fetchFunc, ttl, log), nil
return newAutoProtocolSelector(protocol, []Protocol{QUIC, HTTP2, H2mux}, threshold, fetchFunc, ttl, log, needPQ), nil
}
return nil, fmt.Errorf("Unknown protocol %s, %s", protocolFlag, AvailableProtocolFlagMessage)
@@ -261,6 +277,7 @@ func selectWarpRoutingProtocols(
log *zerolog.Logger,
threshold int32,
protocol Protocol,
needPQ bool,
) (ProtocolSelector, error) {
// If the user picks a protocol, then we stick to it no matter what.
switch protocolFlag {
@@ -273,7 +290,7 @@ func selectWarpRoutingProtocols(
// If the user does not pick (hopefully the majority) then we use the one derived from the TXT DNS record and
// fallback on failures.
if protocolFlag == AutoSelectFlag {
return newAutoProtocolSelector(protocol, []Protocol{QUICWarp, HTTP2Warp}, threshold, fetchFunc, ttl, log), nil
return newAutoProtocolSelector(protocol, []Protocol{QUICWarp, HTTP2Warp}, threshold, fetchFunc, ttl, log, needPQ), nil
}
return nil, fmt.Errorf("Unknown protocol %s, %s", protocolFlag, AvailableProtocolFlagMessage)

View File

@@ -219,7 +219,7 @@ func TestNewProtocolSelector(t *testing.T) {
for _, test := range tests {
t.Run(test.name, func(t *testing.T) {
selector, err := NewProtocolSelector(test.protocol, test.warpRoutingEnabled, test.namedTunnelConfig, test.fetchFunc, testNoTTL, &log)
selector, err := NewProtocolSelector(test.protocol, test.warpRoutingEnabled, test.namedTunnelConfig, test.fetchFunc, testNoTTL, &log, false)
if test.wantErr {
assert.Error(t, err, fmt.Sprintf("test %s failed", test.name))
} else {
@@ -237,7 +237,7 @@ func TestNewProtocolSelector(t *testing.T) {
func TestAutoProtocolSelectorRefresh(t *testing.T) {
fetcher := dynamicMockFetcher{}
selector, err := NewProtocolSelector(AutoSelectFlag, noWarpRoutingEnabled, testNamedTunnelProperties, fetcher.fetch(), testNoTTL, &log)
selector, err := NewProtocolSelector(AutoSelectFlag, noWarpRoutingEnabled, testNamedTunnelProperties, fetcher.fetch(), testNoTTL, &log, false)
assert.NoError(t, err)
assert.Equal(t, H2mux, selector.Current())
@@ -267,7 +267,7 @@ func TestAutoProtocolSelectorRefresh(t *testing.T) {
func TestHTTP2ProtocolSelectorRefresh(t *testing.T) {
fetcher := dynamicMockFetcher{}
// Since the user chooses http2 on purpose, we always stick to it.
selector, err := NewProtocolSelector("http2", noWarpRoutingEnabled, testNamedTunnelProperties, fetcher.fetch(), testNoTTL, &log)
selector, err := NewProtocolSelector("http2", noWarpRoutingEnabled, testNamedTunnelProperties, fetcher.fetch(), testNoTTL, &log, false)
assert.NoError(t, err)
assert.Equal(t, HTTP2, selector.Current())
@@ -297,7 +297,7 @@ func TestHTTP2ProtocolSelectorRefresh(t *testing.T) {
func TestProtocolSelectorRefreshTTL(t *testing.T) {
fetcher := dynamicMockFetcher{}
fetcher.protocolPercents = edgediscovery.ProtocolPercents{edgediscovery.ProtocolPercent{Protocol: "quic", Percentage: 100}}
selector, err := NewProtocolSelector(AutoSelectFlag, noWarpRoutingEnabled, testNamedTunnelProperties, fetcher.fetch(), time.Hour, &log)
selector, err := NewProtocolSelector(AutoSelectFlag, noWarpRoutingEnabled, testNamedTunnelProperties, fetcher.fetch(), time.Hour, &log, false)
assert.NoError(t, err)
assert.Equal(t, QUIC, selector.Current())