-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathpreprocess_partnet.py
71 lines (51 loc) · 2.17 KB
/
preprocess_partnet.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
import argparse
import pathlib
from collections import defaultdict
from torch_geometric.transforms import Compose, NormalizeScale
from tqdm.auto import tqdm
import log_set
from training.data.partnet import HdfIO
from transforms import FPS, PosToFloat32
def main(args):
data_dir = pathlib.Path(args.data_dir)
if not data_dir.is_dir():
raise ValueError(f"'{data_dir}' is not a directory")
out_dir = pathlib.Path(args.out_dir)
out_dir.mkdir(exist_ok=True, parents=True)
dirs = list(data_dir.iterdir())
progress = tqdm(dirs)
merge_files = defaultdict(list)
for sample_dir in progress:
if not sample_dir.is_dir():
continue
progress.set_description(f"Scan: {sample_dir}")
for file in sample_dir.iterdir():
if file.stem not in ("test_files", "train_files", "val_files"):
continue
prefix = file.stem.split("_")[0]
with file.open("r", encoding="utf-8") as hdf_file_list:
hdf_files = filter(lambda x: len(x) > 0, map(str.strip, hdf_file_list.readlines()))
union_files = list(str(sample_dir / rel_path) for rel_path in hdf_files)
merge_files[prefix].extend(union_files)
transforms = Compose(
[
NormalizeScale(),
FPS(args.num_points, device=args.device),
PosToFloat32()
]
)
for key in merge_files:
out_path = out_dir / f"{key}.h5"
HdfIO.union_and_compress(merge_files[key], str(
out_path), transforms=transforms)
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--data_dir", type=str, required=True, help="A pth to PartNet data root")
parser.add_argument("--out_dir", type=str, required=True,
help="A path to output dir with union files")
parser.add_argument("--num_points", type=int, default=4096, required=False,
help="A number of points to sample for FPS")
parser.add_argument("--device", default="cpy", type=str, required=False,
help="A device to process. cuda much faster")
args = parser.parse_args()
main(args)