Skip to content

Commit

Permalink
Add seed + sample.py guide
Browse files Browse the repository at this point in the history
  • Loading branch information
ashwinvaidya17 committed May 23, 2022
1 parent c49a6fd commit b7c30a2
Show file tree
Hide file tree
Showing 5 changed files with 46 additions and 10 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -184,7 +184,7 @@ def create_task_annotations(task: str, data_path: str, annotation_path: str) ->
Raises:
ValueError: When task is not classification, detection or segmentation.
"""
annotation_path = os.path.join(data_path, task)
annotation_path = os.path.join(annotation_path, task)
os.makedirs(annotation_path, exist_ok=True)

for split in ["train", "val", "test"]:
Expand Down
12 changes: 11 additions & 1 deletion external/anomaly/ote_anomalib/train_task.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
# See the License for the specific language governing permissions
# and limitations under the License.

from typing import Optional
from anomalib.utils.callbacks import MinMaxNormalizationCallback
from ote_anomalib import AnomalyInferenceTask
from ote_anomalib.callbacks import ProgressCallback
Expand All @@ -23,7 +24,7 @@
from ote_sdk.entities.model import ModelEntity
from ote_sdk.entities.train_parameters import TrainParameters
from ote_sdk.usecases.tasks.interfaces.training_interface import ITrainingTask
from pytorch_lightning import Trainer
from pytorch_lightning import Trainer, seed_everything

logger = get_logger(__name__)

Expand All @@ -36,17 +37,26 @@ def train(
dataset: DatasetEntity,
output_model: ModelEntity,
train_parameters: TrainParameters,
seed: Optional[int] = 0,
) -> None:
"""Train the anomaly classification model.
Args:
dataset (DatasetEntity): Input dataset.
output_model (ModelEntity): Output model to save the model weights.
train_parameters (TrainParameters): Training parameters
seed: (Optional[int]): Setting seed to a value other than 0 also marks PytorchLightning trainer's
deterministic flag to True.
"""
logger.info("Training the model.")

config = self.get_config()

if seed is not None and seed > 0:
logger.info(f"Setting seed to {seed}")
seed_everything(seed, workers=True)
config.trainer.deterministic = True

logger.info("Training Configs '%s'", config)

datamodule = OTEAnomalyDataModule(config=config, dataset=dataset, task_type=self.task_type)
Expand Down
3 changes: 2 additions & 1 deletion external/anomaly/tests/test_ote_training.py
Original file line number Diff line number Diff line change
Expand Up @@ -238,7 +238,8 @@ def _run_ote_training(self, data_collector):
self.copy_hyperparams = deepcopy(self.task.task_environment.get_hyper_parameters())

try:
self.task.train(self.dataset, self.output_model, TrainParameters)
# fix seed so that result is repeatable
self.task.train(self.dataset, self.output_model, TrainParameters, seed=42)
except Exception as ex:
raise RuntimeError("Training failed") from ex

Expand Down
17 changes: 17 additions & 0 deletions external/anomaly/tools/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
OpenVINO Training Extension interacts with the anomaly detection library ([Anomalib](https://github.com/openvinotoolkit/anomalib)) by providing interfaces in the `external/anomaly` of this repository. The `sample.py` file contained in this folder serves as an end-to-end example of how these interfaces are used. To begin using this script, first ensure that `ote_cli`, `ote_sdk` and `external/anomaly` dependencies are installed.

To get started, we provide a handy script in `ote_anomalib/data/create_mvtec_ad_json_annotations.py` to help generate annotation json files for MVTec dataset. Assuming that you have placed the MVTec dataset in a directory your home folder (`~/dataset/MVTec`), you can run the following command to generate the annotations.

```bash
python create_mvtec_ad_json_annotations.py --data_path ~/datasets/MVTec --annotation_path ~/training_extensions/data/MVtec/
```

This will generate three folders in `~/training_extensions/data/MVtec/` for classification, segmentation and detection task.

Then, to run sample.py you can use the following command.

```bash
python tools/sample.py --dataset_path ~/datasets/MVTec --category bottle --train-ann-files ../../data/MVtec/bottle/segmentation/train.json --val-ann-files ../../data/MVtec/bottle/segmentation/val.json --test-ann-files ../../data/MVtec/bottle/segmentation/test.json --model_template_path ./configs/anomaly_segmentation/padim/template.yaml
```

Optionally, you can also optimize to `nncf` or `pot` by using the `--optimization` flag
22 changes: 15 additions & 7 deletions external/anomaly/tools/sample.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@
import os
import shutil
from argparse import Namespace
from typing import Any, Dict, Type, Union
from typing import Any, Dict, Optional, Type, Union

from ote_anomalib import AnomalyNNCFTask, OpenVINOAnomalyTask
from ote_anomalib.data.dataset import (
Expand Down Expand Up @@ -61,13 +61,18 @@ def __init__(
val_subset: Dict[str, str],
test_subset: Dict[str, str],
model_template_path: str,
seed: Optional[int] = 0,
) -> None:
"""Initialize OteAnomalyTask.
Args:
dataset_path (str): Path to the MVTec dataset.
seed (int): Seed to split the dataset into train/val/test splits.
train_subset (Dict[str, str]): Dictionary containing path to train annotation file and path to dataset.
val_subset (Dict[str, str]): Dictionary containing path to validation annotation file and path to dataset.
test_subset (Dict[str, str]): Dictionary containing path to test annotation file and path to dataset.
model_template_path (str): Path to model template.
seed (Optional[int]): Setting seed to a value other than 0 also marks PytorchLightning trainer's
deterministic flag to True.
Example:
>>> import os
Expand All @@ -78,9 +83,12 @@ def __init__(
>>> model_template_path = "./configs/anomaly_classification/padim/template.yaml"
>>> dataset_path = "./datasets/MVTec"
>>> seed = 0
>>> task = OteAnomalyTask(
... dataset_path=dataset_path, seed=seed, model_template_path=model_template_path
... dataset_path=dataset_path,
... train_subset={"ann_file": train.json, "data_root": dataset_path},
... val_subset={"ann_file": val.json, "data_root": dataset_path},
... test_subset={"ann_file": test.json, "data_root": dataset_path},
... model_template_path=model_template_path
... )
>>> task.train()
Expand Down Expand Up @@ -110,6 +118,7 @@ def __init__(
self.openvino_task: OpenVINOAnomalyTask
self.nncf_task: AnomalyNNCFTask
self.results = {"category": dataset_path}
self.seed = seed

def get_dataclass(
self,
Expand Down Expand Up @@ -176,9 +185,7 @@ def train(self) -> ModelEntity:
configuration=self.task_environment.get_model_configuration(),
)
self.torch_task.train(
dataset=self.dataset,
output_model=output_model,
train_parameters=TrainParameters(),
dataset=self.dataset, output_model=output_model, train_parameters=TrainParameters(), seed=self.seed
)

logger.info("Inferring the base torch model on the validation set.")
Expand Down Expand Up @@ -364,6 +371,7 @@ def main() -> None:
val_subset=val_subset,
test_subset=test_subset,
model_template_path=args.model_template_path,
seed=args.seed,
)

task.train()
Expand Down

0 comments on commit b7c30a2

Please sign in to comment.