Skip to content

Commit

Permalink
pylint: Fix pylint error for frontend_go
Browse files Browse the repository at this point in the history
Signed-off-by: Arthur Chan <arthur.chan@adalogics.com>
  • Loading branch information
arthurscchan committed Jan 22, 2025
1 parent a8c12d8 commit 869851d
Showing 1 changed file with 87 additions and 63 deletions.
150 changes: 87 additions & 63 deletions src/fuzz_introspector/frontends/frontend_go.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@
class GoSourceCodeFile(SourceCodeFile):
"""Class for holding file-specific information."""

def language_specific_process(self):
def language_specific_process(self) -> None:
"""Perform some language specific processes in subclasses."""
self.imports: list[str] = []
# List of function definitions in the source file.
Expand Down Expand Up @@ -88,12 +88,14 @@ def _set_imports(self):
for _, imports in import_query_res.items():
for imp in imports:
for import_spec in imp.children:
if import_spec.type == 'import_spec_list':
for path in import_spec.children:
if path.type == 'import_spec':
path = path.text.decode().replace('"', '')
# Only store the package name, not full path
import_set.add(path.rsplit('/', 1)[-1])
if import_spec.type != 'import_spec_list':
continue

for path in import_spec.children:
if path.type == 'import_spec':
path = path.text.decode().replace('"', '')
# Only store the package name, not full path
import_set.add(path.rsplit('/', 1)[-1])

self.imports = list(import_set)

Expand Down Expand Up @@ -177,6 +179,7 @@ def dump_module_logic(self,
harness_source: str = '',
dump_output: bool = True):
"""Dumps the data for the module in full."""
# pylint: disable=unused-argument
logger.info('Dumping project-wide logic.')
report: dict[str, Any] = {'report': 'name'}
report['sources'] = []
Expand Down Expand Up @@ -258,6 +261,7 @@ def extract_calltree(self,
line_number: int = -1,
other_props: Optional[dict[str, Any]] = None) -> str:
"""Extracts calltree string of a calltree so that FI core can use it."""
# pylint: disable=unused-argument
if not visited_functions:
visited_functions = set()

Expand Down Expand Up @@ -292,25 +296,23 @@ def extract_calltree(self,
return line_to_print

visited_functions.add(function)
for cs, line_number in func.base_callsites:
for cs, line in func.base_callsites:
line_to_print += self.extract_calltree(
source_code.source_file,
function=cs,
visited_functions=visited_functions,
depth=depth + 1,
line_number=line_number)
line_number=line)
return line_to_print

def get_source_codes_with_harnesses(self) -> list[GoSourceCodeFile]:
return super().get_source_codes_with_harnesses()

def get_reachable_functions(
self,
source_file: str = '',
source_code: Optional[GoSourceCodeFile] = None,
function: Optional[str] = None,
visited_functions: Optional[set[str]] = None) -> set[str]:
"""Get a list of reachable functions for a provided function name."""
# pylint: disable=unused-argument
if not visited_functions:
visited_functions = set()

Expand All @@ -333,7 +335,7 @@ def get_reachable_functions(
return visited_functions

visited_functions.add(function)
for cs, line_number in func.base_callsites:
for cs, _ in func.base_callsites:
visited_functions = self.get_reachable_functions(
source_code.source_file,
function=cs,
Expand Down Expand Up @@ -480,19 +482,18 @@ def _process_properties(self):
param_name = param_tmp.text.decode()

# Param type
if param.child_by_field_name('type'):
type_str = param.child_by_field_name(
'type').text.decode()
param_tmp = param
while param_tmp.child_by_field_name(
'declarator') is not None:
if param_tmp.type == 'pointer_declarator':
type_str += '*'
param_tmp = param_tmp.child_by_field_name(
'declarator')
param_type = type_str

if param_name:
if not param.child_by_field_name('type'):
continue
type_str = param.child_by_field_name('type').text.decode()
param_tmp = param
while param_tmp.child_by_field_name(
'declarator') is not None:
if param_tmp.type == 'pointer_declarator':
type_str += '*'
param_tmp = param_tmp.child_by_field_name('declarator')
param_type = type_str

if param_name and param_type:
if param_name != self.receiver_name:
param_names.append(param_name)
param_types.append(param_type)
Expand Down Expand Up @@ -575,17 +576,21 @@ def _process_call_expr_child(
self, call_child: Node,
all_funcs_meths: dict[str, 'FunctionMethod']) -> Optional[str]:
"""Internal helper to process call expr."""
# pylint: disable=unused-argument
target_name = None

# Simple call
if call_child.type == 'identifier':
target_name = call_child.text.decode()
if call_child.text:
target_name = call_child.text.decode()

# Package/method call
if call_child.type == 'selector_expression':
target_name = call_child.text.decode()
if call_child.text:
target_name = call_child.text.decode()

# Variable call
split_call = target_name.split('.')
split_call = target_name.split('.') if target_name else []
if len(split_call) > 1:
# For indexing selector
if '[' in split_call[-2] and ']' in split_call[-2]:
Expand All @@ -601,10 +606,11 @@ def _process_call_expr_child(
target_name = f'{var_name}.{split_call[-1]}'

elif split_call[0] not in self.parent_source.imports:
target_name = target_name.split('.')[-1]
target_name = (target_name.split('.')[-1]
if target_name else None)

# Chain call
split_call = target_name.rsplit(').', 1)
split_call = target_name.rsplit(').', 1) if target_name else []
if len(split_call) > 1:
target_name = split_call[1]

Expand All @@ -613,47 +619,52 @@ def _process_call_expr_child(
def _detect_variable_type(
self, node: Node,
all_funcs_meths: dict[str, 'FunctionMethod']) -> Optional[str]:
"""Internal recursive helper to determine the return type of the expression."""
"""Internal recursive helper to determine the return type of the
expression."""

for child in node.children:
# Literals
if child.type in LITERAL_TYPE_MAP:
return LITERAL_TYPE_MAP[child.type]

# Identifier
elif child.type == 'identifier':
if child.text.decode() in self.var_map:
if child.type == 'identifier':
if child.text and child.text.decode() in self.var_map:
return self.var_map[child.text.decode()]

# Composite Literal
elif child.type == 'composite_literal':
composite_type = child.child_by_field_name('type')
if composite_type:
if composite_type and composite_type.text:
return composite_type.text.decode()

# Call expression
elif child.type == 'call_expression':
call = child.child_by_field_name('function')
args = child.child_by_field_name('arguments')

if not call or not args:
continue

target_name = self._process_call_expr_child(
call, all_funcs_meths)

if target_name in all_funcs_meths:
return all_funcs_meths[target_name].return_type

elif target_name == 'new':
if target_name == 'new':
for arg in args.children:
if arg.type.endswith('identifier'):
if arg.type.endswith('identifier') and arg.text:
return arg.text.decode()

elif target_name == 'make':
for arg in args.children:
type_node = arg.child_by_field_name('value')
if type_node:
if type_node and type_node.text:
return type_node.text.decode()

type_node = arg.child_by_field_name('element')
if type_node:
if type_node and type_node.text:
return type_node.text.decode()

# Selector expression
Expand All @@ -666,11 +677,14 @@ def _detect_variable_type(
# Index expression / Slice expression
elif child.type in ['index_expression', 'slice_expression']:
op = child.child_by_field_name('operand')
if not op or not op.text:
continue

parent_type = self.var_map.get(op.text.decode())
if parent_type:
if '[' in parent_type and ']' in parent_type:
return parent_type.rsplit(']', 1)[-1]
elif parent_type == 'string':
if parent_type == 'string':
return 'uint8'

# Other expression that need to recursive deeper
Expand All @@ -691,11 +705,13 @@ def extract_local_variable_type(self,
for decl_node in exprs:
left = decl_node.child_by_field_name('left')
right = decl_node.child_by_field_name('right')
decl_name = ''

if not left or not right:
continue

for child in left.children:
if child.type == 'identifier':
if child.type == 'identifier' and child.text:
decl_name = child.text.decode()

decl_type = self._detect_variable_type(right, all_funcs_meths)
Expand All @@ -707,29 +723,33 @@ def extract_local_variable_type(self,
for _, exprs in query.captures(self.root).items():
for for_node in exprs:
for child in for_node.children:
if child.type == 'range_clause':
left = child.child_by_field_name('left')
right = child.child_by_field_name('right')
if not left or not right:
continue
if child.type != 'range_clause':
continue

left = child.child_by_field_name('left')
right = child.child_by_field_name('right')
if not left or not right:
continue

for left_child in left.children:
if left_child.type == 'identifier' and left_child.text:
decl_name = left_child.text.decode()

for left_child in left.children:
if left_child.type == 'identifier':
decl_name = left_child.text.decode()
if right.type == 'identifier':
if not right.text:
continue

if right.type == 'identifier':
decl_type = self.var_map.get(
right.text.decode(), '')
if '[' in decl_type and ']' in decl_type:
decl_type = decl_type.split(']', 1)[-1]
elif decl_type == 'string':
decl_type = 'uint8'
else:
decl_type = self._detect_variable_type(
right, all_funcs_meths)
decl_type = self.var_map.get(right.text.decode(), '')
if '[' in decl_type and ']' in decl_type:
decl_type = decl_type.split(']', 1)[-1]
elif decl_type == 'string':
decl_type = 'uint8'
else:
decl_type = self._detect_variable_type(
right, all_funcs_meths)

if decl_name and decl_type:
self.var_map[decl_name] = decl_type
if decl_name and decl_type:
self.var_map[decl_name] = decl_type

def extract_callsites(self, all_funcs_meths: dict[str, 'FunctionMethod']):
"""Gets the callsites of the function."""
Expand All @@ -740,6 +760,9 @@ def extract_callsites(self, all_funcs_meths: dict[str, 'FunctionMethod']):
for _, call_exprs in call_res.items():
for call_expr in call_exprs:
call = call_expr.child_by_field_name('function')
if not call:
continue

target_name = self._process_call_expr_child(
call, all_funcs_meths)
if target_name in ['new', 'make']:
Expand All @@ -757,13 +780,14 @@ def extract_callsites(self, all_funcs_meths: dict[str, 'FunctionMethod']):
self.base_callsites = [(x[0], x[2]) for x in callsites]
# Process detailed callsites
for dst, src_line in self.base_callsites:
src_loc = self.parent_source.source_file + ':%d,1' % (src_line)
src_loc = f'{self.parent_source.source_file}:{src_line},1'
self.detailed_callsites.append({'Src': src_loc, 'Dst': dst})


def load_treesitter_trees(source_files: list[str],
is_log: bool = True) -> GoProject:
"""Creates treesitter trees for all files in a given list of source files."""
"""Creates treesitter trees for all files in a given list of source
files."""
results = []

for code_file in source_files:
Expand Down

0 comments on commit 869851d

Please sign in to comment.