-
Notifications
You must be signed in to change notification settings - Fork 291
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Tracking: development around Rectified Flow #182
Comments
The first stage of refactoring and migration to continous acceleration has been finished. Rectified Flow models can still run with full compatibility, but the following configurations will no longer take effects on Rectified Flow at training time (they will be converted automatically at inference time if the config file does not contain the new keys):
Inference API (scripts/infer.py) has been changed as follows:
|
ONNX exporting is supported now, but some early Rectified Flow models will result in KeyError. Please manually add the missing keys into the configuration file. |
The second stage of refactoring has been finished in dc6896b. Due to adjustment in the state dict, previous model trained on this branch before the commit should be migrated with the following code: import collections
import pathlib
from typing import Dict, Any
import click
import torch
@click.command()
@click.argument(
'in_ckpt', type=click.Path(
exists=True, dir_okay=False, file_okay=True, readable=True, path_type=pathlib.Path
)
)
@click.argument(
'out_ckpt', type=click.Path(
exists=False, dir_okay=False, file_okay=True, writable=True, path_type=pathlib.Path
)
)
def migrate_reflow(in_ckpt: pathlib.Path, out_ckpt: pathlib.Path):
ckpt = torch.load(in_ckpt, map_location='cpu')
in_state_dict: Dict[str, Any] = ckpt['state_dict']
out_state_dict = collections.OrderedDict()
for k, v in in_state_dict.items():
if 'denoise_fn' in k:
out_state_dict[k.replace('denoise_fn', 'velocity_fn')] = v
elif 'spec_min' in k or 'spec_max' in k:
continue
else:
out_state_dict[k] = v
torch.save({'category': ckpt['category'], 'state_dict': out_state_dict}, out_ckpt)
if __name__ == '__main__':
migrate_reflow() The following configuration keys are renamed:
|
We are introducing Rectified Flow, a new ODE-based generative model, to this repository (in
RectifiedFlow
branch). The differences between Rectified Flow and the currently used DDPM will result in some API changes. The testing and adaptation may take one or more weeks. Since we are still in the early stage and the code is not well-organized, the APIs and configurations on the branch may change over time without any backward compatibility. This issue is raised mainly to inform those who are testing and researching on that branch with the changes (and possible migration steps).TODOs
The text was updated successfully, but these errors were encountered: