diff --git a/CHANGELOG.md b/CHANGELOG.md index 69f32f1c..dd948c6f 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -5,6 +5,9 @@ All notable changes to this project will be documented in this file. The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/), and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0.html). +## Unreleased +- Add `use_gpu` for PureSVD ([#229](https://github.com/MobileTeleSystems/RecTools/pull/229)) + ## [0.10.0] - 16.01.2025 ### Added diff --git a/poetry.lock b/poetry.lock index d73fdc78..93d95540 100644 --- a/poetry.lock +++ b/poetry.lock @@ -529,6 +529,36 @@ tomli = {version = "*", optional = true, markers = "python_full_version <= \"3.1 [package.extras] toml = ["tomli"] +[[package]] +name = "cupy-cuda12x" +version = "13.3.0" +description = "CuPy: NumPy & SciPy for GPU" +optional = true +python-versions = ">=3.9" +files = [ + {file = "cupy_cuda12x-13.3.0-cp310-cp310-manylinux2014_aarch64.whl", hash = "sha256:674488e990998042cc54d2486d3c37cae80a12ba3787636be5a10b9446dd6914"}, + {file = "cupy_cuda12x-13.3.0-cp310-cp310-manylinux2014_x86_64.whl", hash = "sha256:cf4a2a0864364715881b50012927e88bd7ec1e6f1de3987970870861ae5ed25e"}, + {file = "cupy_cuda12x-13.3.0-cp310-cp310-win_amd64.whl", hash = "sha256:7c0dc8c49d271d1c03e49a5d6c8e42e8fee3114b10f269a5ecc387731d693eaa"}, + {file = "cupy_cuda12x-13.3.0-cp311-cp311-manylinux2014_aarch64.whl", hash = "sha256:c0cc095b9a3835fd5db66c45ed3c58ecdc5a3bb14e53e1defbfd4a0ce5c8ecdb"}, + {file = "cupy_cuda12x-13.3.0-cp311-cp311-manylinux2014_x86_64.whl", hash = "sha256:a0e3bead04e502ebde515f0343444ca3f4f7aed09cbc3a316a946cba97f2ea66"}, + {file = "cupy_cuda12x-13.3.0-cp311-cp311-win_amd64.whl", hash = "sha256:5f11df1149c7219858b27e4c8be92cb4eaf7364c94af6b78c40dffb98050a61f"}, + {file = "cupy_cuda12x-13.3.0-cp312-cp312-manylinux2014_aarch64.whl", hash = "sha256:bbd0d916310391faf0d7dc9c58fff7a6dc996b67e5768199160bbceb5ebdda8c"}, + {file = "cupy_cuda12x-13.3.0-cp312-cp312-manylinux2014_x86_64.whl", hash = "sha256:e206bd8664f0763732b6012431f484ee535bffd77a5ae95e9bfe1c7c72396625"}, + {file = "cupy_cuda12x-13.3.0-cp312-cp312-win_amd64.whl", hash = "sha256:88ef1478f00ae252da0026e7f04f70c9bb6a2dc130ba5f1e5bc5e8069a928bf5"}, + {file = "cupy_cuda12x-13.3.0-cp39-cp39-manylinux2014_aarch64.whl", hash = "sha256:3a52aa49ffcc940d034f2bb39728c90e9fa83c7a49e376404507956adb6d6ec4"}, + {file = "cupy_cuda12x-13.3.0-cp39-cp39-manylinux2014_x86_64.whl", hash = "sha256:3ef13f3cbc449d2a0f816594ab1fa0236e1f06ad1eaa81ad04c75e47cbeb87be"}, + {file = "cupy_cuda12x-13.3.0-cp39-cp39-win_amd64.whl", hash = "sha256:8f5433eec3e5cd8d39e8fcb82e0fdab7c22eba8e3304fcb0b42f2ea988fef0d6"}, +] + +[package.dependencies] +fastrlock = ">=0.5" +numpy = ">=1.22,<2.3" + +[package.extras] +all = ["Cython (>=0.29.22,<3)", "optuna (>=2.0)", "scipy (>=1.7,<1.14)"] +stylecheck = ["autopep8 (==1.5.5)", "flake8 (==3.8.4)", "mypy (==1.4.1)", "pbr (==5.5.1)", "pycodestyle (==2.6.0)", "types-setuptools (==57.4.14)"] +test = ["hypothesis (>=6.37.2,<6.55.0)", "mpmath", "packaging", "pytest (>=7.2)"] + [[package]] name = "decorator" version = "5.1.1" @@ -597,6 +627,90 @@ files = [ [package.extras] devel = ["colorama", "json-spec", "jsonschema", "pylint", "pytest", "pytest-benchmark", "pytest-cache", "validictory"] +[[package]] +name = "fastrlock" +version = "0.8.2" +description = "Fast, re-entrant optimistic lock implemented in Cython" +optional = true +python-versions = "*" +files = [ + {file = "fastrlock-0.8.2-cp27-cp27m-macosx_10_15_x86_64.whl", hash = "sha256:94e348c72a1fd1f8191f25ea056448e4f5a87b8fbf005b39d290dcb0581a48cd"}, + {file = "fastrlock-0.8.2-cp27-cp27m-manylinux_2_5_i686.manylinux1_i686.whl", hash = "sha256:2d5595903444c854b99c42122b87edfe8a37cd698a4eae32f4fd1d2a7b6c115d"}, + {file = "fastrlock-0.8.2-cp27-cp27m-manylinux_2_5_x86_64.manylinux1_x86_64.whl", hash = "sha256:e4bbde174a0aff5f6eeba75cf8c4c5d2a316316bc21f03a0bddca0fc3659a6f3"}, + {file = "fastrlock-0.8.2-cp27-cp27mu-manylinux_2_5_i686.manylinux1_i686.whl", hash = "sha256:7a2ccaf88ac0db153e84305d1ef0aa138cea82c6a88309066f6eaa3bc98636cd"}, + {file = "fastrlock-0.8.2-cp27-cp27mu-manylinux_2_5_x86_64.manylinux1_x86_64.whl", hash = "sha256:31a27a2edf482df72b91fe6c6438314d2c65290aa7becc55589d156c9b91f0da"}, + {file = "fastrlock-0.8.2-cp310-cp310-macosx_11_0_x86_64.whl", hash = "sha256:e9904b5b37c3e5bb4a245c56bc4b7e497da57ffb8528f4fc39af9dcb168ee2e1"}, + {file = "fastrlock-0.8.2-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.manylinux_2_24_aarch64.whl", hash = "sha256:43a241655e83e4603a152192cf022d5ca348c2f4e56dfb02e5c9c4c1a32f9cdb"}, + {file = "fastrlock-0.8.2-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:9121a894d74e65557e47e777060a495ab85f4b903e80dd73a3c940ba042920d7"}, + {file = "fastrlock-0.8.2-cp310-cp310-manylinux_2_5_i686.manylinux1_i686.manylinux_2_24_i686.whl", hash = "sha256:11bbbbc526363955aeddb9eec4cee2a0012322b7b2f15b54f44454fcf4fd398a"}, + {file = "fastrlock-0.8.2-cp310-cp310-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_24_x86_64.whl", hash = "sha256:27786c62a400e282756ae1b090bcd7cfa35f28270cff65a9e7b27a5327a32561"}, + {file = "fastrlock-0.8.2-cp310-cp310-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:08315bde19d0c2e6b06593d5a418be3dc8f9b1ee721afa96867b9853fceb45cf"}, + {file = "fastrlock-0.8.2-cp310-cp310-musllinux_1_1_aarch64.whl", hash = "sha256:e8b49b5743ede51e0bcf6805741f39f5e0e0fd6a172ba460cb39e3097ba803bb"}, + {file = "fastrlock-0.8.2-cp310-cp310-musllinux_1_1_x86_64.whl", hash = "sha256:b443e73a4dfc7b6e0800ea4c13567b9694358e86f53bb2612a51c9e727cac67b"}, + {file = "fastrlock-0.8.2-cp310-cp310-win_amd64.whl", hash = "sha256:b3853ed4ce522598dc886160a7bab432a093051af85891fa2f5577c1dcac8ed6"}, + {file = "fastrlock-0.8.2-cp311-cp311-macosx_10_15_universal2.whl", hash = "sha256:790fc19bccbd39426060047e53629f171a44745613bf360a045e9f9c8c4a2cea"}, + {file = "fastrlock-0.8.2-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.manylinux_2_24_aarch64.whl", hash = "sha256:dbdce852e6bb66e1b8c36679d482971d69d93acf1785657522e51b7de30c3356"}, + {file = "fastrlock-0.8.2-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:d47713ffe6d4a627fbf078be9836a95ac106b4a0543e3841572c91e292a5d885"}, + {file = "fastrlock-0.8.2-cp311-cp311-manylinux_2_5_i686.manylinux1_i686.manylinux_2_24_i686.whl", hash = "sha256:ea96503b918fceaf40443182742b8964d47b65c5ebdea532893cb9479620000c"}, + {file = "fastrlock-0.8.2-cp311-cp311-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_24_x86_64.whl", hash = "sha256:c6bffa978793bea5e1b00e677062e53a62255439339591b70e209fa1552d5ee0"}, + {file = "fastrlock-0.8.2-cp311-cp311-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:75c07726c8b1a52147fd7987d6baaa318c5dced1416c3f25593e40f56e10755b"}, + {file = "fastrlock-0.8.2-cp311-cp311-musllinux_1_1_aarch64.whl", hash = "sha256:88f079335e9da631efa64486c8207564a7bcd0c00526bb9e842e9d5b7e50a6cc"}, + {file = "fastrlock-0.8.2-cp311-cp311-musllinux_1_1_x86_64.whl", hash = "sha256:4fb2e77ff04bc4beb71d63c8e064f052ce5a6ea1e001d528d4d7f4b37d736f2e"}, + {file = "fastrlock-0.8.2-cp311-cp311-win_amd64.whl", hash = "sha256:b4c9083ea89ab236b06e9ef2263971db3b4b507195fc7d5eecab95828dcae325"}, + {file = "fastrlock-0.8.2-cp312-cp312-macosx_10_15_universal2.whl", hash = "sha256:98195866d3a9949915935d40a88e4f1c166e82e378f622c88025f2938624a90a"}, + {file = "fastrlock-0.8.2-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:b22ea9bf5f9fad2b0077e944a7813f91593a4f61adf8faf734a70aed3f2b3a40"}, + {file = "fastrlock-0.8.2-cp312-cp312-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:dcc1bf0ac8a194313cf6e645e300a8a379674ceed8e0b1e910a2de3e3c28989e"}, + {file = "fastrlock-0.8.2-cp312-cp312-musllinux_1_1_aarch64.whl", hash = "sha256:a3dcc876050b8f5cbc0ee84ef1e7f0c1dfe7c148f10098828bc4403683c33f10"}, + {file = "fastrlock-0.8.2-cp312-cp312-musllinux_1_1_x86_64.whl", hash = "sha256:685e656048b59d8dfde8c601f188ad53a4d719eb97080cafc8696cda6d75865e"}, + {file = "fastrlock-0.8.2-cp312-cp312-win_amd64.whl", hash = "sha256:fb5363cf0fddd9b50525ddbf64a1e1b28ec4c6dfb28670a940cb1cf988a6786b"}, + {file = "fastrlock-0.8.2-cp35-cp35m-macosx_10_15_x86_64.whl", hash = "sha256:a74f5a92fa6e51c4f3c69b29c4662088b97be12f40652a21109605a175c81824"}, + {file = "fastrlock-0.8.2-cp35-cp35m-manylinux_2_5_i686.manylinux1_i686.whl", hash = "sha256:ccf39ad5702e33e4d335b48ef9d56e21619b529b7f7471b5211419f380329b62"}, + {file = "fastrlock-0.8.2-cp35-cp35m-manylinux_2_5_x86_64.manylinux1_x86_64.whl", hash = "sha256:66f2662c640bb71a1016a031eea6eef9d25c2bcdf7ffd1d1ddc5a58f9a1ced04"}, + {file = "fastrlock-0.8.2-cp36-cp36m-macosx_10_15_x86_64.whl", hash = "sha256:17734e2e5af4c07ddb0fb10bd484e062c22de3be6b67940b9cc6ec2f18fa61ba"}, + {file = "fastrlock-0.8.2-cp36-cp36m-manylinux_2_17_aarch64.manylinux2014_aarch64.manylinux_2_24_aarch64.whl", hash = "sha256:ab91b0c36e95d42e1041a4907e3eefd06c482d53af3c7a77be7e214cc7cd4a63"}, + {file = "fastrlock-0.8.2-cp36-cp36m-manylinux_2_17_aarch64.manylinux2014_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:b32fdf874868326351a75b1e4c02f97e802147119ae44c52d3d9da193ec34f5b"}, + {file = "fastrlock-0.8.2-cp36-cp36m-manylinux_2_5_i686.manylinux1_i686.manylinux_2_24_i686.whl", hash = "sha256:2074548a335fcf7d19ebb18d9208da9e33b06f745754466a7e001d2b1c58dd19"}, + {file = "fastrlock-0.8.2-cp36-cp36m-manylinux_2_5_i686.manylinux1_i686.whl", hash = "sha256:4fb04442b6d1e2b36c774919c6bcbe3339c61b337261d4bd57e27932589095af"}, + {file = "fastrlock-0.8.2-cp36-cp36m-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_24_x86_64.whl", hash = "sha256:1fed2f4797ad68e9982038423018cf08bec5f4ce9fed63a94a790773ed6a795c"}, + {file = "fastrlock-0.8.2-cp36-cp36m-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:e380ec4e6d8b26e389713995a43cb7fe56baea2d25fe073d4998c4821a026211"}, + {file = "fastrlock-0.8.2-cp36-cp36m-manylinux_2_5_x86_64.manylinux1_x86_64.whl", hash = "sha256:25945f962c7bd808415cfde3da624d4399d4ea71ed8918538375f16bceb79e1c"}, + {file = "fastrlock-0.8.2-cp36-cp36m-musllinux_1_1_aarch64.whl", hash = "sha256:2c1719ddc8218b01e82fb2e82e8451bd65076cb96d7bef4477194bbb4305a968"}, + {file = "fastrlock-0.8.2-cp36-cp36m-musllinux_1_1_x86_64.whl", hash = "sha256:5460c5ee6ced6d61ec8cd2324ebbe793a4960c4ffa2131ffff480e3b61c99ec5"}, + {file = "fastrlock-0.8.2-cp36-cp36m-win_amd64.whl", hash = "sha256:33145acbad8317584cd64588131c7e1e286beef6280c0009b4544c91fce171d2"}, + {file = "fastrlock-0.8.2-cp37-cp37m-manylinux_2_17_aarch64.manylinux2014_aarch64.manylinux_2_24_aarch64.whl", hash = "sha256:59344c1d46b7dec97d3f22f1cc930fafe8980b3c5bc9c9765c56738a5f1559e4"}, + {file = "fastrlock-0.8.2-cp37-cp37m-manylinux_2_17_aarch64.manylinux2014_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:b2a1c354f13f22b737621d914f3b4a8434ae69d3027a775e94b3e671756112f9"}, + {file = "fastrlock-0.8.2-cp37-cp37m-manylinux_2_5_i686.manylinux1_i686.manylinux_2_24_i686.whl", hash = "sha256:cf81e0278b645004388873e0a1f9e3bc4c9ab8c18e377b14ed1a544be4b18c9a"}, + {file = "fastrlock-0.8.2-cp37-cp37m-manylinux_2_5_i686.manylinux1_i686.whl", hash = "sha256:1b15430b93d7eb3d56f6ff690d2ebecb79ed0e58248427717eba150a508d1cd7"}, + {file = "fastrlock-0.8.2-cp37-cp37m-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_24_x86_64.whl", hash = "sha256:067edb0a0805bf61e17a251d5046af59f6e9d2b8ad01222e0ef7a0b7937d5548"}, + {file = "fastrlock-0.8.2-cp37-cp37m-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:eb31fe390f03f7ae886dcc374f1099ec88526631a4cb891d399b68181f154ff0"}, + {file = "fastrlock-0.8.2-cp37-cp37m-manylinux_2_5_x86_64.manylinux1_x86_64.whl", hash = "sha256:643e1e65b4f5b284427e61a894d876d10459820e93aa1e724dfb415117be24e0"}, + {file = "fastrlock-0.8.2-cp37-cp37m-musllinux_1_1_aarch64.whl", hash = "sha256:5dfb78dd600a12f23fc0c3ec58f81336229fdc74501ecf378d1ce5b3f2f313ea"}, + {file = "fastrlock-0.8.2-cp37-cp37m-musllinux_1_1_x86_64.whl", hash = "sha256:b8ca0fe21458457077e4cb2d81e1ebdb146a00b3e9e2db6180a773f7ea905032"}, + {file = "fastrlock-0.8.2-cp37-cp37m-win_amd64.whl", hash = "sha256:d918dfe473291e8bfd8e13223ea5cb9b317bd9f50c280923776c377f7c64b428"}, + {file = "fastrlock-0.8.2-cp38-cp38-manylinux_2_17_aarch64.manylinux2014_aarch64.manylinux_2_24_aarch64.whl", hash = "sha256:c393af77c659a38bffbca215c0bcc8629ba4299568308dd7e4ff65d62cabed39"}, + {file = "fastrlock-0.8.2-cp38-cp38-manylinux_2_17_aarch64.manylinux2014_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:73426f5eb2ecc10626c67cf86bd0af9e00d53e80e5c67d5ce8e18376d6abfa09"}, + {file = "fastrlock-0.8.2-cp38-cp38-manylinux_2_5_i686.manylinux1_i686.manylinux_2_24_i686.whl", hash = "sha256:320fd55bafee3eb069cfb5d6491f811a912758387ef2193840e2663e80e16f48"}, + {file = "fastrlock-0.8.2-cp38-cp38-manylinux_2_5_i686.manylinux1_i686.whl", hash = "sha256:8c1c91a68926421f5ccbc82c85f83bd3ba593b121a46a1b9a554b3f0dd67a4bf"}, + {file = "fastrlock-0.8.2-cp38-cp38-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_24_x86_64.whl", hash = "sha256:ad1bc61c7f6b0e58106aaab034916b6cb041757f708b07fbcdd9d6e1ac629225"}, + {file = "fastrlock-0.8.2-cp38-cp38-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:87f4e01b042c84e6090dbc4fbe3415ddd69f6bc0130382323f9d3f1b8dd71b46"}, + {file = "fastrlock-0.8.2-cp38-cp38-manylinux_2_5_x86_64.manylinux1_x86_64.whl", hash = "sha256:d34546ad2e4a480b94b6797bcc5a322b3c705c4c74c3e4e545c4a3841c1b2d59"}, + {file = "fastrlock-0.8.2-cp38-cp38-musllinux_1_1_aarch64.whl", hash = "sha256:ebb32d776b61acd49f859a1d16b9e3d84e7b46d0d92aebd58acd54dc38e96664"}, + {file = "fastrlock-0.8.2-cp38-cp38-musllinux_1_1_x86_64.whl", hash = "sha256:30bdbe4662992348132d03996700e1cf910d141d629179b967b146a22942264e"}, + {file = "fastrlock-0.8.2-cp38-cp38-win_amd64.whl", hash = "sha256:07ed3c7b3867c05a3d6be4ced200c7767000f3431b9be6da66972822dd86e8be"}, + {file = "fastrlock-0.8.2-cp39-cp39-macosx_11_0_x86_64.whl", hash = "sha256:ddf5d247f686aec853ddcc9a1234bfcc6f57b0a0670d2ad82fc25d8ae7e6a15f"}, + {file = "fastrlock-0.8.2-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.manylinux_2_24_aarch64.whl", hash = "sha256:7269bb3fc15587b0c191eecd95831d771a7d80f0c48929e560806b038ff3066c"}, + {file = "fastrlock-0.8.2-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:adcb9e77aa132cc6c9de2ffe7cf880a20aa8cdba21d367d1da1a412f57bddd5d"}, + {file = "fastrlock-0.8.2-cp39-cp39-manylinux_2_5_i686.manylinux1_i686.manylinux_2_24_i686.whl", hash = "sha256:a3b8b5d2935403f1b4b25ae324560e94b59593a38c0d2e7b6c9872126a9622ed"}, + {file = "fastrlock-0.8.2-cp39-cp39-manylinux_2_5_i686.manylinux1_i686.whl", hash = "sha256:2587cedbb36c7988e707d83f0f1175c1f882f362b5ebbee25d70218ea33d220d"}, + {file = "fastrlock-0.8.2-cp39-cp39-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_24_x86_64.whl", hash = "sha256:9af691a9861027181d4de07ed74f0aee12a9650ac60d0a07f4320bff84b5d95f"}, + {file = "fastrlock-0.8.2-cp39-cp39-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:99dd6652bd6f730beadf74ef769d38c6bbd8ee6d1c15c8d138ea680b0594387f"}, + {file = "fastrlock-0.8.2-cp39-cp39-manylinux_2_5_x86_64.manylinux1_x86_64.whl", hash = "sha256:4d63b6596368dab9e0cc66bf047e7182a56f33b34db141816a4f21f5bf958228"}, + {file = "fastrlock-0.8.2-cp39-cp39-musllinux_1_1_aarch64.whl", hash = "sha256:ff75c90663d6e8996610d435e71487daa853871ad1770dd83dc0f2fc4997241e"}, + {file = "fastrlock-0.8.2-cp39-cp39-musllinux_1_1_x86_64.whl", hash = "sha256:e27c3cd27fbd25e5223c5c992b300cd4ee8f0a75c6f222ce65838138d853712c"}, + {file = "fastrlock-0.8.2-cp39-cp39-win_amd64.whl", hash = "sha256:dd961a32a7182c3891cdebca417fda67496d5d5de6ae636962254d22723bdf52"}, + {file = "fastrlock-0.8.2.tar.gz", hash = "sha256:644ec9215cf9c4df8028d8511379a15d9c1af3e16d80e47f1b6fdc6ba118356a"}, +] + [[package]] name = "filelock" version = "3.16.1" @@ -3638,7 +3752,8 @@ test = ["big-O", "importlib-resources", "jaraco.functools", "jaraco.itertools", type = ["pytest-mypy"] [extras] -all = ["ipywidgets", "nbformat", "nmslib", "nmslib-metabrainz", "plotly", "pytorch-lightning", "rectools-lightfm", "torch", "torch"] +all = ["cupy-cuda12x", "ipywidgets", "nbformat", "nmslib", "nmslib-metabrainz", "plotly", "pytorch-lightning", "rectools-lightfm", "torch", "torch"] +cupy = ["cupy-cuda12x"] lightfm = ["rectools-lightfm"] nmslib = ["nmslib", "nmslib-metabrainz"] torch = ["pytorch-lightning", "torch", "torch"] @@ -3647,4 +3762,4 @@ visuals = ["ipywidgets", "nbformat", "plotly"] [metadata] lock-version = "2.0" python-versions = ">=3.9, <3.13" -content-hash = "0be90e416bf1a931732bf02749481c4569077fa2ee63291abd6965d604c0d04a" +content-hash = "894765cd6220c87dd20ad10a0926b39298bfbc2d215ef0091dbf0849a468b349" diff --git a/pyproject.toml b/pyproject.toml index 4b13ac67..50a9b973 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -84,6 +84,7 @@ pytorch-lightning = {version = ">=1.6.0, <3.0.0", optional = true} ipywidgets = {version = ">=7.7,<8.2", optional = true} plotly = {version="^5.22.0", optional = true} nbformat = {version = ">=4.2.0", optional = true} +cupy-cuda12x = {version = "^13.3.0", python = "<3.13", optional = true} [tool.poetry.extras] @@ -91,11 +92,13 @@ lightfm = ["rectools-lightfm"] nmslib = ["nmslib", "nmslib-metabrainz"] torch = ["torch", "pytorch-lightning"] visuals = ["ipywidgets", "plotly", "nbformat"] +cupy = ["cupy-cuda12x"] all = [ "rectools-lightfm", "nmslib", "nmslib-metabrainz", "torch", "pytorch-lightning", "ipywidgets", "plotly", "nbformat", + "cupy-cuda12x", ] diff --git a/rectools/models/pure_svd.py b/rectools/models/pure_svd.py index 3d5d5de5..a0ba2153 100644 --- a/rectools/models/pure_svd.py +++ b/rectools/models/pure_svd.py @@ -15,6 +15,7 @@ """SVD Model.""" import typing as tp +import warnings import numpy as np import typing_extensions as tpe @@ -26,6 +27,15 @@ from rectools.models.rank import Distance from rectools.models.vector import Factors, VectorModel +try: + import cupy as cp + from cupyx.scipy.sparse import csr_matrix as cp_csr_matrix + from cupyx.scipy.sparse.linalg import svds as cupy_svds +except ImportError: # pragma: no cover + cupy_svds = None + cp_csr_matrix = None + cp = None + class PureSVDModelConfig(ModelConfig): """Config for `PureSVD` model.""" @@ -34,6 +44,7 @@ class PureSVDModelConfig(ModelConfig): tol: float = 0 maxiter: tp.Optional[int] = None random_state: tp.Optional[int] = None + use_gpu: tp.Optional[bool] = False recommend_n_threads: int = 0 recommend_use_gpu_ranking: bool = True @@ -53,7 +64,9 @@ class PureSVDModel(VectorModel[PureSVDModelConfig]): maxiter : int, optional, default ``None`` Maximum number of iterations. random_state : int, optional, default ``None`` - Pseudorandom number generator state used to generate resamples. + Pseudorandom number generator state used to generate resamples. Omitted if use_gpu is True. + use_gpu : bool, default ``False`` + If ``True``, `cupyx.scipy.sparse.linalg.svds()` is used instead of SciPy. CuPy is required. verbose : int, default ``0`` Degree of verbose output. If ``0``, no output will be provided. recommend_n_threads: int, default 0 @@ -83,6 +96,7 @@ def __init__( tol: float = 0, maxiter: tp.Optional[int] = None, random_state: tp.Optional[int] = None, + use_gpu: tp.Optional[bool] = False, verbose: int = 0, recommend_n_threads: int = 0, recommend_use_gpu_ranking: bool = True, @@ -93,6 +107,16 @@ def __init__( self.tol = tol self.maxiter = maxiter self.random_state = random_state + self._use_gpu = use_gpu # for making a config + if use_gpu: # pragma: no cover + if not cp: + warnings.warn("Forced to use CPU. CuPy is not available.") + use_gpu = False + elif not cp.cuda.is_available(): + warnings.warn("Forced to use CPU. GPU is not available.") + use_gpu = False + + self.use_gpu = use_gpu self.recommend_n_threads = recommend_n_threads self.recommend_use_gpu_ranking = recommend_use_gpu_ranking @@ -106,6 +130,7 @@ def _get_config(self) -> PureSVDModelConfig: tol=self.tol, maxiter=self.maxiter, random_state=self.random_state, + use_gpu=self._use_gpu, verbose=self.verbose, recommend_n_threads=self.recommend_n_threads, recommend_use_gpu_ranking=self.recommend_use_gpu_ranking, @@ -118,6 +143,7 @@ def _from_config(cls, config: PureSVDModelConfig) -> tpe.Self: tol=config.tol, maxiter=config.maxiter, random_state=config.random_state, + use_gpu=config.use_gpu, verbose=config.verbose, recommend_n_threads=config.recommend_n_threads, recommend_use_gpu_ranking=config.recommend_use_gpu_ranking, @@ -126,10 +152,19 @@ def _from_config(cls, config: PureSVDModelConfig) -> tpe.Self: def _fit(self, dataset: Dataset) -> None: # type: ignore ui_csr = dataset.get_user_item_matrix(include_weights=True) - u, sigma, vt = svds(ui_csr, k=self.factors, tol=self.tol, maxiter=self.maxiter, random_state=self.random_state) + if self.use_gpu: # pragma: no cover + ui_csr = cp_csr_matrix(ui_csr) + # To prevent IndexError, we need to subtract 1 from factors + u, sigma, vt = cupy_svds(ui_csr.toarray(), k=self.factors - 1, tol=self.tol, maxiter=self.maxiter) + u = u.get() + self.item_factors = (cp.diag(sigma) @ vt).T.get() + else: + u, sigma, vt = svds( + ui_csr, k=self.factors, tol=self.tol, maxiter=self.maxiter, random_state=self.random_state + ) + self.item_factors = (np.diag(sigma) @ vt).T self.user_factors = u - self.item_factors = (np.diag(sigma) @ vt).T def _get_users_factors(self, dataset: Dataset) -> Factors: return Factors(self.user_factors) diff --git a/tests/models/test_pure_svd.py b/tests/models/test_pure_svd.py index 71a7ed67..408023da 100644 --- a/tests/models/test_pure_svd.py +++ b/tests/models/test_pure_svd.py @@ -17,6 +17,7 @@ import numpy as np import pandas as pd import pytest +from pytest_mock import MockerFixture from rectools import Columns from rectools.dataset import Dataset @@ -32,6 +33,17 @@ assert_second_fit_refits_model, ) +try: + import cupy as cp # pylint: disable=import-error, unused-import +except ImportError: # pragma: no cover + cp = None + +try: + HAS_CUDA = cp.is_available() if cp else False +except Exception: # pragma: no cover # pylint: disable=broad-except + # If CUDA isn't installed cupy raises CUDARuntimeError: + HAS_CUDA = False + class TestPureSVDModel: @@ -72,6 +84,7 @@ def test_basic( expected: pd.DataFrame, use_gpu_ranking: bool, ) -> None: + model = PureSVDModel(factors=2, recommend_use_gpu_ranking=use_gpu_ranking).fit(dataset) actual = model.recommend( users=np.array([10, 20]), @@ -85,6 +98,52 @@ def test_basic( actual, ) + # SciPy's svds and cupy's svds results can be different and use_gpu fallback causes errors + @pytest.mark.skipif(cp is None or not HAS_CUDA, reason="CUDA is not available") + @pytest.mark.parametrize( + "filter_viewed,expected", + ( + ( + True, + pd.DataFrame( + { + Columns.User: [10, 10, 20, 20], + Columns.Item: [15, 13, 14, 15], + Columns.Rank: [1, 2, 1, 2], + } + ), + ), + ( + False, + pd.DataFrame( + { + Columns.User: [10, 10, 20, 20], + Columns.Item: [11, 12, 11, 12], + Columns.Rank: [1, 2, 1, 2], + } + ), + ), + ), + ) + def test_basic_gpu( + self, + dataset: Dataset, + filter_viewed: bool, + expected: pd.DataFrame, + ) -> None: + model = PureSVDModel(factors=2, use_gpu=True, recommend_use_gpu_ranking=True).fit(dataset) + actual = model.recommend( + users=np.array([10, 20]), + dataset=dataset, + k=2, + filter_viewed=filter_viewed, + ) + pd.testing.assert_frame_equal(actual.drop(columns=Columns.Score), expected) + pd.testing.assert_frame_equal( + actual.sort_values([Columns.User, Columns.Score], ascending=[True, False]).reset_index(drop=True), + actual, + ) + @pytest.mark.parametrize( "filter_viewed,expected", ( @@ -279,12 +338,16 @@ def test_dumps_loads(self, dataset: Dataset) -> None: class TestPureSVDModelConfiguration: - def test_from_config(self) -> None: + @pytest.mark.parametrize("use_gpu", (False, True)) + def test_from_config(self, mocker: MockerFixture, use_gpu: bool) -> None: + mocker.patch("rectools.models.pure_svd.cp", return_value=True) + mocker.patch("rectools.models.pure_svd.cp.cuda.is_available", return_value=True) config = { "factors": 100, "tol": 0, "maxiter": 100, "random_state": 32, + "use_gpu": use_gpu, "verbose": 0, } model = PureSVDModel.from_config(config) @@ -296,12 +359,18 @@ def test_from_config(self) -> None: @pytest.mark.parametrize("random_state", (None, 42)) @pytest.mark.parametrize("simple_types", (False, True)) - def test_get_config(self, random_state: tp.Optional[int], simple_types: bool) -> None: + @pytest.mark.parametrize("use_gpu", (False, True)) + def test_get_config( + self, mocker: MockerFixture, random_state: tp.Optional[int], simple_types: bool, use_gpu: bool + ) -> None: + mocker.patch("rectools.models.pure_svd.cp.cuda.is_available", return_value=True) + mocker.patch("rectools.models.pure_svd.cp", return_value=True) model = PureSVDModel( factors=100, - tol=1, + tol=1.0, maxiter=100, random_state=random_state, + use_gpu=use_gpu, verbose=1, recommend_n_threads=2, recommend_use_gpu_ranking=False, @@ -310,9 +379,10 @@ def test_get_config(self, random_state: tp.Optional[int], simple_types: bool) -> expected = { "cls": "PureSVDModel" if simple_types else PureSVDModel, "factors": 100, - "tol": 1, + "tol": 1.0, "maxiter": 100, "random_state": random_state, + "use_gpu": use_gpu, "verbose": 1, "recommend_n_threads": 2, "recommend_use_gpu_ranking": False,