From a0d7a49a70126e0e81ecb11683b0224399616bf2 Mon Sep 17 00:00:00 2001 From: dnth Date: Wed, 31 May 2023 03:56:11 +0000 Subject: [PATCH] fix slow loading imagenet --- notebooks/imagenet-1k-pytorch.ipynb | 62 ++++++++++++++++++----------- vl_datasets/vl_imagenet.py | 17 +++++--- 2 files changed, 51 insertions(+), 28 deletions(-) diff --git a/notebooks/imagenet-1k-pytorch.ipynb b/notebooks/imagenet-1k-pytorch.ipynb index dacf2b0..015fcfb 100644 --- a/notebooks/imagenet-1k-pytorch.ipynb +++ b/notebooks/imagenet-1k-pytorch.ipynb @@ -2,7 +2,7 @@ "cells": [ { "cell_type": "code", - "execution_count": 4, + "execution_count": 1, "id": "ea3b9915", "metadata": {}, "outputs": [], @@ -16,7 +16,7 @@ }, { "cell_type": "code", - "execution_count": 5, + "execution_count": 2, "id": "6c727397", "metadata": {}, "outputs": [], @@ -46,7 +46,7 @@ }, { "cell_type": "code", - "execution_count": 6, + "execution_count": 3, "id": "e3dada78", "metadata": {}, "outputs": [ @@ -54,43 +54,43 @@ "name": "stdout", "output_type": "stream", "text": [ - "Using provided CSV file: archives/ImageNet-1K_images_issue_file_list.csv\n", - "Using provided CSV file: archives/ImageNet-1K_images_issue_file_list.csv\n" + "Using provided CSV file: ../../imagenet-1k/archives/ImageNet-1K_images_issue_file_list.csv\n", + "Using provided CSV file: ../../imagenet-1k/archives/ImageNet-1K_images_issue_file_list.csv\n" ] } ], "source": [ "from vl_datasets import VLImageNet\n", - "train = VLImageNet('./archives', exclude_csv='archives/ImageNet-1K_images_issue_file_list.csv', transform=train_transform)\n", - "valid = VLImageNet('./archives', split='val', exclude_csv='archives/ImageNet-1K_images_issue_file_list.csv', transform=valid_transform)" + "train_dataset = VLImageNet('../../imagenet-1k/archives', split='train', exclude_csv='../../imagenet-1k/archives/ImageNet-1K_images_issue_file_list.csv', transform=train_transform)\n", + "valid_dataset = VLImageNet('../../imagenet-1k/archives', split='val', exclude_csv='../../imagenet-1k/archives/ImageNet-1K_images_issue_file_list.csv', transform=valid_transform)" ] }, { "cell_type": "code", - "execution_count": 8, + "execution_count": 4, "id": "4fe4e84c", "metadata": {}, "outputs": [], "source": [ - "train_loader = DataLoader(train, batch_size=256, shuffle=True)\n", - "valid_loader = DataLoader(valid, batch_size=256, shuffle=True)" + "train_loader = DataLoader(train_dataset, batch_size=256, shuffle=True)\n", + "valid_loader = DataLoader(valid_dataset, batch_size=256, shuffle=True)" ] }, { "cell_type": "code", - "execution_count": 10, + "execution_count": 5, "id": "08c1ae84", "metadata": {}, "outputs": [], "source": [ "model = torchvision.models.resnet18(weights=torchvision.models.ResNet18_Weights.DEFAULT)\n", "num_ftrs = model.fc.in_features\n", - "model.fc = nn.Linear(num_ftrs, len(train.classes))" + "model.fc = nn.Linear(num_ftrs, len(train_dataset.classes))" ] }, { "cell_type": "code", - "execution_count": 11, + "execution_count": 6, "id": "6d2d80da", "metadata": {}, "outputs": [], @@ -99,9 +99,17 @@ "optimizer = optim.Adam(model.parameters(), lr=0.001)" ] }, + { + "cell_type": "markdown", + "id": "e1150909", + "metadata": {}, + "source": [ + "For the purpose of demonstration, we will be training only on 1% of the total images. " + ] + }, { "cell_type": "code", - "execution_count": 15, + "execution_count": 7, "id": "c7c69d36", "metadata": {}, "outputs": [], @@ -115,7 +123,7 @@ }, { "cell_type": "code", - "execution_count": 16, + "execution_count": 8, "id": "d88fc413", "metadata": {}, "outputs": [ @@ -124,11 +132,11 @@ "output_type": "stream", "text": [ "Using device: cpu\n", - "Epoch 1 - Loss: 6.17537749959316\n", - "Epoch 2 - Loss: 5.914224514416424\n", - "Epoch 3 - Loss: 5.068374691766483\n", - "Epoch 4 - Loss: 4.328482826895582\n", - "Epoch 5 - Loss: 3.832579866973203\n" + "Epoch 1 - Loss: 6.119085155736327\n", + "Epoch 2 - Loss: 5.854837443867154\n", + "Epoch 3 - Loss: 4.923516098672809\n", + "Epoch 4 - Loss: 4.275343816862248\n", + "Epoch 5 - Loss: 3.8568021098517646\n" ] } ], @@ -140,7 +148,7 @@ "\n", "for epoch in range(num_epochs):\n", " running_loss = 0.0\n", - " for i, data in enumerate(subset_loader):\n", + " for i, data in enumerate(subset_loader): # to train on the full dataset replace subset_loader with train_loader\n", " inputs, labels = data\n", " inputs, labels = inputs.to(device), labels.to(device)\n", "\n", @@ -158,10 +166,18 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 9, "id": "b19d0ba6", "metadata": {}, - "outputs": [], + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Accuracy: 0.19878699365098684\n" + ] + } + ], "source": [ "correct = 0\n", "total = 0\n", diff --git a/vl_datasets/vl_imagenet.py b/vl_datasets/vl_imagenet.py index ac72203..b6cbcce 100644 --- a/vl_datasets/vl_imagenet.py +++ b/vl_datasets/vl_imagenet.py @@ -1,7 +1,7 @@ # Code adapted from https://github.com/pytorch/vision/blob/main/torchvision/datasets/ from torchvision.datasets import ImageNet -from typing import Callable, Optional, Union, Sequence, Any, Tuple +from typing import Optional, Any, Tuple import pandas as pd class VLImageNet(ImageNet): @@ -16,11 +16,18 @@ def __init__(self, root: str, self.exclude_df, self.exclude_set = parse_exclude_csv(exclude_csv) # Filter file lists based on VL CSV files - # TODO: use more efficient method. This takes too long. Sets subtraction maybe? - image_keep_list = [i for i, (filename, class_num) in enumerate(self.samples) if not filename.endswith(tuple(self.exclude_set))] + # Extract filenames from samples + filenames = {sample[0].split("/")[-1] for sample in self.samples} - self.samples = [self.samples[i] for i in image_keep_list] - self.targets = [self.targets[i] for i in image_keep_list] + # Remove filenames found in exclude_set + filtered_filenames = filenames - self.exclude_set + + # Create the filtered_list by filtering tuples_list based on the filtered_filenames + filtered_samples = [(filename, label) for filename, label in self.samples if filename.split("/")[-1] in filtered_filenames] + filtered_targets = [s[1] for s in filtered_samples] + + self.samples = filtered_samples + self.targets = filtered_targets def parse_exclude_csv(exclude_csv_arg: str) -> Tuple[pd.DataFrame, set]: