Skip to content

Commit

Permalink
[ntuple] Add support for std::map fields
Browse files Browse the repository at this point in the history
This change in theory also adds support for other associative
collections, provided they implement a collection proxy. Because the use
case for this is not yet obvious, this is still disabled.
  • Loading branch information
enirolf committed Oct 20, 2023
1 parent 5a039c2 commit 972a45e
Show file tree
Hide file tree
Showing 7 changed files with 342 additions and 4 deletions.
76 changes: 72 additions & 4 deletions tree/ntuple/v7/inc/ROOT/RField.hxx
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@
#include <functional>
#include <iostream>
#include <iterator>
#include <map>
#include <memory>
#include <new>
#include <set>
Expand Down Expand Up @@ -859,8 +860,8 @@ protected:
void GenerateValue(void *where) const override;
void DestroyValue(void *objPtr, bool dtorOnly = false) const override;

std::size_t AppendImpl(const void *from) final;
void ReadGlobalImpl(NTupleSize_t globalIndex, void *to) final;
std::size_t AppendImpl(const void *from) override;
void ReadGlobalImpl(NTupleSize_t globalIndex, void *to) override;

void CommitClusterImpl() final { fNWritten = 0; }

Expand All @@ -871,10 +872,10 @@ public:
~RProxiedCollectionField() override = default;

using Detail::RFieldBase::GenerateValue;
std::vector<RValue> SplitValue(const RValue &value) const final;
std::vector<RValue> SplitValue(const RValue &value) const override;
size_t GetValueSize() const override { return fProxy->Sizeof(); }
size_t GetAlignment() const override { return alignof(std::max_align_t); }
void AcceptVisitor(Detail::RFieldVisitor &visitor) const final;
void AcceptVisitor(Detail::RFieldVisitor &visitor) const override;
void GetCollectionInfo(NTupleSize_t globalIndex, RClusterIndex *collectionStart, ClusterSize_t *size) const
{
fPrincipalColumn->GetCollectionInfo(globalIndex, collectionStart, size);
Expand All @@ -885,6 +886,27 @@ public:
}
};

/// The field for a class representing an associative collection of elements via `TVirtualCollectionProxy`.
/// These collections contain three fields: the index field, key field, and value field. Each key-value pair is treated
/// by the collection proxy as a `std::pair` object.
/// At the moment, this type of field is only enabled for `std::map` objects through `RMapField`.
class RProxiedAssociativeCollectionField : public RProxiedCollectionField {
protected:
TClass *fItemClass;

RProxiedAssociativeCollectionField(std::string_view fieldName, std::string_view typeName,
std::unique_ptr<Detail::RFieldBase> keyField,
std::unique_ptr<Detail::RFieldBase> valueField);

protected:
std::size_t AppendImpl(const void *from) final;
void ReadGlobalImpl(NTupleSize_t globalIndex, void *to) final;

public:
std::vector<RValue> SplitValue(const RValue &value) const final;
void AcceptVisitor(Detail::RFieldVisitor &visitor) const final;
};

/// The field for an untyped record. The subfields are stored consequitively in a memory block, i.e.
/// the memory layout is identical to one that a C++ struct would have
class RRecordField : public Detail::RFieldBase {
Expand Down Expand Up @@ -1155,6 +1177,21 @@ public:
size_t GetAlignment() const override { return std::alignment_of<std::set<std::max_align_t>>(); }
};

/// The generic field for a std::map<KeyType, ValueType>
class RMapField : public RProxiedAssociativeCollectionField {
protected:
std::unique_ptr<Detail::RFieldBase> CloneImpl(std::string_view newName) const final;

public:
RMapField(std::string_view fieldName, std::string_view typeName, std::unique_ptr<Detail::RFieldBase> keyField,
std::unique_ptr<Detail::RFieldBase> valueField);
RMapField(RMapField &&other) = default;
RMapField &operator=(RMapField &&other) = default;
~RMapField() override = default;

size_t GetAlignment() const override { return std::alignment_of<std::map<std::max_align_t, std::max_align_t>>(); }
};

/// The field for values that may or may not be present in an entry. Parent class for unique pointer field and
/// optional field. A nullable field cannot be instantiated itself but only its descendants.
/// The RNullableField takes care of the on-disk representation. Child classes are responsible for the in-memory
Expand Down Expand Up @@ -2196,6 +2233,37 @@ public:
size_t GetAlignment() const final { return std::alignment_of<ContainerT>(); }
};

template <typename KeyT, typename ValueT>
class RField<std::map<KeyT, ValueT>> : public RMapField {
using ContainerT = typename std::map<KeyT, ValueT>;

protected:
void GenerateValue(void *where) const final { new (where) ContainerT(); }
void DestroyValue(void *objPtr, bool dtorOnly = false) const final
{
std::destroy_at(static_cast<ContainerT *>(objPtr));
Detail::RFieldBase::DestroyValue(objPtr, dtorOnly);
}

public:
static std::string TypeName()
{
return "std::map<" + RField<KeyT>::TypeName() + ", " + RField<ValueT>::TypeName() + ">";
}

explicit RField(std::string_view name)
: RMapField(name, TypeName(), std::make_unique<RField<KeyT>>("_0"), std::make_unique<RField<ValueT>>("_1"))
{
}
RField(RField &&other) = default;
RField &operator=(RField &&other) = default;
~RField() override = default;

using Detail::RFieldBase::GenerateValue;
size_t GetValueSize() const final { return sizeof(ContainerT); }
size_t GetAlignment() const final { return std::alignment_of<ContainerT>(); }
};

template <typename... ItemTs>
class RField<std::variant<ItemTs...>> : public RVariantField {
using ContainerT = typename std::variant<ItemTs...>;
Expand Down
6 changes: 6 additions & 0 deletions tree/ntuple/v7/inc/ROOT/RFieldVisitor.hxx
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,10 @@ public:
virtual void VisitBoolField(const RField<bool> &field) { VisitField(field); }
virtual void VisitClassField(const RClassField &field) { VisitField(field); }
virtual void VisitProxiedCollectionField(const RProxiedCollectionField &field) { VisitField(field); }
virtual void VisitProxiedAssociativeCollectionField(const RProxiedAssociativeCollectionField &field)
{
VisitField(field);
}
virtual void VisitRecordField(const RRecordField &field) { VisitField(field); }
virtual void VisitClusterSizeField(const RField<ClusterSize_t> &field) { VisitField(field); }
virtual void VisitCardinalityField(const RCardinalityField &field) { VisitField(field); }
Expand Down Expand Up @@ -187,6 +191,7 @@ private:
void PrintIndent();
void PrintName(const Detail::RFieldBase &field);
void PrintCollection(const Detail::RFieldBase &field);
void PrintAssociativeCollection(const Detail::RFieldBase &field);

public:
RPrintValueVisitor(Detail::RFieldBase::RValue &&value, std::ostream &output, unsigned int level = 0,
Expand Down Expand Up @@ -216,6 +221,7 @@ public:
void VisitClassField(const RClassField &field) final;
void VisitRecordField(const RRecordField &field) final;
void VisitProxiedCollectionField(const RProxiedCollectionField &field) final;
void VisitProxiedAssociativeCollectionField(const RProxiedAssociativeCollectionField &field) final;
void VisitVectorField(const RVectorField &field) final;
void VisitVectorBoolField(const RField<std::vector<bool>> &field) final;
void VisitRVecField(const RRVecField &field) final;
Expand Down
108 changes: 108 additions & 0 deletions tree/ntuple/v7/src/RField.cxx
Original file line number Diff line number Diff line change
Expand Up @@ -185,6 +185,8 @@ std::string GetNormalizedTypeName(const std::string &typeName)
normalizedType = "std::" + normalizedType;
if (normalizedType.substr(0, 4) == "set<")
normalizedType = "std::" + normalizedType;
if (normalizedType.substr(0, 4) == "map<")
normalizedType = "std::" + normalizedType;
if (normalizedType.substr(0, 7) == "atomic<")
normalizedType = "std::" + normalizedType;

Expand Down Expand Up @@ -451,6 +453,18 @@ ROOT::Experimental::Detail::RFieldBase::Create(const std::string &fieldName, con
auto normalizedInnerTypeName = itemField->GetType();
result =
std::make_unique<RSetField>(fieldName, "std::set<" + normalizedInnerTypeName + ">", std::move(itemField));
} else if (canonicalType.substr(0, 9) == "std::map<") {
auto innerTypes = TokenizeTypeList(canonicalType.substr(9, canonicalType.length() - 10));
if (innerTypes.size() != 2)
return R__FAIL("the type list for std::map must have exactly two elements");

auto keyField = Create("_0", innerTypes[0]).Unwrap();
auto valueField = Create("_1", innerTypes[1]).Unwrap();
auto normalizedKeyTypeName = keyField->GetType();
auto normalizedValueTypeName = valueField->GetType();
result = std::make_unique<RMapField>(fieldName,
"std::map<" + normalizedKeyTypeName + ", " + normalizedValueTypeName + ">",
std::move(keyField), std::move(valueField));
} else if (canonicalType.substr(0, 12) == "std::atomic<") {
std::string itemTypeName = canonicalType.substr(12, canonicalType.length() - 13);
auto itemField = Create("_0", itemTypeName).Unwrap();
Expand Down Expand Up @@ -1685,6 +1699,81 @@ void ROOT::Experimental::RProxiedCollectionField::AcceptVisitor(Detail::RFieldVi

//------------------------------------------------------------------------------

ROOT::Experimental::RProxiedAssociativeCollectionField::RProxiedAssociativeCollectionField(
std::string_view fieldName, std::string_view typeName, std::unique_ptr<Detail::RFieldBase> keyField,
std::unique_ptr<Detail::RFieldBase> valueField)
: RProxiedCollectionField(fieldName, typeName, TClass::GetClass(std::string(typeName).c_str()))
{
// Per the implementation of `TGenCollectionProxy`, items are of type std::pair<key, value>.
fItemClass = fProxy->GetValueClass();
fItemSize = fItemClass->GetClassSize();
Attach(std::move(keyField));
Attach(std::move(valueField));
}

std::size_t ROOT::Experimental::RProxiedAssociativeCollectionField::AppendImpl(const void *from)
{
std::size_t nbytes = 0;
unsigned count = 0;
TVirtualCollectionProxy::TPushPop RAII(fProxy.get(), const_cast<void *>(from));
for (auto ptr : RCollectionIterableOnce{const_cast<void *>(from), fIFuncsWrite, fProxy.get(),
(fCollectionType == kSTLvector ? fItemSize : 0U)}) {
nbytes += CallAppendOn(*fSubFields[0],
static_cast<unsigned char *>(ptr) + fItemClass->GetDataMember("first")->GetOffset());
nbytes += CallAppendOn(*fSubFields[1],
static_cast<unsigned char *>(ptr) + fItemClass->GetDataMember("second")->GetOffset());
count++;
}
fNWritten += count;
fColumns[0]->Append(&fNWritten);
return nbytes + fColumns[0]->GetElement()->GetPackedSize();
}

void ROOT::Experimental::RProxiedAssociativeCollectionField::ReadGlobalImpl(NTupleSize_t globalIndex, void *to)
{
ClusterSize_t nItems;
RClusterIndex collectionStart;
fPrincipalColumn->GetCollectionInfo(globalIndex, &collectionStart, &nItems);

TVirtualCollectionProxy::TPushPop RAII(fProxy.get(), to);
void *obj =
fProxy->Allocate(static_cast<std::uint32_t>(nItems), (fProperties & TVirtualCollectionProxy::kNeedDelete));

unsigned i = 0;
for (auto ptr : RCollectionIterableOnce{obj, fIFuncsRead, fProxy.get(),
(fCollectionType == kSTLvector || obj != to ? fItemSize : 0U)}) {
CallReadOn(*fSubFields[0], collectionStart + i,
static_cast<unsigned char *>(ptr) + fItemClass->GetDataMember("first")->GetOffset());
CallReadOn(*fSubFields[1], collectionStart + i,
static_cast<unsigned char *>(ptr) + fItemClass->GetDataMember("second")->GetOffset());
i++;
}
if (obj != to)
fProxy->Commit(obj);
}

std::vector<ROOT::Experimental::Detail::RFieldBase::RValue>
ROOT::Experimental::RProxiedAssociativeCollectionField::SplitValue(const RValue &value) const
{
std::vector<RValue> result;
TVirtualCollectionProxy::TPushPop RAII(fProxy.get(), value.GetRawPtr());
for (auto ptr : RCollectionIterableOnce{value.GetRawPtr(), fIFuncsWrite, fProxy.get(),
(fCollectionType == kSTLvector ? fItemSize : 0U)}) {
result.emplace_back(
fSubFields[0]->BindValue(static_cast<unsigned char *>(ptr) + fItemClass->GetDataMember("first")->GetOffset()));
result.emplace_back(fSubFields[1]->BindValue(static_cast<unsigned char *>(ptr) +
fItemClass->GetDataMember("second")->GetOffset()));
}
return result;
}

void ROOT::Experimental::RProxiedAssociativeCollectionField::AcceptVisitor(Detail::RFieldVisitor &visitor) const
{
visitor.VisitProxiedAssociativeCollectionField(*this);
}

//------------------------------------------------------------------------------

ROOT::Experimental::RRecordField::RRecordField(std::string_view fieldName,
std::vector<std::unique_ptr<Detail::RFieldBase>> &&itemFields,
const std::vector<std::size_t> &offsets, std::string_view typeName)
Expand Down Expand Up @@ -2619,6 +2708,25 @@ ROOT::Experimental::RSetField::CloneImpl(std::string_view newName) const

//------------------------------------------------------------------------------

ROOT::Experimental::RMapField::RMapField(std::string_view fieldName, std::string_view typeName,
std::unique_ptr<Detail::RFieldBase> keyField,
std::unique_ptr<Detail::RFieldBase> valueField)
: ROOT::Experimental::RProxiedAssociativeCollectionField(fieldName, typeName, std::move(keyField),
std::move(valueField))
{
}

std::unique_ptr<ROOT::Experimental::Detail::RFieldBase>
ROOT::Experimental::RMapField::CloneImpl(std::string_view newName) const
{
auto newKeyField = fSubFields[0]->Clone(fSubFields[0]->GetName());
auto newValueField = fSubFields[1]->Clone(fSubFields[1]->GetName());
return std::unique_ptr<RMapField>(
new RMapField(newName, GetType(), std::move(newKeyField), std::move(newValueField)));
}

//------------------------------------------------------------------------------

ROOT::Experimental::RNullableField::RNullableField(std::string_view fieldName, std::string_view typeName,
std::unique_ptr<Detail::RFieldBase> itemField)
: ROOT::Experimental::Detail::RFieldBase(fieldName, typeName, ENTupleStructure::kCollection, false /* isSimple */)
Expand Down
31 changes: 31 additions & 0 deletions tree/ntuple/v7/src/RFieldVisitor.cxx
Original file line number Diff line number Diff line change
Expand Up @@ -145,6 +145,31 @@ void ROOT::Experimental::RPrintValueVisitor::PrintCollection(const Detail::RFiel
fOutput << "]";
}

void ROOT::Experimental::RPrintValueVisitor::PrintAssociativeCollection(const Detail::RFieldBase &field)
{
PrintIndent();
PrintName(field);
fOutput << "[";
auto elems = field.SplitValue(fValue);
for (auto iValue = elems.begin(); iValue != elems.end();) {
fOutput << "{";
RPrintOptions options;
options.fPrintSingleLine = true;
options.fPrintName = false;
RPrintValueVisitor keyVisitor(iValue->GetNonOwningCopy(), fOutput, 0 /* level */, options);
iValue->GetField()->AcceptVisitor(keyVisitor);
fOutput << ": ";
RPrintValueVisitor valueVisitor((++iValue)->GetNonOwningCopy(), fOutput, 0 /* level */, options);
iValue->GetField()->AcceptVisitor(valueVisitor);
fOutput << "}";

if (++iValue == elems.end())
break;
else
fOutput << ", ";
}
fOutput << "]";
}

void ROOT::Experimental::RPrintValueVisitor::VisitField(const Detail::RFieldBase &field)
{
Expand Down Expand Up @@ -397,6 +422,12 @@ void ROOT::Experimental::RPrintValueVisitor::VisitProxiedCollectionField(const R
PrintCollection(field);
}

void ROOT::Experimental::RPrintValueVisitor::VisitProxiedAssociativeCollectionField(
const RProxiedAssociativeCollectionField &field)
{
PrintAssociativeCollection(field);
}

void ROOT::Experimental::RPrintValueVisitor::VisitVectorField(const RVectorField &field)
{
PrintCollection(field);
Expand Down
8 changes: 8 additions & 0 deletions tree/ntuple/v7/test/CustomStructLinkDef.h
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,14 @@
#pragma link C++ class std::set<std::set<char>> +;
#pragma link C++ class std::set<std::pair<int, CustomStruct>> +;

#pragma link C++ class std::map<char, long> + ;
#pragma link C++ class std::map<char, std::int64_t> + ;
#pragma link C++ class std::map<char, std::string> +;
#pragma link C++ class std::map<int, std::vector<CustomStruct>> + ;
#pragma link C++ class std::map<std::string, float> + ;
#pragma link C++ class std::map<char, std::map<int, CustomStruct>> +;
#pragma link C++ class std::map<float, std::map<char, std::int32_t>> +;

#pragma link C++ options = version(3) class StructWithIORulesBase + ;
#pragma link C++ options = version(3) class StructWithTransientString + ;
#pragma link C++ options = version(3) class StructWithIORules + ;
Expand Down
28 changes: 28 additions & 0 deletions tree/ntuple/v7/test/ntuple_show.cxx
Original file line number Diff line number Diff line change
Expand Up @@ -531,6 +531,34 @@ TEST(RNTupleShow, CollectionProxy)
}
}

TEST(RNTupleShow, AssociativeCollectionProxy)
{
FileRaii fileGuard("test_ntuple_show_assoccollectionproxy.ntuple");
{
auto model = RNTupleModel::Create();
auto mapF = model->MakeField<std::map<std::string, float>>("mapF");
auto ntuple = RNTupleWriter::Recreate(std::move(model), "f", fileGuard.GetPath());

*mapF = {{"foo", 3.14}, {"bar", 2.72}};
ntuple->Fill();
}

{
auto ntuple = RNTupleReader::Open("f", fileGuard.GetPath());
EXPECT_EQ(1U, ntuple->GetNEntries());

std::ostringstream os;
ntuple->Show(0, os);
// clang-format off
std::string expected{std::string("")
+ "{\n"
+ " \"mapF\": [{\"bar\": 2.72}, {\"foo\": 3.14}]\n"
+ "}\n"};
// clang-format on
EXPECT_EQ(os.str(), expected);
}
}

TEST(RNTupleShow, Enum)
{

Expand Down
Loading

0 comments on commit 972a45e

Please sign in to comment.