-
Notifications
You must be signed in to change notification settings - Fork 2
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
feat: s2s auto chat support #779
Conversation
# Add validation data if missing, add 'id' column | ||
dd, dataset_config = self._validate_dataset_dict(dd, []), dataset_config | ||
# Apply the datasets custom formatter on load dataset dict | ||
col_names = ( | ||
dd[Split.train].column_names | ||
if dataset_config.formatter.remove_columns | ||
else [] | ||
) | ||
dd = dd.map( | ||
dataset_config.formatter.format_batch, | ||
batched=True, | ||
remove_columns=col_names, | ||
with_indices=True, | ||
) | ||
dd, dataset_config = self._validate_dataset_dict(dd, []), dataset_config |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
will clean this up before we merge
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Nice nice! I think this looks pretty clean to me though. Is the assumption that you could also be formatting e.g. with the Alpaca formatter?
metadata = sample.get(self.metadata_col, {}) | ||
sample_cols = [ | ||
col | ||
for col in sample.keys() | ||
if col not in [self.metadata_col, self.turns_col] | ||
] | ||
for col in sample_cols: | ||
metadata[col] = sample[col] | ||
unraveled_turns = unraveled_turns | metadata |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
will also clean this up
Codecov Report
@@ Coverage Diff @@
## main #779 +/- ##
==========================================
- Coverage 87.72% 87.41% -0.32%
==========================================
Files 184 186 +2
Lines 15127 15195 +68
==========================================
+ Hits 13270 13282 +12
- Misses 1857 1913 +56
... and 3 files with indirect coverage changes 📣 We’re building smart automated test selection to slash your CI/CD build times. Learn more |
@@ -0,0 +1,19 @@ | |||
from typing import Dict, Type |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I never quite understand when to put things in the init
?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
things that are common to all the formatters, like this map!
target_col: str = "output" | ||
max_train_size: int = 1000 | ||
|
||
def format_sample(self, sample: Dict[str, str], idx: int) -> Dict[str, str]: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
What is idx used for?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Should this be made optional?
max_train_size: Optional[int] = None | ||
remove_columns: bool = False | ||
|
||
def format_batch(self, batch: Dict, idxs: List[int]) -> Dict[str, List]: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This approach doesn't seem bad! There may be a bit of extra data copying, but for a working solution this seems fine.
My only thought is you could have a wrapper data class something like:
class DataBatch():
batch: Dict
active_row: int = 0
...
def get(key: str):
return batch[key][active_row]
...
This could at least help avoid some of the copying involved in creating the sample. You would still need to return a Dict and add to the result I guess.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Awesome 🚢 !! I left just some small comments. On to the chat with history now :)
) | ||
# We must re-add the id column if it's been dropped | ||
dd = self._validate_dataset_dict(dd, []) | ||
return dd, dataset_config |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Beautiful!
target_col: str = "output" | ||
max_train_size: int = 1000 | ||
|
||
def format_sample(self, sample: Dict[str, str], idx: int) -> Dict[str, str]: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Should this be made optional?
|
||
|
||
@dataclass | ||
class BatchData: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Nice 👍
# Add sample level metadata | ||
turn_data.update(metadata) | ||
for k, v in turn_data.items(): | ||
# NOTE: When we drop p3.8 we can use 'turn_data |= turn_meta' |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This comment maybe is meant to be above
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Left a few questions / notes. Feel free to take it or leave it!
max_train_sz = ( | ||
dataset_config.max_train_size or dataset_config.formatter.max_train_size | ||
) | ||
max_train_sz = max_train_size or dataset_config.formatter.max_train_size |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Seeing that max_train_size
is optional even for formatters, isn't it possible we'd end up with None
here?
# Add metadata to each turn | ||
turn_meta = { | ||
f"{role}_{col}": turn[col] | ||
for col in turn.keys() | ||
if col not in turn_default_cols | ||
and isinstance(turn[col], valid_meta_types) | ||
} | ||
# Add turn level metadata to turn | ||
# NOTE: When we drop p3.8 we can use 'turn_data |= turn_meta' | ||
turn_data.update(turn_meta) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Not seeing this in the docstring example above?
# Reset turn data | ||
turn_data = {} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why only under the elif
?
# Add validation data if missing, add 'id' column | ||
dd = self._validate_dataset_dict(dd, []) | ||
formatter = dataset_config.formatter | ||
if formatter.process_batch: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
When would this be false?
# We must re-add the id column if it's been dropped | ||
dd = self._validate_dataset_dict(dd, []) | ||
else: | ||
dd = dd.map(formatter.format_sample, remove_columns=formatter.remove_cols) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It'd be nice to just have a wrapper method format_batch
for all formatters that operated on the batch level so we don't have to if-else it this way.
V1 of auto chat support. This just unravels the turns and has input as user, target as assistant
V2 will include some of the chat history
https://app.shortcut.com/galileo/story/8388/dq-support-basic-chat-models
Did e2e tests for:
✅ chat data link
✅ auto with alpaca link
✅ completion dataset