From f69dac38f1d4662f1e7eec58a6b7015e7bb10d21 Mon Sep 17 00:00:00 2001 From: Pavel Jbanov Date: Wed, 4 Dec 2024 15:50:31 -0500 Subject: [PATCH] feat(js/flows): consolidated `defineFlow` and `defineStreamingFlow` (#1401) --- js/core/src/flow.ts | 67 +++++++++++++--------- js/core/tests/flow_test.ts | 107 ++++++++++++++++++++++++++++++++++- js/core/tests/utils.ts | 60 ++++++++++++++++++++ js/genkit/src/genkit.ts | 9 ++- js/genkit/tests/flow_test.ts | 81 ++++++++++++++++++++++++++ 5 files changed, 293 insertions(+), 31 deletions(-) create mode 100644 js/core/tests/utils.ts create mode 100644 js/genkit/tests/flow_test.ts diff --git a/js/core/src/flow.ts b/js/core/src/flow.ts index 6a9948be6..60d193146 100644 --- a/js/core/src/flow.ts +++ b/js/core/src/flow.ts @@ -91,32 +91,37 @@ export interface StreamingFlowConfig< streamSchema?: S; } +export interface FlowCallOptions { + /** @deprecated use {@link context} instead. */ + withLocalAuthContext?: unknown; + context?: unknown; +} + /** * Non-streaming flow that can be called directly like a function. */ export interface CallableFlow< I extends z.ZodTypeAny = z.ZodTypeAny, O extends z.ZodTypeAny = z.ZodTypeAny, + S extends z.ZodTypeAny = z.ZodTypeAny, > { - ( - input?: z.infer, - opts?: { withLocalAuthContext?: unknown } - ): Promise>; + (input?: z.infer, opts?: FlowCallOptions): Promise>; + + stream(input?: z.infer, opts?: FlowCallOptions): StreamingResponse; + flow: Flow; } /** * Streaming flow that can be called directly like a function. + * @deprecated use {@link CallableFlow} */ export interface StreamableFlow< I extends z.ZodTypeAny = z.ZodTypeAny, O extends z.ZodTypeAny = z.ZodTypeAny, S extends z.ZodTypeAny = z.ZodTypeAny, > { - ( - input?: z.infer, - opts?: { withLocalAuthContext?: unknown } - ): StreamingResponse; + (input?: z.infer, opts?: FlowCallOptions): StreamingResponse; flow: Flow; } @@ -128,7 +133,7 @@ interface StreamingResponse< S extends z.ZodTypeAny = z.ZodTypeAny, > { /** Iterator over the streaming chunks. */ - stream: AsyncGenerator, z.infer | undefined>; + stream: AsyncGenerator>; /** Final output of the flow. */ output: Promise>; } @@ -144,7 +149,7 @@ export type FlowFn< /** Input to the flow. */ input: z.infer, /** Callback for streaming functions only. */ - streamingCallback?: StreamingCallback> + streamingCallback: StreamingCallback> ) => Promise> | z.infer; /** @@ -223,7 +228,10 @@ export class Flow< }); try { metadata.input = input; - const output = await this.flowFn(input, opts.streamingCallback); + const output = await this.flowFn( + input, + opts.streamingCallback ?? (() => {}) + ); metadata.output = JSON.stringify(output); setCustomMetadataAttribute(flowMetadataPrefix('state'), 'done'); return { @@ -252,10 +260,7 @@ export class Flow< /** * Runs the flow. This is used when calling a flow from another flow. */ - async run( - payload?: z.infer, - opts?: { withLocalAuthContext?: unknown } - ): Promise> { + async run(payload?: z.infer, opts?: FlowCallOptions): Promise> { const input = this.inputSchema ? this.inputSchema.parse(payload) : payload; await this.authPolicy?.(opts?.withLocalAuthContext, payload); @@ -266,7 +271,7 @@ export class Flow< } const result = await this.invoke(input, { - auth: opts?.withLocalAuthContext, + auth: opts?.context || opts?.withLocalAuthContext, }); return result.result; } @@ -276,7 +281,7 @@ export class Flow< */ stream( payload?: z.infer, - opts?: { withLocalAuthContext?: unknown } + opts?: FlowCallOptions ): StreamingResponse { let chunkStreamController: ReadableStreamController>; const chunkStream = new ReadableStream>({ @@ -288,7 +293,7 @@ export class Flow< }); const authPromise = - this.authPolicy?.(opts?.withLocalAuthContext, payload) ?? + this.authPolicy?.(opts?.context || opts?.withLocalAuthContext, payload) ?? Promise.resolve(); const invocationPromise = authPromise @@ -301,7 +306,7 @@ export class Flow< }) as S extends z.ZodVoid ? undefined : StreamingCallback>, - auth: opts?.withLocalAuthContext, + auth: opts?.context || opts?.withLocalAuthContext, } ).then((s) => s.result) ) @@ -530,21 +535,31 @@ export class FlowServer { export function defineFlow< I extends z.ZodTypeAny = z.ZodTypeAny, O extends z.ZodTypeAny = z.ZodTypeAny, + S extends z.ZodTypeAny = z.ZodTypeAny, >( registry: Registry, - config: FlowConfig | string, - fn: FlowFn -): CallableFlow { + config: StreamingFlowConfig | string, + fn: FlowFn +): CallableFlow { const resolvedConfig: FlowConfig = typeof config === 'string' ? { name: config } : config; - const flow = new Flow(registry, resolvedConfig, fn); + const flow = new Flow(registry, resolvedConfig, fn); registerFlowAction(registry, flow); - const callableFlow: CallableFlow = async (input, opts) => { + const callableFlow = async ( + input: z.infer, + opts: FlowCallOptions + ): Promise> => { return flow.run(input, opts); }; - callableFlow.flow = flow; - return callableFlow; + (callableFlow as CallableFlow).flow = flow; + (callableFlow as CallableFlow).stream = ( + input: z.infer, + opts: FlowCallOptions + ): StreamingResponse => { + return flow.stream(input, opts); + }; + return callableFlow as CallableFlow; } /** diff --git a/js/core/tests/flow_test.ts b/js/core/tests/flow_test.ts index 141e244b1..5a3df6e7f 100644 --- a/js/core/tests/flow_test.ts +++ b/js/core/tests/flow_test.ts @@ -14,11 +14,19 @@ * limitations under the License. */ +import { SimpleSpanProcessor } from '@opentelemetry/sdk-trace-base'; import assert from 'node:assert'; import { beforeEach, describe, it } from 'node:test'; -import { defineFlow, defineStreamingFlow } from '../src/flow.js'; -import { getFlowAuth, z } from '../src/index.js'; +import { defineFlow, defineStreamingFlow, run } from '../src/flow.js'; +import { defineAction, getFlowAuth, z } from '../src/index.js'; import { Registry } from '../src/registry.js'; +import { enableTelemetry } from '../src/tracing.js'; +import { TestSpanExporter } from './utils.js'; + +const spanExporter = new TestSpanExporter(); +enableTelemetry({ + spanProcessors: [new SimpleSpanProcessor(spanExporter)], +}); function createTestFlow(registry: Registry) { return defineFlow( @@ -224,4 +232,99 @@ describe('flow', () => { assert.deepEqual(gotChunks, [{ count: 0 }, { count: 1 }, { count: 2 }]); }); }); + + describe('telemetry', async () => { + beforeEach(() => { + spanExporter.exportedSpans = []; + }); + + it('should create a trace', async () => { + const testFlow = createTestFlow(registry); + + const result = await testFlow('foo'); + + assert.equal(result, 'bar foo'); + assert.strictEqual(spanExporter.exportedSpans.length, 1); + assert.strictEqual(spanExporter.exportedSpans[0].displayName, 'testFlow'); + assert.deepStrictEqual(spanExporter.exportedSpans[0].attributes, { + 'genkit:input': '"foo"', + 'genkit:isRoot': true, + 'genkit:metadata:flow:name': 'testFlow', + 'genkit:metadata:flow:state': 'done', + 'genkit:name': 'testFlow', + 'genkit:output': '"bar foo"', + 'genkit:path': '/{testFlow,t:flow}', + 'genkit:state': 'success', + 'genkit:type': 'flow', + }); + }); + + it('records traces of nested actions', async () => { + const testAction = defineAction( + registry, + { + name: 'testAction', + actionType: 'tool', + metadata: { type: 'tool' }, + }, + async (i) => { + return 'bar'; + } + ); + + const testFlow = defineFlow( + registry, + { + name: 'testFlow', + inputSchema: z.string(), + outputSchema: z.string(), + }, + async (input) => { + return run('custom', async () => { + return 'foo ' + (await testAction(undefined)); + }); + } + ); + const result = await testFlow('foo'); + + assert.equal(result, 'foo bar'); + assert.strictEqual(spanExporter.exportedSpans.length, 3); + + assert.strictEqual( + spanExporter.exportedSpans[0].displayName, + 'testAction' + ); + assert.deepStrictEqual(spanExporter.exportedSpans[0].attributes, { + 'genkit:metadata:subtype': 'tool', + 'genkit:name': 'testAction', + 'genkit:output': '"bar"', + 'genkit:path': + '/{testFlow,t:flow}/{custom,t:flowStep}/{testAction,t:action,s:tool}', + 'genkit:state': 'success', + 'genkit:type': 'action', + }); + + assert.strictEqual(spanExporter.exportedSpans[1].displayName, 'custom'); + assert.deepStrictEqual(spanExporter.exportedSpans[1].attributes, { + 'genkit:name': 'custom', + 'genkit:output': '"foo bar"', + 'genkit:path': '/{testFlow,t:flow}/{custom,t:flowStep}', + 'genkit:state': 'success', + 'genkit:type': 'flowStep', + }); + + assert.strictEqual(spanExporter.exportedSpans[2].displayName, 'testFlow'); + assert.deepStrictEqual(spanExporter.exportedSpans[2].attributes, { + 'genkit:input': '"foo"', + 'genkit:isRoot': true, + 'genkit:metadata:flow:name': 'testFlow', + 'genkit:metadata:flow:state': 'done', + 'genkit:name': 'testFlow', + 'genkit:output': '"foo bar"', + 'genkit:path': '/{testFlow,t:flow}', + 'genkit:state': 'success', + 'genkit:type': 'flow', + }); + }); + }); }); diff --git a/js/core/tests/utils.ts b/js/core/tests/utils.ts new file mode 100644 index 000000000..72470e7f0 --- /dev/null +++ b/js/core/tests/utils.ts @@ -0,0 +1,60 @@ +/** + * Copyright 2024 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +import { SpanKind } from '@opentelemetry/api'; +import { ExportResult } from '@opentelemetry/core'; +import { ReadableSpan, SpanExporter } from '@opentelemetry/sdk-trace-base'; + +export class TestSpanExporter implements SpanExporter { + exportedSpans: any[] = []; + + export( + spans: ReadableSpan[], + resultCallback: (result: ExportResult) => void + ): void { + this.exportedSpans.push(...spans.map((s) => this._exportInfo(s))); + resultCallback({ code: 0 }); + } + + shutdown(): Promise { + return this.forceFlush(); + } + + private _exportInfo(span: ReadableSpan) { + return { + spanId: span.spanContext().spanId, + traceId: span.spanContext().traceId, + attributes: { ...span.attributes }, + displayName: span.name, + links: span.links, + spanKind: SpanKind[span.kind], + parentSpanId: span.parentSpanId, + sameProcessAsParentSpan: { value: !span.spanContext().isRemote }, + status: span.status, + timeEvents: { + timeEvent: span.events.map((e) => ({ + annotation: { + attributes: e.attributes ?? {}, + description: e.name, + }, + })), + }, + }; + } + forceFlush(): Promise { + return Promise.resolve(); + } +} diff --git a/js/genkit/src/genkit.ts b/js/genkit/src/genkit.ts index fa05208c4..e091ca222 100644 --- a/js/genkit/src/genkit.ts +++ b/js/genkit/src/genkit.ts @@ -110,7 +110,6 @@ import { defineSchema, defineStreamingFlow, Flow, - FlowConfig, FlowFn, FlowServer, FlowServerOptions, @@ -203,7 +202,11 @@ export class Genkit { defineFlow< I extends z.ZodTypeAny = z.ZodTypeAny, O extends z.ZodTypeAny = z.ZodTypeAny, - >(config: FlowConfig | string, fn: FlowFn): CallableFlow { + S extends z.ZodTypeAny = z.ZodTypeAny, + >( + config: StreamingFlowConfig | string, + fn: FlowFn + ): CallableFlow { const flow = defineFlow(this.registry, config, fn); this.registeredFlows.push(flow.flow); return flow; @@ -212,7 +215,7 @@ export class Genkit { /** * Defines and registers a streaming flow. * - * @todo TODO: Improve this documentation (show snippetss, etc). + * @deprecated use {@link defineFlow} */ defineStreamingFlow< I extends z.ZodTypeAny = z.ZodTypeAny, diff --git a/js/genkit/tests/flow_test.ts b/js/genkit/tests/flow_test.ts new file mode 100644 index 000000000..714b5b111 --- /dev/null +++ b/js/genkit/tests/flow_test.ts @@ -0,0 +1,81 @@ +/** + * Copyright 2024 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +import { z } from '@genkit-ai/core'; +import assert from 'node:assert'; +import { beforeEach, describe, it } from 'node:test'; +import { Genkit, genkit } from '../src/genkit'; + +describe('flow', () => { + let ai: Genkit; + + beforeEach(() => { + ai = genkit({}); + }); + + it('calls simple flow', async () => { + const bananaFlow = ai.defineFlow('banana', () => 'banana'); + + assert.strictEqual(await bananaFlow(), 'banana'); + }); + + it('streams simple chunks (no schema defined)', async () => { + const streamingBananaFlow = ai.defineFlow( + 'banana', + (input: string, sendChunk) => { + for (let i = 0; i < input.length; i++) { + sendChunk(input.charAt(i)); + } + return input; + } + ); + + const { stream, output } = streamingBananaFlow.stream('banana'); + let chunks: string[] = []; + for await (const chunk of stream) { + chunks.push(chunk as string); + } + assert.strictEqual(await output, 'banana'); + assert.deepStrictEqual(chunks, ['b', 'a', 'n', 'a', 'n', 'a']); + }); + + it('streams simple chunks with schema defined', async () => { + const streamingBananaFlow = ai.defineFlow( + { + name: 'banana', + inputSchema: z.string(), + streamSchema: z.string(), + }, + (input, sendChunk) => { + for (let i = 0; i < input.length; i++) { + sendChunk(input.charAt(i)); + } + return input; + } + ); + + const { stream, output } = streamingBananaFlow.stream('banana'); + let chunks: string[] = []; + for await (const chunk of stream) { + chunks.push(chunk); + } + assert.deepStrictEqual(chunks, ['b', 'a', 'n', 'a', 'n', 'a']); + assert.strictEqual(await output, 'banana'); + + // a "streaming" flow can be invoked in non-streaming mode. + assert.strictEqual(await streamingBananaFlow('banana2'), 'banana2'); + }); +});