Skip to content

Commit

Permalink
[JVM-frontend] Add unit test for jvm frontend
Browse files Browse the repository at this point in the history
Signed-off-by: Arthur Chan <arthur.chan@adalogics.com>
  • Loading branch information
arthurscchan committed Jan 10, 2025
1 parent 0c3b682 commit 184d59b
Show file tree
Hide file tree
Showing 4 changed files with 265 additions and 20 deletions.
162 changes: 142 additions & 20 deletions src/fuzz_introspector/frontends/frontend_jvm.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,19 @@
"remainingBytes": "int"
}

LITERAL_MAP = {
"decimal_integer_literal": "int",
"hex_integer_literal": "int",
"octal_integer_literal": "int",
"binary_integer_literal": "int",
"decimal_floating_point_literal": "float",
"hex_floating_point_literal": "float",
"true": "boolean",
"false": "boolean",
"character_literal": "char",
"string_literal": "String",
"null_literal": "null"
}

class SourceCodeFile():
"""Class for holding file-specific information."""
Expand Down Expand Up @@ -114,7 +127,7 @@ def _set_package_declaration(self):
for _, nodes in res.items():
for node in nodes:
for package in node.children:
if package.type == 'scoped_identifier':
if package.type in ['scoped_identifier', 'identifier']:
self.package = package.text.decode()

def _set_class_interface_declaration(self):
Expand Down Expand Up @@ -338,7 +351,7 @@ def _process_declaration(self):
self.var_map[arg_name] = arg_type

# Process return type
elif child.type == 'type_identifier' or child.type.endswith(
elif child.type.endswith('type_identifier') or child.type.endswith(
'_type'):
self.return_type = child.text.decode()

Expand All @@ -352,7 +365,7 @@ def _process_declaration(self):
# Process exceptions
elif child.type == 'throws':
for exception in child.children:
if exception.type == 'type_identifier':
if exception.type.endswith('type_identifier'):
self.exceptions.append(exception.text.decode())

def _process_statements(self):
Expand Down Expand Up @@ -450,8 +463,13 @@ def _process_invoke_object(
"""Internal helper for processing the object from a invocation."""
callsites: list[tuple[str, int, int]] = []
return_value = ''

# Handle literal value
if stmt.type in LITERAL_MAP:
return_value = LITERAL_MAP[stmt.type]

# Determine the type of the object
if stmt.child_count == 0:
elif stmt.child_count == 0:
# Class call
if stmt.type == 'this':
return_value = self.class_interface.name
Expand All @@ -468,7 +486,7 @@ def _process_invoke_object(
stmt.text.decode(), '')
if not return_value and self.parent_source:
return_value = self.parent_source.imports.get(
stmt.text.decode(), self.class_interface.name)
stmt.text.decode(), '')
else:
# Field access
if stmt.type == 'field_access':
Expand Down Expand Up @@ -533,8 +551,35 @@ def _process_invoke_args(
for argument in stmt.children:
return_value = self.class_interface.name

# Handling literal value
if argument.type in LITERAL_MAP:
return_values.append(LITERAL_MAP[argument.type])

# Binary expression
elif argument.type == 'binary_expression':
found = False
other_type_node = []

# Try locate literal values
for child in argument.children:
if child.type in LITERAL_MAP:
return_values.append(LITERAL_MAP[child.type])
found = True
else:
other_type_node.append(child)

# Only store type value is not found
for node in other_type_node:
return_value, invoke = self._process_invoke(
node, classes)

if return_value and not found:
found = True
return_values.append(return_value)
callsites.extend(invoke)

# Variables
if argument.type == 'identifier':
elif argument.type == 'identifier':
return_value = self.var_map.get(argument.text.decode(), '')
if not return_value:
return_value = self.class_interface.class_fields.get(
Expand Down Expand Up @@ -629,12 +674,19 @@ def _process_invoke(
elif cls_type.type == 'super':
object_type = self.class_interface.super_class

elif cls_type.type == 'type_identifier' or cls_type.type.endswith(
elif cls_type.type.endswith('type_identifier') or cls_type.type.endswith(
'_type'):
object_type = cls_type.text.decode().split('<')[0]

object_type = self.parent_source.get_full_qualified_name(
object_type)

for cls in classes.values():
packaged_type = cls.add_package_to_class_name(object_type)
if packaged_type:
object_type = packaged_type
break

target_name = f'[{object_type}].<init>({",".join(argument_types)})'
callsites.append(
(target_name, expr.byte_range[1], expr.start_point.row + 1))
Expand All @@ -652,10 +704,27 @@ def _process_invoke(
# Process this method invocation
target_name = ''
if object_type and name:
for cls in classes.values():
packaged_type = cls.add_package_to_class_name(object_type)
if packaged_type:
object_type = packaged_type
break

target_name = f'[{object_type}].{name.text.decode()}({",".join(argument_types)})'
callsites.append(
(target_name, expr.byte_range[1], expr.start_point.row + 1))

# Calling to library outside of project
# Preserve the full method call
elif name:
if objects:
target_name = f'{objects.text.decode()}.{name.text.decode()}({",".join(argument_types)})'
else:
target_name = f'{name.text.decode()}({",".join(argument_types)})'

callsites.append(
(target_name, expr.byte_range[1], expr.start_point.row + 1))

# Determine return value from method invocation
if object_type == 'com.code_intelligence.jazzer.api.FuzzedDataProvider':
return_type = FUZZING_METHOD_RETURN_TYPE_MAP.get(
Expand All @@ -678,32 +747,41 @@ def _process_invoke(
def _process_callsites(
self, stmt: Node,
classes: dict[str,
'JavaClassInterface']) -> list[tuple[str, int, int]]:
'JavaClassInterface']) -> tuple[str, list[tuple[str, int, int]]]:
"""Process and store the callsites of the method."""
type = ''
callsites = []

if stmt.type == 'method_invocation':
_, invoke_callsites = self._process_invoke(stmt, classes)
type, invoke_callsites = self._process_invoke(stmt, classes)
callsites.extend(invoke_callsites)
elif stmt.type == 'object_creation_expression':
_, invoke_callsites = self._process_invoke(stmt, classes, True)
type, invoke_callsites = self._process_invoke(stmt, classes, True)
callsites.extend(invoke_callsites)
elif stmt.type == 'explicit_constructor_invocation':
_, invoke_callsites = self._process_invoke(stmt, classes, True)
type, invoke_callsites = self._process_invoke(stmt, classes, True)
callsites.extend(invoke_callsites)
elif stmt.type == 'assignment_expression':
left = stmt.child_by_field_name('left')
right = stmt.child_by_field_name('right')

var_name = left.text.decode().split(' ')[-1]
type, invoke_callsites = self._process_callsites(right, classes)
self.var_map[var_name] = type
callsites.extend(invoke_callsites)
else:
for child in stmt.children:
callsites.extend(self._process_callsites(child, classes))
callsites.extend(self._process_callsites(child, classes)[1])

return callsites
return type, callsites

def extract_callsites(self, classes: dict[str, 'JavaClassInterface']):
"""Extract callsites."""

if not self.base_callsites:
callsites = []
for stmt in self.stmts:
callsites.extend(self._process_callsites(stmt, classes))
callsites.extend(self._process_callsites(stmt, classes)[1])
callsites = sorted(set(callsites), key=lambda x: x[1])
self.base_callsites = [(x[0], x[2]) for x in callsites]

Expand Down Expand Up @@ -748,6 +826,17 @@ def __init__(self,
# Process inner classes
self._process_inner_classes(inner_class_nodes)

def add_package_to_class_name(self, name: str) -> Optional[str]:
"""Helper for finding a specific class name."""
if self.name == f'{self.package}.{name.rsplit(".")[-1]}':
if self.name.endswith(name):
return self.name

for inner_class in self.inner_classes:
return inner_class.add_package_to_class_name(name)

return None

def post_process_full_qualified_name(self):
"""Post process the full qualified name for types."""
# Refine class fields
Expand Down Expand Up @@ -778,7 +867,7 @@ def _process_node(self) -> list[Node]:
# Process super class
if child.type == 'superclass':
for cls in child.children:
if cls.type == 'type_identifier':
if cls.type.endswith('type_identifier'):
self.super_class = cls.text.decode()

# Process super interfaces
Expand All @@ -787,7 +876,7 @@ def _process_node(self) -> list[Node]:
if interfaces.type == 'type_list':
type_set = set()
for interface in interfaces.children:
if interface.type == 'type_identifier':
if interface.type.endswith('type_identifier'):
type_set.add(interface.text.decode())
self.super_interfaces = list(type_set)

Expand Down Expand Up @@ -1083,9 +1172,6 @@ def extract_calltree(self,
if not method and source_code:
method = source_code.get_entry_method_name(True)

if not method or not source_code:
return ''

line_to_print = ' ' * depth
line_to_print += method
line_to_print += ' '
Expand Down Expand Up @@ -1118,11 +1204,47 @@ def extract_calltree(self,

return line_to_print

def get_reachable_methods(self,
source_file: str,
source_code: Optional[SourceCodeFile] = None,
method: Optional[str] = None,
visited_methods: Optional[set[str]] = None) -> set[str]:
"""Get a list of reachable functions for a provided function name."""
if not visited_methods:
visited_methods = set()

if not source_code and method:
source_code = self.find_source_with_method(method)

if not method and source_code:
method = source_code.get_entry_method_name(True)

if source_code and method:
method_node = source_code.get_method_node(method)
if not method_node:
visited_methods.add(method)
return visited_methods
else:
visited_methods.add(method)
return visited_methods

visited_methods.add(method)
for cs, _ in method_node.base_callsites:
if cs in visited_methods:
continue

visited_methods = self.get_reachable_methods(
source_code.source_file,
method=cs,
visited_methods=visited_methods)

return visited_methods


def capture_source_files_in_tree(directory_tree: str) -> list[str]:
"""Captures source code files in a given directory."""
exclude_directories = [
'target', 'test', 'node_modules', 'aflplusplus', 'honggfuzz',
'target', 'node_modules', 'aflplusplus', 'honggfuzz',
'inspector', 'libfuzzer'
]
language_extensions = ['.java']
Expand Down
28 changes: 28 additions & 0 deletions src/test/data/source-code/jvm/test-project-1/Fuzzer.java
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
package simple;

import com.code_intelligence.jazzer.api.FuzzedDataProvider;

class SimpleClass {
public SimpleClass() {
System.out.println("Default Constructor Called");
}

public SimpleClass(String param) {
System.out.println("Constructor with parameter called: " + param.toUpperCase());
}

public void simpleMethod() {
System.out.println("Simple Method Called");
}

public void unreachableMethod() {
System.out.println("Unreachable Method in SimpleClass");
}
}

public class Fuzzer {
public static void fuzzerTestOneInput(FuzzedDataProvider data) {
SimpleClass sc = new SimpleClass(data.consumeString(10));
sc.simpleMethod();
}
}
55 changes: 55 additions & 0 deletions src/test/data/source-code/jvm/test-project-2/Fuzzer.java
Original file line number Diff line number Diff line change
@@ -0,0 +1,55 @@
package polymorphism;

import com.code_intelligence.jazzer.api.FuzzedDataProvider;

interface Animal {
void sound();
}

class Dog implements Animal {
public Dog() {
System.out.println("Dog Constructor Called: " + Math.random());
}

public Dog(String name) {
System.out.println("Dog Constructor with name: " + name.toLowerCase());
}

public void sound() {
System.out.println("Bark");
}

public void unreachableDogMethod() {
System.out.println("Unreachable Method in Dog");
}
}

class Cat implements Animal {
public Cat() {
System.out.println("Cat Constructor Called: " + Math.random());
}

public Cat(String name) {
System.out.println("Cat Constructor with name: " + name.toUpperCase());
}

public void sound() {
System.out.println("Meow");
}

public void unreachableCatMethod() {
System.out.println("Unreachable Method in Cat");
}
}

public class Fuzzer {
public static void fuzzerTestOneInput(FuzzedDataProvider data) {
Animal animal;
if ("dog".equals(data.consumeString(10))) {
animal = new Dog(data.consumeString(10));
} else {
animal = new Cat(data.consumeString(10));
}
animal.sound();
}
}
Loading

0 comments on commit 184d59b

Please sign in to comment.