Skip to content

Commit

Permalink
Manage impossible examples SQuAD v2
Browse files Browse the repository at this point in the history
  • Loading branch information
LysandreJik committed Jan 21, 2020
1 parent 983c484 commit 073219b
Showing 1 changed file with 8 additions and 2 deletions.
10 changes: 8 additions & 2 deletions src/transformers/data/processors/squad.py
Original file line number Diff line number Diff line change
Expand Up @@ -242,6 +242,7 @@ def squad_convert_example_to_features(example, max_seq_length, doc_stride, max_q
token_to_orig_map=span["token_to_orig_map"],
start_position=start_position,
end_position=end_position,
is_impossible=span_is_impossible
)
)
return features
Expand Down Expand Up @@ -332,6 +333,7 @@ def squad_convert_examples_to_features(
all_token_type_ids = torch.tensor([f.token_type_ids for f in features], dtype=torch.long)
all_cls_index = torch.tensor([f.cls_index for f in features], dtype=torch.long)
all_p_mask = torch.tensor([f.p_mask for f in features], dtype=torch.float)
all_is_impossible = torch.tensor([f.is_impossible for f in features], dtype=torch.float)

if not is_training:
all_example_index = torch.arange(all_input_ids.size(0), dtype=torch.long)
Expand All @@ -349,6 +351,7 @@ def squad_convert_examples_to_features(
all_end_positions,
all_cls_index,
all_p_mask,
all_is_impossible
)

return features, dataset
Expand All @@ -369,14 +372,15 @@ def gen():
"end_position": ex.end_position,
"cls_index": ex.cls_index,
"p_mask": ex.p_mask,
"is_impossible": ex.is_impossible
},
)

return tf.data.Dataset.from_generator(
gen,
(
{"input_ids": tf.int32, "attention_mask": tf.int32, "token_type_ids": tf.int32},
{"start_position": tf.int64, "end_position": tf.int64, "cls_index": tf.int64, "p_mask": tf.int32},
{"start_position": tf.int64, "end_position": tf.int64, "cls_index": tf.int64, "p_mask": tf.int32, "is_impossible": tf.int32},
),
(
{
Expand All @@ -389,6 +393,7 @@ def gen():
"end_position": tf.TensorShape([]),
"cls_index": tf.TensorShape([]),
"p_mask": tf.TensorShape([None]),
"is_impossible": tf.TensorShape([])
},
),
)
Expand Down Expand Up @@ -658,6 +663,7 @@ def __init__(
token_to_orig_map,
start_position,
end_position,
is_impossible
):
self.input_ids = input_ids
self.attention_mask = attention_mask
Expand All @@ -674,7 +680,7 @@ def __init__(

self.start_position = start_position
self.end_position = end_position

self.is_impossible = is_impossible

class SquadResult(object):
"""
Expand Down

0 comments on commit 073219b

Please sign in to comment.