Skip to content
This repository has been archived by the owner on Feb 15, 2025. It is now read-only.

Commit

Permalink
Merge pull request #169 from defenseunicorns/stablelm-streamcomplete-…
Browse files Browse the repository at this point in the history
…cleanup

Bug Fix: Stablelm and ctransformers
  • Loading branch information
Gerred Dillon authored Aug 31, 2023
2 parents 1568345 + 5498cb7 commit bc978ac
Show file tree
Hide file tree
Showing 5 changed files with 29 additions and 25 deletions.
3 changes: 2 additions & 1 deletion models/llms/ctransformers/requirements.txt
Original file line number Diff line number Diff line change
@@ -1,2 +1,3 @@
transformers
torch
torch
ctransformers
2 changes: 1 addition & 1 deletion models/llms/ctransformers/test.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
You are an AI assistant that answers participates in chat discussions in an honest, concise, friendly way.<|im_end|>
<|im_start|>user
Write two sequences composed of 3 'A's and 2 'B's such that there are no two successive identical letter. Be concise.<|im_end|>
<|im_assistant|>
<|im_start|>assistant
"""

def run():
Expand Down
4 changes: 2 additions & 2 deletions models/llms/ctransformers/test_stream.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
You are an AI assistant that answers participates in chat discussions in an honest, concise, friendly way.<|im_end|>
<|im_start|>user
Write two sequences composed of 3 'A's and 2 'B's such that there are no two successive identical letter. Be concise.<|im_end|>
<|im_assistant|>
<|im_start|>assistant
"""

def run():
Expand All @@ -35,7 +35,7 @@ def run():
response: Iterator[leapfrogai.CompletionResponse] = stub.CompleteStream(request)

for completion in response:
print(completion.choices[0].text, end="")
print(completion.choices[0].text, end="", flush=True)


if __name__ == "__main__":
Expand Down
43 changes: 23 additions & 20 deletions models/llms/stablelm/stablelm.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
CompletionRequest,
CompletionResponse,
CompletionServiceServicer,
CompletionStreamServiceServicer,
GrpcContext,
serve,
)
Expand All @@ -32,49 +33,51 @@ def __call__(
return True
return False


class StableLM(CompletionServiceServicer):
torch.cuda.init()
if torch.cuda.is_available():
device = "cuda"
else:
device = "cpu"
tokenizer = AutoTokenizer.from_pretrained(TOKENIZER_ID)
model = AutoModelForCausalLM.from_pretrained(MODEL_ID)
model.half().cuda()
print("StableLM Loaded...")
class StableLM(CompletionServiceServicer, CompletionStreamServiceServicer):
def __init__(self):
torch.cuda.init()
if torch.cuda.is_available():
self.device = "cuda"
else:
self.device = "cpu"
self.tokenizer = AutoTokenizer.from_pretrained(TOKENIZER_ID)
self.model = AutoModelForCausalLM.from_pretrained(MODEL_ID)
self.model.half().cuda()
print("StableLM Loaded...")

def Complete(
self, request: CompletionRequest, context: GrpcContext
) -> CompletionResponse:
print(f"Request: { request }")
inputs = self.tokenizer(request.prompt, return_tensors="pt").to(torch.cuda.current_device())
logging.debug(f"Request: { request }")
inputs = self.tokenizer(request.prompt, return_tensors="pt").to(self.device)

# error checking for valid params
tokens = self.model.generate(
**inputs,
max_new_tokens=request.max_new_tokens,
temperature=request.temperature,
# repetition_penalty=request.frequence_penalty,
# top_p=request.top_p,
do_sample=True,
pad_token_id=self.tokenizer.eos_token_id,
eos_token_id=self.tokenizer.eos_token_id,
stopping_criteria=StoppingCriteriaList([StopOnTokens()]),
)
logging.debug(f"Response {tokens}")

# Extract out only the completion tokens
completion_tokens = tokens[0][inputs["input_ids"].size(1) :]
completion = self.tokenizer.decode(completion_tokens, skip_special_tokens=True)

c = CompletionChoice(text=completion, index=0)
logging.debug(f"Decoded Response: {completion}")

return CompletionResponse(choices=[c])


def CompleteStream(self, request: CompletionRequest, context: GrpcContext):
inputs = self.tokenizer(request.prompt, return_tensors="pt").to(self.device)
logging.debug(f"Request: { request }")
inputs = self.tokenizer(request.prompt, return_tensors="pt").to(self.device)

streamer = TextIteratorStreamer(self.tokenizer, skip_prompt=True)

generation_kwargs = dict(
inputs,
streamer=streamer,
Expand All @@ -86,12 +89,12 @@ def CompleteStream(self, request: CompletionRequest, context: GrpcContext):
eos_token_id=self.tokenizer.eos_token_id,
stopping_criteria=StoppingCriteriaList([StopOnTokens()]),
)

thread = Thread(target=self.model.generate, kwargs=generation_kwargs)
thread.start()
for text in streamer:
print(text)
yield text
# logging.debug(f"Response {tokens}")
completion = CompletionChoice(text=text, index=0)
yield CompletionResponse(choices=[completion])


if __name__ == "__main__":
Expand Down
2 changes: 1 addition & 1 deletion models/llms/stablelm/test_stream.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ def run():
response: Iterator[leapfrogai.CompletionResponse] = stub.CompleteStream(request)

for completion in response:
print(completion.choices[0].text, end="")
print(completion.choices[0].text, end="", flush=True)


if __name__ == "__main__":
Expand Down

0 comments on commit bc978ac

Please sign in to comment.