From 1fc402c4781c489e8a4946e4f696e1f4c6f977c8 Mon Sep 17 00:00:00 2001 From: Adam Fowler Date: Mon, 18 Nov 2024 14:35:34 +0000 Subject: [PATCH] ZLibCompressor/Decompressor (#22) * Use non-copyable zlib compressor/decompressor * make algorithm non-mutable * Make DecompressByteBufferSequence a struct * No need to make DecompressByteBufferSequence.makeAsyncSequence consuming * Update for CompressNIO Zlib updates Cleaned up DecompressByteBufferSequence by using a state machine, also allocate Decompressor on first call to `next()`. * swift format * Include window in state machine, make iterator a struct * Use compress-nio 1.3.0 --- Package.swift | 2 +- .../CompressedBodyWriter.swift | 38 +++---- .../RequestDecompressionMiddleware.swift | 106 ++++++++++-------- .../ResponseCompressionMiddleware.swift | 9 +- 4 files changed, 79 insertions(+), 76 deletions(-) diff --git a/Package.swift b/Package.swift index be88733..2b2cf92 100644 --- a/Package.swift +++ b/Package.swift @@ -11,7 +11,7 @@ let package = Package( ], dependencies: [ .package(url: "https://github.com/hummingbird-project/hummingbird.git", from: "2.0.0"), - .package(url: "https://github.com/adam-fowler/compress-nio.git", from: "1.2.1"), + .package(url: "https://github.com/adam-fowler/compress-nio.git", from: "1.3.0"), .package(url: "https://github.com/apple/swift-nio.git", from: "2.32.1"), ], targets: [ diff --git a/Sources/HummingbirdCompression/CompressedBodyWriter.swift b/Sources/HummingbirdCompression/CompressedBodyWriter.swift index 8254375..9548967 100644 --- a/Sources/HummingbirdCompression/CompressedBodyWriter.swift +++ b/Sources/HummingbirdCompression/CompressedBodyWriter.swift @@ -19,36 +19,29 @@ import Logging // ResponseBodyWriter that writes a compressed version of the response to a parent writer final class CompressedBodyWriter: ResponseBodyWriter { var parentWriter: ParentWriter - let compressor: NIOCompressor + private let compressor: ZlibCompressor + private var window: ByteBuffer var lastBuffer: ByteBuffer? let logger: Logger init( parent: ParentWriter, - algorithm: CompressionAlgorithm, + algorithm: ZlibAlgorithm, + configuration: ZlibConfiguration, windowSize: Int, logger: Logger ) throws { self.parentWriter = parent - self.compressor = algorithm.compressor - self.compressor.window = ByteBufferAllocator().buffer(capacity: windowSize) + self.compressor = try ZlibCompressor(algorithm: algorithm, configuration: configuration) + self.window = ByteBufferAllocator().buffer(capacity: windowSize) self.lastBuffer = nil self.logger = logger - try self.compressor.startStream() - } - - deinit { - do { - try self.compressor.finishStream() - } catch { - logger.error("Error finalizing compression stream: \(error) ") - } } /// Write response buffer func write(_ buffer: ByteBuffer) async throws { var buffer = buffer - try await buffer.compressStream(with: self.compressor, flush: .sync) { buffer in + try await buffer.compressStream(with: self.compressor, window: &self.window, flush: .sync) { buffer in try await self.parentWriter.write(buffer) } // need to store the last buffer so it can be finished once the writer is done @@ -59,17 +52,17 @@ final class CompressedBodyWriter: R /// - Parameter trailingHeaders: Any trailing headers you want to include at end consuming func finish(_ trailingHeaders: HTTPFields?) async throws { // The last buffer must be finished - if var lastBuffer, var window = self.compressor.window { + if var lastBuffer { // keep finishing stream until we don't get a buffer overflow while true { do { - try lastBuffer.compressStream(to: &window, with: self.compressor, flush: .finish) - try await self.parentWriter.write(window) - window.clear() + try lastBuffer.compressStream(to: &self.window, with: self.compressor, flush: .finish) + try await self.parentWriter.write(self.window) + self.window.clear() break } catch let error as CompressNIOError where error == .bufferOverflow { - try await self.parentWriter.write(window) - window.clear() + try await self.parentWriter.write(self.window) + self.window.clear() } } } @@ -87,10 +80,11 @@ extension ResponseBodyWriter { /// - logger: Logger used to output compression errors /// - Returns: new ``HummingbirdCore/ResponseBodyWriter`` public func compressed( - algorithm: CompressionAlgorithm, + algorithm: ZlibAlgorithm, + configuration: ZlibConfiguration, windowSize: Int, logger: Logger ) throws -> some ResponseBodyWriter { - try CompressedBodyWriter(parent: self, algorithm: algorithm, windowSize: windowSize, logger: logger) + try CompressedBodyWriter(parent: self, algorithm: algorithm, configuration: configuration, windowSize: windowSize, logger: logger) } } diff --git a/Sources/HummingbirdCompression/RequestDecompressionMiddleware.swift b/Sources/HummingbirdCompression/RequestDecompressionMiddleware.swift index b516845..07e8bff 100644 --- a/Sources/HummingbirdCompression/RequestDecompressionMiddleware.swift +++ b/Sources/HummingbirdCompression/RequestDecompressionMiddleware.swift @@ -47,13 +47,13 @@ public struct RequestDecompressionMiddleware: RouterMid } /// Determines the decompression algorithm based off content encoding header. - private func algorithm(from contentEncodingHeaders: [String]) -> CompressionAlgorithm? { + private func algorithm(from contentEncodingHeaders: [String]) -> ZlibAlgorithm? { for encoding in contentEncodingHeaders { switch encoding { case "gzip": - return CompressionAlgorithm.gzip() + return .gzip case "deflate": - return CompressionAlgorithm.zlib() + return .zlib default: break } @@ -67,67 +67,75 @@ struct DecompressByteBufferSequence: AsyncSequen typealias Element = ByteBuffer let base: Base - let algorithm: CompressionAlgorithm + let algorithm: ZlibAlgorithm let windowSize: Int let logger: Logger - class AsyncIterator: AsyncIteratorProtocol { - var baseIterator: Base.AsyncIterator - let decompressor: NIODecompressor - var currentBuffer: ByteBuffer? - var window: ByteBuffer - let logger: Logger + init(base: Base, algorithm: ZlibAlgorithm, windowSize: Int, logger: Logger) { + self.base = base + self.algorithm = algorithm + self.windowSize = windowSize + self.logger = logger + } - init(baseIterator: Base.AsyncIterator, algorithm: CompressionAlgorithm, windowSize: Int, logger: Logger) { - self.baseIterator = baseIterator - self.decompressor = algorithm.decompressor - self.window = ByteBufferAllocator().buffer(capacity: windowSize) - self.currentBuffer = nil - self.logger = logger - do { - try self.decompressor.startStream() - } catch { - logger.error("Error initializing decompression stream: \(error) ") - } + struct AsyncIterator: AsyncIteratorProtocol { + enum State { + case uninitialized(ZlibAlgorithm, windowSize: Int) + case decompressing(ZlibDecompressor, buffer: ByteBuffer, window: ByteBuffer) + case done } - deinit { - do { - try self.decompressor.finishStream() - } catch { - logger.error("Error finalizing decompression stream: \(error) ") - } + var baseIterator: Base.AsyncIterator + var state: State + + init(baseIterator: Base.AsyncIterator, algorithm: ZlibAlgorithm, windowSize: Int) { + self.baseIterator = baseIterator + self.state = .uninitialized(algorithm, windowSize: windowSize) } - func next() async throws -> ByteBuffer? { - do { - if self.currentBuffer == nil { - self.currentBuffer = try await self.baseIterator.next() + mutating func next() async throws -> ByteBuffer? { + switch self.state { + case .uninitialized(let algorithm, let windowSize): + guard let buffer = try await self.baseIterator.next() else { + self.state = .done + return nil } - self.window.clear() - while var buffer = self.currentBuffer { - do { - try buffer.decompressStream(to: &self.window, with: self.decompressor) - } catch let error as CompressNIOError where error == .bufferOverflow { - self.currentBuffer = buffer - return self.window - } catch let error as CompressNIOError where error == .inputBufferOverflow { - // can ignore CompressNIOError.inputBufferOverflow errors here - } + let decompressor = try ZlibDecompressor(algorithm: algorithm) + self.state = .decompressing(decompressor, buffer: buffer, window: ByteBufferAllocator().buffer(capacity: windowSize)) + return try await self.next() - self.currentBuffer = try await self.baseIterator.next() + case .decompressing(let decompressor, var buffer, var window): + do { + window.clear() + while true { + do { + try buffer.decompressStream(to: &window, with: decompressor) + } catch let error as CompressNIOError where error == .bufferOverflow { + self.state = .decompressing(decompressor, buffer: buffer, window: window) + return window + } catch let error as CompressNIOError where error == .inputBufferOverflow { + // can ignore CompressNIOError.inputBufferOverflow errors here + } + + guard let nextBuffer = try await self.baseIterator.next() else { + self.state = .done + return window.readableBytes > 0 ? window : nil + } + buffer = nextBuffer + } + } catch let error as CompressNIOError where error == .corruptData { + throw HTTPError(.badRequest, message: "Corrupt compression data.") + } catch { + throw HTTPError(.badRequest, message: "Data decompression failed.") } - self.currentBuffer = nil - return self.window.readableBytes > 0 ? self.window : nil - } catch let error as CompressNIOError where error == .corruptData { - throw HTTPError(.badRequest, message: "Corrupt compression data.") - } catch { - throw HTTPError(.badRequest, message: "Data decompression failed.") + + case .done: + return nil } } } func makeAsyncIterator() -> AsyncIterator { - .init(baseIterator: self.base.makeAsyncIterator(), algorithm: self.algorithm, windowSize: self.windowSize, logger: self.logger) + .init(baseIterator: self.base.makeAsyncIterator(), algorithm: self.algorithm, windowSize: self.windowSize) } } diff --git a/Sources/HummingbirdCompression/ResponseCompressionMiddleware.swift b/Sources/HummingbirdCompression/ResponseCompressionMiddleware.swift index 7030b28..42f89a4 100644 --- a/Sources/HummingbirdCompression/ResponseCompressionMiddleware.swift +++ b/Sources/HummingbirdCompression/ResponseCompressionMiddleware.swift @@ -70,6 +70,7 @@ public struct ResponseCompressionMiddleware: RouterMidd editedResponse.body = .init { writer in let compressWriter = try writer.compressed( algorithm: algorithm, + configuration: self.zlibConfiguration, windowSize: self.windowSize, logger: context.logger ) @@ -95,7 +96,7 @@ public struct ResponseCompressionMiddleware: RouterMidd } /// Determines the compression algorithm to use for the next response. - private func compressionAlgorithm(from acceptContentHeaders: [some StringProtocol]) -> (compressor: CompressionAlgorithm, name: String)? { + private func compressionAlgorithm(from acceptContentHeaders: [some StringProtocol]) -> (algorithm: ZlibAlgorithm, name: String)? { var gzipQValue: Float = -1 var deflateQValue: Float = -1 var anyQValue: Float = -1 @@ -112,15 +113,15 @@ public struct ResponseCompressionMiddleware: RouterMidd if gzipQValue > 0 || deflateQValue > 0 { if gzipQValue > deflateQValue { - return (compressor: CompressionAlgorithm.gzip(configuration: self.zlibConfiguration), name: "gzip") + return (algorithm: .gzip, name: "gzip") } else { - return (compressor: CompressionAlgorithm.zlib(configuration: self.zlibConfiguration), name: "deflate") + return (algorithm: .zlib, name: "deflate") } } else if anyQValue > 0 { // Though gzip is usually less well compressed than deflate, it has slightly // wider support because it's unabiguous. We therefore default to that unless // the client has expressed a preference. - return (compressor: CompressionAlgorithm.gzip(configuration: self.zlibConfiguration), name: "gzip") + return (algorithm: .gzip, name: "gzip") } return nil