Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Refine funcs #446

Merged
merged 5 commits into from
Feb 19, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 11 additions & 9 deletions auto_round/calib_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -150,7 +150,7 @@ def get_pile_val_dataset(tokenizer, seqlen, dataset_name="swift/pile-val-backup"


@register_dataset("BAAI/CCI3-HQ")
def get_CCI3_HQ_dataset(tokenizer, seqlen, dataset_name="BAAI/CCI3-HQ", split=None, seed=42, apply_chat_template=False):
def get_cci3_hq_dataset(tokenizer, seqlen, dataset_name="BAAI/CCI3-HQ", split=None, seed=42, apply_chat_template=False):
"""Returns a dataloader for the specified dataset and split.

Args:
Expand Down Expand Up @@ -235,18 +235,19 @@ def get_new_chinese_title_dataset(
seed=42,
apply_chat_template=False
):
"""Returns a dataloader for the specified dataset and split.
"""
Returns a tokenized dataset for the specified parameters.

Args:
tokenizer: The tokenizer to be used for tokenization.
seqlen: The maximum sequence length.
data_name: The name of the dataset.
split: The data split to be used (e.g., "train", "test").
seed: The random seed for shuffling the dataset.
apply_chat_template: Whether to apply chat template in tokenization.
tokenizer: The tokenizer to use.
seqlen: Maximum sequence length.
dataset_name: Name of the dataset to load.
split: Which split of the dataset to use.
seed: Random seed for shuffling.
apply_template: Whether to apply a template to the data.

Returns:
A dataloader for the specified dataset and split, using the provided tokenizer and sequence length.
A tokenized and shuffled dataset.
"""

def get_tokenizer_function(tokenizer, seqlen, apply_chat_template=apply_chat_template):
Expand Down Expand Up @@ -639,3 +640,4 @@ def collate_batch(batch):

calib_dataloader = DataLoader(dataset_final, batch_size=bs, shuffle=False, collate_fn=collate_batch)
return calib_dataloader

5 changes: 3 additions & 2 deletions auto_round/export/export_to_autoround/export.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@ def check_neq_config(config, data_type, bits, group_size, sym):
return [key for key, expected_value in expected_config.items() if config.get(key) != expected_value]


def dynamic_import_quantLinear_for_packing(backend, bits, group_size, sym):
def dynamic_import_quant_linear_for_packing(backend, bits, group_size, sym):
"""
Dynamically imports and returns the appropriate QuantLinear class based on the specified backend and parameters.

Expand Down Expand Up @@ -97,7 +97,7 @@ def pack_layer(name, model, layer_config, backend, pbar):
layer = get_module(model, name)
device = layer.weight.device

QuantLinear = dynamic_import_quantLinear_for_packing(backend, bits, group_size, sym)
QuantLinear = dynamic_import_quant_linear_for_packing(backend, bits, group_size, sym)

if isinstance(layer, nn.Linear):
in_features = layer.in_features
Expand Down Expand Up @@ -286,3 +286,4 @@ def save(model: nn.Module, save_dir: str, max_shard_size: str = "5GB", safe_seri
if hasattr(model, "config") and hasattr(model.config, "quantization_config"):
with open(os.path.join(save_dir, config_file), "w", encoding="utf-8") as f:
json.dump(model.config.quantization_config, f, indent=2)

7 changes: 4 additions & 3 deletions auto_round/low_cpu_mem/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,7 +78,7 @@ def get_named_children(model, pre=[]):
return module_list


def dowload_hf_model(repo_id, cache_dir=None, repo_type=None, revision=None):
def download_hf_model(repo_id, cache_dir=None, repo_type=None, revision=None):
"""Download hugging face model from hf hub."""
from huggingface_hub.constants import DEFAULT_REVISION, HUGGINGFACE_HUB_CACHE
from huggingface_hub.file_download import REGEX_COMMIT_HASH, repo_folder_name
Expand Down Expand Up @@ -116,7 +116,7 @@ def load_empty_model(pretrained_model_name_or_path, cls=AutoModelForCausalLM, sa
if is_local: # pragma: no cover
path = pretrained_model_name_or_path
else:
path = dowload_hf_model(pretrained_model_name_or_path)
path = download_hf_model(pretrained_model_name_or_path)
torch_dtype = kwargs.pop("torch_dtype", None)
if cls.__base__ == _BaseAutoModelClass:
config = AutoConfig.from_pretrained(path, **kwargs)
Expand Down Expand Up @@ -258,7 +258,7 @@ def _get_path(pretrained_model_name_or_path):
if is_local: # pragma: no cover
path = pretrained_model_name_or_path
else:
path = dowload_hf_model(pretrained_model_name_or_path)
path = download_hf_model(pretrained_model_name_or_path)
return path


Expand Down Expand Up @@ -471,3 +471,4 @@ def layer_wise_load(path):
d = pickle.loads(d)
state_dict.update(d)
return state_dict

11 changes: 6 additions & 5 deletions auto_round/mllm/processor.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,15 +20,15 @@
PROCESSORS = {}


def regist_processor(name):
def register_processor(name):
def register(processor):
PROCESSORS[name] = processor
return processor

return register


@regist_processor("basic")
@register_processor("basic")
class BasicProcessor:
def __init__(self):
pass
Expand Down Expand Up @@ -111,7 +111,7 @@ def squeeze_result(ret):
return ret


@regist_processor("qwen2_vl")
@register_processor("qwen2_vl")
class Qwen2VLProcessor(BasicProcessor):
@staticmethod
def squeeze_result(ret):
Expand All @@ -122,7 +122,7 @@ def squeeze_result(ret):
return ret


@regist_processor("cogvlm2")
@register_processor("cogvlm2")
class CogVLM2Processor(BasicProcessor):
def get_input(
self, text, images, truncation=False,
Expand Down Expand Up @@ -205,7 +205,7 @@ def default_image_processor(image_path_or_url):
llava_train = LazyImport("llava.train.train")


@regist_processor("llava")
@register_processor("llava")
class LlavaProcessor(BasicProcessor):
def post_init(self, model, tokenizer, image_processor=None, **kwargs):
self.model = model
Expand Down Expand Up @@ -245,3 +245,4 @@ class DataArgs:

def data_collator(self, batch):
return self.collator_func(batch)