From 2c0e8df0fd7af3a7a9e420dc1756de5443a81a81 Mon Sep 17 00:00:00 2001 From: Tomoki Sekiyama Date: Fri, 24 Jan 2025 20:19:56 +0900 Subject: [PATCH] feat: support CDict & DDict --- README.md | 37 +++++++++ ext/zstdruby/common.h | 46 ++++++++--- ext/zstdruby/main.c | 4 + ext/zstdruby/zstdruby.c | 92 +++++++++++++++++++++ spec/zstd-ruby-streaming-compress_spec.rb | 21 ++++- spec/zstd-ruby-streaming-decompress_spec.rb | 24 +++++- spec/zstd-ruby-using-dict_spec.rb | 38 ++++++++- 7 files changed, 249 insertions(+), 13 deletions(-) diff --git a/README.md b/README.md index 8058701..bf50907 100644 --- a/README.md +++ b/README.md @@ -49,6 +49,22 @@ compressed_data = Zstd.compress(data, level: complession_level) # default compre compressed_using_dict = Zstd.compress("", dict: File.read('dictionary_file')) ``` +#### Compression with CDict + +If you use the same dictionary repeatedly, you can speed up the setup by creating CDict in advance: + +```ruby +cdict = Zstd::CDict.new(File.read('dictionary_file')) +compressed_using_dict = Zstd.compress("", dict: cdict) +``` + +The compression_level can be specified on creating CDict. + +```ruby +cdict = Zstd::CDict.new(File.read('dictionary_file'), 5) +compressed_using_dict = Zstd.compress("", dict: cdict) +``` + #### Streaming Compression ```ruby stream = Zstd::StreamingCompress.new @@ -86,6 +102,16 @@ stream << "ghi" res << stream.finish ``` +#### Streaming Compression with CDict of level 5 +```ruby +cdict = Zstd::CDict.new(File.read('dictionary_file', 5) +stream = Zstd::StreamingCompress.new(dict: cdict) +stream << "abc" << "def" +res = stream.flush +stream << "ghi" +res << stream.finish +``` + ### Decompression #### Simple Decompression @@ -100,6 +126,15 @@ data = Zstd.decompress(compressed_data) Zstd.decompress(compressed_using_dict, dict: File.read('dictionary_file')) ``` +#### Decompression with DDict + +If you use the same dictionary repeatedly, you can speed up the setup by creating DDict in advance: + +```ruby +cdict = Zstd::CDict.new(File.read('dictionary_file')) +compressed_using_dict = Zstd.compress(compressed_using_dict, ddict) +``` + #### Streaming Decompression ```ruby cstr = "" # Compressed data @@ -118,6 +153,8 @@ result << stream.decompress(cstr[0, 10]) result << stream.decompress(cstr[10..-1]) ``` +DDict can also be specified to `dict:`. + ### Skippable frame ```ruby diff --git a/ext/zstdruby/common.h b/ext/zstdruby/common.h index da437cd..996b9e3 100644 --- a/ext/zstdruby/common.h +++ b/ext/zstdruby/common.h @@ -8,6 +8,8 @@ #include #include "./libzstd/zstd.h" +extern VALUE rb_cCDict, rb_cDDict; + static int convert_compression_level(VALUE compression_level_value) { if (NIL_P(compression_level_value)) { @@ -34,12 +36,24 @@ static void set_compress_params(ZSTD_CCtx* const ctx, VALUE level_from_args, VAL ZSTD_CCtx_setParameter(ctx, ZSTD_c_compressionLevel, compression_level); if (kwargs_values[1] != Qundef && kwargs_values[1] != Qnil) { - char* dict_buffer = RSTRING_PTR(kwargs_values[1]); - size_t dict_size = RSTRING_LEN(kwargs_values[1]); - size_t load_dict_ret = ZSTD_CCtx_loadDictionary(ctx, dict_buffer, dict_size); - if (ZSTD_isError(load_dict_ret)) { + if (CLASS_OF(kwargs_values[1]) == rb_cCDict) { + ZSTD_CDict* cdict = DATA_PTR(kwargs_values[1]); + size_t ref_dict_ret = ZSTD_CCtx_refCDict(ctx, cdict); + if (ZSTD_isError(ref_dict_ret)) { + ZSTD_freeCCtx(ctx); + rb_raise(rb_eRuntimeError, "%s", "ZSTD_CCtx_refCDict failed"); + } + } else if (TYPE(kwargs_values[1]) == T_STRING) { + char* dict_buffer = RSTRING_PTR(kwargs_values[1]); + size_t dict_size = RSTRING_LEN(kwargs_values[1]); + size_t load_dict_ret = ZSTD_CCtx_loadDictionary(ctx, dict_buffer, dict_size); + if (ZSTD_isError(load_dict_ret)) { + ZSTD_freeCCtx(ctx); + rb_raise(rb_eRuntimeError, "%s", "ZSTD_CCtx_loadDictionary failed"); + } + } else { ZSTD_freeCCtx(ctx); - rb_raise(rb_eRuntimeError, "%s", "ZSTD_CCtx_loadDictionary failed"); + rb_raise(rb_eArgError, "`dict:` must be a Zstd::CDict or a String"); } } } @@ -113,12 +127,24 @@ static void set_decompress_params(ZSTD_DCtx* const dctx, VALUE kwargs) rb_get_kwargs(kwargs, kwargs_keys, 0, 1, kwargs_values); if (kwargs_values[0] != Qundef && kwargs_values[0] != Qnil) { - char* dict_buffer = RSTRING_PTR(kwargs_values[0]); - size_t dict_size = RSTRING_LEN(kwargs_values[0]); - size_t load_dict_ret = ZSTD_DCtx_loadDictionary(dctx, dict_buffer, dict_size); - if (ZSTD_isError(load_dict_ret)) { + if (CLASS_OF(kwargs_values[0]) == rb_cDDict) { + ZSTD_DDict* ddict = DATA_PTR(kwargs_values[0]); + size_t ref_dict_ret = ZSTD_DCtx_refDDict(dctx, ddict); + if (ZSTD_isError(ref_dict_ret)) { + ZSTD_freeDCtx(dctx); + rb_raise(rb_eRuntimeError, "%s", "ZSTD_DCtx_refDDict failed"); + } + } else if (TYPE(kwargs_values[0]) == T_STRING) { + char* dict_buffer = RSTRING_PTR(kwargs_values[0]); + size_t dict_size = RSTRING_LEN(kwargs_values[0]); + size_t load_dict_ret = ZSTD_DCtx_loadDictionary(dctx, dict_buffer, dict_size); + if (ZSTD_isError(load_dict_ret)) { + ZSTD_freeDCtx(dctx); + rb_raise(rb_eRuntimeError, "%s", "ZSTD_CCtx_loadDictionary failed"); + } + } else { ZSTD_freeDCtx(dctx); - rb_raise(rb_eRuntimeError, "%s", "ZSTD_CCtx_loadDictionary failed"); + rb_raise(rb_eArgError, "`dict:` must be a Zstd::DDict or a String"); } } } diff --git a/ext/zstdruby/main.c b/ext/zstdruby/main.c index 0f2198b..e859fa0 100644 --- a/ext/zstdruby/main.c +++ b/ext/zstdruby/main.c @@ -1,6 +1,8 @@ #include "common.h" VALUE rb_mZstd; +VALUE rb_cCDict; +VALUE rb_cDDict; void zstd_ruby_init(void); void zstd_ruby_skippable_frame_init(void); void zstd_ruby_streaming_compress_init(void); @@ -14,6 +16,8 @@ Init_zstdruby(void) #endif rb_mZstd = rb_define_module("Zstd"); + rb_cCDict = rb_define_class_under(rb_mZstd, "CDict", rb_cObject); + rb_cDDict = rb_define_class_under(rb_mZstd, "DDict", rb_cObject); zstd_ruby_init(); zstd_ruby_skippable_frame_init(); zstd_ruby_streaming_compress_init(); diff --git a/ext/zstdruby/zstdruby.c b/ext/zstdruby/zstdruby.c index 512aa7b..94de5bf 100644 --- a/ext/zstdruby/zstdruby.c +++ b/ext/zstdruby/zstdruby.c @@ -195,6 +195,90 @@ static VALUE rb_decompress_using_dict(int argc, VALUE *argv, VALUE self) return output; } +static void free_cdict(void *dict) +{ + ZSTD_freeCDict(dict); +} + +static size_t sizeof_cdict(const void *dict) +{ + return ZSTD_sizeof_CDict(dict); +} + +static void free_ddict(void *dict) +{ + ZSTD_freeDDict(dict); +} + +static size_t sizeof_ddict(const void *dict) +{ + return ZSTD_sizeof_DDict(dict); +} + +static const rb_data_type_t cdict_type = { + "Zstd::CDict", + {0, free_cdict, sizeof_cdict,}, + 0, 0, RUBY_TYPED_FREE_IMMEDIATELY +}; + +static const rb_data_type_t ddict_type = { + "Zstd::DDict", + {0, free_ddict, sizeof_ddict,}, + 0, 0, RUBY_TYPED_FREE_IMMEDIATELY +}; + +static VALUE rb_cdict_alloc(VALUE self) +{ + ZSTD_CDict* cdict = NULL; + return TypedData_Wrap_Struct(self, &cdict_type, cdict); +} + +static VALUE rb_cdict_initialize(int argc, VALUE *argv, VALUE self) +{ + VALUE dict; + VALUE compression_level_value; + rb_scan_args(argc, argv, "11", &dict, &compression_level_value); + int compression_level = convert_compression_level(compression_level_value); + + StringValue(dict); + char* dict_buffer = RSTRING_PTR(dict); + size_t dict_size = RSTRING_LEN(dict); + + ZSTD_CDict* const cdict = ZSTD_createCDict(dict_buffer, dict_size, compression_level); + if (cdict == NULL) { + rb_raise(rb_eRuntimeError, "%s", "ZSTD_createCDict failed"); + } + + DATA_PTR(self) = cdict; + return self; +} + +static VALUE rb_ddict_alloc(VALUE self) +{ + ZSTD_CDict* ddict = NULL; + return TypedData_Wrap_Struct(self, &ddict_type, ddict); +} + +static VALUE rb_ddict_initialize(VALUE self, VALUE dict) +{ + StringValue(dict); + char* dict_buffer = RSTRING_PTR(dict); + size_t dict_size = RSTRING_LEN(dict); + + ZSTD_DDict* const ddict = ZSTD_createDDict(dict_buffer, dict_size); + if (ddict == NULL) { + rb_raise(rb_eRuntimeError, "%s", "ZSTD_createDDict failed"); + } + + DATA_PTR(self) = ddict; + return self; +} + +static VALUE rb_prohibit_copy(VALUE, VALUE) +{ + rb_raise(rb_eRuntimeError, "CDict cannot be duplicated"); +} + void zstd_ruby_init(void) { @@ -203,4 +287,12 @@ zstd_ruby_init(void) rb_define_module_function(rb_mZstd, "compress_using_dict", rb_compress_using_dict, -1); rb_define_module_function(rb_mZstd, "decompress", rb_decompress, -1); rb_define_module_function(rb_mZstd, "decompress_using_dict", rb_decompress_using_dict, -1); + + rb_define_alloc_func(rb_cCDict, rb_cdict_alloc); + rb_define_private_method(rb_cCDict, "initialize", rb_cdict_initialize, -1); + rb_define_method(rb_cCDict, "initialize_copy", rb_prohibit_copy, 1); + + rb_define_alloc_func(rb_cDDict, rb_ddict_alloc); + rb_define_private_method(rb_cDDict, "initialize", rb_ddict_initialize, 1); + rb_define_method(rb_cDDict, "initialize_copy", rb_prohibit_copy, 1); } diff --git a/spec/zstd-ruby-streaming-compress_spec.rb b/spec/zstd-ruby-streaming-compress_spec.rb index 770a2bb..6583b2f 100644 --- a/spec/zstd-ruby-streaming-compress_spec.rb +++ b/spec/zstd-ruby-streaming-compress_spec.rb @@ -53,7 +53,7 @@ end end - describe 'dictionary' do + describe 'String dictionary' do let(:dictionary) do File.read("#{__dir__}/dictionary") end @@ -72,6 +72,25 @@ end end + describe 'Zstd::CDict dictionary' do + let(:cdict) do + Zstd::CDict.new(File.read("#{__dir__}/dictionary"), 5) + end + let(:user_json) do + File.read("#{__dir__}/user_springmt.json") + end + it 'shoud work' do + dict_stream = Zstd::StreamingCompress.new(dict: cdict) + dict_stream << user_json + dict_res = dict_stream.finish + stream = Zstd::StreamingCompress.new(level: 5) + stream << user_json + res = stream.finish + + expect(dict_res.length).to be < res.length + end + end + describe 'nil dictionary' do let(:user_json) do File.read("#{__dir__}/user_springmt.json") diff --git a/spec/zstd-ruby-streaming-decompress_spec.rb b/spec/zstd-ruby-streaming-decompress_spec.rb index d635171..6eb0534 100644 --- a/spec/zstd-ruby-streaming-decompress_spec.rb +++ b/spec/zstd-ruby-streaming-decompress_spec.rb @@ -32,7 +32,7 @@ end end - describe 'dictionary streaming decompress + GC.compact' do + describe 'String dictionary streaming decompress + GC.compact' do let(:dictionary) do File.read("#{__dir__}/dictionary") end @@ -51,6 +51,28 @@ end end + describe 'Zstd::DDict dictionary streaming decompress + GC.compact' do + let(:dictionary) do + File.read("#{__dir__}/dictionary") + end + let(:ddict) do + Zstd::DDict.new(dictionary) + end + let(:user_json) do + File.read("#{__dir__}/user_springmt.json") + end + it 'shoud work' do + compressed_json = Zstd.compress(user_json, dict: dictionary) + stream = Zstd::StreamingDecompress.new(dict: ddict) + result = '' + result << stream.decompress(compressed_json[0, 5]) + result << stream.decompress(compressed_json[5, 5]) + GC.compact + result << stream.decompress(compressed_json[10..-1]) + expect(result).to eq(user_json) + end + end + describe 'nil dictionary streaming decompress + GC.compact' do let(:dictionary) do File.read("#{__dir__}/dictionary") diff --git a/spec/zstd-ruby-using-dict_spec.rb b/spec/zstd-ruby-using-dict_spec.rb index bfab595..b946567 100644 --- a/spec/zstd-ruby-using-dict_spec.rb +++ b/spec/zstd-ruby-using-dict_spec.rb @@ -7,7 +7,7 @@ # https://github.com/facebook/zstd/releases/tag/v1.1.3 RSpec.describe Zstd do - describe 'compress and decompress with dict keyward args' do + describe 'compress and decompress with String dict keyward args' do let(:user_json) do File.read("#{__dir__}/user_springmt.json") end @@ -52,6 +52,42 @@ end end + describe 'compress and decompress with Zstd::CDict and Zstd::DDict dict keyward args' do + let(:user_json) do + File.read("#{__dir__}/user_springmt.json") + end + let(:cdict) do + Zstd::CDict.new(File.read("#{__dir__}/dictionary")) + end + let(:cdict_10) do + Zstd::CDict.new(File.read("#{__dir__}/dictionary"), 10) + end + let(:ddict) do + Zstd::DDict.new(File.read("#{__dir__}/dictionary")) + end + + it 'should work' do + compressed_using_dict = Zstd.compress(user_json, dict: cdict) + compressed = Zstd.compress(user_json) + expect(compressed_using_dict.length).to be < compressed.length + expect(user_json).to eq(Zstd.decompress(compressed_using_dict, dict: ddict)) + end + + it 'should be able to use dictionary multiple times' do + compressed_using_dict = Zstd.compress(user_json, dict: cdict) + expect(compressed_using_dict).to eq(Zstd.compress(user_json, dict: cdict)) + expect(user_json).to eq(Zstd.decompress(compressed_using_dict, dict: ddict)) + expect(user_json).to eq(Zstd.decompress(compressed_using_dict, dict: ddict)) + end + + it 'should support compression levels' do + compressed_using_dict = Zstd.compress(user_json, dict: cdict) + compressed_using_dict_10 = Zstd.compress(user_json, dict: cdict_10) + expect(compressed_using_dict_10.length).to be < compressed_using_dict.length + expect(user_json).to eq(Zstd.decompress(compressed_using_dict_10, dict: ddict)) + end + end + describe 'compress_using_dict' do let(:user_json) do File.read("#{__dir__}/user_springmt.json")