Skip to content

Commit

Permalink
feat(stepfunctions-tasks): allow BedrockInvokeModel to use JsonPath (#…
Browse files Browse the repository at this point in the history
…30298)

### Issue # (if applicable)

Closes #29229.

### Reason for this change

When trying to use JsonPath to specify the S3 URIs that BedrockInvokeModel will read from and write from, you get an error.

Example of the Error message:

`jsii.errors.JavaScriptError: Error: Field references must be the entire string, cannot concatenate them (found 's3://${Token[prompt_bucket.348]}/${Token[prompt_key.349]}')`

### Description of changes

Extended the inputPath property to be allowed as an input value for the task state.
Instead of adding a new S3Uri props in current `BedrockInvokeModelProps` as proposed in the original issue, leveraged the `inputPath` property that is already defined in   `sfn.TaskStateBaseProps` and being extended by `BedrockInvokeModelInputProps` and `BedrockInvokeModelOutputProps`

**Limitation:**  We cannot limit the resource policy to specific input token for which the value might be coming from the prompt, so had to keep it as [*] here. 
### Description of how you validated changes

Added unit tests.
Successful deployment of integration tests in the account.

### Checklist
- [x] Unit Tests
- [x] Integration Tests
- [x] Updated ReadMe
- [x] My code adheres to the [CONTRIBUTING GUIDE](https://github.com/aws/aws-cdk/blob/main/CONTRIBUTING.md) and [DESIGN GUIDELINES](https://github.com/aws/aws-cdk/blob/main/docs/DESIGN_GUIDELINES.md)

----

*By submitting this pull request, I confirm that my contribution is made under the terms of the Apache-2.0 license*
  • Loading branch information
shikha372 authored Jul 24, 2024
1 parent 82b163d commit f5dd73b
Show file tree
Hide file tree
Showing 11 changed files with 255 additions and 19 deletions.

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,25 @@
]
]
}
},
{
"Action": [
"s3:GetObject",
"s3:PutObject"
],
"Effect": "Allow",
"Resource": {
"Fn::Join": [
"",
[
"arn:",
{
"Ref": "AWS::Partition"
},
":s3:::*"
]
]
}
}
],
"Version": "2012-10-17"
Expand Down Expand Up @@ -72,7 +91,19 @@
{
"Ref": "AWS::Region"
},
"::foundation-model/amazon.titan-text-express-v1\",\"Body\":{\"inputText\":\"Generate a list of five first names.\",\"textGenerationConfig\":{\"maxTokenCount\":100,\"temperature\":1}}}},\"Prompt2\":{\"End\":true,\"Type\":\"Task\",\"ResultPath\":\"$\",\"ResultSelector\":{\"names.$\":\"$.Body.results[0].outputText\"},\"Resource\":\"arn:",
"::foundation-model/amazon.titan-text-express-v1\",\"Body\":{\"inputText\":\"Generate a list of five first names.\",\"textGenerationConfig\":{\"maxTokenCount\":100,\"temperature\":1}}}},\"Prompt2\":{\"Next\":\"Prompt3\",\"Type\":\"Task\",\"ResultPath\":\"$\",\"ResultSelector\":{\"names.$\":\"$.Body.results[0].outputText\"},\"Resource\":\"arn:",
{
"Ref": "AWS::Partition"
},
":states:::bedrock:invokeModel\",\"Parameters\":{\"ModelId\":\"arn:",
{
"Ref": "AWS::Partition"
},
":bedrock:",
{
"Ref": "AWS::Region"
},
"::foundation-model/amazon.titan-text-express-v1\",\"Body\":{\"inputText.$\":\"States.Format('Alphabetize this list of first names:\\n{}', $.names)\",\"textGenerationConfig\":{\"maxTokenCount\":100,\"temperature\":1}}}},\"Prompt3\":{\"End\":true,\"Type\":\"Task\",\"InputPath\":\"$.names\",\"OutputPath\":\"$.names\",\"Resource\":\"arn:",
{
"Ref": "AWS::Partition"
},
Expand All @@ -84,7 +115,7 @@
{
"Ref": "AWS::Region"
},
"::foundation-model/amazon.titan-text-express-v1\",\"Body\":{\"inputText.$\":\"States.Format('Alphabetize this list of first names:\\n{}', $.names)\",\"textGenerationConfig\":{\"maxTokenCount\":100,\"temperature\":1}}}}},\"TimeoutSeconds\":30}"
"::foundation-model/amazon.titan-text-express-v1\",\"Input\":{\"S3Uri.$\":\"$.names\"},\"Output\":{\"S3Uri.$\":\"$.names\"}}}},\"TimeoutSeconds\":30}"
]
]
},
Expand Down

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,13 @@ const prompt2 = new BedrockInvokeModel(stack, 'Prompt2', {
resultPath: '$',
});

const chain = sfn.Chain.start(prompt1).next(prompt2);
const prompt3 = new BedrockInvokeModel(stack, 'Prompt3', {
model,
inputPath: sfn.JsonPath.stringAt('$.names'),
outputPath: sfn.JsonPath.stringAt('$.names'),
});

const chain = sfn.Chain.start(prompt1).next(prompt2).next(prompt3);

new sfn.StateMachine(stack, 'StateMachine', {
definitionBody: sfn.DefinitionBody.fromChainable(chain),
Expand Down
21 changes: 21 additions & 0 deletions packages/aws-cdk-lib/aws-stepfunctions-tasks/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -398,6 +398,27 @@ const task = new tasks.BedrockInvokeModel(this, 'Prompt Model', {
names: sfn.JsonPath.stringAt('$.Body.results[0].outputText'),
},
});
```
### Using Input Path

Provide S3 URI as an input or output path to invoke a model

```ts

import * as bedrock from 'aws-cdk-lib/aws-bedrock';

const model = bedrock.FoundationModel.fromFoundationModelId(
this,
'Model',
bedrock.FoundationModelIdentifier.AMAZON_TITAN_TEXT_G1_EXPRESS_V1,
);

const task = new tasks.BedrockInvokeModel(this, 'Prompt Model', {
model,
inputPath: sfn.JsonPath.stringAt('$.prompt'),
outputPath: sfn.JsonPath.stringAt('$.prompt'),
});

```

You can apply a guardrail to the invocation by setting `guardrail`.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -140,12 +140,14 @@ export class BedrockInvokeModel extends sfn.TaskStateBase {

constructor(scope: Construct, id: string, private readonly props: BedrockInvokeModelProps) {
super(scope, id, props);

this.integrationPattern = props.integrationPattern ?? sfn.IntegrationPattern.REQUEST_RESPONSE;

validatePatternSupported(this.integrationPattern, BedrockInvokeModel.SUPPORTED_INTEGRATION_PATTERNS);

const isBodySpecified = props.body !== undefined;
const isInputSpecified = props.input !== undefined && props.input.s3Location !== undefined;
//Either specific props.input with bucket name and object key or input s3 path
const isInputSpecified = (props.input !== undefined && props.input.s3Location !== undefined) || (props.inputPath !== undefined);

if (isBodySpecified && isInputSpecified) {
throw new Error('Either `body` or `input` must be specified, but not both.');
Expand All @@ -171,7 +173,21 @@ export class BedrockInvokeModel extends sfn.TaskStateBase {
}),
];

if (this.props.input !== undefined && this.props.input.s3Location !== undefined) {
if (this.props.inputPath !== undefined) {
policyStatements.push(
new iam.PolicyStatement({
actions: ['s3:GetObject'],
resources: [
Stack.of(this).formatArn({
region: '',
account: '',
service: 's3',
resource: '*',
}),
],
}),
);
} else if (this.props.input !== undefined && this.props.input.s3Location !== undefined) {
policyStatements.push(
new iam.PolicyStatement({
actions: ['s3:GetObject'],
Expand All @@ -188,7 +204,21 @@ export class BedrockInvokeModel extends sfn.TaskStateBase {
);
}

if (this.props.output !== undefined && this.props.output.s3Location !== undefined) {
if (this.props.outputPath !== undefined) {
policyStatements.push(
new iam.PolicyStatement({
actions: ['s3:PutObject'],
resources: [
Stack.of(this).formatArn({
region: '',
account: '',
service: 's3',
resource: '*',
}),
],
}),
);
} else if (this.props.output !== undefined && this.props.output.s3Location !== undefined) {
policyStatements.push(
new iam.PolicyStatement({
actions: ['s3:PutObject'],
Expand Down Expand Up @@ -241,10 +271,10 @@ export class BedrockInvokeModel extends sfn.TaskStateBase {
Body: this.props.body?.value,
Input: this.props.input?.s3Location ? {
S3Uri: `s3://${this.props.input.s3Location.bucketName}/${this.props.input.s3Location.objectKey}`,
} : undefined,
} : this.props.inputPath ? { S3Uri: this.props.inputPath } : undefined,
Output: this.props.output?.s3Location ? {
S3Uri: `s3://${this.props.output.s3Location.bucketName}/${this.props.output.s3Location.objectKey}`,
} : undefined,
} : this.props.outputPath ? { S3Uri: this.props.outputPath }: undefined,
GuardrailIdentifier: this.props.guardrail?.guardrailIdentifier,
GuardrailVersion: this.props.guardrail?.guardrailVersion,
Trace: this.props.traceEnabled === undefined
Expand All @@ -254,5 +284,6 @@ export class BedrockInvokeModel extends sfn.TaskStateBase {
: 'DISABLED',
}),
};
}
};
}

Loading

0 comments on commit f5dd73b

Please sign in to comment.