-
Notifications
You must be signed in to change notification settings - Fork 7k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Allow classification references to use the tensor backend #7629
Conversation
references/classification/train.py
Outdated
@@ -507,6 +511,7 @@ def get_args_parser(add_help=True): | |||
"--ra-reps", default=3, type=int, help="number of repetitions for Repeated Augmentation (default: 3)" | |||
) | |||
parser.add_argument("--weights", default=None, type=str, help="the weights enum name to load") | |||
parser.add_argument("--backend", default="PIL", type=str, help="PIL or tensor - case insensitive") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
IIRC, there is a callback option here that we could use to call .lower()
. Even if not, we should probably do this during parsing or in the main script rather inside the transforms. Optimally, we would also check the values here to avoid doing this only at the transform level.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I prefer doing this at the lower level, those transforms can still be called on their own outside of the reference scripts. basically the "case-insensitive" feature exists at the transform level not at the args
level
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Nit: Should we still have this .lower()
call here to avoid doing it in the main script? I mean, right now there is only one occurrence so it doesn't really matter. However, if we have multiple ones in the future, we need to call it multiple times instead of one.
references/classification/train.py
Outdated
@@ -160,10 +161,13 @@ def load_data(traindir, valdir, args): | |||
else: | |||
if args.weights and args.test_only: | |||
weights = torchvision.models.get_weight(args.weights) | |||
preprocessing = weights.transforms() | |||
preprocessing = weights.transforms(antialias=True, backend=args.backend) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The antialias=True
is not related to this PR, right? Is True
the default?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
the default is None which is True for PIL and False for tensors. In those training references, we always want it to be True
Co-authored-by: Philip Meier <github.pmeier@posteo.de>
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks!
references/classification/presets.py
Outdated
if backend == "tensor": | ||
trans.append(transforms.PILToTensor()) | ||
elif backend != "pil": | ||
raise ValueError("backend can be 'tensor' or 'pil', but got {backend}") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
raise ValueError("backend can be 'tensor' or 'pil', but got {backend}") | |
raise ValueError(f"backend can be 'tensor' or 'pil', but got {backend}") |
references/classification/presets.py
Outdated
if backend == "tensor": | ||
trans.append(transforms.PILToTensor()) | ||
else: | ||
assert backend == "pil" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Same as above.
references/classification/train.py
Outdated
@@ -507,6 +511,7 @@ def get_args_parser(add_help=True): | |||
"--ra-reps", default=3, type=int, help="number of repetitions for Repeated Augmentation (default: 3)" | |||
) | |||
parser.add_argument("--weights", default=None, type=str, help="the weights enum name to load") | |||
parser.add_argument("--backend", default="PIL", type=str, help="PIL or tensor - case insensitive") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Nit: Should we still have this .lower()
call here to avoid doing it in the main script? I mean, right now there is only one occurrence so it doesn't really matter. However, if we have multiple ones in the future, we need to call it multiple times instead of one.
Hey @NicolasHug! You merged this PR, but no labels were added. The list of valid labels is available at https://github.com/pytorch/vision/blob/main/.github/process_commit.py |
…7629) Summary: Co-authored-by: Philip Meier <github.pmeier@posteo.de> Reviewed By: vmoens Differential Revision: D46724120 fbshipit-source-id: 52df67c15514fd17e310c168846b94f80688062a
Results match between PIL and tensor with
--test-only
:I also verified that both
train_one_epoch()
andevaluate()
run without errors for both backends. I didn't check the acc on those (that would require full training), but considering this PR is mostly a simplified version of #7220, everything should be fine. If there's any issue we'll know soon enough.