-
Notifications
You must be signed in to change notification settings - Fork 15
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request #118 from pr4deepr/master
gradient_z support Congrats @pr4deepr
- Loading branch information
Showing
6 changed files
with
313 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,32 @@ | ||
#ifndef __TIER1_CLEGRADIENTZKERNEL_HPP | ||
#define __TIER1_CLEGRADIENTZKERNEL_HPP | ||
|
||
#include "cleOperation.hpp" | ||
|
||
namespace cle | ||
{ | ||
|
||
class GradientZKernel : public Operation | ||
{ | ||
public: | ||
explicit GradientZKernel(const ProcessorPointer & device); | ||
|
||
auto | ||
SetInput(const Image & object) -> void; | ||
|
||
auto | ||
SetOutput(const Image & object) -> void; | ||
}; | ||
|
||
inline auto | ||
GradientZKernel_Call(const std::shared_ptr<cle::Processor> & device, const Image & src, const Image & dst) -> void | ||
{ | ||
GradientZKernel kernel(device); | ||
kernel.SetInput(src); | ||
kernel.SetOutput(dst); | ||
kernel.Execute(); | ||
} | ||
|
||
} // namespace cle | ||
|
||
#endif // __TIER1_CLEGradientZKERNEL_HPP |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,27 @@ | ||
#include "cleGradientZKernel.hpp" | ||
|
||
namespace cle | ||
{ | ||
|
||
GradientZKernel::GradientZKernel(const ProcessorPointer & device) | ||
: Operation(device, 2) | ||
{ | ||
std::string cl_header_ = { | ||
#include "cle_gradient_z.h" | ||
}; | ||
this->SetSource("gradient_z", cl_header_); | ||
} | ||
|
||
auto | ||
GradientZKernel::SetInput(const Image & object) -> void | ||
{ | ||
this->AddParameter("src", object); | ||
} | ||
|
||
auto | ||
GradientZKernel::SetOutput(const Image & object) -> void | ||
{ | ||
this->AddParameter("dst", object); | ||
} | ||
|
||
} // namespace cle |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,237 @@ | ||
|
||
|
||
#include "clesperanto.hpp" | ||
|
||
#include <random> | ||
|
||
template <class type> | ||
auto | ||
run_test(const std::array<size_t, 3> & shape, const cle::MemoryType & mem_type) -> bool | ||
{ | ||
std::vector<type> input = { 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0 }; | ||
std::vector<type> valid = { 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, -1, 0, 0, 0, 0 }; | ||
|
||
cle::Clesperanto cle; | ||
cle.GetDevice()->WaitForKernelToFinish(); | ||
auto gpu_input = cle.Push<type>(input, shape, mem_type); | ||
auto gpu_output = cle.Create<type>(shape, mem_type); | ||
cle.GradientZ(gpu_input, gpu_output); | ||
auto output = cle.Pull<type>(gpu_output); | ||
return std::equal(output.begin(), output.end(), valid.begin()); | ||
} | ||
|
||
auto | ||
main(int argc, char ** argv) -> int | ||
{ | ||
if (!run_test<float>({ 3, 3, 3 }, cle::BUFFER)) | ||
{ | ||
return EXIT_FAILURE; | ||
} | ||
/* | ||
if (!run_test<signed int>({ 10, 1, 1 }, cle::BUFFER)) | ||
{ | ||
return EXIT_FAILURE; | ||
} | ||
if (!run_test<unsigned int>({ 10, 1, 1 }, cle::BUFFER)) | ||
{ | ||
return EXIT_FAILURE; | ||
} | ||
if (!run_test<signed short>({ 10, 1, 1 }, cle::BUFFER)) | ||
{ | ||
return EXIT_FAILURE; | ||
} | ||
if (!run_test<unsigned short>({ 10, 1, 1 }, cle::BUFFER)) | ||
{ | ||
return EXIT_FAILURE; | ||
} | ||
if (!run_test<signed char>({ 10, 1, 1 }, cle::BUFFER)) | ||
{ | ||
return EXIT_FAILURE; | ||
} | ||
if (!run_test<unsigned char>({ 10, 1, 1 }, cle::BUFFER)) | ||
{ | ||
return EXIT_FAILURE; | ||
} | ||
if (!run_test<float>({ 10, 7, 1 }, cle::BUFFER)) | ||
{ | ||
return EXIT_FAILURE; | ||
} | ||
if (!run_test<signed int>({ 10, 7, 1 }, cle::BUFFER)) | ||
{ | ||
return EXIT_FAILURE; | ||
} | ||
if (!run_test<unsigned int>({ 10, 7, 1 }, cle::BUFFER)) | ||
{ | ||
return EXIT_FAILURE; | ||
} | ||
if (!run_test<signed short>({ 10, 7, 1 }, cle::BUFFER)) | ||
{ | ||
return EXIT_FAILURE; | ||
} | ||
if (!run_test<unsigned short>({ 10, 7, 1 }, cle::BUFFER)) | ||
{ | ||
return EXIT_FAILURE; | ||
} | ||
if (!run_test<signed char>({ 10, 7, 1 }, cle::BUFFER)) | ||
{ | ||
return EXIT_FAILURE; | ||
} | ||
if (!run_test<unsigned char>({ 10, 7, 1 }, cle::BUFFER)) | ||
{ | ||
return EXIT_FAILURE; | ||
} | ||
if (!run_test<float>({ 10, 7, 5 }, cle::BUFFER)) | ||
{ | ||
return EXIT_FAILURE; | ||
} | ||
if (!run_test<signed int>({ 10, 7, 5 }, cle::BUFFER)) | ||
{ | ||
return EXIT_FAILURE; | ||
} | ||
if (!run_test<unsigned int>({ 10, 7, 5 }, cle::BUFFER)) | ||
{ | ||
return EXIT_FAILURE; | ||
} | ||
if (!run_test<signed short>({ 10, 7, 5 }, cle::BUFFER)) | ||
{ | ||
return EXIT_FAILURE; | ||
} | ||
if (!run_test<unsigned short>({ 10, 7, 5 }, cle::BUFFER)) | ||
{ | ||
return EXIT_FAILURE; | ||
} | ||
if (!run_test<signed char>({ 10, 7, 5 }, cle::BUFFER)) | ||
{ | ||
return EXIT_FAILURE; | ||
} | ||
if (!run_test<unsigned char>({ 10, 7, 5 }, cle::BUFFER)) | ||
{ | ||
return EXIT_FAILURE; | ||
} | ||
// if (!run_test<float>({ 10, 1, 1 }, cle::IMAGE)) | ||
// { | ||
// return EXIT_FAILURE; | ||
// } | ||
// if (!run_test<signed int>({ 10, 1, 1 }, cle::IMAGE)) | ||
// { | ||
// return EXIT_FAILURE; | ||
// } | ||
// if (!run_test<unsigned int>({ 10, 1, 1 }, cle::IMAGE)) | ||
// { | ||
// return EXIT_FAILURE; | ||
// } | ||
// if (!run_test<signed short>({ 10, 1, 1 }, cle::IMAGE)) | ||
// { | ||
// return EXIT_FAILURE; | ||
// } | ||
// if (!run_test<unsigned short>({ 10, 1, 1 }, cle::IMAGE)) | ||
// { | ||
// return EXIT_FAILURE; | ||
// } | ||
// if (!run_test<signed char>({ 10, 1, 1 }, cle::IMAGE)) | ||
// { | ||
// return EXIT_FAILURE; | ||
// } | ||
// if (!run_test<unsigned char>({ 10, 1, 1 }, cle::IMAGE)) | ||
// { | ||
// return EXIT_FAILURE; | ||
// } | ||
// if (!run_test<float>({ 10, 7, 1 }, cle::IMAGE)) | ||
// { | ||
// return EXIT_FAILURE; | ||
// } | ||
// if (!run_test<signed int>({ 10, 7, 1 }, cle::IMAGE)) | ||
// { | ||
// return EXIT_FAILURE; | ||
// } | ||
// if (!run_test<unsigned int>({ 10, 7, 1 }, cle::IMAGE)) | ||
// { | ||
// return EXIT_FAILURE; | ||
// } | ||
// if (!run_test<signed short>({ 10, 7, 1 }, cle::IMAGE)) | ||
// { | ||
// return EXIT_FAILURE; | ||
// } | ||
// if (!run_test<unsigned short>({ 10, 7, 1 }, cle::IMAGE)) | ||
// { | ||
// return EXIT_FAILURE; | ||
// } | ||
// if (!run_test<signed char>({ 10, 7, 1 }, cle::IMAGE)) | ||
// { | ||
// return EXIT_FAILURE; | ||
// } | ||
// if (!run_test<unsigned char>({ 10, 7, 1 }, cle::IMAGE)) | ||
// { | ||
// return EXIT_FAILURE; | ||
// } | ||
// if (!run_test<float>({ 10, 7, 5 }, cle::IMAGE)) | ||
// { | ||
// return EXIT_FAILURE; | ||
// } | ||
// if (!run_test<signed int>({ 10, 7, 5 }, cle::IMAGE)) | ||
// { | ||
// return EXIT_FAILURE; | ||
// } | ||
// if (!run_test<unsigned int>({ 10, 7, 5 }, cle::IMAGE)) | ||
// { | ||
// return EXIT_FAILURE; | ||
// } | ||
// if (!run_test<signed short>({ 10, 7, 5 }, cle::IMAGE)) | ||
// { | ||
// return EXIT_FAILURE; | ||
// } | ||
// if (!run_test<unsigned short>({ 10, 7, 5 }, cle::IMAGE)) | ||
// { | ||
// return EXIT_FAILURE; | ||
// } | ||
// if (!run_test<signed char>({ 10, 7, 5 }, cle::IMAGE)) | ||
// { | ||
// return EXIT_FAILURE; | ||
// } | ||
// if (!run_test<unsigned char>({ 10, 7, 5 }, cle::IMAGE)) | ||
// { | ||
// return EXIT_FAILURE; | ||
// } | ||
*/ | ||
return EXIT_SUCCESS; | ||
} |