Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Vb/issue235 -- Convert config.pkl to config.yaml #240

Merged
merged 5 commits into from
Apr 25, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 10 additions & 11 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -414,7 +414,7 @@ Image poses may be *locally* refined using the `--do-pose-sgd` flag. Please cons

## 6. Analysis of results

Once the model has finished training, the output directory will contain a configuration file `config.pkl`, neural network weights `weights.pkl`, image poses (if performing pose sgd) `pose.pkl`, and the latent embeddings for each image `z.pkl`. The latent embeddings are provided in the same order as the input particles. To analyze these results, use the `cryodrgn analyze` command to visualize the latent space and generate structures. `cryodrgn analyze` will also provide a template jupyter notebook for further interactive visualization and analysis.
Once the model has finished training, the output directory will contain a configuration file `config.yaml`, neural network weights `weights.pkl`, image poses (if performing pose sgd) `pose.pkl`, and the latent embeddings for each image `z.pkl`. The latent embeddings are provided in the same order as the input particles. To analyze these results, use the `cryodrgn analyze` command to visualize the latent space and generate structures. `cryodrgn analyze` will also provide a template jupyter notebook for further interactive visualization and analysis.

NEW in version 1.0: There are two additional tools `cryodrgn analyze_landscape` and `cryodrgn analyze_landscape_full` for more comprehensive and auomated analyses of cryodrgn results.

Expand Down Expand Up @@ -497,12 +497,11 @@ Additional structures may be generated using `cryodrgn eval_vol`:
weights Model weights

optional arguments:
-h, --help show this help message and exit
-c PKL, --config PKL CryoDRGN config.pkl file
-o O Output .mrc or directory
--prefix PREFIX Prefix when writing out multiple .mrc files (default:
vol_)
-v, --verbose Increaes verbosity
-h, --help show this help message and exit
-c YAML, --config YAML CryoDRGN config.yaml file
-o O Output .mrc or directory
--prefix PREFIX Prefix when writing out multiple .mrc files (default: vol_)
-v, --verbose Increase verbosity

Specify z values:
-z [Z [Z ...]] Specify one z-value
Expand All @@ -519,7 +518,7 @@ Additional structures may be generated using `cryodrgn eval_vol`:
-d DOWNSAMPLE, --downsample DOWNSAMPLE
Downsample volumes to this box size (pixels)

Overwrite architecture hyperparameters in config.pkl:
Overwrite architecture hyperparameters in config.yaml:
--norm NORM NORM
-D D Box size
--enc-layers QLAYERS Number of hidden layers
Expand All @@ -543,19 +542,19 @@ Additional structures may be generated using `cryodrgn eval_vol`:

To generate a volume at a single value of the latent variable:

$ cryodrgn eval_vol [YOUR_WORKDIR]/weights.pkl --config [YOUR_WORKDIR]/config.pkl -z ZVALUE -o reconstruct.mrc
$ cryodrgn eval_vol [YOUR_WORKDIR]/weights.pkl --config [YOUR_WORKDIR]/config.yaml -z ZVALUE -o reconstruct.mrc

The number of inputs for `-z` must match the dimension of your latent variable.

Or to generate a trajectory of structures from a defined start and ending point, use the `--z-start` and `--z-end` arugments:

$ cryodrgn eval_vol [YOUR_WORKDIR]/weights.pkl --config [YOUR_WORKDIR]/config.pkl --z-start -3 --z-end 3 -n 20 -o [WORKDIR]/trajectory
$ cryodrgn eval_vol [YOUR_WORKDIR]/weights.pkl --config [YOUR_WORKDIR]/config.yaml --z-start -3 --z-end 3 -n 20 -o [WORKDIR]/trajectory

This example generates 20 structures at evenly spaced values between z=[-3,3], assuming a 1-dimensional latent variable model.

Finally, a series of structures can be generated using values of z given in a file specified by the arugment `--zfile`:

$ cryodrgn eval_vol [WORKDIR]/weights.pkl --config [WORKDIR]/config.pkl --zfile zvalues.txt -o [WORKDIR]/trajectory
$ cryodrgn eval_vol [WORKDIR]/weights.pkl --config [WORKDIR]/config.yaml --zfile zvalues.txt -o [WORKDIR]/trajectory

The input to `--zfile` is expected to be an array of dimension (N_volumes x zdim), loaded with np.loadtxt.

Expand Down
2 changes: 1 addition & 1 deletion cryodrgn/analysis.py
Original file line number Diff line number Diff line change
Expand Up @@ -577,7 +577,7 @@ def gen_volumes(
"""Call cryodrgn eval_vol to generate volumes at specified z values
Input:
weights (str): Path to model weights .pkl
config (str): Path to config.pkl
config (str): Path to config.yaml
zfile (str): Path to .txt file of z values
outdir (str): Path to output directory for volumes,
device (int or None): Specify cuda device
Expand Down
11 changes: 4 additions & 7 deletions cryodrgn/commands/abinit_het.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,9 +14,9 @@
from torch.nn.parallel import DataParallel
from torch.utils.data import DataLoader
from typing import Union
import cryodrgn
from cryodrgn import ctf, dataset, lie_tools, utils
from cryodrgn.beta_schedule import LinearSchedule, get_beta_schedule
import cryodrgn.config
from cryodrgn.lattice import Lattice
from cryodrgn.losses import EquivarianceLoss
from cryodrgn.models import HetOnlyVAE, unparallelize
Expand Down Expand Up @@ -649,11 +649,8 @@ def save_config(args, dataset, lattice, model, out_config):
config = dict(
dataset_args=dataset_args, lattice_args=lattice_args, model_args=model_args
)
config["seed"] = args.seed
with open(out_config, "wb") as f:
pickle.dump(config, f)
meta = dict(time=dt.now(), cmd=sys.argv, version=cryodrgn.__version__)
pickle.dump(meta, f)

cryodrgn.config.save(config, out_config)


def sort_poses(poses):
Expand Down Expand Up @@ -890,7 +887,7 @@ def main(args):
pose_model = model

# save configuration
out_config = "{}/config.pkl".format(args.outdir)
out_config = "{}/config.yaml".format(args.outdir)
save_config(args, data, lattice, model, out_config)

ps = PoseSearch(
Expand Down
7 changes: 6 additions & 1 deletion cryodrgn/commands/analyze.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@

import argparse
import os
import os.path
import shutil
from datetime import datetime as dt
import logging
Expand Down Expand Up @@ -348,7 +349,11 @@ def main(args):
workdir = args.workdir
zfile = f"{workdir}/z.{E}.pkl"
weights = f"{workdir}/weights.{E}.pkl"
config = f"{workdir}/config.pkl"
config = (
f"{workdir}/config.yaml"
if os.path.exists(f"{workdir}/config.yaml")
else f"{workdir}/config.pkl"
)
outdir = f"{workdir}/analyze.{E}"
if E == -1:
zfile = f"{workdir}/z.pkl"
Expand Down
10 changes: 8 additions & 2 deletions cryodrgn/commands/analyze_landscape.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@

import argparse
import os
import os.path
from collections import Counter
from datetime import datetime as dt
import logging
Expand All @@ -16,6 +17,7 @@
from sklearn.cluster import AgglomerativeClustering
from sklearn.decomposition import PCA
from cryodrgn import analysis, mrc, utils
import cryodrgn.config

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -450,7 +452,11 @@ def main(args):
workdir = args.workdir
zfile = f"{workdir}/z.{E}.pkl"
weights = f"{workdir}/weights.{E}.pkl"
config = f"{workdir}/config.pkl"
config = (
f"{workdir}/config.yaml"
if os.path.exists(f"{workdir}/config.yaml")
else f"{workdir}/config.pkl"
)
outdir = f"{workdir}/landscape.{E}"

if args.outdir:
Expand Down Expand Up @@ -505,7 +511,7 @@ def main(args):

logger.info("Analyzing volumes...")
# get particle indices if the dataset was originally filtered
c = utils.load_pkl(config)
c = cryodrgn.config.load(config)
particle_ind = (
utils.load_pkl(c["dataset_args"]["ind"])
if c["dataset_args"]["ind"] is not None
Expand Down
13 changes: 9 additions & 4 deletions cryodrgn/commands/analyze_landscape_full.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import argparse
import os
import os.path
import pprint
import shutil
from datetime import datetime as dt
Expand Down Expand Up @@ -159,7 +160,7 @@ def __getitem__(self, idx):


def generate_and_map_volumes(
zfile, cfg_pkl, weights, mask_mrc, pca_obj_pkl, landscape_dir, outdir, args
zfile, cfg, weights, mask_mrc, pca_obj_pkl, landscape_dir, outdir, args
):
# Sample z
logger.info(f"Sampling {args.N} particles from {zfile}")
Expand All @@ -175,7 +176,7 @@ def generate_and_map_volumes(
assert torch.cuda.is_available(), "No GPUs detected"
torch.set_default_tensor_type(torch.cuda.FloatTensor) # type: ignore

cfg = config.update_config_v1(cfg_pkl)
cfg = config.update_config_v1(cfg)
logger.info("Loaded configuration:")
pprint.pprint(cfg)

Expand Down Expand Up @@ -288,7 +289,11 @@ def main(args):
workdir = args.workdir
zfile = f"{workdir}/z.{E}.pkl"
weights = f"{workdir}/weights.{E}.pkl"
cfg_pkl = f"{workdir}/config.pkl"
cfg = (
f"{workdir}/config.yaml"
if os.path.exists(f"{workdir}/config.yaml")
else f"{workdir}/config.pkl"
)
landscape_dir = (
f"{workdir}/landscape.{E}" if args.landscape_dir is None else args.landscape_dir
)
Expand Down Expand Up @@ -320,7 +325,7 @@ def main(args):
z = utils.load_pkl(z_sampled_pkl)
else:
z, embeddings = generate_and_map_volumes(
zfile, cfg_pkl, weights, mask_mrc, pca_obj_pkl, landscape_dir, outdir, args
zfile, cfg, weights, mask_mrc, pca_obj_pkl, landscape_dir, outdir, args
)
utils.save_pkl(embeddings, embeddings_pkl)

Expand Down
2 changes: 1 addition & 1 deletion cryodrgn/commands/eval_images.py
Original file line number Diff line number Diff line change
Expand Up @@ -110,7 +110,7 @@ def add_args(parser):
)

group = parser.add_argument_group(
"Overwrite architecture hyperparameters in config.pkl"
"Overwrite architecture hyperparameters in config.yaml"
)
group.add_argument("--zdim", type=int, help="Dimension of latent variable")
group.add_argument(
Expand Down
8 changes: 6 additions & 2 deletions cryodrgn/commands/eval_vol.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,11 @@
def add_args(parser):
parser.add_argument("weights", help="Model weights")
parser.add_argument(
"-c", "--config", metavar="PKL", required=True, help="CryoDRGN config.pkl file"
"-c",
"--config",
metavar="YAML",
required=True,
help="CryoDRGN config.yaml file",
)
parser.add_argument(
"-o", type=os.path.abspath, required=True, help="Output .mrc or directory"
Expand Down Expand Up @@ -72,7 +76,7 @@ def add_args(parser):
)

group = parser.add_argument_group(
"Overwrite architecture hyperparameters in config.pkl"
"Overwrite architecture hyperparameters in config.yaml"
)
group.add_argument("--norm", nargs=2, type=float)
group.add_argument("-D", type=int, help="Box size")
Expand Down
8 changes: 3 additions & 5 deletions cryodrgn/commands/train_nn.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
from cryodrgn.lattice import Lattice
from cryodrgn.pose import PoseTracker
from cryodrgn.models import DataParallelDecoder, Decoder
import cryodrgn.config

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -326,10 +327,7 @@ def save_config(args, dataset, lattice, model, out_config):
dataset_args=dataset_args, lattice_args=lattice_args, model_args=model_args
)
config["seed"] = args.seed
with open(out_config, "wb") as f:
pickle.dump(config, f)
meta = dict(time=dt.now(), cmd=sys.argv, version=cryodrgn.__version__)
pickle.dump(meta, f)
cryodrgn.config.save(config, out_config)


def get_latest(args):
Expand Down Expand Up @@ -467,7 +465,7 @@ def main(args):
Apix = ctf_params[0, 0] if ctf_params is not None else 1

# save configuration
out_config = f"{args.outdir}/config.pkl"
out_config = f"{args.outdir}/config.yaml"
save_config(args, data, lattice, model, out_config)

# Mixed precision training with AMP
Expand Down
8 changes: 3 additions & 5 deletions cryodrgn/commands/train_vae.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
from cryodrgn.lattice import Lattice
from cryodrgn.models import HetOnlyVAE, unparallelize
from cryodrgn.pose import PoseTracker
import cryodrgn.config

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -532,10 +533,7 @@ def save_config(args, dataset, lattice, model, out_config):
dataset_args=dataset_args, lattice_args=lattice_args, model_args=model_args
)
config["seed"] = args.seed
with open(out_config, "wb") as f:
pickle.dump(config, f)
meta = dict(time=dt.now(), cmd=sys.argv, version=cryodrgn.__version__)
pickle.dump(meta, f)
cryodrgn.config.save(config, out_config)


def get_latest(args):
Expand Down Expand Up @@ -741,7 +739,7 @@ def main(args):
)

# save configuration
out_config = "{}/config.pkl".format(args.outdir)
out_config = "{}/config.yaml".format(args.outdir)
save_config(args, data, lattice, model, out_config)

optim = torch.optim.Adam(model.parameters(), lr=args.lr, weight_decay=args.wd)
Expand Down
33 changes: 23 additions & 10 deletions cryodrgn/commands/view_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,11 @@

import argparse
import os
import pickle
import os.path
from pprint import pprint
import logging
import warnings
from cryodrgn import utils

logger = logging.getLogger(__name__)

Expand All @@ -19,16 +21,27 @@ def add_args(parser):


def main(args):
f = open(f"{args.workdir}/config.pkl", "rb")
cfg = pickle.load(f)
try:
meta = pickle.load(f)
logger.info(f'Version: {meta["version"]}')
logger.info(f'Creation time: {meta["time"]}')
warnings.warn(
"The view_config command is deprecated."
"Please save configuration in yaml format and view the config.yaml file directly.",
DeprecationWarning,
)
config_yaml = f"{args.workdir}/config.yaml"
config_pkl = f"{args.workdir}/config.pkl"
if os.path.exists(config_yaml):
cfg = utils.load_yaml(config_yaml)
elif os.path.exists(config_pkl):
cfg = utils.load_pkl(config_pkl)
else:
raise RuntimeError(f"A configuration file was not found at {args.workdir}")

if "version" in cfg:
logger.info(f'Version: {cfg["version"]}')
if "time" in cfg:
logger.info(f'Creation time: {cfg["time"]}')
if "cmd" in cfg:
logger.info("Command:")
print(" ".join(meta["cmd"]))
except: # noqa: E722
pass
print(" ".join(cfg["cmd"]))
logger.info("Config:")
pprint(cfg)

Expand Down
Loading