Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Improve support to forward REMB packets to each stream #1186

Merged
merged 2 commits into from
Apr 6, 2018
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions erizo/src/erizo/MediaStream.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -86,6 +86,11 @@ MediaStream::~MediaStream() {
ELOG_DEBUG("%s message: Destructor ended", toLog());
}

uint32_t MediaStream::getMaxVideoBW() {
uint32_t bitrate = rtcp_processor_ ? rtcp_processor_->getMaxVideoBW() : 0;
return bitrate;
}

void MediaStream::syncClose() {
ELOG_DEBUG("%s message:Close called", toLog());
if (!sending_) {
Expand Down
3 changes: 2 additions & 1 deletion erizo/src/erizo/MediaStream.h
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,7 @@ class MediaStream: public MediaSink, public MediaSource, public FeedbackSink,
virtual ~MediaStream();
bool init();
void close() override;
virtual uint32_t getMaxVideoBW();
void syncClose();
bool setRemoteSdp(std::shared_ptr<SdpInfo> sdp);
bool setLocalSdp(std::shared_ptr<SdpInfo> sdp);
Expand All @@ -84,7 +85,7 @@ class MediaStream: public MediaSink, public MediaSource, public FeedbackSink,

void getJSONStats(std::function<void(std::string)> callback);

void onTransportData(std::shared_ptr<DataPacket> packet, Transport *transport);
virtual void onTransportData(std::shared_ptr<DataPacket> packet, Transport *transport);

void sendPacketAsync(std::shared_ptr<DataPacket> packet);

Expand Down
53 changes: 43 additions & 10 deletions erizo/src/erizo/WebRtcConnection.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,7 @@ void WebRtcConnection::syncClose() {
return;
}
sending_ = false;
media_streams_.clear();
if (video_transport_.get()) {
video_transport_->close();
}
Expand Down Expand Up @@ -470,22 +471,49 @@ void WebRtcConnection::onCandidate(const CandidateInfo& cand, Transport *transpo
}
}

void WebRtcConnection::onREMBFromTransport(RtcpHeader *chead, Transport *transport) {
std::vector<std::shared_ptr<MediaStream>> streams;

for (uint8_t index = 0; index < chead->getREMBNumSSRC(); index++) {
uint32_t ssrc_feed = chead->getREMBFeedSSRC(index);
forEachMediaStream([ssrc_feed, &streams] (const std::shared_ptr<MediaStream> &media_stream) {
if (media_stream->isSinkSSRC(ssrc_feed)) {
streams.push_back(media_stream);
}
});
}

std::sort(streams.begin(), streams.end(),
[](const std::shared_ptr<MediaStream> &i, const std::shared_ptr<MediaStream> &j) {
return i->getMaxVideoBW() < j->getMaxVideoBW();
});

uint8_t remaining_streams = streams.size();
uint32_t remaining_bitrate = chead->getREMBBitRate();
std::for_each(streams.begin(), streams.end(),
[&remaining_bitrate, &remaining_streams, transport, chead](const std::shared_ptr<MediaStream> &stream) {
uint32_t max_bitrate = stream->getMaxVideoBW();
uint32_t remaining_avg_bitrate = remaining_bitrate / remaining_streams;
uint32_t bitrate = std::min(max_bitrate, remaining_avg_bitrate);
auto generated_remb = RtpUtils::createREMB(chead->getSSRC(), {stream->getVideoSinkSSRC()}, bitrate);
stream->onTransportData(generated_remb, transport);
remaining_bitrate -= bitrate;
remaining_streams--;
});
}

void WebRtcConnection::onRtcpFromTransport(std::shared_ptr<DataPacket> packet, Transport *transport) {
RtpUtils::forEachRtcpBlock(packet, [this, packet, transport](RtcpHeader *chead) {
uint32_t ssrc = chead->isFeedback() ? chead->getSourceSSRC() : chead->getSSRC();
if (chead->isREMB()) {
onREMBFromTransport(chead, transport);
return;
}
std::shared_ptr<DataPacket> rtcp = std::make_shared<DataPacket>(*packet);
rtcp->length = (ntohs(chead->length) + 1) * 4;
std::memcpy(rtcp->data, chead, rtcp->length);
forEachMediaStream([rtcp, transport, ssrc, chead] (const std::shared_ptr<MediaStream> &media_stream) {
if (chead->isREMB()) {
for (uint8_t index = 0; index < chead->getREMBNumSSRC(); index++) {
uint32_t ssrc_feed = chead->getREMBFeedSSRC(index);
if (media_stream->isSourceSSRC(ssrc_feed) || media_stream->isSinkSSRC(ssrc_feed)) {
// TODO(javier): Calculate the portion of bitrate that corresponds to this stream.
media_stream->onTransportData(rtcp, transport);
}
}
} else if (media_stream->isSourceSSRC(ssrc) || media_stream->isSinkSSRC(ssrc)) {
forEachMediaStream([rtcp, transport, ssrc] (const std::shared_ptr<MediaStream> &media_stream) {
if (media_stream->isSourceSSRC(ssrc) || media_stream->isSinkSSRC(ssrc)) {
media_stream->onTransportData(rtcp, transport);
}
});
Expand Down Expand Up @@ -681,4 +709,9 @@ void WebRtcConnection::syncWrite(std::shared_ptr<DataPacket> packet) {
transport->write(packet->data, packet->length);
}

void WebRtcConnection::setTransport(std::shared_ptr<Transport> transport) { // Only for Testing purposes
video_transport_ = transport;
bundle_ = true;
}

} // namespace erizo
3 changes: 3 additions & 0 deletions erizo/src/erizo/WebRtcConnection.h
Original file line number Diff line number Diff line change
Expand Up @@ -142,6 +142,8 @@ class WebRtcConnection: public TransportListener, public LogContext,
void forEachMediaStream(std::function<void(const std::shared_ptr<MediaStream>&)> func);
void forEachMediaStreamAsync(std::function<void(const std::shared_ptr<MediaStream>&)> func);

void setTransport(std::shared_ptr<Transport> transport); // Only for Testing purposes

std::shared_ptr<Stats> getStatsService() { return stats_; }

RtpExtensionProcessor& getRtpExtensionProcessor() { return extension_processor_; }
Expand All @@ -159,6 +161,7 @@ class WebRtcConnection: public TransportListener, public LogContext,
std::string getJSONCandidate(const std::string& mid, const std::string& sdp);
void trackTransportInfo();
void onRtcpFromTransport(std::shared_ptr<DataPacket> packet, Transport *transport);
void onREMBFromTransport(RtcpHeader *chead, Transport *transport);

private:
std::string connection_id_;
Expand Down
20 changes: 20 additions & 0 deletions erizo/src/erizo/rtp/RtpUtils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -86,6 +86,26 @@ std::shared_ptr<DataPacket> RtpUtils::createFIR(uint32_t source_ssrc, uint32_t s
return std::make_shared<DataPacket>(0, buf, len, VIDEO_PACKET);
}

std::shared_ptr<DataPacket> RtpUtils::createREMB(uint32_t ssrc, std::vector<uint32_t> ssrc_list, uint32_t bitrate) {
erizo::RtcpHeader remb;
remb.setPacketType(RTCP_PS_Feedback_PT);
remb.setBlockCount(RTCP_AFB);
memcpy(&remb.report.rembPacket.uniqueid, "REMB", 4);

remb.setSSRC(ssrc);
remb.setSourceSSRC(0);
remb.setLength(4 + ssrc_list.size());
remb.setREMBBitRate(bitrate);
remb.setREMBNumSSRC(ssrc_list.size());
uint8_t index = 0;
for (uint32_t feed_ssrc : ssrc_list) {
remb.setREMBFeedSSRC(index++, feed_ssrc);
}
int len = (remb.getLength() + 1) * 4;
char *buf = reinterpret_cast<char*>(&remb);
return std::make_shared<erizo::DataPacket>(0, buf, len, erizo::OTHER_PACKET);
}


int RtpUtils::getPaddingLength(std::shared_ptr<DataPacket> packet) {
RtpHeader *rtp_header = reinterpret_cast<RtpHeader*>(packet->data);
Expand Down
1 change: 1 addition & 0 deletions erizo/src/erizo/rtp/RtpUtils.h
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ class RtpUtils {
static std::shared_ptr<DataPacket> createPLI(uint32_t source_ssrc, uint32_t sink_ssrc);

static std::shared_ptr<DataPacket> createFIR(uint32_t source_ssrc, uint32_t sink_ssrc, uint8_t seq_number);
static std::shared_ptr<DataPacket> createREMB(uint32_t ssrc, std::vector<uint32_t> ssrc_list, uint32_t bitrate);

static int getPaddingLength(std::shared_ptr<DataPacket> packet);

Expand Down
182 changes: 182 additions & 0 deletions erizo/src/test/WebRtcConnectionTest.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,182 @@
#include <gmock/gmock.h>
#include <gtest/gtest.h>

#include <rtp/RtpHeaders.h>
#include <rtp/RtpUtils.h>
#include <MediaDefinitions.h>
#include <WebRtcConnection.h>

#include <string>
#include <tuple>

#include "utils/Mocks.h"
#include "utils/Matchers.h"

using testing::_;
using testing::Return;
using testing::Eq;
using testing::Args;
using testing::AtLeast;
using erizo::DataPacket;
using erizo::ExtMap;
using erizo::IceConfig;
using erizo::RtpMap;
using erizo::RtpUtils;
using erizo::WebRtcConnection;

typedef std::vector<uint32_t> MaxList;
typedef std::vector<bool> EnabledList;
typedef std::vector<int32_t> ExpectedList;

class WebRtcConnectionTest :
public ::testing::TestWithParam<std::tr1::tuple<MaxList,
uint32_t,
EnabledList,
ExpectedList>> {
protected:
virtual void SetUp() {
index = 0;
simulated_clock = std::make_shared<erizo::SimulatedClock>();
simulated_worker = std::make_shared<erizo::SimulatedWorker>(simulated_clock);
simulated_worker->start();
io_worker = std::make_shared<erizo::IOWorker>();
io_worker->start();
connection = std::make_shared<WebRtcConnection>(simulated_worker, io_worker,
"test_connection", ice_config, rtp_maps, ext_maps, nullptr);
transport = std::make_shared<erizo::MockTransport>("test_connection", true, ice_config,
simulated_worker, io_worker);
connection->setTransport(transport);
connection->updateState(TRANSPORT_READY, transport.get());
max_video_bw_list = std::tr1::get<0>(GetParam());
bitrate_value = std::tr1::get<1>(GetParam());
add_to_remb_list = std::tr1::get<2>(GetParam());
expected_bitrates = std::tr1::get<3>(GetParam());

setUpStreams();
}

void setUpStreams() {
for (uint32_t max_video_bw : max_video_bw_list) {
streams.push_back(addMediaStream(false, max_video_bw));
}
}

std::shared_ptr<erizo::MockMediaStream> addMediaStream(bool is_publisher, uint32_t max_video_bw) {
std::string id = std::to_string(index);
std::string label = std::to_string(index);
uint32_t video_sink_ssrc = getSsrcFromIndex(index);
uint32_t audio_sink_ssrc = getSsrcFromIndex(index) + 1;
uint32_t video_source_ssrc = getSsrcFromIndex(index) + 2;
uint32_t audio_source_ssrc = getSsrcFromIndex(index) + 3;
auto media_stream = std::make_shared<erizo::MockMediaStream>(simulated_worker, connection, id, label,
rtp_maps, is_publisher);
media_stream->setVideoSinkSSRC(video_sink_ssrc);
media_stream->setAudioSinkSSRC(audio_sink_ssrc);
media_stream->setVideoSourceSSRC(video_source_ssrc);
media_stream->setAudioSourceSSRC(audio_source_ssrc);
connection->addMediaStream(media_stream);
simulated_worker->executeTasks();
EXPECT_CALL(*media_stream, getMaxVideoBW()).Times(AtLeast(0)).WillRepeatedly(Return(max_video_bw));
index++;
return media_stream;
}

void onRembReceived(uint32_t bitrate, std::vector<uint32_t> ids) {
std::transform(ids.begin(), ids.end(), ids.begin(), [](uint32_t id) {
return id * 1000;
});
auto remb = RtpUtils::createREMB(ids[0], ids, bitrate);
connection->onTransportData(remb, transport.get());
}

void onRembReceived() {
uint32_t index = 0;
std::vector<uint32_t> ids;
for (bool enabled : add_to_remb_list) {
if (enabled) {
ids.push_back(index);
}
index++;
}
onRembReceived(bitrate_value, ids);
}

uint32_t getIndexFromSsrc(uint32_t ssrc) {
return ssrc / 1000;
}

uint32_t getSsrcFromIndex(uint32_t index) {
return index * 1000;
}

virtual void TearDown() {
connection->close();
simulated_worker->executeTasks();
streams.clear();
}

std::vector<std::shared_ptr<erizo::MockMediaStream>> streams;
MaxList max_video_bw_list;
uint32_t bitrate_value;
EnabledList add_to_remb_list;
ExpectedList expected_bitrates;
IceConfig ice_config;
std::vector<RtpMap> rtp_maps;
std::vector<ExtMap> ext_maps;
uint32_t index;
std::shared_ptr<erizo::MockTransport> transport;
std::shared_ptr<WebRtcConnection> connection;
std::shared_ptr<erizo::MockRtcpProcessor> processor;
std::shared_ptr<erizo::SimulatedClock> simulated_clock;
std::shared_ptr<erizo::SimulatedWorker> simulated_worker;
std::shared_ptr<erizo::IOWorker> io_worker;
std::queue<std::shared_ptr<DataPacket>> packet_queue;
};

TEST_P(WebRtcConnectionTest, forwardRembToStreams_When_StreamTheyExist) {
uint32_t index = 0;
for (int32_t expected_bitrate : expected_bitrates) {
if (expected_bitrate > 0) {
EXPECT_CALL(*(streams[index]), onTransportData(_, _))
.With(Args<0>(erizo::RembHasBitrateValue(static_cast<uint32_t>(expected_bitrate)))).Times(1);
} else {
EXPECT_CALL(*streams[index], onTransportData(_, _)).Times(0);
}
index++;
}

onRembReceived();
}

INSTANTIATE_TEST_CASE_P(
REMB_values, WebRtcConnectionTest, testing::Values(
std::make_tuple(MaxList{300}, 100, EnabledList{1}, ExpectedList{100}),
std::make_tuple(MaxList{300}, 600, EnabledList{1}, ExpectedList{300}),

std::make_tuple(MaxList{300, 300}, 300, EnabledList{1, 0}, ExpectedList{300, -1}),
std::make_tuple(MaxList{300, 300}, 300, EnabledList{0, 1}, ExpectedList{-1, 300}),
std::make_tuple(MaxList{300, 300}, 300, EnabledList{1, 1}, ExpectedList{150, 150}),
std::make_tuple(MaxList{100, 300}, 300, EnabledList{1, 1}, ExpectedList{100, 200}),
std::make_tuple(MaxList{300, 100}, 300, EnabledList{1, 1}, ExpectedList{200, 100}),
std::make_tuple(MaxList{100, 100}, 300, EnabledList{1, 1}, ExpectedList{100, 100}),

std::make_tuple(MaxList{300, 300, 300}, 300, EnabledList{1, 0, 0}, ExpectedList{300, -1, -1}),
std::make_tuple(MaxList{300, 300, 300}, 300, EnabledList{0, 1, 0}, ExpectedList{ -1, 300, -1}),
std::make_tuple(MaxList{300, 300, 300}, 300, EnabledList{1, 1, 0}, ExpectedList{150, 150, -1}),
std::make_tuple(MaxList{100, 300, 300}, 300, EnabledList{1, 1, 0}, ExpectedList{100, 200, -1}),
std::make_tuple(MaxList{300, 100, 300}, 300, EnabledList{1, 1, 0}, ExpectedList{200, 100, -1}),
std::make_tuple(MaxList{100, 100, 300}, 300, EnabledList{1, 1, 0}, ExpectedList{100, 100, -1}),

std::make_tuple(MaxList{300, 300, 300}, 300, EnabledList{0, 1, 0}, ExpectedList{-1, 300, -1}),
std::make_tuple(MaxList{300, 300, 300}, 300, EnabledList{0, 0, 1}, ExpectedList{-1, -1, 300}),
std::make_tuple(MaxList{300, 300, 300}, 300, EnabledList{0, 1, 1}, ExpectedList{-1, 150, 150}),
std::make_tuple(MaxList{300, 100, 300}, 300, EnabledList{0, 1, 1}, ExpectedList{-1, 100, 200}),
std::make_tuple(MaxList{300, 300, 100}, 300, EnabledList{0, 1, 1}, ExpectedList{-1, 200, 100}),
std::make_tuple(MaxList{300, 100, 100}, 300, EnabledList{0, 1, 1}, ExpectedList{-1, 100, 100}),

std::make_tuple(MaxList{100, 100, 100}, 300, EnabledList{1, 1, 1}, ExpectedList{100, 100, 100}),
std::make_tuple(MaxList{100, 100, 100}, 600, EnabledList{1, 1, 1}, ExpectedList{100, 100, 100}),
std::make_tuple(MaxList{300, 300, 300}, 600, EnabledList{1, 1, 1}, ExpectedList{200, 200, 200}),
std::make_tuple(MaxList{100, 200, 300}, 600, EnabledList{1, 1, 1}, ExpectedList{100, 200, 300}),
std::make_tuple(MaxList{300, 200, 100}, 600, EnabledList{1, 1, 1}, ExpectedList{300, 200, 100}),
std::make_tuple(MaxList{100, 500, 500}, 800, EnabledList{1, 1, 1}, ExpectedList{100, 350, 350})));
34 changes: 32 additions & 2 deletions erizo/src/test/utils/Mocks.h
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,33 @@ class MockMediaSink : public MediaSink {
}
};

class MockTransport: public Transport {
public:
MockTransport(std::string connection_id, bool bundle, const IceConfig &ice_config,
std::shared_ptr<Worker> worker, std::shared_ptr<IOWorker> io_worker) :
Transport(VIDEO_TYPE, "video", connection_id, bundle, true,
std::shared_ptr<erizo::TransportListener>(nullptr), ice_config,
worker, io_worker) {}

virtual ~MockTransport() {
}

void updateIceState(IceState state, IceConnection *conn) override {
}
void onIceData(packetPtr packet) override {
}
void onCandidate(const CandidateInfo &candidate, IceConnection *conn) override {
}
void write(char* data, int len) override {
}
void processLocalSdp(SdpInfo *localSdp_) override {
}
void start() override {
}
void close() override {
}
};

class MockWebRtcConnection: public WebRtcConnection {
public:
MockWebRtcConnection(std::shared_ptr<Worker> worker, std::shared_ptr<IOWorker> io_worker, const IceConfig &ice_config,
Expand All @@ -74,11 +101,14 @@ class MockMediaStream: public MediaStream {
public:
MockMediaStream(std::shared_ptr<Worker> worker, std::shared_ptr<WebRtcConnection> connection,
const std::string& media_stream_id, const std::string& media_stream_label,
std::vector<RtpMap> rtp_mappings) :
MediaStream(worker, connection, media_stream_id, media_stream_label, true) {
std::vector<RtpMap> rtp_mappings, bool is_publisher = true) :
MediaStream(worker, connection, media_stream_id, media_stream_label, is_publisher) {
local_sdp_ = std::make_shared<SdpInfo>(rtp_mappings);
remote_sdp_ = std::make_shared<SdpInfo>(rtp_mappings);
}

MOCK_METHOD0(getMaxVideoBW, uint32_t());
MOCK_METHOD2(onTransportData, void(std::shared_ptr<DataPacket>, Transport*));
};

class Reader : public InboundHandler {
Expand Down