diff --git a/src/anomalib/data/__init__.py b/src/anomalib/data/__init__.py index c2f3be9b03..5c56ae8d87 100644 --- a/src/anomalib/data/__init__.py +++ b/src/anomalib/data/__init__.py @@ -151,6 +151,7 @@ def get_datamodule(config: DictConfig | ListConfig) -> AnomalibDataModule: task=config.dataset.task, clip_length_in_frames=config.dataset.clip_length_in_frames, frames_between_clips=config.dataset.frames_between_clips, + target_frame=config.dataset.target_frame, image_size=(config.dataset.image_size[0], config.dataset.image_size[1]), center_crop=center_crop, normalization=config.dataset.normalization, @@ -169,6 +170,7 @@ def get_datamodule(config: DictConfig | ListConfig) -> AnomalibDataModule: task=config.dataset.task, clip_length_in_frames=config.dataset.clip_length_in_frames, frames_between_clips=config.dataset.frames_between_clips, + target_frame=config.dataset.target_frame, image_size=(config.dataset.image_size[0], config.dataset.image_size[1]), center_crop=center_crop, normalization=config.dataset.normalization, @@ -205,6 +207,7 @@ def get_datamodule(config: DictConfig | ListConfig) -> AnomalibDataModule: task=config.dataset.task, clip_length_in_frames=config.dataset.clip_length_in_frames, frames_between_clips=config.dataset.frames_between_clips, + target_frame=config.dataset.target_frame, image_size=(config.dataset.image_size[0], config.dataset.image_size[1]), center_crop=center_crop, normalization=config.dataset.normalization, diff --git a/src/anomalib/data/avenue.py b/src/anomalib/data/avenue.py index 517323d5e4..ff1dae3c2f 100644 --- a/src/anomalib/data/avenue.py +++ b/src/anomalib/data/avenue.py @@ -28,6 +28,7 @@ from pandas import DataFrame from anomalib.data.base import AnomalibVideoDataModule, AnomalibVideoDataset +from anomalib.data.base.video import VideoTargetFrame from anomalib.data.task_type import TaskType from anomalib.data.utils import ( DownloadInfo, @@ -140,6 +141,7 @@ class AvenueDataset(AnomalibVideoDataset): split (Split): Split of the dataset, usually Split.TRAIN or Split.TEST clip_length_in_frames (int, optional): Number of video frames in each clip. frames_between_clips (int, optional): Number of frames between each consecutive video clip. + target_frame (VideoTargetFrame): Specifies the target frame in the video clip, used for ground truth retrieval """ def __init__( @@ -151,8 +153,9 @@ def __init__( split: Split, clip_length_in_frames: int = 1, frames_between_clips: int = 1, + target_frame: VideoTargetFrame = VideoTargetFrame.LAST, ) -> None: - super().__init__(task, transform, clip_length_in_frames, frames_between_clips) + super().__init__(task, transform, clip_length_in_frames, frames_between_clips, target_frame) self.root = root if isinstance(root, Path) else Path(root) self.gt_dir = gt_dir if isinstance(gt_dir, Path) else Path(gt_dir) @@ -172,6 +175,7 @@ class Avenue(AnomalibVideoDataModule): gt_dir (Path | str): Path to the ground truth files clip_length_in_frames (int, optional): Number of video frames in each clip. frames_between_clips (int, optional): Number of frames between each consecutive video clip. + target_frame (VideoTargetFrame): Specifies the target frame in the video clip, used for ground truth retrieval task TaskType): Task type, 'classification', 'detection' or 'segmentation' image_size (int | tuple[int, int] | None, optional): Size of the input image. Defaults to None. @@ -198,6 +202,7 @@ def __init__( gt_dir: Path | str, clip_length_in_frames: int = 1, frames_between_clips: int = 1, + target_frame: VideoTargetFrame = VideoTargetFrame.LAST, task: TaskType = TaskType.SEGMENTATION, image_size: int | tuple[int, int] | None = None, center_crop: int | tuple[int, int] | None = None, @@ -241,6 +246,7 @@ def __init__( transform=transform_train, clip_length_in_frames=clip_length_in_frames, frames_between_clips=frames_between_clips, + target_frame=target_frame, root=root, gt_dir=gt_dir, split=Split.TRAIN, @@ -251,6 +257,7 @@ def __init__( transform=transform_eval, clip_length_in_frames=clip_length_in_frames, frames_between_clips=frames_between_clips, + target_frame=target_frame, root=root, gt_dir=gt_dir, split=Split.TEST, diff --git a/src/anomalib/data/shanghaitech.py b/src/anomalib/data/shanghaitech.py index da55c0ad9f..1244d2e267 100644 --- a/src/anomalib/data/shanghaitech.py +++ b/src/anomalib/data/shanghaitech.py @@ -28,6 +28,7 @@ from torch import Tensor from anomalib.data.base import AnomalibVideoDataModule, AnomalibVideoDataset +from anomalib.data.base.video import VideoTargetFrame from anomalib.data.task_type import TaskType from anomalib.data.utils import ( DownloadInfo, @@ -187,6 +188,7 @@ class ShanghaiTechDataset(AnomalibVideoDataset): split (Split): Split of the dataset, usually Split.TRAIN or Split.TEST clip_length_in_frames (int, optional): Number of video frames in each clip. frames_between_clips (int, optional): Number of frames between each consecutive video clip. + target_frame (VideoTargetFrame): Specifies the target frame in the video clip, used for ground truth retrieval """ def __init__( @@ -198,8 +200,9 @@ def __init__( split: Split, clip_length_in_frames: int = 1, frames_between_clips: int = 1, + target_frame: VideoTargetFrame = VideoTargetFrame.LAST, ): - super().__init__(task, transform, clip_length_in_frames, frames_between_clips) + super().__init__(task, transform, clip_length_in_frames, frames_between_clips, target_frame) self.root = root self.scene = scene @@ -219,6 +222,7 @@ class ShanghaiTech(AnomalibVideoDataModule): scene (int): Index of the dataset scene (category) in range [1, 13] clip_length_in_frames (int, optional): Number of video frames in each clip. frames_between_clips (int, optional): Number of frames between each consecutive video clip. + target_frame (VideoTargetFrame): Specifies the target frame in the video clip, used for ground truth retrieval task TaskType): Task type, 'classification', 'detection' or 'segmentation' image_size (int | tuple[int, int] | None, optional): Size of the input image. Defaults to None. @@ -245,6 +249,7 @@ def __init__( scene: int, clip_length_in_frames: int = 1, frames_between_clips: int = 1, + target_frame: VideoTargetFrame = VideoTargetFrame.LAST, task: TaskType = TaskType.SEGMENTATION, image_size: int | tuple[int, int] | None = None, center_crop: int | tuple[int, int] | None = None, @@ -288,6 +293,7 @@ def __init__( transform=transform_train, clip_length_in_frames=clip_length_in_frames, frames_between_clips=frames_between_clips, + target_frame=target_frame, root=root, scene=scene, split=Split.TRAIN, @@ -298,6 +304,7 @@ def __init__( transform=transform_eval, clip_length_in_frames=clip_length_in_frames, frames_between_clips=frames_between_clips, + target_frame=target_frame, root=root, scene=scene, split=Split.TEST, diff --git a/src/anomalib/data/ucsd_ped.py b/src/anomalib/data/ucsd_ped.py index 42e22ca9cb..e3ce32099c 100644 --- a/src/anomalib/data/ucsd_ped.py +++ b/src/anomalib/data/ucsd_ped.py @@ -18,6 +18,7 @@ from torch import Tensor from anomalib.data.base import AnomalibVideoDataModule, AnomalibVideoDataset +from anomalib.data.base.video import VideoTargetFrame from anomalib.data.task_type import TaskType from anomalib.data.utils import ( DownloadInfo, @@ -155,6 +156,7 @@ class UCSDpedDataset(AnomalibVideoDataset): split (str | Split | None): Split of the dataset, usually Split.TRAIN or Split.TEST clip_length_in_frames (int, optional): Number of video frames in each clip. frames_between_clips (int, optional): Number of frames between each consecutive video clip. + target_frame (VideoTargetFrame): Specifies the target frame in the video clip, used for ground truth retrieval """ def __init__( @@ -166,8 +168,9 @@ def __init__( split: Split, clip_length_in_frames: int = 1, frames_between_clips: int = 1, + target_frame: VideoTargetFrame = VideoTargetFrame.LAST, ) -> None: - super().__init__(task, transform, clip_length_in_frames, frames_between_clips) + super().__init__(task, transform, clip_length_in_frames, frames_between_clips, target_frame) self.root_category = Path(root) / category self.split = split @@ -186,6 +189,7 @@ class UCSDped(AnomalibVideoDataModule): category (str): Sub-category of the dataset, e.g. 'bottle' clip_length_in_frames (int, optional): Number of video frames in each clip. frames_between_clips (int, optional): Number of frames between each consecutive video clip. + target_frame (VideoTargetFrame): Specifies the target frame in the video clip, used for ground truth retrieval task (TaskType): Task type, 'classification', 'detection' or 'segmentation' image_size (int | tuple[int, int] | None, optional): Size of the input image. Defaults to None. @@ -215,6 +219,7 @@ def __init__( category: str, clip_length_in_frames: int = 1, frames_between_clips: int = 1, + target_frame: VideoTargetFrame = VideoTargetFrame.LAST, task: TaskType = TaskType.SEGMENTATION, image_size: int | tuple[int, int] | None = None, center_crop: int | tuple[int, int] | None = None, @@ -258,6 +263,7 @@ def __init__( transform=transform_train, clip_length_in_frames=clip_length_in_frames, frames_between_clips=frames_between_clips, + target_frame=target_frame, root=root, category=category, split=Split.TRAIN, @@ -268,6 +274,7 @@ def __init__( transform=transform_eval, clip_length_in_frames=clip_length_in_frames, frames_between_clips=frames_between_clips, + target_frame=target_frame, root=root, category=category, split=Split.TEST, diff --git a/src/anomalib/models/__init__.py b/src/anomalib/models/__init__.py index b881cf8cf4..f28d649c93 100644 --- a/src/anomalib/models/__init__.py +++ b/src/anomalib/models/__init__.py @@ -12,6 +12,7 @@ from omegaconf import DictConfig, ListConfig from torch import load +from anomalib.models.ai_vad import AiVad from anomalib.models.cfa import Cfa from anomalib.models.cflow import Cflow from anomalib.models.components import AnomalyModule @@ -41,6 +42,7 @@ "ReverseDistillation", "Rkde", "Stfpm", + "AiVad", ] logger = logging.getLogger(__name__) @@ -92,6 +94,7 @@ def get_model(config: DictConfig | ListConfig) -> AnomalyModule: "reverse_distillation", "rkde", "stfpm", + "ai_vad", ] model: AnomalyModule diff --git a/src/anomalib/models/ai_vad/__init__.py b/src/anomalib/models/ai_vad/__init__.py new file mode 100644 index 0000000000..9e1833abb4 --- /dev/null +++ b/src/anomalib/models/ai_vad/__init__.py @@ -0,0 +1,13 @@ +"""Implementatation of the AI-VAD Model. + +AI-VAD: Accurate and Interpretable Video Anomaly Detection + +Paper https://arxiv.org/pdf/2212.00789.pdf +""" + +# Copyright (C) 2023 Intel Corporation +# SPDX-License-Identifier: Apache-2.0 + +from .lightning_model import AiVad, AiVadLightning + +__all__ = ["AiVad", "AiVadLightning"] diff --git a/src/anomalib/models/ai_vad/clip/LICENSE b/src/anomalib/models/ai_vad/clip/LICENSE new file mode 100644 index 0000000000..c123b69334 --- /dev/null +++ b/src/anomalib/models/ai_vad/clip/LICENSE @@ -0,0 +1,21 @@ +MIT License + +Copyright (c) 2021 OpenAI + +Permission is hereby granted, free of charge, to any person obtaining a copy +of this software and associated documentation files (the "Software"), to deal +in the Software without restriction, including without limitation the rights +to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +copies of the Software, and to permit persons to whom the Software is +furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in all +copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +SOFTWARE. diff --git a/src/anomalib/models/ai_vad/clip/__init__.py b/src/anomalib/models/ai_vad/clip/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/src/anomalib/models/ai_vad/clip/clip.py b/src/anomalib/models/ai_vad/clip/clip.py new file mode 100644 index 0000000000..e5065f0b32 --- /dev/null +++ b/src/anomalib/models/ai_vad/clip/clip.py @@ -0,0 +1,213 @@ +# mypy: ignore-errors +# ruff: noqa + +# Original Code +# https://github.com/openai/CLIP. +# SPDX-License-Identifier: MIT +# +# Modified +# Copyright (C) 2023 Intel Corporation +# SPDX-License-Identifier: Apache-2.0 + +import hashlib +import os +import urllib +import warnings +from typing import List, Union + +import torch +from PIL import Image +from pkg_resources import packaging +from torchvision.transforms import CenterCrop, Compose, Normalize, Resize, ToTensor +from tqdm import tqdm + +from .model import build_model + +try: + from torchvision.transforms import InterpolationMode + + BICUBIC = InterpolationMode.BICUBIC +except ImportError: + BICUBIC = Image.BICUBIC + + +if packaging.version.parse(torch.__version__) < packaging.version.parse("1.7.1"): + warnings.warn("PyTorch version 1.7.1 or higher is recommended") + + +__all__ = ["available_models", "load"] + +_MODELS = { + "RN50": "https://openaipublic.azureedge.net/clip/models/afeb0e10f9e5a86da6080e35cf09123aca3b358a0c3e3b6c78a7b63bc04b6762/RN50.pt", + "RN101": "https://openaipublic.azureedge.net/clip/models/8fa8567bab74a42d41c5915025a8e4538c3bdbe8804a470a72f30b0d94fab599/RN101.pt", + "RN50x4": "https://openaipublic.azureedge.net/clip/models/7e526bd135e493cef0776de27d5f42653e6b4c8bf9e0f653bb11773263205fdd/RN50x4.pt", + "RN50x16": "https://openaipublic.azureedge.net/clip/models/52378b407f34354e150460fe41077663dd5b39c54cd0bfd2b27167a4a06ec9aa/RN50x16.pt", + "RN50x64": "https://openaipublic.azureedge.net/clip/models/be1cfb55d75a9666199fb2206c106743da0f6468c9d327f3e0d0a543a9919d9c/RN50x64.pt", + "ViT-B/32": "https://openaipublic.azureedge.net/clip/models/40d365715913c9da98579312b702a82c18be219cc2a73407c4526f58eba950af/ViT-B-32.pt", + "ViT-B/16": "https://openaipublic.azureedge.net/clip/models/5806e77cd80f8b59890b7e101eabd078d9fb84e6937f9e85e4ecb61988df416f/ViT-B-16.pt", + "ViT-L/14": "https://openaipublic.azureedge.net/clip/models/b8cca3fd41ae0c99ba7e8951adf17d267cdb84cd88be6f7c2e0eca1737a03836/ViT-L-14.pt", + "ViT-L/14@336px": "https://openaipublic.azureedge.net/clip/models/3035c92b350959924f9f00213499208652fc7ea050643e8b385c2dac08641f02/ViT-L-14-336px.pt", +} + + +def _download(url: str, root: str): + os.makedirs(root, exist_ok=True) + filename = os.path.basename(url) + + expected_sha256 = url.split("/")[-2] + download_target = os.path.join(root, filename) + + if os.path.exists(download_target) and not os.path.isfile(download_target): + raise RuntimeError(f"{download_target} exists and is not a regular file") + + if os.path.isfile(download_target): + if hashlib.sha256(open(download_target, "rb").read()).hexdigest() == expected_sha256: + return download_target + else: + warnings.warn(f"{download_target} exists, but the SHA256 checksum does not match; re-downloading the file") + + with urllib.request.urlopen(url) as source, open(download_target, "wb") as output: + with tqdm( + total=int(source.info().get("Content-Length")), ncols=80, unit="iB", unit_scale=True, unit_divisor=1024 + ) as loop: + while True: + buffer = source.read(8192) + if not buffer: + break + + output.write(buffer) + loop.update(len(buffer)) + + if hashlib.sha256(open(download_target, "rb").read()).hexdigest() != expected_sha256: + raise RuntimeError("Model has been downloaded but the SHA256 checksum does not not match") + + return download_target + + +def _convert_image_to_rgb(image): + return image.convert("RGB") + + +def _transform(n_px): + return Compose( + [ + Resize(n_px, interpolation=BICUBIC), + CenterCrop(n_px), + _convert_image_to_rgb, + ToTensor(), + Normalize((0.48145466, 0.4578275, 0.40821073), (0.26862954, 0.26130258, 0.27577711)), + ] + ) + + +def available_models() -> List[str]: + """Returns the names of available CLIP models""" + return list(_MODELS.keys()) + + +def load( + name: str, + device: Union[str, torch.device] = "cuda" if torch.cuda.is_available() else "cpu", + jit: bool = False, + download_root: str = None, +): + """Load a CLIP model + + Parameters + ---------- + name : str + A model name listed by `clip.available_models()`, or the path to a model checkpoint containing the state_dict + + device : Union[str, torch.device] + The device to put the loaded model + + jit : bool + Whether to load the optimized JIT model or more hackable non-JIT model (default). + + download_root: str + path to download the model files; by default, it uses "~/.cache/clip" + + Returns + ------- + model : torch.nn.Module + The CLIP model + + preprocess : Callable[[PIL.Image], torch.Tensor] + A torchvision transform that converts a PIL image into a tensor that the returned model can take as its input + """ + if name in _MODELS: + model_path = _download(_MODELS[name], download_root or os.path.expanduser("~/.cache/clip")) + elif os.path.isfile(name): + model_path = name + else: + raise RuntimeError(f"Model {name} not found; available models = {available_models()}") + + with open(model_path, "rb") as opened_file: + try: + # loading JIT archive + model = torch.jit.load(opened_file, map_location=device if jit else "cpu").eval() + state_dict = None + except RuntimeError: + # loading saved state dict + if jit: + warnings.warn(f"File {model_path} is not a JIT archive. Loading as a state dict instead") + jit = False + state_dict = torch.load(opened_file, map_location="cpu") + + if not jit: + model = build_model(state_dict or model.state_dict()).to(device) + if str(device) == "cpu": + model.float() + return model, _transform(model.visual.input_resolution) + + # patch the device names + device_holder = torch.jit.trace(lambda: torch.ones([]).to(torch.device(device)), example_inputs=[]) + device_node = [n for n in device_holder.graph.findAllNodes("prim::Constant") if "Device" in repr(n)][-1] + + def patch_device(module): + try: + graphs = [module.graph] if hasattr(module, "graph") else [] + except RuntimeError: + graphs = [] + + if hasattr(module, "forward1"): + graphs.append(module.forward1.graph) + + for graph in graphs: + for node in graph.findAllNodes("prim::Constant"): + if "value" in node.attributeNames() and str(node["value"]).startswith("cuda"): + node.copyAttributes(device_node) + + model.apply(patch_device) + patch_device(model.encode_image) + patch_device(model.encode_text) + + # patch dtype to float32 on CPU + if str(device) == "cpu": + float_holder = torch.jit.trace(lambda: torch.ones([]).float(), example_inputs=[]) + float_input = list(float_holder.graph.findNode("aten::to").inputs())[1] + float_node = float_input.node() + + def patch_float(module): + try: + graphs = [module.graph] if hasattr(module, "graph") else [] + except RuntimeError: + graphs = [] + + if hasattr(module, "forward1"): + graphs.append(module.forward1.graph) + + for graph in graphs: + for node in graph.findAllNodes("aten::to"): + inputs = list(node.inputs()) + for i in [1, 2]: # dtype can be the second or third argument to aten::to() + if inputs[i].node()["value"] == 5: + inputs[i].node().copyAttributes(float_node) + + model.apply(patch_float) + patch_float(model.encode_image) + patch_float(model.encode_text) + + model.float() + + return model, _transform(model.input_resolution.item()) diff --git a/src/anomalib/models/ai_vad/clip/model.py b/src/anomalib/models/ai_vad/clip/model.py new file mode 100644 index 0000000000..0395d9c904 --- /dev/null +++ b/src/anomalib/models/ai_vad/clip/model.py @@ -0,0 +1,476 @@ +# mypy: ignore-errors +# ruff: noqa + +# Original Code +# https://github.com/openai/CLIP. +# SPDX-License-Identifier: MIT +# +# Modified +# Copyright (C) 2023 Intel Corporation +# SPDX-License-Identifier: Apache-2.0 + +from collections import OrderedDict +from typing import Tuple, Union + +import numpy as np +import torch +import torch.nn.functional as F +from torch import nn + + +class Bottleneck(nn.Module): + expansion = 4 + + def __init__(self, inplanes, planes, stride=1): + super().__init__() + + # all conv layers have stride 1. an avgpool is performed after the second convolution when stride > 1 + self.conv1 = nn.Conv2d(inplanes, planes, 1, bias=False) + self.bn1 = nn.BatchNorm2d(planes) + self.relu1 = nn.ReLU(inplace=True) + + self.conv2 = nn.Conv2d(planes, planes, 3, padding=1, bias=False) + self.bn2 = nn.BatchNorm2d(planes) + self.relu2 = nn.ReLU(inplace=True) + + self.avgpool = nn.AvgPool2d(stride) if stride > 1 else nn.Identity() + + self.conv3 = nn.Conv2d(planes, planes * self.expansion, 1, bias=False) + self.bn3 = nn.BatchNorm2d(planes * self.expansion) + self.relu3 = nn.ReLU(inplace=True) + + self.downsample = None + self.stride = stride + + if stride > 1 or inplanes != planes * Bottleneck.expansion: + # downsampling layer is prepended with an avgpool, and the subsequent convolution has stride 1 + self.downsample = nn.Sequential( + OrderedDict( + [ + ("-1", nn.AvgPool2d(stride)), + ("0", nn.Conv2d(inplanes, planes * self.expansion, 1, stride=1, bias=False)), + ("1", nn.BatchNorm2d(planes * self.expansion)), + ] + ) + ) + + def forward(self, x: torch.Tensor): + identity = x + + out = self.relu1(self.bn1(self.conv1(x))) + out = self.relu2(self.bn2(self.conv2(out))) + out = self.avgpool(out) + out = self.bn3(self.conv3(out)) + + if self.downsample is not None: + identity = self.downsample(x) + + out += identity + out = self.relu3(out) + return out + + +class AttentionPool2d(nn.Module): + def __init__(self, spacial_dim: int, embed_dim: int, num_heads: int, output_dim: int = None): + super().__init__() + self.positional_embedding = nn.Parameter(torch.randn(spacial_dim**2 + 1, embed_dim) / embed_dim**0.5) + self.k_proj = nn.Linear(embed_dim, embed_dim) + self.q_proj = nn.Linear(embed_dim, embed_dim) + self.v_proj = nn.Linear(embed_dim, embed_dim) + self.c_proj = nn.Linear(embed_dim, output_dim or embed_dim) + self.num_heads = num_heads + + def forward(self, x): + x = x.flatten(start_dim=2).permute(2, 0, 1) # NCHW -> (HW)NC + x = torch.cat([x.mean(dim=0, keepdim=True), x], dim=0) # (HW+1)NC + x = x + self.positional_embedding[:, None, :].to(x.dtype) # (HW+1)NC + x, _ = F.multi_head_attention_forward( + query=x[:1], + key=x, + value=x, + embed_dim_to_check=x.shape[-1], + num_heads=self.num_heads, + q_proj_weight=self.q_proj.weight, + k_proj_weight=self.k_proj.weight, + v_proj_weight=self.v_proj.weight, + in_proj_weight=None, + in_proj_bias=torch.cat([self.q_proj.bias, self.k_proj.bias, self.v_proj.bias]), + bias_k=None, + bias_v=None, + add_zero_attn=False, + dropout_p=0, + out_proj_weight=self.c_proj.weight, + out_proj_bias=self.c_proj.bias, + use_separate_proj_weight=True, + training=self.training, + need_weights=False, + ) + return x.squeeze(0) + + +class ModifiedResNet(nn.Module): + """ + A ResNet class that is similar to torchvision's but contains the following changes: + - There are now 3 "stem" convolutions as opposed to 1, with an average pool instead of a max pool. + - Performs anti-aliasing strided convolutions, where an avgpool is prepended to convolutions with stride > 1 + - The final pooling layer is a QKV attention instead of an average pool + """ + + def __init__(self, layers, output_dim, heads, input_resolution=224, width=64): + super().__init__() + self.output_dim = output_dim + self.input_resolution = input_resolution + + # the 3-layer stem + self.conv1 = nn.Conv2d(3, width // 2, kernel_size=3, stride=2, padding=1, bias=False) + self.bn1 = nn.BatchNorm2d(width // 2) + self.relu1 = nn.ReLU(inplace=True) + self.conv2 = nn.Conv2d(width // 2, width // 2, kernel_size=3, padding=1, bias=False) + self.bn2 = nn.BatchNorm2d(width // 2) + self.relu2 = nn.ReLU(inplace=True) + self.conv3 = nn.Conv2d(width // 2, width, kernel_size=3, padding=1, bias=False) + self.bn3 = nn.BatchNorm2d(width) + self.relu3 = nn.ReLU(inplace=True) + self.avgpool = nn.AvgPool2d(2) + + # residual layers + self._inplanes = width # this is a *mutable* variable used during construction + self.layer1 = self._make_layer(width, layers[0]) + self.layer2 = self._make_layer(width * 2, layers[1], stride=2) + self.layer3 = self._make_layer(width * 4, layers[2], stride=2) + self.layer4 = self._make_layer(width * 8, layers[3], stride=2) + + embed_dim = width * 32 # the ResNet feature dimension + self.attnpool = AttentionPool2d(input_resolution // 32, embed_dim, heads, output_dim) + + def _make_layer(self, planes, blocks, stride=1): + layers = [Bottleneck(self._inplanes, planes, stride)] + + self._inplanes = planes * Bottleneck.expansion + for _ in range(1, blocks): + layers.append(Bottleneck(self._inplanes, planes)) + + return nn.Sequential(*layers) + + def forward(self, x): + def stem(x): + x = self.relu1(self.bn1(self.conv1(x))) + x = self.relu2(self.bn2(self.conv2(x))) + x = self.relu3(self.bn3(self.conv3(x))) + x = self.avgpool(x) + return x + + x = x.type(self.conv1.weight.dtype) + x = stem(x) + x = self.layer1(x) + x = self.layer2(x) + x = self.layer3(x) + x = self.layer4(x) + x = self.attnpool(x) + + return x + + +class LayerNorm(nn.LayerNorm): + """Subclass torch's LayerNorm to handle fp16.""" + + def forward(self, x: torch.Tensor): + orig_type = x.dtype + ret = super().forward(x.type(torch.float32)) + return ret.type(orig_type) + + +class QuickGELU(nn.Module): + def forward(self, x: torch.Tensor): + return x * torch.sigmoid(1.702 * x) + + +class ResidualAttentionBlock(nn.Module): + def __init__(self, d_model: int, n_head: int, attn_mask: torch.Tensor = None): + super().__init__() + + self.attn = nn.MultiheadAttention(d_model, n_head) + self.ln_1 = LayerNorm(d_model) + self.mlp = nn.Sequential( + OrderedDict( + [ + ("c_fc", nn.Linear(d_model, d_model * 4)), + ("gelu", QuickGELU()), + ("c_proj", nn.Linear(d_model * 4, d_model)), + ] + ) + ) + self.ln_2 = LayerNorm(d_model) + self.attn_mask = attn_mask + + def attention(self, x: torch.Tensor): + self.attn_mask = self.attn_mask.to(dtype=x.dtype, device=x.device) if self.attn_mask is not None else None + return self.attn(x, x, x, need_weights=False, attn_mask=self.attn_mask)[0] + + def forward(self, x: torch.Tensor): + x = x + self.attention(self.ln_1(x)) + x = x + self.mlp(self.ln_2(x)) + return x + + +class Transformer(nn.Module): + def __init__(self, width: int, layers: int, heads: int, attn_mask: torch.Tensor = None): + super().__init__() + self.width = width + self.layers = layers + self.resblocks = nn.Sequential(*[ResidualAttentionBlock(width, heads, attn_mask) for _ in range(layers)]) + + def forward(self, x: torch.Tensor): + return self.resblocks(x) + + +class VisionTransformer(nn.Module): + def __init__(self, input_resolution: int, patch_size: int, width: int, layers: int, heads: int, output_dim: int): + super().__init__() + self.input_resolution = input_resolution + self.output_dim = output_dim + self.conv1 = nn.Conv2d(in_channels=3, out_channels=width, kernel_size=patch_size, stride=patch_size, bias=False) + + scale = width**-0.5 + self.class_embedding = nn.Parameter(scale * torch.randn(width)) + self.positional_embedding = nn.Parameter(scale * torch.randn((input_resolution // patch_size) ** 2 + 1, width)) + self.ln_pre = LayerNorm(width) + + self.transformer = Transformer(width, layers, heads) + + self.ln_post = LayerNorm(width) + self.proj = nn.Parameter(scale * torch.randn(width, output_dim)) + + def forward(self, x: torch.Tensor): + x = self.conv1(x) # shape = [*, width, grid, grid] + x = x.reshape(x.shape[0], x.shape[1], -1) # shape = [*, width, grid ** 2] + x = x.permute(0, 2, 1) # shape = [*, grid ** 2, width] + x = torch.cat( + [ + self.class_embedding.to(x.dtype) + + torch.zeros(x.shape[0], 1, x.shape[-1], dtype=x.dtype, device=x.device), + x, + ], + dim=1, + ) # shape = [*, grid ** 2 + 1, width] + x = x + self.positional_embedding.to(x.dtype) + x = self.ln_pre(x) + + x = x.permute(1, 0, 2) # NLD -> LND + x = self.transformer(x) + x = x.permute(1, 0, 2) # LND -> NLD + + x = self.ln_post(x[:, 0, :]) + + if self.proj is not None: + x = x @ self.proj + + return x + + +class CLIP(nn.Module): + def __init__( + self, + embed_dim: int, + # vision + image_resolution: int, + vision_layers: Union[Tuple[int, int, int, int], int], + vision_width: int, + vision_patch_size: int, + # text + context_length: int, + vocab_size: int, + transformer_width: int, + transformer_heads: int, + transformer_layers: int, + ): + super().__init__() + + self.context_length = context_length + + if isinstance(vision_layers, (tuple, list)): + vision_heads = vision_width * 32 // 64 + self.visual = ModifiedResNet( + layers=vision_layers, + output_dim=embed_dim, + heads=vision_heads, + input_resolution=image_resolution, + width=vision_width, + ) + else: + vision_heads = vision_width // 64 + self.visual = VisionTransformer( + input_resolution=image_resolution, + patch_size=vision_patch_size, + width=vision_width, + layers=vision_layers, + heads=vision_heads, + output_dim=embed_dim, + ) + + self.transformer = Transformer( + width=transformer_width, + layers=transformer_layers, + heads=transformer_heads, + attn_mask=self.build_attention_mask(), + ) + + self.vocab_size = vocab_size + self.token_embedding = nn.Embedding(vocab_size, transformer_width) + self.positional_embedding = nn.Parameter(torch.empty(self.context_length, transformer_width)) + self.ln_final = LayerNorm(transformer_width) + + self.text_projection = nn.Parameter(torch.empty(transformer_width, embed_dim)) + self.logit_scale = nn.Parameter(torch.ones([]) * np.log(1 / 0.07)) + + self.initialize_parameters() + + def initialize_parameters(self): + nn.init.normal_(self.token_embedding.weight, std=0.02) + nn.init.normal_(self.positional_embedding, std=0.01) + + if isinstance(self.visual, ModifiedResNet): + if self.visual.attnpool is not None: + std = self.visual.attnpool.c_proj.in_features**-0.5 + nn.init.normal_(self.visual.attnpool.q_proj.weight, std=std) + nn.init.normal_(self.visual.attnpool.k_proj.weight, std=std) + nn.init.normal_(self.visual.attnpool.v_proj.weight, std=std) + nn.init.normal_(self.visual.attnpool.c_proj.weight, std=std) + + for resnet_block in [self.visual.layer1, self.visual.layer2, self.visual.layer3, self.visual.layer4]: + for name, param in resnet_block.named_parameters(): + if name.endswith("bn3.weight"): + nn.init.zeros_(param) + + proj_std = (self.transformer.width**-0.5) * ((2 * self.transformer.layers) ** -0.5) + attn_std = self.transformer.width**-0.5 + fc_std = (2 * self.transformer.width) ** -0.5 + for block in self.transformer.resblocks: + nn.init.normal_(block.attn.in_proj_weight, std=attn_std) + nn.init.normal_(block.attn.out_proj.weight, std=proj_std) + nn.init.normal_(block.mlp.c_fc.weight, std=fc_std) + nn.init.normal_(block.mlp.c_proj.weight, std=proj_std) + + if self.text_projection is not None: + nn.init.normal_(self.text_projection, std=self.transformer.width**-0.5) + + def build_attention_mask(self): + # lazily create causal attention mask, with full attention between the vision tokens + # pytorch uses additive attention mask; fill with -inf + mask = torch.empty(self.context_length, self.context_length) + mask.fill_(float("-inf")) + mask.triu_(1) # zero out the lower diagonal + return mask + + @property + def dtype(self): + return self.visual.conv1.weight.dtype + + def encode_image(self, image): + return self.visual(image.type(self.dtype)) + + def encode_text(self, text): + x = self.token_embedding(text).type(self.dtype) # [batch_size, n_ctx, d_model] + + x = x + self.positional_embedding.type(self.dtype) + x = x.permute(1, 0, 2) # NLD -> LND + x = self.transformer(x) + x = x.permute(1, 0, 2) # LND -> NLD + x = self.ln_final(x).type(self.dtype) + + # x.shape = [batch_size, n_ctx, transformer.width] + # take features from the eot embedding (eot_token is the highest number in each sequence) + x = x[torch.arange(x.shape[0]), text.argmax(dim=-1)] @ self.text_projection + + return x + + def forward(self, image, text): + image_features = self.encode_image(image) + text_features = self.encode_text(text) + + # normalized features + image_features = image_features / image_features.norm(dim=1, keepdim=True) + text_features = text_features / text_features.norm(dim=1, keepdim=True) + + # cosine similarity as logits + logit_scale = self.logit_scale.exp() + logits_per_image = logit_scale * image_features @ text_features.t() + logits_per_text = logits_per_image.t() + + # shape = [global_batch_size, global_batch_size] + return logits_per_image, logits_per_text + + +def convert_weights(model: nn.Module): + """Convert applicable model parameters to fp16""" + + def _convert_weights_to_fp16(l): + if isinstance(l, (nn.Conv1d, nn.Conv2d, nn.Linear)): + l.weight.data = l.weight.data.half() + if l.bias is not None: + l.bias.data = l.bias.data.half() + + if isinstance(l, nn.MultiheadAttention): + for attr in [*[f"{s}_proj_weight" for s in ["in", "q", "k", "v"]], "in_proj_bias", "bias_k", "bias_v"]: + tensor = getattr(l, attr) + if tensor is not None: + tensor.data = tensor.data.half() + + for name in ["text_projection", "proj"]: + if hasattr(l, name): + attr = getattr(l, name) + if attr is not None: + attr.data = attr.data.half() + + model.apply(_convert_weights_to_fp16) + + +def build_model(state_dict: dict): + vit = "visual.proj" in state_dict + + if vit: + vision_width = state_dict["visual.conv1.weight"].shape[0] + vision_layers = len( + [k for k in state_dict.keys() if k.startswith("visual.") and k.endswith(".attn.in_proj_weight")] + ) + vision_patch_size = state_dict["visual.conv1.weight"].shape[-1] + grid_size = round((state_dict["visual.positional_embedding"].shape[0] - 1) ** 0.5) + image_resolution = vision_patch_size * grid_size + else: + counts: list = [ + len(set(k.split(".")[2] for k in state_dict if k.startswith(f"visual.layer{b}"))) for b in [1, 2, 3, 4] + ] + vision_layers = tuple(counts) + vision_width = state_dict["visual.layer1.0.conv1.weight"].shape[0] + output_width = round((state_dict["visual.attnpool.positional_embedding"].shape[0] - 1) ** 0.5) + vision_patch_size = None + assert output_width**2 + 1 == state_dict["visual.attnpool.positional_embedding"].shape[0] + image_resolution = output_width * 32 + + embed_dim = state_dict["text_projection"].shape[1] + context_length = state_dict["positional_embedding"].shape[0] + vocab_size = state_dict["token_embedding.weight"].shape[0] + transformer_width = state_dict["ln_final.weight"].shape[0] + transformer_heads = transformer_width // 64 + transformer_layers = len(set(k.split(".")[2] for k in state_dict if k.startswith("transformer.resblocks"))) + + model = CLIP( + embed_dim, + image_resolution, + vision_layers, + vision_width, + vision_patch_size, + context_length, + vocab_size, + transformer_width, + transformer_heads, + transformer_layers, + ) + + for key in ["input_resolution", "context_length", "vocab_size"]: + if key in state_dict: + del state_dict[key] + + convert_weights(model) + model.load_state_dict(state_dict) + return model.eval() diff --git a/src/anomalib/models/ai_vad/config.yaml b/src/anomalib/models/ai_vad/config.yaml new file mode 100644 index 0000000000..cef4659319 --- /dev/null +++ b/src/anomalib/models/ai_vad/config.yaml @@ -0,0 +1,110 @@ +dataset: + name: ucsd #options: [mvtec, btech, folder] + format: ucsdped + path: ./datasets/ucsd + category: UCSDped2 + task: detection + clip_length_in_frames: 2 + frames_between_clips: 10 + target_frame: "last" + image_size: [240, 360] # AI-VAD region- and feature extractor apply custom resizing + normalization: none # AI-VAD region- and feature extractor apply custom normalization + train_batch_size: 8 + eval_batch_size: 1 + num_workers: 8 + transform_config: + train: null + eval: null + test_split_mode: from_dir + test_split_ratio: 0.2 + val_split_mode: same_as_test # options: [same_as_test, from_test] + val_split_ratio: 0.5 + tiling: + apply: false + tile_size: null + stride: null + remove_border_count: 0 + use_random_tiling: False + random_tile_count: 16 + +model: + name: ai_vad + box_score_thresh: 0.7 + n_velocity_bins: 1 + use_velocity_features: True + use_pose_features: True + use_deep_features: True + n_components_velocity: 2 + n_neighbors_pose: 1 + n_neighbors_deep: 1 + # generic params + normalization_method: min_max # options: [null, min_max, cdf] + +metrics: + image: + - AUROC + threshold: + method: adaptive #options: [adaptive, manual] + manual_image: null + manual_pixel: null + +visualization: + show_images: False # show images on the screen + save_images: True # save images to the file system + log_images: False # log images to the available loggers (if any) + image_save_path: null # path to which images will be saved + mode: full # options: ["full", "simple"] + +project: + seed: 42 + path: ./results + +logging: + logger: [] # options: [comet, tensorboard, wandb, csv] or combinations. + log_graph: false # Logs the model graph to respective logger. + +optimization: + export_mode: null # options: onnx, openvino + +# PL Trainer Args. Don't add extra parameter here. +trainer: + enable_checkpointing: true + default_root_dir: null + gradient_clip_val: 0 + gradient_clip_algorithm: norm + num_nodes: 1 + devices: 1 + enable_progress_bar: true + overfit_batches: 0.0 + track_grad_norm: -1 + check_val_every_n_epoch: 1 # Don't validate before extracting features. + fast_dev_run: false + accumulate_grad_batches: 1 + max_epochs: 1 + min_epochs: null + max_steps: -1 + min_steps: null + max_time: null + limit_train_batches: 1.0 + limit_val_batches: 1.0 + limit_test_batches: 1.0 + limit_predict_batches: 1.0 + val_check_interval: 1.0 # Don't validate before extracting features. + log_every_n_steps: 50 + accelerator: auto # <"cpu", "gpu", "tpu", "ipu", "hpu", "auto"> + strategy: null + sync_batchnorm: false + precision: 32 + enable_model_summary: true + num_sanity_val_steps: 0 + profiler: null + benchmark: false + deterministic: false + reload_dataloaders_every_n_epochs: 0 + auto_lr_find: false + replace_sampler_ddp: true + detect_anomaly: false + auto_scale_batch_size: false + plugins: null + move_metrics_to_cpu: false + multiple_trainloader_mode: max_size_cycle diff --git a/src/anomalib/models/ai_vad/density.py b/src/anomalib/models/ai_vad/density.py new file mode 100644 index 0000000000..0ed8cd705c --- /dev/null +++ b/src/anomalib/models/ai_vad/density.py @@ -0,0 +1,296 @@ +"""Density estimation module for AI-VAD model implementation.""" + +# Copyright (C) 2023 Intel Corporation +# SPDX-License-Identifier: Apache-2.0 + +from __future__ import annotations + +from abc import ABC, abstractmethod +from typing import Any + +import torch +from sklearn.mixture import GaussianMixture +from torch import Tensor, nn + +from anomalib.models.ai_vad.features import FeatureType +from anomalib.utils.metrics.min_max import MinMax + + +class BaseDensityEstimator(nn.Module, ABC): + """Base density estimator.""" + + @abstractmethod + def update(self, features: dict[FeatureType, Tensor] | Tensor, group: Any = None): + """Update the density model with a new set of features.""" + raise NotImplementedError + + @abstractmethod + def predict(self, features: dict[FeatureType, Tensor] | Tensor): + """Predict the density of a set of features.""" + raise NotImplementedError + + @abstractmethod + def fit(self): + """Compose model using collected features.""" + raise NotImplementedError + + def forward(self, features: dict[FeatureType, Tensor] | Tensor): + """Update or predict depending on training status.""" + if self.training: + self.update(features) + return None + return self.predict(features) + + +class CombinedDensityEstimator(BaseDensityEstimator): + """Density estimator for AI-VAD. + + Combines density estimators for the different feature types included in the model. + + Args: + use_velocity_features (bool): Flag indicating if velocity features should be used. + use_pose_features (bool): Flag indicating if pose features should be used. + use_deep_features (bool): Flag indicating if deep features should be used. + n_components_velocity (int): Number of components used by GMM density estimation for velocity features. + n_neighbors_pose (int): Number of neighbors used in KNN density estimation for pose features. + n_neighbors_deep (int): Number of neighbors used in KNN density estimation for deep features. + """ + + def __init__( + self, + use_pose_features: bool = True, + use_deep_features: bool = True, + use_velocity_features: bool = False, + n_neighbors_pose: int = 1, + n_neighbors_deep: int = 1, + n_components_velocity: int = 5, + ) -> None: + super().__init__() + + self.use_pose_features = use_pose_features + self.use_deep_features = use_deep_features + self.use_velocity_features = use_velocity_features + + if self.use_velocity_features: + self.velocity_estimator = GMMEstimator(n_components=n_components_velocity) + if self.use_deep_features: + self.appearance_estimator = GroupedKNNEstimator(n_neighbors_deep) + if self.use_pose_features: + self.pose_estimator = GroupedKNNEstimator(n_neighbors=n_neighbors_pose) + assert any((use_pose_features, use_deep_features, use_velocity_features)) + + def update(self, features: dict[FeatureType, Tensor], group: Any = None) -> None: + """Update the density estimators for the different feature types. + + Args: + features (dict[FeatureType, Tensor]): Dictionary containing extracted features for a single frame. + group (str): Identifier of the video from which the frame was sampled. Used for grouped density estimation. + """ + if self.use_velocity_features: + self.velocity_estimator.update(features[FeatureType.VELOCITY]) + if self.use_deep_features: + self.appearance_estimator.update(features[FeatureType.DEEP], group=group) + if self.use_pose_features: + self.pose_estimator.update(features[FeatureType.POSE], group=group) + + def fit(self): + """Fit the density estimation models on the collected features.""" + if self.use_velocity_features: + self.velocity_estimator.fit() + if self.use_deep_features: + self.appearance_estimator.fit() + if self.use_pose_features: + self.pose_estimator.fit() + + def predict(self, features: dict[FeatureType, Tensor]) -> tuple[Tensor, Tensor]: + """Predict the region- and image-level anomaly scores for an image based on a set of features. + + Args: + features (dict[Tensor]): Dictionary containing extracted features for a single frame. + + Returns: + Tensor: Region-level anomaly scores for all regions withing the frame. + Tensor: Frame-level anomaly score for the frame. + """ + n_regions = list(features.values())[0].shape[0] + device = list(features.values())[0].device + region_scores = torch.zeros(n_regions).to(device) + image_score = 0 + if self.use_velocity_features: + velocity_scores = self.velocity_estimator.predict(features[FeatureType.VELOCITY]) + region_scores += velocity_scores + image_score += velocity_scores.max() + if self.use_deep_features: + deep_scores = self.appearance_estimator.predict(features[FeatureType.DEEP]) + region_scores += deep_scores + image_score += deep_scores.max() + if self.use_pose_features: + pose_scores = self.pose_estimator.predict(features[FeatureType.POSE]) + region_scores += pose_scores + image_score += pose_scores.max() + return region_scores, image_score + + +class GroupedKNNEstimator(BaseDensityEstimator): + """Grouped KNN density estimator. + + Keeps track of the group (e.g. video id) from which the features were sampled for normalization purposes. + + Args: + n_neighbors (int): Number of neighbors used in KNN search. + """ + + def __init__(self, n_neighbors: int) -> None: + super().__init__() + + self.n_neighbors = n_neighbors + self.memory_bank: dict[Any, list[Tensor] | Tensor] = {} + self.normalization_statistics = MinMax() + + def update(self, features: Tensor, group: Any = None) -> None: + """Update the internal feature bank while keeping track of the group. + + Args: + features (Tensor): Feature vectors extracted from a video frame. + group (Any): Identifier of the group (video) from which the frame was sampled. + """ + group = group or "default" + + if group in self.memory_bank: + self.memory_bank[group].append(features) + else: + self.memory_bank[group] = [features] + + def fit(self) -> None: + """Fit the KNN model by stacking the feature vectors and computing the normalization statistics.""" + self.memory_bank = {key: torch.vstack(value) for key, value in self.memory_bank.items()} + self._compute_normalization_statistics() + + def predict(self, features: Tensor, group: Any = None, n_neighbors: int = 1, normalize: bool = True) -> Tensor: + """Predict the (normalized) density for a set of features. + + Args: + features (Tensor): Input features that will be compared to the density model. + group (Any, optional): Group (video id) from which the features originate. If passed, all features of the + same group in the memory bank will be excluded from the density estimation. + n_neighbors (int): Number of neighbors used in the KNN search. + normalize (bool): Flag indicating if the density should be normalized to min-max stats of the feature bank. + Returns: + Tensor: Mean (normalized) distances of input feature vectors to k nearest neighbors in feature bank. + """ + n_neighbors = n_neighbors or self.n_neighbors + + if group: + mem_bank = self.memory_bank.copy() + mem_bank.pop(group) + else: + mem_bank = self.memory_bank + + mem_bank_tensor = torch.vstack(list(mem_bank.values())) + + distances = self._nearest_neighbors(mem_bank_tensor, features, n_neighbors=n_neighbors) + + if normalize: + distances = self._normalize(distances) + + return distances.mean(axis=1) + + @staticmethod + def _nearest_neighbors(feature_bank: Tensor, features: Tensor, n_neighbors: int = 1): + """Perform the KNN search. + + Args: + feature_bank (Tensor): Feature bank used for KNN search. + features (Ternsor): Input features. + n_neighbors (int): Number of neighbors used in KNN search. + Returns: + Tensor: Distances between the input features and their K nearest neighbors in the feature bank. + """ + distances = torch.cdist(features, feature_bank, p=2.0) # euclidean norm + if n_neighbors == 1: + # when n_neighbors is 1, speed up computation by using min instead of topk + distances, _ = distances.min(1) + return distances.unsqueeze(1) + distances, _ = distances.topk(k=n_neighbors, largest=False, dim=1) + return distances + + def _compute_normalization_statistics(self): + """Compute min-max normalization statistics while taking the group into account.""" + for group, features in self.memory_bank.items(): + distances = self.predict(features, group, normalize=False) + self.normalization_statistics.update(distances) + + self.normalization_statistics.compute() + + def _normalize(self, distances: Tensor): + """Normalize distance predictions. + + Args: + distances (Tensor): Distance tensor produced by KNN search. + Returns: + Tensor: Normalized distances. + """ + return (distances - self.normalization_statistics.min) / ( + self.normalization_statistics.max - self.normalization_statistics.min + ) + + +class GMMEstimator(BaseDensityEstimator): + """Density estimation based on Gaussian Mixture Model. + + Args: + n_components (int): Number of components used in the GMM. + """ + + def __init__(self, n_components: int = 2) -> None: + super().__init__() + + # TODO: replace with custom pytorch implementation of GMM (CVS-109432) + self.gmm = GaussianMixture(n_components=n_components, random_state=0) + self.memory_bank: list[Tensor] | Tensor = [] + + self.normalization_statistics = MinMax() + + def update(self, features: Tensor, group: Any = None): + """Update the feature bank.""" + del group + self.memory_bank.append(features) + + def fit(self): + """Fit the GMM and compute normalization statistics.""" + self.memory_bank = torch.vstack(self.memory_bank) + self.gmm.fit(self.memory_bank.cpu()) + self._compute_normalization_statistics() + + def predict(self, features: Tensor, normalize: bool = True) -> Tensor: + """Predict the density of a set of feature vectors. + + Args: + features (Tensor): Input feature vectors. + normalize (bool): Flag indicating if the density should be normalized to min-max stats of the feature bank. + Returns: + Tensor: Density scores of the input feature vectors. + """ + density = -self.gmm.score_samples(features.cpu()) + density = Tensor(density).to(self.normalization_statistics.device) + if normalize: + density = self._normalize(density) + return density + + def _compute_normalization_statistics(self): + """Compute min-max normalization statistics over the feature bank.""" + training_scores = self.predict(self.memory_bank, normalize=False) + self.normalization_statistics.update(training_scores) + self.normalization_statistics.compute() + + def _normalize(self, density: Tensor): + """Normalize distance predictions. + + Args: + distances (Tensor): Distance tensor produced by KNN search. + Returns: + Tensor: Normalized distances. + """ + return (density - self.normalization_statistics.min) / ( + self.normalization_statistics.max - self.normalization_statistics.min + ) diff --git a/src/anomalib/models/ai_vad/features.py b/src/anomalib/models/ai_vad/features.py new file mode 100644 index 0000000000..0ee0e36c04 --- /dev/null +++ b/src/anomalib/models/ai_vad/features.py @@ -0,0 +1,244 @@ +"""Feature extraction module for AI-VAD model implementation""" + +# Copyright (C) 2023 Intel Corporation +# SPDX-License-Identifier: Apache-2.0 + +from __future__ import annotations + +from enum import Enum + +import torch +from torch import Tensor, nn +from torchvision.models.detection import KeypointRCNN_ResNet50_FPN_Weights, keypointrcnn_resnet50_fpn +from torchvision.models.detection.roi_heads import keypointrcnn_inference +from torchvision.ops import roi_align +from torchvision.transforms import Normalize + +from anomalib.models.ai_vad.clip import clip + + +class FeatureType(str, Enum): + """Names of the different feature streams used in AI-VAD.""" + + POSE = "pose" + VELOCITY = "velocity" + DEEP = "deep" + + +class FeatureExtractor(nn.Module): + """Feature extractor for AI-VAD. + + Args: + n_velocity_bins (int): Number of discrete bins used for velocity histogram features. + use_velocity_features (bool): Flag indicating if velocity features should be used. + use_pose_features (bool): Flag indicating if pose features should be used. + use_deep_features (bool): Flag indicating if deep features should be used. + """ + + def __init__( + self, + n_velocity_bins: int = 8, + use_velocity_features: bool = True, + use_pose_features: bool = True, + use_deep_features: bool = True, + ) -> None: + super().__init__() + assert ( + use_velocity_features or use_pose_features or use_deep_features + ), "At least one feature stream must be enabled." + + self.use_velocity_features = use_velocity_features + self.use_pose_features = use_pose_features + self.use_deep_features = use_deep_features + + self.deep_extractor = DeepExtractor() + self.velocity_extractor = VelocityExtractor(n_bins=n_velocity_bins) + self.pose_extractor = PoseExtractor() + + def forward( + self, + rgb_batch: Tensor, + flow_batch: Tensor, + regions: list[dict], + ) -> list[dict]: + """Forward pass through the feature extractor. + + Extract any combination of velocity, pose and deep features depending on configuration. + + Args: + rgb_batch (Tensor): Batch of RGB images of shape (N, 3, H, W) + flow_batch (Tensor): Batch of optical flow images of shape (N, 2, H, W) + regions (list[dict]): Region information per image in batch. + Returns: + list[dict]: Feature dictionary per image in batch. + """ + batch_size = rgb_batch.shape[0] + + # convert from list of [N, 4] tensors to single [N, 5] tensor where each row is [index-in-batch, x1, y1, x2, y2] + boxes_list = [batch_item["boxes"] for batch_item in regions] + indices = torch.repeat_interleave( + torch.arange(len(regions)), Tensor([boxes.shape[0] for boxes in boxes_list]).int() + ) + boxes = torch.cat([indices.unsqueeze(1).to(rgb_batch.device), torch.cat(boxes_list)], dim=1) + + # Extract features + feature_dict = {} + if self.use_velocity_features: + velocity_features = self.velocity_extractor(flow_batch, boxes) + feature_dict[FeatureType.VELOCITY] = [velocity_features[indices == i] for i in range(batch_size)] + if self.use_pose_features: + pose_features = self.pose_extractor(rgb_batch, boxes_list) + feature_dict[FeatureType.POSE] = pose_features + if self.use_deep_features: + deep_features = self.deep_extractor(rgb_batch, boxes, batch_size) + feature_dict[FeatureType.DEEP] = [deep_features[indices == i] for i in range(batch_size)] + + # dict of lists to list of dicts + feature_collection = [dict(zip(feature_dict, item)) for item in zip(*feature_dict.values())] + + return feature_collection + + +class DeepExtractor(nn.Module): + """Deep feature extractor. + + Extracts the deep (appearance) features from the input regions. + """ + + def __init__(self) -> None: + super().__init__() + + self.encoder, _ = clip.load("ViT-B/16") + self.transform = Normalize(mean=(0.48145466, 0.4578275, 0.40821073), std=(0.26862954, 0.26130258, 0.27577711)) + + def forward(self, batch: Tensor, boxes: Tensor, batch_size: int) -> Tensor: + """Extract deep features using CLIP encoder. + + Args: + batch (Tensor): Batch of RGB input images of shape (N, 3, H, W) + boxes (Tensor): Bounding box coordinates of shaspe (M, 5). First column indicates batch index of the bbox. + batch_size (int): Number of images in the batch. + + Returns: + Tensor: Deep feature tensor of shape (M, 512) + """ + rgb_regions = roi_align(batch, boxes, output_size=[224, 224]) + + features = [] + batched_regions = torch.split(rgb_regions, batch_size) + with torch.no_grad(): + features = torch.vstack([self.encoder.encode_image(self.transform(batch)) for batch in batched_regions]) + + return features + + +class VelocityExtractor(nn.Module): + """Velocity feature extractor. + + Extracts histograms of optical flow magnitude and direction. + + Args: + n_bins (int): Number of direction bins used for the feature histograms. + """ + + def __init__(self, n_bins: int = 8) -> None: + super().__init__() + + self.n_bins = n_bins + + def forward(self, flows: Tensor, boxes: Tensor) -> Tensor: + """Extract velocioty features by filling a histogram. + + Args: + flows (Tensor): Batch of optical flow images of shape (N, 2, H, W) + boxes (Tensor): Bounding box coordinates of shaspe (M, 5). First column indicates batch index of the bbox. + + Returns: + Tensor: Velocity feature tensor of shape (M, n_bins) + """ + flow_regions = roi_align(flows, boxes, output_size=[224, 224]) + + # cartesian to polar + mag_batch = torch.linalg.norm(flow_regions, axis=1, ord=2) + theta_batch = torch.atan2(flow_regions[:, 0, ...], flow_regions[:, 1, ...]) + + # compute velocity histogram + velocity_histograms = [] + for mag, theta in zip(mag_batch, theta_batch): + histogram_mag = torch.histogram( + input=theta.cpu(), bins=self.n_bins, range=(-torch.pi, torch.pi), weight=mag.cpu() + ).hist + histogram_counts = torch.histogram(input=theta.cpu(), bins=self.n_bins, range=(-torch.pi, torch.pi)).hist + final_histogram = torch.zeros_like(histogram_mag) + mask = histogram_counts != 0 + final_histogram[mask] = histogram_mag[mask] / histogram_counts[mask] + velocity_histograms.append(final_histogram) + + return torch.stack(velocity_histograms).to(flows.device) + + +class PoseExtractor(nn.Module): + """Pose feature extractor. + + Extracts pose features based on estimated body landmark keypoints. + """ + + def __init__(self, *args, **kwargs) -> None: + super().__init__(*args, **kwargs) + + weights = KeypointRCNN_ResNet50_FPN_Weights.DEFAULT + model = keypointrcnn_resnet50_fpn(weights=weights) + self.model = model + self.transform = model.transform + self.backbone = model.backbone + self.roi_heads = model.roi_heads + + @staticmethod + def _post_process(keypoint_detections: list[dict]) -> list[Tensor]: + """Convert keypoint predictions to 1D feature vectors. + + Post-processing consists of flattening and normalizing to bbox coordinates. + + Args: + keypoint_detections (list[dict]): Outputs of the keypoint extractor + + Returns: + list[Tensor]: List of pose feature tensors for each image + """ + poses = [] + for detection in keypoint_detections: + boxes = detection["boxes"].unsqueeze(1) + keypoints = detection["keypoints"] + normalized_keypoints = (keypoints[..., :2] - boxes[..., :2]) / (boxes[..., 2:] - boxes[..., :2]) + poses.append(normalized_keypoints.reshape(normalized_keypoints.shape[0], -1)) + return poses + + def forward(self, batch: Tensor, boxes: Tensor) -> list[Tensor]: + """Extract pose features using a human keypoint estimation model. + + Args: + batch (Tensor): Batch of RGB input images of shape (N, 3, H, W) + boxes (Tensor): Bounding box coordinates of shaspe (M, 5). First column indicates batch index of the bbox. + + Returns: + list[Tensor]: list of pose feature tensors for each image. + """ + images, _ = self.transform(batch) + features = self.backbone(images.tensors) + + image_sizes = [b.shape[-2:] for b in batch] + scales = [Tensor(new) / Tensor([orig[0], orig[1]]) for orig, new in zip(image_sizes, images.image_sizes)] + + boxes = [box * scale.repeat(2).to(box.device) for box, scale in zip(boxes, scales)] + + keypoint_features = self.roi_heads.keypoint_roi_pool(features, boxes, images.image_sizes) + keypoint_features = self.roi_heads.keypoint_head(keypoint_features) + keypoint_logits = self.roi_heads.keypoint_predictor(keypoint_features) + keypoints_probs, _ = keypointrcnn_inference(keypoint_logits, boxes) + + keypoint_detections = self.transform.postprocess( + [{"keypoints": keypoints, "boxes": box} for keypoints, box in zip(keypoints_probs, boxes)], + images.image_sizes, + image_sizes, + ) + return self._post_process(keypoint_detections) diff --git a/src/anomalib/models/ai_vad/flow.py b/src/anomalib/models/ai_vad/flow.py new file mode 100644 index 0000000000..525b9c015e --- /dev/null +++ b/src/anomalib/models/ai_vad/flow.py @@ -0,0 +1,61 @@ +"""Optical Flow extraction module for AI-VAD implementation.""" + +# Copyright (C) 2023 Intel Corporation +# SPDX-License-Identifier: Apache-2.0 + +from __future__ import annotations + +import torch +import torchvision.transforms.functional as F +from torch import Tensor, nn +from torchvision.models.optical_flow import Raft_Large_Weights, raft_large + + +class FlowExtractor(nn.Module): + """Optical Flow extractor. + + Computes the pixel displacement between 2 consecutive frames from a video clip. + """ + + def __init__(self, *args, **kwargs) -> None: + super().__init__(*args, **kwargs) + + weights = Raft_Large_Weights.DEFAULT + self.model = raft_large(weights=weights) + self.transforms = weights.transforms() + + def pre_process(self, first_frame: Tensor, last_frame: Tensor) -> tuple[Tensor, Tensor]: + """Resize inputs to dimensions required by backbone. + + Args: + first_frame (Tensor): Starting frame of optical flow computation. + last_frame (Tensor): Last frame of optical flow computation. + Returns: + tuple[Tensor, Tensor]: Preprocessed first and last frame. + """ + first_frame = F.resize(first_frame, size=[520, 960], antialias=False) + last_frame = F.resize(last_frame, size=[520, 960], antialias=False) + return self.transforms(first_frame, last_frame) + + def forward(self, first_frame: Tensor, last_frame: Tensor) -> Tensor: + """Forward pass through the flow extractor. + + Args: + first_frame (Tensor): Batch of starting frames of shape (N, 3, H, W). + last_frame (Tensor): Batch of last frames of shape (N, 3, H, W). + Returns: + Tensor: Estimated optical flow map of shape (N, 2, H, W). + """ + height, width = first_frame.shape[-2:] + + # preprocess batch + first_frame, last_frame = self.pre_process(first_frame, last_frame) + + # get flow maps + with torch.no_grad(): + flows = self.model(first_frame, last_frame)[-1] + + # convert back to original size + flows = F.resize(flows, [height, width], antialias=False) + + return flows diff --git a/src/anomalib/models/ai_vad/lightning_model.py b/src/anomalib/models/ai_vad/lightning_model.py new file mode 100644 index 0000000000..c11d9ef9f6 --- /dev/null +++ b/src/anomalib/models/ai_vad/lightning_model.py @@ -0,0 +1,124 @@ +"""Attribute-based Representations for Accurate and Interpretable Video Anomaly Detection. + +Paper https://arxiv.org/pdf/2212.00789.pdf +""" + +# Copyright (C) 2023 Intel Corporation +# SPDX-License-Identifier: Apache-2.0 + +from __future__ import annotations + +import logging + +from omegaconf import DictConfig, ListConfig +from pytorch_lightning.utilities.types import STEP_OUTPUT +from torch import Tensor + +from anomalib.models.ai_vad.torch_model import AiVadModel +from anomalib.models.components import AnomalyModule + +logger = logging.getLogger(__name__) + +__all__ = ["AiVad", "AiVadLightning"] + + +class AiVad(AnomalyModule): + """PaDiM: a Patch Distribution Modeling Framework for Anomaly Detection and Localization. + + Args: + layers (list[str]): Layers to extract features from the backbone CNN + input_size (tuple[int, int]): Size of the model input. + backbone (str): Backbone CNN network + pre_trained (bool, optional): Boolean to check whether to use a pre_trained backbone. + n_features (int, optional): Number of features to retain in the dimension reduction step. + Default values from the paper are available for: resnet18 (100), wide_resnet50_2 (550). + """ + + def __init__( + self, + box_score_thresh: float = 0.8, + n_velocity_bins: int = 8, + use_velocity_features: bool = True, + use_pose_features: bool = True, + use_deep_features: bool = True, + n_components_velocity: int = 5, + n_neighbors_pose: int = 1, + n_neighbors_deep: int = 1, + ) -> None: + super().__init__() + + self.model = AiVadModel( + box_score_thresh=box_score_thresh, + n_velocity_bins=n_velocity_bins, + use_velocity_features=use_velocity_features, + use_pose_features=use_pose_features, + use_deep_features=use_deep_features, + n_components_velocity=n_components_velocity, + n_neighbors_pose=n_neighbors_pose, + n_neighbors_deep=n_neighbors_deep, + ) + + @staticmethod + def configure_optimizers() -> None: + """TAI-VAD training does not involve fine-tuning of NN weights, no optimizers needed.""" + return None + + def training_step(self, batch: dict[str, str | Tensor]) -> None: + """Training Step of AI-VAD. + + Extract features from the batch of clips and update the density estimators. + + Args: + batch (dict[str, str | Tensor]): Batch containing image filename, image, label and mask + """ + features_per_batch = self.model(batch["image"]) + + for features, video_path in zip(features_per_batch, batch["video_path"]): + self.model.density_estimator.update(features, video_path) + + def on_validation_start(self) -> None: + """Fit the density estimators to the extracted features from the training set.""" + # NOTE: Previous anomalib versions fit Gaussian at the end of the epoch. + # This is not possible anymore with PyTorch Lightning v1.4.0 since validation + # is run within train epoch. + self.model.density_estimator.fit() + + def validation_step(self, batch: dict[str, str | Tensor], *args, **kwargs) -> STEP_OUTPUT: + """Validation Step of AI-VAD. + + Extract boxes and box scores.. + + Args: + batch (dict[str, str | Tensor]): Input batch + + Returns: + Batch dictionary with added boxes and box scores. + """ + boxes, anomaly_scores, image_scores = self.model(batch["image"]) + batch["pred_boxes"] = [box.int() for box in boxes] + batch["box_scores"] = [score.to(self.device) for score in anomaly_scores] + batch["pred_scores"] = Tensor(image_scores).to(self.device) + + return batch + + +class AiVadLightning(AiVad): + """PaDiM: a Patch Distribution Modeling Framework for Anomaly Detection and Localization. + + Args: + hparams (DictConfig | ListConfig): Model params + """ + + def __init__(self, hparams: DictConfig | ListConfig) -> None: + super().__init__( + box_score_thresh=hparams.model.box_score_thresh, + n_velocity_bins=hparams.model.n_velocity_bins, + use_velocity_features=hparams.model.use_velocity_features, + use_pose_features=hparams.model.use_pose_features, + use_deep_features=hparams.model.use_deep_features, + n_components_velocity=hparams.model.n_components_velocity, + n_neighbors_pose=hparams.model.n_neighbors_pose, + n_neighbors_deep=hparams.model.n_neighbors_deep, + ) + self.hparams: DictConfig | ListConfig # type: ignore + self.save_hyperparameters(hparams) diff --git a/src/anomalib/models/ai_vad/regions.py b/src/anomalib/models/ai_vad/regions.py new file mode 100644 index 0000000000..e48d711aaa --- /dev/null +++ b/src/anomalib/models/ai_vad/regions.py @@ -0,0 +1,37 @@ +"""Regions extraction module of AI-VAD model implementation.""" + +# Copyright (C) 2023 Intel Corporation +# SPDX-License-Identifier: Apache-2.0 + +from __future__ import annotations + +import torch +from torch import Tensor, nn +from torchvision.models.detection import MaskRCNN_ResNet50_FPN_V2_Weights, maskrcnn_resnet50_fpn_v2 + + +class RegionExtractor(nn.Module): + """Region extractor for AI-VAD. + + Args: + box_score_thresh (float): Confidence threshold for bounding box predictions. + """ + + def __init__(self, box_score_thresh: float = 0.8) -> None: + super().__init__() + + weights = MaskRCNN_ResNet50_FPN_V2_Weights.DEFAULT + self.backbone = maskrcnn_resnet50_fpn_v2(weights=weights, box_score_thresh=box_score_thresh, rpn_nms_thresh=0.3) + + def forward(self, batch: Tensor) -> list[dict]: + """Forward pass through region extractor. + + Args: + batch (Tensor): Batch of input images of shape (N, C, H, W) + Returns: + list[dict]: List of Mask RCNN predictions for each image in the batch. + """ + with torch.no_grad(): + regions = self.backbone(batch) + + return regions diff --git a/src/anomalib/models/ai_vad/torch_model.py b/src/anomalib/models/ai_vad/torch_model.py new file mode 100644 index 0000000000..90cdb7f7a3 --- /dev/null +++ b/src/anomalib/models/ai_vad/torch_model.py @@ -0,0 +1,111 @@ +"""PyTorch model for AI-VAD model implementation. + +Paper https://arxiv.org/pdf/2212.00789.pdf +""" + +# Copyright (C) 2023 Intel Corporation +# SPDX-License-Identifier: Apache-2.0 + +from __future__ import annotations + +import torch +from torch import Tensor, nn + +from anomalib.models.ai_vad.density import CombinedDensityEstimator +from anomalib.models.ai_vad.features import FeatureExtractor +from anomalib.models.ai_vad.flow import FlowExtractor +from anomalib.models.ai_vad.regions import RegionExtractor + + +class AiVadModel(nn.Module): + """AI-VAD model. + + Args: + box_score_thresh (float): Confidence threshold for region extraction stage. + n_velocity_bins (int): Number of discrete bins used for velocity histogram features. + use_velocity_features (bool): Flag indicating if velocity features should be used. + use_pose_features (bool): Flag indicating if pose features should be used. + use_deep_features (bool): Flag indicating if deep features should be used. + n_components_velocity (int): Number of components used by GMM density estimation for velocity features. + n_neighbors_pose (int): Number of neighbors used in KNN density estimation for pose features. + n_neighbors_deep (int): Number of neighbors used in KNN density estimation for deep features. + """ + + def __init__( + self, + # region-extraction params + box_score_thresh: float = 0.8, + # feature-extraction params + n_velocity_bins: int = 8, + use_velocity_features: bool = True, + use_pose_features: bool = True, + use_deep_features: bool = True, + # density-estimation params + n_components_velocity: int = 5, + n_neighbors_pose: int = 1, + n_neighbors_deep: int = 1, + ): + super().__init__() + if not any((use_velocity_features, use_pose_features, use_deep_features)): + raise ValueError("Select at least one feature type.") + + # initialize flow extractor + self.flow_extractor = FlowExtractor() + # initialize region extractor + self.region_extractor = RegionExtractor(box_score_thresh=box_score_thresh) + # initialize feature extractor + self.feature_extractor = FeatureExtractor( + n_velocity_bins=n_velocity_bins, + use_velocity_features=use_velocity_features, + use_pose_features=use_pose_features, + use_deep_features=use_deep_features, + ) + # initialize density estimator + self.density_estimator = CombinedDensityEstimator( + use_velocity_features=use_velocity_features, + use_pose_features=use_pose_features, + use_deep_features=use_deep_features, + n_components_velocity=n_components_velocity, + n_neighbors_pose=n_neighbors_pose, + n_neighbors_deep=n_neighbors_deep, + ) + + def forward(self, batch: Tensor) -> tuple[list[Tensor], list[Tensor], list[Tensor]]: + """Forward pass through AI-VAD model. + + Args: + batch (Tensor): Input image of shape (N, L, C, H, W) + Returns: + list[Tensor]: List of bbox locations for each image. + list[Tensor]: List of per-bbox anomaly scores for each image. + list[Tensor]: List of per-image anomaly scores. + """ + self.flow_extractor.eval() + self.region_extractor.eval() + self.feature_extractor.eval() + + # 1. get first and last frame from clip + first_frame = batch[:, 0, ...] + last_frame = batch[:, -1, ...] + + # 2. extract flows and regions + with torch.no_grad(): + flows = self.flow_extractor(first_frame, last_frame) + regions = self.region_extractor(last_frame) + + # 3. extract pose, appearance and velocity features + features_per_batch = self.feature_extractor(first_frame, flows, regions) + + if self.training: + return features_per_batch + + # 4. estimate density + box_scores = [] + image_scores = [] + for features in features_per_batch: + box, image = self.density_estimator(features) + box_scores.append(box) + image_scores.append(image) + + box_locations = [batch_item["boxes"] for batch_item in regions] + return box_locations, box_scores, image_scores diff --git a/src/anomalib/post_processing/post_process.py b/src/anomalib/post_processing/post_process.py index ba24f76de5..bbd96f23ff 100644 --- a/src/anomalib/post_processing/post_process.py +++ b/src/anomalib/post_processing/post_process.py @@ -166,6 +166,6 @@ def draw_boxes(image: np.ndarray, boxes: np.ndarray, color: tuple[int, int, int] np.ndarray: Image showing the bounding boxes drawn on top of the source image. """ for box in boxes: - x_1, y_1, x_2, y_2 = box.astype(np.int) + x_1, y_1, x_2, y_2 = box.astype(int) image = cv2.rectangle(image, (x_1, y_1), (x_2, y_2), color=color, thickness=2) return image diff --git a/third-party-programs.txt b/third-party-programs.txt index bdfab12be5..3155b2a930 100644 --- a/third-party-programs.txt +++ b/third-party-programs.txt @@ -38,3 +38,7 @@ terms are listed below. 6. Reference aupro-implementation contained in ./tests/helpers/aupro_reference.py Copyright (c) 2021 @eliahuhorwitz, https://github.com/eliahuhorwitz/3D-ADS. SPDX-License-Identifier: MIT + +7. CLIP neural network used for deep feature extraction in AI-VAD model + Copyright (c) 2022 @openai, https://github.com/openai/CLIP. + SPDX-License-Identifier: MIT