diff --git a/CHANGELOG.md b/CHANGELOG.md index 0488d2f7..e48dfb02 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -17,6 +17,7 @@ The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/). - Add details attribute to get_vts command. [#222](https://github.com/greenbone/ospd/pull/222) - Add [pontos](https://github.com/greenbone/pontos) as dev dependency for managing the version information in ospd [#254](https://github.com/greenbone/ospd/pull/254) +- Add more info about scan progress with progress attribute in get_scans cmd. [#266](https://github.com/greenbone/ospd/pull/266) ### Changes - Modify __init__() method and use new syntax for super(). [#186](https://github.com/greenbone/ospd/pull/186) @@ -30,11 +31,15 @@ The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/). a bit different then `pipenv install`. It installs dev packages by default and also ospd in editable mode. This means after running poetry install ospd will directly be importable in the virtual python environment. [#252](https://github.com/greenbone/ospd/pull/252) +- Progress bar calculation does not take in account the dead hosts. [#266](https://github.com/greenbone/ospd/pull/266) ### Fixed - Fix stop scan. Wait for the scan process to be stopped before delete it from the process table. [#204](https://github.com/greenbone/ospd/pull/204) - Fix get_scanner_details(). [#210](https://github.com/greenbone/ospd/pull/210) +### Removed +- Remove support for resume task. [#266](https://github.com/greenbone/ospd/pull/266) + ## [2.0.1] (unreleased) ### Added diff --git a/doc/OSP.xml b/doc/OSP.xml index ef50ba7f..5f8efed0 100644 --- a/doc/OSP.xml +++ b/doc/OSP.xml @@ -208,7 +208,7 @@ Foundation, Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA. finished_hosts - One or many finished hosts to exclude when resuming a task. The list is comma-separated. Each entry can be an IP address, a CIDR notation, a hostname, a IP range. IPs can be v4 or v6. The listed hosts will be set as finished before starting the scan. Each wrapper must handle the finished hosts. + One or many finished hosts to exclude when the client resumes a task. The list is comma-separated. Each entry can be an IP address, a CIDR notation, a hostname, a IP range. IPs can be v4 or v6. The listed hosts will be set as finished before starting the scan. Each wrapper must handle the finished hosts. string @@ -440,6 +440,7 @@ Foundation, Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA. ID of a specific scan to get
Whether to return the full scan report
+ Whether to return a detailed progress information Whether to remove the fetched results Maximum number of results to fetch. Only considered if pop_results is enabled. Default = None, which means that all available results are returned
@@ -521,6 +522,11 @@ Foundation, Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA. Whether to get full scan reports boolean + + progress + Whether to return a detailed progress information + boolean + pop_results Whether to remove the fetched results @@ -624,6 +630,30 @@ Foundation, Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA. + + Get a scan progress summary + + + + + + + + 0.0 + 68.00013854253257 + 26.0009697977279 + 0.0 + 38.800221668052096 + 1 + 249 + 0 + 254 + + + + + delete_scan diff --git a/ospd/command/command.py b/ospd/command/command.py index e6050097..1d1d6db3 100644 --- a/ospd/command/command.py +++ b/ospd/command/command.py @@ -412,22 +412,25 @@ def handle_xml(self, xml: Element) -> bytes: details = xml.get('details') pop_res = xml.get('pop_results') max_res = xml.get('max_results') + progress = xml.get('progress') if details and details == '0': details = False else: details = True - if pop_res and pop_res == '1': - pop_res = True - else: - pop_res = False + pop_res = pop_res and pop_res == '1' + if max_res: max_res = int(max_res) + progress = progress and progress == '1' + responses = [] if scan_id and scan_id in self._daemon.scan_collection.ids_iterator(): self._daemon.check_scan_process(scan_id) - scan = self._daemon.get_scan_xml(scan_id, details, pop_res, max_res) + scan = self._daemon.get_scan_xml( + scan_id, details, pop_res, max_res, progress + ) responses.append(scan) elif scan_id: text = "Failed to find scan '{0}'".format(scan_id) @@ -436,7 +439,7 @@ def handle_xml(self, xml: Element) -> bytes: for scan_id in self._daemon.scan_collection.ids_iterator(): self._daemon.check_scan_process(scan_id) scan = self._daemon.get_scan_xml( - scan_id, details, pop_res, max_res + scan_id, details, pop_res, max_res, progress ) responses.append(scan) diff --git a/ospd/ospd.py b/ospd/ospd.py index d9a450b2..d7b8f816 100644 --- a/ospd/ospd.py +++ b/ospd/ospd.py @@ -56,6 +56,7 @@ from ospd.xml import ( elements_as_text, get_result_xml, + get_progress_xml, get_elements_from_dict, ) @@ -82,6 +83,9 @@ }, } # type: Dict +PROGRESS_FINISHED = 100 +PROGRESS_DEAD_HOST = -1 + def _terminate_process_group(process: multiprocessing.Process) -> None: os.killpg(os.getpgid(process.pid), 15) @@ -410,7 +414,7 @@ def exec_scan(self, scan_id: str): def finish_scan(self, scan_id: str) -> None: """ Sets a scan as finished. """ - self.set_scan_progress(scan_id, 100) + self.scan_collection.set_progress(scan_id, PROGRESS_FINISHED) self.set_scan_status(scan_id, ScanStatus.FINISHED) logger.info("%s: Scan finished.", scan_id) @@ -505,29 +509,14 @@ def calculate_progress(self, scan_id: str) -> float: return self.scan_collection.calculate_target_progress(scan_id) - def process_exclude_hosts(self, scan_id: str, exclude_hosts: str) -> None: - """ Process the exclude hosts before launching the scans.""" - - exc_hosts_list = '' - if not exclude_hosts: - return - exc_hosts_list = target_str_to_list(exclude_hosts) - self.remove_scan_hosts_from_target_progress(scan_id, exc_hosts_list) - def process_finished_hosts(self, scan_id: str, finished_hosts: str) -> None: - """ Process the finished hosts before launching the scans. - Set finished hosts as finished with 100% to calculate - the scan progress.""" + """ Process the finished hosts before launching the scans.""" - exc_hosts_list = '' if not finished_hosts: return - exc_hosts_list = target_str_to_list(finished_hosts) - - for host in exc_hosts_list: - self.set_scan_host_finished(scan_id, finished_hosts=host) - self.set_scan_host_progress(scan_id, host=host, progress=100) + exc_finished_hosts_list = target_str_to_list(finished_hosts) + self.scan_collection.set_host_finished(scan_id, exc_finished_hosts_list) def start_scan(self, scan_id: str, target: Dict) -> None: """ Starts the scan with scan_id. """ @@ -538,7 +527,6 @@ def start_scan(self, scan_id: str, target: Dict) -> None: logger.info("%s: Scan started.", scan_id) - self.process_exclude_hosts(scan_id, target.get('exclude_hosts')) self.process_finished_hosts(scan_id, target.get('finished_hosts')) try: @@ -584,27 +572,34 @@ def handle_timeout(self, scan_id: str, host: str) -> None: value="{0} exec timeout.".format(self.get_scanner_name()), ) - def remove_scan_hosts_from_target_progress( - self, scan_id: str, exc_hosts_list: List - ) -> None: - """ Remove a list of hosts from the main scan progress table.""" - self.scan_collection.remove_hosts_from_target_progress( - scan_id, exc_hosts_list - ) - - def set_scan_host_finished( + def sort_host_finished( self, scan_id: str, finished_hosts: Union[List[str], str], ) -> None: - """ Add the host in a list of finished hosts """ + """ Check if the finished host in the list was alive or dead + and update the corresponding alive_count or dead_count. """ if isinstance(finished_hosts, str): finished_hosts = [finished_hosts] - self.scan_collection.set_host_finished(scan_id, finished_hosts) + alive_hosts = [] + dead_hosts = [] + + current_hosts = self.scan_collection.get_current_target_progress( + scan_id + ) + for finished_host in finished_hosts: + progress = current_hosts.get(finished_host) + if progress == PROGRESS_FINISHED: + alive_hosts.append(finished_host) + if progress == PROGRESS_DEAD_HOST: + dead_hosts.append(finished_host) + + self.scan_collection.set_host_dead(scan_id, dead_hosts) - def set_scan_progress(self, scan_id: str, progress: int) -> None: - """ Sets scan_id scan's progress which is a number - between 0 and 100. """ - self.scan_collection.set_progress(scan_id, progress) + self.scan_collection.set_host_finished(scan_id, alive_hosts) + + self.scan_collection.remove_hosts_from_target_progress( + scan_id, finished_hosts + ) def set_scan_progress_batch( self, scan_id: str, host_progress: Dict[str, int] @@ -612,7 +607,7 @@ def set_scan_progress_batch( self.scan_collection.set_host_progress(scan_id, host_progress) scan_progress = self.calculate_progress(scan_id) - self.set_scan_progress(scan_id, scan_progress) + self.scan_collection.set_progress(scan_id, scan_progress) def set_scan_host_progress( self, scan_id: str, host: str = None, progress: int = None, @@ -621,9 +616,6 @@ def set_scan_host_progress( Each time a host progress is updated, the scan progress is updated too. """ - if host and progress < 0 or progress > 100: - return - host_progress = {host: progress} self.set_scan_progress_batch(scan_id, host_progress) @@ -711,6 +703,32 @@ def get_scan_results_xml( logger.debug('Returning %d results', len(results)) return results + def _get_scan_progress_xml(self, scan_id: str): + """ Gets scan_id scan's progress in XML format. + + @return: String of scan progress in xml. + """ + current_progress = dict() + + current_progress[ + 'current_hosts' + ] = self.scan_collection.get_current_target_progress(scan_id) + current_progress['overall'] = self.scan_collection.get_progress(scan_id) + current_progress['count_alive'] = self.scan_collection.get_count_alive( + scan_id + ) + current_progress['count_dead'] = self.scan_collection.get_count_dead( + scan_id + ) + current_progress[ + 'count_excluded' + ] = self.scan_collection.simplify_exclude_host_count(scan_id) + current_progress['count_total'] = self.scan_collection.get_host_count( + scan_id + ) + + return get_progress_xml(current_progress) + @deprecated( version="20.8", reason="Please use ospd.xml.get_elements_from_dict instead.", @@ -730,6 +748,7 @@ def get_scan_xml( detailed: bool = True, pop_res: bool = False, max_res: int = 0, + progress: bool = False, ): """ Gets scan in XML format. @@ -757,6 +776,9 @@ def get_scan_xml( response.append( self.get_scan_results_xml(scan_id, pop_res, max_res) ) + if progress: + response.append(self._get_scan_progress_xml(scan_id)) + return response @staticmethod @@ -1178,23 +1200,17 @@ def create_scan( @target: Target to scan. @options: Miscellaneous scan options. - @return: New scan's ID. None if the scan_id already exists and the - scan status is RUNNING or FINISHED. + @return: New scan's ID. None if the scan_id already exists. """ status = None scan_exists = self.scan_exists(scan_id) if scan_id and scan_exists: status = self.get_scan_status(scan_id) - - if scan_exists and status == ScanStatus.STOPPED: - logger.info("Scan %s exists. Resuming scan.", scan_id) - elif scan_exists and ( - status == ScanStatus.RUNNING or status == ScanStatus.FINISHED - ): logger.info( "Scan %s exists with status %s.", scan_id, status.name.lower() ) return + return self.scan_collection.create_scan( scan_id, targets, options, vt_selection ) @@ -1238,7 +1254,7 @@ def check_scan_process(self, scan_id: str) -> None: scan_process = self.scan_processes[scan_id] progress = self.get_scan_progress(scan_id) - if progress < 100 and not scan_process.is_alive(): + if progress < PROGRESS_FINISHED and not scan_process.is_alive(): if not self.get_scan_status(scan_id) == ScanStatus.STOPPED: self.set_scan_status(scan_id, ScanStatus.STOPPED) self.add_scan_error( @@ -1247,7 +1263,7 @@ def check_scan_process(self, scan_id: str) -> None: logger.info("%s: Scan stopped with errors.", scan_id) - elif progress == 100: + elif progress == PROGRESS_FINISHED: scan_process.join(0) def get_scan_progress(self, scan_id: str): @@ -1281,14 +1297,6 @@ def get_scan_vts(self, scan_id: str) -> Dict: """ Gives a scan's vts. """ return self.scan_collection.get_vts(scan_id) - def get_scan_unfinished_hosts(self, scan_id: str) -> List: - """ Get a list of unfinished hosts.""" - return self.scan_collection.get_hosts_unfinished(scan_id) - - def get_scan_finished_hosts(self, scan_id: str) -> List: - """ Get a list of unfinished hosts.""" - return self.scan_collection.get_hosts_finished(scan_id) - def get_scan_start_time(self, scan_id: str) -> str: """ Gives a scan's start time. """ return self.scan_collection.get_start_time(scan_id) diff --git a/ospd/scan.py b/ospd/scan.py index 0d4c5d21..2dedf1cf 100644 --- a/ospd/scan.py +++ b/ospd/scan.py @@ -148,29 +148,26 @@ def set_host_progress( self.scans_table[scan_id]['target_progress'] = host_progresses def set_host_finished(self, scan_id: str, hosts: List[str]) -> None: - """ Add the host in a list of finished hosts """ - finished_hosts = self.scans_table[scan_id].get('finished_hosts') - finished_hosts.extend(hosts) - uniq_set_of_hosts = set(finished_hosts) + """ Increase the amount of finished hosts which were alive.""" - self.scans_table[scan_id]['finished_hosts'] = list(uniq_set_of_hosts) - - def get_hosts_unfinished(self, scan_id: str) -> List[Any]: - """ Get a list of unfinished hosts.""" - - unfinished_hosts = target_str_to_list(self.get_host_list(scan_id)) - - finished_hosts = self.get_hosts_finished(scan_id) + total_finished = len(hosts) + count_alive = ( + self.scans_table[scan_id].get('count_alive') + total_finished + ) + self.scans_table[scan_id]['count_alive'] = count_alive - for host in finished_hosts: - unfinished_hosts.remove(host) + def set_host_dead(self, scan_id: str, hosts: List[str]) -> None: + """ Increase the amount of dead hosts. """ - return unfinished_hosts + total_dead = len(hosts) + count_dead = self.scans_table[scan_id].get('count_dead') + total_dead + self.scans_table[scan_id]['count_dead'] = count_dead - def get_hosts_finished(self, scan_id: str) -> List: - """ Get a list of finished hosts.""" + def set_amount_dead_hosts(self, scan_id: str, total_dead: int) -> None: + """ Increase the amount of dead hosts. """ - return self.scans_table[scan_id].get('finished_hosts') + count_dead = self.scans_table[scan_id].get('count_dead') + total_dead + self.scans_table[scan_id]['count_dead'] = count_dead def results_iterator( self, scan_id: str, pop_res: bool = False, max_res: int = None @@ -199,50 +196,6 @@ def ids_iterator(self) -> Iterator[str]: return iter(self.scans_table.keys()) - def remove_single_result( - self, scan_id: str, result: Dict[str, str] - ) -> None: - """Removes a single result from the result list in scan_table. - - Parameters: - scan_id (uuid): Scan ID to identify the scan process to be resumed. - result (dict): The result to be removed from the results list. - """ - results = self.scans_table[scan_id]['results'] - results.remove(result) - self.scans_table[scan_id]['results'] = results - - def del_results_for_stopped_hosts(self, scan_id: str) -> None: - """ Remove results from the result table for those host - """ - unfinished_hosts = self.get_hosts_unfinished(scan_id) - for result in self.results_iterator( - scan_id, pop_res=False, max_res=None - ): - if result['host'] in unfinished_hosts: - self.remove_single_result(scan_id, result) - - def resume_scan(self, scan_id: str, options: Optional[Dict]) -> str: - """ Reset the scan status in the scan_table to INIT. - Also, overwrite the options, because a resume task cmd - can add some new option. E.g. exclude hosts list. - Parameters: - scan_id (uuid): Scan ID to identify the scan process to be resumed. - options (dict): Options for the scan to be resumed. This options - are not added to the already existent ones. - The old ones are removed - - Return: - Scan ID which identifies the current scan. - """ - self.scans_table[scan_id]['status'] = ScanStatus.INIT - if options: - self.scans_table[scan_id]['options'] = options - - self.del_results_for_stopped_hosts(scan_id) - - return scan_id - def create_scan( self, scan_id: str = '', @@ -258,25 +211,15 @@ def create_scan( if self.data_manager is None: self.data_manager = multiprocessing.Manager() - # Check if it is possible to resume task. To avoid to resume, the - # scan must be deleted from the scans_table. - if ( - scan_id - and self.id_exists(scan_id) - and (self.get_status(scan_id) == ScanStatus.STOPPED) - ): - self.scans_table[scan_id]['end_time'] = 0 - - return self.resume_scan(scan_id, options) - if not options: options = dict() scan_info = self.data_manager.dict() # type: Dict scan_info['results'] = list() - scan_info['finished_hosts'] = list() scan_info['progress'] = 0 scan_info['target_progress'] = dict() + scan_info['count_alive'] = 0 + scan_info['count_dead'] = 0 scan_info['target'] = target scan_info['vts'] = vts scan_info['options'] = options @@ -318,11 +261,29 @@ def get_progress(self, scan_id: str) -> int: return self.scans_table[scan_id]['progress'] - def simplify_exclude_host_list(self, scan_id: str) -> List[Any]: + def get_count_dead(self, scan_id: str) -> int: + """ Get a scan's current dead host count. """ + + return self.scans_table[scan_id]['count_dead'] + + def get_count_alive(self, scan_id: str) -> int: + """ Get a scan's current dead host count. """ + + return self.scans_table[scan_id]['count_alive'] + + def get_current_target_progress(self, scan_id: str) -> Dict[str, int]: + """ Get a scan's current dead host count. """ + + return self.scans_table[scan_id]['target_progress'] + + def simplify_exclude_host_count(self, scan_id: str) -> int: """ Remove from exclude_hosts the received hosts in the finished_hosts list sent by the client. The finished hosts are sent also as exclude hosts for backward compatibility purposses. + + Return: + Count of excluded host. """ exc_hosts_list = target_str_to_list(self.get_exclude_hosts(scan_id)) @@ -336,22 +297,22 @@ def simplify_exclude_host_list(self, scan_id: str) -> List[Any]: if finished in exc_hosts_list: exc_hosts_list.remove(finished) - return exc_hosts_list + return len(exc_hosts_list) if exc_hosts_list else 0 def calculate_target_progress(self, scan_id: str) -> float: """ Get a target's current progress value. The value is calculated with the progress of each single host in the target.""" - host = self.get_host_list(scan_id) - total_hosts = len(target_str_to_list(host)) - exc_hosts_list = self.simplify_exclude_host_list(scan_id) - exc_hosts = len(exc_hosts_list) if exc_hosts_list else 0 + total_hosts = self.get_host_count(scan_id) + exc_hosts = self.simplify_exclude_host_count(scan_id) + count_alive = self.get_count_alive(scan_id) + count_dead = self.get_count_dead(scan_id) host_progresses = self.scans_table[scan_id].get('target_progress') try: - t_prog = sum(host_progresses.values()) / ( - total_hosts - exc_hosts + t_prog = (sum(host_progresses.values()) + 100 * count_alive) / ( + total_hosts - exc_hosts - count_dead ) # type: float except ZeroDivisionError: LOGGER.error( @@ -377,6 +338,13 @@ def get_host_list(self, scan_id: str) -> Dict: return self.scans_table[scan_id]['target'].get('hosts') + def get_host_count(self, scan_id: str) -> int: + """ Get total host count in the target. """ + host = self.get_host_list(scan_id) + total_hosts = len(target_str_to_list(host)) + + return total_hosts + def get_ports(self, scan_id: str): """ Get a scan's ports list. """ diff --git a/ospd/xml.py b/ospd/xml.py index b62d9c75..793b02c0 100644 --- a/ospd/xml.py +++ b/ospd/xml.py @@ -99,6 +99,32 @@ def get_result_xml(result): return result_xml +def get_progress_xml(progress: Dict[str, int]): + """ Formats a scan progress to XML format. + + Arguments: + progress (dict): Dictionary with a scan progress. + + Return: + Progress as xml element object. + """ + + progress_xml = Element('progress') + for progress_item, value in progress.items(): + elem = None + if progress_item == 'current_hosts': + for host, h_progress in value.items(): + elem = Element('host') + elem.set('name', host) + elem.text = str(h_progress) + progress_xml.append(elem) + else: + elem = Element(progress_item) + elem.text = str(value) + progress_xml.append(elem) + return progress_xml + + def simple_response_str( command: str, status: int, diff --git a/tests/test_scan_and_result.py b/tests/test_scan_and_result.py index dff9ccbe..0733776d 100644 --- a/tests/test_scan_and_result.py +++ b/tests/test_scan_and_result.py @@ -792,35 +792,29 @@ def test_scan_get_target_options(self): target_options = daemon.get_scan_target_options(scan_id) self.assertEqual(target_options, {'alive_test': '0'}) - def test_scan_get_finished_hosts(self): + def test_progress(self): daemon = DummyWrapper([]) + fs = FakeStream() daemon.handle_command( - '' - '' - '' + '' + '' '' - '192.168.10.20-25' - '80,443' - '192.168.10.23-24' - '' - '' - '192.168.0.0/24' - '22' - '' + 'localhost1, localhost2' + '22' + '' '', fs, ) response = fs.get_response() scan_id = response.findtext('id') - time.sleep(1) - finished = daemon.get_scan_finished_hosts(scan_id) - for host in ['192.168.10.23', '192.168.10.24']: - self.assertIn(host, finished) - self.assertEqual(len(finished), 2) + daemon.set_scan_host_progress(scan_id, 'localhost1', 75) + daemon.set_scan_host_progress(scan_id, 'localhost2', 25) - def test_progress(self): + self.assertEqual(daemon.calculate_progress(scan_id), 50) + + def test_sort_host_finished(self): daemon = DummyWrapper([]) fs = FakeStream() @@ -828,7 +822,7 @@ def test_progress(self): '' '' '' - 'localhost1, localhost2' + 'localhost1, localhost2, localhost3, localhost4' '22' '' '', @@ -837,11 +831,62 @@ def test_progress(self): response = fs.get_response() scan_id = response.findtext('id') + daemon.set_scan_host_progress(scan_id, 'localhost3', -1) + daemon.set_scan_host_progress(scan_id, 'localhost1', 75) + daemon.set_scan_host_progress(scan_id, 'localhost4', 100) + daemon.set_scan_host_progress(scan_id, 'localhost2', 25) + + daemon.sort_host_finished(scan_id, ['localhost3', 'localhost4']) + + rounded_progress = int(daemon.calculate_progress(scan_id)) + self.assertEqual(rounded_progress, 66) + + def test_get_scan_progress_xml(self): + daemon = DummyWrapper([]) + + fs = FakeStream() + daemon.handle_command( + '' + '' + '' + 'localhost1, localhost2, localhost3, localhost4' + '22' + '' + '', + fs, + ) + response = fs.get_response() + scan_id = response.findtext('id') + + daemon.set_scan_host_progress(scan_id, 'localhost3', -1) + daemon.set_scan_host_progress(scan_id, 'localhost4', 100) + daemon.sort_host_finished(scan_id, ['localhost3', 'localhost4']) daemon.set_scan_host_progress(scan_id, 'localhost1', 75) daemon.set_scan_host_progress(scan_id, 'localhost2', 25) - self.assertEqual(daemon.calculate_progress(scan_id), 50) + fs = FakeStream() + daemon.handle_command( + '', fs, + ) + response = fs.get_response() + + progress = response.find('scan/progress') + + overall = float(progress.findtext('overall')) + self.assertEqual(int(overall), 66) + + count_alive = progress.findtext('count_alive') + self.assertEqual(count_alive, '1') + + count_dead = progress.findtext('count_dead') + self.assertEqual(count_dead, '1') + + current_hosts = progress.findall('host') + self.assertEqual(len(current_hosts), 2) + + count_excluded = progress.findtext('count_excluded') + self.assertEqual(count_excluded, '0') def test_set_get_vts_version(self): daemon = DummyWrapper([]) @@ -856,17 +901,8 @@ def test_set_get_vts_version_error(self): @patch("ospd.ospd.os") @patch("ospd.command.command.create_process") - def test_resume_task(self, mock_create_process, _mock_os): - daemon = DummyWrapper( - [ - Result( - 'host-detail', host='localhost', value='Some Host Detail' - ), - Result( - 'host-detail', host='localhost', value='Some Host Detail2' - ), - ] - ) + def test_scan_exists(self, mock_create_process, _mock_os): + daemon = DummyWrapper([]) fp = FakeStartProcess() mock_create_process.side_effect = fp @@ -888,59 +924,33 @@ def test_resume_task(self, mock_create_process, _mock_os): ) response = fs.get_response() scan_id = response.findtext('id') - self.assertIsNotNone(scan_id) + status = response.get('status_text') + self.assertEqual(status, 'OK') + assert_called(mock_create_process) assert_called(mock_process.start) - fs = FakeStream() daemon.handle_command('' % scan_id, fs) fs = FakeStream() - daemon.handle_command( - '' % scan_id, fs - ) - response = fs.get_response() - - result = response.findall('scan/results/result') - self.assertEqual(len(result), 2) - - # Resume the task cmd = ( - '' + '' '' - ''.format(scan_id) - ) - fs = FakeStream() - daemon.handle_command(cmd, fs) - response = fs.get_response() - - # Check unfinished host - self.assertEqual(response.findtext('id'), scan_id) - self.assertEqual( - daemon.get_scan_unfinished_hosts(scan_id), ['localhost'] - ) - - # Finished the host and check unfinished again. - daemon.set_scan_host_finished(scan_id, "localhost") - self.assertEqual(len(daemon.get_scan_unfinished_hosts(scan_id)), 0) - - # Check finished hosts - self.assertEqual( - daemon.scan_collection.get_hosts_finished(scan_id), ['localhost'] + '' + 'localhost' + '22' + '' + '' ) - # Check if the result was removed. - fs = FakeStream() daemon.handle_command( - '' % scan_id, fs + cmd, fs, ) response = fs.get_response() - result = response.findall('scan/results/result') - - # current the response still contains the results - # self.assertEqual(len(result), 0) + status = response.get('status_text') + self.assertEqual(status, 'Continue') def test_result_order(self): daemon = DummyWrapper([])