-
Notifications
You must be signed in to change notification settings - Fork 847
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Hybrid Parallel AD (Part 1/?) #1214
Changes from 65 commits
5fd72ca
679e979
b4650ba
caa1542
d153a00
d9ce155
c9ac197
5074ee3
33437ce
5735c0e
4a820f7
a26e2be
94ac52e
ce4a3bc
7bbb9cd
cfb7285
8fc0941
e04f931
1351c79
6bf97a2
aeaf251
5cea386
a440085
223c10d
6775b29
6aaebca
ce44cac
ef7ad26
f093b35
e174bac
2182622
f71b9ec
94dafb4
8eb3094
66d51df
a7fbcd6
2776775
74f20c4
63003ee
b329cb6
02c9c8e
ac5c581
3527c28
b7d3a8e
9efa995
83b032b
7465871
ecb64d0
8b4a89c
165a52b
f63286b
60792dc
3b0854b
67bdd31
4738e29
7e0bc67
3e82662
c82f3c7
e5e3ebc
6483a3f
083f0b7
c3a62d3
92406ed
73a575b
3870382
45cc9a5
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,7 +1,7 @@ | ||
/*! | ||
* \file ad_structure.hpp | ||
* \brief Main routines for the algorithmic differentiation (AD) structure. | ||
* \author T. Albring | ||
* \author T. Albring, J. Blühdorn | ||
* \version 7.1.1 "Blackbird" | ||
* | ||
* SU2 Project Website: https://su2code.github.io | ||
|
@@ -27,7 +27,8 @@ | |
|
||
#pragma once | ||
|
||
#include "datatype_structure.hpp" | ||
#include "../code_config.hpp" | ||
#include "../parallelization/omp_structure.hpp" | ||
|
||
/*! | ||
* \namespace AD | ||
|
@@ -278,62 +279,92 @@ namespace AD{ | |
|
||
extern int adjointVectorPosition; | ||
|
||
/*--- Reference to the tape ---*/ | ||
|
||
extern su2double::TapeType& globalTape; | ||
|
||
extern bool Status; | ||
|
||
extern bool PreaccActive; | ||
|
||
extern bool PreaccEnabled; | ||
|
||
extern su2double::TapeType::Position StartPosition, EndPosition; | ||
#ifdef HAVE_OPDI | ||
using CoDiTapePosition = su2double::TapeType::Position; | ||
using OpDiState = void*; | ||
using TapePosition = std::pair<CoDiTapePosition, OpDiState>; | ||
#else | ||
using TapePosition = su2double::TapeType::Position; | ||
#endif | ||
|
||
extern TapePosition StartPosition, EndPosition; | ||
|
||
extern std::vector<su2double::TapeType::Position> TapePositions; | ||
extern std::vector<TapePosition> TapePositions; | ||
|
||
extern std::vector<su2double::GradientData> localInputValues; | ||
|
||
extern std::vector<su2double*> localOutputValues; | ||
|
||
extern codi::PreaccumulationHelper<su2double> PreaccHelper; | ||
|
||
/*--- Reference to the tape. ---*/ | ||
|
||
FORCEINLINE su2double::TapeType& getGlobalTape() { | ||
return su2double::getGlobalTape(); | ||
} | ||
|
||
Comment on lines
+306
to
+311
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Tapes may change during runtime and different threads use different tapes. Hence, each reference to the tape must be resolved dynamically via this call. |
||
FORCEINLINE void RegisterInput(su2double &data, bool push_index = true) { | ||
AD::globalTape.registerInput(data); | ||
AD::getGlobalTape().registerInput(data); | ||
if (push_index) { | ||
inputValues.push_back(data.getGradientData()); | ||
} | ||
} | ||
|
||
FORCEINLINE void RegisterOutput(su2double& data) {AD::globalTape.registerOutput(data);} | ||
FORCEINLINE void RegisterOutput(su2double& data) {AD::getGlobalTape().registerOutput(data);} | ||
|
||
FORCEINLINE void ResetInput(su2double &data) {data.getGradientData() = su2double::GradientData();} | ||
|
||
FORCEINLINE void StartRecording() {AD::globalTape.setActive();} | ||
FORCEINLINE void StartRecording() {AD::getGlobalTape().setActive();} | ||
|
||
FORCEINLINE void StopRecording() {AD::globalTape.setPassive();} | ||
FORCEINLINE void StopRecording() {AD::getGlobalTape().setPassive();} | ||
|
||
FORCEINLINE bool TapeActive() { return AD::globalTape.isActive(); } | ||
FORCEINLINE bool TapeActive() { return AD::getGlobalTape().isActive(); } | ||
|
||
FORCEINLINE void PrintStatistics() {AD::globalTape.printStatistics();} | ||
FORCEINLINE void PrintStatistics() {AD::getGlobalTape().printStatistics();} | ||
|
||
FORCEINLINE void ClearAdjoints() {AD::globalTape.clearAdjoints(); } | ||
FORCEINLINE void ClearAdjoints() {AD::getGlobalTape().clearAdjoints(); } | ||
|
||
FORCEINLINE void ComputeAdjoint() {AD::globalTape.evaluate(); adjointVectorPosition = 0;} | ||
FORCEINLINE void ComputeAdjoint() { | ||
#if defined(HAVE_OPDI) | ||
opdi::logic->prepareEvaluate(); | ||
#endif | ||
AD::getGlobalTape().evaluate(); | ||
adjointVectorPosition = 0; | ||
} | ||
|
||
FORCEINLINE void ComputeAdjoint(unsigned short enter, unsigned short leave) { | ||
AD::globalTape.evaluate(TapePositions[enter], TapePositions[leave]); | ||
#if defined(HAVE_OPDI) | ||
opdi::logic->recoverState(TapePositions[enter].second); | ||
opdi::logic->prepareEvaluate(); | ||
AD::getGlobalTape().evaluate(TapePositions[enter].first, TapePositions[leave].first); | ||
#else | ||
AD::getGlobalTape().evaluate(TapePositions[enter], TapePositions[leave]); | ||
#endif | ||
Comment on lines
+342
to
+348
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The AD workflow is extended by OpDiLib calls. |
||
if (leave == 0) | ||
adjointVectorPosition = 0; | ||
} | ||
|
||
FORCEINLINE void Reset() { | ||
globalTape.reset(); | ||
AD::getGlobalTape().reset(); | ||
#if defined(HAVE_OPDI) | ||
opdi::logic->reset(); | ||
#endif | ||
if (inputValues.size() != 0) { | ||
adjointVectorPosition = 0; | ||
inputValues.clear(); | ||
} | ||
if (TapePositions.size() != 0) { | ||
#if defined(HAVE_OPDI) | ||
for (TapePosition& pos : TapePositions) { | ||
opdi::logic->freeState(pos.second); | ||
} | ||
#endif | ||
TapePositions.clear(); | ||
} | ||
} | ||
|
@@ -343,11 +374,11 @@ namespace AD{ | |
} | ||
|
||
FORCEINLINE void SetDerivative(int index, const double val) { | ||
AD::globalTape.setGradient(index, val); | ||
AD::getGlobalTape().setGradient(index, val); | ||
} | ||
|
||
FORCEINLINE double GetDerivative(int index) { | ||
return AD::globalTape.getGradient(index); | ||
return AD::getGlobalTape().getGradient(index); | ||
} | ||
|
||
/*--- Base case for parameter pack expansion. ---*/ | ||
|
@@ -361,6 +392,11 @@ namespace AD{ | |
SetPreaccIn(moreData...); | ||
} | ||
|
||
template<class T, class... Ts, su2enable_if<std::is_same<T,su2double>::value> = 0> | ||
FORCEINLINE void SetPreaccIn(T&& data, Ts&&... moreData) { | ||
static_assert(!std::is_same<T,su2double>::value, "rvalues cannot be registered"); | ||
} | ||
|
||
template<class T> | ||
FORCEINLINE void SetPreaccIn(const T& data, const int size) { | ||
if (PreaccActive) { | ||
|
@@ -384,20 +420,8 @@ namespace AD{ | |
} | ||
} | ||
|
||
template<class T> | ||
FORCEINLINE void SetPreaccIn(const T& data, const int size_x, const int size_y, const int size_z) { | ||
if (!PreaccActive) return; | ||
for (int i = 0; i < size_x; i++) { | ||
for (int j = 0; j < size_y; j++) { | ||
for (int k = 0; k < size_z; k++) { | ||
if (data[i][j][k].isActive()) PreaccHelper.addInput(data[i][j][k]); | ||
} | ||
} | ||
} | ||
} | ||
|
||
FORCEINLINE void StartPreacc() { | ||
if (globalTape.isActive() && PreaccEnabled) { | ||
if (AD::getGlobalTape().isActive() && PreaccEnabled) { | ||
PreaccHelper.start(); | ||
PreaccActive = true; | ||
} | ||
|
@@ -438,7 +462,11 @@ namespace AD{ | |
} | ||
|
||
FORCEINLINE void Push_TapePosition() { | ||
TapePositions.push_back(AD::globalTape.getPosition()); | ||
#if defined(HAVE_OPDI) | ||
TapePositions.push_back({AD::getGlobalTape().getPosition(), opdi::logic->exportState()}); | ||
#else | ||
TapePositions.push_back(AD::getGlobalTape().getPosition()); | ||
#endif | ||
} | ||
|
||
FORCEINLINE void EndPreacc(){ | ||
|
@@ -478,15 +506,15 @@ namespace AD{ | |
} | ||
|
||
FORCEINLINE void SetExtFuncOut(su2double& data) { | ||
if (globalTape.isActive()) { | ||
if (AD::getGlobalTape().isActive()) { | ||
FuncHelper->addOutput(data); | ||
} | ||
} | ||
|
||
template<class T> | ||
FORCEINLINE void SetExtFuncOut(T&& data, const int size) { | ||
for (int i = 0; i < size; i++) { | ||
if (globalTape.isActive()) { | ||
if (AD::getGlobalTape().isActive()) { | ||
FuncHelper->addOutput(data[i]); | ||
} | ||
} | ||
|
@@ -496,7 +524,7 @@ namespace AD{ | |
FORCEINLINE void SetExtFuncOut(T&& data, const int size_x, const int size_y) { | ||
for (int i = 0; i < size_x; i++) { | ||
for (int j = 0; j < size_y; j++) { | ||
if (globalTape.isActive()) { | ||
if (AD::getGlobalTape().isActive()) { | ||
FuncHelper->addOutput(data[i][j]); | ||
} | ||
} | ||
|
@@ -511,7 +539,7 @@ namespace AD{ | |
FORCEINLINE void EndExtFunc() { delete FuncHelper; } | ||
|
||
FORCEINLINE bool BeginPassive() { | ||
if(AD::globalTape.isActive()) { | ||
if(AD::getGlobalTape().isActive()) { | ||
StopRecording(); | ||
return true; | ||
} | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Positions in the AD recording are now identified by the CoDiPack tape position together with the corresponding OpDiLib state.