diff --git a/sdk/core/core-http/src/fetchHttpClient.ts b/sdk/core/core-http/src/fetchHttpClient.ts index 203489f08edc..59cce702e44b 100644 --- a/sdk/core/core-http/src/fetchHttpClient.ts +++ b/sdk/core/core-http/src/fetchHttpClient.ts @@ -213,7 +213,10 @@ export abstract class FetchHttpClient implements HttpClient { } let downloadStreamDone = Promise.resolve(); if (isReadableStream(operationResponse?.readableStreamBody)) { - downloadStreamDone = isStreamComplete(operationResponse!.readableStreamBody); + downloadStreamDone = isStreamComplete( + operationResponse!.readableStreamBody, + abortController + ); } Promise.all([uploadStreamDone, downloadStreamDone]) @@ -237,11 +240,14 @@ function isReadableStream(body: any): body is Readable { return body && typeof body.pipe === "function"; } -function isStreamComplete(stream: Readable): Promise { +function isStreamComplete(stream: Readable, aborter?: AbortController): Promise { return new Promise((resolve) => { - stream.on("close", resolve); - stream.on("end", resolve); - stream.on("error", resolve); + stream.once("close", () => { + aborter?.abort(); + resolve(); + }); + stream.once("end", resolve); + stream.once("error", resolve); }); } diff --git a/sdk/core/core-http/test/defaultHttpClientTests.node.ts b/sdk/core/core-http/test/defaultHttpClientTests.node.ts index f312e53e808d..ae635d825982 100644 --- a/sdk/core/core-http/test/defaultHttpClientTests.node.ts +++ b/sdk/core/core-http/test/defaultHttpClientTests.node.ts @@ -12,10 +12,11 @@ import { createReadStream, ReadStream } from "fs"; import { DefaultHttpClient } from "../src/defaultHttpClient"; import { WebResource, TransferProgressEvent } from "../src/webResource"; import { getHttpMock, HttpMockFacade } from "./mockHttp"; -import { PassThrough } from "stream"; +import { PassThrough, Readable } from "stream"; import { ReportTransform, CommonResponse } from "../src/fetchHttpClient"; import { CompositeMapper, Serializer } from "../src/serializer"; import { OperationSpec } from "../src/operationSpec"; +import { AbortController } from "@azure/abort-controller"; describe("defaultHttpClient (node)", function() { let httpMock: HttpMockFacade; @@ -427,6 +428,51 @@ describe("defaultHttpClient (node)", function() { requestInit2.agent.proxyOptions.proxyAuth ); }); + + it("should abort connection when download stream is closed", async function() { + const payload = new PassThrough(); + const b = new PassThrough(); + b.pipe(payload, { end: false }); + b.write("hello"); + const response = { + status: 200, + headers: [], + body: payload + }; + + let signal: AbortSignal | undefined; + const client = new DefaultHttpClient(); + sinon.stub(client, "fetch").callsFake(async (_input, init) => { + assert.ok(init, "expecting valid request initialization"); + signal = init!.signal; + return (response as unknown) as CommonResponse; + }); + + const ac = new AbortController(); + const request = new WebResource( + "http://myhost/bigdownload", + "GET", + undefined, + undefined, + undefined, + true + ); + request.abortSignal = ac.signal; + const promise = client.sendRequest(request); + + const res = await promise; + assert.ok(res.readableStreamBody, "Expecting valid download stream"); + + assert.ok(signal, "Expecting valid signal"); + const abortFiredPromise = new Promise((resolve) => { + signal!.onabort = () => { + resolve(); + }; + }); + const stream: Readable = res.readableStreamBody as any; + stream.destroy(); + await abortFiredPromise; // 'abort' event fired + }); }); describe("ReportTransform", function() {