-
Notifications
You must be signed in to change notification settings - Fork 36
/
Copy pathmodel.py
34 lines (27 loc) · 914 Bytes
/
model.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
from typing import Dict
import torch
from torch.nn import Embedding
class SeqClassifier(torch.nn.Module):
def __init__(
self,
embeddings: torch.tensor,
hidden_size: int,
num_layers: int,
dropout: float,
bidirectional: bool,
num_class: int,
) -> None:
super(SeqClassifier, self).__init__()
self.embed = Embedding.from_pretrained(embeddings, freeze=False)
# TODO: model architecture
@property
def encoder_output_size(self) -> int:
# TODO: calculate the output dimension of rnn
raise NotImplementedError
def forward(self, batch) -> Dict[str, torch.Tensor]:
# TODO: implement model forward
raise NotImplementedError
class SeqTagger(SeqClassifier):
def forward(self, batch) -> Dict[str, torch.Tensor]:
# TODO: implement model forward
raise NotImplementedError