This repository contains code for classifying different types of brain tumors using a Convolutional Neural Network (CNN) architecture called ResNet50. It also includes a web service built with FastAPI for real-time inference.
The data used for training and testing the model is a combination of the following two datasets:
- https://www.kaggle.com/datasets/masoudnickparvar/brain-tumor-mri-dataset
- https://www.kaggle.com/datasets/sartajbhuvaji/brain-tumor-classification-mri
- Python 3.8 or higher
- pip
- Virtual Environment (recommended)
-
Clone the Repository: Clone this repository to a folder of your choice.
git clone https://github.com/adisve/brain-tumor-classifier.git
-
Navigate to the Project Folder: Move into the cloned project directory.
cd brain-tumor-classifier
-
Virtual Environment (Recommended): It's often best to create a virtual environment to isolate package dependencies. To create a virtual environment, run the following command:
python3 -m venv .venv
To activate the virtual environment, run:
-
Linux/Mac:
source .venv/bin/activate
-
Windows:
.venv\Scripts\activate
Install the necessary packages by running the following command:
pip install .
This command reads the pyproject.toml file and installs all dependencies.
The web service is built using FastAPI and provides real-time inferences from the trained model.
To run the web service on your local machine, navigate to the server/ directory and execute:
uvicorn api:app --reload
This will start the FastAPI server and you can access the API documentation at http://127.0.0.1:8000/docs.
- Predict: POST /predict/
- Accepts an MRI image and returns the type of brain tumor.
For detailed documentation, refer to the FastAPI generated documentation at http://127.0.0.1:8000/docs.
To run the model on your local machine, navigate to the model/ directory and open the Jupyter Notebook file.
The model uses the following Keras callbacks during training:
- EarlyStopping: To stop training early if no improvement in validation loss.
- ReduceLROnPlateau: To reduce learning rate when a metric has stopped improving.
- ModelCheckpoint: To save the model after every epoch.
- LambdaCallback: Custom callback for additional functionalities (here, displaying the confusion matrix).
Run the model by executing the Jupyter Notebook. Metrics such as loss, accuracy, and AUC (Area Under the Curve) will be displayed at the end, along with interesting graphs.