Skip to content

Commit

Permalink
Reformat with ruff (#65)
Browse files Browse the repository at this point in the history
  • Loading branch information
SimonBoothroyd authored May 6, 2024
1 parent e5a7f1c commit 998dbb2
Show file tree
Hide file tree
Showing 11 changed files with 38 additions and 75 deletions.
37 changes: 4 additions & 33 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
@@ -1,38 +1,9 @@
repos:
- repo: local
hooks:
- id: isort
name: "[Package] Import formatting"
- id: ruff
name: "[Package] Formatting"
language: system
entry: isort
entry: make
args: [ lint ]
files: \.py$

- id: black
name: "[Package] Code formatting"
language: system
entry: black
files: \.py$

- id: flake8
name: "[Package] Linting"
language: system
entry: flake8
files: \.py$

- id: isort-examples
name: "[Examples] Import formatting"
language: system
entry: nbqa isort
files: examples/.+\.ipynb$

- id: black-examples
name: "[Examples] Code formatting"
language: system
entry: nbqa black
files: examples/.+\.ipynb$

- id: flake8-examples
name: "[Examples] Linting"
language: system
entry: nbqa flake8 --ignore=E402
files: examples/.+\.ipynb$
23 changes: 9 additions & 14 deletions Makefile
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
PACKAGE_NAME := descent
PACKAGE_NAME := descent
PACKAGE_DIR := $(PACKAGE_NAME)

CONDA_ENV_RUN := conda run --no-capture-output --name $(PACKAGE_NAME)

.PHONY: pip-install env lint format test test-examples
Expand All @@ -13,23 +15,16 @@ env:
$(CONDA_ENV_RUN) pre-commit install || true

lint:
$(CONDA_ENV_RUN) isort --check-only $(PACKAGE_NAME)
$(CONDA_ENV_RUN) black --check $(PACKAGE_NAME)
$(CONDA_ENV_RUN) flake8 $(PACKAGE_NAME)
$(CONDA_ENV_RUN) nbqa isort --check-only examples
$(CONDA_ENV_RUN) nbqa black --check examples
$(CONDA_ENV_RUN) nbqa flake8 --ignore=E402 examples
$(CONDA_ENV_RUN) ruff check $(PACKAGE_DIR)

format:
$(CONDA_ENV_RUN) isort $(PACKAGE_NAME)
$(CONDA_ENV_RUN) black $(PACKAGE_NAME)
$(CONDA_ENV_RUN) flake8 $(PACKAGE_NAME)
$(CONDA_ENV_RUN) nbqa isort examples
$(CONDA_ENV_RUN) nbqa black examples
$(CONDA_ENV_RUN) nbqa flake8 --ignore=E402 examples
$(CONDA_ENV_RUN) ruff format $(PACKAGE_DIR)
$(CONDA_ENV_RUN) ruff check --fix --select I $(PACKAGE_DIR)
$(CONDA_ENV_RUN) nbqa 'ruff format' examples
$(CONDA_ENV_RUN) nbqa 'ruff check' --fix --select=I examples

test:
$(CONDA_ENV_RUN) pytest -v --cov=$(PACKAGE_NAME) --cov-report=xml --color=yes $(PACKAGE_NAME)/tests/
$(CONDA_ENV_RUN) pytest -v --cov=$(PACKAGE_NAME) --cov-report=xml --color=yes $(PACKAGE_DIR)/tests/

docs-build:
$(CONDA_ENV_RUN) mkdocs build
Expand Down
3 changes: 2 additions & 1 deletion descent/targets/dimers.py
Original file line number Diff line number Diff line change
Expand Up @@ -275,7 +275,8 @@ def predict(
*[
_predict(dimer, force_field, topologies)
for dimer in descent.utils.dataset.iter_dataset(dataset)
]
],
strict=True,
)
return torch.cat(reference), torch.cat(predicted)

Expand Down
2 changes: 1 addition & 1 deletion descent/targets/thermo.py
Original file line number Diff line number Diff line change
Expand Up @@ -610,7 +610,7 @@ def predict(

per_type_scales = per_type_scales if per_type_scales is not None else {}

for entry, keys in zip(entries, entry_to_simulation):
for entry, keys in zip(entries, entry_to_simulation, strict=True):
value, std = _predict(entry, keys, observables, required_simulations)

type_scale = per_type_scales.get(entry["type"], 1.0)
Expand Down
5 changes: 3 additions & 2 deletions descent/tests/optim/test_lm.py
Original file line number Diff line number Diff line change
Expand Up @@ -143,7 +143,8 @@ def test_damping_factor_loss_fn(mocker):


@pytest.mark.parametrize(
"n_convergence_criteria, n_convergence_steps, step_quality, expected_converged, expected_logs",
"n_convergence_criteria, n_convergence_steps, step_quality, expected_converged, "
"expected_logs",
[
(0, 2, 1.0, False, []),
(1, 2, 0.0, False, []),
Expand Down Expand Up @@ -283,7 +284,7 @@ def mock_loss_fn(_x, *_):
]
assert len(trust_radius_messages) == len(expected_messages)

for message, expected in zip(trust_radius_messages, expected_messages):
for message, expected in zip(trust_radius_messages, expected_messages, strict=True):
assert message.startswith(expected)

# mock_step_fn.assert_has_calls(expected_loss_traj, any_order=False)
Expand Down
1 change: 0 additions & 1 deletion descent/tests/utils/test_loss.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,6 @@ def mock_loss_fn(x: torch.Tensor, a: float, b: float) -> torch.Tensor:


def test_combine_closures():

def mock_closure_a(x_, compute_gradient, compute_hessian):
loss = x_[0] ** 2
grad = 2 * x_[0] if compute_gradient else None
Expand Down
4 changes: 2 additions & 2 deletions descent/utils/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,5 +19,5 @@ def iter_dataset(dataset: datasets.Dataset) -> typing.Iterator[dict[str, typing.

columns = [*dataset.features]

for row in zip(*[dataset[column] for column in columns]):
yield {column: v for column, v in zip(columns, row)}
for row in zip(*[dataset[column] for column in columns], strict=True):
yield dict(zip(columns, row, strict=True))
2 changes: 0 additions & 2 deletions descent/utils/loss.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,15 +84,13 @@ def combine_closures(
def combined_closure_fn(
x: torch.Tensor, compute_gradient: bool, compute_hessian: bool
) -> tuple[torch.Tensor, torch.Tensor | None, torch.Tensor | None]:

loss = []
grad = None if not compute_gradient else []
hess = None if not compute_hessian else []

verbose_rows = []

for name, closure_fn in closures.items():

local_loss, local_grad, local_hess = closure_fn(
x, compute_gradient, compute_hessian
)
Expand Down
17 changes: 12 additions & 5 deletions descent/utils/reporting.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,8 @@


DEFAULT_COLORS, DEFAULT_MARKERS = zip(
*itertools.product(["red", "green", "blue", "black"], ["x", "o", "+", "^"])
*itertools.product(["red", "green", "blue", "black"], ["x", "o", "+", "^"]),
strict=True,
)


Expand Down Expand Up @@ -103,11 +104,15 @@ def print_potential_summary(potential: smee.TensorPotential):

parameter_rows = []

for key, value in zip(potential.parameter_keys, potential.parameters.detach()):
for key, value in zip(
potential.parameter_keys, potential.parameters.detach(), strict=True
):
row = {"ID": _format_parameter_id(key.id)}
row.update(
{
f"{col}{_format_unit(potential.parameter_units[idx])}": f"{value[idx].item():.4f}"
f"{col}{_format_unit(potential.parameter_units[idx])}": (
f"{value[idx].item():.4f}"
)
for idx, col in enumerate(potential.parameter_cols)
}
)
Expand All @@ -119,7 +124,9 @@ def print_potential_summary(potential: smee.TensorPotential):
if potential.attributes is not None:
attribute_rows = [
{
f"{col}{_format_unit(potential.attribute_units[idx])}": f"{potential.attributes[idx].item():.4f} "
f"{col}{_format_unit(potential.attribute_units[idx])}": (
f"{potential.attributes[idx].item():.4f} "
)
for idx, col in enumerate(potential.attribute_cols)
}
]
Expand All @@ -139,7 +146,7 @@ def print_v_site_summary(v_sites: smee.TensorVSites):

parameter_rows = []

for key, value in zip(v_sites.keys, v_sites.parameters.detach()):
for key, value in zip(v_sites.keys, v_sites.parameters.detach(), strict=True):
row = {"ID": _format_parameter_id(key.id)}
row.update(
{
Expand Down
5 changes: 1 addition & 4 deletions devtools/envs/base.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -41,10 +41,7 @@ dependencies:
- versioneer

- pre-commit
- isort
- black
- flake8
- flake8-pyproject
- ruff
- nbqa

- pytest
Expand Down
14 changes: 4 additions & 10 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -31,16 +31,10 @@ versionfile_build = "descent/_version.py"
tag_prefix = ""
parentdir_prefix = "descent-"

[tool.black]
line-length = 88

[tool.isort]
profile = "black"

[tool.flake8]
max-line-length = 88
ignore = ["E203", "E266", "E501", "W503"]
select = ["B","C","E","F","W","T4","B9"]
[tool.ruff.lint]
ignore = ["C901"]
select = ["B","C","E","F","W","B9"]
ignore-init-module-imports = true

[tool.coverage.run]
omit = ["**/tests/*", "**/_version.py"]
Expand Down

0 comments on commit 998dbb2

Please sign in to comment.