mirror of
https://github.com/cloudflare/cloudflared.git
synced 2025-07-27 00:49:57 +00:00
AUTH-2030: Support both authorized_key and short lived cert authentication simultaniously without specifiying at start time
This commit is contained in:
@@ -18,6 +18,23 @@ var (
|
||||
authorizedKeysDir = ".cloudflared/authorized_keys"
|
||||
)
|
||||
|
||||
func (s *SSHServer) configureAuthentication() {
|
||||
caCert, err := getCACert()
|
||||
if err != nil {
|
||||
s.logger.Info(err)
|
||||
}
|
||||
s.caCert = caCert
|
||||
s.PublicKeyHandler = s.authenticationHandler
|
||||
}
|
||||
|
||||
func (s *SSHServer) authenticationHandler(ctx ssh.Context, key ssh.PublicKey) bool {
|
||||
cert, ok := key.(*gossh.Certificate)
|
||||
if !ok {
|
||||
return s.authorizedKeyHandler(ctx, key)
|
||||
}
|
||||
return s.shortLivedCertHandler(ctx, cert)
|
||||
}
|
||||
|
||||
func (s *SSHServer) authorizedKeyHandler(ctx ssh.Context, key ssh.PublicKey) bool {
|
||||
sshUser, err := s.getUserFunc(ctx.User())
|
||||
if err != nil {
|
||||
@@ -56,20 +73,14 @@ func (s *SSHServer) authorizedKeyHandler(ctx ssh.Context, key ssh.PublicKey) boo
|
||||
return false
|
||||
}
|
||||
|
||||
func (s *SSHServer) shortLivedCertHandler(ctx ssh.Context, key ssh.PublicKey) bool {
|
||||
userCert, ok := key.(*gossh.Certificate)
|
||||
if !ok {
|
||||
s.logger.Debug("Received key is not an SSH certificate")
|
||||
return false
|
||||
}
|
||||
|
||||
if !ssh.KeysEqual(s.caCert, userCert.SignatureKey) {
|
||||
func (s *SSHServer) shortLivedCertHandler(ctx ssh.Context, cert *gossh.Certificate) bool {
|
||||
if !ssh.KeysEqual(s.caCert, cert.SignatureKey) {
|
||||
s.logger.Debug("CA certificate does not match user certificate signer")
|
||||
return false
|
||||
}
|
||||
|
||||
checker := gossh.CertChecker{}
|
||||
if err := checker.CheckCert(ctx.User(), userCert); err != nil {
|
||||
if err := checker.CheckCert(ctx.User(), cert); err != nil {
|
||||
s.logger.Debug(err)
|
||||
return false
|
||||
} else {
|
||||
|
@@ -18,6 +18,7 @@ import (
|
||||
"github.com/sirupsen/logrus/hooks/test"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
gossh "golang.org/x/crypto/ssh"
|
||||
)
|
||||
|
||||
const (
|
||||
@@ -90,7 +91,8 @@ func TestShortLivedCerts_Success(t *testing.T) {
|
||||
caCert := getKey(t, testCAFilename)
|
||||
sshServer := SSHServer{logger: log.CreateLogger(), caCert: caCert, getUserFunc: getMockUser}
|
||||
|
||||
userCert := getKey(t, testUserCertFilename)
|
||||
userCert, ok := getKey(t, testUserCertFilename).(*gossh.Certificate)
|
||||
require.True(t, ok)
|
||||
assert.True(t, sshServer.shortLivedCertHandler(context, userCert))
|
||||
}
|
||||
|
||||
@@ -101,7 +103,8 @@ func TestShortLivedCerts_CAsDontMatch(t *testing.T) {
|
||||
caCert := getKey(t, testOtherCAFilename)
|
||||
sshServer := SSHServer{logger: logger, caCert: caCert, getUserFunc: getMockUser}
|
||||
|
||||
userCert := getKey(t, testUserCertFilename)
|
||||
userCert, ok := getKey(t, testUserCertFilename).(*gossh.Certificate)
|
||||
require.True(t, ok)
|
||||
assert.False(t, sshServer.shortLivedCertHandler(context, userCert))
|
||||
assert.Equal(t, "CA certificate does not match user certificate signer", hook.LastEntry().Message)
|
||||
}
|
||||
@@ -113,7 +116,8 @@ func TestShortLivedCerts_UserDoesNotExist(t *testing.T) {
|
||||
caCert := getKey(t, testCAFilename)
|
||||
sshServer := SSHServer{logger: logger, caCert: caCert, getUserFunc: lookupUser}
|
||||
|
||||
userCert := getKey(t, testUserCertFilename)
|
||||
userCert, ok := getKey(t, testUserCertFilename).(*gossh.Certificate)
|
||||
require.True(t, ok)
|
||||
assert.False(t, sshServer.shortLivedCertHandler(context, userCert))
|
||||
assert.Contains(t, hook.LastEntry().Message, "Invalid user")
|
||||
}
|
||||
@@ -125,7 +129,8 @@ func TestShortLivedCerts_InvalidPrincipal(t *testing.T) {
|
||||
caCert := getKey(t, testCAFilename)
|
||||
sshServer := SSHServer{logger: logger, caCert: caCert, getUserFunc: lookupUser}
|
||||
|
||||
userCert := getKey(t, testUserCertFilename)
|
||||
userCert, ok := getKey(t, testUserCertFilename).(*gossh.Certificate)
|
||||
require.True(t, ok)
|
||||
assert.False(t, sshServer.shortLivedCertHandler(context, userCert))
|
||||
assert.Contains(t, hook.LastEntry().Message, "not in the set of valid principals for given certificate")
|
||||
}
|
||||
|
@@ -28,7 +28,7 @@ type SSHServer struct {
|
||||
getUserFunc func(string) (*User, error)
|
||||
}
|
||||
|
||||
func New(logger *logrus.Logger, address string, shutdownC chan struct{}, shortLivedCertAuth bool, idleTimeout, maxTimeout time.Duration) (*SSHServer, error) {
|
||||
func New(logger *logrus.Logger, address string, shutdownC chan struct{}, idleTimeout, maxTimeout time.Duration) (*SSHServer, error) {
|
||||
currentUser, err := user.Current()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
@@ -48,17 +48,7 @@ func New(logger *logrus.Logger, address string, shutdownC chan struct{}, shortLi
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if shortLivedCertAuth {
|
||||
caCert, err := getCACert()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
sshServer.caCert = caCert
|
||||
sshServer.PublicKeyHandler = sshServer.shortLivedCertHandler
|
||||
} else {
|
||||
sshServer.PublicKeyHandler = sshServer.authorizedKeyHandler
|
||||
}
|
||||
|
||||
sshServer.configureAuthentication()
|
||||
return &sshServer, nil
|
||||
}
|
||||
|
||||
@@ -111,6 +101,7 @@ func (s *SSHServer) connectionHandler(session ssh.Session) {
|
||||
return
|
||||
}
|
||||
|
||||
// Supplementary groups are not explicitly specified. They seem to be inherited by default.
|
||||
cmd.SysProcAttr = &syscall.SysProcAttr{Credential: &syscall.Credential{Uid: uidInt, Gid: gidInt}}
|
||||
cmd.Env = append(cmd.Env, fmt.Sprintf("TERM=%s", ptyReq.Term))
|
||||
cmd.Env = append(cmd.Env, fmt.Sprintf("USER=%s", sshUser.Username))
|
||||
|
@@ -11,7 +11,7 @@ import (
|
||||
|
||||
type SSHServer struct{}
|
||||
|
||||
func New(_ *logrus.Logger, _ string, _ chan struct{}, _ bool, _, _ time.Duration) (*SSHServer, error) {
|
||||
func New(_ *logrus.Logger, _ string, _ chan struct{}, _, _ time.Duration) (*SSHServer, error) {
|
||||
return nil, errors.New("cloudflared ssh server is not supported on windows")
|
||||
}
|
||||
|
||||
|
Reference in New Issue
Block a user