forked from ArneBinder/pytorch-ie
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathannotations.py
125 lines (87 loc) · 3.02 KB
/
annotations.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
from dataclasses import dataclass
from typing import Optional, Tuple
from pytorch_ie.core.document import Annotation
def _post_init_single_label(self):
if not isinstance(self.label, str):
raise ValueError("label must be a single string.")
if not isinstance(self.score, float):
raise ValueError("score must be a single float.")
def _post_init_multi_label(self):
if self.score is None:
score = tuple([1.0] * len(self.label))
object.__setattr__(self, "score", score)
if not isinstance(self.label, tuple):
object.__setattr__(self, "label", tuple(self.label))
if not isinstance(self.score, tuple):
object.__setattr__(self, "score", tuple(self.score))
if len(self.label) != len(self.score):
raise ValueError(
f"Number of labels ({len(self.label)}) and scores ({len(self.score)}) must be equal."
)
def _post_init_multi_span(self):
if isinstance(self.slices, list):
object.__setattr__(self, "slices", tuple(tuple(s) for s in self.slices))
@dataclass(eq=True, frozen=True)
class Label(Annotation):
label: str
score: float = 1.0
def __post_init__(self) -> None:
_post_init_single_label(self)
@dataclass(eq=True, frozen=True)
class MultiLabel(Annotation):
label: Tuple[str, ...]
score: Optional[Tuple[float, ...]] = None
def __post_init__(self) -> None:
_post_init_multi_label(self)
@dataclass(eq=True, frozen=True)
class Span(Annotation):
start: int
end: int
def __str__(self) -> str:
if self.target is None:
return ""
return str(self.target[self.start : self.end])
@dataclass(eq=True, frozen=True)
class LabeledSpan(Span):
label: str
score: float = 1.0
def __post_init__(self) -> None:
_post_init_single_label(self)
@dataclass(eq=True, frozen=True)
class MultiLabeledSpan(Span):
label: Tuple[str, ...]
score: Optional[Tuple[float, ...]] = None
def __post_init__(self) -> None:
_post_init_multi_label(self)
@dataclass(eq=True, frozen=True)
class LabeledMultiSpan(Annotation):
slices: Tuple[Tuple[int, int], ...]
label: str
score: float = 1.0
def __post_init__(self) -> None:
_post_init_multi_span(self)
_post_init_single_label(self)
@dataclass(eq=True, frozen=True)
class MultiLabeledMultiSpan(Annotation):
slices: Tuple[Tuple[int, int], ...]
label: Tuple[str, ...]
score: Optional[Tuple[float, ...]] = None
def __post_init__(self) -> None:
_post_init_multi_span(self)
_post_init_multi_label(self)
@dataclass(eq=True, frozen=True)
class BinaryRelation(Annotation):
head: Span
tail: Span
label: str
score: float = 1.0
def __post_init__(self) -> None:
_post_init_single_label(self)
@dataclass(eq=True, frozen=True)
class MultiLabeledBinaryRelation(Annotation):
head: Span
tail: Span
label: Tuple[str, ...]
score: Optional[Tuple[float, ...]] = None
def __post_init__(self) -> None:
_post_init_multi_label(self)