Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Model] Add Qwen2-Audio model support #9248

Merged
merged 32 commits into from
Oct 23, 2024

Conversation

faychu
Copy link
Contributor

@faychu faychu commented Oct 10, 2024

This PR adding support for Qwen2-Audio model.

FIX #8394
FIX #8461

Requirements

Use transformers>=4.45.1, and please install vLLM from source.

Example Usage

import requests


from transformers import AutoTokenizer, AutoProcessor
from transformers.pipelines.audio_utils import ffmpeg_read

from vllm import LLM, SamplingParams

MODEL_PATH = 'Qwen/Qwen2-Audio-7B-Instruct'


def qwen2_audio_batch():
    processor = AutoProcessor.from_pretrained(MODEL_PATH)
    
    conversation1 = [
            {"role": "user", "content": [
                {"type": "audio", "audio_url": "https://qianwen-res.oss-cn-beijing.aliyuncs.com/Qwen2-Audio/audio/glass-breaking-151256.mp3"},
                {"type": "text", "text": "What's that sound?"},
            ]},
            {"role": "assistant", "content": "It is the sound of glass shattering."},
            {"role": "user", "content": [
                {"type": "audio", "audio_url": "https://qianwen-res.oss-cn-beijing.aliyuncs.com/Qwen2-Audio/audio/f2641_0_throatclearing.wav"},
                {"type": "text", "text": "What can you hear?"},
            ]}
        ]

    conversation2 = [
                {"role": "user", "content": [
                    {"type": "audio", "audio_url": "https://qianwen-res.oss-cn-beijing.aliyuncs.com/Qwen2-Audio/audio/1272-128104-0000.flac"},
                    {"type": "text", "text": "What does the person say?"},
                ]},
            ]

    conversation3 = [
                {"role": "user", "content": [
                    {"type": "text", "text": "How to make a pizza?"},
                ]},
            ]

    conversations = [conversation1, conversation2, conversation3]

    text = [processor.apply_chat_template(conversation, add_generation_prompt=True, tokenize=False, add_audio_id = True) for conversation in conversations]
    
    audios = []
    for conversation in conversations:
        audio_infos_vllm = []
        for message in conversation:
            if isinstance(message["content"], list):
                for ele in message["content"]:
                    if ele["type"] == "audio":
                        audio_infos_vllm.append((ffmpeg_read(requests.get(ele['audio_url']).content,
                                        sampling_rate=processor.
                                        feature_extractor.sampling_rate),
                                        processor.feature_extractor.sampling_rate))
        audios.append(audio_infos_vllm)
        
    inputs = [
        {
            'prompt': text[i],
            'multi_modal_data': {
                'audio': audios[i]
            }
        } for i in range(len(conversations))
    ]
    return inputs



def main():
    llm = LLM(
        model=MODEL_PATH, trust_remote_code=True, gpu_memory_utilization=0.98,
        enforce_eager=True,     # Disable CUDA graph, force call forward in every decode step.
        limit_mm_per_prompt={"audio": 5},
    )
    sampling_params = SamplingParams(
        temperature=0.7, top_p=0.01, top_k=1, repetition_penalty=1.1, max_tokens=256,
        stop_token_ids=[],
    )

    inputs = qwen2_audio_batch()
    print(f"{inputs=}")

    outputs = llm.generate(inputs, sampling_params=sampling_params)

    for i, output in enumerate(outputs):
        generated_text = output.outputs[0].text
        print()
        print('=' * 40)
        print(f"Inputs[{i}]: {inputs[i]['prompt']!r}")
        print(f"Generated text: {generated_text!r}")


if __name__ == '__main__':
    main()

Copy link

👋 Hi! Thank you for contributing to the vLLM project.
Just a reminder: PRs would not trigger full CI run by default. Instead, it would only run fastcheck CI which starts running only a small and essential subset of CI tests to quickly catch errors. You can run other CI tests on top of those by going to your fastcheck build on Buildkite UI (linked in the PR checks section) and unblock them. If you do not have permission to unblock, ping simon-mo or khluu to add you in our Buildkite org.

Once the PR is approved and ready to go, your PR reviewer(s) can run CI to test the changes comprehensively before merging.

To run CI, PR reviewers can do one of these:

  • Add ready label to the PR
  • Enable auto-merge.

🚀

@DarkLight1337
Copy link
Member

Thanks for implementing this! Please run format.sh to fix the lint errors.

@faychu
Copy link
Contributor Author

faychu commented Oct 11, 2024

Thanks for implementing this! Please run format.sh to fix the lint errors.

@DarkLight1337 Already fixed them.

@DarkLight1337
Copy link
Member

DarkLight1337 commented Oct 11, 2024

To ease testing, can you also update BaseMultiModalItemTracker._placeholder_str inside vllm/entrypoints/chat_utils.py to support online inference?

Also, it would be great if you could include this model in examples/offline_inference_audio_language.py!

@DarkLight1337 DarkLight1337 self-assigned this Oct 11, 2024
@faychu
Copy link
Contributor Author

faychu commented Oct 12, 2024

To ease testing, can you also update BaseMultiModalItemTracker._placeholder_str inside vllm/entrypoints/chat_utils.py to support online inference?

Also, it would be great if you could include this model in examples/offline_inference_audio_language.py!

@DarkLight1337 Hi, I‘ve updated vllm/entrypoints/chat_utils.py and examples/offline_inference_audio_language.py. Could you please take another look?

@DarkLight1337
Copy link
Member

DarkLight1337 commented Oct 12, 2024

I'm unable to run the example script because of an error in input_processor_for_qwen2_audio. Since multi_modal_data["audio"] is actually a tuple (audio_data, sampling_rate), we should pass them separately into processor.feature_extractor. Additional handling is required if you intend the model to support more than one audio input, since in that case it can either be a single tuple or a list of tuples.

By the way, you should also add this model to the Supported Models page of the docs.

@zjy-git
Copy link

zjy-git commented Oct 17, 2024

use OpenAI Audio API Client send rquest to Qwen2-Audio vllm backend, return error:
openai.BadRequestError: Error code: 400 - {'object': 'error', 'message': 'operands could not be broadcast together with remapped shapes [original->remapped]: (2,2) and requested shape (1,2)', 'type': 'BadRequestError', 'param': None, 'code': 400}
if use stream request, backend is error:
ERROR 10-17 07:33:08 serving_chat.py:637] error in chat completion stream generator: operands could not be broadcast together with remapped shapes [original->remapped]: (2,2) and requested shape (1,2)

@kratorado
Copy link

excellent

@umie0128
Copy link

@faychu 推理时间长的音频 输入的input.tokens长度不对

@faychu
Copy link
Contributor Author

faychu commented Oct 28, 2024

@faychu 推理时间长的音频 输入的input.tokens长度不对

@umie0128 Qwen2audio only supports audio lengths of up to 30 seconds.

cooleel pushed a commit to cooleel/vllm that referenced this pull request Oct 28, 2024
Co-authored-by: DarkLight1337 <tlleungac@connect.ust.hk>
Signed-off-by: Shanshan Wang <shanshan.wang@h2o.ai>
cooleel pushed a commit to cooleel/vllm that referenced this pull request Oct 28, 2024
Co-authored-by: DarkLight1337 <tlleungac@connect.ust.hk>
Signed-off-by: Shanshan Wang <shanshan.wang@h2o.ai>
FerdinandZhong pushed a commit to FerdinandZhong/vllm that referenced this pull request Oct 29, 2024
Co-authored-by: DarkLight1337 <tlleungac@connect.ust.hk>
Signed-off-by: qishuai <ferdinandzhong@gmail.com>
@seetimee seetimee mentioned this pull request Oct 30, 2024
1 task
NickLucche pushed a commit to NickLucche/vllm that referenced this pull request Oct 31, 2024
Co-authored-by: DarkLight1337 <tlleungac@connect.ust.hk>
Signed-off-by: NickLucche <nlucches@redhat.com>
NickLucche pushed a commit to NickLucche/vllm that referenced this pull request Oct 31, 2024
Co-authored-by: DarkLight1337 <tlleungac@connect.ust.hk>
Signed-off-by: NickLucche <nlucches@redhat.com>
@jlia0
Copy link

jlia0 commented Nov 5, 2024

@faychu 推理时间长的音频 输入的input.tokens长度不对

@umie0128 Qwen2audio only supports audio lengths of up to 30 seconds.

What do you recommend we should do if the audio exceeds 30 seconds?

@MenggeLiu
Copy link

MenggeLiu commented Nov 7, 2024

hello,I use vllm.entrypoints.openai.api_server to start qwen2-audio server, and try infer with a .wav file, but face a error, which case vllm broken down.
image

I also have tried example audio files above, there is no error.
Please help me check this problem.

@DarkLight1337
Copy link
Member

hello,I use vllm.entrypoints.openai.api_server to start qwen2-audio server, and try infer with a .wav file, but face a error, which case vllm broken down.
image

I also have tried example audio files above, there is no error.
Please help me check this problem.

Can you share example code that results in this error?

@MenggeLiu
Copy link

MenggeLiu commented Nov 7, 2024

hello,I use vllm.entrypoints.openai.api_server to start qwen2-audio server, and try infer with a .wav file, but face a error, which case vllm broken down.
image
I also have tried example audio files above, there is no error.
Please help me check this problem.

Can you share example code that results in this error?

  • server start code
CUDA_VISIBLE_DEVICES=4,5 \
python -m vllm.entrypoints.openai.api_server \
    --model qwen2-audio-simul \
    --gpu-memory-utilization 0.85 \
    --tensor-parallel-size 2 \
    --max-model-len 8192 \
    --port 8001
import json
import os
import requests
import soundfile as sf
import time
import wave
import base64

def encode_base64_content_from_file(file_path: str) -> str:
    """Encode content retrieved from a local file to base64 format."""

    with open(file_path, 'rb') as file:
        result = base64.b64encode(file.read()).decode('utf-8')

    return result

def encode_base64_content_from_url(content_url: str) -> str:
    """Encode a content retrieved from a remote url to base64 format."""

    with requests.get(content_url) as response:
        response.raise_for_status()
        result = base64.b64encode(response.content).decode('utf-8')

    return result

def query_qwen2audio_local(messages, temperature=0.9, max_retries=3, model="qwen2.5-14b-instruct_sft_simultrans_augment_1029-100w", port=8001):
    model = 'qwen2-audio-simul'
    url = 'http://localhost:8001/v1/chat/completions'

    headers = {'Content-Type': 'application/json'}

    data_entry = {
        "model": model, #"Qwen2-72B-Instruct-GPTQ-Int4",
        "messages": messages,
        "temperature": temperature,
        "max_tokens": 512,
    }
    retries = 0
    while retries < max_retries: 
        try:
            response = requests.request("POST", url, headers=headers, json=data_entry)
            print(response)
            response_text = json.loads(response.text)
            print(response_text)
            res = response_text['choices'][0]['message']['content']
            # print(res)
            return res, response_text
        except Exception as e:
            print('api error:{}'.format(e))
            time.sleep(1)  
            retries += 1  
    return 'Error', None

# read audio and encode to base64
audio_file = 'ted_1096_11.wav'
audio_file = 'ted_1096_7.wav'
audio, sr = sf.read(audio_file)

audio_base64 = encode_base64_content_from_file(audio_file)

messages=[{
    "role":
    "user",
    "content": [
        {
            "type": "text",
            "text": "What's in this audio?"
        },
        {
            "type": "audio_url",
            "audio_url": {
                "url": f"data:audio/ogg;base64,{audio_base64}"
            },
        },
    ],
}]

res, data = query_qwen2audio_local(messages, temperature=0.2, max_retries=1)

I just try more audio files, and this error only happens on one audio file, but I didn't find any special for this audio......
I may share this audio with google-drive https://drive.google.com/file/d/1HQ5zC5MoqzFcm3CG0gcAigOWMtmg8lBY/view?usp=drive_link

@DarkLight1337
Copy link
Member

@faychu can you take a look at this?

@burness
Copy link

burness commented Nov 12, 2024

I build the latest vllm, use the #9248 (comment) and get the error:
image

It seems some ops in moe cause the error, anyone can help me ?

image

@DarkLight1337
Copy link
Member

The latest vLLM uses torch 2.5.1, please update your vLLM dependencies by installing from requirements-*.txt again.

@burness
Copy link

burness commented Nov 13, 2024

@DarkLight1337 Thanks. I will try update torch to 2.5.1

sumitd2 pushed a commit to sumitd2/vllm that referenced this pull request Nov 14, 2024
Co-authored-by: DarkLight1337 <tlleungac@connect.ust.hk>
Signed-off-by: Sumit Dubey <sumit.dubey2@ibm.com>
KuntaiDu pushed a commit to KuntaiDu/vllm that referenced this pull request Nov 20, 2024
Co-authored-by: DarkLight1337 <tlleungac@connect.ust.hk>
mfournioux pushed a commit to mfournioux/vllm that referenced this pull request Nov 20, 2024
Co-authored-by: DarkLight1337 <tlleungac@connect.ust.hk>
Signed-off-by: Maxime Fournioux <55544262+mfournioux@users.noreply.github.com>
@jiahansu
Copy link

I downloaded the v0.6.4.post1

I’m encountering an error when trying to use vLLM serve with the Qwen/Qwen2-Audio-7B-Instruct model to process audio input.
Run the following curl command:

curl https://huxtwsgqgqkueq-5000.proxy.runpod.net/v1/chat/completions \
    -X POST \
    -H 'Content-Type: application/json' \
    -d '{
        "model": "Qwen/Qwen2-Audio-7B-Instruct",
        "max_tokens": 1024,
        "temperature": 0.1,
        "messages": [
            {
                "role": "user",
                "content": [
                    {
                        "type": "audio_url",
                        "audio_url": {
                            "url": "http://modelscope-open.oss-cn-hangzhou.aliyuncs.com/images/weather.wav"
                        }
                    },
                    {
                        "type": "text",
                        "text": "Transcribe Text"
                    }
                ]
            }
        ]
    }'

Observe the error output:

  File "/workspace/miniconda3/envs/vllm/lib/python3.12/site-packages/uvicorn/protocols/http/httptools_impl.py", line 401, in run_asgi
    result = await app(  # type: ignore[func-returns-value]
             ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/workspace/miniconda3/envs/vllm/lib/python3.12/site-packages/uvicorn/middleware/proxy_headers.py", line 60, in __call__
    return await self.app(scope, receive, send)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/workspace/miniconda3/envs/vllm/lib/python3.12/site-packages/fastapi/applications.py", line 1054, in __call__
    await super().__call__(scope, receive, send)
  File "/workspace/miniconda3/envs/vllm/lib/python3.12/site-packages/starlette/applications.py", line 113, in __call__
    await self.middleware_stack(scope, receive, send)
  File "/workspace/miniconda3/envs/vllm/lib/python3.12/site-packages/starlette/middleware/errors.py", line 187, in __call__
    raise exc
  File "/workspace/miniconda3/envs/vllm/lib/python3.12/site-packages/starlette/middleware/errors.py", line 165, in __call__
    await self.app(scope, receive, _send)
  File "/workspace/miniconda3/envs/vllm/lib/python3.12/site-packages/starlette/middleware/base.py", line 185, in __call__
    with collapse_excgroups():
         ^^^^^^^^^^^^^^^^^^^^
  File "/workspace/miniconda3/envs/vllm/lib/python3.12/contextlib.py", line 158, in __exit__
    self.gen.throw(value)
  File "/workspace/miniconda3/envs/vllm/lib/python3.12/site-packages/starlette/_utils.py", line 82, in collapse_excgroups
    raise exc
  File "/workspace/miniconda3/envs/vllm/lib/python3.12/site-packages/starlette/middleware/base.py", line 187, in __call__
    response = await self.dispatch_func(request, call_next)
               ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/workspace/miniconda3/envs/vllm/lib/python3.12/site-packages/vllm/entrypoints/openai/api_server.py", line 490, in add_request_id
    response = await call_next(request)
               ^^^^^^^^^^^^^^^^^^^^^^^^
  File "/workspace/miniconda3/envs/vllm/lib/python3.12/site-packages/starlette/middleware/base.py", line 163, in call_next
    raise app_exc
  File "/workspace/miniconda3/envs/vllm/lib/python3.12/site-packages/starlette/middleware/base.py", line 149, in coro
    await self.app(scope, receive_or_disconnect, send_no_error)
  File "/workspace/miniconda3/envs/vllm/lib/python3.12/site-packages/starlette/middleware/cors.py", line 85, in __call__
    await self.app(scope, receive, send)
  File "/workspace/miniconda3/envs/vllm/lib/python3.12/site-packages/starlette/middleware/exceptions.py", line 62, in __call__
    await wrap_app_handling_exceptions(self.app, conn)(scope, receive, send)
  File "/workspace/miniconda3/envs/vllm/lib/python3.12/site-packages/starlette/_exception_handler.py", line 53, in wrapped_app
    raise exc
  File "/workspace/miniconda3/envs/vllm/lib/python3.12/site-packages/starlette/_exception_handler.py", line 42, in wrapped_app
    await app(scope, receive, sender)
  File "/workspace/miniconda3/envs/vllm/lib/python3.12/site-packages/starlette/routing.py", line 715, in __call__
    await self.middleware_stack(scope, receive, send)
  File "/workspace/miniconda3/envs/vllm/lib/python3.12/site-packages/starlette/routing.py", line 735, in app
    await route.handle(scope, receive, send)
  File "/workspace/miniconda3/envs/vllm/lib/python3.12/site-packages/starlette/routing.py", line 288, in handle
    await self.app(scope, receive, send)
  File "/workspace/miniconda3/envs/vllm/lib/python3.12/site-packages/starlette/routing.py", line 76, in app
    await wrap_app_handling_exceptions(app, request)(scope, receive, send)
  File "/workspace/miniconda3/envs/vllm/lib/python3.12/site-packages/starlette/_exception_handler.py", line 53, in wrapped_app
    raise exc
  File "/workspace/miniconda3/envs/vllm/lib/python3.12/site-packages/starlette/_exception_handler.py", line 42, in wrapped_app
    await app(scope, receive, sender)
  File "/workspace/miniconda3/envs/vllm/lib/python3.12/site-packages/starlette/routing.py", line 73, in app
    response = await f(request)
               ^^^^^^^^^^^^^^^^
  File "/workspace/miniconda3/envs/vllm/lib/python3.12/site-packages/fastapi/routing.py", line 301, in app
    raw_response = await run_endpoint_function(
                   ^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/workspace/miniconda3/envs/vllm/lib/python3.12/site-packages/fastapi/routing.py", line 212, in run_endpoint_function
    return await dependant.call(**values)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/workspace/miniconda3/envs/vllm/lib/python3.12/site-packages/vllm/entrypoints/openai/api_server.py", line 347, in create_chat_completion
    generator = await handler.create_chat_completion(request, raw_request)
                ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/workspace/miniconda3/envs/vllm/lib/python3.12/site-packages/vllm/entrypoints/openai/serving_chat.py", line 238, in create_chat_completion
    return await self.chat_completion_full_generator(
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/workspace/miniconda3/envs/vllm/lib/python3.12/site-packages/vllm/entrypoints/openai/serving_chat.py", line 598, in chat_completion_full_generator
    async for res in result_generator:
  File "/workspace/miniconda3/envs/vllm/lib/python3.12/site-packages/vllm/utils.py", line 402, in iterate_with_cancellation
    item = await awaits[0]
           ^^^^^^^^^^^^^^^
  File "/workspace/miniconda3/envs/vllm/lib/python3.12/site-packages/vllm/engine/multiprocessing/client.py", line 633, in _process_request
    raise request_output
KeyError: 'prompt'

@DarkLight1337
Copy link
Member

DarkLight1337 commented Nov 20, 2024

Refer to #10493

tlrmchlsmth pushed a commit to neuralmagic/vllm that referenced this pull request Nov 23, 2024
Co-authored-by: DarkLight1337 <tlleungac@connect.ust.hk>
Signed-off-by: Tyler Michael Smith <tyler@neuralmagic.com>
@RonanKMcGovern
Copy link
Contributor

Thanks for adding this model.

I'm running now via docker with 0.6.4.post1 and getting this error:

ValueError: Model architectures ['Qwen2AudioForConditionalGeneration'] failed to be inspected. Please check the logs for more details.

I'm passing these arguments:

--served-model-name Qwen2-Audio-7B-Instruct --model Qwen/Qwen2-Audio-7B-Instruct --port 8000 --trust_remote_code --gpu_memory_utilization 0.98 --enforce_eager --limit_mm_per_prompt audio=5

I also ssh'd into the pod and can confirm that transformers 4.46.2 and vllm 0.6.4.post1+cu124 are installed (I'm running on runpod on an A40). One click template is here.

@DarkLight1337
Copy link
Member

Thanks for adding this model.

I'm running now via docker with 0.6.4.post1 and getting this error:

ValueError: Model architectures ['Qwen2AudioForConditionalGeneration'] failed to be inspected. Please check the logs for more details.

I'm passing these arguments:

--served-model-name Qwen2-Audio-7B-Instruct --model Qwen/Qwen2-Audio-7B-Instruct --port 8000 --trust_remote_code --gpu_memory_utilization 0.98 --enforce_eager --limit_mm_per_prompt audio=5

I also ssh'd into the pod and can confirm that transformers 4.46.2 and vllm 0.6.4.post1+cu124 are installed (I'm running on runpod on an A40). One click template is here.

Can you post the error log?

@RonanKMcGovern
Copy link
Contributor

Absolutely, should have done before:

2024-11-26T02:03:08.199442560Z INFO 11-25 18:03:08 api_server.py:585] vLLM API server version 0.6.4.post1
2024-11-26T02:03:08.199510077Z INFO 11-25 18:03:08 api_server.py:586] args: Namespace(host=None, port=8000, uvicorn_log_level='info', allow_credentials=False, allowed_origins=['*'], allowed_methods=['*'], allowed_headers=['*'], api_key=None, lora_modules=None, prompt_adapters=None, chat_template=None, response_role='assistant', ssl_keyfile=None, ssl_certfile=None, ssl_ca_certs=None, ssl_cert_reqs=0, root_path=None, middleware=[], return_tokens_as_token_ids=False, disable_frontend_multiprocessing=False, enable_auto_tool_choice=False, tool_call_parser=None, tool_parser_plugin='', model='Qwen/Qwen2-Audio-7B-Instruct', task='auto', tokenizer=None, skip_tokenizer_init=False, revision=None, code_revision=None, tokenizer_revision=None, tokenizer_mode='auto', chat_template_text_format='string', trust_remote_code=True, allowed_local_media_path=None, download_dir=None, load_format='auto', config_format=<ConfigFormat.AUTO: 'auto'>, dtype='auto', kv_cache_dtype='auto', quantization_param_path=None, max_model_len=None, guided_decoding_backend='outlines', distributed_executor_backend=None, worker_use_ray=False, pipeline_parallel_size=1, tensor_parallel_size=1, max_parallel_loading_workers=None, ray_workers_use_nsight=False, block_size=16, enable_prefix_caching=False, disable_sliding_window=False, use_v2_block_manager=False, num_lookahead_slots=0, seed=0, swap_space=4, cpu_offload_gb=0, gpu_memory_utilization=0.98, num_gpu_blocks_override=None, max_num_batched_tokens=None, max_num_seqs=256, max_logprobs=20, disable_log_stats=False, quantization=None, rope_scaling=None, rope_theta=None, hf_overrides=None, enforce_eager=True, max_seq_len_to_capture=8192, disable_custom_all_reduce=False, tokenizer_pool_size=0, tokenizer_pool_type='ray', tokenizer_pool_extra_config=None, limit_mm_per_prompt={'audio': 5}, mm_processor_kwargs=None, enable_lora=False, enable_lora_bias=False, max_loras=1, max_lora_rank=16, lora_extra_vocab_size=256, lora_dtype='auto', long_lora_scaling_factors=None, max_cpu_loras=None, fully_sharded_loras=False, enable_prompt_adapter=False, max_prompt_adapters=1, max_prompt_adapter_token=0, device='auto', num_scheduler_steps=1, multi_step_stream_outputs=True, scheduler_delay_factor=0.0, enable_chunked_prefill=None, speculative_model=None, speculative_model_quantization=None, num_speculative_tokens=None, speculative_disable_mqa_scorer=False, speculative_draft_tensor_parallel_size=None, speculative_max_model_len=None, speculative_disable_by_batch_size=None, ngram_prompt_lookup_max=None, ngram_prompt_lookup_min=None, spec_decoding_acceptance_method='rejection_sampler', typical_acceptance_sampler_posterior_threshold=None, typical_acceptance_sampler_posterior_alpha=None, disable_logprobs_during_spec_decoding=None, model_loader_extra_config=None, ignore_patterns=[], preemption_mode=None, served_model_name=['Qwen2-Audio-7B-Instruct'], qlora_adapter_name_or_path=None, otlp_traces_endpoint=None, collect_detailed_traces=None, disable_async_output_proc=False, scheduling_policy='fcfs', override_neuron_config=None, override_pooler_config=None, disable_log_requests=False, max_log_len=None, disable_fastapi_docs=False, enable_prompt_tokens_details=False)
2024-11-26T02:03:08.211350315Z INFO 11-25 18:03:08 api_server.py:175] Multiprocessing frontend to use ipc:///tmp/09d6a08e-a33c-4c4e-bb4f-18434bccfa0c for IPC Path.
2024-11-26T02:03:08.214106669Z INFO 11-25 18:03:08 api_server.py:194] Started engine process with PID 88
2024-11-26T02:03:13.801792312Z ERROR 11-25 18:03:13 registry.py:297] Error in inspecting model architecture 'Qwen2AudioForConditionalGeneration'
2024-11-26T02:03:13.801814857Z ERROR 11-25 18:03:13 registry.py:297] Traceback (most recent call last):
2024-11-26T02:03:13.801816904Z ERROR 11-25 18:03:13 registry.py:297]   File "/usr/local/lib/python3.12/dist-packages/vllm/model_executor/models/registry.py", line 468, in _run_in_subprocess
2024-11-26T02:03:13.801818816Z ERROR 11-25 18:03:13 registry.py:297]     returned.check_returncode()
2024-11-26T02:03:13.801820667Z ERROR 11-25 18:03:13 registry.py:297]   File "/usr/lib/python3.12/subprocess.py", line 502, in check_returncode
2024-11-26T02:03:13.801834544Z ERROR 11-25 18:03:13 registry.py:297]     raise CalledProcessError(self.returncode, self.args, self.stdout,
2024-11-26T02:03:13.801836242Z ERROR 11-25 18:03:13 registry.py:297] subprocess.CalledProcessError: Command '['/usr/bin/python3', '-m', 'vllm.model_executor.models.registry']' returned non-zero exit status 1.
2024-11-26T02:03:13.801838157Z ERROR 11-25 18:03:13 registry.py:297]
2024-11-26T02:03:13.801840073Z ERROR 11-25 18:03:13 registry.py:297] The above exception was the direct cause of the following exception:
2024-11-26T02:03:13.801841407Z ERROR 11-25 18:03:13 registry.py:297]
2024-11-26T02:03:13.801842706Z ERROR 11-25 18:03:13 registry.py:297] Traceback (most recent call last):
2024-11-26T02:03:13.801844001Z ERROR 11-25 18:03:13 registry.py:297]   File "/usr/local/lib/python3.12/dist-packages/vllm/model_executor/models/registry.py", line 295, in _try_inspect_model_cls
2024-11-26T02:03:13.801846228Z ERROR 11-25 18:03:13 registry.py:297]     return model.inspect_model_cls()
2024-11-26T02:03:13.801847586Z ERROR 11-25 18:03:13 registry.py:297]            ^^^^^^^^^^^^^^^^^^^^^^^^^
2024-11-26T02:03:13.801848933Z ERROR 11-25 18:03:13 registry.py:297]   File "/usr/local/lib/python3.12/dist-packages/vllm/model_executor/models/registry.py", line 257, in inspect_model_cls
2024-11-26T02:03:13.801850283Z ERROR 11-25 18:03:13 registry.py:297]     return _run_in_subprocess(
2024-11-26T02:03:13.801851828Z ERROR 11-25 18:03:13 registry.py:297]            ^^^^^^^^^^^^^^^^^^^
2024-11-26T02:03:13.801853080Z ERROR 11-25 18:03:13 registry.py:297]   File "/usr/local/lib/python3.12/dist-packages/vllm/model_executor/models/registry.py", line 471, in _run_in_subprocess
2024-11-26T02:03:13.801854419Z ERROR 11-25 18:03:13 registry.py:297]     raise RuntimeError(f"Error raised in subprocess:\n"
2024-11-26T02:03:13.801856200Z ERROR 11-25 18:03:13 registry.py:297] RuntimeError: Error raised in subprocess:
2024-11-26T02:03:13.801859098Z ERROR 11-25 18:03:13 registry.py:297] <frozen runpy>:128: RuntimeWarning: 'vllm.model_executor.models.registry' found in sys.modules after import of package 'vllm.model_executor.models', but prior to execution of 'vllm.model_executor.models.registry'; this may result in unpredictable behaviour
2024-11-26T02:03:13.801860909Z ERROR 11-25 18:03:13 registry.py:297] Traceback (most recent call last):
2024-11-26T02:03:13.801862260Z ERROR 11-25 18:03:13 registry.py:297]   File "<frozen runpy>", line 198, in _run_module_as_main
2024-11-26T02:03:13.801863575Z ERROR 11-25 18:03:13 registry.py:297]   File "<frozen runpy>", line 88, in _run_code
2024-11-26T02:03:13.801864866Z ERROR 11-25 18:03:13 registry.py:297]   File "/usr/local/lib/python3.12/dist-packages/vllm/model_executor/models/registry.py", line 492, in <module>
2024-11-26T02:03:13.801866222Z ERROR 11-25 18:03:13 registry.py:297]     _run()
2024-11-26T02:03:13.801867740Z ERROR 11-25 18:03:13 registry.py:297]   File "/usr/local/lib/python3.12/dist-packages/vllm/model_executor/models/registry.py", line 485, in _run
2024-11-26T02:03:13.801869090Z ERROR 11-25 18:03:13 registry.py:297]     result = fn()
2024-11-26T02:03:13.801870440Z ERROR 11-25 18:03:13 registry.py:297]              ^^^^
2024-11-26T02:03:13.801871966Z ERROR 11-25 18:03:13 registry.py:297]   File "/usr/local/lib/python3.12/dist-packages/vllm/model_executor/models/registry.py", line 258, in <lambda>
2024-11-26T02:03:13.801873381Z ERROR 11-25 18:03:13 registry.py:297]     lambda: _ModelInfo.from_model_cls(self.load_model_cls()))
2024-11-26T02:03:13.801874728Z ERROR 11-25 18:03:13 registry.py:297]                                       ^^^^^^^^^^^^^^^^^^^^^
2024-11-26T02:03:13.801876024Z ERROR 11-25 18:03:13 registry.py:297]   File "/usr/local/lib/python3.12/dist-packages/vllm/model_executor/models/registry.py", line 261, in load_model_cls
2024-11-26T02:03:13.801877358Z ERROR 11-25 18:03:13 registry.py:297]     mod = importlib.import_module(self.module_name)
2024-11-26T02:03:13.801878670Z ERROR 11-25 18:03:13 registry.py:297]           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
2024-11-26T02:03:13.801882298Z ERROR 11-25 18:03:13 registry.py:297]   File "/usr/lib/python3.12/importlib/__init__.py", line 90, in import_module
2024-11-26T02:03:13.801883737Z ERROR 11-25 18:03:13 registry.py:297]     return _bootstrap._gcd_import(name[level:], package, level)
2024-11-26T02:03:13.801885085Z ERROR 11-25 18:03:13 registry.py:297]            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
2024-11-26T02:03:13.801886640Z ERROR 11-25 18:03:13 registry.py:297]   File "<frozen importlib._bootstrap>", line 1387, in _gcd_import
2024-11-26T02:03:13.801887994Z ERROR 11-25 18:03:13 registry.py:297]   File "<frozen importlib._bootstrap>", line 1360, in _find_and_load
2024-11-26T02:03:13.801889327Z ERROR 11-25 18:03:13 registry.py:297]   File "<frozen importlib._bootstrap>", line 1331, in _find_and_load_unlocked
2024-11-26T02:03:13.801890632Z ERROR 11-25 18:03:13 registry.py:297]   File "<frozen importlib._bootstrap>", line 935, in _load_unlocked
2024-11-26T02:03:13.801891890Z ERROR 11-25 18:03:13 registry.py:297]   File "<frozen importlib._bootstrap_external>", line 995, in exec_module
2024-11-26T02:03:13.801893146Z ERROR 11-25 18:03:13 registry.py:297]   File "<frozen importlib._bootstrap>", line 488, in _call_with_frames_removed
2024-11-26T02:03:13.801894753Z ERROR 11-25 18:03:13 registry.py:297]   File "/usr/local/lib/python3.12/dist-packages/vllm/model_executor/models/qwen2_audio.py", line 25, in <module>
2024-11-26T02:03:13.801896047Z ERROR 11-25 18:03:13 registry.py:297]     import librosa
2024-11-26T02:03:13.801897366Z ERROR 11-25 18:03:13 registry.py:297] ModuleNotFoundError: No module named 'librosa'
2024-11-26T02:03:13.801898705Z ERROR 11-25 18:03:13 registry.py:297]
2024-11-26T02:03:13.802900535Z Traceback (most recent call last):
2024-11-26T02:03:13.802928051Z   File "<frozen runpy>", line 198, in _run_module_as_main
2024-11-26T02:03:13.802930903Z   File "<frozen runpy>", line 88, in _run_code
2024-11-26T02:03:13.802935211Z   File "/usr/local/lib/python3.12/dist-packages/vllm/entrypoints/openai/api_server.py", line 643, in <module>
2024-11-26T02:03:13.803389021Z     uvloop.run(run_server(args))
2024-11-26T02:03:13.803505330Z   File "/usr/local/lib/python3.12/dist-packages/uvloop/__init__.py", line 109, in run
2024-11-26T02:03:13.803662829Z     return __asyncio.run(
2024-11-26T02:03:13.803680602Z            ^^^^^^^^^^^^^^
2024-11-26T02:03:13.803685422Z   File "/usr/lib/python3.12/asyncio/runners.py", line 194, in run
2024-11-26T02:03:13.803889826Z     return runner.run(main)
2024-11-26T02:03:13.803916153Z            ^^^^^^^^^^^^^^^^
2024-11-26T02:03:13.803922972Z   File "/usr/lib/python3.12/asyncio/runners.py", line 118, in run
2024-11-26T02:03:13.804056221Z     return self._loop.run_until_complete(task)
2024-11-26T02:03:13.804117092Z            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
2024-11-26T02:03:13.804121927Z   File "uvloop/loop.pyx", line 1518, in uvloop.loop.Loop.run_until_complete
2024-11-26T02:03:13.804448618Z   File "/usr/local/lib/python3.12/dist-packages/uvloop/__init__.py", line 61, in wrapper
2024-11-26T02:03:13.804598455Z     return await main
2024-11-26T02:03:13.804627003Z            ^^^^^^^^^^
2024-11-26T02:03:13.804633607Z   File "/usr/local/lib/python3.12/dist-packages/vllm/entrypoints/openai/api_server.py", line 609, in run_server
2024-11-26T02:03:13.804954262Z     async with build_async_engine_client(args) as engine_client:
2024-11-26T02:03:13.805001853Z                ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
2024-11-26T02:03:13.805006631Z   File "/usr/lib/python3.12/contextlib.py", line 210, in __aenter__
2024-11-26T02:03:13.805177435Z     return await anext(self.gen)
2024-11-26T02:03:13.805217051Z            ^^^^^^^^^^^^^^^^^^^^^
2024-11-26T02:03:13.805221209Z   File "/usr/local/lib/python3.12/dist-packages/vllm/entrypoints/openai/api_server.py", line 113, in build_async_engine_client
2024-11-26T02:03:13.805346343Z     async with build_async_engine_client_from_engine_args(
2024-11-26T02:03:13.805373593Z                ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
2024-11-26T02:03:13.805375808Z   File "/usr/lib/python3.12/contextlib.py", line 210, in __aenter__
2024-11-26T02:03:13.805529891Z     return await anext(self.gen)
2024-11-26T02:03:13.805572662Z            ^^^^^^^^^^^^^^^^^^^^^
2024-11-26T02:03:13.805587862Z   File "/usr/local/lib/python3.12/dist-packages/vllm/entrypoints/openai/api_server.py", line 197, in build_async_engine_client_from_engine_args
2024-11-26T02:03:13.805750639Z     engine_config = engine_args.create_engine_config()
2024-11-26T02:03:13.805800568Z                     ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
2024-11-26T02:03:13.805804276Z   File "/usr/local/lib/python3.12/dist-packages/vllm/engine/arg_utils.py", line 959, in create_engine_config
2024-11-26T02:03:13.806266640Z     model_config = self.create_model_config()
2024-11-26T02:03:13.806316838Z                    ^^^^^^^^^^^^^^^^^^^^^^^^^^
2024-11-26T02:03:13.806318928Z   File "/usr/local/lib/python3.12/dist-packages/vllm/engine/arg_utils.py", line 891, in create_model_config
2024-11-26T02:03:13.806746030Z     return ModelConfig(
2024-11-26T02:03:13.806763144Z            ^^^^^^^^^^^^
2024-11-26T02:03:13.806787097Z   File "/usr/local/lib/python3.12/dist-packages/vllm/config.py", line 251, in __init__
2024-11-26T02:03:13.806968373Z     self.multimodal_config = self._init_multimodal_config(
2024-11-26T02:03:13.807005569Z                              ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
2024-11-26T02:03:13.807007124Z   File "/usr/local/lib/python3.12/dist-packages/vllm/config.py", line 277, in _init_multimodal_config
2024-11-26T02:03:13.807185421Z     if ModelRegistry.is_multimodal_model(architectures):
2024-11-26T02:03:13.807242161Z        ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
2024-11-26T02:03:13.807243666Z   File "/usr/local/lib/python3.12/dist-packages/vllm/model_executor/models/registry.py", line 422, in is_multimodal_model
2024-11-26T02:03:13.807525542Z     return self.inspect_model_cls(architectures).supports_multimodal
2024-11-26T02:03:13.807577860Z            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
2024-11-26T02:03:13.807579473Z   File "/usr/local/lib/python3.12/dist-packages/vllm/model_executor/models/registry.py", line 391, in inspect_model_cls
2024-11-26T02:03:13.807804010Z     return self._raise_for_unsupported(architectures)
2024-11-26T02:03:13.807859237Z            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
2024-11-26T02:03:13.807860736Z   File "/usr/local/lib/python3.12/dist-packages/vllm/model_executor/models/registry.py", line 348, in _raise_for_unsupported
2024-11-26T02:03:13.808059922Z     raise ValueError(
2024-11-26T02:03:13.808074090Z ValueError: Model architectures ['Qwen2AudioForConditionalGeneration'] failed to be inspected. Please check the logs for more details.

@DarkLight1337
Copy link
Member

You need to install librosa. You can do so by installing it directly or pip install vllm[audio]

@RonanKMcGovern
Copy link
Contributor

ah, yes, thanks.

Is there a way to get that in the openai docker image? or is the easiest/best thing for me to just build a docker image that wraps the current one and installs librosa? Thanks

@DarkLight1337
Copy link
Member

We don't include it inside our core dependencies because of licensing issues. It would be best for you to create your own docker image for this purpose.

@RonanKMcGovern
Copy link
Contributor

Noted, with thanks

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
ready ONLY add when PR is ready to merge/full CI is needed
Projects
None yet
10 participants