Skip to content
This repository has been archived by the owner on Oct 11, 2024. It is now read-only.

Commit

Permalink
zlib
Browse files Browse the repository at this point in the history
  • Loading branch information
robertgshaw2-redhat committed Jul 28, 2024
1 parent 852534e commit e42be96
Show file tree
Hide file tree
Showing 2 changed files with 10 additions and 19 deletions.
20 changes: 5 additions & 15 deletions vllm/grpc/client.py
Original file line number Diff line number Diff line change
@@ -1,22 +1,14 @@
import asyncio
from vllm import AsyncLLMEngine
import grpc

# from vllm.grpc.server import UNIX_SOCKET
from .pb import generate_pb2_grpc, generate_pb2
from typing import AsyncIterator, List, Optional, Mapping
from typing import AsyncIterator, Optional, Mapping

from vllm.inputs import PromptInputs
from vllm.lora.request import LoRARequest
from vllm.outputs import RequestOutput
from vllm.outputs import CompletionOutput
from vllm.prompt_adapter.request import PromptAdapterRequest
from vllm.sampling_params import SamplingParams
from transformers import AutoTokenizer
from dataclasses import dataclass

import time
import zmq
import zmq, zlib
import zmq.asyncio
import pickle

Expand All @@ -35,9 +27,7 @@ def __init__(self):
self.worker_use_ray = False
self.log_requests = False
self.engine = None

self.tokenizer = AutoTokenizer.from_pretrained(MODEL)

self.context = zmq.asyncio.Context()


Expand Down Expand Up @@ -77,18 +67,18 @@ async def generate(
socket.connect('tcp://localhost:5570')

await socket.send_multipart([
pickle.dumps(
zlib.compress(pickle.dumps(
RCPRequest(
inputs=inputs,
sampling_params=sampling_params,
request_id=request_id
), pickle.HIGHEST_PROTOCOL
)
))
])

while True:
message = await socket.recv()
request_output = pickle.loads(message)
request_output = pickle.loads(zlib.decompress(message))

if request_output.finished:
break
Expand Down
9 changes: 5 additions & 4 deletions vllm/grpc/server.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from vllm import AsyncEngineArgs, AsyncLLMEngine
import asyncio
import pickle
import zmq
import zmq, zlib
import zmq.asyncio

MODEL = "meta-llama/Meta-Llama-3-8B-Instruct"
Expand All @@ -15,10 +15,10 @@ def __init__(self):

self.running_tasks = set()
self.engine = AsyncLLMEngine.from_engine_args(
AsyncEngineArgs(model=MODEL))
AsyncEngineArgs(model=MODEL, enable_chunked_prefill=True))

async def generate(self, identity, message):
request = pickle.loads(message)
request = pickle.loads(zlib.decompress(message))
results_generator = self.engine.generate(
request.inputs,
sampling_params=request.sampling_params,
Expand All @@ -27,7 +27,8 @@ async def generate(self, identity, message):
async for request_output in results_generator:
self.socket.send_multipart([
identity,
pickle.dumps(request_output, pickle.HIGHEST_PROTOCOL)
zlib.compress(
pickle.dumps(request_output, pickle.HIGHEST_PROTOCOL))
])

async def run_loop(self):
Expand Down

0 comments on commit e42be96

Please sign in to comment.