Skip to content

Commit e49a5f7

Browse files
committed
Add support for inspect.iscoroutinefunction() in Coroutine provider
1 parent 2330122 commit e49a5f7

File tree

2 files changed

+25
-10
lines changed

2 files changed

+25
-10
lines changed

src/dependency_injector/providers.pyx

+10-10
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22

33
from __future__ import absolute_import
44

5+
import asyncio
56
import copy
67
import errno
78
import functools
@@ -27,16 +28,14 @@ except ImportError:
2728
import __builtin__ as builtins
2829

2930
try:
30-
import asyncio
31+
from inspect import _is_coroutine_marker
3132
except ImportError:
32-
asyncio = None
33-
_is_coroutine_marker = None
34-
else:
35-
if sys.version_info >= (3, 5, 3):
36-
import asyncio.coroutines
37-
_is_coroutine_marker = asyncio.coroutines._is_coroutine
38-
else:
39-
_is_coroutine_marker = True
33+
_is_coroutine_marker = True
34+
35+
try:
36+
from asyncio.coroutines import _is_coroutine
37+
except ImportError:
38+
_is_coroutine = True
4039

4140
try:
4241
import ConfigParser as iniconfigparser
@@ -1475,7 +1474,8 @@ cdef class Coroutine(Callable):
14751474
some_coroutine.add_kwargs(keyword_argument1=3, keyword_argument=4)
14761475
"""
14771476

1478-
_is_coroutine = _is_coroutine_marker
1477+
_is_coroutine_marker = _is_coroutine_marker # Python >=3.12
1478+
_is_coroutine = _is_coroutine # Python <3.16
14791479

14801480
def set_provides(self, provides):
14811481
"""Set provider provides."""

tests/unit/providers/coroutines/test_coroutine_py35.py

+15
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
"""Coroutine provider tests."""
2+
import sys
23

34
from dependency_injector import providers, errors
45
from pytest import mark, raises
@@ -208,3 +209,17 @@ def test_repr():
208209
"<dependency_injector.providers."
209210
"Coroutine({0}) at {1}>".format(repr(example), hex(id(provider)))
210211
)
212+
213+
214+
@mark.skipif(sys.version_info > (3, 15), reason="requires Python<3.16")
215+
def test_asyncio_iscoroutinefunction() -> None:
216+
from asyncio.coroutines import iscoroutinefunction
217+
218+
assert iscoroutinefunction(providers.Coroutine(example))
219+
220+
221+
@mark.skipif(sys.version_info < (3, 12), reason="requires Python>=3.12")
222+
def test_inspect_iscoroutinefunction() -> None:
223+
from inspect import iscoroutinefunction
224+
225+
assert iscoroutinefunction(providers.Coroutine(example))

0 commit comments

Comments
 (0)