Skip to content

Commit

Permalink
Add RAII file wrappers to avoid resource leak (#614)
Browse files Browse the repository at this point in the history
Closes #607

Authors:
  - Tianyu Liu (https://github.com/kingcrimsontianyu)

Approvers:
  - Kyle Edwards (https://github.com/KyleFromNVIDIA)
  - Mads R. B. Kristensen (https://github.com/madsbk)
  - Vukasin Milovanovic (https://github.com/vuule)

URL: #614
  • Loading branch information
kingcrimsontianyu authored Feb 4, 2025
1 parent 3630b8c commit 255ed48
Show file tree
Hide file tree
Showing 7 changed files with 416 additions and 139 deletions.
1 change: 1 addition & 0 deletions cpp/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -140,6 +140,7 @@ set(SOURCES
"src/defaults.cpp"
"src/error.cpp"
"src/file_handle.cpp"
"src/file_utils.cpp"
"src/posix_io.cpp"
"src/shim/cuda.cpp"
"src/shim/cufile.cpp"
Expand Down
3 changes: 3 additions & 0 deletions cpp/include/kvikio/error.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -115,6 +115,9 @@ void cufile_check_bytes_done_2(ssize_t nbytes_done, int line_number, char const*
}
}

#define KVIKIO_LOG_ERROR(err_msg) kvikio::detail::log_error(err_msg, __LINE__, __FILE__)
void log_error(std::string_view err_msg, int line_number, char const* filename);

} // namespace detail

} // namespace kvikio
20 changes: 11 additions & 9 deletions cpp/include/kvikio/file_handle.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
#include <kvikio/cufile/config.hpp>
#include <kvikio/defaults.hpp>
#include <kvikio/error.hpp>
#include <kvikio/file_utils.hpp>
#include <kvikio/parallel_operation.hpp>
#include <kvikio/posix_io.hpp>
#include <kvikio/shim/cufile.hpp>
Expand All @@ -44,12 +45,12 @@ namespace kvikio {
class FileHandle {
private:
// We use two file descriptors, one opened with the O_DIRECT flag and one without.
int _fd_direct_on{-1};
int _fd_direct_off{-1};
FileWrapper _fd_direct_on{};
FileWrapper _fd_direct_off{};
bool _initialized{false};
CompatMode _compat_mode{CompatMode::AUTO};
mutable std::size_t _nbytes{0}; // The size of the underlying file, zero means unknown.
CUfileHandle_t _handle{};
CUFileHandleWrapper _cufile_handle{};

/**
* @brief Given a requested compatibility mode, whether it is expected to reduce to `ON` for
Expand Down Expand Up @@ -122,23 +123,24 @@ class FileHandle {
* @brief Get one of the file descriptors
*
* Notice, FileHandle maintains two file descriptors - one opened with the
* `O_DIRECT` flag and one without. This function returns one of them but
* it is unspecified which one.
* `O_DIRECT` flag and one without.
*
* @param o_direct Whether to get the file descriptor opened with the `O_DIRECT` flag.
* @return File descriptor
*/
[[nodiscard]] int fd() const noexcept;
[[nodiscard]] int fd(bool o_direct = false) const noexcept;

/**
* @brief Get the flags of one of the file descriptors (see open(2))
*
* Notice, FileHandle maintains two file descriptors - one opened with the
* `O_DIRECT` flag and one without. This function returns the flags of one of
* them but it is unspecified which one.
* `O_DIRECT` flag and one without.
*
* @param o_direct Whether to get the flags of the file descriptor opened with the `O_DIRECT`
* flag.
* @return File descriptor
*/
[[nodiscard]] int fd_open_flags() const;
[[nodiscard]] int fd_open_flags(bool o_direct = false) const;

/**
* @brief Get the file size
Expand Down
165 changes: 165 additions & 0 deletions cpp/include/kvikio/file_utils.hpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,165 @@
/*
* Copyright (c) 2025, NVIDIA CORPORATION.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#pragma once

#include <optional>
#include <string>

#include <kvikio/shim/cufile_h_wrapper.hpp>

namespace kvikio {
/**
* @brief Class that provides RAII for file handling.
*/
class FileWrapper {
private:
int _fd{-1};

public:
/**
* @brief Open file.
*
* @param file_path File path.
* @param flags Open flags given as a string.
* @param o_direct Append O_DIRECT to `flags`.
* @param mode Access modes.
*/
FileWrapper(std::string const& file_path, std::string const& flags, bool o_direct, mode_t mode);

/**
* @brief Construct an empty file wrapper object without opening a file.
*/
FileWrapper() noexcept = default;

~FileWrapper() noexcept;
FileWrapper(FileWrapper const&) = delete;
FileWrapper& operator=(FileWrapper const&) = delete;
FileWrapper(FileWrapper&& o) noexcept;
FileWrapper& operator=(FileWrapper&& o) noexcept;

/**
* @brief Open file using `open(2)`
*
* @param file_path File path.
* @param flags Open flags given as a string.
* @param o_direct Append O_DIRECT to `flags`.
* @param mode Access modes.
*/
void open(std::string const& file_path, std::string const& flags, bool o_direct, mode_t mode);

/**
* @brief Check if the file has been opened.
*
* @return A boolean answer indicating if the file has been opened.
*/
bool opened() const noexcept;

/**
* @brief Close the file if it is opened; do nothing otherwise.
*/
void close() noexcept;

/**
* @brief Return the file descriptor.
*
* @return File descriptor.
*/
int fd() const noexcept;
};

/**
* @brief Class that provides RAII for the cuFile handle.
*/
class CUFileHandleWrapper {
private:
CUfileHandle_t _handle{};
bool _registered{false};

public:
CUFileHandleWrapper() noexcept = default;
~CUFileHandleWrapper() noexcept;
CUFileHandleWrapper(CUFileHandleWrapper const&) = delete;
CUFileHandleWrapper& operator=(CUFileHandleWrapper const&) = delete;
CUFileHandleWrapper(CUFileHandleWrapper&& o) noexcept;
CUFileHandleWrapper& operator=(CUFileHandleWrapper&& o) noexcept;

/**
* @brief Register the file handle given the file descriptor.
*
* @param fd File descriptor.
* @return Return the cuFile error code from handle register. If the handle has already been
* registered by calling `register_handle()`, return `std::nullopt`.
*/
std::optional<CUfileError_t> register_handle(int fd) noexcept;

/**
* @brief Check if the handle has been registered.
*
* @return A boolean answer indicating if the handle has been registered.
*/
bool registered() const noexcept;

/**
* @brief Return the cuFile handle.
*
* @return The cuFile handle.
*/
CUfileHandle_t handle() const noexcept;

/**
* @brief Unregister the handle if it has been registered; do nothing otherwise.
*/
void unregister_handle() noexcept;
};

/**
* @brief Parse open file flags given as a string and return oflags
*
* @param flags The flags
* @param o_direct Append O_DIRECT to the open flags
* @return oflags
*
* @throw std::invalid_argument if the specified flags are not supported.
* @throw std::invalid_argument if `o_direct` is true, but `O_DIRECT` is not supported.
*/
int open_fd_parse_flags(std::string const& flags, bool o_direct);

/**
* @brief Open file using `open(2)`
*
* @param flags Open flags given as a string
* @param o_direct Append O_DIRECT to `flags`
* @param mode Access modes
* @return File descriptor
*/
int open_fd(std::string const& file_path, std::string const& flags, bool o_direct, mode_t mode);

/**
* @brief Get the flags of the file descriptor (see `open(2)`)
*
* @return Open flags
*/
[[nodiscard]] int open_flags(int fd);

/**
* @brief Get file size from file descriptor `fstat(3)`
*
* @param file_descriptor Open file descriptor
* @return The number of bytes
*/
[[nodiscard]] std::size_t get_file_size(int file_descriptor);

} // namespace kvikio
15 changes: 15 additions & 0 deletions cpp/src/error.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -14,4 +14,19 @@
* limitations under the License.
*/

#include <iostream>

#include <kvikio/error.hpp>

namespace kvikio {

namespace detail {

void log_error(std::string_view err_msg, int line_number, char const* filename)
{
std::cerr << "KvikIO error at: " << filename << ":" << line_number << ": " << err_msg << "\n";
}

} // namespace detail

} // namespace kvikio
Loading

0 comments on commit 255ed48

Please sign in to comment.