Skip to content

Commit 1047038

Browse files
committed
feat: improve xdist compatibility
1 parent 4c6d3e7 commit 1047038

File tree

3 files changed

+102
-3
lines changed

3 files changed

+102
-3
lines changed

src/syrupy/data.py

+11
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
)
55
from typing import (
66
TYPE_CHECKING,
7+
Any,
78
Dict,
89
Iterator,
910
List,
@@ -124,6 +125,16 @@ def __iter__(self) -> Iterator["SnapshotCollection"]:
124125
def __contains__(self, key: str) -> bool:
125126
return key in self._snapshot_collections
126127

128+
def serialize(self) -> dict[str, Any]:
129+
return {k: [c.name for c in v] for k, v in self._snapshot_collections.items()}
130+
131+
def merge_serialized(self, data: dict[str, Any]) -> None:
132+
for location, names in data.items():
133+
snapshot_collection = SnapshotCollection(location=location)
134+
for name in names:
135+
snapshot_collection.add(Snapshot(name))
136+
self.update(snapshot_collection)
137+
127138

128139
@dataclass
129140
class DiffedLine:

src/syrupy/report.py

+31
Original file line numberDiff line numberDiff line change
@@ -508,6 +508,37 @@ def _ran_items_match_location(self, snapshot_location: str) -> bool:
508508
for item in self.ran_items
509509
)
510510

511+
def serialize(self) -> dict[str, Any]:
512+
return {
513+
"discovered": self.discovered.serialize(),
514+
"created": self.created.serialize(),
515+
"failed": self.failed.serialize(),
516+
"matched": self.matched.serialize(),
517+
"updated": self.updated.serialize(),
518+
"used": self.used.serialize(),
519+
"_collected_items": [
520+
{
521+
"nodeid": c.nodeid,
522+
"name": c.name,
523+
"path": str(c.path),
524+
"modulename": c.obj.__module__, # type: ignore[attr-defined]
525+
"methodname": c.obj.__name__, # type: ignore[attr-defined]
526+
}
527+
for c in list(self.collected_items)
528+
],
529+
"_selected_items": {
530+
key: status.value for key, status in self.selected_items.items()
531+
},
532+
}
533+
534+
def merge_serialized(self, data: dict[str, Any]) -> None:
535+
self.discovered.merge_serialized(data["discovered"])
536+
self.created.merge_serialized(data["created"])
537+
self.failed.merge_serialized(data["failed"])
538+
self.matched.merge_serialized(data["matched"])
539+
self.updated.merge_serialized(data["updated"])
540+
self.used.merge_serialized(data["used"])
541+
511542

512543
@dataclass(frozen=True)
513544
class Expression:

src/syrupy/session.py

+60-3
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,5 @@
1+
import json
2+
import os
13
from collections import defaultdict
24
from dataclasses import (
35
dataclass,
@@ -46,6 +48,20 @@ class ItemStatus(Enum):
4648
SKIPPED = "skipped"
4749

4850

51+
class _FakePytestObject:
52+
def __init__(self, collected_item: dict[str, str]) -> None:
53+
self.__module__ = collected_item["modulename"]
54+
self.__name__ = collected_item["methodname"]
55+
56+
57+
class _FakePytestItem:
58+
def __init__(self, collected_item: dict[str, str]) -> None:
59+
self.nodeid = collected_item["nodeid"]
60+
self.name = collected_item["name"]
61+
self.path = Path(collected_item["path"])
62+
self.obj = _FakePytestObject(collected_item)
63+
64+
4965
@dataclass
5066
class SnapshotSession:
5167
pytest_session: "pytest.Session"
@@ -127,6 +143,24 @@ def ran_item(
127143
except ValueError:
128144
pass # if we don't understand the outcome, leave the item as "not run"
129145

146+
def _merge_collected_items(self, collected_items: list[dict[str, str]]) -> None:
147+
for collected_item in collected_items:
148+
custom_item = _FakePytestItem(collected_item)
149+
if not any(
150+
t.nodeid == custom_item.nodeid and t.name == custom_item.nodeid
151+
for t in self._collected_items
152+
):
153+
self._collected_items.add(custom_item) # type: ignore[arg-type]
154+
155+
def _merge_selected_items(self, selected_items: dict[str, str]) -> None:
156+
for key, selected_item in selected_items.items():
157+
if key in self._selected_items:
158+
status = ItemStatus(selected_item)
159+
if status != ItemStatus.NOT_RUN:
160+
self._selected_items[key] = status
161+
else:
162+
self._selected_items[key] = ItemStatus(selected_item)
163+
130164
def finish(self) -> int:
131165
exitstatus = 0
132166
self.flush_snapshot_write_queue()
@@ -139,16 +173,39 @@ def finish(self) -> int:
139173
)
140174

141175
if is_xdist_worker():
142-
# TODO: If we're in a pytest-xdist worker, we need to combine the reports
143-
# of all the workers so that the controller can handle unused
144-
# snapshot removal.
176+
worker_count = os.getenv("PYTEST_XDIST_WORKER_COUNT")
177+
with open(".pytest_syrupy_worker_count", "w", encoding="utf-8") as f:
178+
f.write(worker_count) # type: ignore[arg-type]
179+
with open(
180+
f".pytest_syrupy_{os.getenv("PYTEST_XDIST_WORKER")}_result",
181+
"w",
182+
encoding="utf-8",
183+
) as f:
184+
json.dump(self.report.serialize(), f, indent=2)
145185
return exitstatus
146186
elif is_xdist_controller():
147187
# TODO: If we're in a pytest-xdist controller, merge all the reports.
148188
# Until this is implemented, running syrupy with pytest-xdist is only
149189
# partially functional.
150190
return exitstatus
151191

192+
worker_count = None
193+
try:
194+
with open(".pytest_syrupy_worker_count", encoding="utf-8") as f:
195+
worker_count = f.read()
196+
os.remove(".pytest_syrupy_worker_count")
197+
except FileNotFoundError:
198+
pass
199+
200+
if worker_count:
201+
for i in range(int(worker_count)):
202+
with open(f".pytest_syrupy_gw{i}_result", encoding="utf-8") as f:
203+
data = json.load(f)
204+
self._merge_collected_items(data["_collected_items"])
205+
self._merge_selected_items(data["_selected_items"])
206+
self.report.merge_serialized(data)
207+
os.remove(f".pytest_syrupy_gw{i}_result")
208+
152209
if self.report.num_unused:
153210
if self.update_snapshots:
154211
self.remove_unused_snapshots(

0 commit comments

Comments
 (0)