diff --git a/src/syrupy/assertion.py b/src/syrupy/assertion.py index 6bdb4fcf..5b48f698 100644 --- a/src/syrupy/assertion.py +++ b/src/syrupy/assertion.py @@ -63,6 +63,7 @@ class SnapshotAssertion: include: Optional["PropertyFilter"] = None exclude: Optional["PropertyFilter"] = None matcher: Optional["PropertyMatcher"] = None + extra_args: Dict = field(default_factory=dict) _exclude: Optional["PropertyFilter"] = field( init=False, @@ -105,6 +106,7 @@ def __post_init__(self) -> None: self._include = self.include self._exclude = self.exclude self._matcher = self.matcher + self._extra_args = self.extra_args def __init_extension( self, extension_class: Type["AbstractSyrupyExtension"] @@ -178,6 +180,7 @@ def with_defaults( include: Optional["PropertyFilter"] = None, matcher: Optional["PropertyMatcher"] = None, extension_class: Optional[Type["AbstractSyrupyExtension"]] = None, + extra_args: Optional[Dict] = None, ) -> "SnapshotAssertion": """ Create new snapshot assertion fixture with provided values. This preserves @@ -191,6 +194,7 @@ def with_defaults( test_location=self.test_location, extension_class=extension_class or self.extension_class, session=self.session, + extra_args=extra_args or self.extra_args, ) def use_extension( @@ -205,9 +209,13 @@ def use_extension( def assert_match(self, data: "SerializableData") -> None: assert self == data - def _serialize(self, data: "SerializableData") -> "SerializedData": + def _serialize(self, data: "SerializableData", **kwargs: Any) -> "SerializedData": return self.extension.serialize( - data, exclude=self._exclude, include=self._include, matcher=self.__matcher + data, + exclude=self._exclude, + include=self._include, + matcher=self.__matcher, + **kwargs, ) def get_assert_diff(self) -> List[str]: @@ -264,6 +272,7 @@ def __call__( extension_class: Optional[Type["AbstractSyrupyExtension"]] = None, matcher: Optional["PropertyMatcher"] = None, name: Optional["SnapshotIndex"] = None, + extra_args: Optional[Dict] = None, ) -> "SnapshotAssertion": """ Modifies assertion instance options @@ -280,6 +289,8 @@ def __call__( self.__with_prop("_custom_index", name) if diff is not None: self.__with_prop("_snapshot_diff", diff) + if extra_args: + self.__with_prop("_extra_args", extra_args) return self def __repr__(self) -> str: @@ -300,23 +311,29 @@ def _assert(self, data: "SerializableData") -> bool: matches = False assertion_success = False assertion_exception = None + extra_args = getattr(self, "_extra_args", {}) try: snapshot_data, tainted = self._recall_data(index=self.index) - serialized_data = self._serialize(data) + serialized_data = self._serialize(data, **extra_args) snapshot_diff = getattr(self, "_snapshot_diff", None) if snapshot_diff is not None: - snapshot_data_diff, _ = self._recall_data(index=snapshot_diff) + snapshot_data_diff, _ = self._recall_data( + index=snapshot_diff, **extra_args + ) if snapshot_data_diff is None: raise SnapshotDoesNotExist() serialized_data = self.extension.diff_snapshots( serialized_data=serialized_data, snapshot_data=snapshot_data_diff, + **extra_args, ) matches = ( not tainted and snapshot_data is not None and self.extension.matches( - serialized_data=serialized_data, snapshot_data=snapshot_data + serialized_data=serialized_data, + snapshot_data=snapshot_data, + **extra_args, ) ) assertion_success = matches @@ -361,7 +378,7 @@ def _post_assert(self) -> None: self._post_assert_actions.pop()() def _recall_data( - self, index: "SnapshotIndex" + self, index: "SnapshotIndex", **kwargs: Any ) -> Tuple[Optional["SerializableData"], bool]: try: return ( @@ -369,6 +386,7 @@ def _recall_data( test_location=self.test_location, index=index, session_id=str(id(self.session)), + **kwargs, ), False, ) diff --git a/src/syrupy/extensions/amber/__init__.py b/src/syrupy/extensions/amber/__init__.py index 74dbc33b..98e722bd 100644 --- a/src/syrupy/extensions/amber/__init__.py +++ b/src/syrupy/extensions/amber/__init__.py @@ -47,7 +47,9 @@ def delete_snapshots( else: Path(snapshot_location).unlink() - def _read_snapshot_collection(self, snapshot_location: str) -> "SnapshotCollection": + def _read_snapshot_collection( + self, snapshot_location: str, **kwargs: Any + ) -> "SnapshotCollection": return self.serializer_class.read_file(snapshot_location) @classmethod @@ -72,7 +74,7 @@ def _read_snapshot_data_from_location( @classmethod def _write_snapshot_collection( - cls, *, snapshot_collection: "SnapshotCollection" + cls, *, snapshot_collection: "SnapshotCollection", **kwargs: Any ) -> None: cls.serializer_class.write_file(snapshot_collection, merge=True) diff --git a/src/syrupy/extensions/base.py b/src/syrupy/extensions/base.py index 945cf20b..3cb0c5e7 100644 --- a/src/syrupy/extensions/base.py +++ b/src/syrupy/extensions/base.py @@ -8,6 +8,7 @@ from pathlib import Path from typing import ( TYPE_CHECKING, + Any, Callable, Dict, Iterator, @@ -67,6 +68,7 @@ def serialize( exclude: Optional["PropertyFilter"] = None, include: Optional["PropertyFilter"] = None, matcher: Optional["PropertyMatcher"] = None, + **kwargs: Any, ) -> "SerializedData": """ Serializes a python object / data structure into a string @@ -108,7 +110,7 @@ def is_snapshot_location(self, *, location: str) -> bool: return location.endswith(self._file_extension) def discover_snapshots( - self, *, test_location: "PyTestLocation" + self, *, test_location: "PyTestLocation", **kwargs: Any ) -> "SnapshotCollections": """ Returns all snapshot collections in test site @@ -117,7 +119,7 @@ def discover_snapshots( for filepath in walk_snapshot_dir(self.dirname(test_location=test_location)): if self.is_snapshot_location(location=filepath): snapshot_collection = self._read_snapshot_collection( - snapshot_location=filepath + snapshot_location=filepath, **kwargs ) if not snapshot_collection.has_snapshots: snapshot_collection = SnapshotEmptyCollection(location=filepath) @@ -134,6 +136,7 @@ def read_snapshot( test_location: "PyTestLocation", index: "SnapshotIndex", session_id: str, + **kwargs: Any, ) -> "SerializedData": """ This method is _final_, do not override. You can override @@ -145,6 +148,7 @@ def read_snapshot( snapshot_location=snapshot_location, snapshot_name=snapshot_name, session_id=session_id, + **kwargs, ) if snapshot_data is None: raise SnapshotDoesNotExist() @@ -216,7 +220,7 @@ def delete_snapshots( @abstractmethod def _read_snapshot_collection( - self, *, snapshot_location: str + self, *, snapshot_location: str, **kwargs: Any ) -> "SnapshotCollection": """ Read the snapshot location and construct a snapshot collection object @@ -225,7 +229,12 @@ def _read_snapshot_collection( @abstractmethod def _read_snapshot_data_from_location( - self, *, snapshot_location: str, snapshot_name: str, session_id: str + self, + *, + snapshot_location: str, + snapshot_name: str, + session_id: str, + **kwargs: Any, ) -> Optional["SerializedData"]: """ Get only the snapshot data from location for assertion @@ -235,7 +244,7 @@ def _read_snapshot_data_from_location( @classmethod @abstractmethod def _write_snapshot_collection( - cls, *, snapshot_collection: "SnapshotCollection" + cls, *, snapshot_collection: "SnapshotCollection", **kwargs: Any ) -> None: """ Adds the snapshot data to the snapshots in collection location @@ -243,7 +252,7 @@ def _write_snapshot_collection( raise NotImplementedError @classmethod - def dirname(cls, *, test_location: "PyTestLocation") -> str: + def dirname(cls, *, test_location: "PyTestLocation", **kwargs: Any) -> str: test_dir = Path(test_location.filepath).parent return str(test_dir.joinpath(SNAPSHOT_DIRNAME)) @@ -259,7 +268,10 @@ class SnapshotReporter(ABC): _context_line_count = 1 def diff_snapshots( - self, serialized_data: "SerializedData", snapshot_data: "SerializedData" + self, + serialized_data: "SerializedData", + snapshot_data: "SerializedData", + **kwargs: Any, ) -> "SerializedData": env = {DISABLE_COLOR_ENV_VAR: "true"} attrs = {"_context_line_count": 0} @@ -267,7 +279,10 @@ def diff_snapshots( return "\n".join(self.diff_lines(serialized_data, snapshot_data)) def diff_lines( - self, serialized_data: "SerializedData", snapshot_data: "SerializedData" + self, + serialized_data: "SerializedData", + snapshot_data: "SerializedData", + **kwargs: Any, ) -> Iterator[str]: for line in self.__diff_lines(str(snapshot_data), str(serialized_data)): yield reset(line) @@ -407,6 +422,7 @@ def matches( *, serialized_data: "SerializableData", snapshot_data: "SerializableData", + **kwargs: Any, ) -> bool: """ Compares serialized data and snapshot data and returns diff --git a/src/syrupy/extensions/json/__init__.py b/src/syrupy/extensions/json/__init__.py index 5b52a8d5..ccc9488e 100644 --- a/src/syrupy/extensions/json/__init__.py +++ b/src/syrupy/extensions/json/__init__.py @@ -145,6 +145,7 @@ def serialize( exclude: Optional["PropertyFilter"] = None, include: Optional["PropertyFilter"] = None, matcher: Optional["PropertyMatcher"] = None, + **kwargs: Any, ) -> "SerializedData": data = self._filter( data=data, diff --git a/src/syrupy/extensions/single_file.py b/src/syrupy/extensions/single_file.py index 0b216115..6f421360 100644 --- a/src/syrupy/extensions/single_file.py +++ b/src/syrupy/extensions/single_file.py @@ -3,6 +3,7 @@ from pathlib import Path from typing import ( TYPE_CHECKING, + Any, Optional, Set, Type, @@ -49,6 +50,7 @@ def serialize( exclude: Optional["PropertyFilter"] = None, include: Optional["PropertyFilter"] = None, matcher: Optional["PropertyMatcher"] = None, + **kwargs: Any, ) -> "SerializedData": return self.get_supported_dataclass()(data) @@ -74,12 +76,15 @@ def _get_file_basename( return cls.get_snapshot_name(test_location=test_location, index=index) @classmethod - def dirname(cls, *, test_location: "PyTestLocation") -> str: + def dirname(cls, *, test_location: "PyTestLocation", **kwargs: Any) -> str: original_dirname = AbstractSyrupyExtension.dirname(test_location=test_location) return str(Path(original_dirname).joinpath(test_location.basename)) def _read_snapshot_collection( - self, *, snapshot_location: str + self, + *, + snapshot_location: str, + **kwargs: Any, ) -> "SnapshotCollection": file_ext_len = len(self._file_extension) + 1 if self._file_extension else 0 filename_wo_ext = snapshot_location[:-file_ext_len] @@ -116,7 +121,10 @@ def get_write_encoding(cls) -> Optional[str]: @classmethod def _write_snapshot_collection( - cls, *, snapshot_collection: "SnapshotCollection" + cls, + *, + snapshot_collection: "SnapshotCollection", + **kwargs: Any, ) -> None: filepath, data = ( snapshot_collection.location,