|
26 | 26 | import enum
|
27 | 27 | import errno
|
28 | 28 | import faulthandler
|
| 29 | +import functools |
29 | 30 | import getpass
|
30 | 31 | import inspect
|
31 | 32 | import io
|
@@ -2505,6 +2506,35 @@ def getTestCaseNames(self, testCaseClass): # pylint:disable=invalid-name
|
2505 | 2506 | self._random.shuffle(names)
|
2506 | 2507 | return names
|
2507 | 2508 |
|
| 2509 | + def shardTestCaseNames( |
| 2510 | + self, ordered_names: Sequence[str], shard_index: int, total_shards: int |
| 2511 | + ) -> Sequence[str]: |
| 2512 | + """Filters and returns test case names for a specific shard. |
| 2513 | +
|
| 2514 | + This method is intended to be used in conjunction with test sharding |
| 2515 | + (e.g., when running tests on a distributed system or when running tests |
| 2516 | + with bazel's test sharding feature). It will return a subset of the |
| 2517 | + input test case names, based on the shard index and total shard count. |
| 2518 | +
|
| 2519 | + Args: |
| 2520 | + names: A sequence of test case names. |
| 2521 | + shard_index: The index of the current shard. |
| 2522 | + total_shards: The total number of shards. |
| 2523 | +
|
| 2524 | + Returns: |
| 2525 | + A sequence of test case names for the current shard. |
| 2526 | + """ |
| 2527 | + bucket_iterator = itertools.cycle(range(total_shards)) |
| 2528 | + filtered_names = [] |
| 2529 | + # We need to sort the list of tests in order to determine which tests this |
| 2530 | + # shard is responsible for; however, it's important to preserve the order |
| 2531 | + # returned by the base loader, e.g. in the case of randomized test ordering. |
| 2532 | + for testcase in sorted(ordered_names): |
| 2533 | + bucket = next(bucket_iterator) |
| 2534 | + if bucket == shard_index: |
| 2535 | + filtered_names.append(testcase) |
| 2536 | + return [x for x in ordered_names if x in filtered_names] |
| 2537 | + |
2508 | 2538 |
|
2509 | 2539 | def get_default_xml_output_filename() -> Optional[str]:
|
2510 | 2540 | if os.environ.get('XML_OUTPUT_FILE'):
|
@@ -2626,21 +2656,19 @@ def _setup_sharding(
|
2626 | 2656 | # the test case names for this shard.
|
2627 | 2657 | delegate_get_names = base_loader.getTestCaseNames
|
2628 | 2658 |
|
2629 |
| - bucket_iterator = itertools.cycle(range(total_shards)) |
| 2659 | + def getSharedTestCaseNames(testCaseClass): |
| 2660 | + has_shard_test_case_names = hasattr(base_loader, 'shardTestCaseNames') |
| 2661 | + if has_shard_test_case_names: |
| 2662 | + sharder = getattr(base_loader, 'shardTestCaseNames') |
| 2663 | + else: |
| 2664 | + sharder = TestLoader.shardTestCaseNames |
2630 | 2665 |
|
2631 |
| - def getShardedTestCaseNames(testCaseClass): |
2632 |
| - filtered_names = [] |
2633 |
| - # We need to sort the list of tests in order to determine which tests this |
2634 |
| - # shard is responsible for; however, it's important to preserve the order |
2635 |
| - # returned by the base loader, e.g. in the case of randomized test ordering. |
2636 |
| - ordered_names = delegate_get_names(testCaseClass) |
2637 |
| - for testcase in sorted(ordered_names): |
2638 |
| - bucket = next(bucket_iterator) |
2639 |
| - if bucket == shard_index: |
2640 |
| - filtered_names.append(testcase) |
2641 |
| - return [x for x in ordered_names if x in filtered_names] |
| 2666 | + names = sharder( |
| 2667 | + delegate_get_names(testCaseClass), shard_index, total_shards |
| 2668 | + ) |
| 2669 | + return names |
2642 | 2670 |
|
2643 |
| - base_loader.getTestCaseNames = getShardedTestCaseNames # type: ignore[method-assign] |
| 2671 | + base_loader.getTestCaseNames = getSharedTestCaseNames # type: ignore[method-assign] |
2644 | 2672 | return base_loader, shard_index
|
2645 | 2673 |
|
2646 | 2674 |
|
|
0 commit comments