Skip to content

Commit

Permalink
Revert "zlib"
Browse files Browse the repository at this point in the history
This reverts commit e42be96.
  • Loading branch information
robertgshaw2-redhat committed Jul 28, 2024
1 parent e42be96 commit 5202a59
Show file tree
Hide file tree
Showing 2 changed files with 19 additions and 10 deletions.
20 changes: 15 additions & 5 deletions vllm/grpc/client.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,22 @@
import asyncio
from vllm import AsyncLLMEngine
from typing import AsyncIterator, Optional, Mapping
import grpc

# from vllm.grpc.server import UNIX_SOCKET
from .pb import generate_pb2_grpc, generate_pb2
from typing import AsyncIterator, List, Optional, Mapping

from vllm.inputs import PromptInputs
from vllm.lora.request import LoRARequest
from vllm.outputs import RequestOutput
from vllm.outputs import CompletionOutput
from vllm.prompt_adapter.request import PromptAdapterRequest
from vllm.sampling_params import SamplingParams
from transformers import AutoTokenizer
from dataclasses import dataclass
import zmq, zlib

import time
import zmq
import zmq.asyncio
import pickle

Expand All @@ -27,7 +35,9 @@ def __init__(self):
self.worker_use_ray = False
self.log_requests = False
self.engine = None

self.tokenizer = AutoTokenizer.from_pretrained(MODEL)

self.context = zmq.asyncio.Context()


Expand Down Expand Up @@ -67,18 +77,18 @@ async def generate(
socket.connect('tcp://localhost:5570')

await socket.send_multipart([
zlib.compress(pickle.dumps(
pickle.dumps(
RCPRequest(
inputs=inputs,
sampling_params=sampling_params,
request_id=request_id
), pickle.HIGHEST_PROTOCOL
))
)
])

while True:
message = await socket.recv()
request_output = pickle.loads(zlib.decompress(message))
request_output = pickle.loads(message)

if request_output.finished:
break
Expand Down
9 changes: 4 additions & 5 deletions vllm/grpc/server.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from vllm import AsyncEngineArgs, AsyncLLMEngine
import asyncio
import pickle
import zmq, zlib
import zmq
import zmq.asyncio

MODEL = "meta-llama/Meta-Llama-3-8B-Instruct"
Expand All @@ -15,10 +15,10 @@ def __init__(self):

self.running_tasks = set()
self.engine = AsyncLLMEngine.from_engine_args(
AsyncEngineArgs(model=MODEL, enable_chunked_prefill=True))
AsyncEngineArgs(model=MODEL))

async def generate(self, identity, message):
request = pickle.loads(zlib.decompress(message))
request = pickle.loads(message)
results_generator = self.engine.generate(
request.inputs,
sampling_params=request.sampling_params,
Expand All @@ -27,8 +27,7 @@ async def generate(self, identity, message):
async for request_output in results_generator:
self.socket.send_multipart([
identity,
zlib.compress(
pickle.dumps(request_output, pickle.HIGHEST_PROTOCOL))
pickle.dumps(request_output, pickle.HIGHEST_PROTOCOL)
])

async def run_loop(self):
Expand Down

0 comments on commit 5202a59

Please sign in to comment.