Skip to content

Commit

Permalink
lint
Browse files Browse the repository at this point in the history
  • Loading branch information
ncfrey committed May 31, 2024
1 parent 9e9e25b commit ae74b28
Show file tree
Hide file tree
Showing 31 changed files with 40 additions and 48 deletions.
2 changes: 1 addition & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -7,4 +7,4 @@ build/
.coverage*
.mypy_cache/
.pytest_cache/
_skbuild
_skbuild
3 changes: 3 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,9 @@ select = [
"I", # isort
"B", # flake8-bugbear
]
exclude = [
"tests"
]
fixable = ["ALL"]
unfixable = []

Expand Down
2 changes: 1 addition & 1 deletion scripts/lint.sh
Original file line number Diff line number Diff line change
@@ -1,3 +1,3 @@
#!/usr/bin/env bash

pre-commit run -a && mypy --install-types src tests
pre-commit run -a && mypy --install-types src tests
2 changes: 1 addition & 1 deletion src/walkjump/cmdline/_sample.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,5 +43,5 @@ def sample(cfg: DictConfig) -> bool:
sample_df.drop_duplicates(subset=["fv_heavy_aho", "fv_light_aho"], inplace=True)
print(f"Writing {len(sample_df)} samples to {cfg.designs.output_csv}")
sample_df.to_csv(cfg.designs.output_csv, index=False)

return True
1 change: 0 additions & 1 deletion src/walkjump/cmdline/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,6 @@ def instantiate_model_for_sample_mode(
sample_mode_model_cfg.checkpoint_path
)
if isinstance(model, NoiseEnergyModel) and sample_mode_model_cfg.denoise_path is not None:

print(
"[instantiate_model_for_sample_mode] (model.denoise_model)",
_LOG_MSG_INSTANTIATE_MODEL.format(
Expand Down
11 changes: 5 additions & 6 deletions src/walkjump/conformity/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -18,8 +18,8 @@ To keep things simple, we simulate the features of a reference distribution usin
import torch

# Pick some reference distribution
mu = torch.zeros(5)
covariance_matrix = torch.eye(5)
mu = torch.zeros(5)
covariance_matrix = torch.eye(5)
reference_distribution = torch.distributions.MultivariateNormal(mu, covariance_matrix)

X_train = reference_distribution.sample((1000,))
Expand Down Expand Up @@ -80,14 +80,14 @@ plt.show()
```



![png](../../../assets/conformity_example.png)




```python
# Mean conformity as a single statistic
# Mean conformity as a single statistic

# - > 0.5: higher conformity, more similar to training data than validation data
# - 0.5: optimal conformity, as on average, the test and validation data are equally likely under the reference distribution
Expand All @@ -103,4 +103,3 @@ print(f"Mean conformity for #2: {mean_conformity_test_2:.2f}")

Mean conformity for #1: 0.53
Mean conformity for #2: 0.19

2 changes: 1 addition & 1 deletion src/walkjump/conformity/__init__.py
Original file line number Diff line number Diff line change
@@ -1 +1 @@
from ._conformity_score import conformity_score
from ._conformity_score import conformity_score
2 changes: 1 addition & 1 deletion src/walkjump/conformity/_conformity_score.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ def conformity_score(log_prob: Tensor, val_log_prob: Tensor) -> Tensor:
that are less than or equal to it. Between 0 and 1, where:
- > 0.5: higher conformity, more similar to training data than validation data
- 0.5: optimal conformity, as similar to training data as validation data
- 0.5: optimal conformity, as similar to training data as validation data
- < 0.5: lower conformity, validation is more similar to training data than test data
Expand Down
1 change: 0 additions & 1 deletion src/walkjump/data/_batch.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,6 @@ class AbBatch:
def from_tensor_pylist(
cls, inputs: list[torch.Tensor], vocab_size: int = len(TOKENS_AHO)
) -> "AbBatch":

packed_batch = torch.stack(inputs, dim=0)
return cls(packed_batch, vocab_size=vocab_size)

Expand Down
2 changes: 1 addition & 1 deletion src/walkjump/hydra_config/callbacks/base.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ defaults:


model_checkpoint:
dirpath: checkpoints
dirpath: checkpoints
filename: "{epoch}-{step}-{val_loss:.4f}"
monitor: val_loss

Expand Down
4 changes: 2 additions & 2 deletions src/walkjump/hydra_config/logger/wandb.yaml
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
_target_: lightning.pytorch.loggers.WandbLogger
save_dir: "."
offline: false
project: null
entity: null
project: null
entity: null
group: null
notes: null
tags: null
2 changes: 1 addition & 1 deletion src/walkjump/hydra_config/model/denoise.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -3,4 +3,4 @@ defaults:
- model_cfg/hyperparameters: denoise_hp
- _self_

_target_: walkjump.model.DenoiseModel
_target_: walkjump.model.DenoiseModel
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
_target_: walkjump.model.arch.ByteNetArch
_target_: walkjump.model.arch.ByteNetArch
n_tokens: 21
d_model: 128
n_layers: 35
Expand Down
Original file line number Diff line number Diff line change
@@ -1,3 +1,3 @@
lr: 0.0001
weight_decay: 0.01
sigma: 1.0
sigma: 1.0
Original file line number Diff line number Diff line change
Expand Up @@ -2,4 +2,4 @@ defaults:
- default

warmup_batches: 1
lr_start_factor: 0.1
lr_start_factor: 0.1
Original file line number Diff line number Diff line change
Expand Up @@ -3,4 +3,4 @@ defaults:

batch_size: 64
sigma: 1.0
beta1: 0.9
beta1: 0.9
Original file line number Diff line number Diff line change
@@ -1,3 +1,3 @@
_target_: walkjump.model.DenoiseModel.load_from_checkpoint

checkpoint_path: ???
checkpoint_path: ???
Original file line number Diff line number Diff line change
@@ -1,3 +1,3 @@
_target_: walkjump.sampling.create_sampler_fn

friction: 1.0
friction: 1.0
2 changes: 1 addition & 1 deletion src/walkjump/hydra_config/model/noise_ebm.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -6,4 +6,4 @@ defaults:
- _self_


_target_: walkjump.model.NoiseEnergyModel
_target_: walkjump.model.NoiseEnergyModel
2 changes: 1 addition & 1 deletion src/walkjump/hydra_config/sample.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -23,4 +23,4 @@ designs:
num_samples: 100
limit_seeds: 10

device: null
device: null
2 changes: 1 addition & 1 deletion src/walkjump/hydra_config/setup/seed/default.yaml
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
_target_: lightning.pytorch.seed_everything

seed: 0xf1eece
workers: true
workers: true
2 changes: 1 addition & 1 deletion src/walkjump/hydra_config/setup/torch/default.yaml
Original file line number Diff line number Diff line change
@@ -1,3 +1,3 @@
_target_: torch.set_float32_matmul_precision

precision: medium
precision: medium
1 change: 0 additions & 1 deletion src/walkjump/sampling/_langevin.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,6 @@ def sachsetal(
save_trajectory: bool = False,
verbose: bool = True,
) -> tuple[torch.Tensor, torch.Tensor, list[torch.Tensor]]:

options = _DEFAULT_SAMPLING_OPTIONS | sampling_options # overwrite

delta, gamma, lipschitz = options["delta"], options["friction"], options["lipschitz"]
Expand Down
1 change: 0 additions & 1 deletion src/walkjump/sampling/_walkjump.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,6 @@ def walk(
list_ys = []

for seed_chunk in seed_tensor.chunk(chunksize):

# note: apply_noise should control whether seed_chunk.requires_grad
seed_chunk = model.apply_noise(seed_chunk)
# seed_chunk.requires_grad = True
Expand Down
2 changes: 1 addition & 1 deletion tests/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,5 +9,5 @@
"++designs.seeds=denovo",
"++dryrun=true",
"++designs.redesign_regions=[L1,L2,H1,H2]",
"++model.checkpoint_path=last.ckpt"
"++model.checkpoint_path=last.ckpt",
]
1 change: 0 additions & 1 deletion tests/fixtures.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
import pandas as pd
import pytest
from sklearn.preprocessing import LabelEncoder

from walkjump.constants import TOKENS_AHO


Expand Down
6 changes: 3 additions & 3 deletions tests/test_commands.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,11 +3,11 @@
import hydra
import pytest
from omegaconf import DictConfig, OmegaConf

from tests.constants import CONFIG_PATH, TRAINER_OVERRIDES, SAMPLER_OVERRIDES
from walkjump.cmdline import train, sample
from walkjump.cmdline import sample, train
from walkjump.cmdline.utils import instantiate_callbacks

from tests.constants import CONFIG_PATH, SAMPLER_OVERRIDES, TRAINER_OVERRIDES

COMMAND_TO_OVERRIDES = {"train": TRAINER_OVERRIDES, "sample": SAMPLER_OVERRIDES}


Expand Down
15 changes: 5 additions & 10 deletions tests/test_conformity.py
Original file line number Diff line number Diff line change
@@ -1,17 +1,12 @@
import torch
from walkjump.conformity import conformity_score

import pytest
import torch

def test_conformity_score():
log_prob = torch.tensor([1,2,4])
val_log_prob = torch.tensor([3,3,3,3])
log_prob = torch.tensor([1, 2, 4])
val_log_prob = torch.tensor([3, 3, 3, 3])

output = conformity_score(log_prob, val_log_prob)

expected = torch.tensor([
0.0,
0.0,
0.8
])
assert torch.allclose(output, expected)
expected = torch.tensor([0.0, 0.0, 0.8])
assert torch.allclose(output, expected)
4 changes: 2 additions & 2 deletions tests/test_dataset.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,9 @@
import hydra
from walkjump.constants import TOKENS_AHO
from walkjump.data import AbDataset

from tests.constants import CONFIG_PATH, TRAINER_OVERRIDES
from tests.fixtures import aho_sequence, mock_ab_dataframe # noqa: F401
from walkjump.constants import TOKENS_AHO
from walkjump.data import AbDataset


def test_abdataset(mock_ab_dataframe): # noqa: F811
Expand Down
1 change: 0 additions & 1 deletion tests/test_sampling.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
import torch
from omegaconf import DictConfig

from walkjump.constants import ALPHABET_AHO, TOKEN_GAP
from walkjump.model import TrainableScoreModel
from walkjump.sampling import stack_seed_sequences, walkjump
Expand Down
3 changes: 2 additions & 1 deletion tests/test_tokenization.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from tests.fixtures import aho_alphabet_encoder, aho_sequence # noqa: F401
from walkjump.utils import token_string_from_tensor, token_string_to_tensor

from tests.fixtures import aho_alphabet_encoder, aho_sequence # noqa: F401


def test_token_to_string_tofrom_tensor(aho_alphabet_encoder, aho_sequence): # noqa: F811
assert (
Expand Down

0 comments on commit ae74b28

Please sign in to comment.