diff --git a/supervisor/pqtunnels.go b/supervisor/pqtunnels.go index 2eaad9e8..30eb2e87 100644 --- a/supervisor/pqtunnels.go +++ b/supervisor/pqtunnels.go @@ -17,8 +17,8 @@ const ( ) var ( - nonFipsPostQuantumStrictPKex []tls.CurveID = []tls.CurveID{X25519MLKEM768PQKex, X25519Kyber768Draft00PQKex} - nonFipsPostQuantumPreferPKex []tls.CurveID = []tls.CurveID{X25519MLKEM768PQKex, X25519Kyber768Draft00PQKex} + nonFipsPostQuantumStrictPKex []tls.CurveID = []tls.CurveID{X25519MLKEM768PQKex} + nonFipsPostQuantumPreferPKex []tls.CurveID = []tls.CurveID{X25519MLKEM768PQKex} fipsPostQuantumStrictPKex []tls.CurveID = []tls.CurveID{P256Kyber768Draft00PQKex} fipsPostQuantumPreferPKex []tls.CurveID = []tls.CurveID{P256Kyber768Draft00PQKex, tls.CurveP256} ) diff --git a/supervisor/pqtunnels_test.go b/supervisor/pqtunnels_test.go index 383200db..3be54460 100644 --- a/supervisor/pqtunnels_test.go +++ b/supervisor/pqtunnels_test.go @@ -2,12 +2,16 @@ package supervisor import ( "crypto/tls" + "net/http" + "net/http/httptest" + "slices" "testing" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" "github.com/cloudflare/cloudflared/features" + "github.com/cloudflare/cloudflared/fips" ) func TestCurvePreferences(t *testing.T) { @@ -48,7 +52,7 @@ func TestCurvePreferences(t *testing.T) { pqMode: features.PostQuantumPrefer, fipsEnabled: false, currentCurves: []tls.CurveID{tls.CurveP256}, - expectedCurves: []tls.CurveID{X25519MLKEM768PQKex, X25519Kyber768Draft00PQKex, tls.CurveP256}, + expectedCurves: []tls.CurveID{X25519MLKEM768PQKex, tls.CurveP256}, }, { name: "Non FIPS with Prefer PQ - no duplicates", @@ -62,14 +66,14 @@ func TestCurvePreferences(t *testing.T) { pqMode: features.PostQuantumPrefer, fipsEnabled: false, currentCurves: []tls.CurveID{tls.CurveP256, X25519Kyber768Draft00PQKex}, - expectedCurves: []tls.CurveID{X25519MLKEM768PQKex, X25519Kyber768Draft00PQKex, tls.CurveP256}, + expectedCurves: []tls.CurveID{X25519MLKEM768PQKex, tls.CurveP256, X25519Kyber768Draft00PQKex}, }, { name: "Non FIPS with Strict PQ", pqMode: features.PostQuantumStrict, fipsEnabled: false, currentCurves: []tls.CurveID{tls.CurveP256}, - expectedCurves: []tls.CurveID{X25519MLKEM768PQKex, X25519Kyber768Draft00PQKex}, + expectedCurves: []tls.CurveID{X25519MLKEM768PQKex}, }, } @@ -82,3 +86,34 @@ func TestCurvePreferences(t *testing.T) { }) } } + +func runClientServerHandshake(t *testing.T, curves []tls.CurveID) []tls.CurveID { + var advertisedCurves []tls.CurveID + ts := httptest.NewUnstartedServer(nil) + ts.TLS = &tls.Config{ // nolint: gosec + GetConfigForClient: func(chi *tls.ClientHelloInfo) (*tls.Config, error) { + advertisedCurves = slices.Clone(chi.SupportedCurves) + return nil, nil + }, + } + ts.StartTLS() + defer ts.Close() + clientTlsConfig := ts.Client().Transport.(*http.Transport).TLSClientConfig + clientTlsConfig.CurvePreferences = curves + resp, err := ts.Client().Head(ts.URL) + if err != nil { + t.Error(err) + return nil + } + defer resp.Body.Close() + return advertisedCurves +} + +func TestSupportedCurvesNegotiation(t *testing.T) { + for _, tcase := range []features.PostQuantumMode{features.PostQuantumPrefer} { + curves, err := curvePreference(tcase, fips.IsFipsEnabled(), make([]tls.CurveID, 0)) + require.NoError(t, err) + advertisedCurves := runClientServerHandshake(t, curves) + assert.Equal(t, curves, advertisedCurves) + } +}