Commit 8a6abd72 authored by Pierre NARVOR's avatar Pierre NARVOR
Browse files

[cuda/functors] Modified warning fix in FunctorCompound (not pretty...)

parent d4d8f116
......@@ -80,7 +80,9 @@ struct FunctorCompound
// below is to suppress the warning but has no effect on the code. See
// here for more info :
// https://stackoverflow.com/questions/64523302/cuda-missing-return-statement-at-end-of-non-void-function-in-constexpr-if-fun
return std::get<Level>(functors_)(input);
// CAUTION : THIS CODE IMPLIES THAT ALL FUNCTORS OUTPUT MUST BE DEFAULT
// CONSTRUCTIBLE. MAYBE KEEPING THE WARNING IS BETTER.
return typename functor_get<Level>::OutputT();
}
public:
......
......@@ -7,9 +7,33 @@ using namespace rtac::cuda;
#include "functors_test.h"
template <typename T>
void print_type(std::ostream& os = std::cout) { os << "'print_not_defined'"; }
template<> void print_type<float> (std::ostream& os) { os << "'float'"; }
template<> void print_type<float4>(std::ostream& os) { os << "'float4'"; }
template <class FunctorT>
void print_functor_type(std::ostream& os = std::cout)
{
os << "(InputT : "; print_type<typename FunctorT::InputT> (os);
os << ", OutputT : "; print_type<typename FunctorT::OutputT>(os);
os << ")" << endl;
}
int main()
{
int N = 10;
print_functor_type<Vectorize4>();
print_functor_type<Norm4>();
print_functor_type<MultiType>();
//MultiType fm(Norm4(), Vectorize4()); // why this not working ?
//auto fm = MultiType(Norm4(), Vectorize4());
MultiType fm(std::make_tuple(Norm4(), Vectorize4()));
print_functor_type<decltype(fm)>();
//cout << std::get<0>(fm.functors_) << endl;
cout << fm(1.0f) << endl;
HostVector<float> input(N);
for(int n = 0; n < N; n++) {
......@@ -18,7 +42,7 @@ int main()
//auto output = scaling(input, functor::Scaling<float>({2.0f}));
auto f = Saxpy(functors::Offset<float>({3.0f}), functors::Scaling<float>({2.0f}));
Saxpy f = Saxpy(functors::Offset<float>({3.0f}), functors::Scaling<float>({2.0f}));
cout << f(1.0f) << endl;
auto output = saxpy(input, Saxpy(functors::Offset<float>({3.0f}),
......
......@@ -7,8 +7,39 @@
namespace rtac { namespace cuda {
struct Vectorize4 {
using InputT = float;
using OutputT = float4;
float x;
RTAC_HOSTDEVICE float4 operator()(float input) const {
return float4({input, input, input, input});
}
};
struct Norm4 {
using InputT = float4;
using OutputT = float;
float x;
RTAC_HOSTDEVICE float operator()(const float4& input) const {
// return length(input); // WHY U NOT WORKING ???
return sqrt( input.x*input.x
+ input.y*input.y
+ input.z*input.z
+ input.w*input.w);
}
};
using MultiType = functors::FunctorCompound<Norm4, Vectorize4>;
using Saxpy = functors::FunctorCompound<functors::Offset<float>, functors::Scaling<float>>;
DeviceVector<float> scaling(const DeviceVector<float>& input,
const functors::Scaling<float>& func);
DeviceVector<float> saxpy(const DeviceVector<float>& input, const Saxpy& func);
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment