Skip to content

Commit c3dd5a4

Browse files
Abseil Teamcopybara-github
Abseil Team
authored andcommitted
Add extension point for letting TestLoader specify a custom sharding scheme.
PiperOrigin-RevId: 731407072
1 parent 4de3812 commit c3dd5a4

File tree

1 file changed

+41
-13
lines changed

1 file changed

+41
-13
lines changed

absl/testing/absltest.py

+41-13
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@
2626
import enum
2727
import errno
2828
import faulthandler
29+
import functools
2930
import getpass
3031
import inspect
3132
import io
@@ -2505,6 +2506,35 @@ def getTestCaseNames(self, testCaseClass): # pylint:disable=invalid-name
25052506
self._random.shuffle(names)
25062507
return names
25072508

2509+
def shardTestCaseNames(
2510+
self, ordered_names: Sequence[str], shard_index: int, total_shards: int
2511+
) -> Sequence[str]:
2512+
"""Filters and returns test case names for a specific shard.
2513+
2514+
This method is intended to be used in conjunction with test sharding
2515+
(e.g., when running tests on a distributed system or when running tests
2516+
with bazel's test sharding feature). It will return a subset of the
2517+
input test case names, based on the shard index and total shard count.
2518+
2519+
Args:
2520+
names: A sequence of test case names.
2521+
shard_index: The index of the current shard.
2522+
total_shards: The total number of shards.
2523+
2524+
Returns:
2525+
A sequence of test case names for the current shard.
2526+
"""
2527+
bucket_iterator = itertools.cycle(range(total_shards))
2528+
filtered_names = []
2529+
# We need to sort the list of tests in order to determine which tests this
2530+
# shard is responsible for; however, it's important to preserve the order
2531+
# returned by the base loader, e.g. in the case of randomized test ordering.
2532+
for testcase in sorted(ordered_names):
2533+
bucket = next(bucket_iterator)
2534+
if bucket == shard_index:
2535+
filtered_names.append(testcase)
2536+
return [x for x in ordered_names if x in filtered_names]
2537+
25082538

25092539
def get_default_xml_output_filename() -> Optional[str]:
25102540
if os.environ.get('XML_OUTPUT_FILE'):
@@ -2626,21 +2656,19 @@ def _setup_sharding(
26262656
# the test case names for this shard.
26272657
delegate_get_names = base_loader.getTestCaseNames
26282658

2629-
bucket_iterator = itertools.cycle(range(total_shards))
2659+
def getSharedTestCaseNames(testCaseClass):
2660+
has_shard_test_case_names = hasattr(base_loader, 'shardTestCaseNames')
2661+
if has_shard_test_case_names:
2662+
sharder = getattr(base_loader, 'shardTestCaseNames')
2663+
else:
2664+
sharder = TestLoader.shardTestCaseNames
26302665

2631-
def getShardedTestCaseNames(testCaseClass):
2632-
filtered_names = []
2633-
# We need to sort the list of tests in order to determine which tests this
2634-
# shard is responsible for; however, it's important to preserve the order
2635-
# returned by the base loader, e.g. in the case of randomized test ordering.
2636-
ordered_names = delegate_get_names(testCaseClass)
2637-
for testcase in sorted(ordered_names):
2638-
bucket = next(bucket_iterator)
2639-
if bucket == shard_index:
2640-
filtered_names.append(testcase)
2641-
return [x for x in ordered_names if x in filtered_names]
2666+
names = sharder(
2667+
delegate_get_names(testCaseClass), shard_index, total_shards
2668+
)
2669+
return names
26422670

2643-
base_loader.getTestCaseNames = getShardedTestCaseNames # type: ignore[method-assign]
2671+
base_loader.getTestCaseNames = getSharedTestCaseNames # type: ignore[method-assign]
26442672
return base_loader, shard_index
26452673

26462674

0 commit comments

Comments
 (0)