如何在测试中编写和报告断言

使用 assert 语句断言

pytest 允许您使用标准的 Python assert 语句来验证 Python 测试中的预期和值。例如,您可以编写以下代码

# content of test_assert1.py
def f():
    return 3


def test_function():
    assert f() == 4

来断言您的函数返回一个特定值。如果此断言失败,您将看到函数调用的返回值

$ pytest test_assert1.py
=========================== test session starts ============================
platform linux -- Python 3.x.y, pytest-9.x.y, pluggy-1.x.y
rootdir: /home/sweet/project
collected 1 item

test_assert1.py F                                                    [100%]

================================= FAILURES =================================
______________________________ test_function _______________________________

    def test_function():
>       assert f() == 4
E       assert 3 == 4
E        +  where 3 = f()

test_assert1.py:6: AssertionError
========================= short test summary info ==========================
FAILED test_assert1.py::test_function - assert 3 == 4
============================ 1 failed in 0.12s =============================

pytest 支持显示最常见的子表达式的值,包括调用、属性、比较以及二进制和一元运算符。(参见 pytest 的 Python 失败报告演示)。这允许您使用惯用的 Python 结构,无需样板代码,同时不会丢失内省信息。

如果断言中指定了消息,如下所示

assert a % 2 == 0, "value was odd, should be even"

它会与回溯中的断言内省信息一起打印。

有关断言内省的更多信息,请参见 断言内省详情

关于近似相等性的断言

比较浮点值(或浮点数组)时,小的舍入误差很常见。您可以使用 pytest.approx() 来代替使用 assert abs(a - b) < tolnumpy.isclose

import pytest
import numpy as np


def test_floats():
    assert (0.1 + 0.2) == pytest.approx(0.3)


def test_arrays():
    a = np.array([1.0, 2.0, 3.0])
    b = np.array([0.9999, 2.0001, 3.0])
    assert a == pytest.approx(b)

pytest.approx 适用于标量、列表、字典和 NumPy 数组。它还支持涉及 NaN 的比较。

详情请参见 pytest.approx()

关于预期异常的断言

为了编写关于抛出异常的断言,您可以像这样使用 pytest.raises() 作为上下文管理器

import pytest


def test_zero_division():
    with pytest.raises(ZeroDivisionError):
        1 / 0

如果您需要访问实际的异常信息,可以使用

def test_recursion_depth():
    with pytest.raises(RuntimeError) as excinfo:

        def f():
            f()

        f()
    assert "maximum recursion" in str(excinfo.value)

excinfo 是一个 ExceptionInfo 实例,它是围绕实际抛出的异常的包装器。主要关注的属性是 .type.value.traceback

请注意,pytest.raises 将匹配异常类型或其任何子类(就像标准 except 语句一样)。如果您想检查代码块是否抛出确切的异常类型,您需要明确检查

def test_foo_not_implemented():
    def foo():
        raise NotImplementedError

    with pytest.raises(RuntimeError) as excinfo:
        foo()
    assert excinfo.type is RuntimeError

pytest.raises() 调用将成功,即使函数抛出 NotImplementedError,因为 NotImplementedErrorRuntimeError 的子类;但是,下面的 assert 语句将捕获此问题。

匹配异常消息

您可以将 match 关键字参数传递给上下文管理器,以测试正则表达式是否匹配异常的字符串表示(类似于 unittest 中的 TestCase.assertRaisesRegex 方法)

import pytest


def myfunc():
    raise ValueError("Exception 123 raised")


def test_match():
    with pytest.raises(ValueError, match=r".* 123 .*"):
        myfunc()

注意事项

  • match 参数使用 re.search() 函数进行匹配,因此在上面的示例中,match='123' 也会起作用。

  • match 参数也匹配 PEP-678 __notes__

关于预期异常组的断言

当预期 BaseExceptionGroupExceptionGroup 时,您可以使用 pytest.RaisesGroup

def test_exception_in_group():
    with pytest.RaisesGroup(ValueError):
        raise ExceptionGroup("group msg", [ValueError("value msg")])
    with pytest.RaisesGroup(ValueError, TypeError):
        raise ExceptionGroup("msg", [ValueError("foo"), TypeError("bar")])

它接受一个 match 参数,用于检查组消息,以及一个 check 参数,该参数接受一个任意可调用对象,并将组传递给它,只有当可调用对象返回 True 时才成功。

def test_raisesgroup_match_and_check():
    with pytest.RaisesGroup(BaseException, match="my group msg"):
        raise BaseExceptionGroup("my group msg", [KeyboardInterrupt()])
    with pytest.RaisesGroup(
        Exception, check=lambda eg: isinstance(eg.__cause__, ValueError)
    ):
        raise ExceptionGroup("", [TypeError()]) from ValueError()

它对结构和未包装的异常很严格,与 except* 不同,因此您可能需要设置 flatten_subgroups 和/或 allow_unwrapped 参数。

def test_structure():
    with pytest.RaisesGroup(pytest.RaisesGroup(ValueError)):
        raise ExceptionGroup("", (ExceptionGroup("", (ValueError(),)),))
    with pytest.RaisesGroup(ValueError, flatten_subgroups=True):
        raise ExceptionGroup("1st group", [ExceptionGroup("2nd group", [ValueError()])])
    with pytest.RaisesGroup(ValueError, allow_unwrapped=True):
        raise ValueError

要指定有关所包含异常的更多详细信息,您可以使用 pytest.RaisesExc

def test_raises_exc():
    with pytest.RaisesGroup(pytest.RaisesExc(ValueError, match="foo")):
        raise ExceptionGroup("", (ValueError("foo")))

它们都提供了一个方法 pytest.RaisesGroup.matches() pytest.RaisesExc.matches(),如果您想在不将其用作上下文管理器的情况下进行匹配,这可能很有用。当检查 .__context__.__cause__ 时,这会很有帮助。

def test_matches():
    exc = ValueError()
    exc_group = ExceptionGroup("", [exc])
    if RaisesGroup(ValueError).matches(exc_group):
        ...
    # helpful error is available in `.fail_reason` if it fails to match
    r = RaisesExc(ValueError)
    assert r.matches(e), r.fail_reason

有关更多详细信息和示例,请查看 pytest.RaisesGrouppytest.RaisesExc 的文档。

ExceptionInfo.group_contains()

警告

此辅助工具可以轻松检查特定异常的存在,但它对于检查组中是否 包含 任何其他异常 非常糟糕。因此,这将通过

class EXTREMELYBADERROR(BaseException):
    """This is a very bad error to miss"""


def test_for_value_error():
    with pytest.raises(ExceptionGroup) as excinfo:
        excs = [ValueError()]
        if very_unlucky():
            excs.append(EXTREMELYBADERROR())
        raise ExceptionGroup("", excs)
    # This passes regardless of if there's other exceptions.
    assert excinfo.group_contains(ValueError)
    # You can't simply list all exceptions you *don't* want to get here.

没有好的方法使用 excinfo.group_contains() 来确保您没有收到除您期望的异常之外的 任何 其他异常。您应该改用 pytest.RaisesGroup,请参见 关于预期异常组的断言

您还可以使用 excinfo.group_contains() 方法来测试作为 ExceptionGroup 一部分返回的异常

def test_exception_in_group():
    with pytest.raises(ExceptionGroup) as excinfo:
        raise ExceptionGroup(
            "Group message",
            [
                RuntimeError("Exception 123 raised"),
            ],
        )
    assert excinfo.group_contains(RuntimeError, match=r".* 123 .*")
    assert not excinfo.group_contains(TypeError)

可选的 match 关键字参数与 pytest.raises() 的工作方式相同。

默认情况下,group_contains() 将递归搜索任何嵌套 ExceptionGroup 实例中任何级别的匹配异常。如果您只想匹配特定级别的异常,可以指定一个 depth 关键字参数;直接包含在顶层 ExceptionGroup 中的异常将匹配 depth=1

def test_exception_in_group_at_given_depth():
    with pytest.raises(ExceptionGroup) as excinfo:
        raise ExceptionGroup(
            "Group message",
            [
                RuntimeError(),
                ExceptionGroup(
                    "Nested group",
                    [
                        TypeError(),
                    ],
                ),
            ],
        )
    assert excinfo.group_contains(RuntimeError, depth=1)
    assert excinfo.group_contains(TypeError, depth=2)
    assert not excinfo.group_contains(RuntimeError, depth=2)
    assert not excinfo.group_contains(TypeError, depth=1)

备用 pytest.raises 形式(旧版)

还有一种 pytest.raises() 的备用形式,您可以通过它传递一个将被执行的函数,以及 *args**kwargspytest.raises() 将使用这些参数执行函数,并断言抛出给定异常

def func(x):
    if x <= 0:
        raise ValueError("x needs to be larger than zero")


pytest.raises(ValueError, func, x=-1)

在失败的情况下,例如 没有异常错误的异常,报告器将为您提供有用的输出。

这种形式是最初的 pytest.raises() API,在 Python 语言中添加 with 语句之前开发。如今,这种形式很少使用,上下文管理器形式(使用 with)被认为更具可读性。尽管如此,这种形式仍得到完全支持,并未以任何方式弃用。

xfail 标记和 pytest.raises

还可以向 pytest.mark.xfail 指定 raises 参数,它检查测试是否以比简单地抛出任何异常更具体的方式失败

def f():
    raise IndexError()


@pytest.mark.xfail(raises=IndexError)
def test_f():
    f()

这只会在测试因抛出 IndexError 或其子类而失败时才会被“xfail”。

  • 使用带有 raises 参数的 pytest.mark.xfail 可能更适合记录未修复的错误(测试描述了“应该”发生的情况)或依赖项中的错误。

  • 使用 pytest.raises() 可能更适合您正在测试自己的代码故意抛出的异常的情况,这是大多数情况。

您还可以使用 pytest.RaisesGroup

def f():
    raise ExceptionGroup("", [IndexError()])


@pytest.mark.xfail(raises=RaisesGroup(IndexError))
def test_f():
    f()

关于预期警告的断言

您可以使用 pytest.warns 检查代码是否抛出特定警告。

利用上下文敏感的比较

pytest 在遇到比较时提供了丰富的上下文敏感信息支持。例如

# content of test_assert2.py
def test_set_comparison():
    set1 = set("1308")
    set2 = set("8035")
    assert set1 == set2

如果您运行此模块

$ pytest test_assert2.py
=========================== test session starts ============================
platform linux -- Python 3.x.y, pytest-9.x.y, pluggy-1.x.y
rootdir: /home/sweet/project
collected 1 item

test_assert2.py F                                                    [100%]

================================= FAILURES =================================
___________________________ test_set_comparison ____________________________

    def test_set_comparison():
        set1 = set("1308")
        set2 = set("8035")
>       assert set1 == set2
E       AssertionError: assert {'0', '1', '3', '8'} == {'0', '3', '5', '8'}
E
E         Extra items in the left set:
E         '1'
E         Extra items in the right set:
E         '5'
E         Use -v to get more diff

test_assert2.py:4: AssertionError
========================= short test summary info ==========================
FAILED test_assert2.py::test_set_comparison - AssertionError: assert {'0'...
============================ 1 failed in 0.12s =============================

对许多情况进行了特殊比较

  • 比较长字符串:显示上下文差异

  • 比较长序列:第一个失败索引

  • 比较字典:不同的条目

有关更多示例,请参见 报告演示

为失败的断言定义自己的解释

可以通过实现 pytest_assertrepr_compare 钩子来添加您自己的详细解释。

pytest_assertrepr_compare(config, op, left, right)[source]

返回失败断言表达式中比较的解释。

如果没有自定义解释,则返回 None,否则返回字符串列表。字符串将由换行符连接,但字符串 的任何换行符都将转义。请注意,除第一行外的所有行都将稍微缩进,目的是让第一行作为摘要。

参数:
  • config (Config) – pytest 配置对象。

  • op (str) – 运算符,例如 "==", "!=", "not in"

  • left (object) – 左操作数。

  • right (object) – 右操作数。

在 conftest 插件中使用

任何 conftest 文件都可以实现此钩子。对于给定的项目,只查阅项目目录及其父目录中的 conftest 文件。

例如,考虑在 conftest.py 文件中添加以下钩子,它为 Foo 对象提供替代解释

# content of conftest.py
from test_foocompare import Foo


def pytest_assertrepr_compare(op, left, right):
    if isinstance(left, Foo) and isinstance(right, Foo) and op == "==":
        return [
            "Comparing Foo instances:",
            f"   vals: {left.val} != {right.val}",
        ]

现在,给定此测试模块

# content of test_foocompare.py
class Foo:
    def __init__(self, val):
        self.val = val

    def __eq__(self, other):
        return self.val == other.val


def test_compare():
    f1 = Foo(1)
    f2 = Foo(2)
    assert f1 == f2

您可以运行测试模块并获取在 conftest 文件中定义的自定义输出

$ pytest -q test_foocompare.py
F                                                                    [100%]
================================= FAILURES =================================
_______________________________ test_compare _______________________________

    def test_compare():
        f1 = Foo(1)
        f2 = Foo(2)
>       assert f1 == f2
E       assert Comparing Foo instances:
E            vals: 1 != 2

test_foocompare.py:12: AssertionError
========================= short test summary info ==========================
FAILED test_foocompare.py::test_compare - assert Comparing Foo instances:
1 failed in 0.12s

在测试函数中返回非 None 值

当测试函数返回非 None 值时,会发出 pytest.PytestReturnNotNoneWarning

这有助于防止初学者常犯的一个错误,他们认为返回布尔值(例如 TrueFalse)将决定测试是否通过或失败。

示例

@pytest.mark.parametrize(
    ["a", "b", "result"],
    [
        [1, 2, 5],
        [2, 3, 8],
        [5, 3, 18],
    ],
)
def test_foo(a, b, result):
    return foo(a, b) == result  # Incorrect usage, do not do this.

由于 pytest 忽略返回值,因此测试永远不会根据返回值失败,这可能会令人惊讶。

正确的修复方法是将 return 语句替换为 assert

@pytest.mark.parametrize(
    ["a", "b", "result"],
    [
        [1, 2, 5],
        [2, 3, 8],
        [5, 3, 18],
    ],
)
def test_foo(a, b, result):
    assert foo(a, b) == result

断言内省详情

通过在断言语句运行之前重写它们来实现报告失败断言的详细信息。重写的断言语句将内省信息放入断言失败消息中。pytest 只重写其测试收集过程直接发现的测试模块,因此 支持模块中本身不是测试模块的断言将不会被重写

您可以通过在导入模块之前调用 register_assert_rewrite 来手动为导入的模块启用断言重写(一个好的地方是在您的根 conftest.py 中这样做)。

有关更多信息,Benjamin Peterson 撰写了 pytest 新断言重写幕后

断言重写将文件缓存到磁盘

pytest 会将重写后的模块写回磁盘进行缓存。您可以通过在 conftest.py 文件顶部添加以下内容来禁用此行为(例如,为了避免在经常移动文件的项目中留下过时的 .pyc 文件)

import sys

sys.dont_write_bytecode = True

请注意,您仍然可以获得断言内省的好处,唯一的改变是 .pyc 文件不会缓存到磁盘。

此外,如果无法写入新的 .pyc 文件,例如在只读文件系统或 zipfile 中,重写将静默跳过缓存。

禁用断言重写

pytest 在导入时通过使用导入钩子来重写测试模块,以写入新的 pyc 文件。大多数情况下,这都是透明地工作。但是,如果您自己处理导入机制,导入钩子可能会干扰。

如果是这种情况,您有两个选择

  • 通过在其文档字符串中添加字符串 PYTEST_DONT_REWRITE 来禁用特定模块的重写。

  • 通过使用 --assert=plain 禁用所有模块的重写。