From d4fa9060205f200efdb4e46592adfdb62a0b4401 Mon Sep 17 00:00:00 2001 From: Dan Nelson Date: Tue, 27 Aug 2024 19:15:13 +0000 Subject: [PATCH] working integration tests --- integration-tests/test-model.py | 142 +++++++++++++++++++++++++------- 1 file changed, 113 insertions(+), 29 deletions(-) diff --git a/integration-tests/test-model.py b/integration-tests/test-model.py index f6fd320..7fdb4bb 100644 --- a/integration-tests/test-model.py +++ b/integration-tests/test-model.py @@ -19,17 +19,21 @@ ENV = os.getenv('TEST_ENV', 'local') LOCAL_ENDPOINT = "http://localhost:5000/predictions" MODEL = os.getenv('MODEL', 'no model configured') +IS_DEV = "dev" in MODEL def local_run(model_endpoint: str, model_input: dict): + # TODO: figure this out for multi-image local predictions + st = time.time() response = requests.post(model_endpoint, json={"input": model_input}) + et = time.time() - st data = response.json() try: datauri = data["output"] base64_encoded_data = datauri.split(",")[1] data = base64.b64decode(base64_encoded_data) - return Image.open(BytesIO(data)) + return et, Image.open(BytesIO(data)) except Exception as e: print("Error!") print("input:", model_input) @@ -37,17 +41,24 @@ def local_run(model_endpoint: str, model_input: dict): raise e -def replicate_run(model: str, version: str, model_input: dict): - output = replicate.run( - f"{model}:{version}", - input=model_input) - url = output +def replicate_run(version: str, model_input: dict): + pred = replicate.predictions.create( + version=version, + input=model_input + ) - response = requests.get(url) - return Image.open(BytesIO(response.content)) + pred.wait() + + predict_time = pred.metrics['predict_time'] + images = [] + for url in pred.output: + response = requests.get(url) + images.append(Image.open(BytesIO(response.content))) + print(pred.id) + return predict_time, images -def wait_for_server_to_be_ready(url, timeout=300): +def wait_for_server_to_be_ready(url, timeout=400): """ Waits for the server to be ready. @@ -84,7 +95,7 @@ def inference_func(): elif ENV in {'staging', 'prod'}: model = replicate.models.get(MODEL) version = model.versions.list()[0] - return partial(replicate_run, MODEL, version.id) + return partial(replicate_run, version) else: raise Exception(f"env should be local, staging, or prod but was {ENV}") @@ -121,34 +132,107 @@ def service(): yield -# things I would test: -# standard generation for dev and schnell. assert that the output image has a dog in it with blip-2 or llava +def get_time_bound(): + """entirely here to make sure we don't recompile""" + return 20 if IS_DEV else 10 -# num_outputs = 4, assert time is about what you'd expect off of the prediction object +def test_base_generation(inference_func): + """standard generation for dev and schnell. assert that the output image has a dog in it with blip-2 or llava""" + test_example = { + 'prompt': "A cool dog", + "aspect ratio": "1:1", + "num_outputs": 1, + } + time, img_out = inference_func(test_example) + img_out = img_out[0] + + assert time < get_time_bound() + assert img_out.size == (1024, 1024) -# determinism - test with the same seed twice +def test_num_outputs(inference_func): + """num_outputs = 4, assert time is about what you'd expect off of the prediction object""" + base_time = None + for n_outputs in range(1, 5): + test_example = { + 'prompt': "A cool dog", + "aspect ratio": "1:1", + "num_outputs": n_outputs, + } + time, img_out = inference_func(test_example) + assert len(img_out) == n_outputs + if base_time: + assert time < base_time * n_outputs * 1.5 + if n_outputs == 1: + base_time = time -# changing resolutions - iterate through all resolutions and make sure that the output is valid -def test_base_example(inference_func): +def test_determinism(inference_func): + """determinism - test with the same seed twice""" test_example = { + 'prompt': "A cool dog", + "aspect_ratio": "9:16", + "num_outputs": 1, + "seed": 112358 + } + time, out_one = inference_func(test_example) + out_one = out_one[0] + assert time < get_time_bound() + time_two, out_two = inference_func(test_example) + out_two = out_two[0] + assert time_two < get_time_bound() + assert out_one.size == (768, 1344) + + one_array = np.array(out_one, dtype=np.uint16) + two_array = np.array(out_two, dtype=np.uint16) + assert np.allclose(one_array, two_array, atol=20) + + +def test_resolutions(inference_func): + """changing resolutions - iterate through all resolutions and make sure that the output is """ + aspect_ratios = { + "1:1": (1024, 1024), + "16:9": (1344, 768), + "21:9": (1536, 640), + "3:2": (1216, 832), + "2:3": (832, 1216), + "4:5": (896, 1088), + "5:4": (1088, 896), + "9:16": (768, 1344), + "9:21": (640, 1536), + } + + for ratio, output in aspect_ratios.items(): + test_example = { + 'prompt': "A cool dog", + "aspect_ratio": ratio, + "num_outputs": 1, + "seed": 112358 + } + + time, img_out = inference_func(test_example) + img_out = img_out[0] + assert img_out.size == output + assert time < get_time_bound() + + +def test_img2img(inference_func): + """img2img. does it work?""" + if not IS_DEV: + assert True + return + + test_example= { + 'prompt': 'a cool walrus', 'image': 'https://replicate.delivery/pbxt/IS6z50uYJFdFeh1vCmXe9zasYbG16HqOOMETljyUJ1hmlUXU/keanu.jpeg', - 'scale': 4, - 'face_enhance': True } - expected_url = 'https://replicate.delivery/pbxt/lv0iOW3u6DrNOd30ybfmufqWebiuW10YjILw05YZGbeipZZCB/output.png' - resp = requests.get(expected_url) - expected_img = Image.open(BytesIO(resp.content)) - img_out = inference_func(test_example) - assert img_out.size == expected_img.size - img_out.save('image_out.png') - expected_img.save('expected_out.png') - - img_out_array = np.array(img_out, dtype=np.uint16) - expected_array = np.array(expected_img, dtype=np.uint16) - assert np.allclose(img_out_array, expected_array, atol=20) \ No newline at end of file + _, img_out = inference_func(test_example) + img_out = img_out[0] + assert img_out.size[0] % 16 == 0 + assert img_out.size[0] < 1440 + assert img_out.size[1] % 16 == 0 + assert img_out.size[1] < 1440