Skip to content

Commit

Permalink
Merge branch 'PriorLabs:main' into issue#138
Browse files Browse the repository at this point in the history
  • Loading branch information
kgovind0001 authored Feb 10, 2025
2 parents 2a4974b + 634efcd commit 23a08e7
Show file tree
Hide file tree
Showing 23 changed files with 1,657 additions and 369 deletions.
92 changes: 92 additions & 0 deletions .github/ISSUE_TEMPLATE/bug_report.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,92 @@
name: Bug Report
description: Create a report to help us reproduce and correct the bug
labels: [bug]

body:
- type: markdown
attributes:
value: >
#### Before submitting a bug, please make sure the issue hasn't been already
addressed by searching through [the past issues](https://github.com/PriorLabs/TabPFN/issues)
and updating your tabpfn version by running `pip install tabpfn --upgrade`.
- type: textarea
attributes:
label: Describe the bug
description: >
A clear and concise description of what the bug is.
validations:
required: true
- type: textarea
attributes:
label: Steps/Code to Reproduce
description: |
Please add a minimal code example (some advices [here](https://scikit-learn.org/dev/developers/minimal_reproducer.html)) that can reproduce the error when running it. Be as succinct as possible, **do not depend on external data files**: instead you can generate synthetic data using `numpy.random`, [sklearn.datasets.make_regression](https://scikit-learn.org/stable/modules/generated/sklearn.datasets.make_regression.html), [sklearn.datasets.make_classification](https://scikit-learn.org/stable/modules/generated/sklearn.datasets.make_classification.html) or a few lines of Python code. Example:
```python
import numpy as np
from sklearn.datasets import make_classification
from tabpfn import TabPFNClassifier
# Generate synthetic classification data
X, y = make_classification(n_samples=100, n_features=10, random_state=42)
# Initialize and train TabPFN
classifier = TabPFNClassifier(N_estimators=4)
classifier.fit(X, y)
# Make predictions
y_pred = classifier.predict(X)
```
If the code is too long, feel free to put it in a public gist and link it in the issue: https://gist.github.com.
In short, **we are going to copy-paste your code** to run it and we expect to get the same result as you.
We acknowledge that crafting a minimal reproducible code example requires some effort on your side but it really helps the maintainers quickly reproduce the problem and analyze its cause without any ambiguity. Ambiguous bug reports tend to be slower to fix because they will require more effort and back and forth discussion between the maintainers and the reporter to pin-point the precise conditions necessary to reproduce the problem.
placeholder: |
```
Sample code to reproduce the problem
```
validations:
required: false
- type: textarea
attributes:
label: Expected Results
description: >
Please paste or describe the expected results.
placeholder: >
Example: No error is thrown.
validations:
required: false
- type: textarea
attributes:
label: Actual Results
description: |
Please paste or describe the results you observe instead of the expected results. If you observe an error, please paste the error message including the **full traceback** of the exception. For instance the code above raises the following exception:
```python-traceback
---------------------------------------------------------------------------
Traceback (most recent call last):
File "/Users/leo/VSCProjects/new/TabPFN/test_git.py", line 9, in <module>
classifier = TabPFNClassifier(N_estimators=4)
TypeError: TabPFNClassifier.__init__() got an unexpected keyword argument 'N_estimators'
```
placeholder: >
Please paste or specifically describe the actual output or traceback.
validations:
required: false
- type: textarea
attributes:
label: Versions
render: shell
description: |
Please run the following and paste the output below.
```python
import tabpfn; tabpfn.display_debug_info()
```
validations:
required: false
- type: markdown
attributes:
value: >
Thanks for contributing 🎉!
17 changes: 17 additions & 0 deletions .github/ISSUE_TEMPLATE/doc_improvement.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
name: Documentation improvement
description: Create a report to help us improve the documentation. Alternatively you can just open a pull request with the suggested change.
labels: [documentation]

body:
- type: textarea
attributes:
label: Describe the issue linked to the documentation
description: >
Tell us about the confusion introduced in the documentation.
validations:
required: true
- type: textarea
attributes:
label: Suggest a potential alternative/fix
description: >
Tell us how we could improve the documentation in this regard.
30 changes: 30 additions & 0 deletions .github/ISSUE_TEMPLATE/feature_request.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
name: Feature request
description: Suggest a new feature or enhancement to an existing feature
labels: [enhancement]

body:
- type: textarea
attributes:
label: Describe the workflow you want to enable
validations:
required: true
- type: textarea
attributes:
label: Describe your proposed solution
validations:
required: true
- type: textarea
attributes:
label: Describe alternatives you've considered, if relevant
- type: textarea
attributes:
label: Additional context
- type: dropdown
attributes:
label: Impact
description: How would you rate the impact of this feature for you?
options:
- Critical (Blocking work)
- High (Major improvement)
- Medium (Significant enhancement)
- Low (Nice to have)
15 changes: 15 additions & 0 deletions .github/dependabot.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
version: 2
updates:
# Python dependencies
- package-ecosystem: "pip"
directory: "/"
schedule:
interval: "weekly"
open-pull-requests-limit: 3

# GitHub Actions
- package-ecosystem: "github-actions"
directory: "/"
schedule:
interval: "weekly"
open-pull-requests-limit: 3
125 changes: 125 additions & 0 deletions .github/workflows/pull_request.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,125 @@
name: In pull request
on:
pull_request:
branches:
- main

jobs:
check_python_linting:
name: Ruff Linting & Formatting
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v4
- uses: chartboost/ruff-action@v1
with:
src: "./src ./tests"
version: 0.8.6
- uses: chartboost/ruff-action@v1
with:
src: "./src ./tests"
version: 0.8.6
args: 'format --check'

test_compatibility:
name: Test Package Compatibility
strategy:
fail-fast: false
matrix:
include:
- os: ubuntu-latest
python-version: "3.9"
dependency-set: minimum
- os: macos-13 # macos-latest doesn't work with python 3.10
# https://github.com/actions/setup-python/issues/855
python-version: "3.9"
dependency-set: minimum
- os: ubuntu-latest
python-version: "3.12"
dependency-set: maximum
- os: macos-latest
python-version: "3.12"
dependency-set: maximum
runs-on: ${{ matrix.os }}

steps:
- uses: actions/checkout@v4

- name: Set up Python ${{ matrix.python-version }}
uses: actions/setup-python@v5
with:
python-version: ${{ matrix.python-version }}
architecture: x64

- name: Install uv
uses: astral-sh/setup-uv@v5
with:
enable-cache: true

- name: Generate requirements file for minimum dependencies
if: matrix.dependency-set == 'minimum'
run: |
python << EOF
import re
with open('pyproject.toml', 'r') as f:
content = f.read()
# Find dependencies section using regex
deps_match = re.search(r'dependencies\s*=\s*\[(.*?)\]', content, re.DOTALL)
if deps_match:
deps = [d.strip(' "\'') for d in deps_match.group(1).strip().split('\n') if d.strip()]
min_reqs = []
for dep in deps:
match = re.match(r'([^>=<\s]+)\s*>=\s*([^,\s"\']+)', dep)
if match:
package, min_ver = match.groups()
min_reqs.append(f"{package}=={min_ver}")
with open('requirements.txt', 'w') as f:
f.write('\n'.join(min_reqs))
EOF
- name: Generate requirements file for maximum dependencies
if: matrix.dependency-set == 'maximum'
run: |
python << EOF
import re
with open('pyproject.toml', 'r') as f:
content = f.read()
# Find dependencies section using regex
deps_match = re.search(r'dependencies\s*=\s*\[(.*?)\]', content, re.DOTALL)
if deps_match:
deps = [d.strip(' "\'') for d in deps_match.group(1).strip().split('\n') if d.strip()]
max_reqs = []
for dep in deps:
# Check for maximum version constraint
max_version_match = re.search(r'([^>=<\s]+).*?<\s*([^,\s"\']+)', dep)
if max_version_match:
# If there's a max version, use the version just below it
package, max_ver = max_version_match.groups()
max_reqs.append(f"{package}<{max_ver}")
else:
# If no max version, just use the package name
package = re.match(r'([^>=<\s]+)', dep).group(1)
max_reqs.append(package)
with open('requirements.txt', 'w') as f:
f.write('\n'.join(max_reqs))
EOF
- name: Install dependencies
run: |
uv pip install --system --no-deps .
# onnx is required for onnx export tests
# we don't install all dev dependencies here for speed
uv pip install --system -r requirements.txt
uv pip install --system pytest onnx
- name: Initialize submodules
run: git submodule update --init --recursive

- name: Run Tests
run: |
pytest tests/
2 changes: 1 addition & 1 deletion LICENSE.txt → LICENSE
Original file line number Diff line number Diff line change
Expand Up @@ -198,4 +198,4 @@
shall be required.


END OF TERMS AND CONDITIONS
END OF TERMS AND CONDITIONS
36 changes: 14 additions & 22 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,9 @@
[![Discord](https://img.shields.io/discord/1285598202732482621?color=7289da&label=Discord&logo=discord&logoColor=ffffff)](https://discord.com/channels/1285598202732482621/)
[![Documentation](https://img.shields.io/badge/docs-priorlabs.ai-blue)](https://priorlabs.ai/docs)
[![colab](https://colab.research.google.com/assets/colab-badge.svg)](https://tinyurl.com/tabpfn-colab-local)
[![Python Versions](https://img.shields.io/badge/python-3.9%20%7C%203.10%20%7C%203.11%20%7C%203.12-blue)](https://pypi.org/project/tabpfn/)

<img src="https://github.com/PriorLabs/tabpfn-extensions/blob/main/tabpfn_summary.webp" width="80%" alt="TabPFN Summary">

TabPFN is a foundation model for tabular data that outperforms traditional methods while
being dramatically faster. This repository contains the core PyTorch implementation with
Expand All @@ -16,10 +19,6 @@ features. Previous version available at [v1.0.0](../../tree/v1.0.0) and

📚 For detailed usage examples and best practices, check out [Interactive Colab Tutorial](https://tinyurl.com/tabpfn-colab-local)

## 🚀 We're Hiring!

Prior Labs is building breakthrough foundation models for tabular data. [Join our team in Freiburg and Berlin ](https://jobs.ashbyhq.com/prior-labs)

## 🌐 TabPFN Ecosystem

Choose the right TabPFN implementation for your needs:
Expand Down Expand Up @@ -112,26 +111,10 @@ We're building the future of tabular machine learning and would love your involv

Prior Labs License (Apache 2.0 with additional attribution requirement): [here](https://priorlabs.ai/tabpfn-license/)

```
Attribution clause (cropped):
If You distribute or make available the Work or any Derivative
Work thereof relating to any part of the source or model weights,
or a product or service (including another AI model) that contains
any source or model weights, You shall (A) provide a copy of this
License with any such materials; and (B) prominently display
“Built with TabPFN” on each related website, user interface, blogpost,
about page, or product documentation. If You use the source or model
weights or model outputs to create, train, fine tune, distil, or
otherwise improve an AI model, which is distributed or made available,
you shall also include “TabPFN” at the beginning of any such AI model name.
To clarify, internal benchmarking and testing without external
communication shall not qualify as distribution or making available
pursuant to this Section 10 and no attribution under this Section 10
shall be required.
```

## 📚 Citation

You can read our paper explaining TabPFN [here](https://doi.org/10.1038/s41586-024-08328-6).

```bibtex
@article{hollmann2025tabpfn,
title={Accurate predictions on small data with a tabular foundation model},
Expand All @@ -146,8 +129,17 @@ shall be required.
publisher={Springer Nature},
url={https://www.nature.com/articles/s41586-024-08328-6},
}
@inproceedings{hollmann2023tabpfn,
title={TabPFN: A transformer that solves small tabular classification problems in a second},
author={Hollmann, Noah and M{\"u}ller, Samuel and Eggensperger, Katharina and Hutter, Frank},
booktitle={International Conference on Learning Representations 2023},
year={2023}
}
```



## 🛠️ Development

1. Setup environment:
Expand Down
Loading

0 comments on commit 23a08e7

Please sign in to comment.