Skip to content

Commit

Permalink
Extend websocket server to support message limits and throttling (#116)
Browse files Browse the repository at this point in the history
* implemented connection msg type limit, throttle, max queue size

Signed-off-by: Ian Chen <ichen@osrfoundation.org>

* update subcription count on sub / unsub

Signed-off-by: Ian Chen <ichen@osrfoundation.org>

* update doxy

Signed-off-by: Ian Chen <ichen@osrfoundation.org>

* Fix unsubscribe logic

Signed-off-by: Nate Koenig <nate@openrobotics.org>

Co-authored-by: Nate Koenig <nate@openrobotics.org>
  • Loading branch information
iche033 and Nate Koenig authored Jun 9, 2021
1 parent e48c613 commit c3dd804
Show file tree
Hide file tree
Showing 2 changed files with 242 additions and 20 deletions.
225 changes: 205 additions & 20 deletions plugins/websocket_server/WebsocketServer.cc
Original file line number Diff line number Diff line change
Expand Up @@ -267,6 +267,7 @@ int rootCallback(struct lws *_wsi,
case LWS_CALLBACK_SERVER_WRITEABLE:
{
std::lock_guard<std::mutex> lock(self->connections[fd]->mutex);

if (!self->connections[fd]->buffer.empty())
{
int msgSize = self->connections[fd]->len.front();
Expand Down Expand Up @@ -360,7 +361,6 @@ bool WebsocketServer::Load(const tinyxml2::XMLElement *_elem)
}
this->publishPeriod = std::chrono::duration_cast<std::chrono::nanoseconds>(
std::chrono::duration<double>(1.0 / hz));

// Get the authorization key, if present.
elem = _elem->FirstChildElement("authorization_key");
if (elem)
Expand Down Expand Up @@ -406,13 +406,62 @@ bool WebsocketServer::Load(const tinyxml2::XMLElement *_elem)
}
catch (...)
{
ignerr << "Failed to convert port[" << elem->GetText() << "] to integer."
<< std::endl;
ignerr << "Failed to convert max_connections[" << elem->GetText()
<< "] to integer." << std::endl;
}
igndbg << "Using maximum connection count of "
<< this->maxConnections << std::endl;
}

// Get the msg count per connection.
elem = _elem->FirstChildElement("queue_size_per_connection");
if (elem)
{
int size = -1;
auto result = elem->QueryIntText(&size);
if (result == tinyxml2::XML_SUCCESS && size >= 0)
{
this->queueSizePerConnection = size;
}
else
{
ignerr << "Failed to parse queue_size_per_connection["
<< elem->GetText() << "]." << std::endl;
}
igndbg << "Using connection msg queue size of "
<< this->queueSizePerConnection << std::endl;
}

// Get the msg type subscription limit
elem = _elem->FirstChildElement("subscription_limit_per_connection");
if (elem)
{
auto childElem = elem->FirstChildElement("subscription");
while (childElem)
{
auto msgTypeElem = childElem->FirstChildElement("msg_type");
auto limitElem = childElem->FirstChildElement("limit");
if (msgTypeElem && limitElem)
{
std::string msgType = msgTypeElem->GetText();
int limit = -1;
auto result = limitElem->QueryIntText(&limit);
if (result == tinyxml2::XML_SUCCESS && limit >= 0)
{
this->msgTypeSubscriptionLimit[msgType] = limit;
igndbg << "Setting msg type subscription limit[" << msgType
<< ", " << limit << "]" << std::endl;
}
else
{
ignerr << "Failed to parse subscription limit["
<< msgType << ", " << limitElem->GetText() << "]." << std::endl;
}
}
childElem = childElem->NextSiblingElement("subscription");
}
}

std::string sslCertFile = "";
std::string sslPrivateKeyFile = "";
elem = _elem->FirstChildElement("ssl");
Expand Down Expand Up @@ -527,12 +576,23 @@ void WebsocketServer::QueueMessage(Connection *_connection,
memcpy(buf.get() + LWS_PRE, _data, _size);

std::lock_guard<std::mutex> lock(_connection->mutex);
_connection->buffer.push_back(std::move(buf));
_connection->len.push_back(_size);
if (_connection->buffer.size() < this->queueSizePerConnection)
{
_connection->buffer.push_back(std::move(buf));
_connection->len.push_back(_size);

std::scoped_lock<std::mutex> runLock(this->runMutex);
this->messageCount++;
this->runConditionVariable.notify_all();
std::scoped_lock<std::mutex> runLock(this->runMutex);
this->messageCount++;
this->runConditionVariable.notify_all();
}
else
{
static bool warned{false};
if (!warned)
{
ignwarn << "Queue size reached for connection" << std::endl;
}
}
}
else
{
Expand Down Expand Up @@ -813,22 +873,34 @@ void WebsocketServer::OnMessage(int _socketId, const std::string &_msg)
}
else if (frameParts[0] == "sub")
{
std::string topic = frameParts[1];

// check and update subscription count
if (!this->UpdateMsgTypeSubscriptionCount(topic, _socketId, true))
return;

// Store the relation of socketId to subscribed topic.
this->topicConnections[frameParts[1]].insert(_socketId);
this->topicTimestamps[frameParts[1]] =
this->topicConnections[topic].insert(_socketId);
this->topicTimestamps[topic] =
std::chrono::steady_clock::now() - this->publishPeriod;

igndbg << "Subscribe request to topic[" << frameParts[1] << "]\n";
this->node.SubscribeRaw(frameParts[1],
this->node.SubscribeRaw(topic,
std::bind(&WebsocketServer::OnWebsocketSubscribedMessage,
this, std::placeholders::_1,
std::placeholders::_2, std::placeholders::_3));
}
else if (frameParts[0] == "image")
{
std::string topic = frameParts[1];

// check and update subscription count
if (!this->UpdateMsgTypeSubscriptionCount(topic, _socketId, true))
return;

// Store the relation of socketId to subscribed topic.
this->topicConnections[frameParts[1]].insert(_socketId);
this->topicTimestamps[frameParts[1]] =
this->topicConnections[topic].insert(_socketId);
this->topicTimestamps[topic] =
std::chrono::steady_clock::now() - this->publishPeriod;

std::vector<std::string> allTopics;
Expand All @@ -847,7 +919,7 @@ void WebsocketServer::OnMessage(int _socketId, const std::string &_msg)
}
}
}
std::string topic = frameParts[1];

if (!imageTopics.count(topic))
{
igndbg << "Could not find topic: " << topic << " to stream"
Expand All @@ -861,15 +933,25 @@ void WebsocketServer::OnMessage(int _socketId, const std::string &_msg)
}
else if (frameParts[0] == "unsub")
{
igndbg << "Unsubscribe request for topic[" << frameParts[1] << "]\n";
std::string topic = frameParts[1];

igndbg << "Unsubscribe request for topic[" << topic << "]\n";
std::map<std::string, std::set<int>>::iterator topicConnectionIter =
this->topicConnections.find(frameParts[1]);
this->topicConnections.find(topic);

if (topicConnectionIter != this->topicConnections.end())
{
// Remove from the topic connections map
topicConnectionIter->second.erase(_socketId);

// remove from the connection's topic throttling maps
auto &con = this->connections[_socketId];
con->topicPublishPeriods.erase(topic);
con->topicTimestamps.erase(topic);

// check and update subscription count
this->UpdateMsgTypeSubscriptionCount(topic, _socketId, false);

// Only unsubscribe from the Ignition Transport topic if there are no
// more websocket connections.
if (topicConnectionIter->second.empty())
Expand All @@ -882,7 +964,28 @@ void WebsocketServer::OnMessage(int _socketId, const std::string &_msg)
else
{
ignwarn << "The websocket server is not subscribed to topic["
<< frameParts[1] << "]. Unable to unsubscribe from the topic\n";
<< topic << "]. Unable to unsubscribe from the topic\n";
}
}
else if (frameParts[0] == "throttle")
{
std::string topic = frameParts[1];
igndbg << "Throttle request for topic[" << topic << "]\n";
if (!topic.empty())
{
try
{
int rate = std::stoi(frameParts[3]);
double period = 1.0 / static_cast<double>(rate);
this->connections[_socketId]->topicPublishPeriods[topic] =
std::chrono::duration_cast<std::chrono::nanoseconds>(
std::chrono::duration<double>(1.0 / rate));
}
catch (...)
{
ignwarn << "Unable to set topic rate for topic[" << topic
<< "]" << std::endl;
}
}
}
}
Expand Down Expand Up @@ -924,10 +1027,18 @@ void WebsocketServer::OnWebsocketSubscribedMessage(
// Send the message
for (const int &socketId : iter->second)
{
if (this->connections.find(socketId) != this->connections.end())
auto conIt = this->connections.find(socketId);
if (conIt != this->connections.end())
{
this->QueueMessage(this->connections[socketId].get(),
msg.c_str(), msg.length());
// do additional throttling based on client connection setting
auto lastPubTimeCon = conIt->second->topicTimestamps[_info.Topic()];
std::chrono::nanoseconds timeDeltaCon = systemTime - lastPubTimeCon;
if (timeDeltaCon >= conIt->second->topicPublishPeriods[_info.Topic()])
{
conIt->second->topicTimestamps[_info.Topic()] = systemTime;
this->QueueMessage(conIt->second.get(),
msg.c_str(), msg.length());
}
}
}
}
Expand Down Expand Up @@ -1011,3 +1122,77 @@ void WebsocketServer::OnWebsocketSubscribedImageMessage(
}
}
}

//////////////////////////////////////////////////
bool WebsocketServer::UpdateMsgTypeSubscriptionCount(const std::string &_topic,
int _socketId, bool _subscribe)
{
// check if limit reached for the subscribed msg type
// if not, update subscription count
std::vector<transport::MessagePublisher> publishers;
this->node.TopicInfo(_topic, publishers);
if (!publishers.empty())
{
std::string msgType = publishers.begin()->MsgTypeName();
auto limitIt = this->msgTypeSubscriptionLimit.find(msgType);
if (limitIt != this->msgTypeSubscriptionLimit.end())
{
bool limitReached = false;
auto conIt = this->connections.find(_socketId);
if (conIt != this->connections.end())
{
auto &con = conIt->second;
auto &subCount = con->msgTypeSubscriptionCount;
auto countIt = subCount.find(msgType);

// if there is already a subscription on the topic for this connection
if (countIt != subCount.end())
{
// subscribe: increment count and check if reached limit
// unsubscribe: decrement count and make sure count is >= 0
if (_subscribe)
{
if (countIt->second + 1 <= limitIt->second)
{
countIt->second++;
}
else
{
limitReached = true;
}
}
else
{
countIt->second = std::max(0, countIt->second - 1);
}
}
// if topic not yet subscribed, set count to 1 on subscription
// ignore for unsubscribe option
else if (limitIt->second > 0)
{
if (_subscribe)
subCount[msgType] = 1;
}
// corner case when msg type subscription limit is set to 0
else if (_subscribe)
{
limitReached = true;
}
if (limitReached)
{
ignwarn << "Msg type subscription limit reached[" << msgType
<< ", " << limitIt->second << "] for connection[" << _socketId
<< "]" << std::endl;
return false;
}
}
else
{
ignwarn << "Unable to find connection[" << _socketId << "]"
<< " when setting subscription limit." << std::endl;
return false;
}
}
}
return true;
}
37 changes: 37 additions & 0 deletions plugins/websocket_server/WebsocketServer.hh
Original file line number Diff line number Diff line change
Expand Up @@ -175,6 +175,19 @@ namespace ignition

public: void OnRequestMessage(int _socketId, const std::string &_msg);

/// \brief Check and update subscription count for a message type. If
/// a client has more subscriptions to a topic of a specified type than
/// the subscription limit, this will block subscription. On the other
/// hand, for an unsubscription operation, the count is decremented.
/// \param[in] _topic Topic to subscribe to or unsubscribe from
/// \param[in] _socketId Connection socket id
/// \param[in] _subscribe True for subscribe operation, false for
/// unsubscribe operation
/// \return True if the subscription count is incremented or decremented,
/// and false to indicate the subcription limit has reached.
public: bool UpdateMsgTypeSubscriptionCount(const std::string &_topic,
int _socketId, bool _subscribe);

private: ignition::transport::Node node;

private: bool run = true;
Expand All @@ -190,6 +203,21 @@ namespace ignition
public: std::mutex mutex;

public: bool authorized{false};

/// \brief A map of topic name to outbound publish rate
/// A value of 0 means unthrottled
public: std::map<std::string, std::chrono::nanoseconds>
topicPublishPeriods;

/// \brief A map of topic name to timestamp of last published message
/// for this connection
public: std::map<std::string,
std::chrono::time_point<std::chrono::steady_clock>> topicTimestamps;

/// \brief The number of subscriptions of a msg type this connection
/// has. The key is the msg type, e.g. ignition.msgs.Image, and the
/// value is the subscription count
public: std::map<std::string, int> msgTypeSubscriptionCount;
};

private: void QueueMessage(Connection *_connection,
Expand All @@ -209,6 +237,11 @@ namespace ignition
/// connections that have subscribed to the topic.
public: std::map<std::string, std::set<int>> topicConnections;

/// \brief The limit placed on the number of subscriptions per msg type
/// for each connection. The key is the msg type, e.g.
/// ignition.msgs.Image, and the value is the subscription limit
public: std::map<std::string, int> msgTypeSubscriptionLimit;

/// \brief Run loop mutex.
public: std::mutex runMutex;

Expand All @@ -230,6 +263,10 @@ namespace ignition
std::chrono::time_point<std::chrono::steady_clock>>
topicTimestamps;

/// \brief The message queue size per connection. A negative number
/// indicates no limit.
public: int queueSizePerConnection{-1};

/// \brief The set of valid operations. This enum must align with the
/// `operations` member variable.
private: enum Operation
Expand Down

0 comments on commit c3dd804

Please sign in to comment.