Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Parallel transport generic vector interpolation method. #163

Merged
merged 31 commits into from
Dec 15, 2023
Merged
Show file tree
Hide file tree
Changes from 13 commits
Commits
Show all changes
31 commits
Select commit Hold shift + click to select a range
2c9512f
Implemented forward interpolation. TODO: adjoint and more tests.
odlomax Nov 17, 2023
7c0a5d7
Fixed race condition errors.
odlomax Nov 20, 2023
d01f5dd
Replaced rotation matrices with complex interpolation weights.
odlomax Nov 21, 2023
61ec2e7
Merge branch 'develop' into feature/parallel_transport
odlomax Nov 21, 2023
d5c9ef5
Fixed typos before discussion.
odlomax Nov 23, 2023
088403e
Renamed class to SphericalVector. Added Eigen3 matrices.
odlomax Nov 28, 2023
a6d0df9
Update SphericalVector.cc
odlomax Nov 28, 2023
5064f44
Tidied up SphericalVector class.
odlomax Nov 29, 2023
d0da681
Finalising for PR.
odlomax Nov 29, 2023
91de6bf
Merge branch 'develop' into feature/parallel_transport
odlomax Nov 29, 2023
6b43f84
Minor cosmetic changes.
odlomax Nov 30, 2023
35611a8
Replaced optional compilation with #if ATLAS_HAVE_EIGEN in source fil…
odlomax Nov 30, 2023
a913b18
Removed redundant macros.
odlomax Nov 30, 2023
c4481aa
Removed static factory linking.
odlomax Nov 30, 2023
6107e91
Fused horizontal and vertical component matrix-multiplications. TODO:…
odlomax Dec 1, 2023
77856f3
Tidied fused loop.
odlomax Dec 1, 2023
346f7f1
Uncovered and fixed differences in eckit and Eigen3 CRS format. Also
odlomax Dec 5, 2023
ee4c3ad
Added multiple levels to 3d fields.
odlomax Dec 5, 2023
a96057e
Add SphericalVector to MethodFactory
wdeconinck Dec 7, 2023
8b3e8f4
Added more consistent types to iteration indices.
odlomax Dec 7, 2023
d63503f
Further index consistency added.
odlomax Dec 7, 2023
636c6d8
Removed superfluous templates.
odlomax Dec 8, 2023
3d16b6b
Tided up macros.
odlomax Dec 8, 2023
ba8193c
Merge branch 'develop' into feature/parallel_transport
odlomax Dec 11, 2023
fd77d76
Disable OpenMP for older intel-classic compiler (< intel/2022.2)
wdeconinck Dec 12, 2023
3686d84
Enable test with CONDITION statement
wdeconinck Dec 12, 2023
09af7e7
Revert whitespace changes
wdeconinck Dec 12, 2023
1212789
Make greatCircleCourse private before moving to eckit
wdeconinck Dec 12, 2023
98ab54a
Fix header includes
wdeconinck Dec 13, 2023
96b28b2
Addressed reviewer comments.
odlomax Dec 14, 2023
56badc7
Merge branch 'feature/parallel_transport' of https://github.com/JCSDA…
odlomax Dec 14, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 6 additions & 4 deletions src/atlas/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -522,7 +522,7 @@ functionspace/detail/PointCloudInterface.cc
functionspace/detail/CubedSphereStructure.h
functionspace/detail/CubedSphereStructure.cc

# for cubedsphere matching mesh partitioner
# for cubedsphere matching mesh partitioner
interpolation/method/cubedsphere/CellFinder.cc
interpolation/method/cubedsphere/CellFinder.h
interpolation/Vector2D.cc
Expand All @@ -539,8 +539,8 @@ interpolation/element/Triag3D.cc
interpolation/element/Triag3D.h
interpolation/method/Intersect.cc
interpolation/method/Intersect.h
interpolation/method/Ray.cc # For testing Quad
interpolation/method/Ray.h # For testing Quad
interpolation/method/Ray.cc # For testing Quad
interpolation/method/Ray.h # For testing Quad

# for BuildConvexHull3D

Expand Down Expand Up @@ -632,6 +632,8 @@ interpolation/method/knn/KNearestNeighboursBase.cc
interpolation/method/knn/KNearestNeighboursBase.h
interpolation/method/knn/NearestNeighbour.cc
interpolation/method/knn/NearestNeighbour.h
interpolation/method/sphericalvector/SphericalVector.h
interpolation/method/sphericalvector/SphericalVector.cc
interpolation/method/structured/Cubic2D.cc
interpolation/method/structured/Cubic2D.h
interpolation/method/structured/Cubic3D.cc
Expand Down Expand Up @@ -864,7 +866,7 @@ if( NOT atlas_HAVE_ATLAS_FUNCTIONSPACE )
unset( atlas_parallel_srcs )
unset( atlas_output_srcs )
unset( atlas_redistribution_srcs )
unset( atlas_linalg_srcs ) # only depends on array
unset( atlas_linalg_srcs ) # only depends on array
endif()

if( NOT atlas_HAVE_ATLAS_INTERPOLATION )
Expand Down
4 changes: 4 additions & 0 deletions src/atlas/functionspace/NodeColumns.cc
Original file line number Diff line number Diff line change
Expand Up @@ -283,6 +283,10 @@ void NodeColumns::set_field_metadata(const eckit::Configuration& config, Field&
idx_t variables(0);
config.get("variables", variables);
field.set_variables(variables);

if (config.has("type")) {
field.metadata().set("type", config.getString("type"));
}
}

array::DataType NodeColumns::config_datatype(const eckit::Configuration& config) const {
Expand Down
2 changes: 1 addition & 1 deletion src/atlas/interpolation/method/Method.cc
Original file line number Diff line number Diff line change
Expand Up @@ -374,7 +374,7 @@ void Method::do_execute(const FieldSet& fieldsSource, FieldSet& fieldsTarget, Me
ATLAS_ASSERT(N == fieldsTarget.size());

for (idx_t i = 0; i < fieldsSource.size(); ++i) {
Method::do_execute(fieldsSource[i], fieldsTarget[i], metadata);
Method::do_execute(fieldsSource[i], fieldsTarget[i], metadata);
}
}

Expand Down
4 changes: 2 additions & 2 deletions src/atlas/interpolation/method/Method.h
Original file line number Diff line number Diff line change
Expand Up @@ -127,6 +127,8 @@ class Method : public util::Object {
virtual void do_setup(const FunctionSpace& source, const Field& target);
virtual void do_setup(const FunctionSpace& source, const FieldSet& target);

void check_compatibility(const Field& src, const Field& tgt, const Matrix& W) const;

private:
template <typename Value>
void interpolate_field(const Field& src, Field& tgt, const Matrix&) const;
Expand All @@ -152,8 +154,6 @@ class Method : public util::Object {
template <typename Value>
void adjoint_interpolate_field_rank3(Field& src, const Field& tgt, const Matrix&) const;

void check_compatibility(const Field& src, const Field& tgt, const Matrix& W) const;

private:
const Matrix* matrix_ = nullptr;
std::shared_ptr<Matrix> matrix_shared_;
Expand Down
6 changes: 5 additions & 1 deletion src/atlas/interpolation/method/MethodFactory.cc
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It was still good to have it in this file, but not guarded by #if ATLAS_HAVE_EIGEN if that makes sense.

Copy link
Contributor Author

@odlomax odlomax Nov 30, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I can see that it adds unnecessary marcro mess. I'm just marvelling at how the factories can register at run time instead of compile time!

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What I meant is what I added to your branch with a96057e

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ah! That looks far safer!

Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
* nor does it submit to any jurisdiction.
*/

#include "atlas/library/defines.h"
#include "MethodFactory.h"

// for static linking
Expand All @@ -16,6 +17,7 @@
#include "knn/GridBoxMaximum.h"
#include "knn/KNearestNeighbours.h"
#include "knn/NearestNeighbour.h"
#include "sphericalvector/SphericalVector.h"
#include "structured/Cubic2D.h"
#include "structured/Cubic3D.h"
#include "structured/Linear2D.h"
Expand All @@ -25,7 +27,6 @@
#include "unstructured/FiniteElement.h"
#include "unstructured/UnstructuredBilinearLonLat.h"


namespace atlas {
namespace interpolation {

Expand All @@ -47,6 +48,9 @@ void force_link() {
MethodBuilder<method::GridBoxAverage>();
MethodBuilder<method::GridBoxMaximum>();
MethodBuilder<method::CubedSphereBilinear>();
#if ATLAS_HAVE_EIGEN
odlomax marked this conversation as resolved.
Show resolved Hide resolved
MethodBuilder<method::SphericalVector>();
#endif
}
} link;
}
Expand Down
244 changes: 244 additions & 0 deletions src/atlas/interpolation/method/sphericalvector/SphericalVector.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,244 @@
/*
* (C) Crown Copyright 2023 Met Office
*
* This software is licensed under the terms of the Apache Licence Version 2.0
* which can be obtained at http://www.apache.org/licenses/LICENSE-2.0.
*/

#include "atlas/library/defines.h"
#if ATLAS_HAVE_EIGEN

#include <cmath>
#include <tuple>

#include "atlas/array/ArrayView.h"
#include "atlas/array/helpers/ArrayForEach.h"
#include "atlas/array/Range.h"
#include "atlas/field/Field.h"
#include "atlas/field/FieldSet.h"
#include "atlas/interpolation/Cache.h"
#include "atlas/interpolation/Interpolation.h"
#include "atlas/interpolation/method/MethodFactory.h"
#include "atlas/interpolation/method/sphericalvector/SphericalVector.h"
odlomax marked this conversation as resolved.
Show resolved Hide resolved
#include "atlas/linalg/sparse.h"
#include "atlas/option/Options.h"
odlomax marked this conversation as resolved.
Show resolved Hide resolved
#include "atlas/parallel/omp/omp.h"
#include "atlas/runtime/Exception.h"
#include "atlas/runtime/Trace.h"
#include "atlas/util/Constants.h"
#include "atlas/util/UnitSphere.h"

#include "eckit/linalg/Triplet.h"
odlomax marked this conversation as resolved.
Show resolved Hide resolved

namespace atlas {
namespace interpolation {
namespace method {

using Complex = SphericalVector::Complex;

template <typename Value>
using SparseMatrix = SphericalVector::SparseMatrix<Value>;
using RealMatrixMap = Eigen::Map<const SparseMatrix<double>>;
using ComplexTriplets = std::vector<Eigen::Triplet<Complex>>;
using EckitMatrix = eckit::linalg::SparseMatrix;

namespace {

MethodBuilder<SphericalVector> __builder("spherical-vector");

RealMatrixMap makeMatrixMap(const EckitMatrix& baseMatrix) {
return RealMatrixMap(baseMatrix.rows(), baseMatrix.cols(),
baseMatrix.nonZeros(), baseMatrix.outer(),
baseMatrix.inner(), baseMatrix.data());
}

template <typename MatrixT, typename Functor>
void sparseMatrixForEach(const MatrixT& matrix, const Functor& functor) {

atlas_omp_parallel_for (auto k = 0; k < matrix.outerSize(); ++k) {
for (auto it = typename MatrixT::InnerIterator(matrix, k); it; ++it) {
functor(it.value(), it.row(), it.col());
}
}
}

template <typename MatrixT, typename SourceView, typename TargetView,
typename Functor>
void matrixMultiply(const MatrixT& matrix, SourceView&& sourceView,
TargetView&& targetView, const Functor& mappingFunctor) {

sparseMatrixForEach(matrix, [&](const auto& weight, auto i, auto j) {

constexpr auto rank = std::decay_t<SourceView>::rank();
if constexpr(rank == 2) {
const auto sourceSlice = sourceView.slice(j, array::Range::all());
auto targetSlice = targetView.slice(i, array::Range::all());
mappingFunctor(weight, sourceSlice, targetSlice);
}
else if constexpr(rank == 3) {
const auto iterationFuctor = [&](auto&& sourceVars, auto&& targetVars) {
mappingFunctor(weight, sourceVars, targetVars);
};
const auto sourceSlice =
sourceView.slice(j, array::Range::all(), array::Range::all());
auto targetSlice =
targetView.slice(i, array::Range::all(), array::Range::all());
array::helpers::ArrayForEach<0>::apply(
std::tie(sourceSlice, targetSlice), iterationFuctor);
}
else {
ATLAS_NOTIMPLEMENTED;
}
});
}

} // namespace

void SphericalVector::do_setup(const Grid& source, const Grid& target,
const Cache&) {
ATLAS_NOTIMPLEMENTED;
}

void SphericalVector::do_setup(const FunctionSpace& source,
const FunctionSpace& target) {
ATLAS_TRACE("interpolation::method::SphericalVector::do_setup");
source_ = source;
target_ = target;

if (target_.size() == 0) {
return;
}

const auto baseInterpolator =
Interpolation(interpolationScheme_, source_, target_);
setMatrix(MatrixCache(baseInterpolator));

// Get matrix data.
const auto nRows = matrix().rows();
const auto nCols = matrix().cols();
const auto nNonZeros = matrix().nonZeros();
const auto realWeights = makeMatrixMap(matrix());

complexWeights_ = std::make_shared<ComplexMatrix>(nRows, nCols);
auto complexTriplets = ComplexTriplets(nNonZeros);

const auto sourceLonLats = array::make_view<double, 2>(source_.lonlat());
const auto targetLonLats = array::make_view<double, 2>(target_.lonlat());

sparseMatrixForEach(realWeights, [&](const auto& weight, auto i, auto j) {

const auto sourceLonLat =
PointLonLat(sourceLonLats(j, 0), sourceLonLats(j, 1));
const auto targetLonLat =
PointLonLat(targetLonLats(i, 0), targetLonLats(i, 1));

const auto alpha = util::greatCircleCourse(sourceLonLat, targetLonLat);

const auto deltaAlpha =
(alpha.first - alpha.second) * util::Constants::degreesToRadians();

const auto idx = std::distance(realWeights.valuePtr(), &weight);

complexTriplets[idx] = {int(i), int(j), std::polar(weight, deltaAlpha)};
});
complexWeights_->setFromTriplets(complexTriplets.begin(),
complexTriplets.end());

ATLAS_ASSERT(complexWeights_->nonZeros() == matrix().nonZeros());
}

void SphericalVector::print(std::ostream&) const { ATLAS_NOTIMPLEMENTED; }

void SphericalVector::do_execute(const FieldSet& sourceFieldSet,
FieldSet& targetFieldSet,
Metadata& metadata) const {
ATLAS_TRACE("atlas::interpolation::method::SphericalVector::do_execute()");

const auto nFields = sourceFieldSet.size();
ATLAS_ASSERT(nFields == targetFieldSet.size());

for (auto i = 0; i < sourceFieldSet.size(); ++i) {
do_execute(sourceFieldSet[i], targetFieldSet[i], metadata);
}
}

void SphericalVector::do_execute(const Field& sourceField, Field& targetField,
Metadata&) const {
ATLAS_TRACE("atlas::interpolation::method::SphericalVector::do_execute()");

const auto fieldType = sourceField.metadata().getString("type", "");
if (fieldType != "vector") {

auto metadata = Metadata();
Method::do_execute(sourceField, targetField, metadata);

return;
}

if (target_.size() == 0) {
return;
}

ATLAS_ASSERT_MSG(sourceField.variables() == 2 || sourceField.variables() == 3,
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Shouldn't this assertion come right at the beggining of the method?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'd also have to check that the field was a vector type. By this point in the function, it definitely is.

"Vector field can only have 2 or 3 components.");

Method::check_compatibility(sourceField, targetField, matrix());

haloExchange(sourceField);

if (sourceField.datatype().kind() == array::DataType::KIND_REAL64) {
interpolate_vector_field<double>(sourceField, targetField);
} else if (sourceField.datatype().kind() == array::DataType::KIND_REAL32) {
interpolate_vector_field<float>(sourceField, targetField);
} else {
ATLAS_NOTIMPLEMENTED;
}

targetField.set_dirty();
}

template <typename Value>
void SphericalVector::interpolate_vector_field(const Field& sourceField,
Field& targetField) const {
if (sourceField.rank() == 2) {
interpolate_vector_field<Value, 2>(sourceField, targetField);
} else if (sourceField.rank() == 3) {
interpolate_vector_field<Value, 3>(sourceField, targetField);
} else {
ATLAS_NOTIMPLEMENTED;
}
}

template <typename Value, int Rank>
void SphericalVector::interpolate_vector_field(const Field& sourceField,
Field& targetField) const {

const auto sourceView = array::make_view<Value, Rank>(sourceField);
auto targetView = array::make_view<Value, Rank>(targetField);
targetView.assign(0.);

const auto horizontalComponent = [](const auto& weight, auto&& sourceVars,
auto&& targetVars) {
const auto sourceVector = Complex(sourceVars(0), sourceVars(1));
const auto targetVector = weight * sourceVector;
targetVars(0) += targetVector.real();
targetVars(1) += targetVector.imag();
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Without going into the detailed maths, why do you have to
targetView.assign(0)
and then use += for the targetVars ? Can this not be replaced with just = ?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Never mind, I understand, due to the matrix multiply internally.
But perhaps performance can be gained if the assign(0) could be removed and done inside the matrixMultiply

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's because each target element (named targetVars here) has a contribution from multiple source elements (sourceVars). If I changed += to =, it would set the target element to the last source element on that row of the matrix.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I could do the assign(0.) within matrixMultiply function, but it would have to be on the first line, and therefore be the equivalent operation as it currently is. The SparseMatrixForEach is the only function that has knowledge of i as a whole row, but it doesn't know what functor is doing (it's used in a couple of different contexts).

I suspect details like this will have to be refactored to remove code duplication when the adjoint methods are added. In that context, you don't assign zero to the array you're writing to.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

OK it can stay as-is. We should not attempt premature optimisation now.

};

matrixMultiply(*complexWeights_, sourceView, targetView, horizontalComponent);

if (sourceField.variables() == 2) return;

const auto verticalComponent = [](
const auto& weight, auto&& sourceVars,
auto&& targetVars) { targetVars(2) += weight * sourceVars(2); };

const auto realWeights = makeMatrixMap(matrix());
matrixMultiply(realWeights, sourceView, targetView, verticalComponent);
}

} // namespace method
} // namespace interpolation
} // namespace atlas

#endif
Loading