Skip to content

Commit

Permalink
fix typing
Browse files Browse the repository at this point in the history
  • Loading branch information
SunMarc committed Dec 11, 2023
1 parent 289408e commit 2a33aae
Showing 1 changed file with 4 additions and 4 deletions.
8 changes: 4 additions & 4 deletions optimum/gptq/quantizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,7 +77,7 @@ def __init__(
exllama_config: Dict[str, Any] = None,
max_input_length: Optional[int] = None,
cache_block_outputs: Optional[bool] = True,
inside_layer_modules: Optional[list] = None,
inside_layer_modules: Optional[List[List[str]]] = None,
*args,
**kwargs,
):
Expand Down Expand Up @@ -124,7 +124,7 @@ def __init__(
cache_block_outputs (`bool`, defaults to `True`):
Whether to cache block outputs to reuse as inputs for the succeeding block. It allows optimization of non-standard models
(e.g. ChatGLM) but can require more time.
inside_layer_modules (`List`, *optional*, defaults to `None`):
inside_layer_modules (`List[List[str]]`, *optional*, defaults to `None`):
List of module names to quantize inside block_name_to_quantize. If not set, we will quantize all the linear layers.
"""

Expand Down Expand Up @@ -215,7 +215,7 @@ def convert_model(self, model: nn.Module):
block_name = self.block_name_to_quantize
layers_to_be_replaced = get_layers(model, prefix=block_name)
if self.inside_layer_modules is not None:
layers_to_keep = sum(self.inside_layer_modules,[])
layers_to_keep = sum(self.inside_layer_modules, [])
for name in list(layers_to_be_replaced.keys()):
if not any(name.endswith(layer) for layer in layers_to_keep):
logger.info(f"{name} has not been quantized. We don't convert it")
Expand Down Expand Up @@ -453,7 +453,7 @@ def store_input_hook(_, input, *args):
if not has_device_map or get_device(block) == torch.device("cpu"):
block = block.to(0)
layers = get_layers(block)
if isinstance(self.inside_layer_modules,list) and len(self.inside_layer_modules)>0:
if isinstance(self.inside_layer_modules, list) and len(self.inside_layer_modules) > 0:
if self.true_sequential:
layers_name_list = [sum(self.inside_layer_modules, [])]
else:
Expand Down

0 comments on commit 2a33aae

Please sign in to comment.