Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Feature]: LoRA support for InternVLChatModel #9495

Open
1 task done
Tracked by #4194
AkshataABhat opened this issue Oct 18, 2024 · 15 comments
Open
1 task done
Tracked by #4194

[Feature]: LoRA support for InternVLChatModel #9495

AkshataABhat opened this issue Oct 18, 2024 · 15 comments

Comments

@AkshataABhat
Copy link

AkshataABhat commented Oct 18, 2024

Your current environment

vllm version = 0.6.1

Model Input Dumps

No response

🐛 Describe the bug

The output of `command:`

vllm version = 0.6.1. InternVLChat is in list of supported models.

CUDA_VISIBLE_DEVICES=0 python3 -m vllm.entrypoints.openai.api_server --model OpenGVLab/InternVL2-8B --vllm_enable_lora=True --vllm_max_lora_rank=32 --lora-modules line_items=checkpoint-786/ --api-key=abcd  --host=0.0.0.0 --port=8817 --gpu_memory_utilization 0.95 --max_model_len=8192 --trust_remote_code --limit-mm-per-prompt 'image=16' 
rank0]:   File "/root/anaconda3/envs/msswift_latest/lib/python3.10/site-packages/vllm/engine/async_llm_engine.py", line 636, in __init__
[rank0]:     self.engine = self._init_engine(*args, **kwargs)
[rank0]:   File "/root/anaconda3/envs/msswift_latest/lib/python3.10/site-packages/vllm/engine/async_llm_engine.py", line 840, in _init_engine
[rank0]:     return engine_class(*args, **kwargs)
[rank0]:   File "/root/anaconda3/envs/msswift_latest/lib/python3.10/site-packages/vllm/engine/async_llm_engine.py", line 272, in __init__
[rank0]:     super().__init__(*args, **kwargs)
[rank0]:   File "/root/anaconda3/envs/msswift_latest/lib/python3.10/site-packages/vllm/engine/llm_engine.py", line 270, in __init__
[rank0]:     self.model_executor = executor_class(
[rank0]:   File "/root/anaconda3/envs/msswift_latest/lib/python3.10/site-packages/vllm/executor/executor_base.py", line 46, in __init__
[rank0]:     self._init_executor()
[rank0]:   File "/root/anaconda3/envs/msswift_latest/lib/python3.10/site-packages/vllm/executor/gpu_executor.py", line 39, in _init_executor
[rank0]:     self.driver_worker.load_model()
[rank0]:   File "/root/anaconda3/envs/msswift_latest/lib/python3.10/site-packages/vllm/worker/worker.py", line 182, in load_model
[rank0]:     self.model_runner.load_model()
[rank0]:   File "/root/anaconda3/envs/msswift_latest/lib/python3.10/site-packages/vllm/worker/model_runner.py", line 881, in load_model
[rank0]:     self.model = get_model(model_config=self.model_config,
[rank0]:   File "/root/anaconda3/envs/msswift_latest/lib/python3.10/site-packages/vllm/model_executor/model_loader/__init__.py", line 19, in get_model
[rank0]:     return loader.load_model(model_config=model_config,
[rank0]:   File "/root/anaconda3/envs/msswift_latest/lib/python3.10/site-packages/vllm/model_executor/model_loader/loader.py", line 341, in load_model
[rank0]:     model = _initialize_model(model_config, self.load_config,
[rank0]:   File "/root/anaconda3/envs/msswift_latest/lib/python3.10/site-packages/vllm/model_executor/model_loader/loader.py", line 170, in _initialize_model
[rank0]:     return build_model(
[rank0]:   File "/root/anaconda3/envs/msswift_latest/lib/python3.10/site-packages/vllm/model_executor/model_loader/loader.py", line 151, in build_model
[rank0]:     extra_kwargs = _get_model_initialization_kwargs(model_class, lora_config,
[rank0]:   File "/root/anaconda3/envs/msswift_latest/lib/python3.10/site-packages/vllm/model_executor/model_loader/loader.py", line 128, in _get_model_initialization_kwargs
[rank0]:     raise ValueError(
[rank0]: ValueError: Model InternVLChatModel does not support LoRA, but LoRA is enabled. Support for this model may be added in the future. If this is important to you, please open an issue on github

Before submitting a new issue...

  • Make sure you already searched for relevant issues, and asked the chatbot living at the bottom right corner of the documentation page, which can answer lots of frequently asked questions.
@AkshataABhat AkshataABhat added the bug Something isn't working label Oct 18, 2024
@DarkLight1337 DarkLight1337 added feature request and removed bug Something isn't working labels Oct 18, 2024
@DarkLight1337 DarkLight1337 changed the title [Bug]: ValueError: Model InternVLChatModel does not support LoRA, but LoRA is enabled. Support for this model may be added in the future. If this is important to you, please open an issue on github. [Feature]: Model InternVLChatModel does not support LoRA, but LoRA is enabled. Support for this model may be added in the future. If this is important to you, please open an issue on github. Oct 18, 2024
@DarkLight1337 DarkLight1337 changed the title [Feature]: Model InternVLChatModel does not support LoRA, but LoRA is enabled. Support for this model may be added in the future. If this is important to you, please open an issue on github. [Feature]: LoRA support for InternVLChatModel Oct 18, 2024
@jeejeelee
Copy link
Collaborator

Could you provide your LoRA configuration? I might be able to implement this quickly

@AkshataABhat
Copy link
Author

@jeejeelee pls help..its quite urgent. thanks!

"lora_rank": 32,
"lora_alpha": 64,
"lora_dropout": 0.05,
"lora_bias_trainable": "none",
"lora_dtype": null,
"lora_lr_ratio": null,
"use_rslora": false,
"use_dora": false,
"init_lora_weights": true,

@jeejeelee
Copy link
Collaborator

I will. can you provibe target_modules info?

@AkshataABhat
Copy link
Author

AkshataABhat commented Oct 18, 2024

@jeejeelee
"target_modules": [
"fc2",
"w2",
"output",
"mlp1.3",
"fc1",
"mlp1.1",
"w3",
"proj",
"w1",
"wqkv",
"wo",
"qkv"
]

@jeejeelee
Copy link
Collaborator

Currently, vllm does not support LoRA inference for the visual encoder and projector components. It only supports LoRA inference for language models. Even if I were to implement support lorafor internvl now, the infer results might not be correct.

@AkshataABhat
Copy link
Author

AkshataABhat commented Oct 18, 2024

@jeejeelee what is the easiest workaround for this? deploying using merged lora is affecting performance..i would want to deploy the original weights..is there some alternative I can explore (fast inference) for prod deployment.

@jeejeelee
Copy link
Collaborator

deploying using merged lora is affecting performance

How many LoRAs do you have? If you only have one, merging the weights would lead to higher inference efficiency.

@AkshataABhat
Copy link
Author

AkshataABhat commented Oct 18, 2024

i have 2 loras main and adalora. sharing the whole config here for reference:

{
  "model_type": "internvl2-8b",
  "model_id_or_path": "OpenGVLab/InternVL2-8B",
  "model_revision": "main",
  "full_determinism": false,
  "sft_type": "lora",
  "freeze_parameters": [],
  "freeze_vit": false,
  "freeze_parameters_ratio": 0.0,
  "additional_trainable_parameters": [],
  "tuner_backend": "peft",
  "template_type": "internvl2",
  "output_dir": "LLaMA-Factory/output/internvl2-8b/v0-20240912-105045",
  "add_output_dir_suffix": false,
  "ddp_backend": null,
  "ddp_find_unused_parameters": null,
  "ddp_broadcast_buffers": null,
  "ddp_timeout": 1800,
  "seed": 42,
  "resume_from_checkpoint": null,
  "resume_only_model": false,
  "ignore_data_skip": false,
  "dtype": "bf16",
  "packing": false,
  "train_backend": "transformers",
  "tp": 1,
  "pp": 1,
  "min_lr": null,
  "sequence_parallel": false,
  "model_kwargs": null,
  "loss_name": null,
  "dataset": [
    "train.jsonl"
  ],
  "val_dataset": [],
  "dataset_seed": 42,
  "dataset_test_ratio": 0.01,
  "use_loss_scale": false,
  "loss_scale_config_path": "LLaMA-Factory/swift/swift/llm/agent/default_loss_scale_config.json",
  "system": "你是由上海人工智能实验室联合商汤科技开发的书生多模态大模型,英文名叫InternVL, 是一个有用无害的人工智能助手。",
  "tools_prompt": "react_en",
  "max_length": 3072,
  "truncation_strategy": "delete",
  "check_dataset_strategy": "none",
  "streaming": false,
  "streaming_val_size": 0,
  "streaming_buffer_size": 16384,
  "model_name": [
    null,
    null
  ],
  "model_author": [
    null,
    null
  ],
  "quant_method": null,
  "quantization_bit": 0,
  "hqq_axis": 0,
  "hqq_dynamic_config_path": null,
  "bnb_4bit_comp_dtype": "bf16",
  "bnb_4bit_quant_type": "nf4",
  "bnb_4bit_use_double_quant": true,
  "bnb_4bit_quant_storage": null,
  "rescale_image": -1,
  "target_modules": [
    "fc2",
    "w2",
    "output",
    "mlp1.3",
    "fc1",
    "mlp1.1",
    "w3",
    "proj",
    "w1",
    "wqkv",
    "wo",
    "qkv"
  ],
  "target_regex": null,
  "modules_to_save": [],
  "lora_rank": 32,
  "lora_alpha": 64,
  "lora_dropout": 0.05,
  "lora_bias_trainable": "none",
  "lora_dtype": null,
  "lora_lr_ratio": null,
  "use_rslora": false,
  "use_dora": false,
  "init_lora_weights": true,
  "fourier_n_frequency": 2000,
  "fourier_scaling": 300.0,
  "rope_scaling": null,
  "boft_block_size": 4,
  "boft_block_num": 0,
  "boft_n_butterfly_factor": 1,
  "boft_dropout": 0.0,
  "vera_rank": 256,
  "vera_projection_prng_key": 0,
  "vera_dropout": 0.0,
  "vera_d_initial": 0.1,
  "adapter_act": "gelu",
  "adapter_length": 128,
  "use_galore": false,
  "galore_target_modules": null,
  "galore_rank": 128,
  "galore_update_proj_gap": 50,
  "galore_scale": 1.0,
  "galore_proj_type": "std",
  "galore_optim_per_parameter": false,
  "galore_with_embedding": false,
  "galore_quantization": false,
  "galore_proj_quant": false,
  "galore_proj_bits": 4,
  "galore_proj_group_size": 256,
  "galore_cos_threshold": 0.4,
  "galore_gamma_proj": 2,
  "galore_queue_size": 5,
  "adalora_target_r": 8,
  "adalora_init_r": 12,
  "adalora_tinit": 0,
  "adalora_tfinal": 0,
  "adalora_deltaT": 1,
  "adalora_beta1": 0.85,
  "adalora_beta2": 0.85,
  "adalora_orth_reg_weight": 0.5,
  "ia3_feedforward_modules": [],
  "llamapro_num_new_blocks": 4,
  "llamapro_num_groups": null,
  "neftune_noise_alpha": null,
  "neftune_backend": "transformers",
  "lisa_activated_layers": 0,
  "lisa_step_interval": 20,
  "reft_layer_key": null,
  "reft_layers": null,
  "reft_rank": 4,
  "reft_intervention_type": "LoreftIntervention",
  "reft_args": null,
  "use_liger": false,
  "gradient_checkpointing": true,
  "deepspeed": null,
  "batch_size": 1,
  "eval_batch_size": 1,
  "auto_find_batch_size": false,
  "num_train_epochs": 3,
  "max_steps": -1,
  "optim": "adamw_torch",
  "adam_beta1": 0.9,
  "adam_beta2": 0.95,
  "adam_epsilon": 1e-08,
  "learning_rate": 5e-05,
  "weight_decay": 0.1,
  "gradient_accumulation_steps": 1,
  "max_grad_norm": 1,
  "predict_with_generate": false,
  "lr_scheduler_type": "cosine",
  "lr_scheduler_kwargs": {},
  "warmup_ratio": 0.05,
  "warmup_steps": 0,
  "eval_steps": 100,
  "save_steps": 100,
  "save_only_model": false,
  "save_total_limit": 2,
  "logging_steps": 5,
  "acc_steps": 1,
  "dataloader_num_workers": 1,
  "dataloader_pin_memory": true,
  "dataloader_drop_last": false,
  "push_to_hub": false,
  "hub_model_id": null,
  "hub_token": null,
  "hub_private_repo": false,
  "hub_strategy": "every_save",
  "test_oom_error": false,
  "disable_tqdm": false,
  "lazy_tokenize": true,
  "preprocess_num_proc": 1,
  "use_flash_attn": true,
  "ignore_args_error": true,
  "check_model_is_latest": true,
  "logging_dir": "LLaMA-Factory/output/internvl2-8b/v0-20240912-105045/runs",
  "report_to": [
    "tensorboard"
  ],
  "acc_strategy": "token",
  "save_on_each_node": false,
  "evaluation_strategy": "steps",
  "save_strategy": "steps",
  "save_safetensors": true,
  "gpu_memory_fraction": null,
  "include_num_input_tokens_seen": false,
  "local_repo_path": null,
  "custom_register_path": null,
  "custom_dataset_info": null,
  "device_map_config": null,
  "device_max_memory": [],
  "max_new_tokens": 2048,
  "do_sample": null,
  "temperature": null,
  "top_k": null,
  "top_p": null,
  "repetition_penalty": null,
  "num_beams": 1,
  "fsdp": "",
  "fsdp_config": null,
  "sequence_parallel_size": 1,
  "model_layer_cls_name": null,
  "metric_warmup_step": 0,
  "fsdp_num": 1,
  "per_device_train_batch_size": null,
  "per_device_eval_batch_size": null,
  "eval_strategy": null,
  "self_cognition_sample": 0,
  "train_dataset_mix_ratio": 0.0,
  "train_dataset_mix_ds": [
    "ms-bench"
  ],
  "train_dataset_sample": -1,
  "val_dataset_sample": null,
  "safe_serialization": null,
  "only_save_model": null,
  "neftune_alpha": null,
  "deepspeed_config_path": null,
  "model_cache_dir": null,
  "lora_dropout_p": null,
  "lora_target_modules": [
    "fc2",
    "w2",
    "output",
    "mlp1.3",
    "fc1",
    "mlp1.1",
    "w3",
    "proj",
    "w1",
    "wqkv",
    "wo",
    "qkv"
  ],
  "lora_target_regex": null,
  "lora_modules_to_save": [],
  "boft_target_modules": [],
  "boft_modules_to_save": [],
  "vera_target_modules": [],
  "vera_modules_to_save": [],
  "ia3_target_modules": [],
  "ia3_modules_to_save": [],
  "custom_train_dataset_path": [],
  "custom_val_dataset_path": [],
  "device_map_config_path": null,
  "push_hub_strategy": null,
  "use_self_cognition": false,
  "is_multimodal": true,
  "is_vision": true,
  "lora_use_embedding": false,
  "lora_use_all": true,
  "lora_m2s_use_embedding": false,
  "lora_m2s_use_ln": false,
  "torch_dtype": "torch.bfloat16",
  "fp16": false,
  "bf16": true,
  "rank": -1,
  "local_rank": -1,
  "world_size": 1,
  "local_world_size": 1,
  "bnb_4bit_compute_dtype": "torch.bfloat16",
  "load_in_4bit": false,
  "load_in_8bit": false,
  "train_sampler_random": true,
  "training_args": "Seq2SeqTrainingArguments(output_dir='LLaMA-Factory/output/internvl2-8b/v0-20240912-105045', overwrite_output_dir=False, do_train=False, do_eval=True, do_predict=False, eval_strategy=<IntervalStrategy.STEPS: 'steps'>, prediction_loss_only=False, per_device_train_batch_size=1, per_device_eval_batch_size=1, per_gpu_train_batch_size=None, per_gpu_eval_batch_size=None, gradient_accumulation_steps=1, eval_accumulation_steps=None, eval_delay=0, torch_empty_cache_steps=None, learning_rate=5e-05, weight_decay=0.1, adam_beta1=0.9, adam_beta2=0.95, adam_epsilon=1e-08, max_grad_norm=1, num_train_epochs=3, max_steps=-1, lr_scheduler_type=<SchedulerType.COSINE: 'cosine'>, lr_scheduler_kwargs={}, warmup_ratio=0.05, warmup_steps=0, log_level='passive', log_level_replica='warning', log_on_each_node=True, logging_dir='LLaMA-Factory/output/internvl2-8b/v0-20240912-105045/runs', logging_strategy=<IntervalStrategy.STEPS: 'steps'>, logging_first_step=True, logging_steps=5, logging_nan_inf_filter=True, save_strategy=<IntervalStrategy.STEPS: 'steps'>, save_steps=100, save_total_limit=2, save_safetensors=True, save_on_each_node=False, save_only_model=False, restore_callback_states_from_checkpoint=False, no_cuda=False, use_cpu=False, use_mps_device=False, seed=42, data_seed=None, jit_mode_eval=False, use_ipex=False, bf16=True, fp16=False, fp16_opt_level='O1', half_precision_backend='auto', bf16_full_eval=False, fp16_full_eval=False, tf32=None, local_rank=0, ddp_backend=None, tpu_num_cores=None, tpu_metrics_debug=False, debug=[], dataloader_drop_last=False, eval_steps=100, dataloader_num_workers=1, dataloader_prefetch_factor=None, past_index=-1, run_name='LLaMA-Factory/output/internvl2-8b/v0-20240912-105045', disable_tqdm=False, remove_unused_columns=False, label_names=None, load_best_model_at_end=False, metric_for_best_model='loss', greater_is_better=False, ignore_data_skip=False, fsdp=[], fsdp_min_num_params=0, fsdp_config={'min_num_params': 0, 'xla': False, 'xla_fsdp_v2': False, 'xla_fsdp_grad_ckpt': False}, fsdp_transformer_layer_cls_to_wrap=None, accelerator_config=AcceleratorConfig(split_batches=False, dispatch_batches=False, even_batches=True, use_seedable_sampler=True, non_blocking=False, gradient_accumulation_kwargs=None, use_configured_state=False), deepspeed=None, label_smoothing_factor=0.0, optim=<OptimizerNames.ADAMW_TORCH: 'adamw_torch'>, optim_args=None, adafactor=False, group_by_length=False, length_column_name='length', report_to=['tensorboard'], ddp_find_unused_parameters=None, ddp_bucket_cap_mb=None, ddp_broadcast_buffers=None, dataloader_pin_memory=True, dataloader_persistent_workers=False, skip_memory_metrics=True, use_legacy_prediction_loop=False, push_to_hub=False, resume_from_checkpoint=None, hub_model_id=None, hub_strategy=<HubStrategy.EVERY_SAVE: 'every_save'>, hub_token=None, hub_private_repo=False, hub_always_push=False, gradient_checkpointing=True, gradient_checkpointing_kwargs=None, include_inputs_for_metrics=False, eval_do_concat_batches=True, fp16_backend='auto', evaluation_strategy=None, push_to_hub_model_id=None, push_to_hub_organization=None, push_to_hub_token=None, mp_parameters='', auto_find_batch_size=False, full_determinism=False, torchdynamo=None, ray_scope='last', ddp_timeout=1800, torch_compile=False, torch_compile_backend=None, torch_compile_mode=None, dispatch_batches=None, split_batches=None, include_tokens_per_second=False, include_num_input_tokens_seen=False, neftune_noise_alpha=None, optim_target_modules=None, batch_eval_metrics=False, eval_on_start=False, eval_use_gather_object=False, sortish_sampler=True, predict_with_generate=False, generation_max_length=None, generation_num_beams=None, generation_config=GenerationConfig {\n  \"eos_token_id\": 92542,\n  \"max_new_tokens\": 2048,\n  \"pad_token_id\": 2\n}\n, train_sampler_random=True, acc_strategy='token', loss_name=None, additional_saved_files=[], metric_warmup_step=0, train_dataset_sample=-1)"
}

@AkshataABhat
Copy link
Author

@jeejeelee efficiency is better but results are worse than unmerged one

@jeejeelee
Copy link
Collaborator

@AkshataABhat II have preliminarily completed the work of supporting LoRA in InternVL, but note that it can still only add LoRA to the language model. You can try it out. See my branch: https://github.com/jeejeelee/vllm/tree/internvl-lora .

Additionally, I suggest you could try adding LoRA only to the language model and retrain it

@AkshataABhat
Copy link
Author

@jeejeelee I discussed this with my team..and they want support for the lora for vision model as well..

@zzf2grx
Copy link

zzf2grx commented Oct 28, 2024

@AkshataABhat II have preliminarily completed the work of supporting LoRA in InternVL, but note that it can still only add LoRA to the language model. You can try it out. See my branch: https://github.com/jeejeelee/vllm/tree/internvl-lora .

Additionally, I suggest you could try adding LoRA only to the language model and retrain it

Hi~ I have some questions about this branch. Does it support InternVL2-8B (awq) with multi LoRAs? Thank you!

@jeejeelee
Copy link
Collaborator

@AkshataABhat II have preliminarily completed the work of supporting LoRA in InternVL, but note that it can still only add LoRA to the language model. You can try it out. See my branch: https://github.com/jeejeelee/vllm/tree/internvl-lora .
Additionally, I suggest you could try adding LoRA only to the language model and retrain it

Hi~ I have some questions about this branch. Does it support InternVL2-8B (awq) with multi LoRAs? Thank you!

If internvl supports awq, I think it should be fine. You can try it out, and if you have any issues, I'll help you solve them.

@Zhiy-Zhang
Copy link

@AkshataABhat II have preliminarily completed the work of supporting LoRA in InternVL, but note that it can still only add LoRA to the language model. You can try it out. See my branch: https://github.com/jeejeelee/vllm/tree/internvl-lora .
Additionally, I suggest you could try adding LoRA only to the language model and retrain it

Hi~ I have some questions about this branch. Does it support InternVL2-8B (awq) with multi LoRAs? Thank you!

If internvl supports awq, I think it should be fine. You can try it out, and if you have any issues, I'll help you solve them.

@jeejeelee hi~, I have some questions about awq with LoRAs? When language model(awq) with multi LoRAs, LoRA weights need to be quantized?

@jeejeelee
Copy link
Collaborator

@AkshataABhat II have preliminarily completed the work of supporting LoRA in InternVL, but note that it can still only add LoRA to the language model. You can try it out. See my branch: https://github.com/jeejeelee/vllm/tree/internvl-lora .
Additionally, I suggest you could try adding LoRA only to the language model and retrain it

Hi~ I have some questions about this branch. Does it support InternVL2-8B (awq) with multi LoRAs? Thank you!

If internvl supports awq, I think it should be fine. You can try it out, and if you have any issues, I'll help you solve them.

@jeejeelee hi~, I have some questions about awq with LoRAs? When language model(awq) with multi LoRAs, LoRA weights need to be quantized?

You don't need to

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

No branches or pull requests

5 participants