Skip to content

Commit

Permalink
feat(js/flows): consolidated defineFlow and defineStreamingFlow (#…
Browse files Browse the repository at this point in the history
  • Loading branch information
pavelgj authored Dec 4, 2024
1 parent 0dbc518 commit f69dac3
Show file tree
Hide file tree
Showing 5 changed files with 293 additions and 31 deletions.
67 changes: 41 additions & 26 deletions js/core/src/flow.ts
Original file line number Diff line number Diff line change
Expand Up @@ -91,32 +91,37 @@ export interface StreamingFlowConfig<
streamSchema?: S;
}

export interface FlowCallOptions {
/** @deprecated use {@link context} instead. */
withLocalAuthContext?: unknown;
context?: unknown;
}

/**
* Non-streaming flow that can be called directly like a function.
*/
export interface CallableFlow<
I extends z.ZodTypeAny = z.ZodTypeAny,
O extends z.ZodTypeAny = z.ZodTypeAny,
S extends z.ZodTypeAny = z.ZodTypeAny,
> {
(
input?: z.infer<I>,
opts?: { withLocalAuthContext?: unknown }
): Promise<z.infer<O>>;
(input?: z.infer<I>, opts?: FlowCallOptions): Promise<z.infer<O>>;

stream(input?: z.infer<I>, opts?: FlowCallOptions): StreamingResponse<O, S>;

flow: Flow<I, O, z.ZodVoid>;
}

/**
* Streaming flow that can be called directly like a function.
* @deprecated use {@link CallableFlow}
*/
export interface StreamableFlow<
I extends z.ZodTypeAny = z.ZodTypeAny,
O extends z.ZodTypeAny = z.ZodTypeAny,
S extends z.ZodTypeAny = z.ZodTypeAny,
> {
(
input?: z.infer<I>,
opts?: { withLocalAuthContext?: unknown }
): StreamingResponse<O, S>;
(input?: z.infer<I>, opts?: FlowCallOptions): StreamingResponse<O, S>;
flow: Flow<I, O, S>;
}

Expand All @@ -128,7 +133,7 @@ interface StreamingResponse<
S extends z.ZodTypeAny = z.ZodTypeAny,
> {
/** Iterator over the streaming chunks. */
stream: AsyncGenerator<unknown, z.infer<O>, z.infer<S> | undefined>;
stream: AsyncGenerator<z.infer<S>>;
/** Final output of the flow. */
output: Promise<z.infer<O>>;
}
Expand All @@ -144,7 +149,7 @@ export type FlowFn<
/** Input to the flow. */
input: z.infer<I>,
/** Callback for streaming functions only. */
streamingCallback?: StreamingCallback<z.infer<S>>
streamingCallback: StreamingCallback<z.infer<S>>
) => Promise<z.infer<O>> | z.infer<O>;

/**
Expand Down Expand Up @@ -223,7 +228,10 @@ export class Flow<
});
try {
metadata.input = input;
const output = await this.flowFn(input, opts.streamingCallback);
const output = await this.flowFn(
input,
opts.streamingCallback ?? (() => {})
);
metadata.output = JSON.stringify(output);
setCustomMetadataAttribute(flowMetadataPrefix('state'), 'done');
return {
Expand Down Expand Up @@ -252,10 +260,7 @@ export class Flow<
/**
* Runs the flow. This is used when calling a flow from another flow.
*/
async run(
payload?: z.infer<I>,
opts?: { withLocalAuthContext?: unknown }
): Promise<z.infer<O>> {
async run(payload?: z.infer<I>, opts?: FlowCallOptions): Promise<z.infer<O>> {
const input = this.inputSchema ? this.inputSchema.parse(payload) : payload;
await this.authPolicy?.(opts?.withLocalAuthContext, payload);

Expand All @@ -266,7 +271,7 @@ export class Flow<
}

const result = await this.invoke(input, {
auth: opts?.withLocalAuthContext,
auth: opts?.context || opts?.withLocalAuthContext,
});
return result.result;
}
Expand All @@ -276,7 +281,7 @@ export class Flow<
*/
stream(
payload?: z.infer<I>,
opts?: { withLocalAuthContext?: unknown }
opts?: FlowCallOptions
): StreamingResponse<O, S> {
let chunkStreamController: ReadableStreamController<z.infer<S>>;
const chunkStream = new ReadableStream<z.infer<S>>({
Expand All @@ -288,7 +293,7 @@ export class Flow<
});

const authPromise =
this.authPolicy?.(opts?.withLocalAuthContext, payload) ??
this.authPolicy?.(opts?.context || opts?.withLocalAuthContext, payload) ??
Promise.resolve();

const invocationPromise = authPromise
Expand All @@ -301,7 +306,7 @@ export class Flow<
}) as S extends z.ZodVoid
? undefined
: StreamingCallback<z.infer<S>>,
auth: opts?.withLocalAuthContext,
auth: opts?.context || opts?.withLocalAuthContext,
}
).then((s) => s.result)
)
Expand Down Expand Up @@ -530,21 +535,31 @@ export class FlowServer {
export function defineFlow<
I extends z.ZodTypeAny = z.ZodTypeAny,
O extends z.ZodTypeAny = z.ZodTypeAny,
S extends z.ZodTypeAny = z.ZodTypeAny,
>(
registry: Registry,
config: FlowConfig<I, O> | string,
fn: FlowFn<I, O, z.ZodVoid>
): CallableFlow<I, O> {
config: StreamingFlowConfig<I, O> | string,
fn: FlowFn<I, O, S>
): CallableFlow<I, O, S> {
const resolvedConfig: FlowConfig<I, O> =
typeof config === 'string' ? { name: config } : config;

const flow = new Flow<I, O, z.ZodVoid>(registry, resolvedConfig, fn);
const flow = new Flow<I, O, S>(registry, resolvedConfig, fn);
registerFlowAction(registry, flow);
const callableFlow: CallableFlow<I, O> = async (input, opts) => {
const callableFlow = async (
input: z.infer<I>,
opts: FlowCallOptions
): Promise<z.infer<O>> => {
return flow.run(input, opts);
};
callableFlow.flow = flow;
return callableFlow;
(callableFlow as CallableFlow<I, O, S>).flow = flow;
(callableFlow as CallableFlow<I, O, S>).stream = (
input: z.infer<I>,
opts: FlowCallOptions
): StreamingResponse<O, S> => {
return flow.stream(input, opts);
};
return callableFlow as CallableFlow<I, O, S>;
}

/**
Expand Down
107 changes: 105 additions & 2 deletions js/core/tests/flow_test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -14,11 +14,19 @@
* limitations under the License.
*/

import { SimpleSpanProcessor } from '@opentelemetry/sdk-trace-base';
import assert from 'node:assert';
import { beforeEach, describe, it } from 'node:test';
import { defineFlow, defineStreamingFlow } from '../src/flow.js';
import { getFlowAuth, z } from '../src/index.js';
import { defineFlow, defineStreamingFlow, run } from '../src/flow.js';
import { defineAction, getFlowAuth, z } from '../src/index.js';
import { Registry } from '../src/registry.js';
import { enableTelemetry } from '../src/tracing.js';
import { TestSpanExporter } from './utils.js';

const spanExporter = new TestSpanExporter();
enableTelemetry({
spanProcessors: [new SimpleSpanProcessor(spanExporter)],
});

function createTestFlow(registry: Registry) {
return defineFlow(
Expand Down Expand Up @@ -224,4 +232,99 @@ describe('flow', () => {
assert.deepEqual(gotChunks, [{ count: 0 }, { count: 1 }, { count: 2 }]);
});
});

describe('telemetry', async () => {
beforeEach(() => {
spanExporter.exportedSpans = [];
});

it('should create a trace', async () => {
const testFlow = createTestFlow(registry);

const result = await testFlow('foo');

assert.equal(result, 'bar foo');
assert.strictEqual(spanExporter.exportedSpans.length, 1);
assert.strictEqual(spanExporter.exportedSpans[0].displayName, 'testFlow');
assert.deepStrictEqual(spanExporter.exportedSpans[0].attributes, {
'genkit:input': '"foo"',
'genkit:isRoot': true,
'genkit:metadata:flow:name': 'testFlow',
'genkit:metadata:flow:state': 'done',
'genkit:name': 'testFlow',
'genkit:output': '"bar foo"',
'genkit:path': '/{testFlow,t:flow}',
'genkit:state': 'success',
'genkit:type': 'flow',
});
});

it('records traces of nested actions', async () => {
const testAction = defineAction(
registry,
{
name: 'testAction',
actionType: 'tool',
metadata: { type: 'tool' },
},
async (i) => {
return 'bar';
}
);

const testFlow = defineFlow(
registry,
{
name: 'testFlow',
inputSchema: z.string(),
outputSchema: z.string(),
},
async (input) => {
return run('custom', async () => {
return 'foo ' + (await testAction(undefined));
});
}
);
const result = await testFlow('foo');

assert.equal(result, 'foo bar');
assert.strictEqual(spanExporter.exportedSpans.length, 3);

assert.strictEqual(
spanExporter.exportedSpans[0].displayName,
'testAction'
);
assert.deepStrictEqual(spanExporter.exportedSpans[0].attributes, {
'genkit:metadata:subtype': 'tool',
'genkit:name': 'testAction',
'genkit:output': '"bar"',
'genkit:path':
'/{testFlow,t:flow}/{custom,t:flowStep}/{testAction,t:action,s:tool}',
'genkit:state': 'success',
'genkit:type': 'action',
});

assert.strictEqual(spanExporter.exportedSpans[1].displayName, 'custom');
assert.deepStrictEqual(spanExporter.exportedSpans[1].attributes, {
'genkit:name': 'custom',
'genkit:output': '"foo bar"',
'genkit:path': '/{testFlow,t:flow}/{custom,t:flowStep}',
'genkit:state': 'success',
'genkit:type': 'flowStep',
});

assert.strictEqual(spanExporter.exportedSpans[2].displayName, 'testFlow');
assert.deepStrictEqual(spanExporter.exportedSpans[2].attributes, {
'genkit:input': '"foo"',
'genkit:isRoot': true,
'genkit:metadata:flow:name': 'testFlow',
'genkit:metadata:flow:state': 'done',
'genkit:name': 'testFlow',
'genkit:output': '"foo bar"',
'genkit:path': '/{testFlow,t:flow}',
'genkit:state': 'success',
'genkit:type': 'flow',
});
});
});
});
60 changes: 60 additions & 0 deletions js/core/tests/utils.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,60 @@
/**
* Copyright 2024 Google LLC
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

import { SpanKind } from '@opentelemetry/api';
import { ExportResult } from '@opentelemetry/core';
import { ReadableSpan, SpanExporter } from '@opentelemetry/sdk-trace-base';

export class TestSpanExporter implements SpanExporter {
exportedSpans: any[] = [];

export(
spans: ReadableSpan[],
resultCallback: (result: ExportResult) => void
): void {
this.exportedSpans.push(...spans.map((s) => this._exportInfo(s)));
resultCallback({ code: 0 });
}

shutdown(): Promise<void> {
return this.forceFlush();
}

private _exportInfo(span: ReadableSpan) {
return {
spanId: span.spanContext().spanId,
traceId: span.spanContext().traceId,
attributes: { ...span.attributes },
displayName: span.name,
links: span.links,
spanKind: SpanKind[span.kind],
parentSpanId: span.parentSpanId,
sameProcessAsParentSpan: { value: !span.spanContext().isRemote },
status: span.status,
timeEvents: {
timeEvent: span.events.map((e) => ({
annotation: {
attributes: e.attributes ?? {},
description: e.name,
},
})),
},
};
}
forceFlush(): Promise<void> {
return Promise.resolve();
}
}
9 changes: 6 additions & 3 deletions js/genkit/src/genkit.ts
Original file line number Diff line number Diff line change
Expand Up @@ -110,7 +110,6 @@ import {
defineSchema,
defineStreamingFlow,
Flow,
FlowConfig,
FlowFn,
FlowServer,
FlowServerOptions,
Expand Down Expand Up @@ -203,7 +202,11 @@ export class Genkit {
defineFlow<
I extends z.ZodTypeAny = z.ZodTypeAny,
O extends z.ZodTypeAny = z.ZodTypeAny,
>(config: FlowConfig<I, O> | string, fn: FlowFn<I, O>): CallableFlow<I, O> {
S extends z.ZodTypeAny = z.ZodTypeAny,
>(
config: StreamingFlowConfig<I, O, S> | string,
fn: FlowFn<I, O, S>
): CallableFlow<I, O, S> {
const flow = defineFlow(this.registry, config, fn);
this.registeredFlows.push(flow.flow);
return flow;
Expand All @@ -212,7 +215,7 @@ export class Genkit {
/**
* Defines and registers a streaming flow.
*
* @todo TODO: Improve this documentation (show snippetss, etc).
* @deprecated use {@link defineFlow}
*/
defineStreamingFlow<
I extends z.ZodTypeAny = z.ZodTypeAny,
Expand Down
Loading

0 comments on commit f69dac3

Please sign in to comment.