Skip to content

avoid holding a reference to exception and value in to_thread_run_sync #3229

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 17 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 8 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions newsfragments/3229.bugfix.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Avoid holding refs to result/exception from ``trio.to_thread.run_sync``.
16 changes: 1 addition & 15 deletions src/trio/_core/_tests/test_run.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@
create_asyncio_future_in_new_loop,
gc_collect_harder,
ignore_coroutine_never_awaited_warnings,
no_other_refs,
restore_unraisablehook,
slow,
)
Expand Down Expand Up @@ -2802,25 +2803,10 @@ async def spawn_tasks_in_old_nursery(task_status: _core.TaskStatus[None]) -> Non
assert RaisesGroup(ValueError, ValueError).matches(excinfo.value.__cause__)


if sys.version_info >= (3, 11):

def no_other_refs() -> list[object]:
return []

else:

def no_other_refs() -> list[object]:
return [sys._getframe(1)]


@pytest.mark.skipif(
sys.implementation.name != "cpython",
reason="Only makes sense with refcounting GC",
)
@pytest.mark.xfail(
sys.version_info >= (3, 14),
reason="https://github.com/python/cpython/issues/125603",
)
async def test_ki_protection_doesnt_leave_cyclic_garbage() -> None:
class MyException(Exception):
pass
Expand Down
17 changes: 17 additions & 0 deletions src/trio/_core/_tests/tutil.py
Original file line number Diff line number Diff line change
Expand Up @@ -115,3 +115,20 @@ def check_sequence_matches(seq: Sequence[T], template: Iterable[T | set[T]]) ->
def create_asyncio_future_in_new_loop() -> asyncio.Future[object]:
with closing(asyncio.new_event_loop()) as loop:
return loop.create_future()


if sys.version_info >= (3, 14):

def no_other_refs() -> list[object]:
gen = sys._getframe().f_generator
return [] if gen is None else [gen]

elif sys.version_info >= (3, 11):

def no_other_refs() -> list[object]:
return []

else:

def no_other_refs() -> list[object]:
return [sys._getframe(1)]
58 changes: 57 additions & 1 deletion src/trio/_tests/test_threads.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from __future__ import annotations

import contextvars
import gc
import queue as stdlib_queue
import re
import sys
Expand Down Expand Up @@ -29,7 +30,7 @@
sleep_forever,
)
from .._core._tests.test_ki import ki_self
from .._core._tests.tutil import slow
from .._core._tests.tutil import gc_collect_harder, no_other_refs, slow
from .._threads import (
active_thread_count,
current_default_thread_limiter,
Expand Down Expand Up @@ -1141,3 +1142,58 @@ async def wait_no_threads_left() -> None:
async def test_wait_all_threads_completed_no_threads() -> None:
await wait_all_threads_completed()
assert active_thread_count() == 0


@pytest.mark.skipif(
sys.implementation.name == "pypy",
reason=(
"gc.get_referrers is broken on PyPy (see "
"https://github.com/pypy/pypy/issues/5075)"
),
)
async def test_run_sync_worker_references() -> None:
class Foo:
pass

def foo(_: Foo) -> Foo:
return Foo()

cvar = contextvars.ContextVar[Foo]("cvar")
contextval = Foo()
arg = Foo()
cvar.set(contextval)
v = await to_thread_run_sync(foo, arg)

cvar.set(Foo())
gc_collect_harder()

assert gc.get_referrers(contextval) == no_other_refs()
assert gc.get_referrers(foo) == no_other_refs()
assert gc.get_referrers(arg) == no_other_refs()
assert gc.get_referrers(v) == no_other_refs()


@pytest.mark.skipif(
sys.implementation.name == "pypy",
reason=(
"gc.get_referrers is broken on PyPy (see "
"https://github.com/pypy/pypy/issues/5075)"
),
)
async def test_run_sync_workerreferences_exc() -> None:

class MyException(Exception):
pass

def throw() -> None:
raise MyException

e = None
try:
await to_thread_run_sync(throw)
except MyException as err:
e = err

gc_collect_harder()

assert gc.get_referrers(e) == no_other_refs()
64 changes: 50 additions & 14 deletions src/trio/_threads.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
import queue as stdlib_queue
import threading
from itertools import count
from typing import TYPE_CHECKING, Generic, TypeVar
from typing import TYPE_CHECKING, Final, Generic, NoReturn, Protocol, TypeVar

import attrs
import outcome
Expand Down Expand Up @@ -36,6 +36,7 @@
Ts = TypeVarTuple("Ts")

RetT = TypeVar("RetT")
T_co = TypeVar("T_co", covariant=True)


class _ParentTaskData(threading.local):
Expand Down Expand Up @@ -253,6 +254,32 @@ def run_in_system_nursery(self, token: TrioToken) -> None:
token.run_sync_soon(self.run_sync)


class _SupportsUnwrap(Protocol, Generic[T_co]):
def unwrap(self) -> T_co: ...


class _Value(_SupportsUnwrap[T_co]):
def __init__(self, v: T_co) -> None:
self._v: Final = v

def unwrap(self) -> T_co:
try:
return self._v
finally:
del self._v


class _Error(_SupportsUnwrap[NoReturn]):
def __init__(self, e: BaseException) -> None:
self._e: Final = e

def unwrap(self) -> NoReturn:
try:
raise self._e
finally:
del self._e


@enable_ki_protection
async def to_thread_run_sync(
sync_fn: Callable[[Unpack[Ts]], RetT],
Expand Down Expand Up @@ -375,8 +402,15 @@ def do_release_then_return_result() -> RetT:
limiter.release_on_behalf_of(placeholder)

result = outcome.capture(do_release_then_return_result)
if isinstance(result, outcome.Error):
result2: _SupportsUnwrap[RetT] = _Error(result.error)
elif isinstance(result, outcome.Value):
result2 = _Value(result.value)
else:
raise RuntimeError("invalid outcome")
del result
if task_register[0] is not None:
trio.lowlevel.reschedule(task_register[0], outcome.Value(result))
trio.lowlevel.reschedule(task_register[0], outcome.Value(result2))

current_trio_token = trio.lowlevel.current_trio_token()

Expand Down Expand Up @@ -440,20 +474,22 @@ def abort(raise_cancel: RaiseCancelT) -> trio.lowlevel.Abort:

while True:
# wait_task_rescheduled return value cannot be typed
msg_from_thread: outcome.Outcome[RetT] | Run[object] | RunSync[object] = (
msg_from_thread: _Value[RetT] | _Error | Run[object] | RunSync[object] = (
await trio.lowlevel.wait_task_rescheduled(abort)
)
if isinstance(msg_from_thread, outcome.Outcome):
return msg_from_thread.unwrap()
elif isinstance(msg_from_thread, Run):
await msg_from_thread.run()
elif isinstance(msg_from_thread, RunSync):
msg_from_thread.run_sync()
else: # pragma: no cover, internal debugging guard TODO: use assert_never
raise TypeError(
f"trio.to_thread.run_sync received unrecognized thread message {msg_from_thread!r}.",
)
del msg_from_thread
try:
if isinstance(msg_from_thread, (_Value, _Error)):
return msg_from_thread.unwrap()
elif isinstance(msg_from_thread, Run):
await msg_from_thread.run()
elif isinstance(msg_from_thread, RunSync):
msg_from_thread.run_sync()
else: # pragma: no cover, internal debugging guard TODO: use assert_never
raise TypeError(
f"trio.to_thread.run_sync received unrecognized thread message {msg_from_thread!r}.",
)
finally:
del msg_from_thread


def from_thread_check_cancelled() -> None:
Expand Down
Loading