-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathsgd.py
48 lines (40 loc) · 1.33 KB
/
sgd.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
import datetime
from pathlib import Path
import numpy as np
from jsonargparse import CLI
from lattice_quantizer.criteria.nsm import nsm_cpu
from lattice_quantizer.lr_scheduler import CosineLR, FactorizedLR, RatioLR # noqa: F401
from lattice_quantizer.optimizer import SGDLatticeQuantizerOptimizer
def main(
n: int,
optimizer: SGDLatticeQuantizerOptimizer,
output_dir: Path = Path("results"),
checknsm_num_samples: int = int(1e6),
checknsm_parallel: bool = False,
):
output_subdir = (
output_dir
/ f"quantizer_n{n}_T{optimizer.steps}_B{optimizer.batch_size}_{datetime.datetime.now().strftime('%Y-%m-%d_%H:%M:%S')}"
)
output_subdir.mkdir(parents=True, exist_ok=True)
optimizer.log_dir = output_subdir / "logs"
result = optimizer.optimize(n)
nsm, var = nsm_cpu(
result,
checknsm_num_samples,
np.random.default_rng(),
65536 if checknsm_parallel else 0,
)
np.savetxt(
output_subdir / "basis.txt",
result,
)
np.save(output_subdir / "basis.npy", result)
with (output_subdir / "nsm.txt").open("w") as f:
nsm_str = (
f"{np.format_float_positional(nsm)} +/- {np.format_float_scientific(var)}\n"
)
f.write(nsm_str)
print(f"NSM: {nsm_str}") # noqa: T201
if __name__ == "__main__":
CLI(main)