-
Notifications
You must be signed in to change notification settings - Fork 6.2k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[AIR] ResourceChangingScheduler
causes tuning to hang with a large trial number
#30265
Comments
@Tonyhao96 I tried to quickly reproduce that with the AIR example you linked in the first issue and it worked fine for me. Is it possible for you to share a whole script that can reproduce this for you and your cluster setup? |
import os
import time
import argparse
from pathlib import Path
from filelock import FileLock
import models.cifar as Cifar
import torch
import torch.nn as nn
from torch.utils.data import DataLoader
import torchvision
import torchvision.transforms as transforms
import torchvision.datasets as datasets
from torchvision.models import resnet18
import ray
from ray import tune
from ray.train.torch import TorchTrainer, TorchCheckpoint
import ray.train.torch as ht
from ray.air import session
from ray.air.config import FailureConfig, RunConfig, ScalingConfig, CheckpointConfig
from ray.tune.schedulers import ASHAScheduler, FIFOScheduler
from ray.tune.schedulers.resource_changing_scheduler import (
ResourceChangingScheduler,
DistributeResources,
DistributeResourcesToTopJob,
)
from ray.tune.tuner import Tuner
from ray.tune.tune_config import TuneConfig
SEARCH_SPACE = {
"lr": tune.qloguniform(1e-4, 1, 1e-4),
"momentum": tune.quniform(0.5, 0.999, 0.001),
"batch_size": tune.choice([128, 256, 512]),
"gamma": tune.quniform(0.01, 0.9, 0.01),
}
def get_datasets(dataset):
"""Data loader for Cifar10/100 & Imagenet"""
if dataset == "cifar10":
normalize = transforms.Normalize(mean=[0.4914, 0.4822, 0.4465], std=[0.2470, 0.2435, 0.2616])
with FileLock(Path("~/data/data.lock").expanduser()):
train_dataset = datasets.CIFAR10(
root="~/data",
train=True,
download=True,
transform=transforms.Compose(
[transforms.RandomCrop(32, padding=4), transforms.RandomHorizontalFlip(), transforms.ToTensor(), normalize]
),
)
val_dataset = datasets.CIFAR10(
root="~/data", train=False, download=False, transform=transforms.Compose([transforms.ToTensor(), normalize])
)
return train_dataset, val_dataset
def train_epoch(dataloader, model, loss_fn, optimizer, fusion_num):
size = len(dataloader.dataset) // session.get_world_size()
model.train()
for batch, (X, y) in enumerate(dataloader):
# Compute prediction error
pred = model(X)
loss = loss_fn(pred, y)
# Backpropagation
optimizer.zero_grad()
# loss.backward()
ht.backward(loss) # For AMP support
optimizer.step()
def validate_epoch(dataloader, model, loss_fn, fusion_num):
size = len(dataloader.dataset) // session.get_world_size()
num_batches = len(dataloader)
model.eval()
test_loss, correct = 0, 0
with torch.no_grad():
for X, y in dataloader:
pred = model(X)
test_loss += loss_fn(pred, y).item()
correct += (pred.argmax(1) == y).type(torch.float).sum().item()
test_loss /= num_batches
correct /= size
return {"loss": test_loss, "val_acc": correct}
def train_func(config):
ht.accelerate(amp=config["amp"]) # For AMP support
ht.enable_reproducibility(seed=config["seed"])
fusion_num = config.get("FUSION_N", -1)
dataset = config.get("dataset")
if dataset == "imagenet":
model = torchvision.models.__dict__[config.get("model")]()
elif dataset == "cifar10" or "cifar100":
model = resnet18()
model = ht.prepare_model(model)
optimizer = torch.optim.SGD(
model.parameters(),
lr=config.get("lr", 0.01),
momentum=config.get("momentum", 0.9),
weight_decay=config.get("weight_decay", 0.001),
)
optimizer = ht.prepare_optimizer(optimizer)
lr_scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=20, gamma=config.get("gamma", 0.2))
worker_batch_size = config["batch_size"] // session.get_world_size()
train_set, val_set = get_datasets(dataset)
train_loader = DataLoader(train_set, batch_size=worker_batch_size, num_workers=8, pin_memory=True, shuffle=True)
val_loader = DataLoader(val_set, batch_size=worker_batch_size, num_workers=8, pin_memory=True)
train_loader = ht.prepare_data_loader(train_loader)
val_loader = ht.prepare_data_loader(val_loader)
# Create loss.
criterion = nn.CrossEntropyLoss()
for _ in range(10000):
train_epoch(train_loader, model, criterion, optimizer, fusion_num)
result = validate_epoch(val_loader, model, criterion, fusion_num)
lr_scheduler.step()
session.report(result, checkpoint=TorchCheckpoint.from_model(model))
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--address", default="auto", type=str, help="the address to use for Redis")
parser.add_argument("--model", type=str, default="resnet18", help="Model to use")
parser.add_argument("--dataset", type=str, default="cifar10", help="Dataset to use")
parser.add_argument("--scheduler", default="fifo", choices=["asha", "fifo"], type=str, help="Scheduler Algorithm")
parser.add_argument("--max-epoch", default=100, type=int, help="Max Epochs")
parser.add_argument("--max-sample", default=200, type=int, help="Max Samples")
parser.add_argument("--max-time", default=-1, type=int, help="Max Time (s), -1 for no limit")
parser.add_argument("--target-acc", default=0.96, type=float, help="Target Validation Accuracy")
parser.add_argument("--amp", default=False, type=bool, help="Whether enable AMP")
parser.add_argument("--mps", default=1, type=int, help="Whether enable MPS for GPU Sharing")
parser.add_argument("--seed", default=1, type=int, help="Fix Random Seed for Reproducing")
parser.add_argument("--addition-str", default="", type=str, help="Additional String for experiment name")
##### ASHA Parameters
parser.add_argument("--grace", default=3, type=int, help="grace_period")
parser.add_argument("--reduction", default=3, type=int, help="reduction_factor")
parser.add_argument("--brackets", default=1, type=int, help="brackets")
args, _ = parser.parse_known_args()
# ray.init(address=args.address)
ray.init(address=None)
config = SEARCH_SPACE | {
"model": args.model,
"dataset": args.dataset,
"seed": args.seed,
"amp": args.amp,
}
trainer = TorchTrainer(
train_func,
train_loop_config=config,
scaling_config=ScalingConfig(
num_workers=1,
use_gpu=True,
resources_per_worker={"CPU": 8 / args.mps, "GPU": 1 / args.mps},
_max_cpu_fraction_per_node=0.9,
),
)
tune_scheduler = FIFOScheduler()
tune_scheduler = ResourceChangingScheduler(
base_scheduler=tune_scheduler,
resources_allocation_function=DistributeResources(add_bundles=True), # default
)
experiment_name = f"{args.model}_{args.dataset}_s{args.max_sample}_e{args.max_epoch}"
tuner = Tuner(
trainer,
param_space={"train_loop_config": config},
tune_config=TuneConfig(
num_samples=args.max_sample,
metric="val_acc",
mode="max",
scheduler=tune_scheduler,
time_budget_s=args.max_time if args.max_time > 0 else None,
),
run_config=RunConfig(
name=experiment_name,
local_dir="../ray_results",
log_to_file=True,
stop={"training_iteration": args.max_epoch, "val_acc": args.target_acc},
checkpoint_config=CheckpointConfig(num_to_keep=1),
# callbacks=[WandbLoggerCallback(api_key_file="~/.wandb/api_key", project=f"{experiment_name}")],
failure_config=FailureConfig(fail_fast=True, max_failures=0),
),
)
results = tuner.fit()
print(results.get_best_result(metric="val_acc", mode="max"))
df = results.get_dataframe()
df.to_csv(f"../ray_results/{experiment_name}.csv")
time.sleep(5)
os.system("ray stop --force")
|
Hey, thanks, I can reproduce the behavior using the script you provided. I believe I have identified the issue. Can you check if removing the |
Thank you very much. Removing the |
I just took a quick check, by changing Number of trials: 6/6 (1 ERROR, 5 TERMINATED)
+--------------------------+------------+----------------------+------------------------+------------------------+------------------------+------------------------+--------+------------------+----------+-----------+--------------+
| Trial name | status | loc | train_loop_config/ba | train_loop_config/ga | train_loop_config/lr | train_loop_config/mo | iter | total time (s) | loss | val_acc | _timestamp |
| | | | tch_size | mma | | mentum | | | | | |
|--------------------------+------------+----------------------+------------------------+------------------------+------------------------+------------------------+--------+------------------+----------+-----------+--------------|
| TorchTrainer_38651_00000 | TERMINATED | 10.100.77.179:343446 | 256 | 0.28 | 0.0047 | 0.859 | 10 | 100.938 | 0.594755 | 0.8013 | 1668568759 |
| TorchTrainer_38651_00001 | TERMINATED | 10.100.77.179:343499 | 128 | 0.36 | 0.0004 | 0.546 | 10 | 172.939 | 1.0751 | 0.6165 | 1668568834 |
| TorchTrainer_38651_00002 | TERMINATED | 10.100.77.179:343501 | 128 | 0.38 | 0.0478 | 0.967 | 10 | 166.395 | 0.900563 | 0.6938 | 1668568827 |
| TorchTrainer_38651_00003 | TERMINATED | 10.100.77.179:343503 | 512 | 0.21 | 0.0551 | 0.602 | 10 | 67.7346 | 0.641064 | 0.7892 | 1668568729 |
| TorchTrainer_38651_00004 | TERMINATED | 10.100.77.179:351670 | 256 | 0.39 | 0.0137 | 0.956 | 10 | 95.2324 | 0.612997 | 0.7935 | 1668568828 |
| TorchTrainer_38651_00005 | ERROR | 10.100.77.179:354571 | 256 | 0.87 | 0.5708 | 0.888 | 7 | 70.6566 | 2.14731 | 0.4166 | 1668568833 |
+--------------------------+------------+----------------------+------------------------+------------------------+------------------------+------------------------+--------+------------------+----------+-----------+--------------+
|
An error is raised after a while.
Originally posted by @Tonyhao96 in #30247 (comment)
The text was updated successfully, but these errors were encountered: