diff --git a/packages/opentelemetry-context-async-hooks/src/AsyncHooksContextManager.ts b/packages/opentelemetry-context-async-hooks/src/AsyncHooksContextManager.ts index 138ab5e2f0e..4ea8fe78d4a 100644 --- a/packages/opentelemetry-context-async-hooks/src/AsyncHooksContextManager.ts +++ b/packages/opentelemetry-context-async-hooks/src/AsyncHooksContextManager.ts @@ -29,6 +29,19 @@ type PatchedEventEmitter = { __ot_listeners?: { [name: string]: WeakMap, Func> }; } & EventEmitter; +class Reference { + constructor(private _value: T) {} + + set(value: T) { + this._value = value; + return this; + } + + get() { + return this._value; + } +} + const ADD_LISTENER_METHODS = [ 'addListener' as 'addListener', 'on' as 'on', @@ -39,9 +52,7 @@ const ADD_LISTENER_METHODS = [ export class AsyncHooksContextManager implements ContextManager { private _asyncHook: asyncHooks.AsyncHook; - private _contexts: { - [uid: number]: Context | undefined | null; - } = Object.create(null); + private _contextRefs: Map | undefined> = new Map(); constructor() { this._asyncHook = asyncHooks.createHook({ @@ -52,9 +63,8 @@ export class AsyncHooksContextManager implements ContextManager { } active(): Context { - return ( - this._contexts[asyncHooks.executionAsyncId()] || Context.ROOT_CONTEXT - ); + const ref = this._contextRefs.get(asyncHooks.executionAsyncId()); + return ref === undefined ? Context.ROOT_CONTEXT : ref.get(); } with ReturnType>( @@ -62,8 +72,15 @@ export class AsyncHooksContextManager implements ContextManager { fn: T ): ReturnType { const uid = asyncHooks.executionAsyncId(); - const oldContext = this._contexts[uid]; - this._contexts[uid] = context; + let ref = this._contextRefs.get(uid); + let oldContext: Context | undefined = undefined; + if (ref === undefined) { + ref = new Reference(context); + this._contextRefs.set(uid, ref); + } else { + oldContext = ref.get(); + ref.set(context); + } try { return fn(); } catch (err) { @@ -72,7 +89,34 @@ export class AsyncHooksContextManager implements ContextManager { if (oldContext === undefined) { this._destroy(uid); } else { - this._contexts[uid] = oldContext; + ref.set(oldContext); + } + } + } + + async withAsync, U extends (...args: unknown[]) => T>( + context: Context, + fn: U + ): Promise { + const uid = asyncHooks.executionAsyncId(); + let ref = this._contextRefs.get(uid); + let oldContext: Context | undefined = undefined; + if (ref === undefined) { + ref = new Reference(context); + this._contextRefs.set(uid, ref); + } else { + oldContext = ref.get(); + ref.set(context); + } + try { + return await fn(); + } catch (err) { + throw err; + } finally { + if (oldContext === undefined) { + this._destroy(uid); + } else { + ref.set(oldContext); } } } @@ -97,7 +141,7 @@ export class AsyncHooksContextManager implements ContextManager { disable(): this { this._asyncHook.disable(); - this._contexts = {}; + this._contextRefs.clear(); return this; } @@ -232,7 +276,10 @@ export class AsyncHooksContextManager implements ContextManager { * @param uid id of the async context */ private _init(uid: number) { - this._contexts[uid] = this._contexts[asyncHooks.executionAsyncId()]; + const ref = this._contextRefs.get(asyncHooks.executionAsyncId()); + if (ref !== undefined) { + this._contextRefs.set(uid, ref); + } } /** @@ -241,6 +288,6 @@ export class AsyncHooksContextManager implements ContextManager { * @param uid uid of the async context */ private _destroy(uid: number) { - delete this._contexts[uid]; + this._contextRefs.delete(uid); } } diff --git a/packages/opentelemetry-context-async-hooks/test/AsyncHooksContextManager.test.ts b/packages/opentelemetry-context-async-hooks/test/AsyncHooksContextManager.test.ts index 5bc4b99edec..8c6c648a883 100644 --- a/packages/opentelemetry-context-async-hooks/test/AsyncHooksContextManager.test.ts +++ b/packages/opentelemetry-context-async-hooks/test/AsyncHooksContextManager.test.ts @@ -104,6 +104,172 @@ describe('AsyncHooksContextManager', () => { }); }); + describe('.withAsync()', () => { + it('should run the callback', async () => { + let done = false; + await contextManager.withAsync(Context.ROOT_CONTEXT, async () => { + done = true; + }); + + assert.ok(done); + }); + + it('should run the callback with active scope', async () => { + const test = Context.ROOT_CONTEXT.setValue(key1, 1); + await contextManager.withAsync(test, async () => { + assert.strictEqual(contextManager.active(), test, 'should have scope'); + }); + }); + + it('should run the callback (when disabled)', async () => { + contextManager.disable(); + let done = false; + await contextManager.withAsync(Context.ROOT_CONTEXT, async () => { + done = true; + }); + + assert.ok(done); + }); + + it('should rethrow errors', async () => { + contextManager.disable(); + let done = false; + const err = new Error(); + + try { + await contextManager.withAsync(Context.ROOT_CONTEXT, async () => { + throw err; + }); + } catch (e) { + assert.ok(e === err); + done = true; + } + + assert.ok(done); + }); + + it('should finally restore an old scope', async () => { + const scope1 = '1' as any; + const scope2 = '2' as any; + let done = false; + + await contextManager.withAsync(scope1, async () => { + assert.strictEqual(contextManager.active(), scope1); + await contextManager.withAsync(scope2, async () => { + assert.strictEqual(contextManager.active(), scope2); + done = true; + }); + assert.strictEqual(contextManager.active(), scope1); + }); + + assert.ok(done); + }); + }); + + describe('.withAsync/with()', () => { + it('with() inside withAsync() should correctly restore context', async () => { + const scope1 = '1' as any; + const scope2 = '2' as any; + let done = false; + + await contextManager.withAsync(scope1, async () => { + assert.strictEqual(contextManager.active(), scope1); + contextManager.with(scope2, () => { + assert.strictEqual(contextManager.active(), scope2); + done = true; + }); + assert.strictEqual(contextManager.active(), scope1); + }); + + assert.ok(done); + }); + + it('withAsync() inside with() should correctly restore conxtext', done => { + const scope1 = '1' as any; + const scope2 = '2' as any; + + contextManager.with(scope1, async () => { + assert.strictEqual(contextManager.active(), scope1); + await contextManager.withAsync(scope2, async () => { + assert.strictEqual(contextManager.active(), scope2); + }); + assert.strictEqual(contextManager.active(), scope1); + return done(); + }); + assert.strictEqual(contextManager.active(), Context.ROOT_CONTEXT); + }); + + it('not awaited withAsync() inside with() should not restore context', done => { + const scope1 = '1' as any; + const scope2 = '2' as any; + let _done: boolean = false; + + contextManager.with(scope1, () => { + assert.strictEqual(contextManager.active(), scope1); + contextManager + .withAsync(scope2, async () => { + assert.strictEqual(contextManager.active(), scope2); + }) + .then(() => { + assert.strictEqual(contextManager.active(), scope1); + _done = true; + }); + // in this case the current scope is 2 since we + // didnt waited the withAsync call + assert.strictEqual(contextManager.active(), scope2); + setTimeout(() => { + assert.strictEqual(contextManager.active(), scope1); + assert(_done); + return done(); + }, 100); + }); + assert.strictEqual(contextManager.active(), Context.ROOT_CONTEXT); + }); + + it('withAsync() inside a setTimeout inside a with() should correctly restore context', done => { + const scope1 = '1' as any; + const scope2 = '2' as any; + + contextManager.with(scope1, () => { + assert.strictEqual(contextManager.active(), scope1); + setTimeout(() => { + assert.strictEqual(contextManager.active(), scope1); + contextManager + .withAsync(scope2, async () => { + assert.strictEqual(contextManager.active(), scope2); + }) + .then(() => { + assert.strictEqual(contextManager.active(), scope1); + return done(); + }); + }, 5); + assert.strictEqual(contextManager.active(), scope1); + }); + assert.strictEqual(contextManager.active(), Context.ROOT_CONTEXT); + }); + + it('with() inside a setTimeout inside withAsync() should correctly restore context', done => { + const scope1 = '1' as any; + const scope2 = '2' as any; + + contextManager + .withAsync(scope1, async () => { + assert.strictEqual(contextManager.active(), scope1); + setTimeout(() => { + assert.strictEqual(contextManager.active(), scope1); + contextManager.with(scope2, () => { + assert.strictEqual(contextManager.active(), scope2); + return done(); + }); + }, 5); + assert.strictEqual(contextManager.active(), scope1); + }) + .then(() => { + assert.strictEqual(contextManager.active(), scope1); + }); + }); + }); + describe('.bind(function)', () => { it('should return the same target (when enabled)', () => { const test = { a: 1 };