Skip to content

Commit

Permalink
updating optimzier_step
Browse files Browse the repository at this point in the history
  • Loading branch information
chanind committed May 1, 2022
1 parent 7aa0a49 commit ea79146
Showing 1 changed file with 7 additions and 7 deletions.
14 changes: 7 additions & 7 deletions frame_semantic_transformer/model.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from __future__ import annotations
from typing import Any, Optional
from typing import Any, Callable, Optional
import pytorch_lightning as pl
from transformers import (
T5ForConditionalGeneration,
Expand All @@ -9,8 +9,6 @@
import torch
from torch.utils.data import DataLoader, Dataset

from frame_semantic_transformer.data.load_framenet_samples import load_framenet_samples


class T5FineTuner(pl.LightningModule):
"""
Expand Down Expand Up @@ -54,7 +52,6 @@ def __init__(
self.gradient_accumulation_steps = gradient_accumulation_steps
self.warmup_steps = warmup_steps

samples = load_framenet_samples()
self.train_dataset = train_dataset
self.val_dataset = val_dataset

Expand Down Expand Up @@ -145,9 +142,12 @@ def optimizer_step( # type: ignore
self,
epoch: int,
batch_idx: int,
optimizer: Any,
optimizer_idx: int,
second_order_closure: Optional[Any] = None,
optimizer: torch.optim.Optimizer,
optimizer_idx: int = 0,
optimizer_closure: Optional[Callable[[], Any]] = None,
on_tpu: bool = False,
using_native_amp: bool = False,
using_lbfgs: bool = False,
) -> None:
optimizer.step()
optimizer.zero_grad()
Expand Down

0 comments on commit ea79146

Please sign in to comment.