Skip to content

Commit

Permalink
Remove past_key_value blocked by transformers cache error huggingface…
Browse files Browse the repository at this point in the history
  • Loading branch information
haixuanTao committed Feb 4, 2025
1 parent 071de27 commit 4d83d10
Showing 1 changed file with 10 additions and 5 deletions.
15 changes: 10 additions & 5 deletions node-hub/dora-qwen2-5-vl/dora_qwen2_5_vl/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,13 +32,13 @@
IMAGE_WIDTH = int(
os.getenv(
"IMAGE_WIDTH",
"320",
"0",
)
)
IMAGE_HEIGHT = int(
os.getenv(
"IMAGE_HEIGHT",
"225",
"0",
)
)
HISTORY = os.getenv("HISTORY", "False") in ["True", "true"]
Expand Down Expand Up @@ -104,10 +104,12 @@ def generate(frames: dict, question, history, past_key_values=None):
inputs = inputs.to(model.device)

# Inference: Generation of the output
## TODO: Add past_key_values to the inputs when https://github.com/huggingface/transformers/issues/34678 is fixed.
outputs = model.generate(
**inputs, max_new_tokens=128, past_key_values=past_key_values
**inputs,
max_new_tokens=128, # past_key_values=past_key_values
)
past_key_values = outputs.past_key_values
# past_key_values = outputs.past_key_values

generated_ids_trimmed = [
out_ids[len(in_ids) :] for in_ids, out_ids in zip(inputs.input_ids, outputs)
Expand Down Expand Up @@ -201,7 +203,10 @@ def main():
else:
raise RuntimeError(f"Unsupported image encoding: {encoding}")
image = Image.fromarray(frame)
frames[event_id] = image.resize((IMAGE_HEIGHT, IMAGE_WIDTH))
if IMAGE_HEIGHT > 0 and IMAGE_WIDTH > 0:
frames[event_id] = image.resize((IMAGE_HEIGHT, IMAGE_WIDTH))
else:
frames[event_id] = image

elif "text" in event_id:
if len(event["value"]) > 0:
Expand Down

0 comments on commit 4d83d10

Please sign in to comment.