Skip to content

Commit

Permalink
[cleanup][1/x] make hp_tensor_to_float8_dynamic only work with hp inputs
Browse files Browse the repository at this point in the history
Summary:

`hp_tensor_to_float8_dynamic` should only work with high precision
inputs, logic which checks for the input being already in float8 up
to the callsites to make it more explicit and easier to follow.

Test Plan: CI

Reviewers:

Subscribers:

Tasks:

Tags:

ghstack-source-id: 5aa5aaf6e776fe9bda230cebd0224404e0584372
ghstack-comment-id: 2560319845
Pull Request resolved: #1458
  • Loading branch information
vkuzo committed Dec 23, 2024
1 parent eab345c commit 135f7f6
Show file tree
Hide file tree
Showing 3 changed files with 13 additions and 10 deletions.
17 changes: 10 additions & 7 deletions torchao/float8/float8_linear.py
Original file line number Diff line number Diff line change
Expand Up @@ -312,13 +312,16 @@ def cast_input_to_float8(self, input: torch.Tensor) -> torch.Tensor:
autocast_dtype = torch.get_autocast_gpu_dtype()
input = input.to(autocast_dtype)

assert self.scaling_type_input is ScalingType.DYNAMIC
input_fp8 = hp_tensor_to_float8_dynamic(
input,
self.config.cast_config_input.target_dtype,
self.linear_mm_config,
gemm_input_role=GemmInputRole.INPUT,
)
if tensor_already_casted_to_fp8(input):
input_fp8 = input
else:
assert self.scaling_type_input is ScalingType.DYNAMIC
input_fp8 = hp_tensor_to_float8_dynamic(
input,
self.config.cast_config_input.target_dtype,
self.linear_mm_config,
gemm_input_role=GemmInputRole.INPUT,
)
return input_fp8

def get_weight_scale(self, weight: torch.Tensor) -> Optional[torch.Tensor]:
Expand Down
2 changes: 0 additions & 2 deletions torchao/float8/float8_scaling_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,8 +52,6 @@ def hp_tensor_to_float8_dynamic(
scaling_granularity: Defines the scaling granularity
axiswise_dim: if axiswise granularity is used, defines the dim to scale across
"""
if tensor_already_casted_to_fp8(hp_tensor):
return hp_tensor
scale = tensor_to_scale(
hp_tensor,
float8_dtype,
Expand Down
4 changes: 3 additions & 1 deletion torchao/float8/stateful_float8_linear.py
Original file line number Diff line number Diff line change
Expand Up @@ -153,7 +153,9 @@ def cast_input_to_float8(self, input: torch.Tensor) -> torch.Tensor:
autocast_dtype = torch.get_autocast_gpu_dtype()
input = input.to(autocast_dtype)

if self.scaling_type_input is ScalingType.DELAYED:
if tensor_already_casted_to_fp8(input):
input_fp8 = input
elif self.scaling_type_input is ScalingType.DELAYED:
scale_fn_name = self.config.delayed_scaling_config.scale_fn_name
_maybe_initialize_amaxes_scales_for_float8_cast(
input,
Expand Down

0 comments on commit 135f7f6

Please sign in to comment.