-
Notifications
You must be signed in to change notification settings - Fork 210
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge branch 'PriorLabs:main' into issue#138
- Loading branch information
Showing
23 changed files
with
1,657 additions
and
369 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,92 @@ | ||
name: Bug Report | ||
description: Create a report to help us reproduce and correct the bug | ||
labels: [bug] | ||
|
||
body: | ||
- type: markdown | ||
attributes: | ||
value: > | ||
#### Before submitting a bug, please make sure the issue hasn't been already | ||
addressed by searching through [the past issues](https://github.com/PriorLabs/TabPFN/issues) | ||
and updating your tabpfn version by running `pip install tabpfn --upgrade`. | ||
- type: textarea | ||
attributes: | ||
label: Describe the bug | ||
description: > | ||
A clear and concise description of what the bug is. | ||
validations: | ||
required: true | ||
- type: textarea | ||
attributes: | ||
label: Steps/Code to Reproduce | ||
description: | | ||
Please add a minimal code example (some advices [here](https://scikit-learn.org/dev/developers/minimal_reproducer.html)) that can reproduce the error when running it. Be as succinct as possible, **do not depend on external data files**: instead you can generate synthetic data using `numpy.random`, [sklearn.datasets.make_regression](https://scikit-learn.org/stable/modules/generated/sklearn.datasets.make_regression.html), [sklearn.datasets.make_classification](https://scikit-learn.org/stable/modules/generated/sklearn.datasets.make_classification.html) or a few lines of Python code. Example: | ||
```python | ||
import numpy as np | ||
from sklearn.datasets import make_classification | ||
from tabpfn import TabPFNClassifier | ||
# Generate synthetic classification data | ||
X, y = make_classification(n_samples=100, n_features=10, random_state=42) | ||
# Initialize and train TabPFN | ||
classifier = TabPFNClassifier(N_estimators=4) | ||
classifier.fit(X, y) | ||
# Make predictions | ||
y_pred = classifier.predict(X) | ||
``` | ||
If the code is too long, feel free to put it in a public gist and link it in the issue: https://gist.github.com. | ||
In short, **we are going to copy-paste your code** to run it and we expect to get the same result as you. | ||
We acknowledge that crafting a minimal reproducible code example requires some effort on your side but it really helps the maintainers quickly reproduce the problem and analyze its cause without any ambiguity. Ambiguous bug reports tend to be slower to fix because they will require more effort and back and forth discussion between the maintainers and the reporter to pin-point the precise conditions necessary to reproduce the problem. | ||
placeholder: | | ||
``` | ||
Sample code to reproduce the problem | ||
``` | ||
validations: | ||
required: false | ||
- type: textarea | ||
attributes: | ||
label: Expected Results | ||
description: > | ||
Please paste or describe the expected results. | ||
placeholder: > | ||
Example: No error is thrown. | ||
validations: | ||
required: false | ||
- type: textarea | ||
attributes: | ||
label: Actual Results | ||
description: | | ||
Please paste or describe the results you observe instead of the expected results. If you observe an error, please paste the error message including the **full traceback** of the exception. For instance the code above raises the following exception: | ||
```python-traceback | ||
--------------------------------------------------------------------------- | ||
Traceback (most recent call last): | ||
File "/Users/leo/VSCProjects/new/TabPFN/test_git.py", line 9, in <module> | ||
classifier = TabPFNClassifier(N_estimators=4) | ||
TypeError: TabPFNClassifier.__init__() got an unexpected keyword argument 'N_estimators' | ||
``` | ||
placeholder: > | ||
Please paste or specifically describe the actual output or traceback. | ||
validations: | ||
required: false | ||
- type: textarea | ||
attributes: | ||
label: Versions | ||
render: shell | ||
description: | | ||
Please run the following and paste the output below. | ||
```python | ||
import tabpfn; tabpfn.display_debug_info() | ||
``` | ||
validations: | ||
required: false | ||
- type: markdown | ||
attributes: | ||
value: > | ||
Thanks for contributing 🎉! |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,17 @@ | ||
name: Documentation improvement | ||
description: Create a report to help us improve the documentation. Alternatively you can just open a pull request with the suggested change. | ||
labels: [documentation] | ||
|
||
body: | ||
- type: textarea | ||
attributes: | ||
label: Describe the issue linked to the documentation | ||
description: > | ||
Tell us about the confusion introduced in the documentation. | ||
validations: | ||
required: true | ||
- type: textarea | ||
attributes: | ||
label: Suggest a potential alternative/fix | ||
description: > | ||
Tell us how we could improve the documentation in this regard. |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,30 @@ | ||
name: Feature request | ||
description: Suggest a new feature or enhancement to an existing feature | ||
labels: [enhancement] | ||
|
||
body: | ||
- type: textarea | ||
attributes: | ||
label: Describe the workflow you want to enable | ||
validations: | ||
required: true | ||
- type: textarea | ||
attributes: | ||
label: Describe your proposed solution | ||
validations: | ||
required: true | ||
- type: textarea | ||
attributes: | ||
label: Describe alternatives you've considered, if relevant | ||
- type: textarea | ||
attributes: | ||
label: Additional context | ||
- type: dropdown | ||
attributes: | ||
label: Impact | ||
description: How would you rate the impact of this feature for you? | ||
options: | ||
- Critical (Blocking work) | ||
- High (Major improvement) | ||
- Medium (Significant enhancement) | ||
- Low (Nice to have) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,15 @@ | ||
version: 2 | ||
updates: | ||
# Python dependencies | ||
- package-ecosystem: "pip" | ||
directory: "/" | ||
schedule: | ||
interval: "weekly" | ||
open-pull-requests-limit: 3 | ||
|
||
# GitHub Actions | ||
- package-ecosystem: "github-actions" | ||
directory: "/" | ||
schedule: | ||
interval: "weekly" | ||
open-pull-requests-limit: 3 |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,125 @@ | ||
name: In pull request | ||
on: | ||
pull_request: | ||
branches: | ||
- main | ||
|
||
jobs: | ||
check_python_linting: | ||
name: Ruff Linting & Formatting | ||
runs-on: ubuntu-latest | ||
steps: | ||
- uses: actions/checkout@v4 | ||
- uses: chartboost/ruff-action@v1 | ||
with: | ||
src: "./src ./tests" | ||
version: 0.8.6 | ||
- uses: chartboost/ruff-action@v1 | ||
with: | ||
src: "./src ./tests" | ||
version: 0.8.6 | ||
args: 'format --check' | ||
|
||
test_compatibility: | ||
name: Test Package Compatibility | ||
strategy: | ||
fail-fast: false | ||
matrix: | ||
include: | ||
- os: ubuntu-latest | ||
python-version: "3.9" | ||
dependency-set: minimum | ||
- os: macos-13 # macos-latest doesn't work with python 3.10 | ||
# https://github.com/actions/setup-python/issues/855 | ||
python-version: "3.9" | ||
dependency-set: minimum | ||
- os: ubuntu-latest | ||
python-version: "3.12" | ||
dependency-set: maximum | ||
- os: macos-latest | ||
python-version: "3.12" | ||
dependency-set: maximum | ||
runs-on: ${{ matrix.os }} | ||
|
||
steps: | ||
- uses: actions/checkout@v4 | ||
|
||
- name: Set up Python ${{ matrix.python-version }} | ||
uses: actions/setup-python@v5 | ||
with: | ||
python-version: ${{ matrix.python-version }} | ||
architecture: x64 | ||
|
||
- name: Install uv | ||
uses: astral-sh/setup-uv@v5 | ||
with: | ||
enable-cache: true | ||
|
||
- name: Generate requirements file for minimum dependencies | ||
if: matrix.dependency-set == 'minimum' | ||
run: | | ||
python << EOF | ||
import re | ||
with open('pyproject.toml', 'r') as f: | ||
content = f.read() | ||
# Find dependencies section using regex | ||
deps_match = re.search(r'dependencies\s*=\s*\[(.*?)\]', content, re.DOTALL) | ||
if deps_match: | ||
deps = [d.strip(' "\'') for d in deps_match.group(1).strip().split('\n') if d.strip()] | ||
min_reqs = [] | ||
for dep in deps: | ||
match = re.match(r'([^>=<\s]+)\s*>=\s*([^,\s"\']+)', dep) | ||
if match: | ||
package, min_ver = match.groups() | ||
min_reqs.append(f"{package}=={min_ver}") | ||
with open('requirements.txt', 'w') as f: | ||
f.write('\n'.join(min_reqs)) | ||
EOF | ||
- name: Generate requirements file for maximum dependencies | ||
if: matrix.dependency-set == 'maximum' | ||
run: | | ||
python << EOF | ||
import re | ||
with open('pyproject.toml', 'r') as f: | ||
content = f.read() | ||
# Find dependencies section using regex | ||
deps_match = re.search(r'dependencies\s*=\s*\[(.*?)\]', content, re.DOTALL) | ||
if deps_match: | ||
deps = [d.strip(' "\'') for d in deps_match.group(1).strip().split('\n') if d.strip()] | ||
max_reqs = [] | ||
for dep in deps: | ||
# Check for maximum version constraint | ||
max_version_match = re.search(r'([^>=<\s]+).*?<\s*([^,\s"\']+)', dep) | ||
if max_version_match: | ||
# If there's a max version, use the version just below it | ||
package, max_ver = max_version_match.groups() | ||
max_reqs.append(f"{package}<{max_ver}") | ||
else: | ||
# If no max version, just use the package name | ||
package = re.match(r'([^>=<\s]+)', dep).group(1) | ||
max_reqs.append(package) | ||
with open('requirements.txt', 'w') as f: | ||
f.write('\n'.join(max_reqs)) | ||
EOF | ||
- name: Install dependencies | ||
run: | | ||
uv pip install --system --no-deps . | ||
# onnx is required for onnx export tests | ||
# we don't install all dev dependencies here for speed | ||
uv pip install --system -r requirements.txt | ||
uv pip install --system pytest onnx | ||
- name: Initialize submodules | ||
run: git submodule update --init --recursive | ||
|
||
- name: Run Tests | ||
run: | | ||
pytest tests/ |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -198,4 +198,4 @@ | |
shall be required. | ||
|
||
|
||
END OF TERMS AND CONDITIONS | ||
END OF TERMS AND CONDITIONS |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.