Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: s2s auto chat support #779

Merged
merged 4 commits into from
Oct 24, 2023
Merged

feat: s2s auto chat support #779

merged 4 commits into from
Oct 24, 2023

Conversation

elboy3
Copy link
Contributor

@elboy3 elboy3 commented Oct 23, 2023

V1 of auto chat support. This just unravels the turns and has input as user, target as assistant

V2 will include some of the chat history

https://app.shortcut.com/galileo/story/8388/dq-support-basic-chat-models

Did e2e tests for:
✅ chat data link
✅ auto with alpaca link
✅ completion dataset

Comment on lines 115 to 129
# Add validation data if missing, add 'id' column
dd, dataset_config = self._validate_dataset_dict(dd, []), dataset_config
# Apply the datasets custom formatter on load dataset dict
col_names = (
dd[Split.train].column_names
if dataset_config.formatter.remove_columns
else []
)
dd = dd.map(
dataset_config.formatter.format_batch,
batched=True,
remove_columns=col_names,
with_indices=True,
)
dd, dataset_config = self._validate_dataset_dict(dd, []), dataset_config
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

will clean this up before we merge

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nice nice! I think this looks pretty clean to me though. Is the assumption that you could also be formatting e.g. with the Alpaca formatter?

Comment on lines 48 to 56
metadata = sample.get(self.metadata_col, {})
sample_cols = [
col
for col in sample.keys()
if col not in [self.metadata_col, self.turns_col]
]
for col in sample_cols:
metadata[col] = sample[col]
unraveled_turns = unraveled_turns | metadata
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

will also clean this up

@codecov-commenter
Copy link

codecov-commenter commented Oct 23, 2023

Codecov Report

Merging #779 (3b36f93) into main (6c77a8c) will decrease coverage by 0.32%.
The diff coverage is 34.28%.

@@            Coverage Diff             @@
##             main     #779      +/-   ##
==========================================
- Coverage   87.72%   87.41%   -0.32%     
==========================================
  Files         184      186       +2     
  Lines       15127    15195      +68     
==========================================
+ Hits        13270    13282      +12     
- Misses       1857     1913      +56     
Files Coverage Δ
dataquality/dq_auto/schema.py 93.10% <100.00%> (ø)
...ataquality/loggers/data_logger/base_data_logger.py 88.34% <100.00%> (ø)
...aquality/integrations/seq2seq/formatters/alpaca.py 76.92% <76.92%> (ø)
dataquality/integrations/seq2seq/auto.py 0.00% <0.00%> (ø)
...ataquality/integrations/seq2seq/formatters/base.py 64.86% <64.86%> (ø)
...ataquality/integrations/seq2seq/formatters/chat.py 0.00% <0.00%> (ø)

... and 3 files with indirect coverage changes

📣 We’re building smart automated test selection to slash your CI/CD build times. Learn more

@@ -0,0 +1,19 @@
from typing import Dict, Type
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I never quite understand when to put things in the init?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

things that are common to all the formatters, like this map!

target_col: str = "output"
max_train_size: int = 1000

def format_sample(self, sample: Dict[str, str], idx: int) -> Dict[str, str]:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What is idx used for?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should this be made optional?

max_train_size: Optional[int] = None
remove_columns: bool = False

def format_batch(self, batch: Dict, idxs: List[int]) -> Dict[str, List]:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This approach doesn't seem bad! There may be a bit of extra data copying, but for a working solution this seems fine.

My only thought is you could have a wrapper data class something like:

class DataBatch():
  batch: Dict
  active_row: int = 0
  ...

  def get(key: str):
     return batch[key][active_row]
  ...

This could at least help avoid some of the copying involved in creating the sample. You would still need to return a Dict and add to the result I guess.

@elboy3 elboy3 marked this pull request as ready for review October 23, 2023 22:25
@elboy3 elboy3 requested review from dcaustin33 and a team as code owners October 23, 2023 22:25
Copy link
Contributor

@jonathangomesselman jonathangomesselman left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Awesome 🚢 !! I left just some small comments. On to the chat with history now :)

)
# We must re-add the id column if it's been dropped
dd = self._validate_dataset_dict(dd, [])
return dd, dataset_config
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Beautiful!

target_col: str = "output"
max_train_size: int = 1000

def format_sample(self, sample: Dict[str, str], idx: int) -> Dict[str, str]:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should this be made optional?



@dataclass
class BatchData:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nice 👍

# Add sample level metadata
turn_data.update(metadata)
for k, v in turn_data.items():
# NOTE: When we drop p3.8 we can use 'turn_data |= turn_meta'
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This comment maybe is meant to be above

@elboy3 elboy3 merged commit e40e48f into main Oct 24, 2023
@elboy3 elboy3 deleted the feat/s2s-auto-chat-support branch October 24, 2023 02:53
@elboy3 elboy3 mentioned this pull request Oct 24, 2023
Copy link
Member

@setu4993 setu4993 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Left a few questions / notes. Feel free to take it or leave it!

max_train_sz = (
dataset_config.max_train_size or dataset_config.formatter.max_train_size
)
max_train_sz = max_train_size or dataset_config.formatter.max_train_size
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Seeing that max_train_size is optional even for formatters, isn't it possible we'd end up with None here?

Comment on lines +70 to +79
# Add metadata to each turn
turn_meta = {
f"{role}_{col}": turn[col]
for col in turn.keys()
if col not in turn_default_cols
and isinstance(turn[col], valid_meta_types)
}
# Add turn level metadata to turn
# NOTE: When we drop p3.8 we can use 'turn_data |= turn_meta'
turn_data.update(turn_meta)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not seeing this in the docstring example above?

Comment on lines +92 to +93
# Reset turn data
turn_data = {}
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why only under the elif?

# Add validation data if missing, add 'id' column
dd = self._validate_dataset_dict(dd, [])
formatter = dataset_config.formatter
if formatter.process_batch:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

When would this be false?

# We must re-add the id column if it's been dropped
dd = self._validate_dataset_dict(dd, [])
else:
dd = dd.map(formatter.format_sample, remove_columns=formatter.remove_cols)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It'd be nice to just have a wrapper method format_batch for all formatters that operated on the batch level so we don't have to if-else it this way.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants