-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathpyproject.toml
78 lines (64 loc) · 2.73 KB
/
pyproject.toml
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
[tool.poetry]
name = "typed-transformer"
version = "0.1.0"
description = "A typed transformer implementation in JAX"
authors = ["Cole Haus <colehaus@cryptolab.net>"]
license = "AGPL-3.0-or-later"
[tool.poetry.dependencies]
python = "~3.12.1"
equinox = "^0.11.3"
jax = "^0.4.25"
jaxlib = [
# Mildly hacky way to detect GPU vs not.
# It would make more sense to base the marker on `extra` but that's not possible.
{ version = "^0.4.25", source = "pypi", markers = "'NixOS' in platform_version" },
# Manually specify resolution of jax[cuda12_pip] -f https://storage.googleapis.com/jax-releases/jax_releases.html
# because of https://github.com/python-poetry/poetry/issues/6819
{ url = "https://storage.googleapis.com/jax-releases/cuda12/jaxlib-0.4.25+cuda12.cudnn89-cp312-cp312-manylinux2014_x86_64.whl", markers = "'NixOS' not in platform_version" },
]
numpy = "^1.26.0"
optax = "^0.1.7"
tqdm = "^4.66.1"
nvidia-cublas-cu12 = { version = "*", optional = true }
nvidia-cuda-cupti-cu12 = { version = "*", optional = true }
nvidia-cuda-nvcc-cu12 = { version = "*", optional = true }
nvidia-cuda-runtime-cu12 = { version = "*", optional = true }
nvidia-cudnn-cu12 = { version = ">=8.9", optional = true }
nvidia-cufft-cu12 = { version = "*", optional = true }
nvidia-cusolver-cu12 = { version = "*", optional = true }
nvidia-cusparse-cu12 = { version = "*", optional = true }
nvidia-ml-py3 = { version = "^7.352.0", optional = true }
jupyter = { version = "^1.0.0", optional = true }
nbclassic = { version = "^1.0.0", optional = true }
notebook = { version = "^7.0.6", optional = true }
pyright = { version = "^1.1.351", optional = true }
[tool.poetry.extras]
"dev" = ["pyright"]
"gpu" = [
"nvidia-cublas-cu12",
"nvidia-cuda-cupti-cu12",
"nvidia-cuda-nvcc-cu12",
"nvidia-cuda-runtime-cu12",
"nvidia-cudnn-cu12",
"nvidia-cufft-cu12",
"nvidia-cusolver-cu12",
"nvidia-cusparse-cu12",
"nvidia-ml-py3"
]
"gradient" = ["jupyter"]
"vscode" = ["jupyter", "nbclassic", "notebook"]
[build-system]
requires = ["poetry-core>=1.0.0"]
build-backend = "poetry.core.masonry.api"
[tool.black]
line-length = 115
include = '\.pyi?$'
[tool.ruff]
line-length = 115
[tool.ruff.lint]
select = ["A", "B", "C90", "C4", "COM", "D", "E", "F", "FA", "FBT", "FURB", "I", "LOG", "N", "PD", "PERF", "PIE", "PL", "PT", "PTH", "RET", "RUF", "SIM", "UP"]
ignore = ["B008", "COM812", "D100", "D101", "D102", "D103", "D105", "D106", "D107", "D205", "D400", "D415", "E203", "E402", "E741", "F401", "F405", "F821", "F841", "PD008", "RET505", "RET506"]
extend-safe-fixes = ["UP007"]
[tool.ruff.lint.per-file-ignores]
"*.pyi" = ["A", "F403", "FBT001", "FBT002", "N801", "N802", "N803", "N815", "N816", "PLR0904", "PLR0913", "PLW1641"]
"test*.py" = ["PLR2004", "PLR6301"]