Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Extend websocket server to support message limits and throttling #116

Merged
merged 4 commits into from
Jun 9, 2021
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
225 changes: 205 additions & 20 deletions plugins/websocket_server/WebsocketServer.cc
Original file line number Diff line number Diff line change
Expand Up @@ -267,6 +267,7 @@ int rootCallback(struct lws *_wsi,
case LWS_CALLBACK_SERVER_WRITEABLE:
{
std::lock_guard<std::mutex> lock(self->connections[fd]->mutex);

if (!self->connections[fd]->buffer.empty())
{
int msgSize = self->connections[fd]->len.front();
Expand Down Expand Up @@ -360,7 +361,6 @@ bool WebsocketServer::Load(const tinyxml2::XMLElement *_elem)
}
this->publishPeriod = std::chrono::duration_cast<std::chrono::nanoseconds>(
std::chrono::duration<double>(1.0 / hz));

// Get the authorization key, if present.
elem = _elem->FirstChildElement("authorization_key");
if (elem)
Expand Down Expand Up @@ -406,13 +406,62 @@ bool WebsocketServer::Load(const tinyxml2::XMLElement *_elem)
}
catch (...)
{
ignerr << "Failed to convert port[" << elem->GetText() << "] to integer."
<< std::endl;
ignerr << "Failed to convert max_connections[" << elem->GetText()
<< "] to integer." << std::endl;
}
igndbg << "Using maximum connection count of "
<< this->maxConnections << std::endl;
}

// Get the msg count per connection.
elem = _elem->FirstChildElement("queue_size_per_connection");
if (elem)
{
int size = -1;
auto result = elem->QueryIntText(&size);
if (result == tinyxml2::XML_SUCCESS && size >= 0)
{
this->queueSizePerConnection = size;
}
else
{
ignerr << "Failed to parse queue_size_per_connection["
<< elem->GetText() << "]." << std::endl;
}
igndbg << "Using connection msg queue size of "
<< this->queueSizePerConnection << std::endl;
}

// Get the msg type subscription limit
elem = _elem->FirstChildElement("subscription_limit_per_connection");
if (elem)
{
auto childElem = elem->FirstChildElement("subscription");
while (childElem)
{
auto msgTypeElem = childElem->FirstChildElement("msg_type");
auto limitElem = childElem->FirstChildElement("limit");
if (msgTypeElem && limitElem)
{
std::string msgType = msgTypeElem->GetText();
int limit = -1;
auto result = limitElem->QueryIntText(&limit);
if (result == tinyxml2::XML_SUCCESS && limit >= 0)
{
this->msgTypeSubscriptionLimit[msgType] = limit;
igndbg << "Setting msg type subscription limit[" << msgType
<< ", " << limit << "]" << std::endl;
}
else
{
ignerr << "Failed to parse subscription limit["
<< msgType << ", " << limitElem->GetText() << "]." << std::endl;
}
}
childElem = childElem->NextSiblingElement("subscription");
}
}

std::string sslCertFile = "";
std::string sslPrivateKeyFile = "";
elem = _elem->FirstChildElement("ssl");
Expand Down Expand Up @@ -527,12 +576,23 @@ void WebsocketServer::QueueMessage(Connection *_connection,
memcpy(buf.get() + LWS_PRE, _data, _size);

std::lock_guard<std::mutex> lock(_connection->mutex);
_connection->buffer.push_back(std::move(buf));
_connection->len.push_back(_size);
if (_connection->buffer.size() < this->queueSizePerConnection)
{
_connection->buffer.push_back(std::move(buf));
_connection->len.push_back(_size);

std::scoped_lock<std::mutex> runLock(this->runMutex);
this->messageCount++;
this->runConditionVariable.notify_all();
std::scoped_lock<std::mutex> runLock(this->runMutex);
this->messageCount++;
this->runConditionVariable.notify_all();
}
else
{
static bool warned{false};
if (!warned)
{
ignwarn << "Queue size reached for connection" << std::endl;
}
}
}
else
{
Expand Down Expand Up @@ -813,22 +873,34 @@ void WebsocketServer::OnMessage(int _socketId, const std::string &_msg)
}
else if (frameParts[0] == "sub")
{
std::string topic = frameParts[1];

// check and update subscription count
if (!this->UpdateMsgTypeSubscriptionCount(topic, _socketId, true))
return;

// Store the relation of socketId to subscribed topic.
this->topicConnections[frameParts[1]].insert(_socketId);
this->topicTimestamps[frameParts[1]] =
this->topicConnections[topic].insert(_socketId);
this->topicTimestamps[topic] =
std::chrono::steady_clock::now() - this->publishPeriod;

igndbg << "Subscribe request to topic[" << frameParts[1] << "]\n";
this->node.SubscribeRaw(frameParts[1],
this->node.SubscribeRaw(topic,
std::bind(&WebsocketServer::OnWebsocketSubscribedMessage,
this, std::placeholders::_1,
std::placeholders::_2, std::placeholders::_3));
}
else if (frameParts[0] == "image")
{
std::string topic = frameParts[1];

// check and update subscription count
if (!this->UpdateMsgTypeSubscriptionCount(topic, _socketId, true))
return;

// Store the relation of socketId to subscribed topic.
this->topicConnections[frameParts[1]].insert(_socketId);
this->topicTimestamps[frameParts[1]] =
this->topicConnections[topic].insert(_socketId);
this->topicTimestamps[topic] =
std::chrono::steady_clock::now() - this->publishPeriod;

std::vector<std::string> allTopics;
Expand All @@ -847,7 +919,7 @@ void WebsocketServer::OnMessage(int _socketId, const std::string &_msg)
}
}
}
std::string topic = frameParts[1];

if (!imageTopics.count(topic))
{
igndbg << "Could not find topic: " << topic << " to stream"
Expand All @@ -861,15 +933,25 @@ void WebsocketServer::OnMessage(int _socketId, const std::string &_msg)
}
else if (frameParts[0] == "unsub")
{
igndbg << "Unsubscribe request for topic[" << frameParts[1] << "]\n";
std::string topic = frameParts[1];

// check and update subscription count
this->UpdateMsgTypeSubscriptionCount(topic, _socketId, false);

igndbg << "Unsubscribe request for topic[" << topic << "]\n";
std::map<std::string, std::set<int>>::iterator topicConnectionIter =
this->topicConnections.find(frameParts[1]);
this->topicConnections.find(topic);

if (topicConnectionIter != this->topicConnections.end())
{
// Remove from the topic connections map
topicConnectionIter->second.erase(_socketId);

// remove from the connection's topic throttling maps
auto &con = this->connections[_socketId];
con->topicPublishPeriods.erase(topic);
con->topicTimestamps.erase(topic);

// Only unsubscribe from the Ignition Transport topic if there are no
// more websocket connections.
if (topicConnectionIter->second.empty())
Expand All @@ -882,7 +964,28 @@ void WebsocketServer::OnMessage(int _socketId, const std::string &_msg)
else
{
ignwarn << "The websocket server is not subscribed to topic["
<< frameParts[1] << "]. Unable to unsubscribe from the topic\n";
<< topic << "]. Unable to unsubscribe from the topic\n";
}
}
else if (frameParts[0] == "throttle")
{
std::string topic = frameParts[1];
igndbg << "Throttle request for topic[" << topic << "]\n";
if (!topic.empty())
{
try
{
int rate = std::stoi(frameParts[3]);
double period = 1.0 / static_cast<double>(rate);
this->connections[_socketId]->topicPublishPeriods[topic] =
std::chrono::duration_cast<std::chrono::nanoseconds>(
std::chrono::duration<double>(1.0 / rate));
}
catch (...)
{
ignwarn << "Unable to set topic rate for topic[" << topic
<< "]" << std::endl;
}
}
}
}
Expand Down Expand Up @@ -924,10 +1027,18 @@ void WebsocketServer::OnWebsocketSubscribedMessage(
// Send the message
for (const int &socketId : iter->second)
{
if (this->connections.find(socketId) != this->connections.end())
auto conIt = this->connections.find(socketId);
if (conIt != this->connections.end())
{
this->QueueMessage(this->connections[socketId].get(),
msg.c_str(), msg.length());
// do additional throttling based on client connection setting
auto lastPubTimeCon = conIt->second->topicTimestamps[_info.Topic()];
std::chrono::nanoseconds timeDeltaCon = systemTime - lastPubTimeCon;
if (timeDeltaCon > conIt->second->topicPublishPeriods[_info.Topic()])
{
conIt->second->topicTimestamps[_info.Topic()] = systemTime;
this->QueueMessage(conIt->second.get(),
msg.c_str(), msg.length());
}
}
}
}
Expand Down Expand Up @@ -1011,3 +1122,77 @@ void WebsocketServer::OnWebsocketSubscribedImageMessage(
}
}
}

//////////////////////////////////////////////////
bool WebsocketServer::UpdateMsgTypeSubscriptionCount(const std::string &_topic,
int _socketId, bool _subscribe)
{
// check if limit reached for the subscribed msg type
// if not, update subscription count
std::vector<transport::MessagePublisher> publishers;
this->node.TopicInfo(_topic, publishers);
if (!publishers.empty())
{
std::string msgType = publishers.begin()->MsgTypeName();
auto limitIt = this->msgTypeSubscriptionLimit.find(msgType);
if (limitIt != this->msgTypeSubscriptionLimit.end())
{
bool limitReached = false;
auto conIt = this->connections.find(_socketId);
if (conIt != this->connections.end())
{
auto &con = conIt->second;
auto &subCount = con->msgTypeSubscriptionCount;
auto countIt = subCount.find(msgType);

// if there is already a subscription on the topic for this connection
if (countIt != subCount.end())
{
// subscribe: increment count and check if reached limit
// unsubscribe: decrement count and make sure count is >= 0
if (_subscribe)
{
if (countIt->second + 1 <= limitIt->second)
{
countIt->second++;
}
else
{
limitReached = true;
}
}
else
{
countIt->second = std::max(0, countIt->second - 1);
}
}
// if topic not yet subscribed, set count to 1 on subscription
// ignore for unsubscribe option
else if (limitIt->second > 0)
{
if (_subscribe)
subCount[msgType] = 1;
}
// corner case when msg type subscription limit is set to 0
else if (_subscribe)
{
limitReached = true;
}
if (limitReached)
{
ignwarn << "Msg type subscription limit reached[" << msgType
<< ", " << limitIt->second << "] for connection[" << _socketId
<< "]" << std::endl;
return false;
}
}
else
{
ignwarn << "Unable to find connection[" << _socketId << "]"
<< " when setting subscription limit." << std::endl;
return false;
}
}
}
return true;
}
37 changes: 37 additions & 0 deletions plugins/websocket_server/WebsocketServer.hh
Original file line number Diff line number Diff line change
Expand Up @@ -175,6 +175,19 @@ namespace ignition

public: void OnRequestMessage(int _socketId, const std::string &_msg);

/// \brief Check and update subscription count for a message type. If
/// a client has more subscriptions to a topic of a specified type than
/// the subscription limit, this will block subscription. On the other
/// hand, for an unsubscription operation, the count is decremented.
/// \param[in] _topic Topic to subscribe to or unsubscribe from
/// \param[in] _socketId Connection socket id
/// \param[in] _subscribe True for subscribe operation, false for
/// unsubscribe operation
/// \return True if the subscription count is incremented or decremented,
/// and false to indicate the subcription limit has reached.
public: bool UpdateMsgTypeSubscriptionCount(const std::string &_topic,
int _socketId, bool _subscribe);

private: ignition::transport::Node node;

private: bool run = true;
Expand All @@ -190,6 +203,21 @@ namespace ignition
public: std::mutex mutex;

public: bool authorized{false};

/// \brief A map of topic name to outbound publish rate
/// A value of 0 means unthrottled
public: std::map<std::string, std::chrono::nanoseconds>
topicPublishPeriods;

/// \brief A map of topic name to timestamp of last published message
/// for this connection
public: std::map<std::string,
std::chrono::time_point<std::chrono::steady_clock>> topicTimestamps;

/// \brief The number of subscriptions of a msg type this connection
/// has. The key is the msg type, e.g. ignition.msgs.Image, and the
/// value is the subscription count
public: std::map<std::string, int> msgTypeSubscriptionCount;
};

private: void QueueMessage(Connection *_connection,
Expand All @@ -209,6 +237,11 @@ namespace ignition
/// connections that have subscribed to the topic.
public: std::map<std::string, std::set<int>> topicConnections;

/// \brief The limit placed on the number of subscriptions per msg type
/// for each connection. The key is the msg type, e.g.
/// ignition.msgs.Image, and the value is the subscription limit
public: std::map<std::string, int> msgTypeSubscriptionLimit;

/// \brief Run loop mutex.
public: std::mutex runMutex;

Expand All @@ -230,6 +263,10 @@ namespace ignition
std::chrono::time_point<std::chrono::steady_clock>>
topicTimestamps;

/// \brief The message queue size per connection. A negative number
/// indicates no limit.
public: int queueSizePerConnection{-1};

/// \brief The set of valid operations. This enum must align with the
/// `operations` member variable.
private: enum Operation
Expand Down