From 869851d50aaf7b5012a591a486e371b8437c87ec Mon Sep 17 00:00:00 2001 From: Arthur Chan Date: Wed, 22 Jan 2025 16:48:46 +0000 Subject: [PATCH] pylint: Fix pylint error for frontend_go Signed-off-by: Arthur Chan --- .../frontends/frontend_go.py | 150 ++++++++++-------- 1 file changed, 87 insertions(+), 63 deletions(-) diff --git a/src/fuzz_introspector/frontends/frontend_go.py b/src/fuzz_introspector/frontends/frontend_go.py index eecd43aab..749b6de8e 100644 --- a/src/fuzz_introspector/frontends/frontend_go.py +++ b/src/fuzz_introspector/frontends/frontend_go.py @@ -41,7 +41,7 @@ class GoSourceCodeFile(SourceCodeFile): """Class for holding file-specific information.""" - def language_specific_process(self): + def language_specific_process(self) -> None: """Perform some language specific processes in subclasses.""" self.imports: list[str] = [] # List of function definitions in the source file. @@ -88,12 +88,14 @@ def _set_imports(self): for _, imports in import_query_res.items(): for imp in imports: for import_spec in imp.children: - if import_spec.type == 'import_spec_list': - for path in import_spec.children: - if path.type == 'import_spec': - path = path.text.decode().replace('"', '') - # Only store the package name, not full path - import_set.add(path.rsplit('/', 1)[-1]) + if import_spec.type != 'import_spec_list': + continue + + for path in import_spec.children: + if path.type == 'import_spec': + path = path.text.decode().replace('"', '') + # Only store the package name, not full path + import_set.add(path.rsplit('/', 1)[-1]) self.imports = list(import_set) @@ -177,6 +179,7 @@ def dump_module_logic(self, harness_source: str = '', dump_output: bool = True): """Dumps the data for the module in full.""" + # pylint: disable=unused-argument logger.info('Dumping project-wide logic.') report: dict[str, Any] = {'report': 'name'} report['sources'] = [] @@ -258,6 +261,7 @@ def extract_calltree(self, line_number: int = -1, other_props: Optional[dict[str, Any]] = None) -> str: """Extracts calltree string of a calltree so that FI core can use it.""" + # pylint: disable=unused-argument if not visited_functions: visited_functions = set() @@ -292,18 +296,15 @@ def extract_calltree(self, return line_to_print visited_functions.add(function) - for cs, line_number in func.base_callsites: + for cs, line in func.base_callsites: line_to_print += self.extract_calltree( source_code.source_file, function=cs, visited_functions=visited_functions, depth=depth + 1, - line_number=line_number) + line_number=line) return line_to_print - def get_source_codes_with_harnesses(self) -> list[GoSourceCodeFile]: - return super().get_source_codes_with_harnesses() - def get_reachable_functions( self, source_file: str = '', @@ -311,6 +312,7 @@ def get_reachable_functions( function: Optional[str] = None, visited_functions: Optional[set[str]] = None) -> set[str]: """Get a list of reachable functions for a provided function name.""" + # pylint: disable=unused-argument if not visited_functions: visited_functions = set() @@ -333,7 +335,7 @@ def get_reachable_functions( return visited_functions visited_functions.add(function) - for cs, line_number in func.base_callsites: + for cs, _ in func.base_callsites: visited_functions = self.get_reachable_functions( source_code.source_file, function=cs, @@ -480,19 +482,18 @@ def _process_properties(self): param_name = param_tmp.text.decode() # Param type - if param.child_by_field_name('type'): - type_str = param.child_by_field_name( - 'type').text.decode() - param_tmp = param - while param_tmp.child_by_field_name( - 'declarator') is not None: - if param_tmp.type == 'pointer_declarator': - type_str += '*' - param_tmp = param_tmp.child_by_field_name( - 'declarator') - param_type = type_str - - if param_name: + if not param.child_by_field_name('type'): + continue + type_str = param.child_by_field_name('type').text.decode() + param_tmp = param + while param_tmp.child_by_field_name( + 'declarator') is not None: + if param_tmp.type == 'pointer_declarator': + type_str += '*' + param_tmp = param_tmp.child_by_field_name('declarator') + param_type = type_str + + if param_name and param_type: if param_name != self.receiver_name: param_names.append(param_name) param_types.append(param_type) @@ -575,17 +576,21 @@ def _process_call_expr_child( self, call_child: Node, all_funcs_meths: dict[str, 'FunctionMethod']) -> Optional[str]: """Internal helper to process call expr.""" + # pylint: disable=unused-argument target_name = None # Simple call if call_child.type == 'identifier': - target_name = call_child.text.decode() + if call_child.text: + target_name = call_child.text.decode() # Package/method call if call_child.type == 'selector_expression': - target_name = call_child.text.decode() + if call_child.text: + target_name = call_child.text.decode() + # Variable call - split_call = target_name.split('.') + split_call = target_name.split('.') if target_name else [] if len(split_call) > 1: # For indexing selector if '[' in split_call[-2] and ']' in split_call[-2]: @@ -601,10 +606,11 @@ def _process_call_expr_child( target_name = f'{var_name}.{split_call[-1]}' elif split_call[0] not in self.parent_source.imports: - target_name = target_name.split('.')[-1] + target_name = (target_name.split('.')[-1] + if target_name else None) # Chain call - split_call = target_name.rsplit(').', 1) + split_call = target_name.rsplit(').', 1) if target_name else [] if len(split_call) > 1: target_name = split_call[1] @@ -613,7 +619,8 @@ def _process_call_expr_child( def _detect_variable_type( self, node: Node, all_funcs_meths: dict[str, 'FunctionMethod']) -> Optional[str]: - """Internal recursive helper to determine the return type of the expression.""" + """Internal recursive helper to determine the return type of the + expression.""" for child in node.children: # Literals @@ -621,39 +628,43 @@ def _detect_variable_type( return LITERAL_TYPE_MAP[child.type] # Identifier - elif child.type == 'identifier': - if child.text.decode() in self.var_map: + if child.type == 'identifier': + if child.text and child.text.decode() in self.var_map: return self.var_map[child.text.decode()] # Composite Literal elif child.type == 'composite_literal': composite_type = child.child_by_field_name('type') - if composite_type: + if composite_type and composite_type.text: return composite_type.text.decode() # Call expression elif child.type == 'call_expression': call = child.child_by_field_name('function') args = child.child_by_field_name('arguments') + + if not call or not args: + continue + target_name = self._process_call_expr_child( call, all_funcs_meths) if target_name in all_funcs_meths: return all_funcs_meths[target_name].return_type - elif target_name == 'new': + if target_name == 'new': for arg in args.children: - if arg.type.endswith('identifier'): + if arg.type.endswith('identifier') and arg.text: return arg.text.decode() elif target_name == 'make': for arg in args.children: type_node = arg.child_by_field_name('value') - if type_node: + if type_node and type_node.text: return type_node.text.decode() type_node = arg.child_by_field_name('element') - if type_node: + if type_node and type_node.text: return type_node.text.decode() # Selector expression @@ -666,11 +677,14 @@ def _detect_variable_type( # Index expression / Slice expression elif child.type in ['index_expression', 'slice_expression']: op = child.child_by_field_name('operand') + if not op or not op.text: + continue + parent_type = self.var_map.get(op.text.decode()) if parent_type: if '[' in parent_type and ']' in parent_type: return parent_type.rsplit(']', 1)[-1] - elif parent_type == 'string': + if parent_type == 'string': return 'uint8' # Other expression that need to recursive deeper @@ -691,11 +705,13 @@ def extract_local_variable_type(self, for decl_node in exprs: left = decl_node.child_by_field_name('left') right = decl_node.child_by_field_name('right') + decl_name = '' + if not left or not right: continue for child in left.children: - if child.type == 'identifier': + if child.type == 'identifier' and child.text: decl_name = child.text.decode() decl_type = self._detect_variable_type(right, all_funcs_meths) @@ -707,29 +723,33 @@ def extract_local_variable_type(self, for _, exprs in query.captures(self.root).items(): for for_node in exprs: for child in for_node.children: - if child.type == 'range_clause': - left = child.child_by_field_name('left') - right = child.child_by_field_name('right') - if not left or not right: - continue + if child.type != 'range_clause': + continue + + left = child.child_by_field_name('left') + right = child.child_by_field_name('right') + if not left or not right: + continue + + for left_child in left.children: + if left_child.type == 'identifier' and left_child.text: + decl_name = left_child.text.decode() - for left_child in left.children: - if left_child.type == 'identifier': - decl_name = left_child.text.decode() + if right.type == 'identifier': + if not right.text: + continue - if right.type == 'identifier': - decl_type = self.var_map.get( - right.text.decode(), '') - if '[' in decl_type and ']' in decl_type: - decl_type = decl_type.split(']', 1)[-1] - elif decl_type == 'string': - decl_type = 'uint8' - else: - decl_type = self._detect_variable_type( - right, all_funcs_meths) + decl_type = self.var_map.get(right.text.decode(), '') + if '[' in decl_type and ']' in decl_type: + decl_type = decl_type.split(']', 1)[-1] + elif decl_type == 'string': + decl_type = 'uint8' + else: + decl_type = self._detect_variable_type( + right, all_funcs_meths) - if decl_name and decl_type: - self.var_map[decl_name] = decl_type + if decl_name and decl_type: + self.var_map[decl_name] = decl_type def extract_callsites(self, all_funcs_meths: dict[str, 'FunctionMethod']): """Gets the callsites of the function.""" @@ -740,6 +760,9 @@ def extract_callsites(self, all_funcs_meths: dict[str, 'FunctionMethod']): for _, call_exprs in call_res.items(): for call_expr in call_exprs: call = call_expr.child_by_field_name('function') + if not call: + continue + target_name = self._process_call_expr_child( call, all_funcs_meths) if target_name in ['new', 'make']: @@ -757,13 +780,14 @@ def extract_callsites(self, all_funcs_meths: dict[str, 'FunctionMethod']): self.base_callsites = [(x[0], x[2]) for x in callsites] # Process detailed callsites for dst, src_line in self.base_callsites: - src_loc = self.parent_source.source_file + ':%d,1' % (src_line) + src_loc = f'{self.parent_source.source_file}:{src_line},1' self.detailed_callsites.append({'Src': src_loc, 'Dst': dst}) def load_treesitter_trees(source_files: list[str], is_log: bool = True) -> GoProject: - """Creates treesitter trees for all files in a given list of source files.""" + """Creates treesitter trees for all files in a given list of source + files.""" results = [] for code_file in source_files: