Skip to content

Commit

Permalink
Merge branch 'main' into ttakamiy/AAP-13719/completions-var-list-support
Browse files Browse the repository at this point in the history
  • Loading branch information
TamiTakamiya committed Nov 7, 2023
2 parents 831ab37 + 1873d92 commit 2136058
Show file tree
Hide file tree
Showing 15 changed files with 301 additions and 87 deletions.
5 changes: 5 additions & 0 deletions .github/workflows/code_coverage.yml
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,11 @@ jobs:
##############
# Python tests
##############
- name: Setup Python 3.9
uses: actions/setup-python@v4
with:
python-version: '3.9'

- name: Install dependencies
run: |
python3 -m pip install --upgrade pip
Expand Down
5 changes: 5 additions & 0 deletions .github/workflows/pip_audit.yml
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,10 @@ jobs:
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v3
- name: Set up Python ${{ matrix.python-version }}
uses: actions/setup-python@v4
with:
python-version: 3.9
- name: install
run: |
python -m venv env/
Expand All @@ -36,4 +40,5 @@ jobs:
GHSA-jh3w-4vvf-mjgr
GHSA-ww3m-ffrm-qvqv
PYSEC-2023-100
PYSEC-2023-228
GHSA-mq26-g339-26xf
5 changes: 5 additions & 0 deletions ansible_wisdom/ai/api/model_client/exceptions.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,11 @@ class WcaTokenFailure(WcaException):
"""An attempt to retrieve a WCA Token failed."""


@dataclass
class WcaCloudflareRejection(WcaException):
"""Cloudflare rejected the request."""


@dataclass
class WcaInferenceFailure(WcaException):
"""An attempt to run a WCA inference failed."""
Expand Down
59 changes: 36 additions & 23 deletions ansible_wisdom/ai/api/model_client/tests/test_wca_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,26 +25,34 @@
WCA_REQUEST_ID_HEADER,
WCAClient,
ibm_cloud_identity_token_hist,
ibm_cloud_identity_token_retry_counter,
wca_codegen_hist,
wca_codegen_retry_counter,
wca_codematch_hist,
wca_codematch_retry_counter,
)
from django.apps import apps
from django.test import override_settings
from prometheus_client import Counter, Histogram
from requests.exceptions import HTTPError, ReadTimeout
from test_utils import WisdomServiceLogAwareTestCase

DEFAULT_SUGGESTION_ID = uuid.uuid4()


class MockResponse:
def __init__(self, json, status_code, headers=None):
def __init__(self, json, status_code, headers=None, text=None):
self._json = json
self.status_code = status_code
self.headers = {} if headers is None else headers
self.text = text

def json(self):
return self._json

def text(self):
return self.text

def raise_for_status(self):
return

Expand Down Expand Up @@ -76,14 +84,16 @@ def stub_wca_client(
return model_id, model_client, model_input


def assert_call_count_metrics(hist):
def assert_call_count_metrics(metric):
def count_metrics_decorator(func):
@wraps(func)
def wrapped_function(*args, **kwargs):
def get_count():
for metric in hist.collect():
for sample in metric.samples:
if sample.name.endswith("_count"):
for m in metric.collect():
for sample in m.samples:
if isinstance(metric, Histogram) and sample.name.endswith("_count"):
return sample.value
if isinstance(metric, Counter) and sample.name.endswith("_total"):
return sample.value
return 0.0

Expand Down Expand Up @@ -223,7 +233,7 @@ def setUp(self):
def tearDown(self):
self.secret_manager_patcher.stop()

@assert_call_count_metrics(hist=ibm_cloud_identity_token_hist)
@assert_call_count_metrics(metric=ibm_cloud_identity_token_hist)
def test_get_token(self):
headers = {
"Content-Type": "application/x-www-form-urlencoded",
Expand Down Expand Up @@ -252,24 +262,25 @@ def test_get_token(self):
data=data,
)

@assert_call_count_metrics(hist=ibm_cloud_identity_token_hist)
@assert_call_count_metrics(metric=ibm_cloud_identity_token_hist)
@assert_call_count_metrics(metric=ibm_cloud_identity_token_retry_counter)
def test_get_token_http_error(self):
model_client = WCAClient(inference_url='http://example.com/')
model_client.session.post = Mock(side_effect=HTTPError(404))
with self.assertRaises(WcaTokenFailure):
model_client.get_token("api-key")

@assert_call_count_metrics(hist=wca_codegen_hist)
@assert_call_count_metrics(metric=wca_codegen_hist)
def test_infer(self):
self._do_inference(
suggestion_id=str(DEFAULT_SUGGESTION_ID), request_id=str(DEFAULT_SUGGESTION_ID)
)

@assert_call_count_metrics(hist=wca_codegen_hist)
@assert_call_count_metrics(metric=wca_codegen_hist)
def test_infer_without_suggestion_id(self):
self._do_inference(suggestion_id=None, request_id=str(DEFAULT_SUGGESTION_ID))

@assert_call_count_metrics(hist=wca_codegen_hist)
@assert_call_count_metrics(metric=wca_codegen_hist)
def test_infer_without_request_id_header(self):
self._do_inference(suggestion_id=str(DEFAULT_SUGGESTION_ID), request_id=None)

Expand Down Expand Up @@ -331,7 +342,7 @@ def _do_inference(self, suggestion_id=None, request_id=None):
)
self.assertEqual(result, predictions)

@assert_call_count_metrics(hist=wca_codegen_hist)
@assert_call_count_metrics(metric=wca_codegen_hist)
def test_infer_timeout(self):
model_id = "zavala"
model_input = {
Expand Down Expand Up @@ -360,7 +371,8 @@ def test_infer_timeout(self):
)
self.assertEqual(e.exception.model_id, model_id)

@assert_call_count_metrics(hist=wca_codegen_hist)
@assert_call_count_metrics(metric=wca_codegen_hist)
@assert_call_count_metrics(metric=wca_codegen_retry_counter)
def test_infer_http_error(self):
model_id = "zavala"
model_input = {
Expand Down Expand Up @@ -389,7 +401,7 @@ def test_infer_http_error(self):
)
self.assertEqual(e.exception.model_id, model_id)

@assert_call_count_metrics(hist=wca_codegen_hist)
@assert_call_count_metrics(metric=wca_codegen_hist)
def test_infer_request_id_correlation_failure(self):
model_id = "zavala"
model_input = {
Expand Down Expand Up @@ -425,7 +437,7 @@ def test_infer_request_id_correlation_failure(self):
)
self.assertEqual(e.exception.model_id, model_id)

@assert_call_count_metrics(hist=wca_codegen_hist)
@assert_call_count_metrics(metric=wca_codegen_hist)
def test_infer_garbage_model_id(self):
stub = stub_wca_client(400, "zavala")
model_id, model_client, model_input = stub
Expand All @@ -435,7 +447,7 @@ def test_infer_garbage_model_id(self):
)
self.assertEqual(e.exception.model_id, model_id)

@assert_call_count_metrics(hist=wca_codegen_hist)
@assert_call_count_metrics(metric=wca_codegen_hist)
def test_infer_invalid_model_id_for_api_key(self):
stub = stub_wca_client(403, "zavala")
model_id, model_client, model_input = stub
Expand All @@ -445,7 +457,7 @@ def test_infer_invalid_model_id_for_api_key(self):
)
self.assertEqual(e.exception.model_id, model_id)

@assert_call_count_metrics(hist=wca_codegen_hist)
@assert_call_count_metrics(metric=wca_codegen_hist)
def test_infer_empty_response(self):
stub = stub_wca_client(204, "zavala")
model_id, model_client, model_input = stub
Expand All @@ -455,7 +467,7 @@ def test_infer_empty_response(self):
)
self.assertEqual(e.exception.model_id, model_id)

@assert_call_count_metrics(hist=wca_codegen_hist)
@assert_call_count_metrics(metric=wca_codegen_hist)
def test_infer_preprocessed_multitask_prompt_error(self):
# See https://issues.redhat.com/browse/AAP-16642
stub = stub_wca_client(
Expand Down Expand Up @@ -485,7 +497,7 @@ def setUp(self):
def tearDown(self):
self.secret_manager_patcher.stop()

@assert_call_count_metrics(hist=wca_codematch_hist)
@assert_call_count_metrics(metric=wca_codematch_hist)
def test_codematch(self):
model_id = "sample_model_name"
suggestions = [
Expand Down Expand Up @@ -553,7 +565,7 @@ def test_codematch(self):
)
self.assertEqual(result, client_response)

@assert_call_count_metrics(hist=wca_codematch_hist)
@assert_call_count_metrics(metric=wca_codematch_hist)
def test_codematch_timeout(self):
model_id = "sample_model_name"
suggestions = [
Expand All @@ -578,7 +590,8 @@ def test_codematch_timeout(self):
model_client.codematch(model_input=model_input, model_id=model_id)
self.assertEqual(e.exception.model_id, model_id)

@assert_call_count_metrics(hist=wca_codematch_hist)
@assert_call_count_metrics(metric=wca_codematch_hist)
@assert_call_count_metrics(metric=wca_codematch_retry_counter)
def test_codematch_http_error(self):
model_id = "sample_model_name"
model_input = {
Expand All @@ -604,23 +617,23 @@ def test_codematch_http_error(self):
model_client.codematch(model_input=model_input, model_id=model_id)
self.assertEqual(e.exception.model_id, model_id)

@assert_call_count_metrics(hist=wca_codematch_hist)
@assert_call_count_metrics(metric=wca_codematch_hist)
def test_codematch_bad_model_id(self):
stub = stub_wca_client(400, "sample_model_name")
model_id, model_client, model_input = stub
with self.assertRaises(WcaInvalidModelId) as e:
model_client.codematch(model_input=model_input, model_id=model_id)
self.assertEqual(e.exception.model_id, model_id)

@assert_call_count_metrics(hist=wca_codematch_hist)
@assert_call_count_metrics(metric=wca_codematch_hist)
def test_codematch_invalid_model_id_for_api_key(self):
stub = stub_wca_client(403, "sample_model_name")
model_id, model_client, model_input = stub
with self.assertRaises(WcaInvalidModelId) as e:
model_client.codematch(model_input=model_input, model_id=model_id)
self.assertEqual(e.exception.model_id, model_id)

@assert_call_count_metrics(hist=wca_codematch_hist)
@assert_call_count_metrics(metric=wca_codematch_hist)
def test_codematch_empty_response(self):
stub = stub_wca_client(204, "sample_model_name")
model_id, model_client, model_input = stub
Expand Down
47 changes: 43 additions & 4 deletions ansible_wisdom/ai/api/model_client/wca_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
from django.apps import apps
from django.conf import settings
from django_prometheus.conf import NAMESPACE
from prometheus_client import Histogram
from prometheus_client import Counter, Histogram
from requests.exceptions import HTTPError

from ..aws.wca_secret_manager import Suffixes, WcaSecretManagerError
Expand Down Expand Up @@ -48,6 +48,21 @@
"Histogram of IBM Cloud identity token API processing time",
namespace=NAMESPACE,
)
wca_codegen_retry_counter = Counter(
'wca_codegen_retries',
"Counter of WCA codegen API invocation retries",
namespace=NAMESPACE,
)
wca_codematch_retry_counter = Counter(
'wca_codematch_retries',
"Counter of WCA codematch API invocation retries",
namespace=NAMESPACE,
)
ibm_cloud_identity_token_retry_counter = Counter(
'ibm_cloud_identity_token_retries',
"Counter of IBM Cloud identity token API invocation retries",
namespace=NAMESPACE,
)


class WCAClient(ModelMeshClient):
Expand All @@ -71,6 +86,18 @@ def fatal_exception(exc):
# retry on all other errors (e.g. network)
return False

@staticmethod
def on_backoff_inference(details):
wca_codegen_retry_counter.inc()

@staticmethod
def on_backoff_codematch(details):
wca_codematch_retry_counter.inc()

@staticmethod
def on_backoff_ibm_cloud_identity_token(details):
ibm_cloud_identity_token_retry_counter.inc()

def infer(self, model_input, model_id=None, suggestion_id=None):
logger.debug(f"Input prompt: {model_input}")

Expand Down Expand Up @@ -115,7 +142,11 @@ def infer_from_parameters(self, api_key, model_id, context, prompt, suggestion_i
prediction_url = f"{self._inference_url}/v1/wca/codegen/ansible"

@backoff.on_exception(
backoff.expo, Exception, max_tries=self.retries + 1, giveup=self.fatal_exception
backoff.expo,
Exception,
max_tries=self.retries + 1,
giveup=self.fatal_exception,
on_backoff=self.on_backoff_inference,
)
@wca_codegen_hist.time()
def post_request():
Expand Down Expand Up @@ -157,7 +188,11 @@ def get_token(self, api_key):
data = {"grant_type": "urn:ibm:params:oauth:grant-type:apikey", "apikey": api_key}

@backoff.on_exception(
backoff.expo, Exception, max_tries=self.retries + 1, giveup=self.fatal_exception
backoff.expo,
Exception,
max_tries=self.retries + 1,
giveup=self.fatal_exception,
on_backoff=self.on_backoff_ibm_cloud_identity_token,
)
@ibm_cloud_identity_token_hist.time()
def post_request():
Expand Down Expand Up @@ -253,7 +288,11 @@ def codematch(self, model_input, model_id=None):
suggestion_count = len(suggestions)

@backoff.on_exception(
backoff.expo, Exception, max_tries=self.retries + 1, giveup=self.fatal_exception
backoff.expo,
Exception,
max_tries=self.retries + 1,
giveup=self.fatal_exception,
on_backoff=self.on_backoff_codematch,
)
@wca_codematch_hist.time()
def post_request():
Expand Down
Loading

0 comments on commit 2136058

Please sign in to comment.