Skip to content

Commit

Permalink
feat(visualqa): added review comments
Browse files Browse the repository at this point in the history
  • Loading branch information
Dinesh Sajwan committed Mar 11, 2024
1 parent 34a5ba1 commit 7d4d3a1
Show file tree
Hide file tree
Showing 5 changed files with 21 additions and 34 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -198,8 +198,14 @@ def process_visual_qa(input_params,status_variables,filename):
local_file_path= download_file(bucket_name,filename)
base64_images=encode_image_to_base64(local_file_path,filename)
status_variables['answer']= generate_vision_answer_bedrock(_qa_llm,base64_images, qa_modelId,decoded_question)
status_variables['jobstatus'] = JobStatus.DONE.status
streaming = input_params.get("streaming", False)
if(status_variables['answer'] is None):
status_variables['answer'] = JobStatus.ERROR_PREDICTION.status
error = JobStatus.ERROR_PREDICTION.get_message()
status_variables['answer'] = error.decode("utf-8")
status_variables['jobstatus'] = JobStatus.ERROR_PREDICTION.status
else:
status_variables['jobstatus'] = JobStatus.DONE.status
streaming = input_params.get("streaming", False)

else:
logger.error('Invalid Model , cannot load LLM , returning..')
Expand Down Expand Up @@ -253,7 +259,7 @@ def generate_vision_answer_sagemaker(_qa_llm,input_params,decoded_question,statu

return status_variables

def generate_vision_answer_bedrock(bedrock_client,base64_images, model_id,decoded_question):
def generate_vision_answer_bedrock(bedrock_client,base64_images,model_id,decoded_question):
system_prompt=""
# use system prompt for fine tuning the performamce
# system_prompt= """
Expand Down Expand Up @@ -293,10 +299,15 @@ def generate_vision_answer_bedrock(bedrock_client,base64_images, model_id,decode
}

body=json.dumps({'messages': [messages],**claude_config, "system": system_prompt})
response = bedrock_client.invoke_model(
body=body, modelId=model_id, accept="application/json",
contentType="application/json"
)
try:
response = bedrock_client.invoke_model(
body=body, modelId=model_id, accept="application/json",
contentType="application/json"
)
except Exception as err:
logger.exception(f'Error occurred , Reason :{err}')
return None

response = json.loads(response['body'].read().decode('utf-8'))

formated_response= response['content'][0]['text']
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,6 @@
logger = Logger(service="QUESTION_ANSWERING")


sageMakerEndpoint= os.environ['SAGEMAKER_ENDPOINT']

class ContentHandler(LLMContentHandler):
content_type = "application/json"
Expand Down Expand Up @@ -48,7 +47,6 @@ class MultiModal():

@classmethod
def sagemakerendpoint_llm(self,model_id):
if(sageMakerEndpoint ==model_id):
try:
endpoint= SagemakerEndpoint(
endpoint_name=model_id,
Expand All @@ -60,9 +58,7 @@ def sagemakerendpoint_llm(self,model_id):
except Exception as err:
logger.error(f' Error when accessing sagemaker endpoint :: {model_id} , returning...')
return None
else:
logger.error(f" The sagemaker model Id {model_id} do not match a sagemaker endpoint name {sageMakerEndpoint}")
return None




Expand Down
18 changes: 2 additions & 16 deletions src/patterns/gen-ai/aws-qa-appsync-opensearch/index.ts
Original file line number Diff line number Diff line change
Expand Up @@ -154,11 +154,7 @@ export interface QaAppsyncOpensearchProps {
*/
readonly customDockerLambdaProps?: DockerLambdaCustomProps | undefined;

/**
* Optional. Allows to provide custom lambda code
* and settings instead of the existing
*/
readonly sagemakerEndpointName?: string;

}

/**
Expand Down Expand Up @@ -466,15 +462,7 @@ export class QaAppsyncOpensearch extends BaseClass {
}),
);

if (props.sagemakerEndpointName) {
question_answering_function_role.addToPolicy(
new iam.PolicyStatement({
effect: iam.Effect.ALLOW,
actions: ['sagemaker:InvokeEndpoint'],
resources: ['arn:'+ Aws.PARTITION +':sagemaker:' + Aws.ACCOUNT_ID + ':endpoint/*'],
}),
);
}

// The lambda will access the opensearch credentials
if (props.openSearchSecret) {
props.openSearchSecret.grantRead(question_answering_function_role);
Expand Down Expand Up @@ -555,7 +543,6 @@ export class QaAppsyncOpensearch extends BaseClass {
true,
);

const sagemakerEndpointNamestr = props.sagemakerEndpointName || '';
const construct_docker_lambda_props = {
code: lambda.DockerImageCode.fromImageAsset(
path.join(
Expand All @@ -579,7 +566,6 @@ export class QaAppsyncOpensearch extends BaseClass {
OPENSEARCH_DOMAIN_ENDPOINT: opensearch_helper.getOpenSearchEndpoint(props),
OPENSEARCH_INDEX: props.openSearchIndexName,
OPENSEARCH_SECRET_ID: SecretId,
SAGEMAKER_ENDPOINT: sagemakerEndpointNamestr,
},
...(props.lambdaProvisionedConcurrency && {
currentVersionOptions: {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,6 @@ describe('QA Appsync Open search construct', () => {
openSearchIndexName: 'demoindex',
openSearchSecret: osSecret,
cognitoUserPool: userPoolLoaded,
sagemakerEndpointName: 'sageMakerEndpoint',
};

qaTestConstruct = new QaAppsyncOpensearch(qaTestStack, 'test', qaTestProps);
Expand Down Expand Up @@ -90,7 +89,6 @@ describe('QA Appsync Open search construct', () => {
},
OPENSEARCH_INDEX: 'demoindex',
OPENSEARCH_SECRET_ID: 'OSSecretId',
SAGEMAKER_ENDPOINT: 'sageMakerEndpoint',
},
},
});
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,6 @@ describe('QA Appsync Open search construct', () => {
openSearchIndexName: 'demoindex',
openSearchSecret: osSecret,
cognitoUserPool: userPoolLoaded,
sagemakerEndpointName: 'sageMakerEndpoint',
};

qaTestConstruct = new QaAppsyncOpensearch(qaTestStack, 'test', qaTestProps);
Expand Down Expand Up @@ -88,7 +87,6 @@ describe('QA Appsync Open search construct', () => {
OPENSEARCH_DOMAIN_ENDPOINT: 'osendppint.amazon.aws.com',
OPENSEARCH_INDEX: 'demoindex',
OPENSEARCH_SECRET_ID: 'OSSecretId',
SAGEMAKER_ENDPOINT: 'sageMakerEndpoint',
},
},
});
Expand Down Expand Up @@ -225,7 +223,6 @@ describe('QA Appsync Open search construct custom lambda', () => {
openSearchSecret: osSecret,
cognitoUserPool: userPoolLoaded,
customDockerLambdaProps: customDockerLambdaProps,
sagemakerEndpointName: 'sageMakerEndpoint',
};

qaTestConstruct = new QaAppsyncOpensearch(qaTestStack, 'test', qaTestProps);
Expand Down Expand Up @@ -255,7 +252,6 @@ describe('QA Appsync Open search construct custom lambda', () => {
OPENSEARCH_DOMAIN_ENDPOINT: 'osendppint.amazon.aws.com',
OPENSEARCH_INDEX: 'demoindex',
OPENSEARCH_SECRET_ID: 'OSSecretId',
SAGEMAKER_ENDPOINT: 'sageMakerEndpoint',
TEST_VAR: 'hello',
},
},
Expand Down

0 comments on commit 7d4d3a1

Please sign in to comment.