Saul Shanabrook

Partially Applied Pytest Fixtures

2020-10-07

TLDR: If you want to create pytest fixtures that are functions themselves, you can partially curry them, as long as you preserve the signatures and names properly:

In this small example here, using the custom partial_curry function, we can create a pytest fixture that returns a function itself quite nicely, by first wrapping the function in the partial_curry("tmp_path") decorator:

from __future__ import annotations

import functools
import inspect
import typing

import pytest


@pytest.fixture
@partial_curry("tmp_path")
def create_tmp_file(tmp_path, content: str):
    """
    Creates a temporary file with some text in it, returns the pathlib
    """
    p = tmp_path() / "file.txt"
    p.write_text(content)
    return p

def test_tmp_file(create_tmp_file):
    """
    Pulls in the create_tmp_file fixture and calls it with some args
    """
    f = create_tmp_file("some content")
    assert f.read_text() == "some content"


# TODO: improve typing in python 3.9 https://www.python.org/dev/peps/pep-0593/
def partial_curry(
    *first_args: str,
) -> typing.Callable[[typing.Callable], typing.Callable]:
    """
    Partially curry a function, to move the args around so one set must be called first and then the other second.

    Updates the signature and names properly.

    >>> @partial_curry("hi", "there")
    ... def hello_world(hi, there, other):
    ...     return hi + there + other
    >>> hello_world(hi="h", there="t")(other="r")
    'htr'
    >>> hello_world("h", "t")("r")
    'htr'
    >>> hello_world(hi="h", there="t")("r")
    'htr'
    >>> import inspect
    >>> inspect.signature(hello_world)
    <Signature (hi, there)>
    >>> inspect.signature(hello_world(1, 2))
    <Signature (other)>
    """

    def wrapper(f):
        s = inspect.signature(f)
        outer_s = s.replace(
            parameters=[p for p in s.parameters.values() if p.name in first_args]
        )
        inner_s = s.replace(
            parameters=[p for p in s.parameters.values() if p.name not in first_args]
        )

        @functools.wraps(f)
        def outer(*outer_args, **outer_kwargs):
            # Bind each so args/kwargs are all resolved properly and normalized
            outer_bound = outer_s.bind(*outer_args, **outer_kwargs)

            @functools.wraps(f)
            def inner(*inner_args, **inner_kwargs):
                inner_bound = inner_s.bind(*inner_args, **inner_kwargs)
                return f(
                    *outer_bound.args,
                    *inner_bound.args,
                    **outer_bound.kwargs,
                    **inner_bound.kwargs,
                )

            inner.__signature__ = inner_s  # type: ignore
            return inner

        outer.__signature__ = outer_s  # type: ignore
        return outer

    return wrapper

Background

I am currently experimenting with a GraphQL backend for the Jupyter server and I wanted to share a helpful technique that I found for testing it.

I am using Pytest fixtures to provide the GraphQL schema and then I want to be able to run queries against it easily, using another fixture. At first, I had a fixture that returned a class that was callable, like this:

import dataclasses
import ariadne
import graphql
import pytest


@dataclasses.dataclass
class QueryCaller:
    schema: graphql.type.schema.GraphQLSchema

    async def __call__(self, query: str, expect_error=False, **variables):
        """
        if expect error is true, then will verify an error has been raised and return the data.

        Otherwise, raises a real error
        """
        success, result = await ariadne.graphql(
            self.schema,
            {"query": query, "variables": variables},
            debug=True,
            error_formatter=raise_errors_directly
            if not expect_error
            else ariadne.format_error,
        )
        assert success
        if expect_error:
            assert result["errors"]
        return result["data"]

@pytest.fixture
def query(schema):
    return QueryCaller(schema)

async def test_kernels_mutations(query):
    assert (
        await query(
            """
            query {
                kernels {
                    id
                }
            }
            """
        )
        == {"kernels": []}
    )

However, it's a bit troublesome to have to write this custom class and return it from the function. So instead, using the partial_curry function above, you can just write one function and have that pull in some args from pytest fixtures and the rest from the user:

@pytest.fixture
@partial_curry("schema")
async def query(
    schema: graphql.type.schema.GraphQLSchema,
    query: str,
    expect_error=False,
    **variables,
):
    """
    if expect error is true, then will verify an error has been raised and return the data.

    Otherwise, raises a real error.
    """
    success, result = await ariadne.graphql(
        schema,
        {"query": query, "variables": variables},
        debug=True,
        error_formatter=raise_errors_directly
        if not expect_error
        else ariadne.format_error,
    )
    assert success
    if expect_error:
        assert result["errors"]
    return result["data"]

Interestingly enough, you can write a function with almost the same runtime semantics as partial_curry just as functools.partial(functools.partial, functools.partial), but the signature obviously wont' be updated, because this is generic to let you call any number of initial args and later args:

import functool

@functools.partial(functools.partial, functools.partial)
def hello_there(a, b):
    return a + b

hello_there(a=1)(b=2)
hello_there()(a=1, b=2)