0

Shard tests via SDK version.

Groups tests into the same shard that run on the same SDK version.

Bug: 1383650
Change-Id: I07f94d61091cb064ac08559019cb8d90de978f96
Reviewed-on: https://chromium-review.googlesource.com/c/chromium/src/+/4080813
Commit-Queue: Benjamin Joyce (Ben) <bjoyce@chromium.org>
Reviewed-by: Andrew Grieve <agrieve@chromium.org>
Cr-Commit-Position: refs/heads/main@{#1164318}
This commit is contained in:
Ben Joyce
2023-06-29 20:26:40 +00:00
committed by Chromium LUCI CQ
parent b77c227878
commit 234901d37f
2 changed files with 209 additions and 99 deletions

@ -2,6 +2,7 @@
# Use of this source code is governed by a BSD-style license that can be
# found in the LICENSE file.
import itertools
import json
import logging
import multiprocessing
@ -14,8 +15,9 @@ import tempfile
import threading
import time
import zipfile
from collections import defaultdict
from concurrent.futures import ThreadPoolExecutor
from six.moves import range # pylint: disable=redefined-builtin
from devil.utils import cmd_helper
from py_utils import tempfile_ext
@ -25,12 +27,9 @@ from pylib.base import test_run
from pylib.constants import host_paths
from pylib.results import json_results
# Suites we shouldn't shard, usually because they don't contain enough test
# cases.
_EXCLUDED_SUITES = {
'password_check_junit_tests',
'touch_to_fill_junit_tests',
}
# Chosen after timing test runs of chrome_junit_tests with 7,16,32,
# and 64 workers in threadpool and different classes_per_job.
_MAX_TESTS_PER_JOB = 150
_FAILURE_TYPES = (
base_test_result.ResultType.FAIL,
@ -38,8 +37,9 @@ _FAILURE_TYPES = (
base_test_result.ResultType.TIMEOUT,
)
# Running the largest test suite with a single shard takes about 22 minutes.
_SHARD_TIMEOUT = 30 * 60
# Test suites are broken up into batches or "jobs" of about 150 tests.
# Each job should take no longer than 30 seconds.
_JOB_TIMEOUT = 30
# RegExp to detect logcat lines, e.g., 'I/AssetManager: not found'.
_LOGCAT_RE = re.compile(r'(:?\d+\| )?[A-Z]/[\w\d_-]+:')
@ -155,12 +155,12 @@ class LocalMachineJunitTestRun(test_run.TestRun):
# override
def RunTests(self, results, raw_logs_fh=None):
# Takes .5-3 seconds to list tests, depending on the number of tests.
test_classes = _GetTestClasses(self.GetTestsForListing())
shards = ChooseNumOfShards(test_classes, self._test_instance.shards)
# TODO(1384204): This step can take up to 3.5 seconds to execute when there
# are a lot of tests.
test_list = self.GetTestsForListing()
grouped_tests = GroupTestsForShard(test_list)
grouped_tests = GroupTestsForShard(shards, test_classes)
shard_list = list(range(shards))
shard_list = list(range(len(grouped_tests)))
shard_filter = self._test_instance.shard_filter
if shard_filter:
shard_list = [x for x in shard_list if x in shard_filter]
@ -175,11 +175,14 @@ class LocalMachineJunitTestRun(test_run.TestRun):
results.append(test_run_results)
return
num_workers = ChooseNumOfWorkers(len(test_list), self._test_instance.shards)
if shard_filter:
logging.warning('Running test shards: %s',
', '.join(str(x) for x in shard_list))
logging.warning('Running test shards: %s using %s concurrent process(es)',
', '.join(str(x) for x in shard_list), num_workers)
else:
logging.warning('Running tests on %d shard(s).', shards)
logging.warning(
'Running tests with %d shard(s) using %s concurrent process(es).',
len(grouped_tests), num_workers)
with tempfile_ext.NamedTemporaryDirectory() as temp_dir:
cmd_list = [[self._wrapper_path] for _ in shard_list]
@ -190,7 +193,7 @@ class LocalMachineJunitTestRun(test_run.TestRun):
g for i, g in enumerate(grouped_tests) if i in shard_list
]
jar_args_list = self._CreateJarArgsList(json_result_file_paths,
active_groups, shards)
active_groups, num_workers)
if jar_args_list:
for cmd, jar_args in zip(cmd_list, jar_args_list):
cmd += ['--jar-args', '"%s"' % ' '.join(jar_args)]
@ -204,7 +207,8 @@ class LocalMachineJunitTestRun(test_run.TestRun):
show_logcat = logging.getLogger().isEnabledFor(logging.INFO)
num_omitted_lines = 0
for line in _RunCommandsAndSerializeOutput(cmd_list, shard_list):
for line in _RunCommandsAndSerializeOutput(cmd_list, num_workers,
shard_list):
if raw_logs_fh:
raw_logs_fh.write(line)
if show_logcat or not _LOGCAT_RE.match(line):
@ -236,7 +240,7 @@ class LocalMachineJunitTestRun(test_run.TestRun):
base_test_result.ResultType.UNKNOWN)
]
if shards > 1 and failed_shards:
if num_workers > 1 and failed_shards:
for i in failed_shards:
filt = ':'.join(grouped_tests[i])
print(f'Test filter for failed shard {i}: --test-filter "{filt}"')
@ -244,8 +248,7 @@ class LocalMachineJunitTestRun(test_run.TestRun):
print(
f'{len(failed_shards)} shards had failing tests. To re-run only '
f'these shards, use the above filter flags, or use: '
f'--shards {shards} --shard-filter',
','.join(str(x) for x in failed_shards))
f'--shard-filter', ','.join(str(x) for x in failed_shards))
test_run_results = base_test_result.TestRunResults()
test_run_results.AddResults(results_list)
@ -275,43 +278,58 @@ def AddPropertiesJar(cmd_list, temp_dir, resource_apk):
cmd.extend(['--classpath', properties_jar_path])
def ChooseNumOfShards(test_classes, shards=None):
if shards is None:
# Local tests of explicit --shard values show that max speed is achieved
# at cpu_count() / 2.
# Using -XX:TieredStopAtLevel=1 is required for this result. The flag
# reduces CPU time by two-thirds, making sharding more effective.
shards = max(1, multiprocessing.cpu_count() // 2)
# It can actually take longer to run if you shard too much, especially on
# smaller suites. Locally media_base_junit_tests takes 4.3 sec with 1 shard,
# and 6 sec with 2 or more shards.
min_classes_per_shard = 8
else:
min_classes_per_shard = 1
def ChooseNumOfWorkers(num_jobs, num_workers=None):
if num_workers is None:
num_workers = max(1, multiprocessing.cpu_count() // 2)
shards = max(1, min(shards, len(test_classes) // min_classes_per_shard))
return shards
return min(num_workers, num_jobs)
def GroupTestsForShard(num_of_shards, test_classes):
"""Groups tests that will be ran on each shard.
def GroupTestsForShard(test_list):
"""Groups tests that will be run on each shard.
Groups tests from the same SDK version. For a specific
SDK version, groups tests from the same class together.
Args:
num_of_shards: number of shards to split tests between.
test_classes: A list of test_class files in the jar.
test_list: A list of the test names.
Return:
Returns a list test lists.
Returns a tuple containing the number of unique sdks and a list of
test lists.
"""
ret = [[] for _ in range(num_of_shards)]
tests_by_sdk = defaultdict(set)
for test in test_list:
class_name, sdk_ver = _TEST_SDK_VERSION.match(test).groups()
tests_by_sdk[sdk_ver].add((class_name, test))
# Round robin test distribiution to reduce chance that a sequential group of
# classes all have an unusually high number of tests.
for count, test_cls in enumerate(test_classes):
test_cls = test_cls + '*'
test_cls = test_cls.replace('/', '.')
ret[count % num_of_shards].append(test_cls)
ret = []
for tests_for_sdk in tests_by_sdk.values():
tests_for_sdk = sorted(tests_for_sdk)
test_count = 0
# TODO(1458958): Group by classes instead of test names and
# add --sdk-version as filter option. This will reduce filter verbiage.
curr_tests = []
for _, tests_from_class_tuple in itertools.groupby(tests_for_sdk,
lambda x: x[0]):
temp_tests = [
test.replace('#', '.') for _, test in tests_from_class_tuple
]
test_count += len(temp_tests)
curr_tests += temp_tests
if test_count >= _MAX_TESTS_PER_JOB:
ret.append(curr_tests)
test_count = 0
curr_tests = []
ret.append(curr_tests)
# Add an empty shard so that the test runner can throw a error from not
# having any tests.
if not ret:
ret.append([])
return ret
@ -328,11 +346,13 @@ def _DumpJavaStacks(pid):
return result.stdout
def _RunCommandsAndSerializeOutput(cmd_list, shard_list):
def _RunCommandsAndSerializeOutput(cmd_list, num_workers, shard_list):
"""Runs multiple commands in parallel and yields serialized output lines.
Args:
cmd_list: List of command lists to run.
num_workers: The number of concurrent processes to run jobs in the
shard_list.
shard_list: Shard index of each command list.
Raises:
@ -347,8 +367,6 @@ def _RunCommandsAndSerializeOutput(cmd_list, shard_list):
for _ in range(num_shards - 1):
temp_files.append(tempfile.TemporaryFile(mode='w+t', encoding='utf-8'))
deadline = time.time() + (_SHARD_TIMEOUT / (num_shards // 2 + 1))
yield '\n'
yield f'Shard {shard_list[0]} output:\n'
@ -369,7 +387,7 @@ def _RunCommandsAndSerializeOutput(cmd_list, shard_list):
return proc
try:
proc.wait(timeout=deadline - time.time())
proc.wait(timeout=(time.time() + _JOB_TIMEOUT))
except subprocess.TimeoutExpired:
timeout_dumps[idx] = _DumpJavaStacks(proc.pid)
proc.kill()
@ -377,13 +395,13 @@ def _RunCommandsAndSerializeOutput(cmd_list, shard_list):
# Not needed, but keeps pylint happy.
return None
with ThreadPoolExecutor(max_workers=num_shards) as pool:
with ThreadPoolExecutor(max_workers=num_workers) as pool:
futures = []
for i, cmd in enumerate(cmd_list):
futures.append(pool.submit(run_proc, cmd=cmd, idx=i))
yield from _StreamFirstShardOutput(shard_list[0], futures[0].result(),
deadline)
time.time() + _JOB_TIMEOUT)
for i, shard in enumerate(shard_list[1:]):
# Shouldn't cause timeout as run_proc terminates the process with
@ -441,21 +459,3 @@ def _StreamFirstShardOutput(shard, shard_proc, deadline):
line = shard_queue.get()
if line:
yield f'{shard:2}| {line}'
def _GetTestClasses(test_list):
test_classes = set()
unmatched_tests = []
for test in test_list:
match = _TEST_SDK_VERSION.match(test)
if match:
test_classes.add(match.group(1))
else:
unmatched_tests.append(test)
logging.info('Found %d test classes.', len(test_classes))
if unmatched_tests:
logging.warning('Could not parse the class from test(s): %s',
unmatched_tests)
return list(test_classes)

@ -15,6 +15,9 @@ from mock import patch # pylint: disable=import-error
class LocalMachineJunitTestRunTests(unittest.TestCase):
def setUp(self):
local_machine_junit_test_run._MAX_TESTS_PER_JOB = 150
def testAddPropertiesJar(self):
with tempfile_ext.NamedTemporaryDirectory() as temp_dir:
apk = 'resource_apk'
@ -38,40 +41,147 @@ class LocalMachineJunitTestRunTests(unittest.TestCase):
@patch('multiprocessing.cpu_count')
def testChooseNumOfShards(self, mock_cpu_count):
mock_cpu_count.return_value = 37
# Tests set by num_cpus
test_classes = [1] * 500
shards = local_machine_junit_test_run.ChooseNumOfShards(test_classes)
jobs = 500
shards = local_machine_junit_test_run.ChooseNumOfWorkers(jobs)
self.assertEqual(18, shards)
# Tests using min_class per shards.
test_classes = [1] * 20
shards = local_machine_junit_test_run.ChooseNumOfShards(test_classes)
# Number of jobs is less than cpu count.
jobs = 5
shards = local_machine_junit_test_run.ChooseNumOfWorkers(jobs)
self.assertEqual(5, shards)
# Number of test groups is less than shard request.
shards = local_machine_junit_test_run.ChooseNumOfWorkers(jobs, 18)
self.assertEqual(5, shards)
# Shard request is less than job group.
shards = local_machine_junit_test_run.ChooseNumOfWorkers(jobs, 2)
self.assertEqual(2, shards)
# Tests set by flag.
shards = local_machine_junit_test_run.ChooseNumOfShards(test_classes, 18)
self.assertEqual(18, shards)
shards = local_machine_junit_test_run.ChooseNumOfShards(test_classes, 2)
self.assertEqual(2, shards)
def testGroupTestsForShardWithSdk(self):
test_list = []
results = local_machine_junit_test_run.GroupTestsForShard(test_list)
self.assertListEqual(results, [[]])
def testGroupTestsForShard(self):
test_classes = []
results = local_machine_junit_test_run.GroupTestsForShard(1, test_classes)
self.assertEqual(results, [[]])
# All the same SDK and classes. Should come back as one job.
t1 = 'a.b#c[28]'
t2 = 'a.b#d[28]'
t3 = 'a.b#e[28]'
test_list = [t1, t2, t3]
results = local_machine_junit_test_run.GroupTestsForShard(test_list)
ans = [[t1.replace('#', '.'), t2.replace('#', '.'), t3.replace('#', '.')]]
for idx, _ in enumerate(ans):
self.assertCountEqual(ans[idx], results[idx])
test_classes = ['dir/test'] * 5
results = local_machine_junit_test_run.GroupTestsForShard(1, test_classes)
self.assertEqual(results, [['dir.test*'] * 5])
# Tests same class, but different sdks.
# Should come back as 3 jobs as they're different sdks.
t1 = 'a.b#c[28]'
t2 = 'a.b#d[27]'
t3 = 'a.b#e[26]'
test_list = [t1, t2, t3]
results = local_machine_junit_test_run.GroupTestsForShard(test_list)
ans = [[t1.replace('#', '.')], [t2.replace('#', '.')],
[t3.replace('#', '.')]]
self.assertCountEqual(results, ans)
test_classes = ['dir/test'] * 5
results = local_machine_junit_test_run.GroupTestsForShard(2, test_classes)
ans_dict = [['dir.test*'] * 3, ['dir.test*'] * 2]
self.assertEqual(results, ans_dict)
# Tests having different tests and sdks.
# Should come back as 3 jobs.
t1 = 'a.1#c[28]'
t2 = 'a.2#d[27]'
t3 = 'a.3#e[26]'
t4 = 'a.4#e[26]'
test_list = [t1, t2, t3, t4]
results = local_machine_junit_test_run.GroupTestsForShard(test_list)
test_classes = ['a10 warthog', 'b17', 'SR71']
results = local_machine_junit_test_run.GroupTestsForShard(3, test_classes)
ans_dict = [['a10 warthog*'], ['b17*'], ['SR71*']]
self.assertEqual(results, ans_dict)
ans = [[t2.replace('#', '.')], [t3.replace('#', '.'),
t4.replace('#', '.')],
[t1.replace('#', '.')]]
self.assertCountEqual(ans, results)
# Tests multiple tests of same sdk split across multiple jobs.
t0 = 'a.b#c[28]'
t1 = 'foo.bar#d[27]'
t2 = 'alice.bob#e[26]'
t3 = 'a.l#c[28]'
t4 = 'z.x#c[28]'
t5 = 'z.y#c[28]'
t6 = 'z.z#c[28]'
test_list = [t0, t1, t2, t3, t4, t5, t6]
results = local_machine_junit_test_run.GroupTestsForShard(test_list)
results = sorted(results)
t_ans = [x.replace('#', '.') for x in test_list]
ans = [[t_ans[0], t_ans[3], t_ans[4], t_ans[5], t_ans[6]], [t_ans[2]],
[t_ans[1]]]
self.assertCountEqual(ans, results)
# Tests having a class without an sdk
t0 = 'cow.moo#chicken'
t1 = 'a.b#c[28]'
t2 = 'foo.bar#d[27]'
t3 = 'alice.bob#e[26]'
t4 = 'a.l#c[28]'
t5 = 'z.x#c[28]'
t6 = 'z.y#c[28]'
t7 = 'z.moo#c[28]'
test_list = [t0, t1, t2, t3, t4, t5, t6, t7]
results = local_machine_junit_test_run.GroupTestsForShard(test_list)
t_ans = [x.replace('#', '.') for x in test_list]
self.assertEqual(len(results), 4)
ans = [[t_ans[0]], [t_ans[1], t_ans[4], t_ans[7], t_ans[5], t_ans[6]],
[t_ans[2]], [t_ans[3]]]
self.assertCountEqual(ans, results)
def testGroupTestsForShardWithSDK_ClassesPerJob(self):
# Tests grouping tests when classes_per_job is exceeded.
# All tests are from same class so should be in a single job.
local_machine_junit_test_run._MAX_TESTS_PER_JOB = 3
t0 = 'plane.b17#bomb[28]'
t1 = 'plane.b17#gunner[28]'
t2 = 'plane.b17#pilot[28]'
t3 = 'plane.b17#copilot[28]'
t4 = 'plane.b17#radio[28]'
test_list = [t0, t1, t2, t3, t4]
results = local_machine_junit_test_run.GroupTestsForShard(test_list)
t_ans = [x.replace('#', '.') for x in test_list]
ans = [t_ans[0], t_ans[1], t_ans[2], t_ans[3], t_ans[4]]
found_ans = False
for r in results:
if len(r) > 0:
self.assertCountEqual(r, ans)
found_ans = True
self.assertTrue(found_ans)
# Tests grouping tests when classes_per_job is exceeded and classes are
# different.
t0 = 'plane.b17#bomb[28]'
t1 = 'plane.b17#gunner[28]'
t2 = 'plane.b17#pilot[28]'
t3 = 'plane.b24_liberator#copilot[28]'
t4 = 'plane.b24_liberator#radio[28]'
t5 = 'plane.b25_mitchel#doolittle[28]'
t6 = 'plane.b26_marauder#radio[28]'
t7 = 'plane.b36_peacemaker#nuclear[28]'
t8 = 'plane.b52_stratofortress#nuclear[30]'
test_list = [t0, t1, t2, t3, t4, t5, t6, t7, t8]
results = local_machine_junit_test_run.GroupTestsForShard(test_list)
t_ans = [x.replace('#', '.') for x in test_list]
checked_b17 = False
checked_b52 = False
for r in results:
if t_ans[0] in r:
self.assertTrue(t_ans[1] in r)
self.assertTrue(t_ans[2] in r)
checked_b17 = True
if t_ans[8] in r:
self.assertEqual(1, len(r))
checked_b52 = True
continue
# Every job should have at least 1 test. Max any job could have is 5
# if b17 and b24 are paired together as there is no need for any job
# to have 3 classes with 3 shards for the 5 sdk28 classes.
self.assertTrue(len(r) >= 1 and len(r) <= 5)
self.assertTrue(all([checked_b17, checked_b52, len(results) == 4]))
if __name__ == '__main__':