diff --git a/mlx/io/gguf.cpp b/mlx/io/gguf.cpp index dffb2aa1d7..c1a1d03bf9 100644 --- a/mlx/io/gguf.cpp +++ b/mlx/io/gguf.cpp @@ -241,13 +241,13 @@ GGUFLoad load_gguf(const std::string& file, StreamOrDevice s) { throw std::invalid_argument("[load_gguf] Failed to open " + file); } - gguf_ctx* ctx = gguf_open(file.data()); + std::unique_ptr ctx( + gguf_open(file.data()), gguf_close); if (!ctx) { throw std::runtime_error("[load_gguf] gguf_init failed"); } - auto metadata = load_metadata(ctx); - auto arrays = load_arrays(ctx); - gguf_close(ctx); + auto metadata = load_metadata(ctx.get()); + auto arrays = load_arrays(ctx.get()); return {arrays, metadata}; } @@ -293,7 +293,8 @@ void save_gguf( file += ".gguf"; } - gguf_ctx* ctx = gguf_create(file.c_str(), GGUF_OVERWRITE); + std::unique_ptr ctx( + gguf_create(file.c_str(), GGUF_OVERWRITE), gguf_close); if (!ctx) { throw std::runtime_error("[save_gguf] gguf_create failed"); } @@ -312,7 +313,7 @@ void save_gguf( std::vector val_vec(size); string_to_gguf(val_vec.data(), str); gguf_append_kv( - ctx, + ctx.get(), key.c_str(), key.length(), GGUF_VALUE_TYPE_STRING, @@ -335,7 +336,7 @@ void save_gguf( str_ptr += str.length() + sizeof(gguf_string); } gguf_append_kv( - ctx, + ctx.get(), key.c_str(), key.length(), GGUF_VALUE_TYPE_ARRAY, @@ -361,34 +362,34 @@ void save_gguf( } switch (v.dtype()) { case float32: - append_kv_array(ctx, key, v, GGUF_VALUE_TYPE_FLOAT32); + append_kv_array(ctx.get(), key, v, GGUF_VALUE_TYPE_FLOAT32); break; case int64: - append_kv_array(ctx, key, v, GGUF_VALUE_TYPE_INT64); + append_kv_array(ctx.get(), key, v, GGUF_VALUE_TYPE_INT64); break; case int32: - append_kv_array(ctx, key, v, GGUF_VALUE_TYPE_INT32); + append_kv_array(ctx.get(), key, v, GGUF_VALUE_TYPE_INT32); break; case int16: - append_kv_array(ctx, key, v, GGUF_VALUE_TYPE_INT16); + append_kv_array(ctx.get(), key, v, GGUF_VALUE_TYPE_INT16); break; case int8: - append_kv_array(ctx, key, v, GGUF_VALUE_TYPE_INT8); + append_kv_array(ctx.get(), key, v, GGUF_VALUE_TYPE_INT8); break; case uint64: - append_kv_array(ctx, key, v, GGUF_VALUE_TYPE_UINT64); + append_kv_array(ctx.get(), key, v, GGUF_VALUE_TYPE_UINT64); break; case uint32: - append_kv_array(ctx, key, v, GGUF_VALUE_TYPE_UINT32); + append_kv_array(ctx.get(), key, v, GGUF_VALUE_TYPE_UINT32); break; case uint16: - append_kv_array(ctx, key, v, GGUF_VALUE_TYPE_UINT16); + append_kv_array(ctx.get(), key, v, GGUF_VALUE_TYPE_UINT16); break; case uint8: - append_kv_array(ctx, key, v, GGUF_VALUE_TYPE_UINT8); + append_kv_array(ctx.get(), key, v, GGUF_VALUE_TYPE_UINT8); break; case bool_: - append_kv_array(ctx, key, v, GGUF_VALUE_TYPE_BOOL); + append_kv_array(ctx.get(), key, v, GGUF_VALUE_TYPE_BOOL); break; default: std::ostringstream msg; @@ -438,7 +439,7 @@ void save_gguf( dim[i] = arr.shape()[num_dim - 1 - i]; } if (!gguf_append_tensor_info( - ctx, + ctx.get(), tensorname, namelen, num_dim, @@ -452,11 +453,11 @@ void save_gguf( // Then, append the tensor weights for (const auto& [key, arr] : array_map) { - if (!gguf_append_tensor_data(ctx, (void*)arr.data(), arr.nbytes())) { + if (!gguf_append_tensor_data( + ctx.get(), (void*)arr.data(), arr.nbytes())) { throw std::runtime_error("[save_gguf] gguf_append_tensor_data failed"); } } - gguf_close(ctx); } } // namespace mlx::core