Skip to content

Commit

Permalink
lora arg
Browse files Browse the repository at this point in the history
  • Loading branch information
wonjuleee committed Jul 15, 2024
1 parent 206f3d3 commit 92ef6d7
Show file tree
Hide file tree
Showing 3 changed files with 111 additions and 21 deletions.
38 changes: 22 additions & 16 deletions src/otx/algo/classification/backbones/vision_transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,16 +13,16 @@
LayerType,
Mlp,
PatchDropout,
PatchEmbed,
SwiGLUPacked,
get_act_layer,
get_norm_layer,
resample_abs_pos_embed,
resample_patch_embed,
trunc_normal_,
)
from timm.layers import PatchEmbed as TimmPatchEmbed
from timm.models._manipulate import adapt_input_conv
from timm.models.vision_transformer import Block
from timm.models.vision_transformer import Attention, Block
from torch import nn

from otx.algo.modules.base_module import BaseModule
Expand Down Expand Up @@ -203,12 +203,12 @@ def __init__( # noqa: PLR0913
proj_drop_rate: float = 0.0,
attn_drop_rate: float = 0.0,
drop_path_rate: float = 0.0,
embed_layer: Callable = TimmPatchEmbed,
embed_layer: Callable = PatchEmbed,
block_fn: nn.Module = Block,
mlp_layer: nn.Module | None = None,
act_layer: LayerType | None = None,
norm_layer: LayerType | None = None,
use_lora: bool = False,
lora: bool = False,
) -> None:
super().__init__()
if isinstance(arch, str):
Expand Down Expand Up @@ -292,11 +292,11 @@ def __init__( # noqa: PLR0913

self.norm = norm_layer(embed_dim)

self.use_lora = use_lora
if self.use_lora:
self.lora = lora
if self.lora:
lora_rank = 8
lora_alpha = 1.0
assign_lora = partial(QkvWithLoRA, rank=lora_rank, alpha=lora_alpha)
assign_lora = partial(AttentionWithLoRA, rank=lora_rank, alpha=lora_alpha)
for block in self.blocks:
block.attn.qkv = assign_lora(block.attn.qkv)

Expand Down Expand Up @@ -481,6 +481,7 @@ def _n2p(w: np.ndarray, t: bool = True, idx: int | None = None) -> torch.Tensor:
model.norm.bias.copy_(_n2p(w[f"{prefix}Transformer/encoder_norm/bias"]))

mha_sub, b_sub, ln1_sub = (0, 0, 1) if big_vision else (1, 3, 2)
idx: int | None = None
for i, block in enumerate(model.blocks.children()):
if f"{prefix}Transformer/encoderblock/LayerNorm_0/scale" in w:
block_prefix = f"{prefix}Transformer/encoderblock/"
Expand Down Expand Up @@ -526,7 +527,7 @@ def _n2p(w: np.ndarray, t: bool = True, idx: int | None = None) -> torch.Tensor:
model.norm.bias.copy_(_n2p(w[f"{prefix}Transformer/encoder_norm/bias"]))

mha_sub, b_sub, ln1_sub = (0, 0, 1) if big_vision else (1, 3, 2)
for i, block in enumerate(model.blocks.children()):
for i, block in enumerate(model.blocks.children()): # noqa: PLW2901
if f"{prefix}Transformer/encoderblock/LayerNorm_0/scale" in w:
block_prefix = f"{prefix}Transformer/encoderblock/"
idx = i
Expand All @@ -536,7 +537,7 @@ def _n2p(w: np.ndarray, t: bool = True, idx: int | None = None) -> torch.Tensor:
mha_prefix = block_prefix + f"MultiHeadDotProductAttention_{mha_sub}/"
block.norm1.weight.copy_(_n2p(w[f"{block_prefix}LayerNorm_0/scale"], idx=idx))
block.norm1.bias.copy_(_n2p(w[f"{block_prefix}LayerNorm_0/bias"], idx=idx))
if not self.use_lora:
if not self.lora:
block.attn.qkv.weight.copy_(
torch.cat(
[
Expand Down Expand Up @@ -584,27 +585,32 @@ def _n2p(w: np.ndarray, t: bool = True, idx: int | None = None) -> torch.Tensor:


class LoRALayer(torch.nn.Module):
def __init__(self, in_dim, out_dim, rank, alpha):
"""LoRA layer implementation for computing A, B composition."""

def __init__(self, in_dim: int, out_dim: int, rank: int, alpha: float):
super().__init__()
std = torch.sqrt(torch.tensor(rank).float())
self.A = torch.nn.Parameter(torch.randn(in_dim, rank) / std)
self.B = torch.nn.Parameter(torch.zeros(rank, out_dim))
self.alpha = alpha

def forward(self, x):
x = self.alpha * (x @ self.A @ self.B)
return x
def forward(self, x: torch.Tensor) -> torch.Tensor:
"""Forward pass of the LoRA layer."""
return self.alpha * (x @ self.A @ self.B)


class AttentionWithLoRA(torch.nn.Module):
"""Add LoRA layer into QKV attention layer in VisionTransformer."""

class QkvWithLoRA(torch.nn.Module):
def __init__(self, qkv, rank, alpha):
def __init__(self, qkv: Attention, rank: int, alpha: float):
super().__init__()
self.qkv = qkv
self.dim = qkv.in_features
self.lora_q = LoRALayer(self.dim, self.dim, rank, alpha)
self.lora_v = LoRALayer(self.dim, self.dim, rank, alpha)

def forward(self, x):
def forward(self, x: torch.Tensor) -> torch.Tensor:
"""Forward pass of the AttentionWithLoRA."""
qkv = self.qkv(x)
qkv[:, :, : self.dim] += self.lora_q(x)
qkv[:, :, -self.dim :] += self.lora_v(x)
Expand Down
93 changes: 88 additions & 5 deletions src/otx/algo/classification/vit.py
Original file line number Diff line number Diff line change
Expand Up @@ -220,13 +220,15 @@ def __init__(
self,
label_info: LabelInfoTypes,
arch: VIT_ARCH_TYPE = "vit-tiny",
lora: bool = False,
pretrained: bool = True,
optimizer: OptimizerCallable = DefaultOptimizerCallable,
scheduler: LRSchedulerCallable | LRSchedulerListCallable = DefaultSchedulerCallable,
metric: MetricCallable = MultiClassClsMetricCallable,
torch_compile: bool = False,
) -> None:
self.arch = arch
self.lora = lora
self.pretrained = pretrained
super().__init__(
label_info=label_info,
Expand Down Expand Up @@ -279,7 +281,7 @@ def _build_model(self, num_classes: int) -> nn.Module:
{"std": 0.2, "layer": "Linear", "type": "TruncNormal"},
{"bias": 0.0, "val": 1.0, "layer": "LayerNorm", "type": "Constant"},
]
vit_backbone = VisionTransformer(arch=self.arch, img_size=224)
vit_backbone = VisionTransformer(arch=self.arch, img_size=224, lora=self.lora)
return ImageClassifier(
backbone=vit_backbone,
neck=None,
Expand Down Expand Up @@ -446,7 +448,7 @@ def training_step(self, batch: MulticlassClsBatchDataEntity, batch_idx: int) ->
return loss


class VisionTransformerForMultilabelCls(VisionTransformerForMulticlassCls, OTXMultilabelClsModel):
class VisionTransformerForMultilabelCls(ForwardExplainMixInForViT, OTXMultilabelClsModel):
"""DeitTiny Model for multi-class classification task."""

model: ImageClassifier
Expand All @@ -455,13 +457,15 @@ def __init__(
self,
label_info: LabelInfoTypes,
arch: VIT_ARCH_TYPE = "vit-tiny",
lora: bool = False,
pretrained: bool = True,
optimizer: OptimizerCallable = DefaultOptimizerCallable,
scheduler: LRSchedulerCallable | LRSchedulerListCallable = DefaultSchedulerCallable,
metric: MetricCallable = MultiLabelClsMetricCallable,
torch_compile: bool = False,
) -> None:
self.arch = arch
self.lora = lora
self.pretrained = pretrained

super().__init__(
Expand All @@ -472,12 +476,49 @@ def __init__(
torch_compile=torch_compile,
)

def load_from_otx_v1_ckpt(self, state_dict: dict, add_prefix: str = "model.") -> dict:
"""Load the previous OTX ckpt according to OTX2.0."""
for key in list(state_dict.keys()):
new_key = key.replace("patch_embed.projection", "patch_embed.proj")
new_key = new_key.replace("backbone.ln1", "backbone.norm")
new_key = new_key.replace("ffn.layers.0.0", "mlp.fc1")
new_key = new_key.replace("ffn.layers.1", "mlp.fc2")
new_key = new_key.replace("layers", "blocks")
new_key = new_key.replace("ln", "norm")
if new_key != key:
state_dict[new_key] = state_dict.pop(key)
return OTXv1Helper.load_cls_effnet_b0_ckpt(state_dict, "multiclass", add_prefix)

def _create_model(self) -> nn.Module:
# Get classification_layers for class-incr learning
sample_model_dict = self._build_model(num_classes=5).state_dict()
incremental_model_dict = self._build_model(num_classes=6).state_dict()
self.classification_layers = get_classification_layers(
sample_model_dict,
incremental_model_dict,
prefix="model.",
)

model = self._build_model(num_classes=self.num_classes)
model.init_weights()
if self.pretrained and self.arch in pretrained_urls:
print(f"init weight - {pretrained_urls[self.arch]}")
parts = urlparse(pretrained_urls[self.arch])
filename = Path(parts.path).name

cache_dir = Path.home() / ".cache" / "torch" / "hub" / "checkpoints"
cache_file = cache_dir / filename
if not Path.exists(cache_file):
download_url_to_file(pretrained_urls[self.arch], cache_file, "", progress=True)
model.backbone.load_pretrained(checkpoint_path=cache_file)
return model

def _build_model(self, num_classes: int) -> nn.Module:
init_cfg = [
{"std": 0.2, "layer": "Linear", "type": "TruncNormal"},
{"bias": 0.0, "val": 1.0, "layer": "LayerNorm", "type": "Constant"},
]
vit_backbone = VisionTransformer(arch=self.arch, img_size=224)
vit_backbone = VisionTransformer(arch=self.arch, img_size=224, lora=self.lora)
return ImageClassifier(
backbone=vit_backbone,
neck=None,
Expand Down Expand Up @@ -537,7 +578,7 @@ def _customize_outputs(
)


class VisionTransformerForHLabelCls(VisionTransformerForMulticlassCls, OTXHlabelClsModel):
class VisionTransformerForHLabelCls(ForwardExplainMixInForViT, OTXHlabelClsModel):
"""DeitTiny Model for hierarchical label classification task."""

model: ImageClassifier
Expand All @@ -547,13 +588,15 @@ def __init__(
self,
label_info: HLabelInfo,
arch: VIT_ARCH_TYPE = "vit-tiny",
lora: bool = False,
pretrained: bool = True,
optimizer: OptimizerCallable = DefaultOptimizerCallable,
scheduler: LRSchedulerCallable | LRSchedulerListCallable = DefaultSchedulerCallable,
metric: MetricCallable = HLabelClsMetricCallble,
torch_compile: bool = False,
) -> None:
self.arch = arch
self.lora = lora
self.pretrained = pretrained

super().__init__(
Expand All @@ -564,14 +607,54 @@ def __init__(
torch_compile=torch_compile,
)

def load_from_otx_v1_ckpt(self, state_dict: dict, add_prefix: str = "model.") -> dict:
"""Load the previous OTX ckpt according to OTX2.0."""
for key in list(state_dict.keys()):
new_key = key.replace("patch_embed.projection", "patch_embed.proj")
new_key = new_key.replace("backbone.ln1", "backbone.norm")
new_key = new_key.replace("ffn.layers.0.0", "mlp.fc1")
new_key = new_key.replace("ffn.layers.1", "mlp.fc2")
new_key = new_key.replace("layers", "blocks")
new_key = new_key.replace("ln", "norm")
if new_key != key:
state_dict[new_key] = state_dict.pop(key)
return OTXv1Helper.load_cls_effnet_b0_ckpt(state_dict, "multiclass", add_prefix)

def _create_model(self) -> nn.Module:
# Get classification_layers for class-incr learning
sample_config = deepcopy(self.label_info.as_head_config_dict())
sample_config["num_classes"] = 5
sample_model_dict = self._build_model(head_config=sample_config).state_dict()
sample_config["num_classes"] = 6
incremental_model_dict = self._build_model(head_config=sample_config).state_dict()
self.classification_layers = get_classification_layers(
sample_model_dict,
incremental_model_dict,
prefix="model.",
)

model = self._build_model(head_config=self.label_info.as_head_config_dict())
model.init_weights()
if self.pretrained and self.arch in pretrained_urls:
print(f"init weight - {pretrained_urls[self.arch]}")
parts = urlparse(pretrained_urls[self.arch])
filename = Path(parts.path).name

cache_dir = Path.home() / ".cache" / "torch" / "hub" / "checkpoints"
cache_file = cache_dir / filename
if not Path.exists(cache_file):
download_url_to_file(pretrained_urls[self.arch], cache_file, "", progress=True)
model.backbone.load_pretrained(checkpoint_path=cache_file)
return model

def _build_model(self, head_config: dict) -> nn.Module:
if not isinstance(self.label_info, HLabelInfo):
raise TypeError(self.label_info)
init_cfg = [
{"std": 0.2, "layer": "Linear", "type": "TruncNormal"},
{"bias": 0.0, "val": 1.0, "layer": "LayerNorm", "type": "Constant"},
]
vit_backbone = VisionTransformer(arch=self.arch, img_size=224)
vit_backbone = VisionTransformer(arch=self.arch, img_size=224, lora=self.lora)
return ImageClassifier(
backbone=vit_backbone,
neck=None,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ model:
init_args:
label_info: 1000
arch: "vit-tiny"
lora: False

optimizer:
class_path: torch.optim.AdamW
Expand Down

0 comments on commit 92ef6d7

Please sign in to comment.