From 7f307c9ee009c630a889953c8182e3a3d33fdfdc Mon Sep 17 00:00:00 2001 From: maz Date: Thu, 4 Jul 2024 22:42:47 +0900 Subject: [PATCH] fix: validate arn --- .../lib/bedrock/guardrail.ts | 19 +++++++++++---- .../test/bedrock/invoke-model.test.ts | 24 +++++++++++++++++-- 2 files changed, 37 insertions(+), 6 deletions(-) diff --git a/packages/aws-cdk-lib/aws-stepfunctions-tasks/lib/bedrock/guardrail.ts b/packages/aws-cdk-lib/aws-stepfunctions-tasks/lib/bedrock/guardrail.ts index 3a422ee690623..bdc31f418929e 100644 --- a/packages/aws-cdk-lib/aws-stepfunctions-tasks/lib/bedrock/guardrail.ts +++ b/packages/aws-cdk-lib/aws-stepfunctions-tasks/lib/bedrock/guardrail.ts @@ -1,4 +1,4 @@ -import { Token } from '../../../core'; +import { Arn, ArnFormat, Token } from '../../../core'; /** * Guradrail settings for BedrockInvokeModel @@ -34,10 +34,21 @@ export class Guardrail { */ private constructor(public readonly guardrailIdentifier: string, public readonly guardrailVersion: string) { if (!Token.isUnresolved(guardrailIdentifier)) { - const guardrailPattern = /^(([a-z0-9]+)|(arn:aws(-[^:]+)?:bedrock:[a-z0-9-]{1,20}:[0-9]{12}:guardrail\/[a-z0-9]+))$/; + let gurdrailId = undefined; - if (!guardrailPattern.test(guardrailIdentifier)) { - throw new Error(`You must set guardrailIdentifier to the id or the arn of Guardrail, got ${guardrailIdentifier}`); + if (guardrailIdentifier.startsWith('arn:')) { + const arn = Arn.split(guardrailIdentifier, ArnFormat.SLASH_RESOURCE_NAME); + if (!arn.resourceName) { + throw new Error(`Invalid ARN format. The ARN of Guradrail should have the format: \`arn::bedrock:::guardrail/\`, got ${guardrailIdentifier}.`); + } + gurdrailId = arn.resourceName; + } else { + gurdrailId = guardrailIdentifier; + } + + const guardrailPattern = /^[a-z0-9]+$/; + if (!guardrailPattern.test(gurdrailId)) { + throw new Error(`The id of Guardrail must contain only lowercase letters and numbers, got ${gurdrailId}.`); } if (guardrailIdentifier.length > 2048) { diff --git a/packages/aws-cdk-lib/aws-stepfunctions-tasks/test/bedrock/invoke-model.test.ts b/packages/aws-cdk-lib/aws-stepfunctions-tasks/test/bedrock/invoke-model.test.ts index 10c21e187549b..3c480ea1638a3 100644 --- a/packages/aws-cdk-lib/aws-stepfunctions-tasks/test/bedrock/invoke-model.test.ts +++ b/packages/aws-cdk-lib/aws-stepfunctions-tasks/test/bedrock/invoke-model.test.ts @@ -509,7 +509,7 @@ describe('Invoke Model', () => { }); }); - test('guardrail fails when invalid guardrailIdentifier is set', () => { + test('guardrail fails when guardrailIdentifier is set to invalid id', () => { // GIVEN const stack = new cdk.Stack(); const model = bedrock.ProvisionedModel.fromProvisionedModelArn(stack, 'Imported', 'arn:aws:bedrock:us-turbo-2:123456789012:provisioned-model/abc-123'); @@ -526,7 +526,27 @@ describe('Invoke Model', () => { guardrail: Guardrail.enableDraft('invalid-id'), }); // THEN - }).toThrow('You must set guardrailIdentifier to the id or the arn of Guardrail, got invalid-id'); + }).toThrow('The id of Guardrail must contain only lowercase letters and numbers, got invalid-id'); + }); + + test('guardrail fails when guardrailIdentifier is set to invalid ARN', () => { + // GIVEN + const stack = new cdk.Stack(); + const model = bedrock.ProvisionedModel.fromProvisionedModelArn(stack, 'Imported', 'arn:aws:bedrock:us-turbo-2:123456789012:provisioned-model/abc-123'); + + expect(() => { + // WHEN + new BedrockInvokeModel(stack, 'Invoke', { + model, + body: sfn.TaskInput.fromObject( + { + prompt: 'Hello world', + }, + ), + guardrail: Guardrail.enableDraft('arn:aws:bedrock:us-turbo-2:123456789012:guardrail'), + }); + // THEN + }).toThrow('Invalid ARN format. The ARN of Guradrail should have the format: `arn::bedrock:::guardrail/`, got arn:aws:bedrock:us-turbo-2:123456789012:guardrail.'); }); test('guardrail fails when guardrailIdentifier length is invalid', () => {