TUN-7700: Implement feature selector to determine if connections will prefer post quantum cryptography

This commit is contained in:
Chung-Ting Huang
2023-08-25 14:39:25 +01:00
parent 38d3c3cae5
commit bec683b67d
8 changed files with 385 additions and 43 deletions

View File

@@ -392,7 +392,7 @@ func StartServer(
observer.SendURL(quickTunnelURL)
}
tunnelConfig, orchestratorConfig, err := prepareTunnelConfig(c, info, log, logTransport, observer, namedTunnel)
tunnelConfig, orchestratorConfig, err := prepareTunnelConfig(ctx, c, info, log, logTransport, observer, namedTunnel)
if err != nil {
log.Err(err).Msg("Couldn't start tunnel")
return err

View File

@@ -4,10 +4,12 @@ import (
"testing"
"github.com/stretchr/testify/require"
"github.com/cloudflare/cloudflared/features"
)
func TestDedup(t *testing.T) {
expected := []string{"a", "b"}
actual := dedup([]string{"a", "b", "a"})
actual := features.Dedup([]string{"a", "b", "a"})
require.ElementsMatch(t, expected, actual)
}

View File

@@ -1,6 +1,7 @@
package tunnel
import (
"context"
"crypto/tls"
"fmt"
"net"
@@ -112,6 +113,7 @@ func dnsProxyStandAlone(c *cli.Context, namedTunnel *connection.NamedTunnelPrope
}
func prepareTunnelConfig(
ctx context.Context,
c *cli.Context,
info *cliutil.BuildInfo,
log, logTransport *zerolog.Logger,
@@ -131,22 +133,36 @@ func prepareTunnelConfig(
tags = append(tags, tunnelpogs.Tag{Name: "ID", Value: clientID.String()})
transportProtocol := c.String("protocol")
needPQ := c.Bool("post-quantum")
if needPQ {
clientFeatures := features.Dedup(append(c.StringSlice("features"), features.DefaultFeatures...))
staticFeatures := features.StaticFeatures{}
if c.Bool("post-quantum") {
if FipsEnabled {
return nil, nil, fmt.Errorf("post-quantum not supported in FIPS mode")
}
pqMode := features.PostQuantumStrict
staticFeatures.PostQuantumMode = &pqMode
}
featureSelector, err := features.NewFeatureSelector(ctx, namedTunnel.Credentials.AccountTag, staticFeatures, log)
if err != nil {
return nil, nil, errors.Wrap(err, "Failed to create feature selector")
}
pqMode := featureSelector.PostQuantumMode()
if pqMode == features.PostQuantumStrict {
// Error if the user tries to force a non-quic transport protocol
if transportProtocol != connection.AutoSelectFlag && transportProtocol != connection.QUIC.String() {
return nil, nil, fmt.Errorf("post-quantum is only supported with the quic transport")
}
transportProtocol = connection.QUIC.String()
clientFeatures = append(clientFeatures, features.FeaturePostQuantum)
log.Info().Msgf(
"Using hybrid post-quantum key agreement %s",
supervisor.PQKexName,
)
}
clientFeatures := dedup(append(c.StringSlice("features"), features.DefaultFeatures...))
if needPQ {
clientFeatures = append(clientFeatures, features.FeaturePostQuantum)
}
namedTunnel.Client = tunnelpogs.ClientInfo{
ClientID: clientID[:],
Features: clientFeatures,
@@ -202,13 +218,6 @@ func prepareTunnelConfig(
log.Warn().Str("edgeIPVersion", edgeIPVersion.String()).Err(err).Msg("Overriding edge-ip-version")
}
if needPQ {
log.Info().Msgf(
"Using hybrid post-quantum key agreement %s",
supervisor.PQKexName,
)
}
tunnelConfig := &supervisor.TunnelConfig{
GracePeriod: gracePeriod,
ReplaceExisting: c.Bool("force"),
@@ -233,7 +242,7 @@ func prepareTunnelConfig(
NamedTunnel: namedTunnel,
ProtocolSelector: protocolSelector,
EdgeTLSConfigs: edgeTLSConfigs,
NeedPQ: needPQ,
FeatureSelector: featureSelector,
MaxEdgeAddrRetries: uint8(c.Int("max-edge-addr-retries")),
UDPUnregisterSessionTimeout: c.Duration(udpUnregisterSessionTimeoutFlag),
DisableQUICPathMTUDiscovery: c.Bool(quicDisablePathMTUDiscovery),
@@ -276,25 +285,6 @@ func isRunningFromTerminal() bool {
return term.IsTerminal(int(os.Stdout.Fd()))
}
// Remove any duplicates from the slice
func dedup(slice []string) []string {
// Convert the slice into a set
set := make(map[string]bool, 0)
for _, str := range slice {
set[str] = true
}
// Convert the set back into a slice
keys := make([]string, len(set))
i := 0
for str := range set {
keys[i] = str
i++
}
return keys
}
// ParseConfigIPVersion returns the IP version from possible expected values from config
func parseConfigIPVersion(version string) (v allregions.ConfigIPVersion, err error) {
switch version {