From 273ef28464a32810cf3995dd62b1fa839f050f3d Mon Sep 17 00:00:00 2001 From: Dinesh Sajwan Date: Wed, 21 Feb 2024 11:58:07 -0500 Subject: [PATCH] feat(imagegeneration): updated image schema and inference params --- .../src/image_generator.py | 92 ++++++++++++++++--- .../schema.graphql | 1 + 2 files changed, 82 insertions(+), 11 deletions(-) diff --git a/lambda/aws-contentgen-appsync-lambda/src/image_generator.py b/lambda/aws-contentgen-appsync-lambda/src/image_generator.py index 90993596..01ad300d 100644 --- a/lambda/aws-contentgen-appsync-lambda/src/image_generator.py +++ b/lambda/aws-contentgen-appsync-lambda/src/image_generator.py @@ -108,24 +108,48 @@ def image_moderation(self): def generate_image(self,input_params): - """Generate image using Using bedrock with configured modelid""" + """Generate image using Using bedrock with configured modelid and params""" input_text=self.input_text - model_id=input_params['model_config']['modelId'] - cfg_scale=input_params['model_config']['model_kwargs']['cfg_scale'] - seed=input_params['model_config']['model_kwargs']['seed'] - steps=input_params['model_config']['model_kwargs']['steps'] - - promptTemplate="{\"text_prompts\":[{\"text\":\"$input_text\\n\"}],\"cfg_scale\":$cfg_scale,\"seed\":$seed,\"steps\":$steps}" - - prompt=promptTemplate.replace("$input_text", input_text).replace("$cfg_scale", str(cfg_scale)).replace("$seed", str(seed)).replace("$steps", str(steps)) + # add default negative prompts + if 'negative_prompts' in input_params: + sample_string_bytes = base64.b64decode(input_params['negative_prompts']) + decoded_negative_prompts = sample_string_bytes.decode("utf-8") + logger.info(f"decoded negative prompts are :: {decoded_negative_prompts}") + negative_prompts= decoded_negative_prompts + else: + negative_prompts= ["poorly rendered","poor background details"] + + model_id=input_params['model_config']['modelId'] + + model_kwargs=input_params['model_config']['model_kwargs'] + params= get_inference_parameters(model_kwargs) + + logger.info(f'SD params :: {params}') + + + request = json.dumps({ + "text_prompts": ( + [{"text": input_text, "weight": 1.0}] + + [{"text": negprompt, "weight": -1.0} for negprompt in negative_prompts] + ), + "cfg_scale":params['cfg_scale'], + "seed": params['seed'], + "steps": params['steps'], + "style_preset": params['style_preset'], + "clip_guidance_preset": params['clip_guidance_preset'], + "sampler": params['sampler'], + "width": params['width'], + "height": params['height'] + }) + try: return self.bedrock_client.invoke_model( modelId= model_id, contentType= "application/json", accept= "application/json", - body=prompt + body=request ) except Exception as e: logger.error(f"Error occured during generating image:: {e}") @@ -172,4 +196,50 @@ def send_job_status(self,variables): auth=aws_auth_appsync, timeout=10 ) - logger.info('res :: {}',responseJobstatus) \ No newline at end of file + logger.info('res :: {}',responseJobstatus) + + +def get_inference_parameters(model_kwargs): + """ Read inference parameters and set default values""" + if 'seed' in model_kwargs: + seed= model_kwargs['seed'] + else: + seed=452345 + if 'cfg_scale' in model_kwargs: + cfg_scale= model_kwargs['cfg_scale'] + else: + cfg_scale=10 + if 'steps' in model_kwargs: + steps= model_kwargs['steps'] + else: + steps=10 + if 'style_preset' in model_kwargs: + style_preset= model_kwargs['style_preset'] + else: + style_preset='photographic' + if 'clip_guidance_preset' in model_kwargs: + clip_guidance_preset= model_kwargs['clip_guidance_preset'] + else: + clip_guidance_preset='FAST_GREEN' + if 'width' in model_kwargs: + width= model_kwargs['width'] + else: + width=512 + if 'height' in model_kwargs: + height= model_kwargs['height'] + else: + height=512 + if 'sampler' in model_kwargs: + sampler= model_kwargs['sampler'] + else: + sampler='K_DPMPP_2S_ANCESTRAL' + return { + "cfg_scale": cfg_scale, + "seed": seed, + "steps": steps, + "style_preset": style_preset, + "clip_guidance_preset": clip_guidance_preset, + "sampler": sampler, + "width": width, + "height": height, + } diff --git a/resources/gen-ai/aws-contentgen-appsync-lambda/schema.graphql b/resources/gen-ai/aws-contentgen-appsync-lambda/schema.graphql index 4d530c93..4979943c 100644 --- a/resources/gen-ai/aws-contentgen-appsync-lambda/schema.graphql +++ b/resources/gen-ai/aws-contentgen-appsync-lambda/schema.graphql @@ -11,6 +11,7 @@ input ImageInput { filename: String model_config: ModelConfiguration input_text: String! + negative_prompts: String } type ImageOutput @aws_iam @aws_cognito_user_pools {