Skip to content

Commit

Permalink
Make sure gguf_ctx is closed when error happens (#1699)
Browse files Browse the repository at this point in the history
  • Loading branch information
zcbenz authored Dec 14, 2024
1 parent dfccd17 commit 4768c61
Showing 1 changed file with 21 additions and 20 deletions.
41 changes: 21 additions & 20 deletions mlx/io/gguf.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -241,13 +241,13 @@ GGUFLoad load_gguf(const std::string& file, StreamOrDevice s) {
throw std::invalid_argument("[load_gguf] Failed to open " + file);
}

gguf_ctx* ctx = gguf_open(file.data());
std::unique_ptr<gguf_ctx, decltype(&gguf_close)> ctx(
gguf_open(file.data()), gguf_close);
if (!ctx) {
throw std::runtime_error("[load_gguf] gguf_init failed");
}
auto metadata = load_metadata(ctx);
auto arrays = load_arrays(ctx);
gguf_close(ctx);
auto metadata = load_metadata(ctx.get());
auto arrays = load_arrays(ctx.get());
return {arrays, metadata};
}

Expand Down Expand Up @@ -293,7 +293,8 @@ void save_gguf(
file += ".gguf";
}

gguf_ctx* ctx = gguf_create(file.c_str(), GGUF_OVERWRITE);
std::unique_ptr<gguf_ctx, decltype(&gguf_close)> ctx(
gguf_create(file.c_str(), GGUF_OVERWRITE), gguf_close);
if (!ctx) {
throw std::runtime_error("[save_gguf] gguf_create failed");
}
Expand All @@ -312,7 +313,7 @@ void save_gguf(
std::vector<char> val_vec(size);
string_to_gguf(val_vec.data(), str);
gguf_append_kv(
ctx,
ctx.get(),
key.c_str(),
key.length(),
GGUF_VALUE_TYPE_STRING,
Expand All @@ -335,7 +336,7 @@ void save_gguf(
str_ptr += str.length() + sizeof(gguf_string);
}
gguf_append_kv(
ctx,
ctx.get(),
key.c_str(),
key.length(),
GGUF_VALUE_TYPE_ARRAY,
Expand All @@ -361,34 +362,34 @@ void save_gguf(
}
switch (v.dtype()) {
case float32:
append_kv_array(ctx, key, v, GGUF_VALUE_TYPE_FLOAT32);
append_kv_array(ctx.get(), key, v, GGUF_VALUE_TYPE_FLOAT32);
break;
case int64:
append_kv_array(ctx, key, v, GGUF_VALUE_TYPE_INT64);
append_kv_array(ctx.get(), key, v, GGUF_VALUE_TYPE_INT64);
break;
case int32:
append_kv_array(ctx, key, v, GGUF_VALUE_TYPE_INT32);
append_kv_array(ctx.get(), key, v, GGUF_VALUE_TYPE_INT32);
break;
case int16:
append_kv_array(ctx, key, v, GGUF_VALUE_TYPE_INT16);
append_kv_array(ctx.get(), key, v, GGUF_VALUE_TYPE_INT16);
break;
case int8:
append_kv_array(ctx, key, v, GGUF_VALUE_TYPE_INT8);
append_kv_array(ctx.get(), key, v, GGUF_VALUE_TYPE_INT8);
break;
case uint64:
append_kv_array(ctx, key, v, GGUF_VALUE_TYPE_UINT64);
append_kv_array(ctx.get(), key, v, GGUF_VALUE_TYPE_UINT64);
break;
case uint32:
append_kv_array(ctx, key, v, GGUF_VALUE_TYPE_UINT32);
append_kv_array(ctx.get(), key, v, GGUF_VALUE_TYPE_UINT32);
break;
case uint16:
append_kv_array(ctx, key, v, GGUF_VALUE_TYPE_UINT16);
append_kv_array(ctx.get(), key, v, GGUF_VALUE_TYPE_UINT16);
break;
case uint8:
append_kv_array(ctx, key, v, GGUF_VALUE_TYPE_UINT8);
append_kv_array(ctx.get(), key, v, GGUF_VALUE_TYPE_UINT8);
break;
case bool_:
append_kv_array(ctx, key, v, GGUF_VALUE_TYPE_BOOL);
append_kv_array(ctx.get(), key, v, GGUF_VALUE_TYPE_BOOL);
break;
default:
std::ostringstream msg;
Expand Down Expand Up @@ -438,7 +439,7 @@ void save_gguf(
dim[i] = arr.shape()[num_dim - 1 - i];
}
if (!gguf_append_tensor_info(
ctx,
ctx.get(),
tensorname,
namelen,
num_dim,
Expand All @@ -452,11 +453,11 @@ void save_gguf(

// Then, append the tensor weights
for (const auto& [key, arr] : array_map) {
if (!gguf_append_tensor_data(ctx, (void*)arr.data<void>(), arr.nbytes())) {
if (!gguf_append_tensor_data(
ctx.get(), (void*)arr.data<void>(), arr.nbytes())) {
throw std::runtime_error("[save_gguf] gguf_append_tensor_data failed");
}
}
gguf_close(ctx);
}

} // namespace mlx::core

0 comments on commit 4768c61

Please sign in to comment.