This repository has been archived by the owner on Nov 15, 2021. It is now read-only.
forked from ray-project/ray
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
[tune] Update MNIST Example (ray-project#4991)
- Loading branch information
1 parent
bbe3e5b
commit b1827d5
Showing
3 changed files
with
112 additions
and
165 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,189 +1,134 @@ | ||
# Original Code here: | ||
# https://github.com/pytorch/examples/blob/master/mnist/main.py | ||
from __future__ import absolute_import | ||
from __future__ import division | ||
from __future__ import print_function | ||
|
||
import numpy as np | ||
import argparse | ||
import torch | ||
import torch.nn as nn | ||
import torch.nn.functional as F | ||
import torch.optim as optim | ||
from torchvision import datasets, transforms | ||
|
||
# Training settings | ||
parser = argparse.ArgumentParser(description="PyTorch MNIST Example") | ||
parser.add_argument( | ||
"--batch-size", | ||
type=int, | ||
default=64, | ||
metavar="N", | ||
help="input batch size for training (default: 64)") | ||
parser.add_argument( | ||
"--test-batch-size", | ||
type=int, | ||
default=1000, | ||
metavar="N", | ||
help="input batch size for testing (default: 1000)") | ||
parser.add_argument( | ||
"--epochs", | ||
type=int, | ||
default=1, | ||
metavar="N", | ||
help="number of epochs to train (default: 1)") | ||
parser.add_argument( | ||
"--lr", | ||
type=float, | ||
default=0.01, | ||
metavar="LR", | ||
help="learning rate (default: 0.01)") | ||
parser.add_argument( | ||
"--momentum", | ||
type=float, | ||
default=0.5, | ||
metavar="M", | ||
help="SGD momentum (default: 0.5)") | ||
parser.add_argument( | ||
"--no-cuda", | ||
action="store_true", | ||
default=False, | ||
help="disables CUDA training") | ||
parser.add_argument( | ||
"--seed", | ||
type=int, | ||
default=1, | ||
metavar="S", | ||
help="random seed (default: 1)") | ||
parser.add_argument( | ||
"--smoke-test", action="store_true", help="Finish quickly for testing") | ||
|
||
|
||
def train_mnist(args, config, reporter): | ||
vars(args).update(config) | ||
args.cuda = not args.no_cuda and torch.cuda.is_available() | ||
|
||
torch.manual_seed(args.seed) | ||
if args.cuda: | ||
torch.cuda.manual_seed(args.seed) | ||
|
||
kwargs = {"num_workers": 1, "pin_memory": True} if args.cuda else {} | ||
import ray | ||
from ray import tune | ||
from ray.tune import track | ||
from ray.tune.schedulers import AsyncHyperBandScheduler | ||
|
||
# Change these values if you want the training to run quicker or slower. | ||
EPOCH_SIZE = 512 | ||
TEST_SIZE = 256 | ||
|
||
|
||
class Net(nn.Module): | ||
def __init__(self, config): | ||
super(Net, self).__init__() | ||
self.conv1 = nn.Conv2d(1, 3, kernel_size=3) | ||
self.fc = nn.Linear(192, 10) | ||
|
||
def forward(self, x): | ||
x = F.relu(F.max_pool2d(self.conv1(x), 3)) | ||
x = x.view(-1, 192) | ||
x = self.fc(x) | ||
return F.log_softmax(x, dim=1) | ||
|
||
|
||
def train(model, optimizer, train_loader, device): | ||
model.train() | ||
for batch_idx, (data, target) in enumerate(train_loader): | ||
if batch_idx * len(data) > EPOCH_SIZE: | ||
return | ||
data, target = data.to(device), target.to(device) | ||
optimizer.zero_grad() | ||
output = model(data) | ||
loss = F.nll_loss(output, target) | ||
loss.backward() | ||
optimizer.step() | ||
|
||
|
||
def test(model, data_loader, device): | ||
model.eval() | ||
correct = 0 | ||
total = 0 | ||
with torch.no_grad(): | ||
for batch_idx, (data, target) in enumerate(data_loader): | ||
if batch_idx * len(data) > TEST_SIZE: | ||
break | ||
data, target = data.to(device), target.to(device) | ||
outputs = model(data) | ||
_, predicted = torch.max(outputs.data, 1) | ||
total += target.size(0) | ||
correct += (predicted == target).sum().item() | ||
|
||
return correct / total | ||
|
||
|
||
def get_data_loaders(): | ||
mnist_transforms = transforms.Compose( | ||
[transforms.ToTensor(), | ||
transforms.Normalize((0.1307, ), (0.3081, ))]) | ||
|
||
train_loader = torch.utils.data.DataLoader( | ||
datasets.MNIST( | ||
"~/data", | ||
train=True, | ||
download=False, | ||
transform=transforms.Compose([ | ||
transforms.ToTensor(), | ||
transforms.Normalize((0.1307, ), (0.3081, )) | ||
])), | ||
batch_size=args.batch_size, | ||
shuffle=True, | ||
**kwargs) | ||
"~/data", train=True, download=True, transform=mnist_transforms), | ||
batch_size=64, | ||
shuffle=True) | ||
test_loader = torch.utils.data.DataLoader( | ||
datasets.MNIST( | ||
"~/data", | ||
train=False, | ||
transform=transforms.Compose([ | ||
transforms.ToTensor(), | ||
transforms.Normalize((0.1307, ), (0.3081, )) | ||
])), | ||
batch_size=args.test_batch_size, | ||
shuffle=True, | ||
**kwargs) | ||
|
||
class Net(nn.Module): | ||
def __init__(self): | ||
super(Net, self).__init__() | ||
self.conv1 = nn.Conv2d(1, 10, kernel_size=5) | ||
self.conv2 = nn.Conv2d(10, 20, kernel_size=5) | ||
self.conv2_drop = nn.Dropout2d() | ||
self.fc1 = nn.Linear(320, 50) | ||
self.fc2 = nn.Linear(50, 10) | ||
|
||
def forward(self, x): | ||
x = F.relu(F.max_pool2d(self.conv1(x), 2)) | ||
x = F.relu(F.max_pool2d(self.conv2_drop(self.conv2(x)), 2)) | ||
x = x.view(-1, 320) | ||
x = F.relu(self.fc1(x)) | ||
x = F.dropout(x, training=self.training) | ||
x = self.fc2(x) | ||
return F.log_softmax(x, dim=1) | ||
|
||
model = Net() | ||
if args.cuda: | ||
model.cuda() | ||
datasets.MNIST("~/data", train=False, transform=mnist_transforms), | ||
batch_size=64, | ||
shuffle=True) | ||
return train_loader, test_loader | ||
|
||
|
||
def train_mnist(config): | ||
use_cuda = config.get("use_gpu") and torch.cuda.is_available() | ||
device = torch.device("cuda" if use_cuda else "cpu") | ||
train_loader, test_loader = get_data_loaders() | ||
model = Net(config).to(device) | ||
|
||
optimizer = optim.SGD( | ||
model.parameters(), lr=args.lr, momentum=args.momentum) | ||
|
||
def train(epoch): | ||
model.train() | ||
for batch_idx, (data, target) in enumerate(train_loader): | ||
if args.cuda: | ||
data, target = data.cuda(), target.cuda() | ||
optimizer.zero_grad() | ||
output = model(data) | ||
loss = F.nll_loss(output, target) | ||
loss.backward() | ||
optimizer.step() | ||
|
||
def test(): | ||
model.eval() | ||
test_loss = 0 | ||
correct = 0 | ||
with torch.no_grad(): | ||
for data, target in test_loader: | ||
if args.cuda: | ||
data, target = data.cuda(), target.cuda() | ||
output = model(data) | ||
# sum up batch loss | ||
test_loss += F.nll_loss(output, target, reduction="sum").item() | ||
# get the index of the max log-probability | ||
pred = output.argmax(dim=1, keepdim=True) | ||
correct += pred.eq( | ||
target.data.view_as(pred)).long().cpu().sum() | ||
|
||
test_loss = test_loss / len(test_loader.dataset) | ||
accuracy = correct.item() / len(test_loader.dataset) | ||
reporter(mean_loss=test_loss, mean_accuracy=accuracy) | ||
|
||
for epoch in range(1, args.epochs + 1): | ||
train(epoch) | ||
test() | ||
model.parameters(), lr=config["lr"], momentum=config["momentum"]) | ||
|
||
while True: | ||
train(model, optimizer, train_loader, device) | ||
acc = test(model, test_loader, device) | ||
track.log(mean_accuracy=acc) | ||
|
||
|
||
if __name__ == "__main__": | ||
datasets.MNIST("~/data", train=True, download=True) | ||
parser = argparse.ArgumentParser(description="PyTorch MNIST Example") | ||
parser.add_argument( | ||
"--cuda", | ||
action="store_true", | ||
default=False, | ||
help="Enables GPU training") | ||
parser.add_argument( | ||
"--smoke-test", action="store_true", help="Finish quickly for testing") | ||
parser.add_argument( | ||
"--ray-redis-address", | ||
help="Address of Ray cluster for seamless distributed execution.") | ||
args = parser.parse_args() | ||
|
||
import ray | ||
from ray import tune | ||
from ray.tune.schedulers import AsyncHyperBandScheduler | ||
|
||
ray.init() | ||
if args.ray_redis_address: | ||
ray.init(redis_address=args.ray_redis_address) | ||
sched = AsyncHyperBandScheduler( | ||
time_attr="training_iteration", | ||
metric="mean_loss", | ||
mode="min", | ||
max_t=400, | ||
grace_period=20) | ||
tune.register_trainable( | ||
"TRAIN_FN", | ||
lambda config, reporter: train_mnist(args, config, reporter)) | ||
time_attr="training_iteration", metric="mean_accuracy") | ||
tune.run( | ||
"TRAIN_FN", | ||
train_mnist, | ||
name="exp", | ||
scheduler=sched, | ||
**{ | ||
"stop": { | ||
"mean_accuracy": 0.98, | ||
"training_iteration": 1 if args.smoke_test else 20 | ||
}, | ||
"resources_per_trial": { | ||
"cpu": 3, | ||
"gpu": int(not args.no_cuda) | ||
}, | ||
"num_samples": 1 if args.smoke_test else 10, | ||
"config": { | ||
"lr": tune.uniform(0.001, 0.1), | ||
"momentum": tune.uniform(0.1, 0.9), | ||
} | ||
stop={ | ||
"mean_accuracy": 0.98, | ||
"training_iteration": 5 if args.smoke_test else 20 | ||
}, | ||
resources_per_trial={ | ||
"cpu": 2, | ||
"gpu": int(args.cuda) | ||
}, | ||
num_samples=1 if args.smoke_test else 10, | ||
config={ | ||
"lr": tune.sample_from(lambda spec: 10**(-10 * np.random.rand())), | ||
"momentum": tune.uniform(0.1, 0.9), | ||
"use_gpu": int(args.cuda) | ||
}) |