# Copyright (C) 2017 Intel Corporation # Released under the MIT license (see COPYING.MIT) import threading import multiprocessing from unittest.suite import TestSuite from oeqa.core.loader import OETestLoader from oeqa.core.runner import OEStreamLogger, OETestResult class OETestLoaderThreaded(OETestLoader): def __init__(self, tc, module_paths, modules, tests, modules_required, filters, process_num=0, *args, **kwargs): super(OETestLoaderThreaded, self).__init__(tc, module_paths, modules, tests, modules_required, filters, *args, **kwargs) self.process_num = process_num def discover(self): suite = super(OETestLoaderThreaded, self).discover() if self.process_num <= 0: self.process_num = min(multiprocessing.cpu_count(), len(suite._tests)) suites = [] for _ in range(self.process_num): suites.append(self.suiteClass()) def _search_for_module_idx(suites, case): """ Cases in the same module needs to be run in the same thread because PyUnit keeps track of setUp{Module, Class,} and tearDown{Module, Class,}. """ for idx in range(self.process_num): suite = suites[idx] for c in suite._tests: if case.__module__ == c.__module__: return idx return -1 def _search_for_depend_idx(suites, depends): """ Dependency cases needs to be run in the same thread, because OEQA framework look at the state of dependant test to figure out if skip or not. """ for idx in range(self.process_num): suite = suites[idx] for case in suite._tests: if case.id() in depends: return idx return -1 def _get_best_idx(suites): sizes = [len(suite._tests) for suite in suites] return sizes.index(min(sizes)) def _fill_suites(suite): idx = -1 for case in suite: if isinstance(case, TestSuite): _fill_suites(case) else: idx = _search_for_module_idx(suites, case) depends = {} if 'depends' in self.tc._registry: depends = self.tc._registry['depends'] if idx == -1 and case.id() in depends: case_depends = depends[case.id()] idx = _search_for_depend_idx(suites, case_depends) if idx == -1: idx = _get_best_idx(suites) suites[idx].addTest(case) _fill_suites(suite) suites_tmp = suites suites = [] for suite in suites_tmp: if len(suite._tests) > 0: suites.append(suite) return suites class OEStreamLoggerThreaded(OEStreamLogger): _lock = threading.Lock() buffers = {} def write(self, msg): tid = threading.get_ident() if not tid in self.buffers: self.buffers[tid] = "" if msg: self.buffers[tid] += msg def finish(self): tid = threading.get_ident() self._lock.acquire() self.logger.info('THREAD: %d' % tid) self.logger.info('-' * 70) for line in self.buffers[tid].split('\n'): self.logger.info(line) self._lock.release() class OETestResultThreadedInternal(OETestResult): def _tc_map_results(self): tid = threading.get_ident() # PyUnit generates a result for every test module run, test # if the thread already has an entry to avoid lose the previous # test module results. if not tid in self.tc._results: self.tc._results[tid] = {} self.tc._results[tid]['failures'] = self.failures self.tc._results[tid]['errors'] = self.errors self.tc._results[tid]['skipped'] = self.skipped self.tc._results[tid]['expectedFailures'] = self.expectedFailures class OETestResultThreaded(object): _results = {} _lock = threading.Lock() def __init__(self, tc): self.tc = tc def _fill_tc_results(self): tids = list(self.tc._results.keys()) fields = ['failures', 'errors', 'skipped', 'expectedFailures'] for tid in tids: result = self.tc._results[tid] for field in fields: if not field in self.tc._results: self.tc._results[field] = [] self.tc._results[field].extend(result[field]) def addResult(self, result, run_start_time, run_end_time): tid = threading.get_ident() self._lock.acquire() self._results[tid] = {} self._results[tid]['result'] = result self._results[tid]['run_start_time'] = run_start_time self._results[tid]['run_end_time'] = run_end_time self._results[tid]['result'] = result self._lock.release() def wasSuccessful(self): wasSuccessful = True for tid in self._results.keys(): wasSuccessful = wasSuccessful and \ self._results[tid]['result'].wasSuccessful() return wasSuccessful def stop(self): for tid in self._results.keys(): self._results[tid]['result'].stop() def logSummary(self, component, context_msg=''): elapsed_time = (self.tc._run_end_time - self.tc._run_start_time) self.tc.logger.info("SUMMARY:") self.tc.logger.info("%s (%s) - Ran %d tests in %.3fs" % (component, context_msg, len(self.tc._registry['cases']), elapsed_time)) if self.wasSuccessful(): msg = "%s - OK - All required tests passed" % component else: msg = "%s - FAIL - Required tests failed" % component self.tc.logger.info(msg) def logDetails(self): tid = list(self._results)[0] result = self._results[tid]['result'] result.logDetails()