Skip to content

Commit

Permalink
fixes issue #346
Browse files Browse the repository at this point in the history
  • Loading branch information
“oguiza” committed Dec 28, 2021
1 parent 1e54768 commit 3da9451
Show file tree
Hide file tree
Showing 10 changed files with 410 additions and 537 deletions.
219 changes: 201 additions & 18 deletions nbs/050_losses.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -98,7 +98,7 @@
{
"data": {
"text/plain": [
"tensor(0.5384)"
"tensor(0.4309)"
]
},
"execution_count": null,
Expand Down Expand Up @@ -137,17 +137,10 @@
"execution_count": null,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"torch.Size([8, 30])\n"
]
},
{
"data": {
"text/plain": [
"(tensor(nan), tensor(0.8645))"
"(tensor(nan), tensor(0.8740))"
]
},
"execution_count": null,
Expand All @@ -162,27 +155,217 @@
"nn.L1Loss()(inp, targ), MaskedLossWrapper(nn.L1Loss())(inp, targ)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"#export\n",
"class CenterLoss(Module):\n",
" r\"\"\"\n",
" Code in Pytorch has been slightly modified from: https://github.com/KaiyangZhou/pytorch-center-loss/blob/master/center_loss.py\n",
" Based on paper: Wen et al. A Discriminative Feature Learning Approach for Deep Face Recognition. ECCV 2016.\n",
"\n",
" Args:\n",
" c_out (int): number of classes.\n",
" logits_dim (int): dim 1 of the logits. By default same as c_out (for one hot encoded logits)\n",
" \n",
" \"\"\"\n",
" def __init__(self, c_out, logits_dim=None):\n",
" logits_dim = ifnone(logits_dim, c_out)\n",
" self.c_out, self.logits_dim = c_out, logits_dim\n",
" self.centers = nn.Parameter(torch.randn(c_out, logits_dim))\n",
" self.classes = nn.Parameter(torch.arange(c_out).long(), requires_grad=False)\n",
"\n",
" def forward(self, x, labels):\n",
" \"\"\"\n",
" Args:\n",
" x: feature matrix with shape (batch_size, logits_dim).\n",
" labels: ground truth labels with shape (batch_size).\n",
" \"\"\"\n",
" bs = x.shape[0]\n",
" distmat = torch.pow(x, 2).sum(dim=1, keepdim=True).expand(bs, self.c_out) + \\\n",
" torch.pow(self.centers, 2).sum(dim=1, keepdim=True).expand(self.c_out, bs).T\n",
" distmat = torch.addmm(distmat, x, self.centers.T, beta=1, alpha=-2)\n",
"\n",
" labels = labels.unsqueeze(1).expand(bs, self.c_out)\n",
" mask = labels.eq(self.classes.expand(bs, self.c_out))\n",
"\n",
" dist = distmat * mask.float()\n",
" loss = dist.clamp(min=1e-12, max=1e+12).sum() / bs\n",
"\n",
" return loss\n",
"\n",
"\n",
"class CenterPlusLoss(Module):\n",
" \n",
" def __init__(self, loss, c_out, λ=1e-2, logits_dim=None):\n",
" self.loss, self.c_out, self.λ = loss, c_out, λ\n",
" self.centerloss = CenterLoss(c_out, logits_dim)\n",
" \n",
" def forward(self, x, labels):\n",
" return self.loss(x, labels) + self.λ * self.centerloss(x, labels)\n",
" def __repr__(self): return f\"CenterPlusLoss(loss={self.loss}, c_out={self.c_out}, λ={self.λ})\""
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"(tensor(10.6629, grad_fn=<DivBackward0>),\n",
" TensorBase(2.3617, grad_fn=<AliasBackward0>))"
]
},
"execution_count": null,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"c_in = 10\n",
"x = torch.rand(64, c_in).to(device=default_device())\n",
"x = F.softmax(x, dim=1)\n",
"label = x.max(dim=1).indices\n",
"CenterLoss(c_in).to(x.device)(x, label), CenterPlusLoss(LabelSmoothingCrossEntropyFlat(), c_in).to(x.device)(x, label)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [
{
"data": {
"text/html": [
"<img src onerror=\"\n",
" this.nextElementSibling.focus();\n",
" this.dispatchEvent(new KeyboardEvent('keydown', {key:'s', keyCode: 83, metaKey: true}));\n",
" \" style=\"display:none\"><input style=\"width:0;height:0;border:0\">"
],
"text/plain": [
"<IPython.core.display.HTML object>"
"CenterPlusLoss(loss=FlattenedLoss of LabelSmoothingCrossEntropy(), c_out=10, λ=0.01)"
]
},
"execution_count": null,
"metadata": {},
"output_type": "display_data"
"output_type": "execute_result"
}
],
"source": [
"CenterPlusLoss(LabelSmoothingCrossEntropyFlat(), c_in)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"#export\n",
"class FocalLoss(Module):\n",
" \"\"\" Weighted, multiclass focal loss\"\"\"\n",
"\n",
" def __init__(self, alpha:Optional[Tensor]=None, gamma:float=2., reduction:str='mean'):\n",
" \"\"\"\n",
" Args:\n",
" alpha (Tensor, optional): Weights for each class. Defaults to None.\n",
" gamma (float, optional): A constant, as described in the paper. Defaults to 2.\n",
" reduction (str, optional): 'mean', 'sum' or 'none'. Defaults to 'mean'.\n",
" \"\"\"\n",
" self.alpha, self.gamma, self.reduction = alpha, gamma, reduction\n",
" self.nll_loss = nn.NLLLoss(weight=alpha, reduction='none')\n",
"\n",
" def forward(self, x: Tensor, y: Tensor) -> Tensor:\n",
"\n",
" log_p = F.log_softmax(x, dim=-1)\n",
" pt = log_p[torch.arange(len(x)), y].exp()\n",
" ce = self.nll_loss(log_p, y)\n",
" loss = (1 - pt) ** self.gamma * ce\n",
"\n",
" if self.reduction == 'mean':\n",
" loss = loss.mean()\n",
" elif self.reduction == 'sum':\n",
" loss = loss.sum()\n",
" return loss\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"tensor(1.5199)"
]
},
"execution_count": null,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"inputs = torch.normal(0, 2, (16, 2)).to(device=default_device())\n",
"targets = torch.randint(0, 2, (16,)).to(device=default_device())\n",
"FocalLoss()(inputs, targets)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"#export\n",
"class TweedieLoss(Module):\n",
" def __init__(self, p=1.5, eps=1e-10):\n",
" \"\"\"\n",
" Tweedie loss as calculated in LightGBM\n",
" Args:\n",
" p: tweedie variance power (1 < p < 2)\n",
" eps: small number to avoid log(zero).\n",
" \"\"\"\n",
" assert p > 1 and p < 2, \"make sure 1 < p < 2\"\n",
" self.p, self.eps = p, eps\n",
"\n",
" def forward(self, inp, targ):\n",
" inp = inp.flatten()\n",
" targ = targ.flatten()\n",
" torch.clamp_min_(inp, self.eps)\n",
" a = targ * torch.exp((1 - self.p) * torch.log(inp)) / (1 - self.p)\n",
" b = torch.exp((2 - self.p) * torch.log(inp)) / (2 - self.p)\n",
" loss = -a + b\n",
" return loss.mean()"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"tensor(3.2877)"
]
},
"execution_count": null,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"c_in = 10\n",
"output = torch.rand(64).to(device=default_device())\n",
"target = torch.rand(64).to(device=default_device())\n",
"TweedieLoss().to(output.device)(output, target)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"#hide\n",
"from tsai.imports import create_scripts\n",
Expand All @@ -201,7 +384,7 @@
],
"metadata": {
"kernelspec": {
"display_name": "Python 3",
"display_name": "Python 3 (ipykernel)",
"language": "python",
"name": "python3"
}
Expand Down
52 changes: 35 additions & 17 deletions nbs/060_callback.core.ipynb

Large diffs are not rendered by default.

Loading

0 comments on commit 3da9451

Please sign in to comment.