-
Notifications
You must be signed in to change notification settings - Fork 10
/
Copy pathpipeline.py
29 lines (21 loc) · 987 Bytes
/
pipeline.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
from pathlib import Path
from argparse import ArgumentParser
from pytorch_lightning import Trainer
import concept
import capture
if __name__ == '__main__':
parser = ArgumentParser()
parser.add_argument('path', type=Path)
args = parser.parse_args()
## Extract square crops from image for each of the binary masks located in <path>/masks
regions = concept.crop(args.path)
## Iterate through regions to invert the concept and generate texture views
for region in regions.iterdir():
lora = concept.invert(region)
concept.infer(lora, renorm=True)
## Construct a dataset with all generations and load pretrained decomposition model
data = capture.get_data(predict_dir=args.path, predict_ds='sd')
module = capture.get_inference_module(pt='model.ckpt')
## Proceed with inference on decomposition model
decomp = Trainer(default_root_dir=args.path, accelerator='gpu', devices=1, precision=16)
decomp.predict(module, data)