diff --git a/nixos/doc/manual/development/writing-nixos-tests.section.md b/nixos/doc/manual/development/writing-nixos-tests.section.md index bd588e2ba80b..5b08975e5ea4 100644 --- a/nixos/doc/manual/development/writing-nixos-tests.section.md +++ b/nixos/doc/manual/development/writing-nixos-tests.section.md @@ -121,8 +121,7 @@ and checks that the output is more-or-less correct: ```py machine.start() machine.wait_for_unit("default.target") -if not "Linux" in machine.succeed("uname"): - raise Exception("Wrong OS") +t.assertIn("Linux", machine.succeed("uname"), "Wrong OS") ``` The first line is technically unnecessary; machines are implicitly started @@ -134,6 +133,8 @@ starting them in parallel: start_all() ``` +Under the variable `t`, all assertions from [`unittest.TestCase`](https://docs.python.org/3/library/unittest.html) are available. + If the hostname of a node contains characters that can't be used in a Python variable name, those characters will be replaced with underscores in the variable name, so `nodes.machine-a` will be exposed diff --git a/nixos/lib/test-driver/default.nix b/nixos/lib/test-driver/default.nix index f22744806d48..91db5d8be3c2 100644 --- a/nixos/lib/test-driver/default.nix +++ b/nixos/lib/test-driver/default.nix @@ -31,6 +31,7 @@ python3Packages.buildPythonApplication { colorama junit-xml ptpython + ipython ] ++ extraPythonPackages python3Packages; diff --git a/nixos/lib/test-driver/src/pyproject.toml b/nixos/lib/test-driver/src/pyproject.toml index ac83eed268d9..fa4e6a2de127 100644 --- a/nixos/lib/test-driver/src/pyproject.toml +++ b/nixos/lib/test-driver/src/pyproject.toml @@ -21,7 +21,7 @@ target-version = "py312" line-length = 88 lint.select = ["E", "F", "I", "U", "N"] -lint.ignore = ["E501"] +lint.ignore = ["E501", "N818"] # xxx: we can import https://pypi.org/project/types-colorama/ here [[tool.mypy.overrides]] diff --git a/nixos/lib/test-driver/src/test_driver/__init__.py b/nixos/lib/test-driver/src/test_driver/__init__.py index 1c0793aa75a5..26d8391017e8 100755 --- a/nixos/lib/test-driver/src/test_driver/__init__.py +++ b/nixos/lib/test-driver/src/test_driver/__init__.py @@ -3,7 +3,7 @@ import os import time from pathlib import Path -import ptpython.repl +import ptpython.ipython from test_driver.driver import Driver from test_driver.logger import ( @@ -136,11 +136,10 @@ def main() -> None: if args.interactive: history_dir = os.getcwd() history_path = os.path.join(history_dir, ".nixos-test-history") - ptpython.repl.embed( - driver.test_symbols(), - {}, + ptpython.ipython.embed( + user_ns=driver.test_symbols(), history_filename=history_path, - ) + ) # type:ignore else: tic = time.time() driver.run_tests() diff --git a/nixos/lib/test-driver/src/test_driver/driver.py b/nixos/lib/test-driver/src/test_driver/driver.py index 6061c1bc09b8..49b6692bf422 100644 --- a/nixos/lib/test-driver/src/test_driver/driver.py +++ b/nixos/lib/test-driver/src/test_driver/driver.py @@ -1,13 +1,17 @@ import os import re import signal +import sys import tempfile import threading +import traceback from collections.abc import Callable, Iterator from contextlib import AbstractContextManager, contextmanager from pathlib import Path from typing import Any +from unittest import TestCase +from test_driver.errors import MachineError, RequestedAssertionFailed from test_driver.logger import AbstractLogger from test_driver.machine import Machine, NixStartScript, retry from test_driver.polling_condition import PollingCondition @@ -16,6 +20,18 @@ from test_driver.vlan import VLan SENTINEL = object() +class AssertionTester(TestCase): + """ + Subclass of `unittest.TestCase` which is used in the + `testScript` to perform assertions. + + It throws a custom exception whose parent class + gets special treatment in the logs. + """ + + failureException = RequestedAssertionFailed + + def get_tmp_dir() -> Path: """Returns a temporary directory that is defined by TMPDIR, TEMP, TMP or CWD Raises an exception in case the retrieved temporary directory is not writeable @@ -115,7 +131,7 @@ class Driver: try: yield except Exception as e: - self.logger.error(f'Test "{name}" failed with error: "{e}"') + self.logger.log_test_error(f'Test "{name}" failed with error: "{e}"') raise e def test_symbols(self) -> dict[str, Any]: @@ -140,6 +156,7 @@ class Driver: serial_stdout_on=self.serial_stdout_on, polling_condition=self.polling_condition, Machine=Machine, # for typing + t=AssertionTester(), ) machine_symbols = {pythonize_name(m.name): m for m in self.machines} # If there's exactly one machine, make it available under the name @@ -163,7 +180,36 @@ class Driver: """Run the test script""" with self.logger.nested("run the VM test script"): symbols = self.test_symbols() # call eagerly - exec(self.tests, symbols, None) + try: + exec(self.tests, symbols, None) + except MachineError: + for line in traceback.format_exc().splitlines(): + self.logger.log_test_error(line) + sys.exit(1) + except RequestedAssertionFailed: + exc_type, exc, tb = sys.exc_info() + # We manually print the stack frames, keeping only the ones from the test script + # (note: because the script is not a real file, the frame filename is ``) + filtered = [ + frame + for frame in traceback.extract_tb(tb) + if frame.filename == "" + ] + + self.logger.log_test_error("Traceback (most recent call last):") + + code = self.tests.splitlines() + for frame, line in zip(filtered, traceback.format_list(filtered)): + self.logger.log_test_error(line.rstrip()) + if lineno := frame.lineno: + self.logger.log_test_error(f" {code[lineno - 1].strip()}") + + self.logger.log_test_error("") # blank line for readability + exc_prefix = exc_type.__name__ if exc_type is not None else "Error" + for line in f"{exc_prefix}: {exc}".splitlines(): + self.logger.log_test_error(line) + + sys.exit(1) def run_tests(self) -> None: """Run the test script (for non-interactive test runs)""" diff --git a/nixos/lib/test-driver/src/test_driver/errors.py b/nixos/lib/test-driver/src/test_driver/errors.py new file mode 100644 index 000000000000..fe072b5185c9 --- /dev/null +++ b/nixos/lib/test-driver/src/test_driver/errors.py @@ -0,0 +1,20 @@ +class MachineError(Exception): + """ + Exception that indicates an error that is NOT the user's fault, + i.e. something went wrong without the test being necessarily invalid, + such as failing OCR. + + To make it easier to spot, this exception (and its subclasses) + get a `!!!` prefix in the log output. + """ + + +class RequestedAssertionFailed(AssertionError): + """ + Special assertion that gets thrown on an assertion error, + e.g. a failing `t.assertEqual(...)` or `machine.succeed(...)`. + + This gets special treatment in error reporting: i.e. it gets + `!!!` as prefix just as `MachineError`, but only stack frames coming + from `testScript` will show up in logs. + """ diff --git a/nixos/lib/test-driver/src/test_driver/logger.py b/nixos/lib/test-driver/src/test_driver/logger.py index 564d39f4f055..a218d234fe3f 100644 --- a/nixos/lib/test-driver/src/test_driver/logger.py +++ b/nixos/lib/test-driver/src/test_driver/logger.py @@ -44,6 +44,10 @@ class AbstractLogger(ABC): def error(self, *args, **kwargs) -> None: # type: ignore pass + @abstractmethod + def log_test_error(self, *args, **kwargs) -> None: # type:ignore + pass + @abstractmethod def log_serial(self, message: str, machine: str) -> None: pass @@ -97,6 +101,9 @@ class JunitXMLLogger(AbstractLogger): self.tests[self.currentSubtest].stderr += args[0] + os.linesep self.tests[self.currentSubtest].failure = True + def log_test_error(self, *args, **kwargs) -> None: # type: ignore + self.error(*args, **kwargs) + def log_serial(self, message: str, machine: str) -> None: if not self._print_serial_logs: return @@ -156,6 +163,10 @@ class CompositeLogger(AbstractLogger): for logger in self.logger_list: logger.warning(*args, **kwargs) + def log_test_error(self, *args, **kwargs) -> None: # type: ignore + for logger in self.logger_list: + logger.log_test_error(*args, **kwargs) + def error(self, *args, **kwargs) -> None: # type: ignore for logger in self.logger_list: logger.error(*args, **kwargs) @@ -202,7 +213,7 @@ class TerminalLogger(AbstractLogger): tic = time.time() yield toc = time.time() - self.log(f"(finished: {message}, in {toc - tic:.2f} seconds)") + self.log(f"(finished: {message}, in {toc - tic:.2f} seconds)", attributes) def info(self, *args, **kwargs) -> None: # type: ignore self.log(*args, **kwargs) @@ -222,6 +233,11 @@ class TerminalLogger(AbstractLogger): self._eprint(Style.DIM + f"{machine} # {message}" + Style.RESET_ALL) + def log_test_error(self, *args, **kwargs) -> None: # type: ignore + prefix = Fore.RED + "!!! " + Style.RESET_ALL + # NOTE: using `warning` instead of `error` to ensure it does not exit after printing the first log + self.warning(f"{prefix}{args[0]}", *args[1:], **kwargs) + class XMLLogger(AbstractLogger): def __init__(self, outfile: str) -> None: @@ -261,6 +277,9 @@ class XMLLogger(AbstractLogger): def error(self, *args, **kwargs) -> None: # type: ignore self.log(*args, **kwargs) + def log_test_error(self, *args, **kwargs) -> None: # type: ignore + self.log(*args, **kwargs) + def log(self, message: str, attributes: dict[str, str] = {}) -> None: self.drain_log_queue() self.log_line(message, attributes) diff --git a/nixos/lib/test-driver/src/test_driver/machine.py b/nixos/lib/test-driver/src/test_driver/machine.py index cba386ae86b4..1b9dd1262ce6 100644 --- a/nixos/lib/test-driver/src/test_driver/machine.py +++ b/nixos/lib/test-driver/src/test_driver/machine.py @@ -19,6 +19,7 @@ from pathlib import Path from queue import Queue from typing import Any +from test_driver.errors import MachineError, RequestedAssertionFailed from test_driver.logger import AbstractLogger from .qmp import QMPSession @@ -129,7 +130,7 @@ def _preprocess_screenshot(screenshot_path: str, negate: bool = False) -> str: ) if ret.returncode != 0: - raise Exception( + raise MachineError( f"Image processing failed with exit code {ret.returncode}, stdout: {ret.stdout.decode()}, stderr: {ret.stderr.decode()}" ) @@ -140,7 +141,7 @@ def _perform_ocr_on_screenshot( screenshot_path: str, model_ids: Iterable[int] ) -> list[str]: if shutil.which("tesseract") is None: - raise Exception("OCR requested but enableOCR is false") + raise MachineError("OCR requested but enableOCR is false") processed_image = _preprocess_screenshot(screenshot_path, negate=False) processed_negative = _preprocess_screenshot(screenshot_path, negate=True) @@ -163,7 +164,7 @@ def _perform_ocr_on_screenshot( capture_output=True, ) if ret.returncode != 0: - raise Exception(f"OCR failed with exit code {ret.returncode}") + raise MachineError(f"OCR failed with exit code {ret.returncode}") model_results.append(ret.stdout.decode("utf-8")) return model_results @@ -180,7 +181,9 @@ def retry(fn: Callable, timeout: int = 900) -> None: time.sleep(1) if not fn(True): - raise Exception(f"action timed out after {timeout} seconds") + raise RequestedAssertionFailed( + f"action timed out after {timeout} tries with one-second pause in-between" + ) class StartCommand: @@ -409,14 +412,14 @@ class Machine: def check_active(_last_try: bool) -> bool: state = self.get_unit_property(unit, "ActiveState", user) if state == "failed": - raise Exception(f'unit "{unit}" reached state "{state}"') + raise RequestedAssertionFailed(f'unit "{unit}" reached state "{state}"') if state == "inactive": status, jobs = self.systemctl("list-jobs --full 2>&1", user) if "No jobs" in jobs: info = self.get_unit_info(unit, user) if info["ActiveState"] == state: - raise Exception( + raise RequestedAssertionFailed( f'unit "{unit}" is inactive and there are no pending jobs' ) @@ -431,7 +434,7 @@ class Machine: def get_unit_info(self, unit: str, user: str | None = None) -> dict[str, str]: status, lines = self.systemctl(f'--no-pager show "{unit}"', user) if status != 0: - raise Exception( + raise RequestedAssertionFailed( f'retrieving systemctl info for unit "{unit}"' + ("" if user is None else f' under user "{user}"') + f" failed with exit code {status}" @@ -461,7 +464,7 @@ class Machine: user, ) if status != 0: - raise Exception( + raise RequestedAssertionFailed( f'retrieving systemctl property "{property}" for unit "{unit}"' + ("" if user is None else f' under user "{user}"') + f" failed with exit code {status}" @@ -509,7 +512,7 @@ class Machine: info = self.get_unit_info(unit) state = info["ActiveState"] if state != require_state: - raise Exception( + raise RequestedAssertionFailed( f"Expected unit '{unit}' to to be in state " f"'{require_state}' but it is in state '{state}'" ) @@ -663,7 +666,9 @@ class Machine: (status, out) = self.execute(command, timeout=timeout) if status != 0: self.log(f"output: {out}") - raise Exception(f"command `{command}` failed (exit code {status})") + raise RequestedAssertionFailed( + f"command `{command}` failed (exit code {status})" + ) output += out return output @@ -677,7 +682,9 @@ class Machine: with self.nested(f"must fail: {command}"): (status, out) = self.execute(command, timeout=timeout) if status == 0: - raise Exception(f"command `{command}` unexpectedly succeeded") + raise RequestedAssertionFailed( + f"command `{command}` unexpectedly succeeded" + ) output += out return output @@ -922,7 +929,7 @@ class Machine: ret = subprocess.run(f"pnmtopng '{tmp}' > '{filename}'", shell=True) os.unlink(tmp) if ret.returncode != 0: - raise Exception("Cannot convert screenshot") + raise MachineError("Cannot convert screenshot") def copy_from_host_via_shell(self, source: str, target: str) -> None: """Copy a file from the host into the guest by piping it over the diff --git a/nixos/lib/test-script-prepend.py b/nixos/lib/test-script-prepend.py index 9d2efdf97303..31dad14ef8dd 100644 --- a/nixos/lib/test-script-prepend.py +++ b/nixos/lib/test-script-prepend.py @@ -8,6 +8,7 @@ from test_driver.logger import AbstractLogger from typing import Callable, Iterator, ContextManager, Optional, List, Dict, Any, Union from typing_extensions import Protocol from pathlib import Path +from unittest import TestCase class RetryProtocol(Protocol): @@ -51,3 +52,4 @@ join_all: Callable[[], None] serial_stdout_off: Callable[[], None] serial_stdout_on: Callable[[], None] polling_condition: PollingConditionProtocol +t: TestCase