mirror of
https://github.com/cloudflare/cloudflared.git
synced 2025-05-11 23:06:33 +00:00

This parameterizes relevant component tests by transport protocol where applicable. The motivation is to have coverage for (graceful or not) shutdown that was broken in QUIC. That logic (as well as reconnect) is different depending on the transport, so we should have it parameterized. In fact, the test is failing for QUIC (and passing for others) right now, which is expected until we roll out some edge fixes for QUIC. So we could have caught this earlier on.
119 lines
4.5 KiB
Python
119 lines
4.5 KiB
Python
#!/usr/bin/env python
|
|
from contextlib import contextmanager
|
|
import platform
|
|
import signal
|
|
import threading
|
|
import time
|
|
|
|
import pytest
|
|
import requests
|
|
|
|
from constants import protocols
|
|
from util import start_cloudflared, wait_tunnel_ready, check_tunnel_not_connected
|
|
|
|
|
|
def supported_signals():
|
|
if platform.system() == "Windows":
|
|
return [signal.SIGTERM]
|
|
return [signal.SIGTERM, signal.SIGINT]
|
|
|
|
|
|
class TestTermination:
|
|
grace_period = 5
|
|
timeout = 10
|
|
sse_endpoint = "/sse?freq=1s"
|
|
|
|
def _extra_config(self, protocol):
|
|
return {
|
|
"grace-period": f"{self.grace_period}s",
|
|
"protocol": protocol,
|
|
}
|
|
|
|
@pytest.mark.parametrize("signal", supported_signals())
|
|
@pytest.mark.parametrize("protocol", protocols())
|
|
def test_graceful_shutdown(self, tmp_path, component_tests_config, signal, protocol):
|
|
config = component_tests_config(self._extra_config(protocol))
|
|
with start_cloudflared(
|
|
tmp_path, config, new_process=True, capture_output=False) as cloudflared:
|
|
wait_tunnel_ready(tunnel_url=config.get_url())
|
|
|
|
connected = threading.Condition()
|
|
in_flight_req = threading.Thread(
|
|
target=self.stream_request, args=(config, connected, False, ))
|
|
in_flight_req.start()
|
|
|
|
with connected:
|
|
connected.wait(self.timeout)
|
|
# Send signal after the SSE connection is established
|
|
self.terminate_by_signal(cloudflared, signal)
|
|
self.wait_eyeball_thread(
|
|
in_flight_req, self.grace_period + self.timeout)
|
|
|
|
# test cloudflared terminates before grace period expires when all eyeball
|
|
# connections are drained
|
|
@pytest.mark.parametrize("signal", supported_signals())
|
|
@pytest.mark.parametrize("protocol", protocols())
|
|
def test_shutdown_once_no_connection(self, tmp_path, component_tests_config, signal, protocol):
|
|
config = component_tests_config(self._extra_config(protocol))
|
|
with start_cloudflared(
|
|
tmp_path, config, new_process=True, capture_output=False) as cloudflared:
|
|
wait_tunnel_ready(tunnel_url=config.get_url())
|
|
|
|
connected = threading.Condition()
|
|
in_flight_req = threading.Thread(
|
|
target=self.stream_request, args=(config, connected, True, ))
|
|
in_flight_req.start()
|
|
|
|
with connected:
|
|
connected.wait(self.timeout)
|
|
with self.within_grace_period():
|
|
# Send signal after the SSE connection is established
|
|
self.terminate_by_signal(cloudflared, signal)
|
|
self.wait_eyeball_thread(in_flight_req, self.grace_period)
|
|
|
|
@pytest.mark.parametrize("signal", supported_signals())
|
|
@pytest.mark.parametrize("protocol", protocols())
|
|
def test_no_connection_shutdown(self, tmp_path, component_tests_config, signal, protocol):
|
|
config = component_tests_config(self._extra_config(protocol))
|
|
with start_cloudflared(
|
|
tmp_path, config, new_process=True, capture_output=False) as cloudflared:
|
|
wait_tunnel_ready(tunnel_url=config.get_url())
|
|
with self.within_grace_period():
|
|
self.terminate_by_signal(cloudflared, signal)
|
|
|
|
def terminate_by_signal(self, cloudflared, sig):
|
|
cloudflared.send_signal(sig)
|
|
check_tunnel_not_connected()
|
|
cloudflared.wait()
|
|
|
|
def wait_eyeball_thread(self, thread, timeout):
|
|
thread.join(timeout)
|
|
assert thread.is_alive() == False, "eyeball thread is still alive"
|
|
|
|
# Using this context asserts logic within the context is executed within grace period
|
|
@contextmanager
|
|
def within_grace_period(self):
|
|
try:
|
|
start = time.time()
|
|
yield
|
|
finally:
|
|
duration = time.time() - start
|
|
assert duration < self.grace_period
|
|
|
|
def stream_request(self, config, connected, early_terminate):
|
|
expected_terminate_message = "502 Bad Gateway"
|
|
url = config.get_url() + self.sse_endpoint
|
|
|
|
with requests.get(url, timeout=5, stream=True) as resp:
|
|
with connected:
|
|
connected.notifyAll()
|
|
lines = 0
|
|
for line in resp.iter_lines():
|
|
if expected_terminate_message.encode() == line:
|
|
break
|
|
lines += 1
|
|
if early_terminate and lines == 2:
|
|
return
|
|
# /sse returns count followed by 2 new lines
|
|
assert lines >= (self.grace_period * 2)
|