1
+ import json
2
+ import os
1
3
from collections import defaultdict
2
4
from dataclasses import (
3
5
dataclass ,
@@ -46,6 +48,20 @@ class ItemStatus(Enum):
46
48
SKIPPED = "skipped"
47
49
48
50
51
+ class _FakePytestObject :
52
+ def __init__ (self , collected_item : dict [str , str ]) -> None :
53
+ self .__module__ = collected_item ["modulename" ]
54
+ self .__name__ = collected_item ["methodname" ]
55
+
56
+
57
+ class _FakePytestItem :
58
+ def __init__ (self , collected_item : dict [str , str ]) -> None :
59
+ self .nodeid = collected_item ["nodeid" ]
60
+ self .name = collected_item ["name" ]
61
+ self .path = Path (collected_item ["path" ])
62
+ self .obj = _FakePytestObject (collected_item )
63
+
64
+
49
65
@dataclass
50
66
class SnapshotSession :
51
67
pytest_session : "pytest.Session"
@@ -127,6 +143,24 @@ def ran_item(
127
143
except ValueError :
128
144
pass # if we don't understand the outcome, leave the item as "not run"
129
145
146
+ def _merge_collected_items (self , collected_items : list [dict [str , str ]]) -> None :
147
+ for collected_item in collected_items :
148
+ custom_item = _FakePytestItem (collected_item )
149
+ if not any (
150
+ t .nodeid == custom_item .nodeid and t .name == custom_item .nodeid
151
+ for t in self ._collected_items
152
+ ):
153
+ self ._collected_items .add (custom_item ) # type: ignore[arg-type]
154
+
155
+ def _merge_selected_items (self , selected_items : dict [str , str ]) -> None :
156
+ for key , selected_item in selected_items .items ():
157
+ if key in self ._selected_items :
158
+ status = ItemStatus (selected_item )
159
+ if status != ItemStatus .NOT_RUN :
160
+ self ._selected_items [key ] = status
161
+ else :
162
+ self ._selected_items [key ] = ItemStatus (selected_item )
163
+
130
164
def finish (self ) -> int :
131
165
exitstatus = 0
132
166
self .flush_snapshot_write_queue ()
@@ -139,16 +173,39 @@ def finish(self) -> int:
139
173
)
140
174
141
175
if is_xdist_worker ():
142
- # TODO: If we're in a pytest-xdist worker, we need to combine the reports
143
- # of all the workers so that the controller can handle unused
144
- # snapshot removal.
176
+ worker_count = os .getenv ("PYTEST_XDIST_WORKER_COUNT" )
177
+ with open (".pytest_syrupy_worker_count" , "w" , encoding = "utf-8" ) as f :
178
+ f .write (worker_count ) # type: ignore[arg-type]
179
+ with open (
180
+ f".pytest_syrupy_{ os .getenv ("PYTEST_XDIST_WORKER" )} _result" ,
181
+ "w" ,
182
+ encoding = "utf-8" ,
183
+ ) as f :
184
+ json .dump (self .report .serialize (), f , indent = 2 )
145
185
return exitstatus
146
186
elif is_xdist_controller ():
147
187
# TODO: If we're in a pytest-xdist controller, merge all the reports.
148
188
# Until this is implemented, running syrupy with pytest-xdist is only
149
189
# partially functional.
150
190
return exitstatus
151
191
192
+ worker_count = None
193
+ try :
194
+ with open (".pytest_syrupy_worker_count" , encoding = "utf-8" ) as f :
195
+ worker_count = f .read ()
196
+ os .remove (".pytest_syrupy_worker_count" )
197
+ except FileNotFoundError :
198
+ pass
199
+
200
+ if worker_count :
201
+ for i in range (int (worker_count )):
202
+ with open (f".pytest_syrupy_gw{ i } _result" , encoding = "utf-8" ) as f :
203
+ data = json .load (f )
204
+ self ._merge_collected_items (data ["_collected_items" ])
205
+ self ._merge_selected_items (data ["_selected_items" ])
206
+ self .report .merge_serialized (data )
207
+ os .remove (f".pytest_syrupy_gw{ i } _result" )
208
+
152
209
if self .report .num_unused :
153
210
if self .update_snapshots :
154
211
self .remove_unused_snapshots (
0 commit comments