-
Notifications
You must be signed in to change notification settings - Fork 42
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Parallel transport generic vector interpolation method. #163
Changes from 13 commits
2c9512f
7c0a5d7
d01f5dd
61ec2e7
d5c9ef5
088403e
a6d0df9
5064f44
d0da681
91de6bf
6b43f84
35611a8
a913b18
c4481aa
6107e91
77856f3
346f7f1
ee4c3ad
a96057e
8b3e8f4
d63503f
636c6d8
3d16b6b
ba8193c
fd77d76
3686d84
09af7e7
1212789
98ab54a
96b28b2
56badc7
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,244 @@ | ||
/* | ||
* (C) Crown Copyright 2023 Met Office | ||
* | ||
* This software is licensed under the terms of the Apache Licence Version 2.0 | ||
* which can be obtained at http://www.apache.org/licenses/LICENSE-2.0. | ||
*/ | ||
|
||
#include "atlas/library/defines.h" | ||
#if ATLAS_HAVE_EIGEN | ||
|
||
#include <cmath> | ||
#include <tuple> | ||
|
||
#include "atlas/array/ArrayView.h" | ||
#include "atlas/array/helpers/ArrayForEach.h" | ||
#include "atlas/array/Range.h" | ||
#include "atlas/field/Field.h" | ||
#include "atlas/field/FieldSet.h" | ||
#include "atlas/interpolation/Cache.h" | ||
#include "atlas/interpolation/Interpolation.h" | ||
#include "atlas/interpolation/method/MethodFactory.h" | ||
#include "atlas/interpolation/method/sphericalvector/SphericalVector.h" | ||
odlomax marked this conversation as resolved.
Show resolved
Hide resolved
|
||
#include "atlas/linalg/sparse.h" | ||
#include "atlas/option/Options.h" | ||
odlomax marked this conversation as resolved.
Show resolved
Hide resolved
|
||
#include "atlas/parallel/omp/omp.h" | ||
#include "atlas/runtime/Exception.h" | ||
#include "atlas/runtime/Trace.h" | ||
#include "atlas/util/Constants.h" | ||
#include "atlas/util/UnitSphere.h" | ||
|
||
#include "eckit/linalg/Triplet.h" | ||
odlomax marked this conversation as resolved.
Show resolved
Hide resolved
|
||
|
||
namespace atlas { | ||
namespace interpolation { | ||
namespace method { | ||
|
||
using Complex = SphericalVector::Complex; | ||
|
||
template <typename Value> | ||
using SparseMatrix = SphericalVector::SparseMatrix<Value>; | ||
using RealMatrixMap = Eigen::Map<const SparseMatrix<double>>; | ||
using ComplexTriplets = std::vector<Eigen::Triplet<Complex>>; | ||
using EckitMatrix = eckit::linalg::SparseMatrix; | ||
|
||
namespace { | ||
|
||
MethodBuilder<SphericalVector> __builder("spherical-vector"); | ||
|
||
RealMatrixMap makeMatrixMap(const EckitMatrix& baseMatrix) { | ||
return RealMatrixMap(baseMatrix.rows(), baseMatrix.cols(), | ||
baseMatrix.nonZeros(), baseMatrix.outer(), | ||
baseMatrix.inner(), baseMatrix.data()); | ||
} | ||
|
||
template <typename MatrixT, typename Functor> | ||
void sparseMatrixForEach(const MatrixT& matrix, const Functor& functor) { | ||
|
||
atlas_omp_parallel_for (auto k = 0; k < matrix.outerSize(); ++k) { | ||
for (auto it = typename MatrixT::InnerIterator(matrix, k); it; ++it) { | ||
functor(it.value(), it.row(), it.col()); | ||
} | ||
} | ||
} | ||
|
||
template <typename MatrixT, typename SourceView, typename TargetView, | ||
typename Functor> | ||
void matrixMultiply(const MatrixT& matrix, SourceView&& sourceView, | ||
TargetView&& targetView, const Functor& mappingFunctor) { | ||
|
||
sparseMatrixForEach(matrix, [&](const auto& weight, auto i, auto j) { | ||
|
||
constexpr auto rank = std::decay_t<SourceView>::rank(); | ||
if constexpr(rank == 2) { | ||
const auto sourceSlice = sourceView.slice(j, array::Range::all()); | ||
auto targetSlice = targetView.slice(i, array::Range::all()); | ||
mappingFunctor(weight, sourceSlice, targetSlice); | ||
} | ||
else if constexpr(rank == 3) { | ||
const auto iterationFuctor = [&](auto&& sourceVars, auto&& targetVars) { | ||
mappingFunctor(weight, sourceVars, targetVars); | ||
}; | ||
const auto sourceSlice = | ||
sourceView.slice(j, array::Range::all(), array::Range::all()); | ||
auto targetSlice = | ||
targetView.slice(i, array::Range::all(), array::Range::all()); | ||
array::helpers::ArrayForEach<0>::apply( | ||
std::tie(sourceSlice, targetSlice), iterationFuctor); | ||
} | ||
else { | ||
ATLAS_NOTIMPLEMENTED; | ||
} | ||
}); | ||
} | ||
|
||
} // namespace | ||
|
||
void SphericalVector::do_setup(const Grid& source, const Grid& target, | ||
const Cache&) { | ||
ATLAS_NOTIMPLEMENTED; | ||
} | ||
|
||
void SphericalVector::do_setup(const FunctionSpace& source, | ||
const FunctionSpace& target) { | ||
ATLAS_TRACE("interpolation::method::SphericalVector::do_setup"); | ||
source_ = source; | ||
target_ = target; | ||
|
||
if (target_.size() == 0) { | ||
return; | ||
} | ||
|
||
const auto baseInterpolator = | ||
Interpolation(interpolationScheme_, source_, target_); | ||
setMatrix(MatrixCache(baseInterpolator)); | ||
|
||
// Get matrix data. | ||
const auto nRows = matrix().rows(); | ||
const auto nCols = matrix().cols(); | ||
const auto nNonZeros = matrix().nonZeros(); | ||
const auto realWeights = makeMatrixMap(matrix()); | ||
|
||
complexWeights_ = std::make_shared<ComplexMatrix>(nRows, nCols); | ||
auto complexTriplets = ComplexTriplets(nNonZeros); | ||
|
||
const auto sourceLonLats = array::make_view<double, 2>(source_.lonlat()); | ||
const auto targetLonLats = array::make_view<double, 2>(target_.lonlat()); | ||
|
||
sparseMatrixForEach(realWeights, [&](const auto& weight, auto i, auto j) { | ||
|
||
const auto sourceLonLat = | ||
PointLonLat(sourceLonLats(j, 0), sourceLonLats(j, 1)); | ||
const auto targetLonLat = | ||
PointLonLat(targetLonLats(i, 0), targetLonLats(i, 1)); | ||
|
||
const auto alpha = util::greatCircleCourse(sourceLonLat, targetLonLat); | ||
|
||
const auto deltaAlpha = | ||
(alpha.first - alpha.second) * util::Constants::degreesToRadians(); | ||
|
||
const auto idx = std::distance(realWeights.valuePtr(), &weight); | ||
|
||
complexTriplets[idx] = {int(i), int(j), std::polar(weight, deltaAlpha)}; | ||
}); | ||
complexWeights_->setFromTriplets(complexTriplets.begin(), | ||
complexTriplets.end()); | ||
|
||
ATLAS_ASSERT(complexWeights_->nonZeros() == matrix().nonZeros()); | ||
} | ||
|
||
void SphericalVector::print(std::ostream&) const { ATLAS_NOTIMPLEMENTED; } | ||
|
||
void SphericalVector::do_execute(const FieldSet& sourceFieldSet, | ||
FieldSet& targetFieldSet, | ||
Metadata& metadata) const { | ||
ATLAS_TRACE("atlas::interpolation::method::SphericalVector::do_execute()"); | ||
|
||
const auto nFields = sourceFieldSet.size(); | ||
ATLAS_ASSERT(nFields == targetFieldSet.size()); | ||
|
||
for (auto i = 0; i < sourceFieldSet.size(); ++i) { | ||
do_execute(sourceFieldSet[i], targetFieldSet[i], metadata); | ||
} | ||
} | ||
|
||
void SphericalVector::do_execute(const Field& sourceField, Field& targetField, | ||
Metadata&) const { | ||
ATLAS_TRACE("atlas::interpolation::method::SphericalVector::do_execute()"); | ||
|
||
const auto fieldType = sourceField.metadata().getString("type", ""); | ||
if (fieldType != "vector") { | ||
|
||
auto metadata = Metadata(); | ||
Method::do_execute(sourceField, targetField, metadata); | ||
|
||
return; | ||
} | ||
|
||
if (target_.size() == 0) { | ||
return; | ||
} | ||
|
||
ATLAS_ASSERT_MSG(sourceField.variables() == 2 || sourceField.variables() == 3, | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Shouldn't this assertion come right at the beggining of the method? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I'd also have to check that the field was a vector type. By this point in the function, it definitely is. |
||
"Vector field can only have 2 or 3 components."); | ||
|
||
Method::check_compatibility(sourceField, targetField, matrix()); | ||
|
||
haloExchange(sourceField); | ||
|
||
if (sourceField.datatype().kind() == array::DataType::KIND_REAL64) { | ||
interpolate_vector_field<double>(sourceField, targetField); | ||
} else if (sourceField.datatype().kind() == array::DataType::KIND_REAL32) { | ||
interpolate_vector_field<float>(sourceField, targetField); | ||
} else { | ||
ATLAS_NOTIMPLEMENTED; | ||
} | ||
|
||
targetField.set_dirty(); | ||
} | ||
|
||
template <typename Value> | ||
void SphericalVector::interpolate_vector_field(const Field& sourceField, | ||
Field& targetField) const { | ||
if (sourceField.rank() == 2) { | ||
interpolate_vector_field<Value, 2>(sourceField, targetField); | ||
} else if (sourceField.rank() == 3) { | ||
interpolate_vector_field<Value, 3>(sourceField, targetField); | ||
} else { | ||
ATLAS_NOTIMPLEMENTED; | ||
} | ||
} | ||
|
||
template <typename Value, int Rank> | ||
void SphericalVector::interpolate_vector_field(const Field& sourceField, | ||
Field& targetField) const { | ||
|
||
const auto sourceView = array::make_view<Value, Rank>(sourceField); | ||
auto targetView = array::make_view<Value, Rank>(targetField); | ||
targetView.assign(0.); | ||
|
||
const auto horizontalComponent = [](const auto& weight, auto&& sourceVars, | ||
auto&& targetVars) { | ||
const auto sourceVector = Complex(sourceVars(0), sourceVars(1)); | ||
const auto targetVector = weight * sourceVector; | ||
targetVars(0) += targetVector.real(); | ||
targetVars(1) += targetVector.imag(); | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Without going into the detailed maths, why do you have to There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Never mind, I understand, due to the matrix multiply internally. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. It's because each target element (named There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I could do the I suspect details like this will have to be refactored to remove code duplication when the adjoint methods are added. In that context, you don't assign zero to the array you're writing to. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. OK it can stay as-is. We should not attempt premature optimisation now. |
||
}; | ||
|
||
matrixMultiply(*complexWeights_, sourceView, targetView, horizontalComponent); | ||
|
||
if (sourceField.variables() == 2) return; | ||
|
||
const auto verticalComponent = []( | ||
const auto& weight, auto&& sourceVars, | ||
auto&& targetVars) { targetVars(2) += weight * sourceVars(2); }; | ||
|
||
const auto realWeights = makeMatrixMap(matrix()); | ||
matrixMultiply(realWeights, sourceView, targetView, verticalComponent); | ||
} | ||
|
||
} // namespace method | ||
} // namespace interpolation | ||
} // namespace atlas | ||
|
||
#endif |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It was still good to have it in this file, but not guarded by
#if ATLAS_HAVE_EIGEN
if that makes sense.There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I can see that it adds unnecessary marcro mess. I'm just marvelling at how the factories can register at run time instead of compile time!
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
What I meant is what I added to your branch with a96057e
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ah! That looks far safer!