Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

refactor: Use mutex from utils #1851

Merged
merged 1 commit into from
Jan 27, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 5 additions & 4 deletions src/data/cassandra/impl/AsyncExecutor.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
#include "data/cassandra/Handle.hpp"
#include "data/cassandra/Types.hpp"
#include "data/cassandra/impl/RetryPolicy.hpp"
#include "util/Mutex.hpp"
#include "util/log/Logger.hpp"

#include <boost/asio.hpp>
Expand Down Expand Up @@ -64,8 +65,8 @@ class AsyncExecutor : public std::enable_shared_from_this<AsyncExecutor<Statemen
RetryCallbackType onRetry_;

// does not exist during initial construction, hence optional
std::optional<FutureWithCallbackType> future_;
std::mutex mtx_;
using OptionalFuture = std::optional<FutureWithCallbackType>;
util::Mutex<OptionalFuture> future_;

public:
/**
Expand Down Expand Up @@ -127,8 +128,8 @@ class AsyncExecutor : public std::enable_shared_from_this<AsyncExecutor<Statemen
self = nullptr; // explicitly decrement refcount
};

std::scoped_lock const lck{mtx_};
future_.emplace(handle.asyncExecute(data_, std::move(handler)));
auto future = future_.template lock<std::scoped_lock>();
future->emplace(handle.asyncExecute(data_, std::move(handler)));
}
};

Expand Down
22 changes: 11 additions & 11 deletions src/feed/impl/TrackableSignal.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,8 @@

#pragma once

#include "util/Mutex.hpp"

#include <boost/signals2.hpp>
#include <boost/signals2/connection.hpp>
#include <boost/signals2/variadic_signal.hpp>
Expand All @@ -45,8 +47,8 @@ class TrackableSignal {

// map of connection and signal connection, key is the pointer of the connection object
// allow disconnect to be called in the destructor of the connection
std::unordered_map<ConnectionPtr, boost::signals2::connection> connections_;
mutable std::mutex mutex_;
using ConnectionsMap = std::unordered_map<ConnectionPtr, boost::signals2::connection>;
util::Mutex<ConnectionsMap> connections_;

using SignalType = boost::signals2::signal<void(Args...)>;
SignalType signal_;
Expand All @@ -64,16 +66,16 @@ class TrackableSignal {
bool
connectTrackableSlot(ConnectionSharedPtr const& trackable, std::function<void(Args...)> slot)
{
std::scoped_lock const lk(mutex_);
if (connections_.contains(trackable.get())) {
auto connections = connections_.template lock<std::scoped_lock>();
if (connections->contains(trackable.get())) {
return false;
}

// This class can't hold the trackable's shared_ptr, because disconnect should be able to be called in the
// the trackable's destructor. However, the trackable can not be destroied when the slot is being called
// either. track_foreign will hold a weak_ptr to the connection, which makes sure the connection is valid when
// the slot is called.
connections_.emplace(
connections->emplace(
trackable.get(), signal_.connect(typename SignalType::slot_type(slot).track_foreign(trackable))
);
return true;
Expand All @@ -89,10 +91,9 @@ class TrackableSignal {
bool
disconnect(ConnectionPtr trackablePtr)
{
std::scoped_lock const lk(mutex_);
if (connections_.contains(trackablePtr)) {
connections_[trackablePtr].disconnect();
connections_.erase(trackablePtr);
if (auto connections = connections_.template lock<std::scoped_lock>(); connections->contains(trackablePtr)) {
connections->operator[](trackablePtr).disconnect();
connections->erase(trackablePtr);
return true;
}
return false;
Expand All @@ -115,8 +116,7 @@ class TrackableSignal {
std::size_t
count() const
{
std::scoped_lock const lk(mutex_);
return connections_.size();
return connections_.template lock<std::scoped_lock>()->size();
}
};
} // namespace feed::impl
25 changes: 13 additions & 12 deletions src/feed/impl/TrackableSignalMap.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
#pragma once

#include "feed/impl/TrackableSignal.hpp"
#include "util/Mutex.hpp"

#include <boost/signals2.hpp>

Expand Down Expand Up @@ -49,8 +50,8 @@ class TrackableSignalMap {
using ConnectionPtr = Session*;
using ConnectionSharedPtr = std::shared_ptr<Session>;

mutable std::mutex mutex_;
std::unordered_map<Key, TrackableSignal<Session, Args...>> signalsMap_;
using SignalsMap = std::unordered_map<Key, TrackableSignal<Session, Args...>>;
util::Mutex<SignalsMap> signalsMap_;

public:
/**
Expand All @@ -66,8 +67,8 @@ class TrackableSignalMap {
bool
connectTrackableSlot(ConnectionSharedPtr const& trackable, Key const& key, std::function<void(Args...)> slot)
{
std::scoped_lock const lk(mutex_);
return signalsMap_[key].connectTrackableSlot(trackable, slot);
auto map = signalsMap_.template lock<std::scoped_lock>();
return map->operator[](key).connectTrackableSlot(trackable, slot);
}

/**
Expand All @@ -80,14 +81,14 @@ class TrackableSignalMap {
bool
disconnect(ConnectionPtr trackablePtr, Key const& key)
{
std::scoped_lock const lk(mutex_);
if (!signalsMap_.contains(key))
auto map = signalsMap_.template lock<std::scoped_lock>();
if (!map->contains(key))
return false;

auto const disconnected = signalsMap_[key].disconnect(trackablePtr);
auto const disconnected = map->operator[](key).disconnect(trackablePtr);
// clean the map if there is no connection left.
if (disconnected && signalsMap_[key].count() == 0)
signalsMap_.erase(key);
if (disconnected && map->operator[](key).count() == 0)
map->erase(key);

return disconnected;
}
Expand All @@ -101,9 +102,9 @@ class TrackableSignalMap {
void
emit(Key const& key, Args const&... args)
{
std::scoped_lock const lk(mutex_);
if (signalsMap_.contains(key))
signalsMap_[key].emit(args...);
auto map = signalsMap_.template lock<std::scoped_lock>();
if (map->contains(key))
map->operator[](key).emit(args...);
}
};
} // namespace feed::impl
6 changes: 3 additions & 3 deletions src/util/Mutex.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ class Mutex;
* @tparam LockType type of lock
* @tparam MutexType type of mutex
*/
template <typename ProtectedDataType, template <typename> typename LockType, typename MutexType>
template <typename ProtectedDataType, template <typename...> typename LockType, typename MutexType>
class Lock {
LockType<MutexType> lock_;
ProtectedDataType& data_;
Expand Down Expand Up @@ -129,7 +129,7 @@ class Mutex {
* @tparam LockType The type of lock to use
* @return A lock on the mutex and a reference to the protected data
*/
template <template <typename> typename LockType = std::lock_guard>
template <template <typename...> typename LockType = std::lock_guard>
Lock<ProtectedDataType const, LockType, MutexType>
lock() const
{
Expand All @@ -142,7 +142,7 @@ class Mutex {
* @tparam LockType The type of lock to use
* @return A lock on the mutex and a reference to the protected data
*/
template <template <typename> typename LockType = std::lock_guard>
template <template <typename...> typename LockType = std::lock_guard>
Lock<ProtectedDataType, LockType, MutexType>
lock()
{
Expand Down
45 changes: 25 additions & 20 deletions src/util/prometheus/impl/HistogramImpl.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@

#include "util/Assert.hpp"
#include "util/Concepts.hpp"
#include "util/Mutex.hpp"
#include "util/prometheus/OStream.hpp"

#include <cstdint>
Expand Down Expand Up @@ -61,28 +62,30 @@ class HistogramImpl {
void
setBuckets(std::vector<ValueType> const& bounds)
{
std::scoped_lock const lock{*mutex_};
ASSERT(buckets_.empty(), "Buckets can be set only once.");
buckets_.reserve(bounds.size());
auto data = data_->template lock<std::scoped_lock>();
ASSERT(data->buckets.empty(), "Buckets can be set only once.");
data->buckets.reserve(bounds.size());
for (auto const& bound : bounds) {
buckets_.emplace_back(bound);
data->buckets.emplace_back(bound);
}
}

void
observe(ValueType const value)
{
auto const bucket =
std::lower_bound(buckets_.begin(), buckets_.end(), value, [](Bucket const& bucket, ValueType const& value) {
return bucket.upperBound < value;
});
std::scoped_lock const lock{*mutex_};
if (bucket != buckets_.end()) {
auto data = data_->template lock<std::scoped_lock>();
auto const bucket = std::lower_bound(
data->buckets.begin(),
data->buckets.end(),
value,
[](Bucket const& bucket, ValueType const& value) { return bucket.upperBound < value; }
);
if (bucket != data->buckets.end()) {
++bucket->count;
} else {
++lastBucket_.count;
++data->lastBucket.count;
}
sum_ += value;
data->sum += value;
}

void
Expand All @@ -98,23 +101,23 @@ class HistogramImpl {
labelsString.back() = ',';
}

std::scoped_lock const lock{*mutex_};
auto data = data_->template lock<std::scoped_lock>();
std::uint64_t cumulativeCount = 0;

for (auto const& bucket : buckets_) {
for (auto const& bucket : data->buckets) {
cumulativeCount += bucket.count;
stream << name << "_bucket" << labelsString << "le=\"" << bucket.upperBound << "\"} " << cumulativeCount
<< '\n';
}
cumulativeCount += lastBucket_.count;
cumulativeCount += data->lastBucket.count;
stream << name << "_bucket" << labelsString << "le=\"+Inf\"} " << cumulativeCount << '\n';

if (labelsString.size() == 1) {
labelsString = "";
} else {
labelsString.back() = '}';
}
stream << name << "_sum" << labelsString << " " << sum_ << '\n';
stream << name << "_sum" << labelsString << " " << data->sum << '\n';
stream << name << "_count" << labelsString << " " << cumulativeCount << '\n';
}

Expand All @@ -128,10 +131,12 @@ class HistogramImpl {
std::uint64_t count = 0;
};

std::vector<Bucket> buckets_;
Bucket lastBucket_{std::numeric_limits<ValueType>::max()};
ValueType sum_ = 0;
mutable std::unique_ptr<std::mutex> mutex_ = std::make_unique<std::mutex>();
struct Data {
std::vector<Bucket> buckets;
Bucket lastBucket{std::numeric_limits<ValueType>::max()};
ValueType sum = 0;
};
std::unique_ptr<util::Mutex<Data>> data_ = std::make_unique<util::Mutex<Data>>();
};

} // namespace util::prometheus::impl
40 changes: 20 additions & 20 deletions src/web/dosguard/DOSGuard.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -58,17 +58,17 @@ DOSGuard::isOk(std::string const& ip) const noexcept
return true;

{
std::scoped_lock const lck(mtx_);
if (ipState_.find(ip) != ipState_.end()) {
auto [transferedByte, requests] = ipState_.at(ip);
if (transferedByte > maxFetches_ || requests > maxRequestCount_) {
auto lock = mtx_.lock<std::scoped_lock>();
if (lock->ipState.find(ip) != lock->ipState.end()) {
auto [transferredByte, requests] = lock->ipState.at(ip);
if (transferredByte > maxFetches_ || requests > maxRequestCount_) {
LOG(log_.warn()) << "Dosguard: Client surpassed the rate limit. ip = " << ip
<< " Transfered Byte: " << transferedByte << "; Requests: " << requests;
<< " Transfered Byte: " << transferredByte << "; Requests: " << requests;
return false;
}
}
auto it = ipConnCount_.find(ip);
if (it != ipConnCount_.end()) {
auto it = lock->ipConnCount.find(ip);
if (it != lock->ipConnCount.end()) {
if (it->second > maxConnCount_) {
LOG(log_.warn()) << "Dosguard: Client surpassed the rate limit. ip = " << ip
<< " Concurrent connection: " << it->second;
Expand All @@ -84,20 +84,20 @@ DOSGuard::increment(std::string const& ip) noexcept
{
if (whitelistHandler_.get().isWhiteListed(ip))
return;
std::scoped_lock const lck{mtx_};
ipConnCount_[ip]++;
auto lock = mtx_.lock<std::scoped_lock>();
lock->ipConnCount[ip]++;
}

void
DOSGuard::decrement(std::string const& ip) noexcept
{
if (whitelistHandler_.get().isWhiteListed(ip))
return;
std::scoped_lock const lck{mtx_};
ASSERT(ipConnCount_[ip] > 0, "Connection count for ip {} can't be 0", ip);
ipConnCount_[ip]--;
if (ipConnCount_[ip] == 0)
ipConnCount_.erase(ip);
auto lock = mtx_.lock<std::scoped_lock>();
ASSERT(lock->ipConnCount[ip] > 0, "Connection count for ip {} can't be 0", ip);
lock->ipConnCount[ip]--;
if (lock->ipConnCount[ip] == 0)
lock->ipConnCount.erase(ip);
}

[[maybe_unused]] bool
Expand All @@ -107,8 +107,8 @@ DOSGuard::add(std::string const& ip, uint32_t numObjects) noexcept
return true;

{
std::scoped_lock const lck(mtx_);
ipState_[ip].transferedByte += numObjects;
auto lock = mtx_.lock<std::scoped_lock>();
lock->ipState[ip].transferedByte += numObjects;
}

return isOk(ip);
Expand All @@ -121,8 +121,8 @@ DOSGuard::request(std::string const& ip) noexcept
return true;

{
std::scoped_lock const lck(mtx_);
ipState_[ip].requestsCount++;
auto lock = mtx_.lock<std::scoped_lock>();
lock->ipState[ip].requestsCount++;
}

return isOk(ip);
Expand All @@ -131,8 +131,8 @@ DOSGuard::request(std::string const& ip) noexcept
void
DOSGuard::clear() noexcept
{
std::scoped_lock const lck(mtx_);
ipState_.clear();
auto lock = mtx_.lock<std::scoped_lock>();
lock->ipState.clear();
}

[[nodiscard]] std::unordered_set<std::string>
Expand Down
11 changes: 7 additions & 4 deletions src/web/dosguard/DOSGuard.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@

#pragma once

#include "util/Mutex.hpp"
#include "util/log/Logger.hpp"
#include "util/newconfig/ConfigDefinition.hpp"
#include "web/dosguard/DOSGuardInterface.hpp"
Expand All @@ -30,7 +31,6 @@

#include <cstdint>
#include <functional>
#include <mutex>
#include <string>
#include <string_view>
#include <unordered_map>
Expand All @@ -52,9 +52,12 @@ class DOSGuard : public DOSGuardInterface {
std::uint32_t requestsCount = 0; /**< Accumulated served requests count */
};

mutable std::mutex mtx_;
std::unordered_map<std::string, ClientState> ipState_;
std::unordered_map<std::string, std::uint32_t> ipConnCount_;
struct State {
std::unordered_map<std::string, ClientState> ipState;
std::unordered_map<std::string, std::uint32_t> ipConnCount;
};
util::Mutex<State> mtx_;

std::reference_wrapper<WhitelistHandlerInterface const> whitelistHandler_;

std::uint32_t const maxFetches_;
Expand Down
Loading
Loading