Skip to content

Commit

Permalink
Merge pull request greenbone#332 from jjnicola/update-total
Browse files Browse the repository at this point in the history
Allow the scanner to update total count of hosts.
  • Loading branch information
jjnicola authored Oct 12, 2020
2 parents 9f637f0 + a84d0e2 commit 6650020
Show file tree
Hide file tree
Showing 4 changed files with 62 additions and 8 deletions.
3 changes: 3 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,9 @@ The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/).

## [20.8.2] (unreleased)

### Added
- Allow the scanner to update total count of hosts. [#332](https://github.com/greenbone/ospd/pull/332)

### Fixed
- Fix OSP version. [#326](https://github.com/greenbone/ospd/pull/326)

Expand Down
7 changes: 6 additions & 1 deletion ospd/ospd.py
Original file line number Diff line number Diff line change
Expand Up @@ -786,7 +786,7 @@ def _get_scan_progress_xml(self, scan_id: str):
current_progress[
'count_excluded'
] = self.scan_collection.simplify_exclude_host_count(scan_id)
current_progress['count_total'] = self.scan_collection.get_host_count(
current_progress['count_total'] = self.scan_collection.get_count_total(
scan_id
)

Expand Down Expand Up @@ -1392,6 +1392,11 @@ def set_scan_option(self, scan_id: str, name: str, value: Any) -> None:
""" Sets a scan's option to a provided value. """
return self.scan_collection.set_option(scan_id, name, value)

def set_scan_total_hosts(self, scan_id: str, count_total: int) -> None:
"""Sets a scan's total hosts. Allow the scanner to update
the total count of host to be scanned."""
self.scan_collection.update_count_total(scan_id, count_total)

def clean_forgotten_scans(self) -> None:
"""Check for old stopped or finished scans which have not been
deleted and delete them if the are older than the set value."""
Expand Down
20 changes: 18 additions & 2 deletions ospd/scan.py
Original file line number Diff line number Diff line change
Expand Up @@ -269,6 +269,7 @@ def unpickle_scan_info(self, scan_id: str) -> None:
scan_info['target_progress'] = dict()
scan_info['count_alive'] = 0
scan_info['count_dead'] = 0
scan_info['count_total'] = 0
scan_info['target'] = unpickled_scan_info.pop('target')
scan_info['vts'] = unpickled_scan_info.pop('vts')
scan_info['options'] = unpickled_scan_info.pop('options')
Expand Down Expand Up @@ -360,10 +361,25 @@ def get_count_dead(self, scan_id: str) -> int:
return self.scans_table[scan_id]['count_dead']

def get_count_alive(self, scan_id: str) -> int:
""" Get a scan's current dead host count. """
""" Get a scan's current alive host count. """

return self.scans_table[scan_id]['count_alive']

def update_count_total(self, scan_id: str, count_total: int) -> int:
""" Sets a scan's total hosts."""

self.scans_table[scan_id]['count_total'] = count_total

def get_count_total(self, scan_id: str) -> int:
""" Get a scan's total host count. """

count_total = self.scans_table[scan_id]['count_total']
if not count_total:
count_total = self.get_host_count(scan_id)
self.update_count_total(scan_id, count_total)

return count_total

def get_current_target_progress(self, scan_id: str) -> Dict[str, int]:
""" Get a scan's current hosts progress """
return self.scans_table[scan_id]['target_progress']
Expand Down Expand Up @@ -396,7 +412,7 @@ def calculate_target_progress(self, scan_id: str) -> int:
The value is calculated with the progress of each single host
in the target."""

total_hosts = self.get_host_count(scan_id)
total_hosts = self.get_count_total(scan_id)
exc_hosts = self.simplify_exclude_host_count(scan_id)
count_alive = self.get_count_alive(scan_id)
count_dead = self.get_count_dead(scan_id)
Expand Down
40 changes: 35 additions & 5 deletions tests/test_scan_and_result.py
Original file line number Diff line number Diff line change
Expand Up @@ -238,7 +238,8 @@ def test_get_vts_filter_negative(self):
)
fs = FakeStream()
self.daemon.handle_command(
'<get_vts filter="modification_time&lt;19000203"></get_vts>', fs,
'<get_vts filter="modification_time&lt;19000203"></get_vts>',
fs,
)
response = fs.get_response()

Expand Down Expand Up @@ -741,7 +742,8 @@ def test_get_scan_results_clean(self):

fs = FakeStream()
self.daemon.handle_command(
'<get_scans scan_id="%s" pop_results="1"/>' % scan_id, fs,
'<get_scans scan_id="%s" pop_results="1"/>' % scan_id,
fs,
)

res_len = len(
Expand Down Expand Up @@ -771,7 +773,8 @@ def test_get_scan_results_restore(self):

fs = FakeStream(return_value=False)
self.daemon.handle_command(
'<get_scans scan_id="%s" pop_results="1"/>' % scan_id, fs,
'<get_scans scan_id="%s" pop_results="1"/>' % scan_id,
fs,
)

res_len = len(
Expand Down Expand Up @@ -1060,6 +1063,31 @@ def test_get_scan_without_scanid(self):
fs,
)

def test_set_scan_total_hosts(self):

fs = FakeStream()
self.daemon.handle_command(
'<start_scan parallel="2">'
'<scanner_params />'
'<targets><target>'
'<hosts>localhost1, localhost2, localhost3, localhost4</hosts>'
'<ports>22</ports>'
'</target></targets>'
'</start_scan>',
fs,
)
self.daemon.start_queued_scans()

response = fs.get_response()
scan_id = response.findtext('id')

count = self.daemon.scan_collection.get_count_total(scan_id)
self.assertEqual(count, 4)

self.daemon.set_scan_total_hosts(scan_id, 3)
count = self.daemon.scan_collection.get_count_total(scan_id)
self.assertEqual(count, 3)

def test_get_scan_progress_xml(self):

fs = FakeStream()
Expand Down Expand Up @@ -1087,7 +1115,8 @@ def test_get_scan_progress_xml(self):

fs = FakeStream()
self.daemon.handle_command(
'<get_scans scan_id="%s" details="0" progress="1"/>' % scan_id, fs,
'<get_scans scan_id="%s" details="0" progress="1"/>' % scan_id,
fs,
)
response = fs.get_response()

Expand Down Expand Up @@ -1164,7 +1193,8 @@ def test_scan_exists(self, mock_create_process, _mock_os):
)

self.daemon.handle_command(
cmd, fs,
cmd,
fs,
)
self.daemon.start_queued_scans()

Expand Down

0 comments on commit 6650020

Please sign in to comment.