mirror of
https://github.com/cloudflare/cloudflared.git
synced 2025-07-27 15:19:58 +00:00
AUTH-2052: Adds tests for SSH server
This commit is contained in:
195
ssh_server_tests/tests.py
Normal file
195
ssh_server_tests/tests.py
Normal file
@@ -0,0 +1,195 @@
|
||||
"""
|
||||
Cloudflared Integration tests
|
||||
"""
|
||||
|
||||
import unittest
|
||||
import subprocess
|
||||
import os
|
||||
import tempfile
|
||||
from contextlib import contextmanager
|
||||
|
||||
from pexpect import pxssh
|
||||
|
||||
|
||||
class TestSSHBase(unittest.TestCase):
|
||||
"""
|
||||
SSH test base class containing constants and helper funcs
|
||||
"""
|
||||
|
||||
HOSTNAME = os.environ["SSH_HOSTNAME"]
|
||||
SSH_USER = os.environ["SSH_USER"]
|
||||
SSH_TARGET = f"{SSH_USER}@{HOSTNAME}"
|
||||
AUTHORIZED_KEYS_SSH_CONFIG = os.environ["AUTHORIZED_KEYS_SSH_CONFIG"]
|
||||
SHORT_LIVED_CERT_SSH_CONFIG = os.environ["SHORT_LIVED_CERT_SSH_CONFIG"]
|
||||
SSH_OPTIONS = {"StrictHostKeyChecking": "no"}
|
||||
|
||||
@classmethod
|
||||
def get_ssh_command(cls, pty=True):
|
||||
"""
|
||||
Return ssh command arg list. If pty is true, a PTY is forced for the session.
|
||||
"""
|
||||
cmd = [
|
||||
"ssh",
|
||||
"-o",
|
||||
"StrictHostKeyChecking=no",
|
||||
"-F",
|
||||
cls.AUTHORIZED_KEYS_SSH_CONFIG,
|
||||
cls.SSH_TARGET,
|
||||
]
|
||||
if not pty:
|
||||
cmd += ["-T"]
|
||||
else:
|
||||
cmd += ["-tt"]
|
||||
|
||||
return cmd
|
||||
|
||||
@classmethod
|
||||
@contextmanager
|
||||
def ssh_session_manager(cls, *args, **kwargs):
|
||||
"""
|
||||
Context manager for interacting with a pxssh session.
|
||||
Disables pty echo on the remote server and ensures session is terminated afterward.
|
||||
"""
|
||||
session = pxssh.pxssh(options=cls.SSH_OPTIONS)
|
||||
|
||||
session.login(
|
||||
cls.HOSTNAME,
|
||||
username=cls.SSH_USER,
|
||||
original_prompt=r"[#@$]",
|
||||
ssh_config=kwargs.get("ssh_config", cls.AUTHORIZED_KEYS_SSH_CONFIG),
|
||||
ssh_tunnels=kwargs.get("ssh_tunnels", {}),
|
||||
)
|
||||
try:
|
||||
session.sendline("stty -echo")
|
||||
session.prompt()
|
||||
yield session
|
||||
finally:
|
||||
session.logout()
|
||||
|
||||
@staticmethod
|
||||
def get_command_output(session, cmd):
|
||||
"""
|
||||
Executes command on remote ssh server and waits for prompt.
|
||||
Returns command output
|
||||
"""
|
||||
session.sendline(cmd)
|
||||
session.prompt()
|
||||
return session.before.decode().strip()
|
||||
|
||||
def exec_command(self, cmd, shell=False):
|
||||
"""
|
||||
Executes command locally. Raises Assertion error for non-zero return code.
|
||||
Returns stdout and stderr
|
||||
"""
|
||||
proc = subprocess.Popen(
|
||||
cmd, stderr=subprocess.PIPE, stdout=subprocess.PIPE, shell=shell
|
||||
)
|
||||
raw_out, raw_err = proc.communicate()
|
||||
|
||||
out = raw_out.decode()
|
||||
err = raw_err.decode()
|
||||
self.assertEqual(proc.returncode, 0, msg=f"stdout: {out} stderr: {err}")
|
||||
return out.strip(), err.strip()
|
||||
|
||||
|
||||
class TestSSHCommandExec(TestSSHBase):
|
||||
"""
|
||||
Tests inline ssh command exec
|
||||
"""
|
||||
|
||||
# Name of file to be downloaded over SCP on remote server.
|
||||
REMOTE_SCP_FILENAME = os.environ["REMOTE_SCP_FILENAME"]
|
||||
|
||||
@classmethod
|
||||
def get_scp_base_command(cls):
|
||||
return [
|
||||
"scp",
|
||||
"-o",
|
||||
"StrictHostKeyChecking=no",
|
||||
"-v",
|
||||
"-F",
|
||||
cls.AUTHORIZED_KEYS_SSH_CONFIG,
|
||||
]
|
||||
|
||||
@unittest.skip(
|
||||
"This creates files on the remote. Should be skipped until server is dockerized."
|
||||
)
|
||||
def test_verbose_scp_sink_mode(self):
|
||||
with tempfile.NamedTemporaryFile() as fl:
|
||||
self.exec_command(
|
||||
self.get_scp_base_command() + [fl.name, f"{self.SSH_TARGET}:"]
|
||||
)
|
||||
|
||||
def test_verbose_scp_source_mode(self):
|
||||
with tempfile.TemporaryDirectory() as tmpdirname:
|
||||
self.exec_command(
|
||||
self.get_scp_base_command()
|
||||
+ [f"{self.SSH_TARGET}:{self.REMOTE_SCP_FILENAME}", tmpdirname]
|
||||
)
|
||||
local_filename = os.path.join(tmpdirname, self.REMOTE_SCP_FILENAME)
|
||||
|
||||
self.assertTrue(os.path.exists(local_filename))
|
||||
self.assertTrue(os.path.getsize(local_filename) > 0)
|
||||
|
||||
def test_pty_command(self):
|
||||
base_cmd = self.get_ssh_command()
|
||||
|
||||
out, _ = self.exec_command(base_cmd + ["whoami"])
|
||||
self.assertEqual(out.strip().lower(), self.SSH_USER.lower())
|
||||
|
||||
out, _ = self.exec_command(base_cmd + ["tty"])
|
||||
self.assertNotEqual(out, "not a tty")
|
||||
|
||||
def test_non_pty_command(self):
|
||||
base_cmd = self.get_ssh_command(pty=False)
|
||||
|
||||
out, _ = self.exec_command(base_cmd + ["whoami"])
|
||||
self.assertEqual(out.strip().lower(), self.SSH_USER.lower())
|
||||
|
||||
out, _ = self.exec_command(base_cmd + ["tty"])
|
||||
self.assertEqual(out, "not a tty")
|
||||
|
||||
|
||||
class TestSSHShell(TestSSHBase):
|
||||
"""
|
||||
Tests interactive SSH shell
|
||||
"""
|
||||
|
||||
# File path to a file on the remote server with root only read privileges.
|
||||
ROOT_ONLY_TEST_FILE_PATH = os.environ["ROOT_ONLY_TEST_FILE_PATH"]
|
||||
|
||||
def test_ssh_pty(self):
|
||||
with self.ssh_session_manager() as session:
|
||||
|
||||
# Test shell launched as correct user
|
||||
username = self.get_command_output(session, "whoami")
|
||||
self.assertEqual(username.lower(), self.SSH_USER.lower())
|
||||
|
||||
# Test USER env variable set
|
||||
user_var = self.get_command_output(session, "echo $USER")
|
||||
self.assertEqual(user_var.lower(), self.SSH_USER.lower())
|
||||
|
||||
# Test HOME env variable set to true user home.
|
||||
home_env = self.get_command_output(session, "echo $HOME")
|
||||
pwd = self.get_command_output(session, "pwd")
|
||||
self.assertEqual(pwd, home_env)
|
||||
|
||||
# Test shell launched in correct user home dir.
|
||||
self.assertIn(username, pwd)
|
||||
|
||||
# Ensure shell launched with correct user's permissions and privs.
|
||||
# Cant read root owned 0700 files.
|
||||
output = self.get_command_output(
|
||||
session, f"cat {self.ROOT_ONLY_TEST_FILE_PATH}"
|
||||
)
|
||||
self.assertIn("Permission denied", output)
|
||||
|
||||
def test_short_lived_cert_auth(self):
|
||||
with self.ssh_session_manager(
|
||||
ssh_config=self.SHORT_LIVED_CERT_SSH_CONFIG
|
||||
) as session:
|
||||
username = self.get_command_output(session, "whoami")
|
||||
self.assertEqual(username.lower(), self.SSH_USER.lower())
|
||||
|
||||
|
||||
unittest.main()
|
Reference in New Issue
Block a user