From 6784127eda7c1e2d820c4a7ec9199add350f9aef Mon Sep 17 00:00:00 2001 From: Mihir Patel Date: Mon, 27 Feb 2023 11:46:17 -0800 Subject: [PATCH] tag algos --- .../low_precision_groupnorm/low_precision_groupnorm.py | 7 +++++++ .../low_precision_layernorm/low_precision_layernorm.py | 7 +++++++ 2 files changed, 14 insertions(+) diff --git a/composer/algorithms/low_precision_groupnorm/low_precision_groupnorm.py b/composer/algorithms/low_precision_groupnorm/low_precision_groupnorm.py index 1ccf626d31..b1c127e535 100644 --- a/composer/algorithms/low_precision_groupnorm/low_precision_groupnorm.py +++ b/composer/algorithms/low_precision_groupnorm/low_precision_groupnorm.py @@ -55,6 +55,13 @@ def __init__(self, apply_at: Event = Event.INIT): if self.apply_at not in {Event.INIT, Event.AFTER_LOAD}: raise ValueError('LowPrecisionGroupNorm only supports application on Event.INIT and Event.AFTER_LOAD.') + def __repr__(self) -> str: + return f'{self.__class__.__name__}(apply_at={self.apply_at})' + + @staticmethod + def required_on_load() -> bool: + return True + def match(self, event: Event, state: State) -> bool: del state # unused return event == self.apply_at diff --git a/composer/algorithms/low_precision_layernorm/low_precision_layernorm.py b/composer/algorithms/low_precision_layernorm/low_precision_layernorm.py index a09f3972ef..b0c72f1e92 100644 --- a/composer/algorithms/low_precision_layernorm/low_precision_layernorm.py +++ b/composer/algorithms/low_precision_layernorm/low_precision_layernorm.py @@ -70,6 +70,13 @@ def __init__(self, apply_at: Event = Event.INIT): if self.apply_at not in {Event.INIT, Event.AFTER_LOAD}: raise ValueError('LowPrecisionLayerNorm only supports application on Event.INIT and Event.AFTER_LOAD.') + def __repr__(self) -> str: + return f'{self.__class__.__name__}(apply_at={self.apply_at})' + + @staticmethod + def required_on_load() -> bool: + return True + def match(self, event: Event, state: State) -> bool: del state # unused return event == self.apply_at