From df100fbdae792648da56aef42eb35a1941bfe138 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Javier=20Cervi=C3=B1o?= Date: Fri, 6 Apr 2018 13:10:34 +0200 Subject: [PATCH] Improve support to forward REMB packets to each stream (#1186) --- erizo/src/erizo/MediaStream.cpp | 5 + erizo/src/erizo/MediaStream.h | 3 +- erizo/src/erizo/WebRtcConnection.cpp | 53 +++++-- erizo/src/erizo/WebRtcConnection.h | 3 + erizo/src/erizo/rtp/RtpUtils.cpp | 20 +++ erizo/src/erizo/rtp/RtpUtils.h | 1 + erizo/src/test/WebRtcConnectionTest.cpp | 182 ++++++++++++++++++++++++ erizo/src/test/utils/Mocks.h | 34 ++++- 8 files changed, 288 insertions(+), 13 deletions(-) create mode 100644 erizo/src/test/WebRtcConnectionTest.cpp diff --git a/erizo/src/erizo/MediaStream.cpp b/erizo/src/erizo/MediaStream.cpp index 8fa0f9f738..00c37f46cd 100644 --- a/erizo/src/erizo/MediaStream.cpp +++ b/erizo/src/erizo/MediaStream.cpp @@ -86,6 +86,11 @@ MediaStream::~MediaStream() { ELOG_DEBUG("%s message: Destructor ended", toLog()); } +uint32_t MediaStream::getMaxVideoBW() { + uint32_t bitrate = rtcp_processor_ ? rtcp_processor_->getMaxVideoBW() : 0; + return bitrate; +} + void MediaStream::syncClose() { ELOG_DEBUG("%s message:Close called", toLog()); if (!sending_) { diff --git a/erizo/src/erizo/MediaStream.h b/erizo/src/erizo/MediaStream.h index 88875bbe5b..dc04512aeb 100644 --- a/erizo/src/erizo/MediaStream.h +++ b/erizo/src/erizo/MediaStream.h @@ -60,6 +60,7 @@ class MediaStream: public MediaSink, public MediaSource, public FeedbackSink, virtual ~MediaStream(); bool init(); void close() override; + virtual uint32_t getMaxVideoBW(); void syncClose(); bool setRemoteSdp(std::shared_ptr sdp); bool setLocalSdp(std::shared_ptr sdp); @@ -84,7 +85,7 @@ class MediaStream: public MediaSink, public MediaSource, public FeedbackSink, void getJSONStats(std::function callback); - void onTransportData(std::shared_ptr packet, Transport *transport); + virtual void onTransportData(std::shared_ptr packet, Transport *transport); void sendPacketAsync(std::shared_ptr packet); diff --git a/erizo/src/erizo/WebRtcConnection.cpp b/erizo/src/erizo/WebRtcConnection.cpp index 253eecdb9a..4cc5a10fae 100644 --- a/erizo/src/erizo/WebRtcConnection.cpp +++ b/erizo/src/erizo/WebRtcConnection.cpp @@ -71,6 +71,7 @@ void WebRtcConnection::syncClose() { return; } sending_ = false; + media_streams_.clear(); if (video_transport_.get()) { video_transport_->close(); } @@ -470,22 +471,49 @@ void WebRtcConnection::onCandidate(const CandidateInfo& cand, Transport *transpo } } +void WebRtcConnection::onREMBFromTransport(RtcpHeader *chead, Transport *transport) { + std::vector> streams; + + for (uint8_t index = 0; index < chead->getREMBNumSSRC(); index++) { + uint32_t ssrc_feed = chead->getREMBFeedSSRC(index); + forEachMediaStream([ssrc_feed, &streams] (const std::shared_ptr &media_stream) { + if (media_stream->isSinkSSRC(ssrc_feed)) { + streams.push_back(media_stream); + } + }); + } + + std::sort(streams.begin(), streams.end(), + [](const std::shared_ptr &i, const std::shared_ptr &j) { + return i->getMaxVideoBW() < j->getMaxVideoBW(); + }); + + uint8_t remaining_streams = streams.size(); + uint32_t remaining_bitrate = chead->getREMBBitRate(); + std::for_each(streams.begin(), streams.end(), + [&remaining_bitrate, &remaining_streams, transport, chead](const std::shared_ptr &stream) { + uint32_t max_bitrate = stream->getMaxVideoBW(); + uint32_t remaining_avg_bitrate = remaining_bitrate / remaining_streams; + uint32_t bitrate = std::min(max_bitrate, remaining_avg_bitrate); + auto generated_remb = RtpUtils::createREMB(chead->getSSRC(), {stream->getVideoSinkSSRC()}, bitrate); + stream->onTransportData(generated_remb, transport); + remaining_bitrate -= bitrate; + remaining_streams--; + }); +} + void WebRtcConnection::onRtcpFromTransport(std::shared_ptr packet, Transport *transport) { RtpUtils::forEachRtcpBlock(packet, [this, packet, transport](RtcpHeader *chead) { uint32_t ssrc = chead->isFeedback() ? chead->getSourceSSRC() : chead->getSSRC(); + if (chead->isREMB()) { + onREMBFromTransport(chead, transport); + return; + } std::shared_ptr rtcp = std::make_shared(*packet); rtcp->length = (ntohs(chead->length) + 1) * 4; std::memcpy(rtcp->data, chead, rtcp->length); - forEachMediaStream([rtcp, transport, ssrc, chead] (const std::shared_ptr &media_stream) { - if (chead->isREMB()) { - for (uint8_t index = 0; index < chead->getREMBNumSSRC(); index++) { - uint32_t ssrc_feed = chead->getREMBFeedSSRC(index); - if (media_stream->isSourceSSRC(ssrc_feed) || media_stream->isSinkSSRC(ssrc_feed)) { - // TODO(javier): Calculate the portion of bitrate that corresponds to this stream. - media_stream->onTransportData(rtcp, transport); - } - } - } else if (media_stream->isSourceSSRC(ssrc) || media_stream->isSinkSSRC(ssrc)) { + forEachMediaStream([rtcp, transport, ssrc] (const std::shared_ptr &media_stream) { + if (media_stream->isSourceSSRC(ssrc) || media_stream->isSinkSSRC(ssrc)) { media_stream->onTransportData(rtcp, transport); } }); @@ -681,4 +709,9 @@ void WebRtcConnection::syncWrite(std::shared_ptr packet) { transport->write(packet->data, packet->length); } +void WebRtcConnection::setTransport(std::shared_ptr transport) { // Only for Testing purposes + video_transport_ = transport; + bundle_ = true; +} + } // namespace erizo diff --git a/erizo/src/erizo/WebRtcConnection.h b/erizo/src/erizo/WebRtcConnection.h index 686a2a236e..ec84465bc8 100644 --- a/erizo/src/erizo/WebRtcConnection.h +++ b/erizo/src/erizo/WebRtcConnection.h @@ -142,6 +142,8 @@ class WebRtcConnection: public TransportListener, public LogContext, void forEachMediaStream(std::function&)> func); void forEachMediaStreamAsync(std::function&)> func); + void setTransport(std::shared_ptr transport); // Only for Testing purposes + std::shared_ptr getStatsService() { return stats_; } RtpExtensionProcessor& getRtpExtensionProcessor() { return extension_processor_; } @@ -159,6 +161,7 @@ class WebRtcConnection: public TransportListener, public LogContext, std::string getJSONCandidate(const std::string& mid, const std::string& sdp); void trackTransportInfo(); void onRtcpFromTransport(std::shared_ptr packet, Transport *transport); + void onREMBFromTransport(RtcpHeader *chead, Transport *transport); private: std::string connection_id_; diff --git a/erizo/src/erizo/rtp/RtpUtils.cpp b/erizo/src/erizo/rtp/RtpUtils.cpp index 82f4230e09..62c263cc39 100644 --- a/erizo/src/erizo/rtp/RtpUtils.cpp +++ b/erizo/src/erizo/rtp/RtpUtils.cpp @@ -86,6 +86,26 @@ std::shared_ptr RtpUtils::createFIR(uint32_t source_ssrc, uint32_t s return std::make_shared(0, buf, len, VIDEO_PACKET); } +std::shared_ptr RtpUtils::createREMB(uint32_t ssrc, std::vector ssrc_list, uint32_t bitrate) { + erizo::RtcpHeader remb; + remb.setPacketType(RTCP_PS_Feedback_PT); + remb.setBlockCount(RTCP_AFB); + memcpy(&remb.report.rembPacket.uniqueid, "REMB", 4); + + remb.setSSRC(ssrc); + remb.setSourceSSRC(0); + remb.setLength(4 + ssrc_list.size()); + remb.setREMBBitRate(bitrate); + remb.setREMBNumSSRC(ssrc_list.size()); + uint8_t index = 0; + for (uint32_t feed_ssrc : ssrc_list) { + remb.setREMBFeedSSRC(index++, feed_ssrc); + } + int len = (remb.getLength() + 1) * 4; + char *buf = reinterpret_cast(&remb); + return std::make_shared(0, buf, len, erizo::OTHER_PACKET); +} + int RtpUtils::getPaddingLength(std::shared_ptr packet) { RtpHeader *rtp_header = reinterpret_cast(packet->data); diff --git a/erizo/src/erizo/rtp/RtpUtils.h b/erizo/src/erizo/rtp/RtpUtils.h index ff47a2d50c..902277653d 100644 --- a/erizo/src/erizo/rtp/RtpUtils.h +++ b/erizo/src/erizo/rtp/RtpUtils.h @@ -28,6 +28,7 @@ class RtpUtils { static std::shared_ptr createPLI(uint32_t source_ssrc, uint32_t sink_ssrc); static std::shared_ptr createFIR(uint32_t source_ssrc, uint32_t sink_ssrc, uint8_t seq_number); + static std::shared_ptr createREMB(uint32_t ssrc, std::vector ssrc_list, uint32_t bitrate); static int getPaddingLength(std::shared_ptr packet); diff --git a/erizo/src/test/WebRtcConnectionTest.cpp b/erizo/src/test/WebRtcConnectionTest.cpp new file mode 100644 index 0000000000..073b3061b9 --- /dev/null +++ b/erizo/src/test/WebRtcConnectionTest.cpp @@ -0,0 +1,182 @@ +#include +#include + +#include +#include +#include +#include + +#include +#include + +#include "utils/Mocks.h" +#include "utils/Matchers.h" + +using testing::_; +using testing::Return; +using testing::Eq; +using testing::Args; +using testing::AtLeast; +using erizo::DataPacket; +using erizo::ExtMap; +using erizo::IceConfig; +using erizo::RtpMap; +using erizo::RtpUtils; +using erizo::WebRtcConnection; + +typedef std::vector MaxList; +typedef std::vector EnabledList; +typedef std::vector ExpectedList; + +class WebRtcConnectionTest : + public ::testing::TestWithParam> { + protected: + virtual void SetUp() { + index = 0; + simulated_clock = std::make_shared(); + simulated_worker = std::make_shared(simulated_clock); + simulated_worker->start(); + io_worker = std::make_shared(); + io_worker->start(); + connection = std::make_shared(simulated_worker, io_worker, + "test_connection", ice_config, rtp_maps, ext_maps, nullptr); + transport = std::make_shared("test_connection", true, ice_config, + simulated_worker, io_worker); + connection->setTransport(transport); + connection->updateState(TRANSPORT_READY, transport.get()); + max_video_bw_list = std::tr1::get<0>(GetParam()); + bitrate_value = std::tr1::get<1>(GetParam()); + add_to_remb_list = std::tr1::get<2>(GetParam()); + expected_bitrates = std::tr1::get<3>(GetParam()); + + setUpStreams(); + } + + void setUpStreams() { + for (uint32_t max_video_bw : max_video_bw_list) { + streams.push_back(addMediaStream(false, max_video_bw)); + } + } + + std::shared_ptr addMediaStream(bool is_publisher, uint32_t max_video_bw) { + std::string id = std::to_string(index); + std::string label = std::to_string(index); + uint32_t video_sink_ssrc = getSsrcFromIndex(index); + uint32_t audio_sink_ssrc = getSsrcFromIndex(index) + 1; + uint32_t video_source_ssrc = getSsrcFromIndex(index) + 2; + uint32_t audio_source_ssrc = getSsrcFromIndex(index) + 3; + auto media_stream = std::make_shared(simulated_worker, connection, id, label, + rtp_maps, is_publisher); + media_stream->setVideoSinkSSRC(video_sink_ssrc); + media_stream->setAudioSinkSSRC(audio_sink_ssrc); + media_stream->setVideoSourceSSRC(video_source_ssrc); + media_stream->setAudioSourceSSRC(audio_source_ssrc); + connection->addMediaStream(media_stream); + simulated_worker->executeTasks(); + EXPECT_CALL(*media_stream, getMaxVideoBW()).Times(AtLeast(0)).WillRepeatedly(Return(max_video_bw)); + index++; + return media_stream; + } + + void onRembReceived(uint32_t bitrate, std::vector ids) { + std::transform(ids.begin(), ids.end(), ids.begin(), [](uint32_t id) { + return id * 1000; + }); + auto remb = RtpUtils::createREMB(ids[0], ids, bitrate); + connection->onTransportData(remb, transport.get()); + } + + void onRembReceived() { + uint32_t index = 0; + std::vector ids; + for (bool enabled : add_to_remb_list) { + if (enabled) { + ids.push_back(index); + } + index++; + } + onRembReceived(bitrate_value, ids); + } + + uint32_t getIndexFromSsrc(uint32_t ssrc) { + return ssrc / 1000; + } + + uint32_t getSsrcFromIndex(uint32_t index) { + return index * 1000; + } + + virtual void TearDown() { + connection->close(); + simulated_worker->executeTasks(); + streams.clear(); + } + + std::vector> streams; + MaxList max_video_bw_list; + uint32_t bitrate_value; + EnabledList add_to_remb_list; + ExpectedList expected_bitrates; + IceConfig ice_config; + std::vector rtp_maps; + std::vector ext_maps; + uint32_t index; + std::shared_ptr transport; + std::shared_ptr connection; + std::shared_ptr processor; + std::shared_ptr simulated_clock; + std::shared_ptr simulated_worker; + std::shared_ptr io_worker; + std::queue> packet_queue; +}; + +TEST_P(WebRtcConnectionTest, forwardRembToStreams_When_StreamTheyExist) { + uint32_t index = 0; + for (int32_t expected_bitrate : expected_bitrates) { + if (expected_bitrate > 0) { + EXPECT_CALL(*(streams[index]), onTransportData(_, _)) + .With(Args<0>(erizo::RembHasBitrateValue(static_cast(expected_bitrate)))).Times(1); + } else { + EXPECT_CALL(*streams[index], onTransportData(_, _)).Times(0); + } + index++; + } + + onRembReceived(); +} + +INSTANTIATE_TEST_CASE_P( + REMB_values, WebRtcConnectionTest, testing::Values( + std::make_tuple(MaxList{300}, 100, EnabledList{1}, ExpectedList{100}), + std::make_tuple(MaxList{300}, 600, EnabledList{1}, ExpectedList{300}), + + std::make_tuple(MaxList{300, 300}, 300, EnabledList{1, 0}, ExpectedList{300, -1}), + std::make_tuple(MaxList{300, 300}, 300, EnabledList{0, 1}, ExpectedList{-1, 300}), + std::make_tuple(MaxList{300, 300}, 300, EnabledList{1, 1}, ExpectedList{150, 150}), + std::make_tuple(MaxList{100, 300}, 300, EnabledList{1, 1}, ExpectedList{100, 200}), + std::make_tuple(MaxList{300, 100}, 300, EnabledList{1, 1}, ExpectedList{200, 100}), + std::make_tuple(MaxList{100, 100}, 300, EnabledList{1, 1}, ExpectedList{100, 100}), + + std::make_tuple(MaxList{300, 300, 300}, 300, EnabledList{1, 0, 0}, ExpectedList{300, -1, -1}), + std::make_tuple(MaxList{300, 300, 300}, 300, EnabledList{0, 1, 0}, ExpectedList{ -1, 300, -1}), + std::make_tuple(MaxList{300, 300, 300}, 300, EnabledList{1, 1, 0}, ExpectedList{150, 150, -1}), + std::make_tuple(MaxList{100, 300, 300}, 300, EnabledList{1, 1, 0}, ExpectedList{100, 200, -1}), + std::make_tuple(MaxList{300, 100, 300}, 300, EnabledList{1, 1, 0}, ExpectedList{200, 100, -1}), + std::make_tuple(MaxList{100, 100, 300}, 300, EnabledList{1, 1, 0}, ExpectedList{100, 100, -1}), + + std::make_tuple(MaxList{300, 300, 300}, 300, EnabledList{0, 1, 0}, ExpectedList{-1, 300, -1}), + std::make_tuple(MaxList{300, 300, 300}, 300, EnabledList{0, 0, 1}, ExpectedList{-1, -1, 300}), + std::make_tuple(MaxList{300, 300, 300}, 300, EnabledList{0, 1, 1}, ExpectedList{-1, 150, 150}), + std::make_tuple(MaxList{300, 100, 300}, 300, EnabledList{0, 1, 1}, ExpectedList{-1, 100, 200}), + std::make_tuple(MaxList{300, 300, 100}, 300, EnabledList{0, 1, 1}, ExpectedList{-1, 200, 100}), + std::make_tuple(MaxList{300, 100, 100}, 300, EnabledList{0, 1, 1}, ExpectedList{-1, 100, 100}), + + std::make_tuple(MaxList{100, 100, 100}, 300, EnabledList{1, 1, 1}, ExpectedList{100, 100, 100}), + std::make_tuple(MaxList{100, 100, 100}, 600, EnabledList{1, 1, 1}, ExpectedList{100, 100, 100}), + std::make_tuple(MaxList{300, 300, 300}, 600, EnabledList{1, 1, 1}, ExpectedList{200, 200, 200}), + std::make_tuple(MaxList{100, 200, 300}, 600, EnabledList{1, 1, 1}, ExpectedList{100, 200, 300}), + std::make_tuple(MaxList{300, 200, 100}, 600, EnabledList{1, 1, 1}, ExpectedList{300, 200, 100}), + std::make_tuple(MaxList{100, 500, 500}, 800, EnabledList{1, 1, 1}, ExpectedList{100, 350, 350}))); diff --git a/erizo/src/test/utils/Mocks.h b/erizo/src/test/utils/Mocks.h index 1c456559cc..3a1761777c 100644 --- a/erizo/src/test/utils/Mocks.h +++ b/erizo/src/test/utils/Mocks.h @@ -60,6 +60,33 @@ class MockMediaSink : public MediaSink { } }; +class MockTransport: public Transport { + public: + MockTransport(std::string connection_id, bool bundle, const IceConfig &ice_config, + std::shared_ptr worker, std::shared_ptr io_worker) : + Transport(VIDEO_TYPE, "video", connection_id, bundle, true, + std::shared_ptr(nullptr), ice_config, + worker, io_worker) {} + + virtual ~MockTransport() { + } + + void updateIceState(IceState state, IceConnection *conn) override { + } + void onIceData(packetPtr packet) override { + } + void onCandidate(const CandidateInfo &candidate, IceConnection *conn) override { + } + void write(char* data, int len) override { + } + void processLocalSdp(SdpInfo *localSdp_) override { + } + void start() override { + } + void close() override { + } +}; + class MockWebRtcConnection: public WebRtcConnection { public: MockWebRtcConnection(std::shared_ptr worker, std::shared_ptr io_worker, const IceConfig &ice_config, @@ -74,11 +101,14 @@ class MockMediaStream: public MediaStream { public: MockMediaStream(std::shared_ptr worker, std::shared_ptr connection, const std::string& media_stream_id, const std::string& media_stream_label, - std::vector rtp_mappings) : - MediaStream(worker, connection, media_stream_id, media_stream_label, true) { + std::vector rtp_mappings, bool is_publisher = true) : + MediaStream(worker, connection, media_stream_id, media_stream_label, is_publisher) { local_sdp_ = std::make_shared(rtp_mappings); remote_sdp_ = std::make_shared(rtp_mappings); } + + MOCK_METHOD0(getMaxVideoBW, uint32_t()); + MOCK_METHOD2(onTransportData, void(std::shared_ptr, Transport*)); }; class Reader : public InboundHandler {