mirror of
https://github.com/cloudflare/cloudflared.git
synced 2025-07-27 22:29:57 +00:00
AUTH-2036: Refactor user retrieval, shutdown after ssh server stops, add custom version string
This commit is contained in:
@@ -22,7 +22,6 @@ import (
|
||||
)
|
||||
|
||||
const (
|
||||
validPrincipal = "testUser"
|
||||
testDir = "testdata"
|
||||
testUserKeyFilename = "id_rsa.pub"
|
||||
testCAFilename = "ca.pub"
|
||||
@@ -30,7 +29,10 @@ const (
|
||||
testUserCertFilename = "id_rsa-cert.pub"
|
||||
)
|
||||
|
||||
var logger, hook = test.NewNullLogger()
|
||||
var (
|
||||
logger, hook = test.NewNullLogger()
|
||||
mockUser = &User{Username: "testUser", HomeDir: testDir}
|
||||
)
|
||||
|
||||
func TestMain(m *testing.M) {
|
||||
authorizedKeysDir = testUserKeyFilename
|
||||
@@ -40,20 +42,20 @@ func TestMain(m *testing.M) {
|
||||
}
|
||||
|
||||
func TestPublicKeyAuth_Success(t *testing.T) {
|
||||
context, cancel := newMockContext(validPrincipal)
|
||||
context, cancel := newMockContext(mockUser)
|
||||
defer cancel()
|
||||
|
||||
sshServer := SSHServer{getUserFunc: getMockUser}
|
||||
sshServer := SSHServer{logger: logger}
|
||||
|
||||
pubKey := getKey(t, testUserKeyFilename)
|
||||
assert.True(t, sshServer.authorizedKeyHandler(context, pubKey))
|
||||
}
|
||||
|
||||
func TestPublicKeyAuth_MissingKey(t *testing.T) {
|
||||
context, cancel := newMockContext(validPrincipal)
|
||||
context, cancel := newMockContext(mockUser)
|
||||
defer cancel()
|
||||
|
||||
sshServer := SSHServer{logger: logger, getUserFunc: getMockUser}
|
||||
sshServer := SSHServer{logger: logger}
|
||||
|
||||
pubKey := getKey(t, testOtherCAFilename)
|
||||
assert.False(t, sshServer.authorizedKeyHandler(context, pubKey))
|
||||
@@ -61,23 +63,27 @@ func TestPublicKeyAuth_MissingKey(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestPublicKeyAuth_InvalidUser(t *testing.T) {
|
||||
context, cancel := newMockContext("notAUser")
|
||||
context, cancel := newMockContext(&User{Username: "notAUser"})
|
||||
defer cancel()
|
||||
|
||||
sshServer := SSHServer{logger: logger, getUserFunc: lookupUser}
|
||||
sshServer := SSHServer{logger: logger}
|
||||
|
||||
pubKey := getKey(t, testUserKeyFilename)
|
||||
assert.False(t, sshServer.authorizedKeyHandler(context, pubKey))
|
||||
assert.False(t, sshServer.authenticationHandler(context, pubKey))
|
||||
assert.Contains(t, hook.LastEntry().Message, "Invalid user")
|
||||
}
|
||||
|
||||
func TestPublicKeyAuth_MissingFile(t *testing.T) {
|
||||
currentUser, err := user.Current()
|
||||
tempUser, err := user.Current()
|
||||
require.Nil(t, err)
|
||||
context, cancel := newMockContext(currentUser.Username)
|
||||
currentUser, err := lookupUser(tempUser.Username)
|
||||
require.Nil(t, err)
|
||||
|
||||
require.Nil(t, err)
|
||||
context, cancel := newMockContext(currentUser)
|
||||
defer cancel()
|
||||
|
||||
sshServer := SSHServer{Server: ssh.Server{}, logger: logger, getUserFunc: lookupUser}
|
||||
sshServer := SSHServer{Server: ssh.Server{}, logger: logger}
|
||||
|
||||
pubKey := getKey(t, testUserKeyFilename)
|
||||
assert.False(t, sshServer.authorizedKeyHandler(context, pubKey))
|
||||
@@ -85,11 +91,11 @@ func TestPublicKeyAuth_MissingFile(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestShortLivedCerts_Success(t *testing.T) {
|
||||
context, cancel := newMockContext(validPrincipal)
|
||||
context, cancel := newMockContext(mockUser)
|
||||
defer cancel()
|
||||
|
||||
caCert := getKey(t, testCAFilename)
|
||||
sshServer := SSHServer{logger: log.CreateLogger(), caCert: caCert, getUserFunc: getMockUser}
|
||||
sshServer := SSHServer{logger: log.CreateLogger(), caCert: caCert}
|
||||
|
||||
userCert, ok := getKey(t, testUserCertFilename).(*gossh.Certificate)
|
||||
require.True(t, ok)
|
||||
@@ -97,11 +103,11 @@ func TestShortLivedCerts_Success(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestShortLivedCerts_CAsDontMatch(t *testing.T) {
|
||||
context, cancel := newMockContext(validPrincipal)
|
||||
context, cancel := newMockContext(mockUser)
|
||||
defer cancel()
|
||||
|
||||
caCert := getKey(t, testOtherCAFilename)
|
||||
sshServer := SSHServer{logger: logger, caCert: caCert, getUserFunc: getMockUser}
|
||||
sshServer := SSHServer{logger: logger, caCert: caCert}
|
||||
|
||||
userCert, ok := getKey(t, testUserCertFilename).(*gossh.Certificate)
|
||||
require.True(t, ok)
|
||||
@@ -109,25 +115,12 @@ func TestShortLivedCerts_CAsDontMatch(t *testing.T) {
|
||||
assert.Equal(t, "CA certificate does not match user certificate signer", hook.LastEntry().Message)
|
||||
}
|
||||
|
||||
func TestShortLivedCerts_UserDoesNotExist(t *testing.T) {
|
||||
context, cancel := newMockContext(validPrincipal)
|
||||
defer cancel()
|
||||
|
||||
caCert := getKey(t, testCAFilename)
|
||||
sshServer := SSHServer{logger: logger, caCert: caCert, getUserFunc: lookupUser}
|
||||
|
||||
userCert, ok := getKey(t, testUserCertFilename).(*gossh.Certificate)
|
||||
require.True(t, ok)
|
||||
assert.False(t, sshServer.shortLivedCertHandler(context, userCert))
|
||||
assert.Contains(t, hook.LastEntry().Message, "Invalid user")
|
||||
}
|
||||
|
||||
func TestShortLivedCerts_InvalidPrincipal(t *testing.T) {
|
||||
context, cancel := newMockContext("notAUser")
|
||||
context, cancel := newMockContext(&User{Username: "NotAUser"})
|
||||
defer cancel()
|
||||
|
||||
caCert := getKey(t, testCAFilename)
|
||||
sshServer := SSHServer{logger: logger, caCert: caCert, getUserFunc: lookupUser}
|
||||
sshServer := SSHServer{logger: logger, caCert: caCert}
|
||||
|
||||
userCert, ok := getKey(t, testUserCertFilename).(*gossh.Certificate)
|
||||
require.True(t, ok)
|
||||
@@ -135,14 +128,6 @@ func TestShortLivedCerts_InvalidPrincipal(t *testing.T) {
|
||||
assert.Contains(t, hook.LastEntry().Message, "not in the set of valid principals for given certificate")
|
||||
}
|
||||
|
||||
func getMockUser(_ string) (*User, error) {
|
||||
return &User{
|
||||
Username: validPrincipal,
|
||||
HomeDir: testDir,
|
||||
}, nil
|
||||
|
||||
}
|
||||
|
||||
func getKey(t *testing.T, filename string) ssh.PublicKey {
|
||||
path := path.Join(testDir, filename)
|
||||
bytes, err := ioutil.ReadFile(path)
|
||||
@@ -157,10 +142,13 @@ type mockSSHContext struct {
|
||||
*sync.Mutex
|
||||
}
|
||||
|
||||
func newMockContext(user string) (*mockSSHContext, context.CancelFunc) {
|
||||
func newMockContext(user *User) (*mockSSHContext, context.CancelFunc) {
|
||||
innerCtx, cancel := context.WithCancel(context.Background())
|
||||
mockCtx := &mockSSHContext{innerCtx, &sync.Mutex{}}
|
||||
mockCtx.SetValue("user", user)
|
||||
mockCtx.SetValue("sshUser", user)
|
||||
|
||||
// This naming is confusing but we cant change it because this mocks the SSHContext struct in gliderlabs/ssh
|
||||
mockCtx.SetValue("user", user.Username)
|
||||
return mockCtx, cancel
|
||||
}
|
||||
|
||||
|
Reference in New Issue
Block a user