Skip to content

Latest commit

 

History

History
60 lines (41 loc) · 3.29 KB

README.md

File metadata and controls

60 lines (41 loc) · 3.29 KB

Categorical Embeddings for Tabular Data with PyTorch

This repository contains the PyTorch implementation of a deep learning approach to handling tabular data, particularly focusing on the use of categorical embeddings for multiclass classification. The methodology is applied to the Shelter Animal Outcomes dataset to predict the outcomes for animals in a shelter. Feature importance is analyzed using Mean Decrease Impurity (MDI) and permutation importance techniques to interpret model decisions better.

Abstract

While deep learning has shown remarkable success in fields like computer vision and natural language processing, its application to tabular data is less explored. Tabular data, being the most common form in industry, presents unique challenges. This project demonstrates the effectiveness of using embeddings to represent categorical variables in a continuous vector space, leveraging PyTorch for its robust GPU acceleration and prebuilt functionalities. Our model aims to predict the probability of an animal's outcome in a shelter across five categories and assess the importance of different features in making these predictions.

Dataset

The Shelter Animal Outcomes dataset from Kaggle is used, detailing various aspects of animals in the Austin Animal Center. Information includes breed, color, sex, age, and more, mapped against the outcome for each animal.

Getting Started

Prerequisites

  • Python 3.6+
  • PyTorch 1.x
  • scikit-learn
  • pandas
  • numpy

Installation

Clone the repository to your local machine:

git clone https://github.com/your_github_username/categorical_embeddings_pytorch.git
cd categorical_embeddings_pytorch

Install the required dependencies:

pip install -r requirements.txt

Model and Training A feedforward neural network model is utilized, with separate pathways for numerical and categorical data. Categorical data is transformed into embeddings before being combined with numerical data for predictions. The Adam optimizer and cross-entropy loss function guide the training process.

Feature Importance Two methods are applied to assess feature importance: Random Forest Feature Importance and Permutation Feature Importance. These analyses help in understanding the impact of different features on the model's prediction outcomes.

Results Our model demonstrates the ability to accurately predict the outcome probabilities for shelter animals. Feature importance analysis revealed insights into which features significantly influence model predictions, aiding in interpretability and model improvement.

References Shelter Animal Outcomes Dataset: https://www.kaggle.com/c/shelter-animal-outcomes/data

License This project is licensed under the MIT License - see the LICENSE.md file for details.

Acknowledgments Austin Animal Center for providing the dataset. PyTorch Team for the comprehensive deep learning library. Contributors to the scikit-learn library, which was used for feature importance analysis.

Remember to replace `https://github.com/your_github_username/categorical_embeddings_pytorch.git` with the actual URL of your GitHub repository and adjust any paths or scripts names according to your project's structure.