mirror of
https://github.com/cloudflare/cloudflared.git
synced 2025-07-27 15:49:58 +00:00
TUN-1196: Allow TLS config client CA and root CA to be constructed from multiple certificates
This commit is contained in:
@@ -6,150 +6,81 @@ import (
|
||||
"crypto/tls"
|
||||
"crypto/x509"
|
||||
"io/ioutil"
|
||||
"net"
|
||||
|
||||
"github.com/cloudflare/cloudflared/log"
|
||||
"github.com/pkg/errors"
|
||||
"gopkg.in/urfave/cli.v2"
|
||||
"runtime"
|
||||
)
|
||||
|
||||
var logger = log.CreateLogger()
|
||||
|
||||
// CLIFlags names the flags used to configure TLS for a command or subsystem.
|
||||
// The nil value for a field means the flag is ignored.
|
||||
type CLIFlags struct {
|
||||
Cert string
|
||||
Key string
|
||||
ClientCert string
|
||||
RootCA string
|
||||
// Config is the user provided parameters to create a tls.Config
|
||||
type TLSParameters struct {
|
||||
Cert string
|
||||
Key string
|
||||
GetCertificate *CertReloader
|
||||
ClientCAs []string
|
||||
RootCAs []string
|
||||
ServerName string
|
||||
CurvePreferences []tls.CurveID
|
||||
}
|
||||
|
||||
// GetConfig returns a TLS configuration according to the flags defined in f and
|
||||
// set by the user.
|
||||
func (f CLIFlags) GetConfig(c *cli.Context) *tls.Config {
|
||||
config := &tls.Config{}
|
||||
|
||||
if c.IsSet(f.Cert) && c.IsSet(f.Key) {
|
||||
cert, err := tls.LoadX509KeyPair(c.String(f.Cert), c.String(f.Key))
|
||||
// GetConfig returns a TLS configuration according to the Config set by the user.
|
||||
func GetConfig(p *TLSParameters) (*tls.Config, error) {
|
||||
tlsconfig := &tls.Config{}
|
||||
if p.GetCertificate != nil {
|
||||
tlsconfig.GetCertificate = p.GetCertificate.Cert
|
||||
tlsconfig.BuildNameToCertificate()
|
||||
} else if p.Cert != "" && p.Key != "" {
|
||||
cert, err := tls.LoadX509KeyPair(p.Cert, p.Key)
|
||||
if err != nil {
|
||||
logger.WithError(err).Fatal("Error parsing X509 key pair")
|
||||
return nil, errors.Wrap(err, "Error parsing X509 key pair")
|
||||
}
|
||||
config.Certificates = []tls.Certificate{cert}
|
||||
config.BuildNameToCertificate()
|
||||
tlsconfig.Certificates = []tls.Certificate{cert}
|
||||
tlsconfig.BuildNameToCertificate()
|
||||
}
|
||||
return f.finishGettingConfig(c, config)
|
||||
}
|
||||
|
||||
func (f CLIFlags) GetConfigReloadableCert(c *cli.Context, cr *CertReloader) *tls.Config {
|
||||
config := &tls.Config{
|
||||
GetCertificate: cr.Cert,
|
||||
}
|
||||
config.BuildNameToCertificate()
|
||||
return f.finishGettingConfig(c, config)
|
||||
}
|
||||
|
||||
func (f CLIFlags) finishGettingConfig(c *cli.Context, config *tls.Config) *tls.Config {
|
||||
if c.IsSet(f.ClientCert) {
|
||||
if len(p.ClientCAs) > 0 {
|
||||
// set of root certificate authorities that servers use if required to verify a client certificate
|
||||
// by the policy in ClientAuth
|
||||
config.ClientCAs = LoadCert(c.String(f.ClientCert))
|
||||
clientCAs, err := LoadCert(p.ClientCAs)
|
||||
if err != nil {
|
||||
return nil, errors.Wrap(err, "Error loading client CAs")
|
||||
}
|
||||
tlsconfig.ClientCAs = clientCAs
|
||||
// server's policy for TLS Client Authentication. Default is no client cert
|
||||
config.ClientAuth = tls.RequireAndVerifyClientCert
|
||||
tlsconfig.ClientAuth = tls.RequireAndVerifyClientCert
|
||||
}
|
||||
// set of root certificate authorities that clients use when verifying server certificates
|
||||
if c.IsSet(f.RootCA) {
|
||||
config.RootCAs = LoadCert(c.String(f.RootCA))
|
||||
|
||||
if len(p.RootCAs) > 0 {
|
||||
rootCAs, err := LoadCert(p.RootCAs)
|
||||
if err != nil {
|
||||
return nil, errors.Wrap(err, "Error loading root CAs")
|
||||
}
|
||||
tlsconfig.RootCAs = rootCAs
|
||||
}
|
||||
// we optimize CurveP256
|
||||
config.CurvePreferences = []tls.CurveID{tls.CurveP256}
|
||||
return config
|
||||
|
||||
if p.ServerName != "" {
|
||||
tlsconfig.ServerName = p.ServerName
|
||||
}
|
||||
|
||||
if len(p.CurvePreferences) > 0 {
|
||||
tlsconfig.CurvePreferences = p.CurvePreferences
|
||||
} else {
|
||||
// Cloudflare optimize CurveP256
|
||||
tlsconfig.CurvePreferences = []tls.CurveID{tls.CurveP256}
|
||||
}
|
||||
|
||||
return tlsconfig, nil
|
||||
}
|
||||
|
||||
// LoadCert creates a CertPool containing all certificates in a PEM-format file.
|
||||
func LoadCert(certPath string) *x509.CertPool {
|
||||
caCert, err := ioutil.ReadFile(certPath)
|
||||
if err != nil {
|
||||
logger.WithError(err).Fatalf("Error reading certificate %s", certPath)
|
||||
}
|
||||
func LoadCert(certPaths []string) (*x509.CertPool, error) {
|
||||
ca := x509.NewCertPool()
|
||||
if !ca.AppendCertsFromPEM(caCert) {
|
||||
logger.WithError(err).Fatalf("Error parsing certificate %s", certPath)
|
||||
}
|
||||
return ca
|
||||
}
|
||||
|
||||
func LoadGlobalCertPool() (*x509.CertPool, error) {
|
||||
success := false
|
||||
|
||||
// First, obtain the system certificate pool
|
||||
certPool, systemCertPoolErr := x509.SystemCertPool()
|
||||
if systemCertPoolErr != nil {
|
||||
if runtime.GOOS != "windows" {
|
||||
logger.Warnf("error obtaining the system certificates: %s", systemCertPoolErr)
|
||||
for _, certPath := range certPaths {
|
||||
caCert, err := ioutil.ReadFile(certPath)
|
||||
if err != nil {
|
||||
return nil, errors.Wrapf(err, "Error reading certificate %s", certPath)
|
||||
}
|
||||
certPool = x509.NewCertPool()
|
||||
} else {
|
||||
success = true
|
||||
}
|
||||
|
||||
// Next, append the Cloudflare CA pool into the system pool
|
||||
if !certPool.AppendCertsFromPEM(cloudflareRootCA) {
|
||||
logger.Warn("could not append the CF certificate to the cloudflared certificate pool")
|
||||
} else {
|
||||
success = true
|
||||
}
|
||||
|
||||
if success != true { // Obtaining any of the CAs has failed; this is a fatal error
|
||||
return nil, errors.New("error loading any of the CAs into the global certificate pool")
|
||||
}
|
||||
|
||||
// Finally, add the Hello certificate into the pool (since it's self-signed)
|
||||
helloCertificate, err := GetHelloCertificateX509()
|
||||
if err != nil {
|
||||
logger.Warn("error obtaining the Hello server certificate")
|
||||
}
|
||||
|
||||
certPool.AddCert(helloCertificate)
|
||||
|
||||
return certPool, nil
|
||||
}
|
||||
|
||||
func LoadOriginCertPool(originCAPoolPEM []byte) (*x509.CertPool, error) {
|
||||
success := false
|
||||
|
||||
// Get the global pool
|
||||
certPool, globalPoolErr := LoadGlobalCertPool()
|
||||
if globalPoolErr != nil {
|
||||
certPool = x509.NewCertPool()
|
||||
} else {
|
||||
success = true
|
||||
}
|
||||
|
||||
// Then, add any custom origin CA pool the user may have passed
|
||||
if originCAPoolPEM != nil {
|
||||
if !certPool.AppendCertsFromPEM(originCAPoolPEM) {
|
||||
logger.Warn("could not append the provided origin CA to the cloudflared certificate pool")
|
||||
} else {
|
||||
success = true
|
||||
if !ca.AppendCertsFromPEM(caCert) {
|
||||
return nil, errors.Wrapf(err, "Error parsing certificate %s", certPath)
|
||||
}
|
||||
}
|
||||
|
||||
if success != true {
|
||||
return nil, errors.New("error loading any of the CAs into the origin certificate pool")
|
||||
}
|
||||
|
||||
return certPool, nil
|
||||
}
|
||||
|
||||
func CreateTunnelConfig(c *cli.Context, addrs []string) *tls.Config {
|
||||
tlsConfig := CLIFlags{RootCA: "cacert"}.GetConfig(c)
|
||||
if tlsConfig.RootCAs == nil {
|
||||
tlsConfig.RootCAs = GetCloudflareRootCA()
|
||||
tlsConfig.ServerName = "cftunnel.com"
|
||||
} else if len(addrs) > 0 {
|
||||
// Set for development environments and for testing specific origintunneld instances
|
||||
tlsConfig.ServerName, _, _ = net.SplitHostPort(addrs[0])
|
||||
}
|
||||
return tlsConfig
|
||||
return ca, nil
|
||||
}
|
||||
|
Reference in New Issue
Block a user