Skip to content

Commit

Permalink
return targets
Browse files Browse the repository at this point in the history
  • Loading branch information
elboy3 committed Dec 15, 2023
1 parent 5856764 commit 6d2745f
Show file tree
Hide file tree
Showing 3 changed files with 15 additions and 15 deletions.
2 changes: 1 addition & 1 deletion dataquality/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@
"""


__version__ = "1.4.0"
__version__ = "1.4.1"

import sys
from typing import Any, List, Optional
Expand Down
16 changes: 8 additions & 8 deletions dataquality/loggers/data_logger/seq2seq/formatters.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ def format_text(
tokenizer: PreTrainedTokenizerFast,
max_tokens: Optional[int],
split_key: str,
) -> Tuple[AlignedTokenData, List[List[str]], Optional[List[str]]]:
) -> Tuple[AlignedTokenData, List[List[str]], List[str]]:
"""Tokenize and align the `text` samples
`format_text` tokenizes and computes token alignments for
Expand Down Expand Up @@ -83,10 +83,10 @@ def format_text(
Aligned token data for *just* target tokens, based on `text`
token_label_str: List[List[str]]
The target tokens (as strings) - see `Seq2SeqDataLogger.token_label_str`
label_strs: Optional[List[str]]
targets: List[str]
The decoded response tokens - i.e. the string representation of the
Targets for each sample. Note that this is only computed for
Decoder-Only models.
Decoder-Only models. Returns [] for Encoder-Decoder
"""
pass

Expand Down Expand Up @@ -221,7 +221,7 @@ def format_text(
tokenizer: PreTrainedTokenizerFast,
max_tokens: Optional[int],
split_key: str,
) -> Tuple[AlignedTokenData, List[List[str]], Optional[List[str]]]:
) -> Tuple[AlignedTokenData, List[List[str]], List[str]]:
"""Further validation for Encoder-Decoder
For Encoder-Decoder we need to:
Expand Down Expand Up @@ -266,7 +266,7 @@ def format_text(
id_to_tokens = dict(zip(ids, token_label_ids))
self.logger_config.id_to_tokens[split_key].update(id_to_tokens)

return batch_aligned_data, token_label_str, None
return batch_aligned_data, token_label_str, []

@torch.no_grad()
def generate_sample(
Expand Down Expand Up @@ -430,7 +430,7 @@ def format_text(
# Empty initialization
batch_aligned_data = AlignedTokenData([], [])
token_label_str = []
str_labels = []
targets = []

# Decode then re-tokenize just the response labels to get correct offsets
for token_label_ids in tqdm(
Expand All @@ -456,7 +456,7 @@ def format_text(
max_input_tokens,
)
batch_aligned_data.append(response_aligned_data)
str_labels.append(response_str)
targets.append(response_str)

# Save the tokenized response labels for each samples
id_to_tokens = dict(zip(ids, tokenized_labels))
Expand All @@ -471,7 +471,7 @@ def format_text(
id_to_formatted_prompt_length
)

return batch_aligned_data, token_label_str, str_labels
return batch_aligned_data, token_label_str, targets

@torch.no_grad()
def generate_sample(
Expand Down
12 changes: 6 additions & 6 deletions dataquality/loggers/data_logger/seq2seq/seq2seq_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -118,14 +118,14 @@ def validate_and_format(self) -> None:
label_len = len(self.labels)
text_len = len(self.texts)
id_len = len(self.ids)
if label_len > 0:
if label_len > 0: # Encoder-Decoder case
assert id_len == text_len == label_len, (
"IDs, texts, and labels must be the same length, got "
f"({id_len} ids, {text_len} texts, {label_len} labels)"
)
else:
else: # Decoder-Only case
assert id_len == text_len, (
"IDs and textsmust be the same length, got "
"IDs and texts must be the same length, got "
f"({id_len} ids, {text_len} texts)"
)
assert self.logger_config.tokenizer, (
Expand All @@ -150,7 +150,7 @@ def validate_and_format(self) -> None:
(
batch_aligned_token_data,
token_label_str,
labels,
targets,
) = self.formatter.format_text(
text=texts,
ids=self.ids,
Expand All @@ -161,8 +161,8 @@ def validate_and_format(self) -> None:
self.token_label_offsets = batch_aligned_token_data.token_label_offsets
self.token_label_positions = batch_aligned_token_data.token_label_positions
self.token_label_str = token_label_str
if labels is not None:
self.labels = labels
if len(targets) > 0: # For Decoder-Only we update the 'targets' here
self.labels = targets

def _get_input_df(self) -> DataFrame:
df_dict = {
Expand Down

0 comments on commit 6d2745f

Please sign in to comment.