Skip to content

Commit

Permalink
feat: Add Phi3 Mini Requirements & Set Active Adapter (#469)
Browse files Browse the repository at this point in the history
**Reason for Change**:
Add the requirements for Phi3 Mini Model
Add updated libraries needed for Phi3
Bug fix to set active adapter to weighted one in inference
  • Loading branch information
ishaansehgal99 authored Jun 13, 2024
1 parent d38d87b commit 3ef8e66
Show file tree
Hide file tree
Showing 11 changed files with 292 additions and 48 deletions.
3 changes: 2 additions & 1 deletion cmd/models.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,5 +7,6 @@ import (
_ "github.com/azure/kaito/presets/models/llama2"
_ "github.com/azure/kaito/presets/models/llama2chat"
_ "github.com/azure/kaito/presets/models/mistral"
_ "github.com/azure/kaito/presets/models/phi"
_ "github.com/azure/kaito/presets/models/phi-2"
_ "github.com/azure/kaito/presets/models/phi-3"
)
2 changes: 1 addition & 1 deletion pkg/resources/manifests.go
Original file line number Diff line number Diff line change
Expand Up @@ -272,7 +272,7 @@ func GenerateDeploymentManifest(ctx context.Context, workspaceObj *kaitov1alpha1
}
initContainers := []corev1.Container{}
envs := []corev1.EnvVar{}
if len(workspaceObj.Inference.Adapters) != 0 {
if len(workspaceObj.Inference.Adapters) > 0 {
for _, adapter := range workspaceObj.Inference.Adapters {
// TODO: accept Volumes and url link to pull images
initContainer := corev1.Container{
Expand Down
14 changes: 7 additions & 7 deletions presets/README.md
Original file line number Diff line number Diff line change
@@ -1,13 +1,13 @@
# Kaito Preset Configurations
The current supported model families with preset configurations are listed below.

| Model Family | Compatible Kaito Versions |
|-----------------|---------------------------|
|[falcon](./models/falcon)| v0.0.1+|
|[llama2](./models/llama2)| v0.0.1+|
|[llama2chat](./models/llama2chat)| v0.0.1+|
|[mistral](./models/mistral)| v0.2.0+|
|[phi2](./models/phi)| v0.2.0+|
| Model Family | Compatible Kaito Versions |
|-----------------------------------|---------------------------|
| [falcon](./models/falcon) | v0.0.1+|
| [llama2](./models/llama2) | v0.0.1+|
| [llama2chat](./models/llama2chat) | v0.0.1+|
| [mistral](./models/mistral) | v0.2.0+|
| [phi2](./models/phi-2) | v0.2.0+|

## Validation
Each preset model has its own hardware requirements in terms of GPU count and GPU memory defined in the respective `model.go` file. Kaito controller performs a validation check of whether the specified SKU and node count are sufficient to run the model or not. In case the provided SKU is not in the known list, the controller bypasses the validation check which means users need to ensure the model can run with the provided SKU.
Expand Down
76 changes: 41 additions & 35 deletions presets/inference/text-generation/inference_api.py
Original file line number Diff line number Diff line change
@@ -1,21 +1,21 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
import os
import subprocess
from dataclasses import asdict, dataclass, field
from typing import Annotated, Any, Dict, List, Optional

import GPUtil
import psutil
import torch
import transformers
import subprocess
import uvicorn
from fastapi import Body, FastAPI, HTTPException
from fastapi.responses import Response
from peft import PeftModel
from pydantic import BaseModel, Extra, Field, validator
from transformers import (AutoModelForCausalLM, AutoTokenizer,
GenerationConfig, HfArgumentParser)
from peft import PeftModel

ADAPTERS_DIR = '/mnt/adapter'
@dataclass
Expand Down Expand Up @@ -66,9 +66,12 @@ def __post_init__(self): # validate parameters
"""
Post-initialization to validate some ModelConfig values
"""
if self.torch_dtype and not hasattr(torch, self.torch_dtype):
if self.torch_dtype == "auto":
pass
elif self.torch_dtype and self.torch_dtype != "auto" and not hasattr(torch, self.torch_dtype):
raise ValueError(f"Invalid torch dtype: {self.torch_dtype}")
self.torch_dtype = getattr(torch, self.torch_dtype) if self.torch_dtype else None
else:
self.torch_dtype = getattr(torch, self.torch_dtype) if self.torch_dtype else None

supported_pipelines = {"conversational", "text-generation"}
if self.pipeline not in supported_pipelines:
Expand All @@ -91,39 +94,42 @@ def __post_init__(self): # validate parameters
tokenizer = AutoTokenizer.from_pretrained(**model_args)
base_model = AutoModelForCausalLM.from_pretrained(**model_args)

def list_files(directory):
try:
result = subprocess.run(['ls', directory], capture_output=True, text=True)
if result.returncode == 0:
return result.stdout.strip().split('\n')
else:
return [f"Command execution failed with return code: {result.returncode}"]
except Exception as e:
return [f"An error occurred: {str(e)}"]
if not os.path.exists(ADAPTERS_DIR):
model = base_model
else:
output = os.listdir(ADAPTERS_DIR)
filtered_output = [s for s in output if s.strip()]
adapters_list = [f"{ADAPTERS_DIR}/{file}" for file in filtered_output]
filtered_adapters_list = [path for path in adapters_list if os.path.exists(os.path.join(path, "adapter_config.json"))]

adapter_names, weights= [], []
for adapter_path in filtered_adapters_list:
adapter_name = os.path.basename(adapter_path)
adapter_names.append(adapter_name)
weights.append(float(os.getenv(adapter_name)))
model = PeftModel.from_pretrained(base_model, filtered_adapters_list[0], adapter_name=adapter_names[0])
for i in range(1, len(filtered_adapters_list)):
model.load_adapter(filtered_adapters_list[i], adapter_names[i])

model.add_weighted_adapter(
adapters = adapter_names,
weights = weights,
adapter_name="combined_adapter",
combination_type=combination_type,
)
print("Model:",model)
else:
valid_adapters_list = [
os.path.join(ADAPTERS_DIR, adapter) for adapter in os.listdir(ADAPTERS_DIR)
if os.path.isfile(os.path.join(ADAPTERS_DIR, adapter, "adapter_config.json"))
]

if valid_adapters_list:
adapter_names, weights = [], []
for adapter_path in valid_adapters_list:
adapter_name = os.path.basename(adapter_path)
adapter_names.append(adapter_name)
weights.append(float(os.getenv(adapter_name, '1.0')))

model = PeftModel.from_pretrained(base_model, valid_adapters_list[0], adapter_name=adapter_names[0])
for i in range(1, len(valid_adapters_list)):
model.load_adapter(valid_adapters_list[i], adapter_name=adapter_names[i])

model.add_weighted_adapter(
adapters=adapter_names,
weights=weights,
adapter_name="combined_adapter",
combination_type=combination_type,
)

model.set_adapter("combined_adapter")

# To avoid any potential future operations that use non-combined adapters
for adapter in adapter_names:
model.delete_adapter(adapter)
else:
print("Warning: Did not find any valid adapters mounted, using base model")
model = base_model

print("Model:", model)

pipeline_kwargs = {
"trust_remote_code": args.trust_remote_code,
Expand Down
3 changes: 2 additions & 1 deletion presets/inference/text-generation/requirements.txt
Original file line number Diff line number Diff line change
@@ -1,12 +1,13 @@
# Dependencies for TFS

# Core Dependencies
transformers==4.40.2
transformers==4.41.2
torch==2.2.0
accelerate==0.30.1
fastapi>=0.111.0,<0.112.0 # Allow patch updates
pydantic>=2.7.1,<2.8 # Allow patch updates
uvicorn[standard]>=0.29.0,<0.30.0 # Allow patch updates
peft

# Utility libraries
bitsandbytes==0.42.0
Expand Down
File renamed without changes.
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
// Copyright (c) Microsoft Corporation.
// Licensed under the MIT license.
package phi
package phi_2

import (
"time"
Expand Down
118 changes: 118 additions & 0 deletions presets/models/phi-3/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,118 @@
## Supported Models
| Model name | Model source | Sample workspace|Kubernetes Workload|Distributed inference|
|--------------------------|:----:|:----:| :----: |:----: |
| phi-3-mini-4k-instruct |[microsoft](https://huggingface.co/microsoft/Phi-3-mini-4k-instruct)|[link](../../../examples/inference/kaito_workspace_phi-2.yaml)|Deployment| false|
| phi-3-mini-128k-instruct |[microsoft](https://huggingface.co/microsoft/Phi-3-mini-128k-instruct)|[link](../../../examples/inference/kaito_workspace_phi-2.yaml)|Deployment| false|


## Image Source
- **Public**: Kaito maintainers manage the lifecycle of the inference service images that contain model weights. The images are available in Microsoft Container Registry (MCR).

## Usage

Phi-3 Mini models are best suited for prompts using the chat format as follows. You can provide the prompt as a question with a generic template as follow:

```
<|user|>\nQuestion<|end|>\n<|assistant|>
```

For example:
```
<|user|>
How to explain Internet for a medieval knight?<|end|>
<|assistant|>
```

Or in the case of few shots prompt:
```
<|user|>
I am going to Paris, what should I see?<|end|>
<|assistant|>
Paris, the capital of France, is known for its stunning architecture, art museums, historical landmarks, and romantic atmosphere. Here are some of the top attractions to see in Paris:\n\n1. The Eiffel Tower: The iconic Eiffel Tower is one of the most recognizable landmarks in the world and offers breathtaking views of the city.\n2. The Louvre Museum: The Louvre is one of the world's largest and most famous museums, housing an impressive collection of art and artifacts, including the Mona Lisa.\n3. Notre-Dame Cathedral: This beautiful cathedral is one of the most famous landmarks in Paris and is known for its Gothic architecture and stunning stained glass windows.\n\nThese are just a few of the many attractions that Paris has to offer. With so much to see and do, it's no wonder that Paris is one of the most popular tourist destinations in the world."<|end|>
<|user|>
What is so great about #1?<|end|>
<|assistant|>
```

The inference service endpoint is `/chat`.


### Basic example
```
curl -X POST "http://<SERVICE>:80/chat" -H "accept: application/json" -H "Content-Type: application/json" -d '{"prompt":"YOUR_PROMPT_HERE"}'
```
```
curl -X POST "http://<SERVICE>:80/chat" -H "accept: application/json" -H "Content-Type: application/json" -d '{"prompt":"<|user|> How to explain Internet for a medieval knight?<|end|><|assistant|>"}'
```


### Example with full configurable parameters
```
curl -X POST \
-H "accept: application/json" \
-H "Content-Type: application/json" \
-d '{
"prompt":"YOUR_PROMPT_HERE",
"return_full_text": false,
"clean_up_tokenization_spaces": false,
"prefix": null,
"handle_long_generation": null,
"generate_kwargs": {
"max_length":200,
"min_length":0,
"do_sample":true,
"early_stopping":false,
"num_beams":1,
"num_beam_groups":1,
"diversity_penalty":0.0,
"temperature":1.0,
"top_k":10,
"top_p":1,
"typical_p":1,
"repetition_penalty":1,
"length_penalty":1,
"no_repeat_ngram_size":0,
"encoder_no_repeat_ngram_size":0,
"bad_words_ids":null,
"num_return_sequences":1,
"output_scores":false,
"return_dict_in_generate":false,
"forced_bos_token_id":null,
"forced_eos_token_id":null,
"remove_invalid_values":null
}
}' \
"http://<SERVICE>:80/chat"
```

### Parameters
- `prompt`: The initial text provided by the user, from which the model will continue generating text.
- `return_full_text`: If False only generated text is returned, else full text is returned.
- `clean_up_tokenization_spaces`: True/False, determines whether to remove potential extra spaces in the text output.
- `prefix`: Prefix added to the prompt.
- `handle_long_generation`: Provides strategies to address generations beyond the model's maximum length capacity.
- `max_length`: The maximum total number of tokens in the generated text.
- `min_length`: The minimum total number of tokens that should be generated.
- `do_sample`: If True, sampling methods will be used for text generation, which can introduce randomness and variation.
- `early_stopping`: If True, the generation will stop early if certain conditions are met, for example, when a satisfactory number of candidates have been found in beam search.
- `num_beams`: The number of beams to be used in beam search. More beams can lead to better results but are more computationally expensive.
- `num_beam_groups`: Divides the number of beams into groups to promote diversity in the generated results.
- `diversity_penalty`: Penalizes the score of tokens that make the current generation too similar to other groups, encouraging diverse outputs.
- `temperature`: Controls the randomness of the output by scaling the logits before sampling.
- `top_k`: Restricts sampling to the k most likely next tokens.
- `top_p`: Uses nucleus sampling to restrict the sampling pool to tokens comprising the top p probability mass.
- `typical_p`: Adjusts the probability distribution to favor tokens that are "typically" likely, given the context.
- `repetition_penalty`: Penalizes tokens that have been generated previously, aiming to reduce repetition.
- `length_penalty`: Modifies scores based on sequence length to encourage shorter or longer outputs.
- `no_repeat_ngram_size`: Prevents the generation of any n-gram more than once.
- `encoder_no_repeat_ngram_size`: Similar to `no_repeat_ngram_size` but applies to the encoder part of encoder-decoder models.
- `bad_words_ids`: A list of token ids that should not be generated.
- `num_return_sequences`: The number of different sequences to generate.
- `output_scores`: Whether to output the prediction scores.
- `return_dict_in_generate`: If True, the method will return a dictionary containing additional information.
- `pad_token_id`: The token ID used for padding sequences to the same length.
- `eos_token_id`: The token ID that signifies the end of a sequence.
- `forced_bos_token_id`: The token ID that is forcibly used as the beginning of a sequence token.
- `forced_eos_token_id`: The token ID that is forcibly used as the end of a sequence when max_length is reached.
- `remove_invalid_values`: If True, filters out invalid values like NaNs or infs from model outputs to prevent crashes.
Loading

0 comments on commit 3ef8e66

Please sign in to comment.