Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Make sure gguf_ctx is closed when error happens #1699

Merged
merged 1 commit into from
Dec 14, 2024
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
41 changes: 21 additions & 20 deletions mlx/io/gguf.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -241,13 +241,13 @@ GGUFLoad load_gguf(const std::string& file, StreamOrDevice s) {
throw std::invalid_argument("[load_gguf] Failed to open " + file);
}

gguf_ctx* ctx = gguf_open(file.data());
std::unique_ptr<gguf_ctx, decltype(&gguf_close)> ctx(
gguf_open(file.data()), gguf_close);
if (!ctx) {
throw std::runtime_error("[load_gguf] gguf_init failed");
}
auto metadata = load_metadata(ctx);
auto arrays = load_arrays(ctx);
gguf_close(ctx);
auto metadata = load_metadata(ctx.get());
auto arrays = load_arrays(ctx.get());
return {arrays, metadata};
}

Expand Down Expand Up @@ -293,7 +293,8 @@ void save_gguf(
file += ".gguf";
}

gguf_ctx* ctx = gguf_create(file.c_str(), GGUF_OVERWRITE);
std::unique_ptr<gguf_ctx, decltype(&gguf_close)> ctx(
gguf_create(file.c_str(), GGUF_OVERWRITE), gguf_close);
if (!ctx) {
throw std::runtime_error("[save_gguf] gguf_create failed");
}
Expand All @@ -312,7 +313,7 @@ void save_gguf(
std::vector<char> val_vec(size);
string_to_gguf(val_vec.data(), str);
gguf_append_kv(
ctx,
ctx.get(),
key.c_str(),
key.length(),
GGUF_VALUE_TYPE_STRING,
Expand All @@ -335,7 +336,7 @@ void save_gguf(
str_ptr += str.length() + sizeof(gguf_string);
}
gguf_append_kv(
ctx,
ctx.get(),
key.c_str(),
key.length(),
GGUF_VALUE_TYPE_ARRAY,
Expand All @@ -361,34 +362,34 @@ void save_gguf(
}
switch (v.dtype()) {
case float32:
append_kv_array(ctx, key, v, GGUF_VALUE_TYPE_FLOAT32);
append_kv_array(ctx.get(), key, v, GGUF_VALUE_TYPE_FLOAT32);
break;
case int64:
append_kv_array(ctx, key, v, GGUF_VALUE_TYPE_INT64);
append_kv_array(ctx.get(), key, v, GGUF_VALUE_TYPE_INT64);
break;
case int32:
append_kv_array(ctx, key, v, GGUF_VALUE_TYPE_INT32);
append_kv_array(ctx.get(), key, v, GGUF_VALUE_TYPE_INT32);
break;
case int16:
append_kv_array(ctx, key, v, GGUF_VALUE_TYPE_INT16);
append_kv_array(ctx.get(), key, v, GGUF_VALUE_TYPE_INT16);
break;
case int8:
append_kv_array(ctx, key, v, GGUF_VALUE_TYPE_INT8);
append_kv_array(ctx.get(), key, v, GGUF_VALUE_TYPE_INT8);
break;
case uint64:
append_kv_array(ctx, key, v, GGUF_VALUE_TYPE_UINT64);
append_kv_array(ctx.get(), key, v, GGUF_VALUE_TYPE_UINT64);
break;
case uint32:
append_kv_array(ctx, key, v, GGUF_VALUE_TYPE_UINT32);
append_kv_array(ctx.get(), key, v, GGUF_VALUE_TYPE_UINT32);
break;
case uint16:
append_kv_array(ctx, key, v, GGUF_VALUE_TYPE_UINT16);
append_kv_array(ctx.get(), key, v, GGUF_VALUE_TYPE_UINT16);
break;
case uint8:
append_kv_array(ctx, key, v, GGUF_VALUE_TYPE_UINT8);
append_kv_array(ctx.get(), key, v, GGUF_VALUE_TYPE_UINT8);
break;
case bool_:
append_kv_array(ctx, key, v, GGUF_VALUE_TYPE_BOOL);
append_kv_array(ctx.get(), key, v, GGUF_VALUE_TYPE_BOOL);
break;
default:
std::ostringstream msg;
Expand Down Expand Up @@ -438,7 +439,7 @@ void save_gguf(
dim[i] = arr.shape()[num_dim - 1 - i];
}
if (!gguf_append_tensor_info(
ctx,
ctx.get(),
tensorname,
namelen,
num_dim,
Expand All @@ -452,11 +453,11 @@ void save_gguf(

// Then, append the tensor weights
for (const auto& [key, arr] : array_map) {
if (!gguf_append_tensor_data(ctx, (void*)arr.data<void>(), arr.nbytes())) {
if (!gguf_append_tensor_data(
ctx.get(), (void*)arr.data<void>(), arr.nbytes())) {
throw std::runtime_error("[save_gguf] gguf_append_tensor_data failed");
}
}
gguf_close(ctx);
}

} // namespace mlx::core