diff --git a/python/tvm/relay/transform/mixed_precision.py b/python/tvm/relay/transform/mixed_precision.py index 1657f895dcd7..fb4d3fa208a8 100644 --- a/python/tvm/relay/transform/mixed_precision.py +++ b/python/tvm/relay/transform/mixed_precision.py @@ -81,8 +81,6 @@ "divide", "nn.bias_add", "nn.batch_norm", - "sum", - "mean", "sqrt", "shape_of", # Simple activations @@ -107,15 +105,9 @@ # "nn.global_max_pool1d", # does not exist yet "nn.global_max_pool2d", # "nn.global_max_pool3d", # does not exist yet - # "nn.global_avg_pool1d", # does not exist yet - "nn.global_avg_pool2d", - # "nn.global_avg_pool3d", # does not exist yet "nn.adaptive_max_pool1d", "nn.adaptive_max_pool2d", "nn.adaptive_max_pool3d", - "nn.adaptive_avg_pool1d", - "nn.adaptive_avg_pool2d", - "nn.adaptive_avg_pool3d", ] DEFAULT_NEVER_LIST = [ # In general if |f(x)| >> |x| for expected inputs then put the op here. @@ -131,6 +123,13 @@ # Do not allow arange arguments (begin/end) to be fp16. "end" can be a big fp32 number # not representable in fp16. "arange", + # Ops that could involve a large summation are not allowed in fp16. + "nn.global_avg_pool2d", + "nn.adaptive_avg_pool1d", + "nn.adaptive_avg_pool2d", + "nn.adaptive_avg_pool3d", + "sum", + "mean", ] diff --git a/tests/python/relay/test_to_mixed_precision.py b/tests/python/relay/test_to_mixed_precision.py index 99078b7371ba..472f98715ec5 100644 --- a/tests/python/relay/test_to_mixed_precision.py +++ b/tests/python/relay/test_to_mixed_precision.py @@ -221,12 +221,9 @@ def test_do_not_convert_softmax(): b = relay.nn.softmax(a) mod = tvm.IRModule.from_expr(b) mod = tvm.relay.transform.InferType()(mod) - - mod_params = { - "a": np.random.uniform(-1, 1, size=shape).astype("float32"), - } - output_mod = verify_mixed_precision_output_close(mod, mod_params, atol=0.0, rtol=0) - assert tvm.ir.structural_equal(mod, output_mod) + out_mod = ToMixedPrecision("float16")(mod) + orig_mod = tvm.relay.transform.InferType()(mod) + assert tvm.ir.structural_equal(orig_mod, out_mod) def test_do_not_convert_arange(): @@ -234,10 +231,26 @@ def test_do_not_convert_arange(): dtype = "float32" arange = relay.arange(relay.const(1, dtype), relay.const(128, dtype)) mod = tvm.IRModule.from_expr(arange) - mod = tvm.relay.transform.InferType()(mod) + out_mod = ToMixedPrecision("float16")(mod) + orig_mod = tvm.relay.transform.InferType()(mod) + assert tvm.ir.structural_equal(orig_mod, out_mod) - output_mod = verify_mixed_precision_output_close(mod, {}, atol=0.0, rtol=0) - assert tvm.ir.structural_equal(mod, output_mod) + +def test_do_not_convert_summation(): + """Ops that could involve a large summation are not allowed in fp16.""" + shape = [1, 3, 16, 16] + a = relay.var("a", shape=shape) + ops = [ + relay.sum, + relay.mean, + relay.nn.global_avg_pool2d, + lambda inp: relay.nn.adaptive_avg_pool2d(inp, (1, 1)), + ] + for op in ops: + mod = tvm.IRModule.from_expr(op(a)) + out_mod = ToMixedPrecision("float16")(mod) + orig_mod = tvm.relay.transform.InferType()(mod) + assert tvm.ir.structural_equal(orig_mod, out_mod) def test_green_gray_propagates_simple():