-
Notifications
You must be signed in to change notification settings - Fork 355
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
1 parent
60bdc7c
commit 7681476
Showing
4 changed files
with
196 additions
and
3 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,6 @@ | ||
import enum | ||
|
||
class LlmHTTPEndpoints(enum.Enum): | ||
GENERATE = 'predict' | ||
SALIENCE = 'salience' | ||
TOKENIZE = 'tokenize' |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,106 @@ | ||
"""Wrapper for connetecting to LLMs on GCP via the model_server HTTP API.""" | ||
|
||
import enum | ||
|
||
from lit_nlp.api import model as lit_model | ||
from lit_nlp.api import types as lit_types | ||
from lit_nlp.api.types import Spec | ||
from lit_nlp.examples.gcp import constants as lit_gcp_constants | ||
from lit_nlp.examples.prompt_debugging import constants as pd_constants | ||
from lit_nlp.examples.prompt_debugging import utils as pd_utils | ||
from lit_nlp.lib import serialize | ||
import requests | ||
|
||
""" | ||
Plan for this module: | ||
From GitHub: | ||
* Rebase to include cl/672527408 and the CL described above | ||
* Define an enum to track HTTP endpoints across Python modules | ||
* Adopt HTTP endpoint enum across model_server.py and LlmOverHTTP | ||
* Adopt model_specs.py in LlmOverHTTP, using HTTP endpoint enum for | ||
conditional additions | ||
""" | ||
|
||
_LlmHTTPEndpoints = lit_gcp_constants.LlmHTTPEndpoints | ||
|
||
|
||
class LlmOverHTTP(lit_model.BatchedRemoteModel): | ||
|
||
def __init__( | ||
self, | ||
base_url: str, | ||
endpoint: str | _LlmHTTPEndpoints, | ||
max_concurrent_requests: int = 4, | ||
max_qps: int | float = 25 | ||
): | ||
super().__init__(max_concurrent_requests, max_qps) | ||
self.endpoint = _LlmHTTPEndpoints(endpoint) | ||
self.url = f'{base_url}/{self.endpoint.value}' | ||
|
||
def input_spec(self) -> lit_types.Spec: | ||
input_spec = pd_constants.INPUT_SPEC | ||
|
||
if self.endpoint == _LlmHTTPEndpoints.SALIENCE: | ||
input_spec |= pd_constants.INPUT_SPEC_SALIENCE | ||
|
||
return input_spec | ||
|
||
def output_spec(self) -> lit_types.Spec: | ||
if self.endpoint == _LlmHTTPEndpoints.GENERATE: | ||
return ( | ||
pd_constants.OUTPUT_SPEC_GENERATION | ||
| pd_constants.OUTPUT_SPEC_GENERATION_EMBEDDINGS | ||
) | ||
elif self.endpoint == _LlmHTTPEndpoints.SALIENCE: | ||
return pd_constants.OUTPUT_SPEC_SALIENCE | ||
else: | ||
return pd_constants.OUTPUT_SPEC_TOKENIZER | ||
|
||
def predict_minibatch( | ||
self, inputs: list[lit_types.JsonDict] | ||
) -> list[lit_types.JsonDict]: | ||
"""Run prediction on a batch of inputs. | ||
Subclass should implement this. | ||
Args: | ||
inputs: sequence of inputs, following model.input_spec() | ||
Returns: | ||
list of outputs, following model.output_spec() | ||
""" | ||
response = requests.post( | ||
self.url, data=serialize.to_json(list(inputs), simple=True) | ||
) | ||
|
||
if not (200 <= response.status_code < 300): | ||
raise RuntimeError() | ||
|
||
outputs = serialize.from_json(response.text) | ||
return outputs | ||
|
||
|
||
def initialize_model_group_for_salience( | ||
name: str, base_url: str, *args, **kw | ||
) -> dict[str, lit_model.Model]: | ||
"""Creates '{name}' and '_{name}_salience' and '_{name}_tokenizer'.""" | ||
salience_name, tokenizer_name = pd_utils.generate_model_group_names(name) | ||
|
||
generation_model = LlmOverHTTP( | ||
*args, base_url=base_url, endpoint=_LlmHTTPEndpoints.GENERATE, **kw | ||
) | ||
salience_model = LlmOverHTTP( | ||
*args, base_url=base_url, endpoint=_LlmHTTPEndpoints.SALIENCE, **kw | ||
) | ||
tokenizer_model = LlmOverHTTP( | ||
*args, base_url=base_url, endpoint=_LlmHTTPEndpoints.TOKENIZE, **kw | ||
) | ||
|
||
return { | ||
name: generation_model, | ||
salience_name: salience_model, | ||
tokenizer_name: tokenizer_model, | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,78 @@ | ||
"""Server for sequence salience with a left-to-right language model.""" | ||
|
||
from collections.abc import Mapping, Sequence | ||
import sys | ||
from typing import Optional | ||
|
||
from absl import app | ||
from absl import flags | ||
from absl import logging | ||
from lit_nlp import dev_server | ||
from lit_nlp import server_flags | ||
from lit_nlp.api import model as lit_model | ||
from lit_nlp.api import types as lit_types | ||
from lit_nlp.examples.gcp import model as lit_gcp_model | ||
from lit_nlp.examples.prompt_debugging import datasets as pd_datasets | ||
from lit_nlp.examples.prompt_debugging import layouts as pd_layouts | ||
|
||
|
||
_FLAGS = flags.FLAGS | ||
|
||
_SPLASH_SCREEN_DOC = """ | ||
# Language Model Salience | ||
To begin, select an example, then click the segment(s) (tokens, words, etc.) | ||
of the output that you would like to explain. Preceding segments(s) will be | ||
highlighted according to their importance to the selected target segment(s), | ||
with darker colors indicating a greater influence (salience) of that segment on | ||
the model's likelihood of the target segment. | ||
""" | ||
|
||
|
||
def init_llm_on_gcp( | ||
name: str, base_url: str, *args, **kw | ||
) -> Mapping[str, lit_model.Model]: | ||
return lit_gcp_model.initialize_model_group_for_salience( | ||
name=name, base_url=base_url, *args, **kw | ||
) | ||
|
||
|
||
def get_wsgi_app() -> Optional[dev_server.LitServerType]: | ||
"""Return WSGI app for container-hosted demos.""" | ||
_FLAGS.set_default("server_type", "external") | ||
_FLAGS.set_default("demo_mode", True) | ||
_FLAGS.set_default("page_title", "LM Prompt Debugging") | ||
_FLAGS.set_default("default_layout", pd_layouts.THREE_PANEL) | ||
# Parse flags without calling app.run(main), to avoid conflict with | ||
# gunicorn command line flags. | ||
unused = flags.FLAGS(sys.argv, known_only=True) | ||
if unused: | ||
logging.info("lm_demo:get_wsgi_app() called with unused args: %s", unused) | ||
return main([]) | ||
|
||
|
||
def main(argv: Sequence[str]) -> Optional[dev_server.LitServerType]: | ||
if len(argv) > 1: | ||
raise app.UsageError("Too many command-line arguments.") | ||
|
||
lit_demo = dev_server.Server( | ||
models={}, | ||
datasets={}, | ||
layouts=pd_layouts.PROMPT_DEBUGGING_LAYOUTS, | ||
model_loaders={ | ||
'LLM on GCP': (init_llm_on_gcp, { | ||
'name': lit_types.String(), | ||
'base_url': lit_types.String(), | ||
'max_concurrent_requests': lit_types.Integer(default=1), | ||
'max_qps': lit_types.Scalar(default=25), | ||
}) | ||
}, | ||
dataset_loaders=pd_datasets.get_dataset_loaders(), | ||
onboard_start_doc=_SPLASH_SCREEN_DOC, | ||
**server_flags.get_flags(), | ||
) | ||
return lit_demo.serve() | ||
|
||
|
||
if __name__ == "__main__": | ||
app.run(main) |