diff --git a/CHANGELOG.md b/CHANGELOG.md index c6115133..f3606c2b 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -9,7 +9,7 @@ The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/). ### Added - Add solution method to solution of vt object. [#131](https://github.com/greenbone/ospd-openvas/pull/131) - Add set_nvticache_str(). [#150](https://github.com/greenbone/ospd-openvas/pull/150) -- Add typing to nvticache.py. [#161](https://github.com/greenbone/ospd-openvas/pull/161) +- Add typing to nvticache.py and db.py. [#161](https://github.com/greenbone/ospd-openvas/pull/161)[#162](https://github.com/greenbone/ospd-openvas/pull/162) ## [1.0.1] (unreleased) diff --git a/ospd_openvas/db.py b/ospd_openvas/db.py index 5f50af19..55cc2f02 100644 --- a/ospd_openvas/db.py +++ b/ospd_openvas/db.py @@ -23,6 +23,8 @@ import sys import time +from typing import List, NewType, Optional + import redis from ospd.errors import RequiredArgument @@ -54,6 +56,9 @@ logger = logging.getLogger(__name__) +# Types +RedisCtx = NewType('RedisCtx', redis.Redis) + class OpenvasDB(object): """ Class to connect to redis, to perform queries, and to move @@ -71,10 +76,10 @@ def __init__(self): self.rediscontext = None @staticmethod - def _parse_openvas_db_address(result): + def _parse_openvas_db_address(result: bytes) -> str: """ Return the path to the redis socket. Arguments: - result (bytes) Output of `openvas -s` + result: Output of `openvas -s` Return redis unix socket path. """ path = None @@ -124,10 +129,10 @@ def max_db_index(self): 'Redis Error: Not possible ' 'to get max_dbindex.' ) - def set_redisctx(self, ctx): + def set_redisctx(self, ctx: RedisCtx): """ Set the current rediscontext. Arguments: - ctx (object): Redis context to be set as default. + ctx: Redis context to be set as default. """ if not ctx: raise RequiredArgument('set_redisctx', 'ctx') @@ -138,13 +143,13 @@ def db_init(self): self.get_db_connection() self.max_db_index() - def try_database_index(self, ctx, kb): + def try_database_index(self, ctx: RedisCtx, kb: int) -> bool: """ Check if a redis kb is already in use. If not, set it as in use and return. Arguments: - ctx (object): Redis object connected to the kb with the + ctx: Redis object connected to the kb with the DBINDEX_NAME key. - kb (int): Kb number intended to be used. + kb: Kb number intended to be used. Return True if it is possible to use the kb. False if the given kb number is already in use. @@ -161,11 +166,11 @@ def try_database_index(self, ctx, kb): return True return False - def kb_connect(self, dbnum=0): + def kb_connect(self, dbnum: Optional[int] = 0) -> RedisCtx: """ Connect to redis to the given database or to the default db 0 . Arguments: - dbnum (int, optional): The db number to connect to. + dbnum: The db number to connect to. Return a redis context on success. """ @@ -197,7 +202,7 @@ def kb_connect(self, dbnum=0): self.db_index = dbnum return ctx - def db_find(self, patt): + def db_find(self, patt: str) -> Optional[RedisCtx]: """ Search a pattern inside all kbs. When find it return it. """ for i in range(0, self.max_dbindex): @@ -207,7 +212,7 @@ def db_find(self, patt): return None - def kb_new(self): + def kb_new(self) -> Optional[RedisCtx]: """ Return a new kb context to an empty kb. """ ctx = self.db_find(self.DBINDEX_NAME) @@ -219,7 +224,7 @@ def kb_new(self): return None - def get_kb_context(self): + def get_kb_context(self) -> RedisCtx: """ Get redis context if it is already connected or do a connection. """ if self.rediscontext is not None: @@ -234,13 +239,13 @@ def get_kb_context(self): return self.rediscontext - def select_kb(self, ctx, kbindex, set_global=False): + def select_kb(self, ctx: RedisCtx, kbindex: str, set_global: bool = False): """ Use an existent redis connection and select a redis kb. If needed, set the ctx as global. Arguments: - ctx (redis obj): Redis context to use. - kbindex (str): The new kb to select - set_global (bool, optional): If should be the global context. + ctx: Redis context to use. + kbindex: The new kb to select + set_global: If should be the global context. """ if not ctx: raise RequiredArgument('select_kb', 'ctx') @@ -253,16 +258,20 @@ def select_kb(self, ctx, kbindex, set_global=False): self.db_index = str(kbindex) def get_list_item( - self, name, ctx=None, start=LIST_FIRST_POS, end=LIST_LAST_POS - ): + self, + name: str, + ctx: Optional[RedisCtx] = None, + start: Optional[int] = LIST_FIRST_POS, + end: Optional[int] = LIST_LAST_POS, + ) -> Optional[list]: """ Returns the specified elements from `start` to `end` of the list stored as `name`. Arguments: - name (str): key name of a list. - ctx (redis obj, optional): Redis context to use. - start (int, optional): first range element to get. - end (int, optional): last range element to get. + name: key name of a list. + ctx: Redis context to use. + start: first range element to get. + end: last range element to get. Return List specified elements in the key. """ @@ -273,12 +282,14 @@ def get_list_item( ctx = self.get_kb_context() return ctx.lrange(name, start, end) - def remove_list_item(self, key, value, ctx=None): + def remove_list_item( + self, key: str, value: str, ctx: Optional[RedisCtx] = None + ): """ Remove item from the key list. Arguments: - key (str): key name of a list. - value (str): Value to be removed from the key. - ctx (redis obj, optional): Redis context to use. + key: key name of a list. + value: Value to be removed from the key. + ctx: Redis context to use. """ if not key: raise RequiredArgument('remove_list_item', 'key') @@ -289,12 +300,17 @@ def remove_list_item(self, key, value, ctx=None): ctx = self.get_kb_context() ctx.lrem(key, count=LIST_ALL, value=value) - def get_single_item(self, name, ctx=None, index=LIST_FIRST_POS): + def get_single_item( + self, + name: str, + ctx: Optional[RedisCtx] = None, + index: Optional[int] = LIST_FIRST_POS, + ) -> Optional[str]: """ Get a single KB element. Arguments: - name (str): key name of a list. - ctx (redis obj, optional): Redis context to use. - index (int, optional): index of the element to be return. + name: key name of a list. + ctx: Redis context to use. + index: index of the element to be return. Return an element. """ if not name: @@ -304,12 +320,14 @@ def get_single_item(self, name, ctx=None, index=LIST_FIRST_POS): ctx = self.get_kb_context() return ctx.lindex(name, index) - def add_single_item(self, name, values, ctx=None): + def add_single_item( + self, name: str, values: List, ctx: Optional[RedisCtx] = None + ): """ Add a single KB element with one or more values. Arguments: - name (str): key name of a list. - value (list): Elements to add to the key. - ctx (redis obj, optional): Redis context to use. + name: key name of a list. + value: Elements to add to the key. + ctx: Redis context to use. """ if not name: raise RequiredArgument('add_list_item', 'name') @@ -320,12 +338,14 @@ def add_single_item(self, name, values, ctx=None): ctx = self.get_kb_context() ctx.rpush(name, *set(values)) - def set_single_item(self, name, value, ctx=None): + def set_single_item( + self, name: str, value: List, ctx: Optional[RedisCtx] = None + ): """ Set (replace) a single KB element. Arguments: - name (str): key name of a list. - value (list): New elements to add to the key. - ctx (redis obj, optional): Redis context to use. + name: key name of a list. + value: New elements to add to the key. + ctx: Redis context to use. """ if not name: raise RequiredArgument('set_single_item', 'name') @@ -339,11 +359,11 @@ def set_single_item(self, name, value, ctx=None): pipe.rpush(name, *set(value)) pipe.execute() - def get_pattern(self, pattern, ctx=None): + def get_pattern(self, pattern: str, ctx: Optional[RedisCtx] = None) -> List: """ Get all items stored under a given pattern. Arguments: - pattern (str): key pattern to match. - ctx (redis obj, optional): Redis context to use. + pattern: key pattern to match. + ctx: Redis context to use. Return a list with the elements under the matched key. """ if not pattern: @@ -363,13 +383,18 @@ def get_pattern(self, pattern, ctx=None): ) return elem_list - def get_elem_pattern_by_index(self, pattern, index=1, ctx=None): + def get_elem_pattern_by_index( + self, + pattern: str, + index: Optional[int] = 1, + ctx: Optional[RedisCtx] = None, + ) -> List: """ Get all items with index 'index', stored under a given pattern. Arguments: - pattern (str): key pattern to match. - index (int, optional): Index of the element to get from the list. - ctx (redis obj, optional): Redis context to use. + pattern: key pattern to match. + index: Index of the element to get from the list. + ctx: Redis context to use. Return a list with the elements under the matched key and given index. """ if not pattern: @@ -384,38 +409,40 @@ def get_elem_pattern_by_index(self, pattern, index=1, ctx=None): elem_list.append([item, ctx.lindex(item, index)]) return elem_list - def release_db(self, kbindex=0): + def release_db(self, kbindex: Optional[int] = 0): """ Connect to redis and select the db by index. Flush db and delete the index from dbindex_name list. Arguments: - kbindex (int, optional): KB index to flush and release. + kbindex: KB index to flush and release. """ ctx = self.kb_connect(kbindex) ctx.flushdb() ctx = self.kb_connect() ctx.hdel(self.DBINDEX_NAME, kbindex) - def get_result(self, ctx=None): + def get_result(self, ctx: Optional[RedisCtx] = None) -> Optional[List]: """ Get and remove the oldest result from the list. Arguments: - ctx (redis obj, optional): Redis context to use. + ctx: Redis context to use. Return a list with scan results """ if not ctx: ctx = self.get_kb_context() return ctx.rpop("internal/results") - def get_status(self, ctx=None): + def get_status(self, ctx: Optional[RedisCtx] = None) -> Optional[str]: """ Get and remove the oldest host scan status from the list. Arguments: - ctx (redis obj, optional): Redis context to use. + ctx: Redis context to use. Return a string which represents the host scan status. """ if not ctx: ctx = self.get_kb_context() return ctx.rpop("internal/status") - def get_host_scan_scan_start_time(self, ctx=None): + def get_host_scan_scan_start_time( + self, ctx: Optional[RedisCtx] = None + ) -> Optional[str]: """ Get the timestamp of the scan start from redis. Arguments: ctx (redis obj, optional): Redis context to use. @@ -425,20 +452,22 @@ def get_host_scan_scan_start_time(self, ctx=None): ctx = self.get_kb_context() return ctx.rpop("internal/start_time") - def get_host_scan_scan_end_time(self, ctx=None): + def get_host_scan_scan_end_time( + self, ctx: Optional[RedisCtx] = None + ) -> Optional[str]: """ Get the timestamp of the scan end from redis. Arguments: - ctx (redis obj, optional): Redis context to use. + ctx: Redis context to use. Return a string with the timestamp of scan end . """ if not ctx: ctx = self.get_kb_context() return ctx.rpop("internal/end_time") - def get_host_ip(self, ctx=None): + def get_host_ip(self, ctx: Optional[RedisCtx] = None) -> Optional[str]: """ Get the ip of host_kb. Arguments: - ctx (redis obj, optional): Redis context to use. + ctx: Redis context to use. Return a string with the ip of the host being scanned. """ if not ctx: diff --git a/ospd_openvas/nvticache.py b/ospd_openvas/nvticache.py index 5a7853d9..d9eef13f 100644 --- a/ospd_openvas/nvticache.py +++ b/ospd_openvas/nvticache.py @@ -23,12 +23,11 @@ import subprocess import sys -from typing import List, Dict, NewType -from redis import Redis +from typing import List, Dict, Optional from pkg_resources import parse_version -from ospd_openvas.db import NVT_META_FIELDS +from ospd_openvas.db import NVT_META_FIELDS, RedisCtx from ospd_openvas.errors import OspdOpenvasError @@ -39,8 +38,6 @@ SUPPORTED_NVTICACHE_VERSIONS = ('20.4',) -RedisCtx = NewType('RedisCtx', Redis) - class NVTICache(object): @@ -176,7 +173,7 @@ def _parse_metadata_tags(tags_str: str, oid: str) -> Dict: return tags_dict - def get_nvt_metadata(self, oid: str) -> Dict: + def get_nvt_metadata(self, oid: str) -> Optional[Dict]: """ Get a full NVT. Returns an XML tree with the NVT metadata. Arguments: oid: OID of VT from which to get the metadata. @@ -221,7 +218,7 @@ def get_nvt_metadata(self, oid: str) -> Dict: return custom - def get_nvt_refs(self, oid: str) -> Dict: + def get_nvt_refs(self, oid: str) -> Optional[Dict]: """ Get a full NVT. Arguments: oid: OID of VT from which to get the VT references. @@ -247,7 +244,7 @@ def get_nvt_refs(self, oid: str) -> Dict: return refs - def get_nvt_prefs(self, ctx: RedisCtx, oid: str) -> List: + def get_nvt_prefs(self, ctx: RedisCtx, oid: str) -> Optional[List]: """ Get NVT preferences. Arguments: ctx: Redis context to be used.