Skip to content

Commit

Permalink
feat: support CDict & DDict
Browse files Browse the repository at this point in the history
  • Loading branch information
sekiyama58 committed Jan 24, 2025
1 parent 08fdaa9 commit 2c0e8df
Show file tree
Hide file tree
Showing 7 changed files with 249 additions and 13 deletions.
37 changes: 37 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,22 @@ compressed_data = Zstd.compress(data, level: complession_level) # default compre
compressed_using_dict = Zstd.compress("", dict: File.read('dictionary_file'))
```

#### Compression with CDict

If you use the same dictionary repeatedly, you can speed up the setup by creating CDict in advance:

```ruby
cdict = Zstd::CDict.new(File.read('dictionary_file'))
compressed_using_dict = Zstd.compress("", dict: cdict)
```

The compression_level can be specified on creating CDict.

```ruby
cdict = Zstd::CDict.new(File.read('dictionary_file'), 5)
compressed_using_dict = Zstd.compress("", dict: cdict)
```

#### Streaming Compression
```ruby
stream = Zstd::StreamingCompress.new
Expand Down Expand Up @@ -86,6 +102,16 @@ stream << "ghi"
res << stream.finish
```

#### Streaming Compression with CDict of level 5
```ruby
cdict = Zstd::CDict.new(File.read('dictionary_file', 5)
stream = Zstd::StreamingCompress.new(dict: cdict)
stream << "abc" << "def"
res = stream.flush
stream << "ghi"
res << stream.finish
```

### Decompression

#### Simple Decompression
Expand All @@ -100,6 +126,15 @@ data = Zstd.decompress(compressed_data)
Zstd.decompress(compressed_using_dict, dict: File.read('dictionary_file'))
```

#### Decompression with DDict

If you use the same dictionary repeatedly, you can speed up the setup by creating DDict in advance:

```ruby
cdict = Zstd::CDict.new(File.read('dictionary_file'))
compressed_using_dict = Zstd.compress(compressed_using_dict, ddict)
```

#### Streaming Decompression
```ruby
cstr = "" # Compressed data
Expand All @@ -118,6 +153,8 @@ result << stream.decompress(cstr[0, 10])
result << stream.decompress(cstr[10..-1])
```

DDict can also be specified to `dict:`.

### Skippable frame

```ruby
Expand Down
46 changes: 36 additions & 10 deletions ext/zstdruby/common.h
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,8 @@
#include <stdbool.h>
#include "./libzstd/zstd.h"

extern VALUE rb_cCDict, rb_cDDict;

static int convert_compression_level(VALUE compression_level_value)
{
if (NIL_P(compression_level_value)) {
Expand All @@ -34,12 +36,24 @@ static void set_compress_params(ZSTD_CCtx* const ctx, VALUE level_from_args, VAL
ZSTD_CCtx_setParameter(ctx, ZSTD_c_compressionLevel, compression_level);

if (kwargs_values[1] != Qundef && kwargs_values[1] != Qnil) {
char* dict_buffer = RSTRING_PTR(kwargs_values[1]);
size_t dict_size = RSTRING_LEN(kwargs_values[1]);
size_t load_dict_ret = ZSTD_CCtx_loadDictionary(ctx, dict_buffer, dict_size);
if (ZSTD_isError(load_dict_ret)) {
if (CLASS_OF(kwargs_values[1]) == rb_cCDict) {
ZSTD_CDict* cdict = DATA_PTR(kwargs_values[1]);
size_t ref_dict_ret = ZSTD_CCtx_refCDict(ctx, cdict);
if (ZSTD_isError(ref_dict_ret)) {
ZSTD_freeCCtx(ctx);
rb_raise(rb_eRuntimeError, "%s", "ZSTD_CCtx_refCDict failed");
}
} else if (TYPE(kwargs_values[1]) == T_STRING) {
char* dict_buffer = RSTRING_PTR(kwargs_values[1]);
size_t dict_size = RSTRING_LEN(kwargs_values[1]);
size_t load_dict_ret = ZSTD_CCtx_loadDictionary(ctx, dict_buffer, dict_size);
if (ZSTD_isError(load_dict_ret)) {
ZSTD_freeCCtx(ctx);
rb_raise(rb_eRuntimeError, "%s", "ZSTD_CCtx_loadDictionary failed");
}
} else {
ZSTD_freeCCtx(ctx);
rb_raise(rb_eRuntimeError, "%s", "ZSTD_CCtx_loadDictionary failed");
rb_raise(rb_eArgError, "`dict:` must be a Zstd::CDict or a String");
}
}
}
Expand Down Expand Up @@ -113,12 +127,24 @@ static void set_decompress_params(ZSTD_DCtx* const dctx, VALUE kwargs)
rb_get_kwargs(kwargs, kwargs_keys, 0, 1, kwargs_values);

if (kwargs_values[0] != Qundef && kwargs_values[0] != Qnil) {
char* dict_buffer = RSTRING_PTR(kwargs_values[0]);
size_t dict_size = RSTRING_LEN(kwargs_values[0]);
size_t load_dict_ret = ZSTD_DCtx_loadDictionary(dctx, dict_buffer, dict_size);
if (ZSTD_isError(load_dict_ret)) {
if (CLASS_OF(kwargs_values[0]) == rb_cDDict) {
ZSTD_DDict* ddict = DATA_PTR(kwargs_values[0]);
size_t ref_dict_ret = ZSTD_DCtx_refDDict(dctx, ddict);
if (ZSTD_isError(ref_dict_ret)) {
ZSTD_freeDCtx(dctx);
rb_raise(rb_eRuntimeError, "%s", "ZSTD_DCtx_refDDict failed");
}
} else if (TYPE(kwargs_values[0]) == T_STRING) {
char* dict_buffer = RSTRING_PTR(kwargs_values[0]);
size_t dict_size = RSTRING_LEN(kwargs_values[0]);
size_t load_dict_ret = ZSTD_DCtx_loadDictionary(dctx, dict_buffer, dict_size);
if (ZSTD_isError(load_dict_ret)) {
ZSTD_freeDCtx(dctx);
rb_raise(rb_eRuntimeError, "%s", "ZSTD_CCtx_loadDictionary failed");
}
} else {
ZSTD_freeDCtx(dctx);
rb_raise(rb_eRuntimeError, "%s", "ZSTD_CCtx_loadDictionary failed");
rb_raise(rb_eArgError, "`dict:` must be a Zstd::DDict or a String");
}
}
}
Expand Down
4 changes: 4 additions & 0 deletions ext/zstdruby/main.c
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
#include "common.h"

VALUE rb_mZstd;
VALUE rb_cCDict;
VALUE rb_cDDict;
void zstd_ruby_init(void);
void zstd_ruby_skippable_frame_init(void);
void zstd_ruby_streaming_compress_init(void);
Expand All @@ -14,6 +16,8 @@ Init_zstdruby(void)
#endif

rb_mZstd = rb_define_module("Zstd");
rb_cCDict = rb_define_class_under(rb_mZstd, "CDict", rb_cObject);
rb_cDDict = rb_define_class_under(rb_mZstd, "DDict", rb_cObject);
zstd_ruby_init();
zstd_ruby_skippable_frame_init();
zstd_ruby_streaming_compress_init();
Expand Down
92 changes: 92 additions & 0 deletions ext/zstdruby/zstdruby.c
Original file line number Diff line number Diff line change
Expand Up @@ -195,6 +195,90 @@ static VALUE rb_decompress_using_dict(int argc, VALUE *argv, VALUE self)
return output;
}

static void free_cdict(void *dict)
{
ZSTD_freeCDict(dict);
}

static size_t sizeof_cdict(const void *dict)
{
return ZSTD_sizeof_CDict(dict);
}

static void free_ddict(void *dict)
{
ZSTD_freeDDict(dict);
}

static size_t sizeof_ddict(const void *dict)
{
return ZSTD_sizeof_DDict(dict);
}

static const rb_data_type_t cdict_type = {
"Zstd::CDict",
{0, free_cdict, sizeof_cdict,},
0, 0, RUBY_TYPED_FREE_IMMEDIATELY
};

static const rb_data_type_t ddict_type = {
"Zstd::DDict",
{0, free_ddict, sizeof_ddict,},
0, 0, RUBY_TYPED_FREE_IMMEDIATELY
};

static VALUE rb_cdict_alloc(VALUE self)
{
ZSTD_CDict* cdict = NULL;
return TypedData_Wrap_Struct(self, &cdict_type, cdict);
}

static VALUE rb_cdict_initialize(int argc, VALUE *argv, VALUE self)
{
VALUE dict;
VALUE compression_level_value;
rb_scan_args(argc, argv, "11", &dict, &compression_level_value);
int compression_level = convert_compression_level(compression_level_value);

StringValue(dict);
char* dict_buffer = RSTRING_PTR(dict);
size_t dict_size = RSTRING_LEN(dict);

ZSTD_CDict* const cdict = ZSTD_createCDict(dict_buffer, dict_size, compression_level);
if (cdict == NULL) {
rb_raise(rb_eRuntimeError, "%s", "ZSTD_createCDict failed");
}

DATA_PTR(self) = cdict;
return self;
}

static VALUE rb_ddict_alloc(VALUE self)
{
ZSTD_CDict* ddict = NULL;
return TypedData_Wrap_Struct(self, &ddict_type, ddict);
}

static VALUE rb_ddict_initialize(VALUE self, VALUE dict)
{
StringValue(dict);
char* dict_buffer = RSTRING_PTR(dict);
size_t dict_size = RSTRING_LEN(dict);

ZSTD_DDict* const ddict = ZSTD_createDDict(dict_buffer, dict_size);
if (ddict == NULL) {
rb_raise(rb_eRuntimeError, "%s", "ZSTD_createDDict failed");
}

DATA_PTR(self) = ddict;
return self;
}

static VALUE rb_prohibit_copy(VALUE, VALUE)
{
rb_raise(rb_eRuntimeError, "CDict cannot be duplicated");
}

void
zstd_ruby_init(void)
{
Expand All @@ -203,4 +287,12 @@ zstd_ruby_init(void)
rb_define_module_function(rb_mZstd, "compress_using_dict", rb_compress_using_dict, -1);
rb_define_module_function(rb_mZstd, "decompress", rb_decompress, -1);
rb_define_module_function(rb_mZstd, "decompress_using_dict", rb_decompress_using_dict, -1);

rb_define_alloc_func(rb_cCDict, rb_cdict_alloc);
rb_define_private_method(rb_cCDict, "initialize", rb_cdict_initialize, -1);
rb_define_method(rb_cCDict, "initialize_copy", rb_prohibit_copy, 1);

rb_define_alloc_func(rb_cDDict, rb_ddict_alloc);
rb_define_private_method(rb_cDDict, "initialize", rb_ddict_initialize, 1);
rb_define_method(rb_cDDict, "initialize_copy", rb_prohibit_copy, 1);
}
21 changes: 20 additions & 1 deletion spec/zstd-ruby-streaming-compress_spec.rb
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@
end
end

describe 'dictionary' do
describe 'String dictionary' do
let(:dictionary) do
File.read("#{__dir__}/dictionary")
end
Expand All @@ -72,6 +72,25 @@
end
end

describe 'Zstd::CDict dictionary' do
let(:cdict) do
Zstd::CDict.new(File.read("#{__dir__}/dictionary"), 5)
end
let(:user_json) do
File.read("#{__dir__}/user_springmt.json")
end
it 'shoud work' do
dict_stream = Zstd::StreamingCompress.new(dict: cdict)
dict_stream << user_json
dict_res = dict_stream.finish
stream = Zstd::StreamingCompress.new(level: 5)
stream << user_json
res = stream.finish

expect(dict_res.length).to be < res.length
end
end

describe 'nil dictionary' do
let(:user_json) do
File.read("#{__dir__}/user_springmt.json")
Expand Down
24 changes: 23 additions & 1 deletion spec/zstd-ruby-streaming-decompress_spec.rb
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@
end
end

describe 'dictionary streaming decompress + GC.compact' do
describe 'String dictionary streaming decompress + GC.compact' do
let(:dictionary) do
File.read("#{__dir__}/dictionary")
end
Expand All @@ -51,6 +51,28 @@
end
end

describe 'Zstd::DDict dictionary streaming decompress + GC.compact' do
let(:dictionary) do
File.read("#{__dir__}/dictionary")
end
let(:ddict) do
Zstd::DDict.new(dictionary)
end
let(:user_json) do
File.read("#{__dir__}/user_springmt.json")
end
it 'shoud work' do
compressed_json = Zstd.compress(user_json, dict: dictionary)
stream = Zstd::StreamingDecompress.new(dict: ddict)
result = ''
result << stream.decompress(compressed_json[0, 5])
result << stream.decompress(compressed_json[5, 5])
GC.compact
result << stream.decompress(compressed_json[10..-1])
expect(result).to eq(user_json)
end
end

describe 'nil dictionary streaming decompress + GC.compact' do
let(:dictionary) do
File.read("#{__dir__}/dictionary")
Expand Down
38 changes: 37 additions & 1 deletion spec/zstd-ruby-using-dict_spec.rb
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
# https://github.com/facebook/zstd/releases/tag/v1.1.3

RSpec.describe Zstd do
describe 'compress and decompress with dict keyward args' do
describe 'compress and decompress with String dict keyward args' do
let(:user_json) do
File.read("#{__dir__}/user_springmt.json")
end
Expand Down Expand Up @@ -52,6 +52,42 @@
end
end

describe 'compress and decompress with Zstd::CDict and Zstd::DDict dict keyward args' do
let(:user_json) do
File.read("#{__dir__}/user_springmt.json")
end
let(:cdict) do
Zstd::CDict.new(File.read("#{__dir__}/dictionary"))
end
let(:cdict_10) do
Zstd::CDict.new(File.read("#{__dir__}/dictionary"), 10)
end
let(:ddict) do
Zstd::DDict.new(File.read("#{__dir__}/dictionary"))
end

it 'should work' do
compressed_using_dict = Zstd.compress(user_json, dict: cdict)
compressed = Zstd.compress(user_json)
expect(compressed_using_dict.length).to be < compressed.length
expect(user_json).to eq(Zstd.decompress(compressed_using_dict, dict: ddict))
end

it 'should be able to use dictionary multiple times' do
compressed_using_dict = Zstd.compress(user_json, dict: cdict)
expect(compressed_using_dict).to eq(Zstd.compress(user_json, dict: cdict))
expect(user_json).to eq(Zstd.decompress(compressed_using_dict, dict: ddict))
expect(user_json).to eq(Zstd.decompress(compressed_using_dict, dict: ddict))
end

it 'should support compression levels' do
compressed_using_dict = Zstd.compress(user_json, dict: cdict)
compressed_using_dict_10 = Zstd.compress(user_json, dict: cdict_10)
expect(compressed_using_dict_10.length).to be < compressed_using_dict.length
expect(user_json).to eq(Zstd.decompress(compressed_using_dict_10, dict: ddict))
end
end

describe 'compress_using_dict' do
let(:user_json) do
File.read("#{__dir__}/user_springmt.json")
Expand Down

0 comments on commit 2c0e8df

Please sign in to comment.