Skip to content

Commit

Permalink
feat: fashion MNIST example (#56)
Browse files Browse the repository at this point in the history
  • Loading branch information
timurcarstensen authored Aug 20, 2024
1 parent 66d43b1 commit d7d5aae
Show file tree
Hide file tree
Showing 5 changed files with 373 additions and 0 deletions.
4 changes: 4 additions & 0 deletions .github/workflows/run-examples.yml
Original file line number Diff line number Diff line change
Expand Up @@ -28,3 +28,7 @@ jobs:
run: |
python -m examples.sinc.sinc_nas
python -m examples.sinc.search
- name: Run fashion MNIST example
run: |
python -m examples.fashion_mnist.train_fashion_mnist
python -m examples.fashion_mnist.search_fashion_mnist
Empty file.
49 changes: 49 additions & 0 deletions examples/fashion_mnist/model.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,49 @@
from __future__ import annotations

from typing import Literal

import torch
import torch.nn as nn
import torch.nn.functional as F

from whittle.modules import Linear


class LeNet(nn.Module):
def __init__(self, fc1_out: int, fc2_out: int, fc_base_out: int | None = 128):
super().__init__()
# FIXME: (fixup comment) 1 input image channel, 6 output channels, 5x5 square conv kernel
self.fc_base_out = fc_base_out
self.fc_base = Linear(28 * 28, fc_base_out, bias=True)
self.fc1 = Linear(fc_base_out, fc1_out, bias=True) # 5x5 image dimension
self.fc2 = Linear(fc1_out, fc2_out, bias=True)
self.fc3 = Linear(fc2_out, 10, bias=True)

def forward(self, x: torch.Tensor):
x = x.reshape(x.shape[0], -1)
x = F.relu(self.fc_base(x))
x = F.relu(self.fc1(x))
x = F.relu(self.fc2(x))
x = self.fc3(x)
return x

def select_sub_network(self, config: dict[Literal["fc1_out", "fc2_out"], int]):
fc1_out = config["fc1_out"]
fc2_out = config["fc2_out"]

self.fc1.set_sub_network(
sub_network_in_features=self.fc_base_out, sub_network_out_features=fc1_out
)
self.fc2.set_sub_network(
sub_network_in_features=fc1_out,
sub_network_out_features=fc2_out,
)
self.fc3.set_sub_network(
sub_network_in_features=fc2_out, sub_network_out_features=10
)

def reset_super_network(self):
self.fc_base.reset_super_network()
self.fc1.reset_super_network()
self.fc2.reset_super_network()
self.fc3.reset_super_network()
110 changes: 110 additions & 0 deletions examples/fashion_mnist/search_fashion_mnist.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,110 @@
from __future__ import annotations

from argparse import ArgumentParser
from pathlib import Path
from typing import Literal

import matplotlib.pyplot as plt
import numpy as np
import torch
from syne_tune.config_space import randint
from torch.utils.data import DataLoader
from torchvision import datasets
from torchvision.transforms import ToTensor

from examples.fashion_mnist.model import LeNet
from examples.fashion_mnist.train_fashion_mnist import validate
from whittle.search import multi_objective_search


def compute_mac_linear_layer(in_features: int, out_features: int):
return in_features * out_features


def objective(
config: dict[Literal["fc1_out", "fc2_out"], int], model: LeNet, device: torch.device
) -> tuple[int, float]:
model.select_sub_network(config=config)

_, loss = validate(
test_loader=test_loader,
model=model,
criterion=torch.nn.functional.cross_entropy,
device=device,
)

mac = 0
mac += compute_mac_linear_layer(
model.fc_base.in_features, model.fc_base.out_features
)
mac += compute_mac_linear_layer(model.fc1.in_features, config["fc1_out"])
mac += compute_mac_linear_layer(config["fc1_out"], config["fc2_out"])
mac += compute_mac_linear_layer(config["fc2_out"], 10)

return mac, loss


if __name__ == "__main__":
parser = ArgumentParser()
parser.add_argument("--batch_size", type=int, default=128)
parser.add_argument("--fc1_out", type=int, default=32)
parser.add_argument("--fc2_out", type=int, default=16)
parser.add_argument("--training_strategy", type=str, default="sandwich")
parser.add_argument("--search_strategy", type=str, default="random_search")
parser.add_argument("--do_plot", type=bool, default=True)
parser.add_argument("--st_checkpoint_dir", type=str, default="./checkpoints")
args, _ = parser.parse_known_args()

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

test_loader = DataLoader(
dataset=datasets.FashionMNIST(".", train=False, transform=ToTensor()),
batch_size=args.batch_size,
shuffle=False,
)

model = LeNet(fc1_out=args.fc1_out, fc2_out=args.fc2_out)
path = (
Path(args.st_checkpoint_dir)
/ f"{args.training_strategy}_model_{args.fc1_out}_{args.fc2_out}.pt"
)
checkpoint = torch.load(path)
model.load_state_dict(checkpoint["state"])
model = model.to(device)
model.eval()

search_space = {
"fc1_out": randint(1, args.fc1_out),
"fc2_out": randint(1, args.fc2_out),
}

results = multi_objective_search(
objective,
search_space,
objective_kwargs={"model": model, "device": device},
search_strategy=args.search_strategy,
num_samples=20,
seed=42,
)

costs = np.array(results["costs"])

idx = np.array(results["is_pareto_optimal"])
if args.do_plot:
plt.scatter(costs[:, 0], costs[:, 1], color="black", label="sub-networks", s=30)
plt.scatter(
costs[idx, 0],
costs[idx, 1],
color="red",
label="Pareto optimal",
s=50,
)

plt.xlabel("MACs")
plt.ylabel("Validation Loss")
plt.xscale("log")
plt.grid(linewidth="1", alpha=0.4)
plt.title("Pareto front for Fashion MNIST")
plt.legend()
plt.tight_layout()
plt.savefig("pareto_front.png", dpi=300)
210 changes: 210 additions & 0 deletions examples/fashion_mnist/train_fashion_mnist.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,210 @@
from __future__ import annotations

import json
import os
from argparse import ArgumentParser
from pathlib import Path
from typing import Callable

from tqdm import trange
import numpy as np
import pandas as pd
import torch
import torch.nn as nn
from syne_tune.config_space import randint
from syne_tune.report import Reporter
from tabulate import tabulate
from torch.utils.data import DataLoader, random_split
from torchvision import datasets
from torchvision.transforms import ToTensor

from examples.fashion_mnist.model import LeNet
from whittle.sampling.random_sampler import RandomSampler
from whittle.training_strategies import (
RandomStrategy,
SandwichStrategy,
StandardStrategy,
)


def correct(output: torch.Tensor, target: torch.Tensor) -> int:
"""Returns the number of correct predictions."""
predicted_digits = output.argmax(1)
correct_ones = (predicted_digits == target).type(torch.float32)
return correct_ones.sum().item()


def validate(
model: LeNet, test_loader: DataLoader, criterion: Callable, device: torch.device
):
model.eval()

num_batches = len(test_loader)
num_items = len(test_loader.dataset)

test_loss = 0
total_correct = 0
with torch.no_grad():
for data, target in test_loader:
# Copy data and targets to GPU
data = data.to(device)
target = target.to(device)

output = model.forward(x=data)
# Calculate the loss
loss = criterion(output, target)
test_loss += loss.item()

# Count number of correct digits
total_correct += correct(output, target)

test_loss = test_loss / num_batches
accuracy = total_correct / num_items

# print(f"Testset accuracy: {100 * accuracy:>0.1f}%, average loss: {test_loss:>7f}")
return 100 * accuracy, loss.cpu().item()


if __name__ == "__main__":
report = Reporter()

parser = ArgumentParser()
parser.add_argument("--batch_size", type=int, default=128)
parser.add_argument("--learning_rate", type=float, default=1e-3)
parser.add_argument("--epochs", type=int, default=5)
parser.add_argument("--fc1_out", type=int, default=32)
parser.add_argument("--fc2_out", type=int, default=16)
parser.add_argument("--training_strategy", type=str, default="sandwich")
parser.add_argument("--st_checkpoint_dir", type=str, default="./checkpoints")
parser.add_argument("--num_train_samples", type=int, default=20000)
parser.add_argument("--seed", type=int, default=42)
args, _ = parser.parse_known_args()

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

full_dataset = datasets.FashionMNIST(
".", train=True, download=True, transform=ToTensor()
)

# Split the dataset
train_dataset, _ = random_split(
full_dataset,
[args.num_train_samples, len(full_dataset) - args.num_train_samples],
)

train_loader = DataLoader(
dataset=train_dataset,
batch_size=args.batch_size,
shuffle=True,
)
test_loader = DataLoader(
dataset=datasets.FashionMNIST(".", train=False, transform=ToTensor()),
batch_size=args.batch_size,
shuffle=False,
)

model = LeNet(fc1_out=args.fc1_out, fc2_out=args.fc2_out).to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=args.learning_rate)
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=args.epochs)

search_space = {
"fc1_out": randint(1, args.fc1_out),
"fc2_out": randint(1, args.fc2_out),
}

os.makedirs(args.st_checkpoint_dir, exist_ok=True)
current_best = None
lc_valid = []
lc_train = []

sampler = RandomSampler(search_space, seed=args.seed)
training_strategies = {
"standard": StandardStrategy(
sampler=sampler, loss_function=nn.functional.cross_entropy
),
"sandwich": SandwichStrategy(
sampler=sampler, loss_function=nn.functional.cross_entropy
),
"random": RandomStrategy(
random_samples=2, sampler=sampler, loss_function=nn.functional.cross_entropy
),
}
update_op = training_strategies[args.training_strategy]

for epoch in trange(args.epochs, desc=f"Training with {args.training_strategy}"):
train_loss = 0
model.train()

for batch_idx, batch in enumerate(train_loader):
x, y = batch
x = x.to(device)
y = y.to(device)

optimizer.zero_grad()
loss = update_op(model, x, y)
train_loss += loss
optimizer.step()
scheduler.step()

valid_acc, valid_loss = validate(
model=model,
test_loader=test_loader,
criterion=torch.nn.functional.cross_entropy,
device=device,
)

lc_train.append(float(train_loss))
lc_valid.append(float(valid_loss))
if np.isnan(valid_loss):
valid_loss = 1000000

report(
epoch=epoch + 1,
train_loss=train_loss,
valid_loss=valid_loss,
valid_acc=valid_acc,
num_params=sum(p.numel() for p in model.parameters() if p.requires_grad),
)

if current_best is None or current_best >= valid_loss:
current_best = valid_loss
if args.st_checkpoint_dir is not None:
os.makedirs(args.st_checkpoint_dir, exist_ok=True)
checkpoint = {
"state": model.state_dict(),
"config": {"fc1_out": args.fc1_out, "fc2_out": args.fc2_out},
}
torch.save(
checkpoint,
Path(args.st_checkpoint_dir)
/ f"{args.training_strategy}_model_{args.fc1_out}_{args.fc2_out}.pt",
)

history = {"train_loss": lc_train, "valid_loss": lc_valid}
json.dump(
history,
open(
Path(args.st_checkpoint_dir)
/ f"{args.training_strategy}_history_{args.fc1_out}_{args.fc2_out}.json",
"w",
),
)

lottery_grid = [[4, 4], [8, 8], [32, 16]]
df = pd.DataFrame(
{"fc1_out": [], "fc2_out": [], "accuracy": [], "loss": []},
)
for i, k in enumerate(lottery_grid):
config = {"fc1_out": k[0], "fc2_out": k[1]}
model.select_sub_network(config=config)
acc, loss = validate(
model=model,
test_loader=test_loader,
criterion=torch.nn.functional.cross_entropy,
device=device,
)
df.loc[i] = [*list(config.values()), acc, loss]
df = df.astype(
{"fc1_out": "str", "fc2_out": "str", "accuracy": "float", "loss": "float"}
)
print(tabulate(df, headers="keys", tablefmt="pipe", floatfmt=".2f"))

0 comments on commit d7d5aae

Please sign in to comment.