1
0
Fork 1
mirror of https://github.com/NixOS/nixpkgs.git synced 2025-06-10 18:12:34 +09:00

nixos/test-driver: improve error reporting and assertions (#390996)

This commit is contained in:
Jacek Galowicz 2025-04-26 10:26:01 +02:00 committed by GitHub
commit d0c304d4c1
Signed by: github
GPG key ID: B5690EEEBB952194
9 changed files with 118 additions and 23 deletions

View file

@ -121,8 +121,7 @@ and checks that the output is more-or-less correct:
```py
machine.start()
machine.wait_for_unit("default.target")
if not "Linux" in machine.succeed("uname"):
raise Exception("Wrong OS")
t.assertIn("Linux", machine.succeed("uname"), "Wrong OS")
```
The first line is technically unnecessary; machines are implicitly started
@ -134,6 +133,8 @@ starting them in parallel:
start_all()
```
Under the variable `t`, all assertions from [`unittest.TestCase`](https://docs.python.org/3/library/unittest.html) are available.
If the hostname of a node contains characters that can't be used in a
Python variable name, those characters will be replaced with
underscores in the variable name, so `nodes.machine-a` will be exposed

View file

@ -31,6 +31,7 @@ python3Packages.buildPythonApplication {
colorama
junit-xml
ptpython
ipython
]
++ extraPythonPackages python3Packages;

View file

@ -21,7 +21,7 @@ target-version = "py312"
line-length = 88
lint.select = ["E", "F", "I", "U", "N"]
lint.ignore = ["E501"]
lint.ignore = ["E501", "N818"]
# xxx: we can import https://pypi.org/project/types-colorama/ here
[[tool.mypy.overrides]]

View file

@ -3,7 +3,7 @@ import os
import time
from pathlib import Path
import ptpython.repl
import ptpython.ipython
from test_driver.driver import Driver
from test_driver.logger import (
@ -136,11 +136,10 @@ def main() -> None:
if args.interactive:
history_dir = os.getcwd()
history_path = os.path.join(history_dir, ".nixos-test-history")
ptpython.repl.embed(
driver.test_symbols(),
{},
ptpython.ipython.embed(
user_ns=driver.test_symbols(),
history_filename=history_path,
)
) # type:ignore
else:
tic = time.time()
driver.run_tests()

View file

@ -1,13 +1,17 @@
import os
import re
import signal
import sys
import tempfile
import threading
import traceback
from collections.abc import Callable, Iterator
from contextlib import AbstractContextManager, contextmanager
from pathlib import Path
from typing import Any
from unittest import TestCase
from test_driver.errors import MachineError, RequestedAssertionFailed
from test_driver.logger import AbstractLogger
from test_driver.machine import Machine, NixStartScript, retry
from test_driver.polling_condition import PollingCondition
@ -16,6 +20,18 @@ from test_driver.vlan import VLan
SENTINEL = object()
class AssertionTester(TestCase):
"""
Subclass of `unittest.TestCase` which is used in the
`testScript` to perform assertions.
It throws a custom exception whose parent class
gets special treatment in the logs.
"""
failureException = RequestedAssertionFailed
def get_tmp_dir() -> Path:
"""Returns a temporary directory that is defined by TMPDIR, TEMP, TMP or CWD
Raises an exception in case the retrieved temporary directory is not writeable
@ -115,7 +131,7 @@ class Driver:
try:
yield
except Exception as e:
self.logger.error(f'Test "{name}" failed with error: "{e}"')
self.logger.log_test_error(f'Test "{name}" failed with error: "{e}"')
raise e
def test_symbols(self) -> dict[str, Any]:
@ -140,6 +156,7 @@ class Driver:
serial_stdout_on=self.serial_stdout_on,
polling_condition=self.polling_condition,
Machine=Machine, # for typing
t=AssertionTester(),
)
machine_symbols = {pythonize_name(m.name): m for m in self.machines}
# If there's exactly one machine, make it available under the name
@ -163,7 +180,36 @@ class Driver:
"""Run the test script"""
with self.logger.nested("run the VM test script"):
symbols = self.test_symbols() # call eagerly
try:
exec(self.tests, symbols, None)
except MachineError:
for line in traceback.format_exc().splitlines():
self.logger.log_test_error(line)
sys.exit(1)
except RequestedAssertionFailed:
exc_type, exc, tb = sys.exc_info()
# We manually print the stack frames, keeping only the ones from the test script
# (note: because the script is not a real file, the frame filename is `<string>`)
filtered = [
frame
for frame in traceback.extract_tb(tb)
if frame.filename == "<string>"
]
self.logger.log_test_error("Traceback (most recent call last):")
code = self.tests.splitlines()
for frame, line in zip(filtered, traceback.format_list(filtered)):
self.logger.log_test_error(line.rstrip())
if lineno := frame.lineno:
self.logger.log_test_error(f" {code[lineno - 1].strip()}")
self.logger.log_test_error("") # blank line for readability
exc_prefix = exc_type.__name__ if exc_type is not None else "Error"
for line in f"{exc_prefix}: {exc}".splitlines():
self.logger.log_test_error(line)
sys.exit(1)
def run_tests(self) -> None:
"""Run the test script (for non-interactive test runs)"""

View file

@ -0,0 +1,20 @@
class MachineError(Exception):
"""
Exception that indicates an error that is NOT the user's fault,
i.e. something went wrong without the test being necessarily invalid,
such as failing OCR.
To make it easier to spot, this exception (and its subclasses)
get a `!!!` prefix in the log output.
"""
class RequestedAssertionFailed(AssertionError):
"""
Special assertion that gets thrown on an assertion error,
e.g. a failing `t.assertEqual(...)` or `machine.succeed(...)`.
This gets special treatment in error reporting: i.e. it gets
`!!!` as prefix just as `MachineError`, but only stack frames coming
from `testScript` will show up in logs.
"""

View file

@ -44,6 +44,10 @@ class AbstractLogger(ABC):
def error(self, *args, **kwargs) -> None: # type: ignore
pass
@abstractmethod
def log_test_error(self, *args, **kwargs) -> None: # type:ignore
pass
@abstractmethod
def log_serial(self, message: str, machine: str) -> None:
pass
@ -97,6 +101,9 @@ class JunitXMLLogger(AbstractLogger):
self.tests[self.currentSubtest].stderr += args[0] + os.linesep
self.tests[self.currentSubtest].failure = True
def log_test_error(self, *args, **kwargs) -> None: # type: ignore
self.error(*args, **kwargs)
def log_serial(self, message: str, machine: str) -> None:
if not self._print_serial_logs:
return
@ -156,6 +163,10 @@ class CompositeLogger(AbstractLogger):
for logger in self.logger_list:
logger.warning(*args, **kwargs)
def log_test_error(self, *args, **kwargs) -> None: # type: ignore
for logger in self.logger_list:
logger.log_test_error(*args, **kwargs)
def error(self, *args, **kwargs) -> None: # type: ignore
for logger in self.logger_list:
logger.error(*args, **kwargs)
@ -202,7 +213,7 @@ class TerminalLogger(AbstractLogger):
tic = time.time()
yield
toc = time.time()
self.log(f"(finished: {message}, in {toc - tic:.2f} seconds)")
self.log(f"(finished: {message}, in {toc - tic:.2f} seconds)", attributes)
def info(self, *args, **kwargs) -> None: # type: ignore
self.log(*args, **kwargs)
@ -222,6 +233,11 @@ class TerminalLogger(AbstractLogger):
self._eprint(Style.DIM + f"{machine} # {message}" + Style.RESET_ALL)
def log_test_error(self, *args, **kwargs) -> None: # type: ignore
prefix = Fore.RED + "!!! " + Style.RESET_ALL
# NOTE: using `warning` instead of `error` to ensure it does not exit after printing the first log
self.warning(f"{prefix}{args[0]}", *args[1:], **kwargs)
class XMLLogger(AbstractLogger):
def __init__(self, outfile: str) -> None:
@ -261,6 +277,9 @@ class XMLLogger(AbstractLogger):
def error(self, *args, **kwargs) -> None: # type: ignore
self.log(*args, **kwargs)
def log_test_error(self, *args, **kwargs) -> None: # type: ignore
self.log(*args, **kwargs)
def log(self, message: str, attributes: dict[str, str] = {}) -> None:
self.drain_log_queue()
self.log_line(message, attributes)

View file

@ -19,6 +19,7 @@ from pathlib import Path
from queue import Queue
from typing import Any
from test_driver.errors import MachineError, RequestedAssertionFailed
from test_driver.logger import AbstractLogger
from .qmp import QMPSession
@ -129,7 +130,7 @@ def _preprocess_screenshot(screenshot_path: str, negate: bool = False) -> str:
)
if ret.returncode != 0:
raise Exception(
raise MachineError(
f"Image processing failed with exit code {ret.returncode}, stdout: {ret.stdout.decode()}, stderr: {ret.stderr.decode()}"
)
@ -140,7 +141,7 @@ def _perform_ocr_on_screenshot(
screenshot_path: str, model_ids: Iterable[int]
) -> list[str]:
if shutil.which("tesseract") is None:
raise Exception("OCR requested but enableOCR is false")
raise MachineError("OCR requested but enableOCR is false")
processed_image = _preprocess_screenshot(screenshot_path, negate=False)
processed_negative = _preprocess_screenshot(screenshot_path, negate=True)
@ -163,7 +164,7 @@ def _perform_ocr_on_screenshot(
capture_output=True,
)
if ret.returncode != 0:
raise Exception(f"OCR failed with exit code {ret.returncode}")
raise MachineError(f"OCR failed with exit code {ret.returncode}")
model_results.append(ret.stdout.decode("utf-8"))
return model_results
@ -180,7 +181,9 @@ def retry(fn: Callable, timeout: int = 900) -> None:
time.sleep(1)
if not fn(True):
raise Exception(f"action timed out after {timeout} seconds")
raise RequestedAssertionFailed(
f"action timed out after {timeout} tries with one-second pause in-between"
)
class StartCommand:
@ -409,14 +412,14 @@ class Machine:
def check_active(_last_try: bool) -> bool:
state = self.get_unit_property(unit, "ActiveState", user)
if state == "failed":
raise Exception(f'unit "{unit}" reached state "{state}"')
raise RequestedAssertionFailed(f'unit "{unit}" reached state "{state}"')
if state == "inactive":
status, jobs = self.systemctl("list-jobs --full 2>&1", user)
if "No jobs" in jobs:
info = self.get_unit_info(unit, user)
if info["ActiveState"] == state:
raise Exception(
raise RequestedAssertionFailed(
f'unit "{unit}" is inactive and there are no pending jobs'
)
@ -431,7 +434,7 @@ class Machine:
def get_unit_info(self, unit: str, user: str | None = None) -> dict[str, str]:
status, lines = self.systemctl(f'--no-pager show "{unit}"', user)
if status != 0:
raise Exception(
raise RequestedAssertionFailed(
f'retrieving systemctl info for unit "{unit}"'
+ ("" if user is None else f' under user "{user}"')
+ f" failed with exit code {status}"
@ -461,7 +464,7 @@ class Machine:
user,
)
if status != 0:
raise Exception(
raise RequestedAssertionFailed(
f'retrieving systemctl property "{property}" for unit "{unit}"'
+ ("" if user is None else f' under user "{user}"')
+ f" failed with exit code {status}"
@ -509,7 +512,7 @@ class Machine:
info = self.get_unit_info(unit)
state = info["ActiveState"]
if state != require_state:
raise Exception(
raise RequestedAssertionFailed(
f"Expected unit '{unit}' to to be in state "
f"'{require_state}' but it is in state '{state}'"
)
@ -663,7 +666,9 @@ class Machine:
(status, out) = self.execute(command, timeout=timeout)
if status != 0:
self.log(f"output: {out}")
raise Exception(f"command `{command}` failed (exit code {status})")
raise RequestedAssertionFailed(
f"command `{command}` failed (exit code {status})"
)
output += out
return output
@ -677,7 +682,9 @@ class Machine:
with self.nested(f"must fail: {command}"):
(status, out) = self.execute(command, timeout=timeout)
if status == 0:
raise Exception(f"command `{command}` unexpectedly succeeded")
raise RequestedAssertionFailed(
f"command `{command}` unexpectedly succeeded"
)
output += out
return output
@ -922,7 +929,7 @@ class Machine:
ret = subprocess.run(f"pnmtopng '{tmp}' > '{filename}'", shell=True)
os.unlink(tmp)
if ret.returncode != 0:
raise Exception("Cannot convert screenshot")
raise MachineError("Cannot convert screenshot")
def copy_from_host_via_shell(self, source: str, target: str) -> None:
"""Copy a file from the host into the guest by piping it over the

View file

@ -8,6 +8,7 @@ from test_driver.logger import AbstractLogger
from typing import Callable, Iterator, ContextManager, Optional, List, Dict, Any, Union
from typing_extensions import Protocol
from pathlib import Path
from unittest import TestCase
class RetryProtocol(Protocol):
@ -51,3 +52,4 @@ join_all: Callable[[], None]
serial_stdout_off: Callable[[], None]
serial_stdout_on: Callable[[], None]
polling_condition: PollingConditionProtocol
t: TestCase