Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Allow classification references to use the tensor backend #7629

Merged
merged 8 commits into from
May 31, 2023

Conversation

NicolasHug
Copy link
Member

Results match between PIL and tensor with --test-only:

(pt) ➜  classification git:(main) ✗ torchrun --nproc_per_node=4 train.py --model resnet50 --data-path /datasets01_ontap/imagenet_full_size/061417 --test-only --weights ResNet50_Weights.IMAGENET1K_V1 --backend pil

Test:  Acc@1 76.130 Acc@5 92.856


(pt) ➜  classification git:(main) ✗ torchrun --nproc_per_node=4 train.py --model resnet50 --data-path /datasets01_ontap/imagenet_full_size/061417 --test-only --weights ResNet50_Weights.IMAGENET1K_V1 --backend tensor

Test:  Acc@1 76.138 Acc@5 92.866

I also verified that both train_one_epoch() and evaluate() run without errors for both backends. I didn't check the acc on those (that would require full training), but considering this PR is mostly a simplified version of #7220, everything should be fine. If there's any issue we'll know soon enough.

@pytorch-bot
Copy link

pytorch-bot bot commented May 25, 2023

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/vision/7629

Note: Links to docs will display an error until the docs builds have been completed.

❌ 32 New Failures

As of commit e18c1d1:

NEW FAILURES - The following jobs have failed:

This comment was automatically generated by Dr. CI and updates every 15 minutes.

@@ -507,6 +511,7 @@ def get_args_parser(add_help=True):
"--ra-reps", default=3, type=int, help="number of repetitions for Repeated Augmentation (default: 3)"
)
parser.add_argument("--weights", default=None, type=str, help="the weights enum name to load")
parser.add_argument("--backend", default="PIL", type=str, help="PIL or tensor - case insensitive")
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

IIRC, there is a callback option here that we could use to call .lower(). Even if not, we should probably do this during parsing or in the main script rather inside the transforms. Optimally, we would also check the values here to avoid doing this only at the transform level.

Copy link
Member Author

@NicolasHug NicolasHug May 25, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I prefer doing this at the lower level, those transforms can still be called on their own outside of the reference scripts. basically the "case-insensitive" feature exists at the transform level not at the args level

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nit: Should we still have this .lower() call here to avoid doing it in the main script? I mean, right now there is only one occurrence so it doesn't really matter. However, if we have multiple ones in the future, we need to call it multiple times instead of one.

@@ -160,10 +161,13 @@ def load_data(traindir, valdir, args):
else:
if args.weights and args.test_only:
weights = torchvision.models.get_weight(args.weights)
preprocessing = weights.transforms()
preprocessing = weights.transforms(antialias=True, backend=args.backend)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The antialias=True is not related to this PR, right? Is True the default?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

the default is None which is True for PIL and False for tensors. In those training references, we always want it to be True

NicolasHug and others added 3 commits May 25, 2023 13:02
Copy link
Collaborator

@pmeier pmeier left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks!

if backend == "tensor":
trans.append(transforms.PILToTensor())
elif backend != "pil":
raise ValueError("backend can be 'tensor' or 'pil', but got {backend}")
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
raise ValueError("backend can be 'tensor' or 'pil', but got {backend}")
raise ValueError(f"backend can be 'tensor' or 'pil', but got {backend}")

if backend == "tensor":
trans.append(transforms.PILToTensor())
else:
assert backend == "pil"
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same as above.

@@ -507,6 +511,7 @@ def get_args_parser(add_help=True):
"--ra-reps", default=3, type=int, help="number of repetitions for Repeated Augmentation (default: 3)"
)
parser.add_argument("--weights", default=None, type=str, help="the weights enum name to load")
parser.add_argument("--backend", default="PIL", type=str, help="PIL or tensor - case insensitive")
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nit: Should we still have this .lower() call here to avoid doing it in the main script? I mean, right now there is only one occurrence so it doesn't really matter. However, if we have multiple ones in the future, we need to call it multiple times instead of one.

@NicolasHug NicolasHug merged commit 0ab7d05 into pytorch:main May 31, 2023
@github-actions
Copy link

Hey @NicolasHug!

You merged this PR, but no labels were added. The list of valid labels is available at https://github.com/pytorch/vision/blob/main/.github/process_commit.py

facebook-github-bot pushed a commit that referenced this pull request Jun 14, 2023
…7629)

Summary: Co-authored-by: Philip Meier <github.pmeier@posteo.de>

Reviewed By: vmoens

Differential Revision: D46724120

fbshipit-source-id: 52df67c15514fd17e310c168846b94f80688062a
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants