Skip to content

Commit a4c0028

Browse files
committed
Fixed middleware testing.
1 parent f914db0 commit a4c0028

File tree

11 files changed

+154
-98
lines changed

11 files changed

+154
-98
lines changed

taskiq/decor.py

+7
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
import copy
12
import sys
23
from collections.abc import Coroutine
34
from datetime import datetime
@@ -8,13 +9,15 @@
89
Callable,
910
Dict,
1011
Generic,
12+
Optional,
1113
TypeVar,
1214
Union,
1315
overload,
1416
)
1517

1618
from typing_extensions import ParamSpec
1719

20+
from taskiq.abc.middleware import TaskiqMiddleware
1821
from taskiq.kicker import AsyncKicker
1922
from taskiq.scheduler.created_schedule import CreatedSchedule
2023
from taskiq.task import AsyncTaskiqTask
@@ -51,11 +54,15 @@ def __init__(
5154
task_name: str,
5255
original_func: Callable[_FuncParams, _ReturnType],
5356
labels: Dict[str, Any],
57+
extra_middlewares: Optional[list[TaskiqMiddleware]] = None,
5458
) -> None:
5559
self.broker = broker
5660
self.task_name = task_name
5761
self.original_func = original_func
5862
self.labels = labels
63+
self.middlewares = copy.copy(broker.middlewares)
64+
if extra_middlewares:
65+
self.middlewares.extend(extra_middlewares)
5966

6067
# This is a hack to make ProcessPoolExecutor work
6168
# with decorated functions.

taskiq/kicker.py

+6-6
Original file line numberDiff line numberDiff line change
@@ -153,7 +153,7 @@ async def kiq(
153153
logger.debug(
154154
f"Kicking {self.task_name} with args={args} and kwargs={kwargs}.",
155155
)
156-
message = self._prepare_message(*args, **kwargs)
156+
message = self.get_message(*args, **kwargs)
157157
for middleware in self.broker.middlewares:
158158
if middleware.__class__.pre_send != TaskiqMiddleware.pre_send:
159159
message = await maybe_awaitable(middleware.pre_send(message))
@@ -191,7 +191,7 @@ async def schedule_by_cron(
191191
schedule_id = self.custom_schedule_id
192192
if schedule_id is None:
193193
schedule_id = self.broker.id_generator()
194-
message = self._prepare_message(*args, **kwargs)
194+
message = self.get_message(*args, **kwargs)
195195
cron_offset = None
196196
if isinstance(cron, CronSpec):
197197
cron_str = cron.to_cron()
@@ -228,7 +228,7 @@ async def schedule_by_time(
228228
schedule_id = self.custom_schedule_id
229229
if schedule_id is None:
230230
schedule_id = self.broker.id_generator()
231-
message = self._prepare_message(*args, **kwargs)
231+
message = self.get_message(*args, **kwargs)
232232
scheduled = ScheduledTask(
233233
schedule_id=schedule_id,
234234
task_name=message.task_name,
@@ -261,10 +261,10 @@ def _prepare_arg(cls, arg: Any) -> Any:
261261
arg = asdict(arg)
262262
return arg
263263

264-
def _prepare_message(
264+
def get_message(
265265
self,
266-
*args: Any,
267-
**kwargs: Any,
266+
*args: _FuncParams.args,
267+
**kwargs: _FuncParams.kwargs,
268268
) -> TaskiqMessage:
269269
"""
270270
Create a message from args and kwargs.

taskiq/receiver/receiver.py

+15-15
Original file line numberDiff line numberDiff line change
@@ -125,7 +125,7 @@ async def callback( # noqa: C901, PLR0912
125125
"Function for task %s is resolved. Executing...",
126126
taskiq_msg.task_name,
127127
)
128-
for middleware in self.broker.middlewares:
128+
for middleware in task.middlewares:
129129
if middleware.__class__.pre_execute != TaskiqMiddleware.pre_execute:
130130
taskiq_msg = await maybe_awaitable(
131131
middleware.pre_execute(
@@ -150,21 +150,32 @@ async def callback( # noqa: C901, PLR0912
150150
message=taskiq_msg,
151151
)
152152

153+
if result.is_err is not None:
154+
for middleware in task.middlewares:
155+
if middleware.__class__.on_error != TaskiqMiddleware.on_error:
156+
await maybe_awaitable(
157+
middleware.on_error(
158+
taskiq_msg,
159+
result,
160+
result.error, # type: ignore
161+
),
162+
)
163+
153164
if self.ack_time == AcknowledgeType.WHEN_EXECUTED and isinstance(
154165
message,
155166
AckableMessage,
156167
):
157168
await maybe_awaitable(message.ack())
158169

159-
for middleware in self.broker.middlewares:
170+
for middleware in task.middlewares:
160171
if middleware.__class__.post_execute != TaskiqMiddleware.post_execute:
161172
await maybe_awaitable(middleware.post_execute(taskiq_msg, result))
162173

163174
try:
164175
if not isinstance(result.error, NoResultError):
165176
await self.broker.result_backend.set_result(taskiq_msg.task_id, result)
166177

167-
for middleware in self.broker.middlewares:
178+
for middleware in task.middlewares:
168179
if middleware.__class__.post_save != TaskiqMiddleware.post_save:
169180
await maybe_awaitable(middleware.post_save(taskiq_msg, result))
170181

@@ -183,7 +194,7 @@ async def callback( # noqa: C901, PLR0912
183194
):
184195
await maybe_awaitable(message.ack())
185196

186-
async def run_task( # noqa: C901, PLR0912, PLR0915
197+
async def run_task( # noqa: C901
187198
self,
188199
target: Callable[..., Any],
189200
message: TaskiqMessage,
@@ -304,17 +315,6 @@ async def run_task( # noqa: C901, PLR0912, PLR0915
304315
error=found_exception,
305316
labels=message.labels,
306317
)
307-
# If exception is found we execute middlewares.
308-
if found_exception is not None:
309-
for middleware in self.broker.middlewares:
310-
if middleware.__class__.on_error != TaskiqMiddleware.on_error:
311-
await maybe_awaitable(
312-
middleware.on_error(
313-
message,
314-
result,
315-
found_exception,
316-
),
317-
)
318318

319319
return result
320320

taskiq/task_creator.py

+21-9
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515

1616
from typing_extensions import ParamSpec, Self
1717

18+
from taskiq.abc.middleware import TaskiqMiddleware
1819
from taskiq.decor import AsyncTaskiqDecoratedTask
1920
from taskiq.utils import remove_suffix
2021

@@ -37,6 +38,7 @@ def __init__(self, broker: "AsyncBroker") -> None:
3738
self._broker = broker
3839
self._task_name: Optional[str] = None
3940
self._labels: Dict[str, Any] = {}
41+
self._middlewares: list[TaskiqMiddleware] = []
4042

4143
def name(self, name: str) -> Self:
4244
"""Assign custom name to the task."""
@@ -48,6 +50,11 @@ def labels(self, **labels: Any) -> Self:
4850
self._labels = labels
4951
return self
5052

53+
def middlewares(self, *middlewares: TaskiqMiddleware) -> Self:
54+
"""Assign custom middlewares to the task."""
55+
self._middlewares = list(middlewares)
56+
return self
57+
5158
def make_task(
5259
self,
5360
task_name: str,
@@ -59,8 +66,21 @@ def make_task(
5966
original_func=func,
6067
labels=self._labels,
6168
task_name=task_name,
69+
extra_middlewares=self._middlewares,
6270
)
6371

72+
def __resolve_name(self, func: Callable[..., Any]) -> str:
73+
"""Resolve name of the function."""
74+
fmodule = func.__module__
75+
if fmodule == "__main__": # pragma: no cover
76+
fmodule = ".".join(
77+
remove_suffix(sys.argv[0], ".py").split(os.path.sep),
78+
)
79+
fname = func.__name__
80+
if fname == "<lambda>":
81+
fname = f"lambda_{uuid4().hex}"
82+
return f"{fmodule}:{fname}"
83+
6484
@overload
6585
def __call__(
6686
self,
@@ -131,15 +151,7 @@ def inner(
131151
) -> AsyncTaskiqDecoratedTask[_FuncParams, _ReturnType]:
132152
inner_task_name = self._task_name
133153
if inner_task_name is None:
134-
fmodule = func.__module__
135-
if fmodule == "__main__": # pragma: no cover
136-
fmodule = ".".join(
137-
remove_suffix(sys.argv[0], ".py").split(os.path.sep),
138-
)
139-
fname = func.__name__
140-
if fname == "<lambda>":
141-
fname = f"lambda_{uuid4().hex}"
142-
inner_task_name = f"{fmodule}:{fname}"
154+
inner_task_name = self.__resolve_name(func)
143155

144156
wrapper = wraps(func)
145157
decorated_task = wrapper(

tests/abc/test_broker.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -40,7 +40,7 @@ def test_decorator_with_name_success() -> None:
4040
"""Test that task_name is successfully set."""
4141
tbrok = _TestBroker()
4242

43-
@tbrok.task(task_name="my_task")
43+
@tbrok.task.name("my_task")
4444
async def test_func() -> None:
4545
"""Some test function."""
4646

@@ -52,7 +52,7 @@ def test_decorator_with_labels_success() -> None:
5252
"""Tests that labels are assigned for task as is."""
5353
tbrok = _TestBroker()
5454

55-
@tbrok.task(label1=1, label2=2)
55+
@tbrok.task.labels(label1=1, label2=2)
5656
async def test_func() -> None:
5757
"""Some test function."""
5858

tests/api/test_scheduler.py

+5-3
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
import asyncio
22
import contextlib
3-
from datetime import datetime, timedelta
3+
from datetime import datetime, timedelta, timezone
44

55
import pytest
66

@@ -16,7 +16,9 @@ async def test_successful() -> None:
1616
scheduler = TaskiqScheduler(broker, sources=[LabelScheduleSource(broker)])
1717
scheduler_task = asyncio.create_task(run_scheduler_task(scheduler))
1818

19-
@broker.task(schedule=[{"time": datetime.utcnow() - timedelta(seconds=1)}])
19+
@broker.task.labels(
20+
schedule=[{"time": datetime.now(timezone.utc) - timedelta(seconds=1)}],
21+
)
2022
def _() -> None:
2123
...
2224

@@ -31,7 +33,7 @@ async def test_cancelation() -> None:
3133
broker = AsyncQueueBroker()
3234
scheduler = TaskiqScheduler(broker, sources=[LabelScheduleSource(broker)])
3335

34-
@broker.task(schedule=[{"time": datetime.utcnow()}])
36+
@broker.task.labels(schedule=[{"time": datetime.now(timezone.utc)}])
3537
def _() -> None:
3638
...
3739

tests/middlewares/test_task_retry.py

+5-5
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@ async def test_wait_result() -> None:
1414
)
1515
runs = 0
1616

17-
@broker.task(retry_on_error=True)
17+
@broker.task.labels(retry_on_error=True)
1818
async def run_task() -> str:
1919
nonlocal runs
2020

@@ -40,7 +40,7 @@ async def test_wait_result_error() -> None:
4040
runs = 0
4141
lock = asyncio.Lock()
4242

43-
@broker.task(retry_on_error=True)
43+
@broker.task.labels(retry_on_error=True)
4444
async def run_task() -> str:
4545
nonlocal runs, lock
4646

@@ -74,7 +74,7 @@ async def test_wait_result_no_result() -> None:
7474
runs = 0
7575
lock = asyncio.Lock()
7676

77-
@broker.task(retry_on_error=True)
77+
@broker.task.labels(retry_on_error=True)
7878
async def run_task() -> str:
7979
nonlocal runs, done, lock
8080

@@ -111,7 +111,7 @@ async def test_max_retries() -> None:
111111
)
112112
runs = 0
113113

114-
@broker.task(max_retries=10)
114+
@broker.task.labels(max_retries=10)
115115
def run_task() -> str:
116116
nonlocal runs
117117

@@ -137,7 +137,7 @@ async def test_no_retry() -> None:
137137
)
138138
runs = 0
139139

140-
@broker.task(retry_on_error=False, max_retries=10)
140+
@broker.task.labels(retry_on_error=False, max_retries=10)
141141
def run_task() -> str:
142142
nonlocal runs
143143

tests/receiver/test_receiver.py

+1-40
Original file line numberDiff line numberDiff line change
@@ -2,18 +2,16 @@
22
import random
33
import time
44
from concurrent.futures import ThreadPoolExecutor
5-
from typing import Any, ClassVar, List, Optional
5+
from typing import Optional
66

77
import pytest
88
from taskiq_dependencies import Depends
99

1010
from taskiq.abc.broker import AckableMessage, AsyncBroker
11-
from taskiq.abc.middleware import TaskiqMiddleware
1211
from taskiq.brokers.inmemory_broker import InMemoryBroker
1312
from taskiq.exceptions import NoResultError, TaskiqResultTimeoutError
1413
from taskiq.message import TaskiqMessage
1514
from taskiq.receiver import Receiver
16-
from taskiq.result import TaskiqResult
1715
from tests.utils import AsyncQueueBroker
1816

1917

@@ -152,43 +150,6 @@ def test_func() -> None:
152150
assert result.is_err
153151

154152

155-
@pytest.mark.anyio
156-
async def test_run_task_exception_middlewares() -> None:
157-
"""Tests that run_task can run sync tasks."""
158-
159-
class _TestMiddleware(TaskiqMiddleware):
160-
found_exceptions: ClassVar[List[BaseException]] = []
161-
162-
def on_error(
163-
self,
164-
message: "TaskiqMessage",
165-
result: "TaskiqResult[Any]",
166-
exception: BaseException,
167-
) -> None:
168-
self.found_exceptions.append(exception)
169-
170-
def test_func() -> None:
171-
raise ValueError
172-
173-
broker = InMemoryBroker().with_middlewares(_TestMiddleware())
174-
receiver = get_receiver(broker)
175-
176-
result = await receiver.run_task(
177-
test_func,
178-
TaskiqMessage(
179-
task_id="",
180-
task_name="",
181-
labels={},
182-
args=[],
183-
kwargs={},
184-
),
185-
)
186-
assert result.return_value is None
187-
assert result.is_err
188-
assert len(_TestMiddleware.found_exceptions) == 1
189-
assert _TestMiddleware.found_exceptions[0].__class__ is ValueError
190-
191-
192153
@pytest.mark.anyio
193154
async def test_callback_success() -> None:
194155
"""Test that callback function works well."""

0 commit comments

Comments
 (0)