Skip to content

Commit

Permalink
Merge pull request #909 from openvinotoolkit/ad/check_nncf_graph_2
Browse files Browse the repository at this point in the history
[OTE_SDK_TESTS] Rework test of nncf graph
  • Loading branch information
Ilya-Krylov authored Feb 14, 2022
2 parents 234bb80 + f23895b commit 461d501
Show file tree
Hide file tree
Showing 4 changed files with 128 additions and 75 deletions.
197 changes: 125 additions & 72 deletions ote_sdk/ote_sdk/test_suite/training_tests_actions.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,38 +74,31 @@ def __call__(self, data_collector: DataCollector, results_prev_stages: OrderedDi
raise NotImplementedError("The main action method is not implemented")


def create_environment_and_task(params, labels_schema, model_template):
environment = TaskEnvironment(
model=None,
hyper_parameters=params,
label_schema=labels_schema,
model_template=model_template,
)
logger.info("Create base Task")
task_impl_path = model_template.entrypoints.base
task_cls = get_impl_class(task_impl_path)
task = task_cls(task_environment=environment)
return environment, task


class OTETestTrainingAction(BaseOTETestAction):
_name = "training"

def __init__(
self,
dataset,
labels_schema,
template_path,
num_training_iters,
batch_size,
reference_dir,
self, dataset, labels_schema, template_path, num_training_iters, batch_size
):
self.dataset = dataset
self.labels_schema = labels_schema
self.template_path = template_path
self.num_training_iters = num_training_iters
self.batch_size = batch_size
self.reference_dir = reference_dir

@staticmethod
def _create_environment_and_task(params, labels_schema, model_template):
environment = TaskEnvironment(
model=None,
hyper_parameters=params,
label_schema=labels_schema,
model_template=model_template,
)
logger.info("Create base Task")
task_impl_path = model_template.entrypoints.base
task_cls = get_impl_class(task_impl_path)
task = task_cls(task_environment=environment)
return environment, task

def _get_training_performance_as_score_name_value(self):
training_performance = getattr(self.output_model, "performance", None)
Expand Down Expand Up @@ -154,7 +147,7 @@ def _run_ote_training(self, data_collector):
)

logger.debug("Setup environment")
self.environment, self.task = self._create_environment_and_task(
self.environment, self.task = create_environment_and_task(
params, self.labels_schema, self.model_template
)

Expand Down Expand Up @@ -184,7 +177,6 @@ def __call__(self, data_collector: DataCollector, results_prev_stages: OrderedDi
"dataset": self.dataset,
"environment": self.environment,
"output_model": self.output_model,
"reference_dir": self.reference_dir,
}
return results

Expand Down Expand Up @@ -455,55 +447,12 @@ def __call__(self, data_collector: DataCollector, results_prev_stages: OrderedDi
return results


# TODO: think about move to special file
def check_nncf_model_graph(reference_dir, nncf_task):
import networkx as nx

# pylint:disable=protected-access
if reference_dir is None:
logger.warning("reference_dir is None")
return True
path_to_dot = os.path.join(reference_dir, "nncf", f"{nncf_task._nncf_preset}.dot")
if not os.path.exists(path_to_dot):
logger.warning(f"Reference file does not exist: {path_to_dot}")
return True
logger.info(f"Reference graph: {path_to_dot}")
load_graph = nx.drawing.nx_pydot.read_dot(path_to_dot)

graph = nncf_task._model.get_graph()
nx_graph = graph.get_graph_for_structure_analysis()

for _, node in nx_graph.nodes(data=True):
if "scope" in node:
node.pop("scope")

for k, attrs in nx_graph.nodes.items():
attrs = {k: str(v) for k, v in attrs.items()}
load_attrs = {k: str(v).strip('"') for k, v in load_graph.nodes[k].items()}
if "scope" in load_attrs:
load_attrs.pop("scope")
if attrs != load_attrs:
logger.info("ATTR: {} : {} != {}".format(k, attrs, load_attrs))
return False

return (
load_graph.nodes.keys() == nx_graph.nodes.keys()
and nx.DiGraph(load_graph).edges == nx_graph.edges
)


class OTETestNNCFAction(BaseOTETestAction):
_name = "nncf"
_depends_stages_names = ["training"]

def _run_ote_nncf(
self,
data_collector,
model_template,
dataset,
trained_model,
environment,
reference_dir,
self, data_collector, model_template, dataset, trained_model, environment
):
logger.debug("Get predictions on the validation set for exported model")
self.environment_for_nncf = deepcopy(environment)
Expand Down Expand Up @@ -542,9 +491,6 @@ def _run_ote_nncf(
assert (
self.nncf_model.model_format == ModelFormat.BASE_FRAMEWORK
), "Wrong model format"
assert check_nncf_model_graph(
reference_dir, self.nncf_task
), "Compressed model differs from the reference"

logger.info("NNCF optimization is finished")

Expand All @@ -556,7 +502,6 @@ def __call__(self, data_collector: DataCollector, results_prev_stages: OrderedDi
"dataset": results_prev_stages["training"]["dataset"],
"trained_model": results_prev_stages["training"]["output_model"],
"environment": results_prev_stages["training"]["environment"],
"reference_dir": results_prev_stages["training"]["reference_dir"],
}

self._run_ote_nncf(data_collector, **kwargs)
Expand All @@ -568,6 +513,113 @@ def __call__(self, data_collector: DataCollector, results_prev_stages: OrderedDi
return results


# TODO: think about move to special file
def check_nncf_model_graph(model, path_to_dot):
import networkx as nx

logger.info(f"Reference graph: {path_to_dot}")
load_graph = nx.drawing.nx_pydot.read_dot(path_to_dot)

graph = model.get_graph()
nx_graph = graph.get_graph_for_structure_analysis()

for _, node in nx_graph.nodes(data=True):
if "scope" in node:
node.pop("scope")

for k, attrs in nx_graph.nodes.items():
attrs = {k: str(v) for k, v in attrs.items()}
load_attrs = {k: str(v).strip('"') for k, v in load_graph.nodes[k].items()}
if "scope" in load_attrs:
load_attrs.pop("scope")
if attrs != load_attrs:
logger.info("ATTR: {} : {} != {}".format(k, attrs, load_attrs))
return False

return (
load_graph.nodes.keys() == nx_graph.nodes.keys()
and nx.DiGraph(load_graph).edges == nx_graph.edges
)


class OTETestNNCFGraphAction(BaseOTETestAction):
_name = "nncf_graph"

def __init__(
self,
dataset,
labels_schema,
template_path,
reference_dir,
fn_get_compressed_model,
):
self.dataset = dataset
self.labels_schema = labels_schema
self.template_path = template_path
self.reference_dir = reference_dir
self.fn_get_compressed_model = fn_get_compressed_model

def _run_ote_nncf_graph(self, data_collector):
# pylint:disable=protected-access
logger.debug("Load model template")
model_template = parse_model_template(self.template_path)
nncf_task_class_impl_path = model_template.entrypoints.nncf

if not nncf_task_class_impl_path:
pytest.skip("NNCF is not enabled for this template")

if not is_nncf_enabled():
pytest.skip("NNCF is not installed")

if not os.path.exists(self.reference_dir):
pytest.skip("Reference directory does not exist")

params = ote_sdk_configuration_helper_create(
model_template.hyper_parameters.data
)
environment, task = create_environment_and_task(
params, self.labels_schema, model_template
)
output_model = ModelEntity(
self.dataset,
environment.get_model_configuration(),
)
# Save model without training to create nncf_task
task.save_model(output_model)

logger.info("Create NNCF Task")
environment_for_nncf = deepcopy(environment)

logger.info("Creating NNCF task and structures")
nncf_model = ModelEntity(
self.dataset,
environment_for_nncf.get_model_configuration(),
)
nncf_model.set_data("weights.pth", output_model.get_data("weights.pth"))

environment_for_nncf.model = nncf_model

nncf_task_cls = get_impl_class(nncf_task_class_impl_path)
nncf_task = nncf_task_cls(task_environment=environment_for_nncf)

path_to_ref_dot = os.path.join(
self.reference_dir, "nncf", f"{nncf_task._nncf_preset}.dot"
)
if not os.path.exists(path_to_ref_dot):
pytest.skip("Reference file does not exist: {}".format(path_to_ref_dot))

compressed_model = self.fn_get_compressed_model(nncf_task)

assert check_nncf_model_graph(
compressed_model, path_to_ref_dot
), "Compressed model differs from the reference"

def __call__(self, data_collector: DataCollector, results_prev_stages: OrderedDict):
self._check_result_prev_stages(results_prev_stages, self.depends_stages_names)
self._run_ote_nncf_graph(data_collector)
return {}


class OTETestNNCFEvaluationAction(BaseOTETestAction):
_name = "nncf_evaluation"
_with_validation = True
Expand Down Expand Up @@ -700,4 +752,5 @@ def get_default_test_action_classes() -> List[Type[BaseOTETestAction]]:
OTETestNNCFEvaluationAction,
OTETestNNCFExportAction,
OTETestNNCFExportEvaluationAction,
OTETestNNCFGraphAction,
]

0 comments on commit 461d501

Please sign in to comment.