-
Notifications
You must be signed in to change notification settings - Fork 50
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add support for LogitsTemperatureScaler in the new ModelOutput API #815
Conversation
Click to view CI ResultsGitHub pull request #815 of commit 6ebbb78330f12a3e5bc47283faccee5d6fda506f, no merge conflicts. Running as SYSTEM Setting status of 6ebbb78330f12a3e5bc47283faccee5d6fda506f to PENDING with url https://10.20.13.93:8080/job/merlin_models/1552/console and message: 'Pending' Using context: Jenkins Building on master in workspace /var/jenkins_home/workspace/merlin_models using credential nvidia-merlin-bot > git rev-parse --is-inside-work-tree # timeout=10 Fetching changes from the remote Git repository > git config remote.origin.url https://github.com/NVIDIA-Merlin/models/ # timeout=10 Fetching upstream changes from https://github.com/NVIDIA-Merlin/models/ > git --version # timeout=10 using GIT_ASKPASS to set credentials This is the bot credentials for our CI/CD > git fetch --tags --force --progress -- https://github.com/NVIDIA-Merlin/models/ +refs/pull/815/*:refs/remotes/origin/pr/815/* # timeout=10 > git rev-parse 6ebbb78330f12a3e5bc47283faccee5d6fda506f^{commit} # timeout=10 Checking out Revision 6ebbb78330f12a3e5bc47283faccee5d6fda506f (detached) > git config core.sparsecheckout # timeout=10 > git checkout -f 6ebbb78330f12a3e5bc47283faccee5d6fda506f # timeout=10 Commit message: "fix logit scaler for new ModelOutput API" > git rev-list --no-walk 36292c0ed34cdbc9c5d610ccd896e4ea09fd32ec # timeout=10 [merlin_models] $ /bin/bash /tmp/jenkins7044676518061846842.sh Looking in indexes: https://pypi.org/simple, https://pypi.ngc.nvidia.com Requirement already satisfied: testbook in /usr/local/lib/python3.8/dist-packages (0.4.2) Requirement already satisfied: nbformat>=5.0.4 in /usr/local/lib/python3.8/dist-packages (from testbook) (5.5.0) Requirement already satisfied: nbclient>=0.4.0 in /usr/local/lib/python3.8/dist-packages (from testbook) (0.6.8) Requirement already satisfied: fastjsonschema in /usr/local/lib/python3.8/dist-packages (from nbformat>=5.0.4->testbook) (2.16.1) Requirement already satisfied: jsonschema>=2.6 in /usr/local/lib/python3.8/dist-packages (from nbformat>=5.0.4->testbook) (4.16.0) Requirement already satisfied: jupyter_core in /usr/local/lib/python3.8/dist-packages (from nbformat>=5.0.4->testbook) (4.11.1) Requirement already satisfied: traitlets>=5.1 in /usr/local/lib/python3.8/dist-packages (from nbformat>=5.0.4->testbook) (5.4.0) Requirement already satisfied: jupyter-client>=6.1.5 in /usr/local/lib/python3.8/dist-packages (from nbclient>=0.4.0->testbook) (7.3.5) Requirement already satisfied: nest-asyncio in /usr/local/lib/python3.8/dist-packages (from nbclient>=0.4.0->testbook) (1.5.5) Requirement already satisfied: attrs>=17.4.0 in /usr/local/lib/python3.8/dist-packages (from jsonschema>=2.6->nbformat>=5.0.4->testbook) (22.1.0) Requirement already satisfied: importlib-resources>=1.4.0; python_version < "3.9" in /usr/local/lib/python3.8/dist-packages (from jsonschema>=2.6->nbformat>=5.0.4->testbook) (5.9.0) Requirement already satisfied: pkgutil-resolve-name>=1.3.10; python_version < "3.9" in /usr/local/lib/python3.8/dist-packages (from jsonschema>=2.6->nbformat>=5.0.4->testbook) (1.3.10) Requirement already satisfied: pyrsistent!=0.17.0,!=0.17.1,!=0.17.2,>=0.14.0 in /usr/local/lib/python3.8/dist-packages (from jsonschema>=2.6->nbformat>=5.0.4->testbook) (0.18.1) Requirement already satisfied: entrypoints in /usr/local/lib/python3.8/dist-packages (from jupyter-client>=6.1.5->nbclient>=0.4.0->testbook) (0.4) Requirement already satisfied: python-dateutil>=2.8.2 in /usr/local/lib/python3.8/dist-packages (from jupyter-client>=6.1.5->nbclient>=0.4.0->testbook) (2.8.2) Requirement already satisfied: pyzmq>=23.0 in /usr/local/lib/python3.8/dist-packages (from jupyter-client>=6.1.5->nbclient>=0.4.0->testbook) (24.0.0) Requirement already satisfied: tornado>=6.2 in /usr/local/lib/python3.8/dist-packages (from jupyter-client>=6.1.5->nbclient>=0.4.0->testbook) (6.2) Requirement already satisfied: zipp>=3.1.0; python_version < "3.10" in /usr/local/lib/python3.8/dist-packages (from importlib-resources>=1.4.0; python_version < "3.9"->jsonschema>=2.6->nbformat>=5.0.4->testbook) (3.8.1) Requirement already satisfied: six>=1.5 in /var/jenkins_home/.local/lib/python3.8/site-packages (from python-dateutil>=2.8.2->jupyter-client>=6.1.5->nbclient>=0.4.0->testbook) (1.15.0) ============================= test session starts ============================== platform linux -- Python 3.8.10, pytest-7.1.3, pluggy-1.0.0 rootdir: /var/jenkins_home/workspace/merlin_models/models, configfile: pyproject.toml plugins: anyio-3.6.1, xdist-2.5.0, forked-1.4.0, cov-4.0.0 collected 775 items |
rerun tests |
Click to view CI ResultsGitHub pull request #815 of commit 6ebbb78330f12a3e5bc47283faccee5d6fda506f, no merge conflicts. Running as SYSTEM Setting status of 6ebbb78330f12a3e5bc47283faccee5d6fda506f to PENDING with url https://10.20.13.93:8080/job/merlin_models/1555/console and message: 'Pending' Using context: Jenkins Building on master in workspace /var/jenkins_home/workspace/merlin_models using credential nvidia-merlin-bot > git rev-parse --is-inside-work-tree # timeout=10 Fetching changes from the remote Git repository > git config remote.origin.url https://github.com/NVIDIA-Merlin/models/ # timeout=10 Fetching upstream changes from https://github.com/NVIDIA-Merlin/models/ > git --version # timeout=10 using GIT_ASKPASS to set credentials This is the bot credentials for our CI/CD > git fetch --tags --force --progress -- https://github.com/NVIDIA-Merlin/models/ +refs/pull/815/*:refs/remotes/origin/pr/815/* # timeout=10 > git rev-parse 6ebbb78330f12a3e5bc47283faccee5d6fda506f^{commit} # timeout=10 Checking out Revision 6ebbb78330f12a3e5bc47283faccee5d6fda506f (detached) > git config core.sparsecheckout # timeout=10 > git checkout -f 6ebbb78330f12a3e5bc47283faccee5d6fda506f # timeout=10 Commit message: "fix logit scaler for new ModelOutput API" > git rev-list --no-walk 2c2062dcb9c8ec545bf784415cf3acafca86d14d # timeout=10 [merlin_models] $ /bin/bash /tmp/jenkins7166896445195625129.sh Looking in indexes: https://pypi.org/simple, https://pypi.ngc.nvidia.com Requirement already satisfied: testbook in /usr/local/lib/python3.8/dist-packages (0.4.2) Requirement already satisfied: nbformat>=5.0.4 in /usr/local/lib/python3.8/dist-packages (from testbook) (5.5.0) Requirement already satisfied: nbclient>=0.4.0 in /usr/local/lib/python3.8/dist-packages (from testbook) (0.6.8) Requirement already satisfied: fastjsonschema in /usr/local/lib/python3.8/dist-packages (from nbformat>=5.0.4->testbook) (2.16.1) Requirement already satisfied: jsonschema>=2.6 in /usr/local/lib/python3.8/dist-packages (from nbformat>=5.0.4->testbook) (4.16.0) Requirement already satisfied: jupyter_core in /usr/local/lib/python3.8/dist-packages (from nbformat>=5.0.4->testbook) (4.11.1) Requirement already satisfied: traitlets>=5.1 in /usr/local/lib/python3.8/dist-packages (from nbformat>=5.0.4->testbook) (5.4.0) Requirement already satisfied: jupyter-client>=6.1.5 in /usr/local/lib/python3.8/dist-packages (from nbclient>=0.4.0->testbook) (7.3.5) Requirement already satisfied: nest-asyncio in /usr/local/lib/python3.8/dist-packages (from nbclient>=0.4.0->testbook) (1.5.5) Requirement already satisfied: attrs>=17.4.0 in /usr/local/lib/python3.8/dist-packages (from jsonschema>=2.6->nbformat>=5.0.4->testbook) (22.1.0) Requirement already satisfied: importlib-resources>=1.4.0; python_version < "3.9" in /usr/local/lib/python3.8/dist-packages (from jsonschema>=2.6->nbformat>=5.0.4->testbook) (5.9.0) Requirement already satisfied: pkgutil-resolve-name>=1.3.10; python_version < "3.9" in /usr/local/lib/python3.8/dist-packages (from jsonschema>=2.6->nbformat>=5.0.4->testbook) (1.3.10) Requirement already satisfied: pyrsistent!=0.17.0,!=0.17.1,!=0.17.2,>=0.14.0 in /usr/local/lib/python3.8/dist-packages (from jsonschema>=2.6->nbformat>=5.0.4->testbook) (0.18.1) Requirement already satisfied: entrypoints in /usr/local/lib/python3.8/dist-packages (from jupyter-client>=6.1.5->nbclient>=0.4.0->testbook) (0.4) Requirement already satisfied: python-dateutil>=2.8.2 in /usr/local/lib/python3.8/dist-packages (from jupyter-client>=6.1.5->nbclient>=0.4.0->testbook) (2.8.2) Requirement already satisfied: pyzmq>=23.0 in /usr/local/lib/python3.8/dist-packages (from jupyter-client>=6.1.5->nbclient>=0.4.0->testbook) (24.0.0) Requirement already satisfied: tornado>=6.2 in /usr/local/lib/python3.8/dist-packages (from jupyter-client>=6.1.5->nbclient>=0.4.0->testbook) (6.2) Requirement already satisfied: zipp>=3.1.0; python_version < "3.10" in /usr/local/lib/python3.8/dist-packages (from importlib-resources>=1.4.0; python_version < "3.9"->jsonschema>=2.6->nbformat>=5.0.4->testbook) (3.8.1) Requirement already satisfied: six>=1.5 in /var/jenkins_home/.local/lib/python3.8/site-packages (from python-dateutil>=2.8.2->jupyter-client>=6.1.5->nbclient>=0.4.0->testbook) (1.15.0) ============================= test session starts ============================== platform linux -- Python 3.8.10, pytest-7.1.3, pluggy-1.0.0 rootdir: /var/jenkins_home/workspace/merlin_models/models, configfile: pyproject.toml plugins: anyio-3.6.1, xdist-2.5.0, forked-1.4.0, cov-4.0.0 collected 775 items |
Click to view CI ResultsGitHub pull request #815 of commit 2c6405ce9fc2b75491206f235e4ca16685e1e4e2, no merge conflicts. Running as SYSTEM Setting status of 2c6405ce9fc2b75491206f235e4ca16685e1e4e2 to PENDING with url https://10.20.13.93:8080/job/merlin_models/1561/console and message: 'Pending' Using context: Jenkins Building on master in workspace /var/jenkins_home/workspace/merlin_models using credential nvidia-merlin-bot > git rev-parse --is-inside-work-tree # timeout=10 Fetching changes from the remote Git repository > git config remote.origin.url https://github.com/NVIDIA-Merlin/models/ # timeout=10 Fetching upstream changes from https://github.com/NVIDIA-Merlin/models/ > git --version # timeout=10 using GIT_ASKPASS to set credentials This is the bot credentials for our CI/CD > git fetch --tags --force --progress -- https://github.com/NVIDIA-Merlin/models/ +refs/pull/815/*:refs/remotes/origin/pr/815/* # timeout=10 > git rev-parse 2c6405ce9fc2b75491206f235e4ca16685e1e4e2^{commit} # timeout=10 Checking out Revision 2c6405ce9fc2b75491206f235e4ca16685e1e4e2 (detached) > git config core.sparsecheckout # timeout=10 > git checkout -f 2c6405ce9fc2b75491206f235e4ca16685e1e4e2 # timeout=10 Commit message: "Merge branch 'main' into fix-logit-scaler" > git rev-list --no-walk 76f7861477466e2003bd4c37851d4b1a93234fd4 # timeout=10 [merlin_models] $ /bin/bash /tmp/jenkins14718599634250335050.sh Looking in indexes: https://pypi.org/simple, https://pypi.ngc.nvidia.com Requirement already satisfied: testbook in /usr/local/lib/python3.8/dist-packages (0.4.2) Requirement already satisfied: nbformat>=5.0.4 in /usr/local/lib/python3.8/dist-packages (from testbook) (5.5.0) Requirement already satisfied: nbclient>=0.4.0 in /usr/local/lib/python3.8/dist-packages (from testbook) (0.6.8) Requirement already satisfied: fastjsonschema in /usr/local/lib/python3.8/dist-packages (from nbformat>=5.0.4->testbook) (2.16.1) Requirement already satisfied: jsonschema>=2.6 in /usr/local/lib/python3.8/dist-packages (from nbformat>=5.0.4->testbook) (4.16.0) Requirement already satisfied: jupyter_core in /usr/local/lib/python3.8/dist-packages (from nbformat>=5.0.4->testbook) (4.11.1) Requirement already satisfied: traitlets>=5.1 in /usr/local/lib/python3.8/dist-packages (from nbformat>=5.0.4->testbook) (5.4.0) Requirement already satisfied: jupyter-client>=6.1.5 in /usr/local/lib/python3.8/dist-packages (from nbclient>=0.4.0->testbook) (7.3.5) Requirement already satisfied: nest-asyncio in /usr/local/lib/python3.8/dist-packages (from nbclient>=0.4.0->testbook) (1.5.5) Requirement already satisfied: attrs>=17.4.0 in /usr/local/lib/python3.8/dist-packages (from jsonschema>=2.6->nbformat>=5.0.4->testbook) (22.1.0) Requirement already satisfied: importlib-resources>=1.4.0; python_version < "3.9" in /usr/local/lib/python3.8/dist-packages (from jsonschema>=2.6->nbformat>=5.0.4->testbook) (5.9.0) Requirement already satisfied: pkgutil-resolve-name>=1.3.10; python_version < "3.9" in /usr/local/lib/python3.8/dist-packages (from jsonschema>=2.6->nbformat>=5.0.4->testbook) (1.3.10) Requirement already satisfied: pyrsistent!=0.17.0,!=0.17.1,!=0.17.2,>=0.14.0 in /usr/local/lib/python3.8/dist-packages (from jsonschema>=2.6->nbformat>=5.0.4->testbook) (0.18.1) Requirement already satisfied: entrypoints in /usr/local/lib/python3.8/dist-packages (from jupyter-client>=6.1.5->nbclient>=0.4.0->testbook) (0.4) Requirement already satisfied: python-dateutil>=2.8.2 in /usr/local/lib/python3.8/dist-packages (from jupyter-client>=6.1.5->nbclient>=0.4.0->testbook) (2.8.2) Requirement already satisfied: pyzmq>=23.0 in /usr/local/lib/python3.8/dist-packages (from jupyter-client>=6.1.5->nbclient>=0.4.0->testbook) (24.0.0) Requirement already satisfied: tornado>=6.2 in /usr/local/lib/python3.8/dist-packages (from jupyter-client>=6.1.5->nbclient>=0.4.0->testbook) (6.2) Requirement already satisfied: zipp>=3.1.0; python_version < "3.10" in /usr/local/lib/python3.8/dist-packages (from importlib-resources>=1.4.0; python_version < "3.9"->jsonschema>=2.6->nbformat>=5.0.4->testbook) (3.8.1) Requirement already satisfied: six>=1.5 in /var/jenkins_home/.local/lib/python3.8/site-packages (from python-dateutil>=2.8.2->jupyter-client>=6.1.5->nbclient>=0.4.0->testbook) (1.15.0) ============================= test session starts ============================== platform linux -- Python 3.8.10, pytest-7.1.3, pluggy-1.0.0 rootdir: /var/jenkins_home/workspace/merlin_models/models, configfile: pyproject.toml plugins: anyio-3.6.1, xdist-2.5.0, forked-1.4.0, cov-4.0.0 collected 775 items |
Goals:
As a follow-up to the fixes needed to reproduce the research scripts with the new API (see Add L2-batch regularization to EmbeddingTable #812 and Add LogQ correction support to the new ModelOutput API #811), I defined integration tests for TwoTower and MatrixFactorization. I observed that TwoTower is still returning a smaller validation performance compared to the old API.
After debugging, the issue is now related to
LogitsTemperatureScaler
not including the logic of scaling the logits if the class variableapply_on_call_outputs
is set to True (which is the default).In fact, the logits scaler in ModelOutput is set here with the default value True, so the logits are never scaled with the provided temperature.
apply_on_call_outputs
is specific to Merlin supporting two versions of the API so I believe we should not expose this argument to the user. It might be confusing for the user to understand when to setapply_on_call_outputs
to True or False.So I refactored
LogitsTemperatureScaler
to have similar logic we have inPopularityLogitsCorrection
Figures below show the 3 runs of integration tests: Old API, New API, New API with logit scaler fix
Implementation Details 🚧
copy_with_updates
to the Prediction NamedTuple class to update only necessary tensors in the post transforms blocks.LogitsTemperatureScaler
block.get_config
method toLogitsTemperatureScaler
to store the valuetemperature
when saving the model.Testing Details 🔍