From 76f355367920a01987aa9664389d8002b8344ba5 Mon Sep 17 00:00:00 2001 From: Yurii Shyrma Date: Tue, 12 May 2020 07:47:09 +0300 Subject: [PATCH] Shyrma merge max ind (#443) * - provide correct possible output types in mergeMaxIndex op Signed-off-by: Yurii * - cleaning up the unneeded backprop arg in reverse_bp op Signed-off-by: Yurii * - improve clipByNorm both ff and bp Signed-off-by: Yurii * - implementation and testing clipByAvgNorm_bp op Signed-off-by: Yurii * - pass biases in any way in dnnl lstm op, they are zeros when user doesn't provide them to us Signed-off-by: Yurii * - start working on mkldnn concat op Signed-off-by: Yurii * - further work on mkldnn concat Signed-off-by: Yurii * missing declaration fix Signed-off-by: raver119@gmail.com * - polishing mkl ops Signed-off-by: Yurii * - testing and fixing bugs in mkl concat op Signed-off-by: Yurii * - fix linkage error for windows cuda build Signed-off-by: Yurii * - further conflicts resolving with master Signed-off-by: Yurii * - fix format tags in mkldnn matmul op Signed-off-by: Yurii * - provide additional type cast in clip.cu Signed-off-by: Yurii * - finally bug in mkldnn tanh_bp was caught Co-authored-by: raver119@gmail.com --- libnd4j/include/array/NDArray.h | 8 +- libnd4j/include/array/NDArray.hXX | 10 +- .../transforms/clip_by_averaged_norm.cpp | 51 +- .../generic/transforms/clip_by_norm.cpp | 10 +- .../declarable/generic/transforms/concat.cpp | 5 +- .../generic/transforms/merge_max_idx.cpp | 5 +- .../declarable/generic/transforms/reverse.cpp | 4 +- .../ops/declarable/headers/transforms.h | 1 + .../ops/declarable/helpers/cpu/clip.cpp | 276 ++++------ .../ops/declarable/helpers/cpu/merge.cpp | 35 +- .../ops/declarable/helpers/cpu/reverse.cpp | 9 +- .../ops/declarable/helpers/cuda/clip.cu | 334 ++++++++++++ .../ops/declarable/helpers/cuda/reverse.cu | 10 +- .../ops/declarable/helpers/cuda/transforms.cu | 509 ------------------ .../include/ops/declarable/helpers/reverse.h | 4 +- .../ops/declarable/helpers/transforms.h | 8 +- .../ops/declarable/impl/DeclarableOp.cpp | 2 +- .../declarable/platform/mkldnn/batchnorm.cpp | 46 +- .../ops/declarable/platform/mkldnn/concat.cpp | 186 +++++++ .../ops/declarable/platform/mkldnn/conv2d.cpp | 111 ++-- .../ops/declarable/platform/mkldnn/conv3d.cpp | 114 ++-- .../declarable/platform/mkldnn/deconv2d.cpp | 97 ++-- .../platform/mkldnn/deconv2d_tf.cpp | 27 +- .../declarable/platform/mkldnn/deconv3d.cpp | 100 ++-- .../platform/mkldnn/depthwiseConv2d.cpp | 51 +- .../declarable/platform/mkldnn/lstmLayer.cpp | 123 ++--- .../ops/declarable/platform/mkldnn/matmul.cpp | 60 +-- .../platform/mkldnn/mkldnnUtils.cpp | 154 +++--- .../declarable/platform/mkldnn/mkldnnUtils.h | 20 +- .../declarable/platform/mkldnn/softmax.cpp | 91 ++-- .../ops/declarable/platform/mkldnn/tanh.cpp | 67 +-- .../declarable/platform/mkldnn/xw_plus_b.cpp | 146 ++--- libnd4j/include/ops/impl/specials_single.hpp | 23 +- .../layers_tests/DeclarableOpsTests16.cpp | 431 ++++++++++++++- .../layers_tests/DeclarableOpsTests3.cpp | 150 ++---- .../layers_tests/DeclarableOpsTests5.cpp | 254 ++++----- .../layers_tests/DeclarableOpsTests6.cpp | 210 ++++---- .../layers_tests/DeclarableOpsTests8.cpp | 504 ++++++----------- .../layers_tests/DeclarableOpsTests9.cpp | 281 +++------- 39 files changed, 2130 insertions(+), 2397 deletions(-) create mode 100644 libnd4j/include/ops/declarable/helpers/cuda/clip.cu create mode 100644 libnd4j/include/ops/declarable/platform/mkldnn/concat.cpp diff --git a/libnd4j/include/array/NDArray.h b/libnd4j/include/array/NDArray.h index 7936f6688..ae4df227d 100644 --- a/libnd4j/include/array/NDArray.h +++ b/libnd4j/include/array/NDArray.h @@ -981,12 +981,12 @@ namespace sd { * these methods suited for FlatBuffers use */ template - std::vector getBufferAsVector(); + std::vector getBufferAsVector() const; std::vector getShapeAsVector() const; std::vector getShapeAsVectorInt() const; - std::vector getShapeInfoAsVector(); - std::vector getShapeInfoAsFlatVector(); - std::vector getShapeAsFlatVector(); + std::vector getShapeInfoAsVector() const; + std::vector getShapeInfoAsFlatVector() const; + std::vector getShapeAsFlatVector() const; /** * set new order and shape in case of suitable array length (in-place operation) diff --git a/libnd4j/include/array/NDArray.hXX b/libnd4j/include/array/NDArray.hXX index 42f5f47f3..786333eec 100644 --- a/libnd4j/include/array/NDArray.hXX +++ b/libnd4j/include/array/NDArray.hXX @@ -982,16 +982,16 @@ std::string NDArray::asString(Nd4jLong limit) { //////////////////////////////////////////////////////////////////////// template -std::vector NDArray::getBufferAsVector() { +std::vector NDArray::getBufferAsVector() const { std::vector vector(lengthOf()); for (Nd4jLong e = 0; e < lengthOf(); e++) vector[e] = this->e(e); return vector; } -BUILD_SINGLE_TEMPLATE(template ND4J_EXPORT std::vector, NDArray::getBufferAsVector(), LIBND4J_TYPES); +BUILD_SINGLE_TEMPLATE(template ND4J_EXPORT std::vector, NDArray::getBufferAsVector() const, LIBND4J_TYPES); //////////////////////////////////////////////////////////////////////// -std::vector NDArray::getShapeAsFlatVector() { +std::vector NDArray::getShapeAsFlatVector() const { std::vector vector(this->rankOf()); for (int e = 0; e < this->rankOf(); e++) vector[e] = static_cast(this->sizeAt(e)); @@ -1019,7 +1019,7 @@ std::vector NDArray::getShapeAsVectorInt() const { } //////////////////////////////////////////////////////////////////////// -std::vector NDArray::getShapeInfoAsFlatVector() { +std::vector NDArray::getShapeInfoAsFlatVector() const { int magicNumber = shape::shapeInfoLength(this->rankOf()); std::vector vector(magicNumber); @@ -1030,7 +1030,7 @@ std::vector NDArray::getShapeInfoAsFlatVector() { } //////////////////////////////////////////////////////////////////////// -std::vector NDArray::getShapeInfoAsVector() { +std::vector NDArray::getShapeInfoAsVector() const { int magicNumber = shape::shapeInfoLength(this->rankOf()); std::vector vector(magicNumber); for (int e = 0; e < magicNumber; e++) diff --git a/libnd4j/include/ops/declarable/generic/transforms/clip_by_averaged_norm.cpp b/libnd4j/include/ops/declarable/generic/transforms/clip_by_averaged_norm.cpp index 958a90410..a7340bf21 100644 --- a/libnd4j/include/ops/declarable/generic/transforms/clip_by_averaged_norm.cpp +++ b/libnd4j/include/ops/declarable/generic/transforms/clip_by_averaged_norm.cpp @@ -15,7 +15,8 @@ ******************************************************************************/ // -// @author raver119@gmail.com +// @author raver119@gmail.com +// @author Yurii Shyrma (iuriish@yahoo.com) // #include @@ -27,24 +28,58 @@ namespace sd { namespace ops { +////////////////////////////////////////////////////////////////////////// CONFIGURABLE_OP_IMPL(clipbyavgnorm, 1, 1, true, 1, 0) { auto input = INPUT_VARIABLE(0); auto output = OUTPUT_VARIABLE(0); const bool isInplace = block.isInplace(); - auto ts = NDArrayFactory::create(T_ARG(0), block.launchContext()); + auto clipNorm = NDArrayFactory::create(T_ARG(0), block.launchContext()); - helpers::clipByAveraged(block.launchContext(), *input, *output, *block.getIArguments(), ts, isInplace); + helpers::clipByNorm(block.launchContext(), *input, *output, *block.getIArguments(), clipNorm, isInplace, true); return Status::OK(); } - DECLARE_TYPES(clipbyavgnorm) { - getOpDescriptor() - ->setAllowedInputTypes(sd::DataType::ANY) - ->setAllowedOutputTypes({ALL_FLOATS}); - } +DECLARE_TYPES(clipbyavgnorm) { + getOpDescriptor() + ->setAllowedInputTypes(sd::DataType::ANY) + ->setAllowedOutputTypes({ALL_FLOATS}); +} + +////////////////////////////////////////////////////////////////////////// +CUSTOM_OP_IMPL(clipbyavgnorm_bp, 2, 1, false, 1, 0) { + + auto input = INPUT_VARIABLE(0); + auto gradO = INPUT_VARIABLE(1); + + auto gradI = OUTPUT_VARIABLE(0); + + const auto clipNorm = NDArrayFactory::create(gradI->dataType(), T_ARG(0), block.launchContext()); + + helpers::clipByNormBp(block.launchContext(), *input, *gradO, *gradI, *block.getIArguments(), clipNorm, true); + + return Status::OK(); +} + +////////////////////////////////////////////////////////////////////////// +DECLARE_SHAPE_FN(clipbyavgnorm_bp) { + + Nd4jLong *newShape = nullptr; + COPY_SHAPE(inputShape->at(1), newShape); + + return SHAPELIST(CONSTANT(newShape)); +} + + +DECLARE_TYPES(clipbyavgnorm_bp) { + getOpDescriptor() + ->setAllowedInputTypes(0, DataType::ANY) + ->setAllowedInputTypes(1, {ALL_FLOATS}) + ->setAllowedOutputTypes(0, {ALL_FLOATS}); +} + } } diff --git a/libnd4j/include/ops/declarable/generic/transforms/clip_by_norm.cpp b/libnd4j/include/ops/declarable/generic/transforms/clip_by_norm.cpp index 43b23ba18..75145f7cc 100644 --- a/libnd4j/include/ops/declarable/generic/transforms/clip_by_norm.cpp +++ b/libnd4j/include/ops/declarable/generic/transforms/clip_by_norm.cpp @@ -31,10 +31,10 @@ namespace ops { auto input = INPUT_VARIABLE(0); auto output = OUTPUT_VARIABLE(0); - const auto clipNorm = NDArrayFactory::create(input->dataType(), T_ARG(0), block.launchContext()); + const auto clipNorm = NDArrayFactory::create(output->dataType(), T_ARG(0), block.launchContext()); const bool isInplace = block.isInplace(); - helpers::clipByNorm(block.launchContext(), *input, *output, *block.getIArguments(), clipNorm, isInplace); + helpers::clipByNorm(block.launchContext(), *input, *output, *block.getIArguments(), clipNorm, isInplace, false); return Status::OK(); } @@ -45,15 +45,15 @@ namespace ops { auto gradO = INPUT_VARIABLE(1); auto gradI = OUTPUT_VARIABLE(0); - const auto clipNorm = NDArrayFactory::create(T_ARG(0)); + const auto clipNorm = NDArrayFactory::create(gradI->dataType(), T_ARG(0), block.launchContext()); - helpers::clipByNormBP(block.launchContext(), *input, *gradO, *gradI, *block.getIArguments(), clipNorm); + helpers::clipByNormBp(block.launchContext(), *input, *gradO, *gradI, *block.getIArguments(), clipNorm, false); return Status::OK(); } DECLARE_SHAPE_FN(clipbynorm_bp) { - auto inShapeInfo = inputShape->at(0); + auto inShapeInfo = inputShape->at(1); Nd4jLong *newShape = nullptr; COPY_SHAPE(inShapeInfo, newShape); diff --git a/libnd4j/include/ops/declarable/generic/transforms/concat.cpp b/libnd4j/include/ops/declarable/generic/transforms/concat.cpp index fb1fd2e87..1cf750e00 100644 --- a/libnd4j/include/ops/declarable/generic/transforms/concat.cpp +++ b/libnd4j/include/ops/declarable/generic/transforms/concat.cpp @@ -23,8 +23,8 @@ #include #include -namespace sd { -namespace ops { +namespace sd { +namespace ops { ////////////////////////////////////////////////////////////////////////// @@ -85,6 +85,7 @@ CUSTOM_OP_IMPL(concat, -1, 1, false, 0, 0) { // ******** input validation ******** // REQUIRE_TRUE(allOfSameType, 0, "CONCAT op: all of input arrays must have same type !"); + REQUIRE_TRUE(nonEmptyArrs[0]->dataType() == OUTPUT_VARIABLE(0)->dataType(), 0, "CONCAT op: output array should have the same type as inputs arrays !"); REQUIRE_TRUE(0 <= axis && (axis < rank || (axis == 0 && rank == 0)), 0, "CONCAT op: input axis must be in range [0, %i], but got %i instead!", rank-1, axis); for(int i = 1; i < numOfNonEmptyArrs; ++i) diff --git a/libnd4j/include/ops/declarable/generic/transforms/merge_max_idx.cpp b/libnd4j/include/ops/declarable/generic/transforms/merge_max_idx.cpp index 1ffe42f4b..3c76450aa 100644 --- a/libnd4j/include/ops/declarable/generic/transforms/merge_max_idx.cpp +++ b/libnd4j/include/ops/declarable/generic/transforms/merge_max_idx.cpp @@ -33,7 +33,7 @@ CUSTOM_OP_IMPL(mergemaxindex, -1, 1, false, 0, 0) { auto output = OUTPUT_VARIABLE(0); std::vector inArrs(block.width()); - + for(int i = 0; i < block.width(); ++i) inArrs[i] = INPUT_VARIABLE(i); @@ -46,7 +46,8 @@ DECLARE_SYN(MergeMaxIndex, mergemaxindex); DECLARE_TYPES(mergemaxindex) { getOpDescriptor() - ->setAllowedInputTypes({ALL_INTS, ALL_FLOATS}); + ->setAllowedInputTypes({ALL_INTS, ALL_FLOATS}) + ->setAllowedOutputTypes({ALL_INDICES}); } } DECLARE_SHAPE_FN(mergemaxindex) { diff --git a/libnd4j/include/ops/declarable/generic/transforms/reverse.cpp b/libnd4j/include/ops/declarable/generic/transforms/reverse.cpp index 401b68d00..e8f659c5d 100644 --- a/libnd4j/include/ops/declarable/generic/transforms/reverse.cpp +++ b/libnd4j/include/ops/declarable/generic/transforms/reverse.cpp @@ -52,7 +52,7 @@ namespace ops { else { // check the consistency of input dimensions to reverse along shape::checkDimensions(input->rankOf(), axis); - helpers::reverse(block.launchContext(), input, output, &axis, false); + helpers::reverse(block.launchContext(), input, output, &axis); } return Status::OK(); @@ -85,7 +85,7 @@ namespace ops { // check the consistency of input dimensions to reverse along shape::checkDimensions(input->rankOf(), axis); // we just reverse back original array - helpers::reverse(block.launchContext(), eps, output, &axis, false); + helpers::reverse(block.launchContext(), eps, output, &axis); } return Status::OK(); diff --git a/libnd4j/include/ops/declarable/headers/transforms.h b/libnd4j/include/ops/declarable/headers/transforms.h index 29efc4a73..3fe2f1223 100644 --- a/libnd4j/include/ops/declarable/headers/transforms.h +++ b/libnd4j/include/ops/declarable/headers/transforms.h @@ -36,6 +36,7 @@ namespace sd { #if NOT_EXCLUDED(OP_clipbyavgnorm) DECLARE_CONFIGURABLE_OP(clipbyavgnorm, 1, 1, true, 1, 0); + DECLARE_CUSTOM_OP(clipbyavgnorm_bp, 2, 1, false, 1, 0); #endif #if NOT_EXCLUDED(OP_cumsum) diff --git a/libnd4j/include/ops/declarable/helpers/cpu/clip.cpp b/libnd4j/include/ops/declarable/helpers/cpu/clip.cpp index d4240d780..2c2d9a111 100644 --- a/libnd4j/include/ops/declarable/helpers/cpu/clip.cpp +++ b/libnd4j/include/ops/declarable/helpers/cpu/clip.cpp @@ -15,83 +15,134 @@ ******************************************************************************/ // -// @author Yurii Shyrma (iuriish@yahoo.com), created on 20.04.2018 +// @author Yurii Shyrma (iuriish@yahoo.com) +// @author sgazeos@gmail.com +// @author raver119@gmail.com // #include -#include +#include namespace sd { namespace ops { namespace helpers { ////////////////////////////////////////////////////////////////////////// -template -static void clipByNorm_(NDArray& input, NDArray& output, const std::vector& dimensions, const NDArray& clipNorm, const bool isInplace) { +void clipByNorm(sd::LaunchContext* context, NDArray& input, NDArray& output, const std::vector& dimensions, const NDArray& clipNorm, const bool isInplace, const bool useAverage) { - const int rank = input.rankOf(); - const auto norm2 = input.reduceAlongDimension(reduce::Norm2, dimensions); + NDArray* z = nullptr; - const T normActual = norm2.e(0); - const T normClip = clipNorm.e(0); + if(isInplace) { + z = &input; + } + else { + output.assign(input); + z = &output; + } - if (isInplace) { + if(dimensions.empty()) { - if(norm2.lengthOf() == 1) { + const NDArray actualNorm = useAverage ? z->reduceAlongDimension(reduce::Norm2, {}) / z->lengthOf() : z->reduceAlongDimension(reduce::Norm2, {}); - if(normActual > normClip) - input *= (normClip / normActual); - } - else { - - auto listOfInSubArrs = input.allTensorsAlongDimension(dimensions); - - auto func = PRAGMA_THREADS_FOR { - for (auto i = start; i < stop; i++) { - const T iNormActual = norm2.e(i); - if (iNormActual > normClip) - *listOfInSubArrs.at(i) *= normClip / iNormActual; - } - }; - samediff::Threads::parallel_tad(func, 0, listOfInSubArrs.size()); - } + if(actualNorm.e(0) > clipNorm.e(0)) + *z *= clipNorm / actualNorm; } else { - if(norm2.lengthOf() == 1) { + auto listOfSubArrs = z->allTensorsAlongDimension(dimensions); - if(normActual > normClip) - output.assign(input * (normClip / normActual)); - else - output.assign(input); - } - else { - - auto listOfInSubArrs = input.allTensorsAlongDimension(dimensions); - auto listOfOutSubArrs = output.allTensorsAlongDimension(dimensions); - - auto func = PRAGMA_THREADS_FOR { - for (auto i = start; i < stop; i++) { - auto inputSubArr = listOfInSubArrs.at(i); - auto outputSubArr = listOfOutSubArrs.at(i); - outputSubArr->assign(inputSubArr); - - const T iNormActual = norm2.e(i); - - if (iNormActual > clipNorm.e(0)) - *outputSubArr *= clipNorm / iNormActual; - } - }; - samediff::Threads::parallel_tad(func, 0, listOfInSubArrs.size()); - } + auto func = PRAGMA_THREADS_FOR { + for (auto i = start; i < stop; i++) { + const NDArray actualNorm = useAverage ? listOfSubArrs.at(i)->reduceAlongDimension(reduce::Norm2, {}) / listOfSubArrs.at(i)->lengthOf() : listOfSubArrs.at(i)->reduceAlongDimension(reduce::Norm2, {}); + if(actualNorm.e(0) > clipNorm.e(0)) + *listOfSubArrs.at(i) *= clipNorm / actualNorm; + } + }; + samediff::Threads::parallel_tad(func, 0, listOfSubArrs.size()); } } + ////////////////////////////////////////////////////////////////////////// -void clipByNorm(sd::LaunchContext * context, NDArray& input, NDArray& output, const std::vector& dimensions, const NDArray& clipNorm, const bool isInplace) { - BUILD_SINGLE_SELECTOR(output.dataType(), clipByNorm_, (input, output, dimensions, clipNorm, isInplace), FLOAT_TYPES); +template +static void clipByNormBp_(const NDArray& input, const NDArray& gradO, NDArray& gradI, const std::vector& dimensions, const NDArray& clipNorm, const bool useAverage) { + + const int rank = input.rankOf(); + + auto norm2 = input.reduceAlongDimension(reduce::Norm2, dimensions); + auto sums = input.reduceAlongDimension(reduce::Sum, dimensions); + + if(norm2.lengthOf() == 1) { + + const T norm = useAverage ? norm2.e(0) / input.lengthOf() : norm2.e(0); + + auto clipVal = clipNorm.e(0); + + if(norm > clipVal) { + + const T sum = sums.e(0); // reduce to scalar + const T factor1 = clipVal / norm; + const T factor2 = static_cast(1.f) / (norm * norm); // 1 / (norm*norm*norm) + + auto lambda = LAMBDA_TT(x, y, sum, factor1, factor2) { + return factor1 * y * (static_cast(1.f) - factor2 * x * sum); + }; + + const_cast(input).applyPairwiseLambda(const_cast(gradO), lambda, gradI); + } + else + gradI.assign(gradO); + } + else { + + auto gradISubArrs = gradI.allTensorsAlongDimension({dimensions}); + auto gradOSubArrs = gradO.allTensorsAlongDimension({dimensions}); + auto inputSubArrs = input.allTensorsAlongDimension({dimensions}); + + auto clipVal = clipNorm.e(0); + + auto func = PRAGMA_THREADS_FOR { + + for (auto i = start; i < stop; i++) { + + auto gradOSubArr = gradOSubArrs.at(i); + auto gradISubArr = gradISubArrs.at(i); + + const T norm = useAverage ? norm2.e(i) / gradISubArr->lengthOf() : norm2.e(i); + + if (norm > clipVal) { + + auto inputSubArr = inputSubArrs.at(i); + + const T sum = sums.e(i); // reduce to scalar + const T factor1 = clipVal / norm; + const T factor2 = static_cast(1.f) / (norm * norm); // 1 / (norm*norm*norm) + + auto lambda = LAMBDA_TT(x, y, sum, factor1, factor2) { + return factor1 * y * (static_cast(1.f) - factor2 * x * sum); + }; + + inputSubArr->applyPairwiseLambda(*gradOSubArr, lambda, *gradISubArr); + } + else + gradISubArr->assign(gradOSubArr); + } + }; + samediff::Threads::parallel_tad(func, 0, gradISubArrs.size()); + } } +BUILD_SINGLE_TEMPLATE(template void clipByNormBp_, (const NDArray& input, const NDArray& gradO, NDArray& gradI, const std::vector& dimensions, const NDArray& clipNorm, const bool useAverage), FLOAT_TYPES); + +////////////////////////////////////////////////////////////////////////// +void clipByNormBp(sd::LaunchContext* context, const NDArray& input, const NDArray& gradO, NDArray& gradI, const std::vector& dimensions, const NDArray& clipNorm, const bool useAverage) { + + const NDArray& castedInput = gradI.dataType() == input.dataType() ? input : input.cast(gradI.dataType()); + + BUILD_SINGLE_SELECTOR(gradI.dataType(), clipByNormBp_, (castedInput, gradO, gradI, dimensions, clipNorm, useAverage), FLOAT_TYPES); +} + + template @@ -132,125 +183,6 @@ void clipByNorm(sd::LaunchContext * context, NDArray& input, NDArray& output, co BUILD_SINGLE_TEMPLATE(template void clipByGlobalNorm_, (std::vector const& inputs, double clipNorm, sd::memory::Workspace* workspace, std::vector& outputs, bool isInplace), FLOAT_TYPES); -////////////////////////////////////////////////////////////////////////// -template -static void clipByNormBP_(const NDArray& input, const NDArray& gradO, NDArray& gradI /*output*/, const std::vector& dimensions, const NDArray& clipNorm) { - - const int rank = input.rankOf(); - - auto norm2 = input.reduceAlongDimension(reduce::Norm2, dimensions); - - if(norm2.lengthOf() == 1) { - - const T N = norm2.e(0); - - auto cn = clipNorm.e(0); - - if(N > cn) { - - const T sumOfProd = (input * gradO).reduceNumber(reduce::Sum).e(0); // reduce to scalar - const T factor1 = static_cast(1.f) / N; - const T factor3 = factor1 / (N * N); // 1 / (N*N*N) - - auto lambda = LAMBDA_TT(elem1, elem2, cn, sumOfProd, factor1, factor3) { - return cn * (factor1 * elem2 - factor3 * elem1 * sumOfProd); - }; - - (const_cast(input)).applyPairwiseLambda(const_cast(gradO), lambda, gradI); - } - else - gradI.assign(gradO); - } - else { - - auto gradISubArrs = gradI.allTensorsAlongDimension({dimensions}); - auto gradOSubArrs = gradO.allTensorsAlongDimension({dimensions}); - auto inputSubArrs = input.allTensorsAlongDimension({dimensions}); - - auto cn = clipNorm.e(0); - - auto func = PRAGMA_THREADS_FOR { - for (auto i = start; i < stop; i++) { - T N = norm2.e(i); - - auto gradOSubArr = gradOSubArrs.at(i); - auto gradISubArr = gradISubArrs.at(i); - - if (N > cn) { - auto inputSubArr = inputSubArrs.at(i); - const T sumOfProd = (*inputSubArr * *gradOSubArr).reduceNumber(reduce::Sum).e(0); // reduce to scalar - const T factor1 = static_cast(1.f) / N; - const T factor3 = factor1 / (N * N); // 1 / (N*N*N) - - auto lambda = LAMBDA_TT(elem1, elem2, cn, sumOfProd, factor1, factor3) { - return cn * (factor1 * elem2 - factor3 * elem1 * sumOfProd); - }; - - inputSubArr->applyPairwiseLambda(*gradOSubArr, lambda, *gradISubArr); - } else - gradISubArr->assign(gradOSubArr); - } - }; - samediff::Threads::parallel_tad(func, 0, gradISubArrs.size()); - } -} - - void clipByNormBP(sd::LaunchContext * context, const NDArray& input, const NDArray& gradO, NDArray& gradI /*output*/, const std::vector& dimensions, const NDArray& clipNorm) { - BUILD_SINGLE_SELECTOR(gradI.dataType(), clipByNormBP_, (input, gradO, gradI, dimensions, clipNorm), FLOAT_TYPES); - } - - BUILD_SINGLE_TEMPLATE(template void clipByNormBP_, (const NDArray& input, const NDArray& gradO, NDArray& gradI /*output*/, const std::vector& dimensions, const NDArray& clipNorm), FLOAT_TYPES); - - -////////////////////////////////////////////////////////////////////////// -template -static void clipByAveraged_(NDArray& input, NDArray& output, const std::vector& dimensions, const NDArray& clipNorm, const bool isInplace) { - - auto cn = clipNorm.e(0); - if (dimensions.size() == 0) { - // all-reduce - T n2 = input.reduceNumber(reduce::Norm2).e(0) / input.lengthOf(); - if (n2 <= cn) { - if (!isInplace) - output.assign(input); - } - else { - const T factor = cn / n2; - auto lambda = LAMBDA_T(_x, factor) { return _x * factor; }; - input.applyLambda(lambda, output); - } - } - else { - // along dimension - auto norm2 = input.reduceAlongDimension(reduce::Norm2, dimensions, false); - if (!isInplace) - output.assign(input); - auto tads = output.allTensorsAlongDimension(dimensions); - // TODO: make this CUDA-compliant somehow - for (int e = 0; e < tads.size(); e++) { - T n2 = norm2.e(e) / tads.at(e)->lengthOf(); - const T factor = cn / n2; - if (n2 > cn) { - auto lambda = LAMBDA_T(_x, factor) {return _x * factor;}; - tads.at(e)->applyLambda(lambda, output); - } - } - } -} - - void clipByAveraged(sd::LaunchContext * context, NDArray& input, NDArray& output, const std::vector& dimensions, const NDArray& clipNorm, const bool isInplace) { - BUILD_SINGLE_SELECTOR(input.dataType(), clipByAveraged_, (input, output, dimensions, clipNorm, isInplace), FLOAT_TYPES); - } - - BUILD_SINGLE_TEMPLATE(template void clipByAveraged_, (NDArray& input, NDArray& output, const std::vector& dimensions, const NDArray& clipNorm, const bool isInplace), FLOAT_TYPES); - -/* - if (d1 > params[1]) - return params[1]; - else if (d1 < params[0]) - return params[0]; - else return d1; -*/ template static void clipByValue_(NDArray& input, double leftBound, double rightBound, NDArray& output) { diff --git a/libnd4j/include/ops/declarable/helpers/cpu/merge.cpp b/libnd4j/include/ops/declarable/helpers/cpu/merge.cpp index 7874d6d67..d748aa6b0 100644 --- a/libnd4j/include/ops/declarable/helpers/cpu/merge.cpp +++ b/libnd4j/include/ops/declarable/helpers/cpu/merge.cpp @@ -29,7 +29,7 @@ namespace helpers { ////////////////////////////////////////////////////////////////////////// -template +template static void mergeMaxIndex_(const std::vector& inArrs, NDArray& output) { const Nd4jLong numArgs = inArrs.size(); @@ -37,17 +37,18 @@ static void mergeMaxIndex_(const std::vector& inArrs, NDArray& o auto func = PRAGMA_THREADS_FOR { for (auto e = start; e < stop; e++) { - T max = -DataTypeUtils::max(); - Nd4jLong idx = 0; + X max = -DataTypeUtils::max(); + Z idx = static_cast(0); for (Nd4jLong i = 0; i < numArgs; i++) { - T v = inArrs[i]->e(e); + X v = inArrs[i]->t(e); if (v > max) { max = v; - idx = i; + idx = static_cast(i); } } - output.p(e, idx); + // FIXME, use .r(e) + output.t(e) = static_cast(idx); } }; @@ -55,14 +56,14 @@ static void mergeMaxIndex_(const std::vector& inArrs, NDArray& o } void mergeMaxIndex(sd::LaunchContext * context, const std::vector& inArrs, NDArray& output) { - BUILD_SINGLE_SELECTOR(inArrs[0]->dataType(), mergeMaxIndex_, (inArrs, output), LIBND4J_TYPES); + BUILD_DOUBLE_SELECTOR(inArrs[0]->dataType(), output.dataType(), mergeMaxIndex_, (inArrs, output), LIBND4J_TYPES, INDEXING_TYPES); } ////////////////////////////////////////////////////////////////////////// template static void mergeMax_(const std::vector& inArrs, NDArray& output) { - + const Nd4jLong numArgs = inArrs.size(); auto x = inArrs[0]; @@ -89,15 +90,15 @@ void mergeMax(sd::LaunchContext * context, const std::vector& in ////////////////////////////////////////////////////////////////////////// template static void mergeMaxBp_(const std::vector& inArrs, std::vector& outArrs) { - + // outArrs.size() == inArrs.size() - 1 const Nd4jLong numArgs = outArrs.size(); // last array is gradient const auto gradient = inArrs[numArgs]->bufferAsT(); auto length = inArrs[numArgs]->lengthOf(); - + bool bSameOrderAndEws1 = (1 == inArrs[numArgs]->ews()); - + if (bSameOrderAndEws1) { auto gradOrdering = inArrs[numArgs]->ordering(); @@ -108,8 +109,8 @@ static void mergeMaxBp_(const std::vector& inArrs, std::vectorews()); } } - - + + if(bSameOrderAndEws1){ auto func = PRAGMA_THREADS_FOR{ for (auto e = start; e < stop; e++) { @@ -130,7 +131,7 @@ static void mergeMaxBp_(const std::vector& inArrs, std::vectorshapeInfo(); std::vector vbSameShaepeAndStrides(numArgs); for (int i = 0; i < numArgs; ++i) { @@ -145,12 +146,12 @@ static void mergeMaxBp_(const std::vector& inArrs, std::vector(); Nd4jLong nMaxIndex = 0; for (Nd4jLong i = 0; i < numArgs; i++) { - + const auto xOffset = vbSameShaepeAndStrides[i] ? gradOffset : shape::getOffset(inArrs[i]->shapeInfo(), coords); const T* v = inArrs[i]->bufferAsT(); if (v[xOffset] > max) { @@ -160,7 +161,7 @@ static void mergeMaxBp_(const std::vector& inArrs, std::vectorshapeInfo(), coords); - + T* z = outArrs[nMaxIndex]->bufferAsT(); z[zOffset] = gradient[gradOffset]; } diff --git a/libnd4j/include/ops/declarable/helpers/cpu/reverse.cpp b/libnd4j/include/ops/declarable/helpers/cpu/reverse.cpp index 95417dade..bc072682a 100644 --- a/libnd4j/include/ops/declarable/helpers/cpu/reverse.cpp +++ b/libnd4j/include/ops/declarable/helpers/cpu/reverse.cpp @@ -193,13 +193,10 @@ static void reverseSequence_(sd::LaunchContext * context, const NDArray* input, } ////////////////////////////////////////////////////////////////////////// -void reverse(sd::LaunchContext * context, const NDArray* input, NDArray* output, const std::vector* intArgs, bool isBackProp) { +void reverse(sd::LaunchContext * context, const NDArray* input, NDArray* output, const std::vector* intArgs) { - // we need to reverse axis only if that's new op - std::vector dimensions = isBackProp ? ShapeUtils::evalDimsToExclude(input->rankOf(), *intArgs) : *intArgs; - - auto listOut = output->allTensorsAlongDimension(dimensions); - auto listIn = input->allTensorsAlongDimension(dimensions); + auto listOut = output->allTensorsAlongDimension(*intArgs); + auto listIn = input->allTensorsAlongDimension(*intArgs); NDArray *subArrIn, *subArrOut; diff --git a/libnd4j/include/ops/declarable/helpers/cuda/clip.cu b/libnd4j/include/ops/declarable/helpers/cuda/clip.cu new file mode 100644 index 000000000..8f1be21e4 --- /dev/null +++ b/libnd4j/include/ops/declarable/helpers/cuda/clip.cu @@ -0,0 +1,334 @@ +/******************************************************************************* + * Copyright (c) 2019 Konduit K.K. + * + * This program and the accompanying materials are made available under the + * terms of the Apache License, Version 2.0 which is available at + * https://www.apache.org/licenses/LICENSE-2.0. + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + * + * SPDX-License-Identifier: Apache-2.0 + ******************************************************************************/ + +// +// @author Yurii Shyrma (iuriish@yahoo.com) +// @author sgazeos@gmail.com +// @author raver119@gmail.com +// + + +#include +#include +#include +#include + +namespace sd { +namespace ops { +namespace helpers { + +////////////////////////////////////////////////////////////////////////// +template +__global__ static void clipByNormCuda(const void* vClipNorm, const void* vNorm, const Nd4jLong* normShapeInfo, void* vz, const Nd4jLong* zShapeInfo, const int* dimensions, const int dimsLen, const bool useAverage) { + + const T clipNorm = *reinterpret_cast(vClipNorm); + const T* norm = reinterpret_cast(vNorm); + T* z = reinterpret_cast(vz); + + __shared__ Nd4jLong zLen, tadLen, totalThreads; + + if (threadIdx.x == 0) { + + zLen = shape::length(zShapeInfo); + tadLen = zLen / shape::length(normShapeInfo); + totalThreads = gridDim.x * blockDim.x; + } + + __syncthreads(); + + int zCoords[MAX_RANK], normCoords[MAX_RANK]; + + const auto tid = blockIdx.x * blockDim.x + threadIdx.x; + + for (Nd4jLong i = tid; i < zLen; i += totalThreads) { + + shape::index2coords(i, zShapeInfo, zCoords); + + // deduce norm coords + for (int j = 0; j < dimsLen; ++j) + normCoords[j] = zCoords[dimensions[j]]; + + const T actualNorm = useAverage ? norm[shape::getOffset(normShapeInfo, normCoords)] / tadLen : norm[shape::getOffset(normShapeInfo, normCoords)]; + + if(actualNorm > clipNorm) + z[shape::getOffset(zShapeInfo, zCoords)] *= clipNorm / actualNorm; + } +} + +////////////////////////////////////////////////////////////////////////// +template +__host__ static void clipByNormCudaLauncher(const int blocksPerGrid, const int threadsPerBlock, const cudaStream_t *stream, + const void* vClipNorm, const void* vNorm, const Nd4jLong* normShapeInfo, void* vz, const Nd4jLong* zShapeInfo, + const int* dimensions, const int dimsLen, const bool useAverage) { + + clipByNormCuda<<>>(vClipNorm, vNorm, normShapeInfo, vz, zShapeInfo, dimensions, dimsLen, useAverage); +} + +////////////////////////////////////////////////////////////////////////// +void clipByNorm(sd::LaunchContext* context, NDArray& input, NDArray& output, const std::vector& dims, const NDArray& clipNorm, const bool isInplace, const bool useAverage) { + + NDArray* z = nullptr; + + if(isInplace) { + z = &input; + } + else { + output.assign(input); + z = &output; + } + + if(dims.empty()) { + + const NDArray actualNorm = useAverage ? z->reduceAlongDimension(reduce::Norm2, {}) / z->lengthOf() : z->reduceAlongDimension(reduce::Norm2, {}); + + if(actualNorm.e(0) > clipNorm.e(0)) + *z *= clipNorm / actualNorm; + } + else { + + const NDArray actualNorms = z->reduceAlongDimension(reduce::Norm2, dims); + + std::vector dimsToExclude = ShapeUtils::evalDimsToExclude(z->rankOf(), dims); + + const int threadsPerBlock = MAX_NUM_THREADS / 2; + const int blocksPerGrid = (z->lengthOf() + threadsPerBlock - 1) / threadsPerBlock; + + PointersManager manager(context, "clipByNorm"); + + const int* dimensions = reinterpret_cast(manager.replicatePointer(dimsToExclude.data(), dimsToExclude.size() * sizeof(int))); + + NDArray::prepareSpecialUse({z}, {z, &actualNorms, &clipNorm}); + BUILD_SINGLE_SELECTOR(z->dataType(), clipByNormCudaLauncher, (blocksPerGrid, threadsPerBlock, context->getCudaStream(), clipNorm.specialBuffer(), actualNorms.specialBuffer(), actualNorms.specialShapeInfo(), z->specialBuffer(), z->specialShapeInfo(), dimensions, (int)dimsToExclude.size(), useAverage), FLOAT_TYPES); + NDArray::registerSpecialUse({z}, {z, &actualNorms, &clipNorm}); + + manager.synchronize(); + } +} + +////////////////////////////////////////////////////////////////////////// +template +__global__ static void clipByNormBpCuda(const void* vClipNorm, + const void* vx, const Nd4jLong* xShapeInfo, // input + const void* vy, const Nd4jLong* yShapeInfo, // gradO + const void* vNorm, const Nd4jLong* normShapeInfo, + const void* vSum, const Nd4jLong* sumShapeInfo, + void* vz, const Nd4jLong* zShapeInfo, // gradI + const int* dimensions, const int dimsLen, const bool useAverage) { + + const T clipNorm = *reinterpret_cast(vClipNorm); + const T* norm = reinterpret_cast(vNorm); + const T* sum = reinterpret_cast(vSum); + const T* x = reinterpret_cast(vx); + const T* y = reinterpret_cast(vy); + T* z = reinterpret_cast(vz); + + __shared__ Nd4jLong zLen, tadLen, totalThreads; + __shared__ bool sameOffsets; + + if (threadIdx.x == 0) { + + zLen = shape::length(zShapeInfo); + tadLen = zLen / shape::length(normShapeInfo); + totalThreads = gridDim.x * blockDim.x; + + sameOffsets = shape::haveSameShapeAndStrides(xShapeInfo, yShapeInfo, zShapeInfo); + } + + __syncthreads(); + + int zCoords[MAX_RANK], normCoords[MAX_RANK]; + + const auto tid = blockIdx.x * blockDim.x + threadIdx.x; + + for (Nd4jLong i = tid; i < zLen; i += totalThreads) { + + shape::index2coords(i, zShapeInfo, zCoords); + + const auto zOffset = shape::getOffset(zShapeInfo, zCoords); + const auto yOffset = sameOffsets ? zOffset : shape::getOffset(yShapeInfo, zCoords); + + // deduce norm coords + for (int j = 0; j < dimsLen; ++j) + normCoords[j] = zCoords[dimensions[j]]; + + const T actualNorm = useAverage ? norm[shape::getOffset(normShapeInfo, normCoords)] / tadLen : norm[shape::getOffset(normShapeInfo, normCoords)]; + + if(actualNorm > clipNorm) { + + const T sumVal = sum[shape::getOffset(sumShapeInfo, normCoords)]; + const auto xOffset = sameOffsets ? zOffset : shape::getOffset(xShapeInfo, zCoords); + + z[zOffset] = (clipNorm / actualNorm) * y[yOffset] * (static_cast(1.f) - (x[xOffset] * sumVal) / (actualNorm * actualNorm)); + } + else + z[zOffset] = y[yOffset]; + } +} + +////////////////////////////////////////////////////////////////////////// +template +void clipByNormBp_(sd::LaunchContext* context, const NDArray& input, const NDArray& gradO, NDArray& gradI, const std::vector& dims, const NDArray& clipNorm, const bool useAverage) { + + const int rank = input.rankOf(); + + auto actualNorms = input.reduceAlongDimension(reduce::Norm2, dims); + + if(actualNorms.lengthOf() == 1) { + + const T norm = useAverage ? actualNorms.e(0) / static_cast(input.lengthOf()) : actualNorms.e(0); + + auto clipVal = clipNorm.e(0); + + if(norm > clipVal) { + + const T sum = input.reduceNumber(reduce::Sum).e(0); // reduce to scalar + const T factor1 = clipVal / norm; + const T factor2 = static_cast(1.f) / (norm * norm); // 1 / (norm*norm*norm) + + auto lambda = LAMBDA_TT(x, y, sum, factor1, factor2) { + return factor1 * y * (static_cast(1.f) - factor2 * x * sum); + }; + + const_cast(input).applyPairwiseLambda(const_cast(gradO), lambda, gradI); + } + else + gradI.assign(gradO); + } + else { + + const NDArray actualNorms = input.reduceAlongDimension(reduce::Norm2, dims); + const NDArray sums = input.reduceAlongDimension(reduce::Sum, dims); + + std::vector dimsToExclude = ShapeUtils::evalDimsToExclude(gradI.rankOf(), dims); + + const int threadsPerBlock = MAX_NUM_THREADS / 2; + const int blocksPerGrid = (gradI.lengthOf() + threadsPerBlock - 1) / threadsPerBlock; + + PointersManager manager(context, "clipByNormBp"); + + const int* dimensions = reinterpret_cast(manager.replicatePointer(dimsToExclude.data(), dimsToExclude.size() * sizeof(int))); + + NDArray::prepareSpecialUse({&gradI}, {&actualNorms, &sums, &clipNorm, &input, &gradO}); + clipByNormBpCuda<<getCudaStream()>>>(clipNorm.specialBuffer(), input.specialBuffer(), input.specialShapeInfo(), gradO.specialBuffer(), gradO.specialShapeInfo(), actualNorms.specialBuffer(), actualNorms.specialShapeInfo(), sums.specialBuffer(), sums.specialShapeInfo(), gradI.specialBuffer(), gradI.specialShapeInfo(), dimensions, (int)dimsToExclude.size(), useAverage); + NDArray::registerSpecialUse({&gradI}, {&actualNorms, &sums, &clipNorm, &input, &gradO}); + + manager.synchronize(); + } +} +BUILD_SINGLE_TEMPLATE(template void clipByNormBp_, (sd::LaunchContext* context, const NDArray& input, const NDArray& gradO, NDArray& gradI, const std::vector& dimensions, const NDArray& clipNorm, const bool useAverage), FLOAT_TYPES); + +////////////////////////////////////////////////////////////////////////// +void clipByNormBp(sd::LaunchContext* context, const NDArray& input, const NDArray& gradO, NDArray& gradI, const std::vector& dimensions, const NDArray& clipNorm, const bool useAverage) { + + const NDArray& castedInput = gradI.dataType() == input.dataType() ? input : input.cast(gradI.dataType()); + BUILD_SINGLE_SELECTOR(gradI.dataType(), clipByNormBp_, (context, castedInput, gradO, gradI, dimensions, clipNorm, useAverage), FLOAT_TYPES); +} + + + + + + + template + void clipByGlobalNorm_(sd::LaunchContext * context, std::vector const& inputs, double clipNorm, sd::memory::Workspace* workspace, std::vector& outputs, bool isInplace) { + NDArray globalNorm = NDArrayFactory::create(0, inputs[0]->getContext()); //sqrt(sum([l2norm(t)**2 for t in t_list])) + + for (auto i = 0; i < inputs.size(); i++) { + auto input = inputs[i]; + auto l2norm = input->reduceNumber(reduce::Norm2); + globalNorm += l2norm * l2norm; + } + + globalNorm.applyTransform(transform::Sqrt, globalNorm); // = sd::math::nd4j_sqrt(globalNorm); + outputs[inputs.size()]->p(0, globalNorm); + globalNorm.syncToHost(); + const T factor = static_cast(clipNorm) / globalNorm.e(0); + + for (size_t e = 0; e < inputs.size(); e++) { + // all-reduce + auto input = inputs[e]; + auto output = outputs[e]; + + if (globalNorm.e(0) <= clipNorm) { + output->assign(input); + } + else { + + auto lambda = LAMBDA_T(_x, factor) { return _x * factor; }; + input->applyLambda(lambda, *output); + } + } + } + + void clipByGlobalNorm(sd::LaunchContext * context, std::vector const& inputs, double clipNorm, sd::memory::Workspace* workspace, std::vector& outputs, bool isInplace) { + BUILD_SINGLE_SELECTOR(outputs[0]->dataType(), clipByGlobalNorm_, (context, inputs, clipNorm, workspace, outputs, isInplace), FLOAT_TYPES); + } + + BUILD_SINGLE_TEMPLATE(template void clipByGlobalNorm_, (sd::LaunchContext * context, std::vector const& inputs, double clipNorm, sd::memory::Workspace* workspace, std::vector& outputs, bool isInplace), FLOAT_TYPES); + + + template + static void __global__ clipByValueKernel(void* input, const Nd4jLong* inputShape, void* output, const Nd4jLong* outputShape, double leftBound, double rightBound) { + __shared__ T* outputBuf; + __shared__ T* inputBuf; + __shared__ Nd4jLong length; + __shared__ bool linearBuffers; + if (threadIdx.x == 0) { + outputBuf = reinterpret_cast(output); + inputBuf = reinterpret_cast(input); + length = shape::length(inputShape); + linearBuffers = shape::elementWiseStride(inputShape) == shape::elementWiseStride(outputShape) && shape::elementWiseStride(inputShape) == 1; + } + __syncthreads(); + const auto tid = blockIdx.x * blockDim.x + threadIdx.x; + const auto step = gridDim.x * blockDim.x; + + for (Nd4jLong e = tid; e < length; e += step) { + if (linearBuffers) { + if (inputBuf[e] > rightBound) outputBuf[e] = (T) rightBound; + else if (inputBuf[e] < leftBound) outputBuf[e] = (T) leftBound; + else outputBuf[e] = inputBuf[e]; + } + else { + auto inputOffset = shape::getIndexOffset(e, inputShape); + auto outputOffset = shape::getIndexOffset(e, outputShape); + if (inputBuf[inputOffset] > rightBound) outputBuf[outputOffset] = (T) rightBound; + else if (inputBuf[inputOffset] < leftBound) outputBuf[outputOffset] = (T) leftBound; + else outputBuf[outputOffset] = inputBuf[outputOffset]; + } + } + } + + template + static void clipByValue_(sd::LaunchContext * context, NDArray& input, double leftBound, double rightBound, NDArray& output) { + auto stream = context->getCudaStream(); + if (!input.isActualOnDeviceSide()) + input.syncToDevice(); + NDArray::prepareSpecialUse({&output}, {&input}); + clipByValueKernel<<<256, 512, 8192, *stream>>>(input.specialBuffer(), input.specialShapeInfo(), output.specialBuffer(), output.specialShapeInfo(), leftBound, rightBound); + NDArray::registerSpecialUse({&output}, {&input}); + } + + void clipByValue(sd::LaunchContext * context, NDArray& input, double leftBound, double rightBound, NDArray& output) { + BUILD_SINGLE_SELECTOR(input.dataType(), clipByValue_, (context, input, leftBound, rightBound, output), FLOAT_TYPES); + } + + BUILD_SINGLE_TEMPLATE(template void clipByValue_, (sd::LaunchContext * context, NDArray& input, double leftBound, double rightBound, NDArray& output);, FLOAT_TYPES); + +} +} +} + diff --git a/libnd4j/include/ops/declarable/helpers/cuda/reverse.cu b/libnd4j/include/ops/declarable/helpers/cuda/reverse.cu index b6bbeea4c..2ed45356e 100644 --- a/libnd4j/include/ops/declarable/helpers/cuda/reverse.cu +++ b/libnd4j/include/ops/declarable/helpers/cuda/reverse.cu @@ -210,14 +210,10 @@ namespace helpers { } ////////////////////////////////////////////////////////////////////////// - void reverse(sd::LaunchContext * context, const NDArray* input, NDArray* output, const std::vector* intArgs, bool isBackProp) { - // we need to reverse axis only if that's new op - std::vector dimensions = isBackProp ? ShapeUtils::evalDimsToExclude(input->rankOf(), *intArgs) : *intArgs; - std::vector axis = ShapeUtils::evalDimsToExclude(input->rankOf(), dimensions); - auto packX = sd::ConstantTadHelper::getInstance()->tadForDimensions(input->shapeInfo(), dimensions); - auto packZ = sd::ConstantTadHelper::getInstance()->tadForDimensions(output->shapeInfo(), dimensions); - + void reverse(sd::LaunchContext * context, const NDArray* input, NDArray* output, const std::vector* intArgs) { + auto packX = sd::ConstantTadHelper::getInstance()->tadForDimensions(input->shapeInfo(), *intArgs); + auto packZ = sd::ConstantTadHelper::getInstance()->tadForDimensions(output->shapeInfo(), *intArgs); NDArray::prepareSpecialUse({output}, {input}); diff --git a/libnd4j/include/ops/declarable/helpers/cuda/transforms.cu b/libnd4j/include/ops/declarable/helpers/cuda/transforms.cu index f016491a6..f14b12e35 100644 --- a/libnd4j/include/ops/declarable/helpers/cuda/transforms.cu +++ b/libnd4j/include/ops/declarable/helpers/cuda/transforms.cu @@ -300,269 +300,6 @@ void tileBP(sd::LaunchContext * context, const NDArray& gradO /*input*/, NDArray manager.synchronize(); } -////////////////////////////////////////////////////////////////////////// -// x - input, y - gradO, z - gradI -template -__global__ static void clipByNormBPWholeArrCuda(const void* vx, const Nd4jLong* xShapeInfo, const void* vy, const Nd4jLong* yShapeInfo, void* vz, const Nd4jLong* zShapeInfo, void* vreducBuff, const Z clipNormVal) { - - const auto tid = blockIdx.x * blockDim.x + threadIdx.x; - - if(tid >= shape::length(zShapeInfo)) - return; - - const auto x = reinterpret_cast(vx); - const auto y = reinterpret_cast(vy); - auto z = reinterpret_cast(vz); - - auto reducBuff = reinterpret_cast(vreducBuff); - uint* count = reinterpret_cast(vreducBuff) + 16384; - - __shared__ Z* shMem; - __shared__ Nd4jLong len; - __shared__ bool amIinLastBlock; - - if (threadIdx.x == 0) { - extern __shared__ unsigned char shmem[]; - shMem = reinterpret_cast(shmem); - - len = shape::length(zShapeInfo); // xLen = yLen = zLen - } - __syncthreads(); - - // fill shared memory with array elements - const auto xVal = x[shape::getIndexOffset(tid, xShapeInfo)]; - const auto yVal = y[shape::getIndexOffset(tid, yShapeInfo)]; - - shMem[2*threadIdx.x] = static_cast(xVal * xVal); // for norm - shMem[2*threadIdx.x + 1] = static_cast(xVal * yVal); // for input * gradO - - __syncthreads(); - - // accumulate sum per block - for (int activeThreads = blockDim.x / 2; activeThreads > 0; activeThreads /= 2) { - - if (threadIdx.x < activeThreads && tid + activeThreads < len) { - - shMem[2*threadIdx.x] += shMem[2*(threadIdx.x + activeThreads)]; - shMem[2*threadIdx.x + 1] += shMem[2*(threadIdx.x + activeThreads) + 1]; - } - __syncthreads(); - } - - // store accumulated sums in reduction buffer (reducBuff) - if (threadIdx.x == 0) { - - reducBuff[2*blockIdx.x] = shMem[0]; - reducBuff[2*blockIdx.x + 1] = shMem[1]; - - __threadfence(); - - amIinLastBlock = gridDim.x == 1 || (atomicInc(count, gridDim.x) == gridDim.x - 1); - } - __syncthreads(); - - // shared memory of last block is used for final summation of values stored in reduction buffer - if (amIinLastBlock) { - - for (int i = threadIdx.x; i < gridDim.x; i += blockDim.x) { - - shMem[2*threadIdx.x] = (i == threadIdx.x ) ? reducBuff[2*i] : reducBuff[2*i] + shMem[2*threadIdx.x]; - shMem[2*threadIdx.x + 1] = (i == threadIdx.x ) ? reducBuff[2*i + 1] : reducBuff[2*i + 1] + shMem[2*threadIdx.x + 1]; - } - __syncthreads(); - - // accumulate sum - for (int activeThreads = blockDim.x / 2; activeThreads > 0; activeThreads /= 2) { - - if (threadIdx.x < activeThreads && threadIdx.x + activeThreads < gridDim.x) { - shMem[2*threadIdx.x] += shMem[2*(threadIdx.x + activeThreads)]; - shMem[2*threadIdx.x + 1] += shMem[2*(threadIdx.x + activeThreads) + 1]; - } - __syncthreads(); - } - - if (threadIdx.x == 0) { - - reducBuff[0] = math::nd4j_sqrt(shMem[0]); - reducBuff[1] = shMem[1]; - count = 0; - } - } -} - -////////////////////////////////////////////////////////////////////////// -// x - input, y - gradO, z - gradI -template -__global__ static void clipByNormBPCalcGradCuda(const void* vx, const Nd4jLong* xShapeInfo, const void* vy, const Nd4jLong* yShapeInfo, void* vz, const Nd4jLong* zShapeInfo, void* vreducBuff, const Z clipNormVal) { - - const auto tid = blockIdx.x * blockDim.x + threadIdx.x; - - const Nd4jLong len = shape::length(zShapeInfo); // xLen = yLen = zLen - - if(tid >= len) - return; - - const auto x = reinterpret_cast(vx); - const auto y = reinterpret_cast(vy); - auto z = reinterpret_cast(vz); - - __shared__ Z norm, sumOfProd; - - if (threadIdx.x == 0) { - - norm = reinterpret_cast(vreducBuff)[0]; - sumOfProd = reinterpret_cast(vreducBuff)[1]; - } - __syncthreads(); - - const auto yOffset = shape::getIndexOffset(tid, yShapeInfo); - const auto zOffset = shape::getIndexOffset(tid, zShapeInfo); - - if(norm > clipNormVal) { - - const auto xOffset = shape::getIndexOffset(tid, xShapeInfo); - - const Z factor1 = static_cast(1) / norm; // 1 / norm - const Z factor2 = factor1 / (norm * norm); // 1 / (norm * norm * norm) - - z[zOffset] = clipNormVal * (factor1 * y[yOffset] - factor2 * sumOfProd * x[xOffset]); - } - else { - z[zOffset] = y[yOffset]; - } -} - -////////////////////////////////////////////////////////////////////////// -// x - input, y - gradO, z - gradI -template -__global__ static void clipByNormBPTadsCuda(const void* vx, const Nd4jLong* xTadShapeInfo, const Nd4jLong* xTadOffsets, const void* vy, const Nd4jLong* yTadShapeInfo, const Nd4jLong* yTadOffsets, void* vz, const Nd4jLong* zTadShapeInfo, const Nd4jLong* zTadOffsets, const Z clipNormVal) { - - const auto x = reinterpret_cast(vx); - const auto y = reinterpret_cast(vy); - auto z = reinterpret_cast(vz); - - __shared__ Z* shMem; - __shared__ Nd4jLong tadLen; - - if (threadIdx.x == 0) { - - extern __shared__ unsigned char shmem[]; - shMem = reinterpret_cast(shmem); - tadLen = shape::length(zTadShapeInfo); // xTadLen = yTadLen = zTadLen - } - __syncthreads(); - - const auto* xTad = x + xTadOffsets[blockIdx.x]; - const auto* yTad = y + yTadOffsets[blockIdx.x]; - auto* zTad = z + zTadOffsets[blockIdx.x]; - - // *** FIRST STAGE - ACCUMULATE REQUIRED SUMS *** // - - Z norm = 0; - Z sumOfProd = 0; - - for (uint i = threadIdx.x; i < tadLen; i += blockDim.x) { - - const auto xOffset = shape::getIndexOffset(i, xTadShapeInfo); - const auto yOffset = shape::getIndexOffset(i, yTadShapeInfo); - - shMem[2*threadIdx.x] = static_cast(xTad[xOffset] * xTad[xOffset]); // for norm - shMem[2*threadIdx.x + 1] = static_cast(xTad[xOffset] * yTad[yOffset]); // for input * gradO - - __syncthreads(); - - // accumulate sum per block - for (uint activeThreads = blockDim.x / 2; activeThreads > 0; activeThreads /= 2) { - - if (threadIdx.x < activeThreads && i + activeThreads < tadLen) { - - shMem[2*threadIdx.x] += shMem[2*(threadIdx.x + activeThreads)]; - shMem[2*threadIdx.x + 1] += shMem[2*(threadIdx.x + activeThreads) + 1]; - } - __syncthreads(); - } - - norm += shMem[0]; - sumOfProd += shMem[1]; - } - - // *** SECOND STAGE - GRADIENT CALCULATION *** // - - norm = math::nd4j_sqrt(norm); - - for (uint i = threadIdx.x; i < tadLen; i += blockDim.x) { - - const auto yOffset = shape::getIndexOffset(i, yTadShapeInfo); - const auto zOffset = shape::getIndexOffset(i, zTadShapeInfo); - - if(norm > clipNormVal) { - - const auto xOffset = shape::getIndexOffset(i, xTadShapeInfo); - - const Z factor1 = static_cast(1) / norm; // 1 / norm - const Z factor2 = factor1 / (norm * norm); // 1 / (norm * norm * norm) - - zTad[zOffset] = clipNormVal * (factor1 * yTad[yOffset] - factor2 * sumOfProd * xTad[xOffset]); - } - else { - zTad[zOffset] = yTad[yOffset]; - } - } -} - -////////////////////////////////////////////////////////////////////////// -template -static void clipByNormBPCudaLauncher(const int blocksPerGrid, const int threadsPerBlock, const int sharedMem, const cudaStream_t *stream, - const void* vx, const Nd4jLong* xShapeInfo, const Nd4jLong* xTadOffsets, - const void* vy, const Nd4jLong* yShapeInfo, const Nd4jLong* yTadOffsets, - void* vz, const Nd4jLong* zShapeInfo, const Nd4jLong* zTadOffsets, - void* vreducBuff, const double clipNormVal) { - - if(xTadOffsets == nullptr) { // means whole array - clipByNormBPWholeArrCuda<<>>(vx, xShapeInfo, vy, yShapeInfo, vz, zShapeInfo, vreducBuff, static_cast(clipNormVal)); - clipByNormBPCalcGradCuda<<>>(vx, xShapeInfo, vy, yShapeInfo, vz, zShapeInfo, vreducBuff, static_cast(clipNormVal)); - } - else // means tads using - clipByNormBPTadsCuda<<>>(vx, xShapeInfo, xTadOffsets, vy, yShapeInfo, yTadOffsets, vz, zShapeInfo, zTadOffsets, static_cast(clipNormVal)); -} -BUILD_DOUBLE_TEMPLATE(template void clipByNormBPCudaLauncher, (const int blocksPerGrid, const int threadsPerBlock, const int sharedMem, const cudaStream_t *stream, const void *vx, const Nd4jLong *xShapeInfo, const Nd4jLong* xTadOffsets, const void *vy, const Nd4jLong *yShapeInfo, const Nd4jLong* yTadOffsets, void *vz, const Nd4jLong *zShapeInfo, const Nd4jLong* zTadOffsets, void* vreducBuff, const double clipNormVal), FLOAT_TYPES, FLOAT_TYPES); - -////////////////////////////////////////////////////////////////////////// -void clipByNormBP(sd::LaunchContext* context, const NDArray& input, const NDArray& gradO, NDArray& gradI /*output*/, const std::vector& dimensions, const NDArray& clipNorm) { - - PointersManager manager(context, "clipByNormBP"); - - const double clipNormVal = clipNorm.e(0); - - const auto xType = input.dataType(); - const auto zType = gradI.dataType(); - - const int threadsPerBlock = MAX_NUM_THREADS / 2; - const int sharedMem = threadsPerBlock * 2 * input.sizeOfT() + 128; - - NDArray::prepareSpecialUse({&gradI}, {&input, &gradO}); - - - if(dimensions.empty() || dimensions.size() == input.rankOf()) { // means whole array - - const int blocksPerGrid = (input.lengthOf() + threadsPerBlock - 1) / threadsPerBlock; - BUILD_DOUBLE_SELECTOR(xType, zType, clipByNormBPCudaLauncher, (blocksPerGrid, threadsPerBlock, sharedMem, context->getCudaStream(), input.specialBuffer(), input.specialShapeInfo(), nullptr, gradO.specialBuffer(), gradO.specialShapeInfo(), nullptr, gradI.specialBuffer(), gradI.specialShapeInfo(), nullptr, context->getReductionPointer(), clipNormVal), FLOAT_TYPES, FLOAT_TYPES); - } - else { // means tads using - - auto packX = ConstantTadHelper::getInstance()->tadForDimensions(input.shapeInfo(), dimensions); - auto packY = ConstantTadHelper::getInstance()->tadForDimensions(gradO.shapeInfo(), dimensions); - auto packZ = ConstantTadHelper::getInstance()->tadForDimensions(gradI.shapeInfo(), dimensions); - - const int blocksPerGrid = packX.numberOfTads(); - BUILD_DOUBLE_SELECTOR(xType, zType, clipByNormBPCudaLauncher, (blocksPerGrid, threadsPerBlock, sharedMem, context->getCudaStream(), input.specialBuffer(), packX.platformShapeInfo(), packX.platformOffsets(), gradO.specialBuffer(), packY.platformShapeInfo(), packY.platformOffsets(), gradI.specialBuffer(), packZ.platformShapeInfo(), packZ.platformOffsets(), nullptr, clipNormVal), FLOAT_TYPES, FLOAT_TYPES); - } - - NDArray::registerSpecialUse({&gradI}, {&input, &gradO}); - - manager.synchronize(); -} - template static __global__ void swapShuffleKernel(T* input, Nd4jLong const* shape, Nd4jLong firstDim, sd::graph::RandomGenerator* rng) { auto tid = blockIdx.x * blockDim.x; @@ -692,252 +429,6 @@ void clipByNormBP(sd::LaunchContext* context, const NDArray& input, const NDArra output.setIdentity(); } - //////////////////////////////////////////////////////////////////////////////////////////////////////////////////// - template - static __global__ void clipByNormInplaceKernel(Nd4jLong numOfSubArrs, T* inputBuffer, Nd4jLong const* shape, Nd4jLong const* inputOffsets, T* norm2Buf, Nd4jLong const* norm2shape, T clipNorm) { - for (int arr = blockIdx.x; arr < numOfSubArrs; arr += gridDim.x) { - __shared__ T* z; - __shared__ Nd4jLong len; - if (threadIdx.x == 0) { - len = shape::length(shape); - z = inputBuffer + inputOffsets[arr]; - } - __syncthreads(); - for (int j = threadIdx.x; j < len; j+= blockDim.x) { - auto xIndex = shape::getIndexOffset(j, shape); - - if(norm2Buf[arr] > clipNorm) - z[xIndex] *= clipNorm / norm2Buf[arr]; // case with ews = 1 and ordering is 'c' - } - } - } - //////////////////////////////////////////////////////////////////////////////////////////////////////////////////// - template - static __global__ void clipByNormKernel(Nd4jLong numOfSubArrs, T* inputBuffer, Nd4jLong const* shape, Nd4jLong const* inputOffsets, T* outputBuffer, Nd4jLong const* outputShape, Nd4jLong const* outputOffsets, T* norm2Buf, Nd4jLong const* norm2shape, T clipNorm) { - - for (Nd4jLong arr = blockIdx.x; arr < numOfSubArrs; arr += gridDim.x) { - __shared__ T* x, *z; - __shared__ Nd4jLong lenZ; - __shared__ T norm2; - - if (threadIdx.x == 0) { - x = inputBuffer + inputOffsets[arr]; - z = outputBuffer + outputOffsets[arr]; - lenZ = shape::length(outputShape); - norm2 = norm2Buf[shape::getIndexOffset(arr, norm2shape)]; - } - __syncthreads(); - for (Nd4jLong j = threadIdx.x; j < lenZ; j+= blockDim.x) { - auto xIndex = shape::getIndexOffset(j, shape); - auto zIndex = shape::getIndexOffset(j, outputShape); - if(norm2 > clipNorm) { - z[zIndex] = x[xIndex] * clipNorm / norm2; // case with ews = 1 and ordering is 'c' - } else { - z[zIndex] = x[xIndex]; - } - //printf("%lld: %lf %lf\n", j, z[zIndex], x[xIndex]); - } - __syncthreads(); - } - } - - ////////////////////////////////////////////////////////////////////////// - template - static void clipByNorm_(sd::LaunchContext * context, NDArray& input, NDArray& output, const std::vector& dimensions, NDArray const& clipNormA, const bool isInplace) { - const int rank = input.rankOf(); - auto norm2 = input.reduceAlongDimension(reduce::Norm2, dimensions); - clipNormA.syncToHost(); - //norm2.printBuffer("Norm2"); - T const clipNorm = clipNormA.e(0); - //clipNormA.printBuffer("ClipNorm"); - auto stream = context->getCudaStream(); - if (isInplace) { - if(norm2.lengthOf() == 1) { - norm2.syncToHost(); - T norm2Val = norm2.e(0); - if(norm2Val > clipNorm) - input *= clipNorm / norm2Val; - } - else { - - std::vector dimsToExclude = ShapeUtils::evalDimsToExclude(rank, dimensions); - const Nd4jLong numOfSubArrs = ShapeUtils::getNumOfSubArrs(input.shapeInfo(), dimsToExclude); - auto packX = sd::ConstantTadHelper::getInstance()->tadForDimensions(input.shapeInfo(), dimensions); - //auto packZ = sd::ConstantTadHelper::getInstance()->tadForDimensions(output.shapeInfo(), dimsToExclude); - T* inputBuffer = reinterpret_cast(input.specialBuffer()); - T* norm2buf = reinterpret_cast(norm2.specialBuffer()); - - clipByNormInplaceKernel<<<256, 512, 1024, *stream>>>(numOfSubArrs, inputBuffer, packX.specialShapeInfo(), packX.specialOffsets(), norm2buf, norm2.specialShapeInfo(), clipNorm); - } - } - else { - - if(norm2.lengthOf() == 1) { - norm2.syncToHost(); - T norm2Val = norm2.e(0); - - if(norm2Val > clipNorm) - output.assign( input * (clipNorm / norm2Val)); - else - output.assign( input ); - } - else { - - std::vector dimsToExclude = ShapeUtils::evalDimsToExclude(rank, dimensions); - const Nd4jLong numOfSubArrs = ShapeUtils::getNumOfSubArrs(input.shapeInfo(), dimsToExclude); - auto packX = sd::ConstantTadHelper::getInstance()->tadForDimensions(input.shapeInfo(), dimensions); - auto packZ = sd::ConstantTadHelper::getInstance()->tadForDimensions(output.shapeInfo(), dimensions); - T* inputBuffer = reinterpret_cast(input.specialBuffer()); - T* norm2buf = reinterpret_cast(norm2.specialBuffer()); - T* outputBuffer = reinterpret_cast(output.specialBuffer()); - - clipByNormKernel<<<256, 512, 1024, *stream>>>(numOfSubArrs, inputBuffer, packX.specialShapeInfo(), packX.specialOffsets(), outputBuffer, packZ.specialShapeInfo(), packZ.specialOffsets(), norm2buf, norm2.specialShapeInfo(), clipNorm); - } - } - } - - void clipByNorm(sd::LaunchContext * context, NDArray& input, NDArray& output, const std::vector& dimensions, const NDArray& clipNorm, const bool isInplace) { - BUILD_SINGLE_SELECTOR(output.dataType(), clipByNorm_, (context, input, output, dimensions, clipNorm, isInplace), FLOAT_TYPES); - } - - BUILD_SINGLE_TEMPLATE(template void clipByNorm_, (sd::LaunchContext * context, NDArray& input, NDArray& output, const std::vector& dimensions, const NDArray& clipNorm, const bool isInplace), FLOAT_TYPES); - - template - void clipByGlobalNorm_(sd::LaunchContext * context, std::vector const& inputs, double clipNorm, sd::memory::Workspace* workspace, std::vector& outputs, bool isInplace) { - NDArray globalNorm = NDArrayFactory::create(0, inputs[0]->getContext()); //sqrt(sum([l2norm(t)**2 for t in t_list])) - - for (auto i = 0; i < inputs.size(); i++) { - auto input = inputs[i]; - auto l2norm = input->reduceNumber(reduce::Norm2); - globalNorm += l2norm * l2norm; - } - - globalNorm.applyTransform(transform::Sqrt, globalNorm); // = sd::math::nd4j_sqrt(globalNorm); - outputs[inputs.size()]->p(0, globalNorm); - globalNorm.syncToHost(); - const T factor = static_cast(clipNorm) / globalNorm.e(0); - - for (size_t e = 0; e < inputs.size(); e++) { - // all-reduce - auto input = inputs[e]; - auto output = outputs[e]; - - if (globalNorm.e(0) <= clipNorm) { - output->assign(input); - } - else { - - auto lambda = LAMBDA_T(_x, factor) { return _x * factor; }; - input->applyLambda(lambda, *output); - } - } - } - - void clipByGlobalNorm(sd::LaunchContext * context, std::vector const& inputs, double clipNorm, sd::memory::Workspace* workspace, std::vector& outputs, bool isInplace) { - BUILD_SINGLE_SELECTOR(outputs[0]->dataType(), clipByGlobalNorm_, (context, inputs, clipNorm, workspace, outputs, isInplace), FLOAT_TYPES); - } - - BUILD_SINGLE_TEMPLATE(template void clipByGlobalNorm_, (sd::LaunchContext * context, std::vector const& inputs, double clipNorm, sd::memory::Workspace* workspace, std::vector& outputs, bool isInplace), FLOAT_TYPES); - - - ////////////////////////////////////////////////////////////////////////// - template - static void clipByAveraged_(sd::LaunchContext * context, NDArray& input, NDArray& output, const std::vector& dimensions, const NDArray& clipNorm, const bool isInplace) { - auto cn = clipNorm.e(0); - if (dimensions.size() == 0) { - // all-reduce - T n2 = input.reduceNumber(reduce::Norm2).e(0) / static_cast(input.lengthOf()); - if (n2 <= cn) { - if (!isInplace) - output.assign(input); - } - else { - const T factor = cn / n2; - //auto lambda = LAMBDA_T(_x, factor) { return _x * factor; }; - //input.applyLambda(lambda, output); - output.assign(input * factor); - } - } - else { - // along dimension - auto norm2 = input.reduceAlongDimension(reduce::Norm2, dimensions, false); - if (!isInplace) - output.assign(input); - auto tads = output.allTensorsAlongDimension(dimensions); - auto outTads = output.allTensorsAlongDimension(dimensions); - // TODO: make this CUDA-compliant somehow - for (int e = 0; e < tads.size(); e++) { - T n2 = norm2.e(e) / static_cast(tads.at(e)->lengthOf()); - const T factor = cn / n2; - if (n2 > cn) { - //auto lambda = LAMBDA_T(_x, factor) {return _x * factor;}; - tads.at(e)->applyScalar(scalar::Multiply, factor, *outTads.at(e));//applyLambda(lambda, &output); - } - } - } - } - - void clipByAveraged(sd::LaunchContext * context, NDArray& input, NDArray& output, const std::vector& dimensions, const NDArray& clipNorm, const bool isInplace) { - BUILD_SINGLE_SELECTOR(input.dataType(), clipByAveraged_, (context, input, output, dimensions, clipNorm, isInplace), FLOAT_TYPES); - } - - BUILD_SINGLE_TEMPLATE(template void clipByAveraged_, (sd::LaunchContext * context, NDArray& input, NDArray& output, const std::vector& dimensions, const NDArray& clipNorm, const bool isInplace), FLOAT_TYPES); - -/* - if (d1 > params[1]) - return params[1]; - else if (d1 < params[0]) - return params[0]; - else return d1; -*/ - template - static void __global__ clipByValueKernel(void* input, Nd4jLong const* inputShape, void* output, Nd4jLong const* outputShape, double leftBound, double rightBound) { - __shared__ T* outputBuf; - __shared__ T* inputBuf; - __shared__ Nd4jLong length; - __shared__ bool linearBuffers; - if (threadIdx.x == 0) { - outputBuf = reinterpret_cast(output); - inputBuf = reinterpret_cast(input); - length = shape::length(inputShape); - linearBuffers = shape::elementWiseStride(inputShape) == shape::elementWiseStride(outputShape) && shape::elementWiseStride(inputShape) == 1; - } - __syncthreads(); - const auto tid = blockIdx.x * blockDim.x + threadIdx.x; - const auto step = gridDim.x * blockDim.x; - - for (Nd4jLong e = tid; e < length; e += step) { - if (linearBuffers) { - if (inputBuf[e] > rightBound) outputBuf[e] = (T) rightBound; - else if (inputBuf[e] < leftBound) outputBuf[e] = (T) leftBound; - else outputBuf[e] = inputBuf[e]; - } - else { - auto inputOffset = shape::getIndexOffset(e, inputShape); - auto outputOffset = shape::getIndexOffset(e, outputShape); - if (inputBuf[inputOffset] > rightBound) outputBuf[outputOffset] = (T) rightBound; - else if (inputBuf[inputOffset] < leftBound) outputBuf[outputOffset] = (T) leftBound; - else outputBuf[outputOffset] = inputBuf[outputOffset]; - } - } - } - - template - static void clipByValue_(sd::LaunchContext * context, NDArray& input, double leftBound, double rightBound, NDArray& output) { - auto stream = context->getCudaStream(); - if (!input.isActualOnDeviceSide()) - input.syncToDevice(); - NDArray::prepareSpecialUse({&output}, {&input}); - clipByValueKernel<<<256, 512, 8192, *stream>>>(input.specialBuffer(), input.specialShapeInfo(), output.specialBuffer(), output.specialShapeInfo(), leftBound, rightBound); - NDArray::registerSpecialUse({&output}, {&input}); - } - - void clipByValue(sd::LaunchContext * context, NDArray& input, double leftBound, double rightBound, NDArray& output) { - BUILD_SINGLE_SELECTOR(input.dataType(), clipByValue_, (context, input, leftBound, rightBound, output), FLOAT_TYPES); - } - - BUILD_SINGLE_TEMPLATE(template void clipByValue_, (sd::LaunchContext * context, NDArray& input, double leftBound, double rightBound, NDArray& output);, FLOAT_TYPES); - } } } diff --git a/libnd4j/include/ops/declarable/helpers/reverse.h b/libnd4j/include/ops/declarable/helpers/reverse.h index d85d017ba..c44327bb0 100644 --- a/libnd4j/include/ops/declarable/helpers/reverse.h +++ b/libnd4j/include/ops/declarable/helpers/reverse.h @@ -29,9 +29,9 @@ namespace helpers { void reverseSequence(sd::LaunchContext * context, const NDArray* input, const NDArray* seqLengths, NDArray* output, int seqDim, const int batchDim); - void reverse(sd::LaunchContext * context, const NDArray* input, NDArray* output, const std::vector* intArgs, bool isBackProp); + void reverse(sd::LaunchContext * context, const NDArray* input, NDArray* output, const std::vector* intArgs); + - } } diff --git a/libnd4j/include/ops/declarable/helpers/transforms.h b/libnd4j/include/ops/declarable/helpers/transforms.h index 6ebecd8f7..d20b98e6c 100644 --- a/libnd4j/include/ops/declarable/helpers/transforms.h +++ b/libnd4j/include/ops/declarable/helpers/transforms.h @@ -63,13 +63,13 @@ namespace helpers { void mergeAdd(sd::LaunchContext * context, const std::vector& inArrs, NDArray& output); void mergeAddBp(sd::LaunchContext* context, const NDArray& gradient, std::vector& outArrs); - void clipByNorm(sd::LaunchContext * context, NDArray& input, NDArray& output, const std::vector& dimensions, const NDArray& clipNorm, const bool isInplace); + void clipByNorm(sd::LaunchContext * context, NDArray& input, NDArray& output, const std::vector& dimensions, const NDArray& clipNorm, const bool isInplace, const bool useAverage); + void clipByGlobalNorm(sd::LaunchContext * context, std::vector const& inputs, double clipNorm, sd::memory::Workspace* workspace, std::vector& outputs, bool isInplace); - void clipByNormBP(sd::LaunchContext * context, const NDArray& input, const NDArray& gradO, NDArray& gradI /*output*/, const std::vector& dimensions, const NDArray& clipNorm); + void clipByNormBp(sd::LaunchContext * context, const NDArray& input, const NDArray& gradO, NDArray& gradI /*output*/, const std::vector& dimensions, const NDArray& clipNorm, const bool useAverage); - void clipByAveraged(sd::LaunchContext * context, NDArray& input, NDArray& output, const std::vector& dimensions, const NDArray& clipNorm, const bool isInplace); - void clipByValue(sd::LaunchContext * context, NDArray& input, double leftBound, double rightBound, NDArray& output); + void clipByAveragedNorm(sd::LaunchContext * context, NDArray& input, NDArray& output, const std::vector& dimensions, const NDArray& clipNorm, const bool isInplace); void mirrorPad(sd::LaunchContext * context, const NDArray& input, const NDArray& paddings, NDArray& output, const int mode); diff --git a/libnd4j/include/ops/declarable/impl/DeclarableOp.cpp b/libnd4j/include/ops/declarable/impl/DeclarableOp.cpp index c839c41c9..713a02666 100644 --- a/libnd4j/include/ops/declarable/impl/DeclarableOp.cpp +++ b/libnd4j/include/ops/declarable/impl/DeclarableOp.cpp @@ -1093,7 +1093,7 @@ namespace sd { return ND4J_STATUS_OK; NDArray *a0 = block.array(0); - for (int e = 0; e < block.width(); e++) { + for (int e = 1; e < block.width(); e++) { auto aV = block.array(e); if (!shape::equalsSoft(a0->shapeInfo(), aV->shapeInfo())) return ND4J_STATUS_BAD_DIMENSIONS; diff --git a/libnd4j/include/ops/declarable/platform/mkldnn/batchnorm.cpp b/libnd4j/include/ops/declarable/platform/mkldnn/batchnorm.cpp index 21bdbbe8d..6e0b1685a 100644 --- a/libnd4j/include/ops/declarable/platform/mkldnn/batchnorm.cpp +++ b/libnd4j/include/ops/declarable/platform/mkldnn/batchnorm.cpp @@ -90,13 +90,12 @@ static void batchnormMKLDNN(const NDArray* x, const NDArray* mean, const NDArray // x dnnl::memory::desc x_mkl_md = dnnl::memory::desc(dims, type, format); dnnl::memory::desc x_user_md = dnnl::memory::desc(dims, type, format); + mkldnnUtils::setBlockStrides(*x, x_user_md); - mkldnnUtils::setBlockStrides(x, x_user_md); // z, output dnnl::memory::desc z_mkl_md = dnnl::memory::desc(dims, type, dnnl::memory::format_tag::any); dnnl::memory::desc z_user_md = dnnl::memory::desc(dims, type, format); - - mkldnnUtils::setBlockStrides(z, z_user_md); + mkldnnUtils::setBlockStrides(*z, z_user_md); auto engine = mkldnnUtils::getEngine(LaunchContext::defaultContext()->engine()); @@ -112,15 +111,10 @@ static void batchnormMKLDNN(const NDArray* x, const NDArray* mean, const NDArray // provide memory and check whether reorder is required // x - mkldnnUtils::loadDataToMklStream(x, engine, stream, x_user_md, op_ff_prim_desc.src_desc(), args[DNNL_ARG_SRC]); + mkldnnUtils::loadDataToMklStream(*x, engine, stream, x_user_md, op_ff_prim_desc.src_desc(), args[DNNL_ARG_SRC]); // z - auto z_user_mem = dnnl::memory(z_user_md, engine, z->buffer()); - const bool zReorder = op_ff_prim_desc.dst_desc() != z_user_mem.get_desc(); - auto z_mkl_mem = zReorder ? dnnl::memory(op_ff_prim_desc.dst_desc(), engine) : z_user_mem; - if (zReorder) - dnnl::reorder(z_user_mem, z_mkl_mem).execute(stream, z_user_mem, z_mkl_mem); - args[DNNL_ARG_DST] = z_mkl_mem; + auto z_user_mem = mkldnnUtils::loadDataToMklStream(*z, engine, stream, z_user_md, op_ff_prim_desc.dst_desc(), args[DNNL_ARG_DST]); // mean auto mean_mkl_mem = dnnl::memory(op_ff_prim_desc.mean_desc(), engine, const_cast(mean->buffer())); @@ -141,8 +135,8 @@ static void batchnormMKLDNN(const NDArray* x, const NDArray* mean, const NDArray dnnl::batch_normalization_forward(op_ff_prim_desc).execute(stream, args); // reorder outputs if necessary - if (zReorder) - dnnl::reorder(z_mkl_mem, z_user_mem).execute(stream, z_mkl_mem, z_user_mem); + if (op_ff_prim_desc.dst_desc() != z_user_mem.get_desc()) + dnnl::reorder(args[DNNL_ARG_DST], z_user_mem).execute(stream, args[DNNL_ARG_DST], z_user_mem); stream.wait(); @@ -151,7 +145,7 @@ static void batchnormMKLDNN(const NDArray* x, const NDArray* mean, const NDArray ////////////////////////////////////////////////////////////////////////// -static void batchnormBackPropMKLDNN(const NDArray* x, const NDArray* mean, const NDArray* variance, const NDArray &dLdO, const NDArray* weights, +static void batchnormBpMKLDNN(const NDArray* x, const NDArray* mean, const NDArray* variance, const NDArray &dLdO, const NDArray* weights, NDArray* dLdI, NDArray* dLdW, const float epsilon, const bool isNCHW) { // unfortunately mkl dnn doesn't support any format (dnnl::memory::format_tag::any) for x @@ -206,20 +200,17 @@ static void batchnormBackPropMKLDNN(const NDArray* x, const NDArray* mean, const // x dnnl::memory::desc x_mkl_md = dnnl::memory::desc(dims, type, format); dnnl::memory::desc x_user_md = dnnl::memory::desc(dims, type, format); - - mkldnnUtils::setBlockStrides(x, x_user_md); + mkldnnUtils::setBlockStrides(*x, x_user_md); // dLdO dnnl::memory::desc dLdO_mkl_md = dnnl::memory::desc(dims, type, dnnl::memory::format_tag::any); dnnl::memory::desc dLdO_user_md = dnnl::memory::desc(dims, type, format); - - mkldnnUtils::setBlockStrides(&dLdO, dLdO_user_md); + mkldnnUtils::setBlockStrides(dLdO, dLdO_user_md); // dLdI dnnl::memory::desc dLdI_mkl_md = dnnl::memory::desc(dims, type, dnnl::memory::format_tag::any); dnnl::memory::desc dLdI_user_md = dnnl::memory::desc(dims, type, format); - - mkldnnUtils::setBlockStrides(dLdI, dLdI_user_md); + mkldnnUtils::setBlockStrides(*dLdI, dLdI_user_md); auto engine = mkldnnUtils::getEngine(LaunchContext::defaultContext()->engine()); @@ -239,10 +230,10 @@ static void batchnormBackPropMKLDNN(const NDArray* x, const NDArray* mean, const // provide memory and check whether reorder is required // x - mkldnnUtils::loadDataToMklStream(x, engine, stream, x_user_md, op_bp_prim_desc.src_desc(), args[DNNL_ARG_SRC]); + mkldnnUtils::loadDataToMklStream(*x, engine, stream, x_user_md, op_bp_prim_desc.src_desc(), args[DNNL_ARG_SRC]); // dLdO - mkldnnUtils::loadDataToMklStream(&dLdO, engine, stream, dLdO_user_md, op_bp_prim_desc.diff_dst_desc(), args[DNNL_ARG_DIFF_DST]); + mkldnnUtils::loadDataToMklStream(dLdO, engine, stream, dLdO_user_md, op_bp_prim_desc.diff_dst_desc(), args[DNNL_ARG_DIFF_DST]); // mean auto mean_mkl_mem = dnnl::memory(op_bp_prim_desc.mean_desc(), engine, const_cast(mean->buffer())); @@ -253,10 +244,7 @@ static void batchnormBackPropMKLDNN(const NDArray* x, const NDArray* mean, const args[DNNL_ARG_VARIANCE] = var_mkl_mem; // dLdI - auto dLdI_user_mem = dnnl::memory(dLdI_user_md, engine, dLdI->buffer()); - const bool dLdIReorder = op_bp_prim_desc.diff_src_desc() != dLdI_user_mem.get_desc(); - auto dLdI_mkl_mem = dLdIReorder ? dnnl::memory(op_bp_prim_desc.diff_src_desc(), engine) : dLdI_user_mem; - args[DNNL_ARG_DIFF_SRC] = dLdI_mkl_mem; + auto dLdI_user_mem = mkldnnUtils::loadDataToMklStream(*dLdI, engine, stream, dLdI_user_md, op_bp_prim_desc.diff_src_desc(), args[DNNL_ARG_DIFF_SRC]); // gamma and beta (and their gradients) if they are present if(weights != nullptr) { @@ -272,8 +260,8 @@ static void batchnormBackPropMKLDNN(const NDArray* x, const NDArray* mean, const dnnl::batch_normalization_backward(op_bp_prim_desc).execute(stream, args); // reorder outputs if necessary - if (dLdIReorder) - dnnl::reorder(dLdI_mkl_mem, dLdI_user_mem).execute(stream, dLdI_mkl_mem, dLdI_user_mem); + if (op_bp_prim_desc.diff_src_desc() != dLdI_user_mem.get_desc()) + dnnl::reorder(args[DNNL_ARG_DIFF_SRC], dLdI_user_mem).execute(stream, args[DNNL_ARG_DIFF_SRC], dLdI_user_mem); stream.wait(); @@ -662,9 +650,9 @@ PLATFORM_IMPL(batchnorm_bp, ENGINE_CPU) { const bool isNCHW = !(axes[0] == inRank - 1 && inRank > 2); if (shape::strideDescendingCAscendingF(dLdO->shapeInfo())) - batchnormBackPropMKLDNN(input, mean, variance, *dLdO, weights, dLdI, dLdW, epsilon, isNCHW); + batchnormBpMKLDNN(input, mean, variance, *dLdO, weights, dLdI, dLdW, epsilon, isNCHW); else - batchnormBackPropMKLDNN(input, mean, variance, dLdO->dup(), weights, dLdI, dLdW, epsilon, isNCHW); + batchnormBpMKLDNN(input, mean, variance, dLdO->dup(), weights, dLdI, dLdW, epsilon, isNCHW); *dLdM = 0; *dLdV = 0; diff --git a/libnd4j/include/ops/declarable/platform/mkldnn/concat.cpp b/libnd4j/include/ops/declarable/platform/mkldnn/concat.cpp new file mode 100644 index 000000000..9df63556e --- /dev/null +++ b/libnd4j/include/ops/declarable/platform/mkldnn/concat.cpp @@ -0,0 +1,186 @@ +/******************************************************************************* + * Copyright (c) 2020 Konduit K.K. + * + * This program and the accompanying materials are made available under the + * terms of the Apache License, Version 2.0 which is available at + * https://www.apache.org/licenses/LICENSE-2.0. + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + * + * SPDX-License-Identifier: Apache-2.0 + ******************************************************************************/ + +// +// @author Yurii Shyrma (iuriish@yahoo.com) +// + +#include +#include +#include + +#include +#include "mkldnnUtils.h" +#include + + +namespace sd { +namespace ops { +namespace platforms { + + +////////////////////////////////////////////////////////////////////////// +static void concatMKLDNN(const std::vector& inArrs, NDArray& output, const int axis) { + + // data type + dnnl::memory::data_type type; + if(output.dataType() == DataType::FLOAT32) + type = dnnl::memory::data_type::f32; + else if(output.dataType() == DataType::HALF) + type = dnnl::memory::data_type::f16; + else if(output.dataType() == DataType::BFLOAT16) + type = dnnl::memory::data_type::bf16; + else if(output.dataType() == DataType::UINT8) + type = dnnl::memory::data_type::u8; + else + type = dnnl::memory::data_type::s8; + + std::vector x_user_md(inArrs.size()), x_mkl_md(inArrs.size()); + + // inputs + for (int i = 0; i < inArrs.size(); ++i) { + + dnnl::memory::dims dims = inArrs[i]->getShapeAsFlatVector(); + x_user_md[i] = x_mkl_md[i] = dnnl::memory::desc(dims, type, mkldnnUtils::getFormat(*inArrs[i])); + mkldnnUtils::setBlockStrides(*inArrs[i], x_user_md[i]); + } + + // output + dnnl::memory::dims dims = output.getShapeAsFlatVector(); + dnnl::memory::desc z_mkl_md = dnnl::memory::desc(dims, type, dnnl::memory::format_tag::any); + dnnl::memory::desc z_user_md = dnnl::memory::desc(dims, type, mkldnnUtils::getFormat(output)); + mkldnnUtils::setBlockStrides(output, z_user_md); + + std::unordered_map args; + + auto engine = mkldnnUtils::getEngine(LaunchContext::defaultContext()->engine()); + + dnnl::concat::primitive_desc op_prim_desc(axis, x_mkl_md, engine); + + dnnl::stream stream(engine); + + // inputs + for (int i = 0; i < inArrs.size(); ++i) + mkldnnUtils::loadDataToMklStream(*inArrs[i], engine, stream, x_user_md[i], op_prim_desc.src_desc(i), args[DNNL_ARG_MULTIPLE_SRC + i]); + + // outputs + auto z_user_mem = mkldnnUtils::loadDataToMklStream(output, engine, stream, z_user_md, op_prim_desc.dst_desc(), args[DNNL_ARG_DST]); + + // primitive execution + dnnl::concat(op_prim_desc).execute(stream, args); + + // reorder output if necessary + if (op_prim_desc.dst_desc() != z_user_mem.get_desc()) + dnnl::reorder(args[DNNL_ARG_DST], z_user_mem).execute(stream, args[DNNL_ARG_DST], z_user_mem); + + stream.wait(); +} + +////////////////////////////////////////////////////////////////////////// +PLATFORM_IMPL(concat, ENGINE_CPU) { + + REQUIRE_TRUE(block.width() > 0, 0, "CONCAT MKLDNN op: No input arrays were provided"); + + const bool isAxisInLastArr = block.getBArguments()->size() == 0 ? false : B_ARG(0); + + const int numOfInArrs = isAxisInLastArr ? block.width() - 1 : block.width(); + + // first of all take into account possible presence of empty arrays + // also if scalar is present -> copy its value to vector with length=1 + std::vector nonEmptyArrs; + std::vector arrsToDelete; + int index = 0; + bool allOfSameType = true; + auto rankOfFirstArr = block.width() > 0 ? INPUT_VARIABLE(0)->rankOf() : 0; + auto typeOfFirstArr = block.width() > 0 ? INPUT_VARIABLE(0)->dataType() : block.dataType(); + + for(int i = 0; i < numOfInArrs; ++i) { + auto input = INPUT_VARIABLE(i); + auto currentRank = input->rankOf(); + + if(!input->isEmpty()) { + + allOfSameType &= (typeOfFirstArr == input->dataType()); + + if(input->rankOf() == 0) { + auto vec = new NDArray('c', {1}, input->dataType(), block.launchContext()); + vec->assign(input); + nonEmptyArrs.push_back(vec); + arrsToDelete.push_back(index); + } + else{ + nonEmptyArrs.push_back(input); + } + ++index; + } + } + + const int numOfNonEmptyArrs = nonEmptyArrs.size(); + + if(numOfNonEmptyArrs == 0){ + //All inputs are empty arrays -> return empty, mainly for TF import compatibility (no op) + REQUIRE_TRUE(OUTPUT_VARIABLE(0)->isEmpty(), 0, "CONCAT MKLDNN op: If all input variables are empty, output must be empty"); + return Status::OK(); + } + + const int rank = nonEmptyArrs[0]->rankOf(); // look up to first non-empty array + int axis = isAxisInLastArr ? INPUT_VARIABLE(block.width() - 1)->e(0) : INT_ARG(0); + if(axis < 0){ + axis += rank; + } + + // ******** input validation ******** // + REQUIRE_TRUE(allOfSameType, 0, "CONCAT MKLDNN op: all of input arrays must have same type !"); + REQUIRE_TRUE(nonEmptyArrs[0]->dataType() == OUTPUT_VARIABLE(0)->dataType(), 0, "CONCAT MKLDNN op: output array should have the same type as inputs arrays !"); + REQUIRE_TRUE(0 <= axis && (axis < rank || (axis == 0 && rank == 0)), 0, "CONCAT MKLDNN op: input axis must be in range [0, %i], but got %i instead!", rank-1, axis); + + for(int i = 1; i < numOfNonEmptyArrs; ++i) + REQUIRE_TRUE(nonEmptyArrs[i]->rankOf() == rank, 0, "CONCAT MKLDNN op: all input arrays must have the same rank !"); + + for(int i = 1; i < numOfNonEmptyArrs; ++i) { + for(int dim = 0; dim < rank; ++dim) + if(dim != axis) + REQUIRE_TRUE(nonEmptyArrs[i]->sizeAt(dim) == nonEmptyArrs[0]->sizeAt(dim), 0, "CONCAT MKLDNN op: all input arrays must have the same dimensions (except those on input axis) !"); + } + // ******** end of input validation ******** // + + auto output = OUTPUT_VARIABLE(0); + + if(numOfNonEmptyArrs == 1) + output->assign(nonEmptyArrs[0]); + else + concatMKLDNN(nonEmptyArrs, *output, axis); + + // delete dynamically allocated vectors with length=1 + for(int index : arrsToDelete) + delete nonEmptyArrs[index]; + + return Status::OK(); +} + +////////////////////////////////////////////////////////////////////////// +PLATFORM_CHECK(concat, ENGINE_CPU) { + + auto z = OUTPUT_VARIABLE(0); + + const auto zType = z->dataType(); + + return z->rankOf() < 7 && (zType==DataType::FLOAT32 || zType==DataType::HALF || zType==DataType::BFLOAT16 || zType==DataType::UINT8 || zType==DataType::INT8); +} + +} +} +} diff --git a/libnd4j/include/ops/declarable/platform/mkldnn/conv2d.cpp b/libnd4j/include/ops/declarable/platform/mkldnn/conv2d.cpp index b1def8ed7..a889d0302 100644 --- a/libnd4j/include/ops/declarable/platform/mkldnn/conv2d.cpp +++ b/libnd4j/include/ops/declarable/platform/mkldnn/conv2d.cpp @@ -62,33 +62,23 @@ static void conv2dMKLDNN(const NDArray *input, const NDArray *weights, auto type = dnnl::memory::data_type::f32; + std::vector permut; + if(0 == wFormat) + permut = {3,2,0,1}; // [kH, kW, iC, oC] -> [oC, iC, kH, kW] + else if(2 == wFormat) + permut = {0,3,1,2}; // [oC, kH, kW, iC] -> [oC, iC, kH, kW] + // memory descriptors for arrays // input dnnl::memory::desc x_mkl_md = dnnl::memory::desc(xDims, type, dnnl::memory::format_tag::any); dnnl::memory::desc x_user_md = dnnl::memory::desc(xDims, type, xzFormatMkl); - mkldnnUtils::setBlockStrides(input, x_user_md); + mkldnnUtils::setBlockStrides(*input, x_user_md); // weights dnnl::memory::desc w_mkl_md = dnnl::memory::desc(wDims, type, dnnl::memory::format_tag::any); dnnl::memory::desc w_user_md = dnnl::memory::desc(wDims, type, wFormatMkl); - if(weights->ews() != 1 || weights->ordering() != 'c' || 1 != wFormat) { - w_user_md.data.format_kind = dnnl_blocked; // overrides format - uint i0, i1, i2, i3; - if(0 == wFormat) { - i0 = 3; i1 = 2; i2 = 0; i3 = 1; // [kH, kW, iC, oC] -> [oC, iC, kH, kW] - } - else if(1 == wFormat) { - i0 = 0; i1 = 1; i2 = 2; i3 = 3; - } - else { - i0 = 0; i1 = 3; i2 = 1; i3 = 2; // [oC, kH, kW, iC] -> [oC, iC, kH, kW] - } - w_user_md.data.format_desc.blocking.strides[0] = weights->strideAt(i0); - w_user_md.data.format_desc.blocking.strides[1] = weights->strideAt(i1); - w_user_md.data.format_desc.blocking.strides[2] = weights->strideAt(i2); - w_user_md.data.format_desc.blocking.strides[3] = weights->strideAt(i3); - } + mkldnnUtils::setBlockStrides(*weights, w_user_md, permut); // bias dnnl::memory::desc b_mkl_md; @@ -98,7 +88,7 @@ static void conv2dMKLDNN(const NDArray *input, const NDArray *weights, // output dnnl::memory::desc z_mkl_md = dnnl::memory::desc(zDims, type, dnnl::memory::format_tag::any); dnnl::memory::desc z_user_md = dnnl::memory::desc(zDims, type, xzFormatMkl); - mkldnnUtils::setBlockStrides(output, z_user_md); + mkldnnUtils::setBlockStrides(*output, z_user_md); auto engine = mkldnnUtils::getEngine(LaunchContext::defaultContext()->engine()); @@ -114,10 +104,10 @@ static void conv2dMKLDNN(const NDArray *input, const NDArray *weights, // provide memory buffers and check whether reorder is required // input - mkldnnUtils::loadDataToMklStream(input, engine, stream, x_user_md, op_prim_desc.src_desc(), args[DNNL_ARG_SRC]); + mkldnnUtils::loadDataToMklStream(*input, engine, stream, x_user_md, op_prim_desc.src_desc(), args[DNNL_ARG_SRC]); // weights - mkldnnUtils::loadDataToMklStream(weights, engine, stream, w_user_md, op_prim_desc.weights_desc(), args[DNNL_ARG_WEIGHTS]); + mkldnnUtils::loadDataToMklStream(*weights, engine, stream, w_user_md, op_prim_desc.weights_desc(), args[DNNL_ARG_WEIGHTS]); // bias if(bias != nullptr) { @@ -126,17 +116,14 @@ static void conv2dMKLDNN(const NDArray *input, const NDArray *weights, } // output - auto z_user_mem = dnnl::memory(z_user_md, engine, output->buffer()); - const bool zReorder = op_prim_desc.dst_desc() != z_user_mem.get_desc(); - auto z_mkl_mem = zReorder ? dnnl::memory(op_prim_desc.dst_desc(), engine) : z_user_mem; - args[DNNL_ARG_DST] = z_mkl_mem; + auto z_user_mem = mkldnnUtils::loadDataToMklStream(*output, engine, stream, z_user_md, op_prim_desc.dst_desc(), args[DNNL_ARG_DST]); // run calculations dnnl::convolution_forward(op_prim_desc).execute(stream, args); // reorder outputs if necessary - if (zReorder) - dnnl::reorder(z_mkl_mem, z_user_mem).execute(stream, z_mkl_mem, z_user_mem); + if (op_prim_desc.dst_desc() != z_user_mem.get_desc()) + dnnl::reorder(args[DNNL_ARG_DST], z_user_mem).execute(stream, args[DNNL_ARG_DST], z_user_mem); stream.wait(); // shape::printArray(z_mkl_mem.map_data(),8); @@ -170,64 +157,38 @@ static void conv2dBpMKLDNN(const NDArray *input, const NDArray *weights, const N auto type = dnnl::memory::data_type::f32; + std::vector permut; + if(0 == wFormat) + permut = {3,2,0,1}; // [kH, kW, iC, oC] -> [oC, iC, kH, kW] + else if(2 == wFormat) + permut = {0,3,1,2}; // [oC, kH, kW, iC] -> [oC, iC, kH, kW] + // memory descriptors for arrays // input dnnl::memory::desc x_mkl_md = dnnl::memory::desc(xDims, type, dnnl::memory::format_tag::any); dnnl::memory::desc x_user_md = dnnl::memory::desc(xDims, type, xzFormatMkl); - mkldnnUtils::setBlockStrides(input, x_user_md); + mkldnnUtils::setBlockStrides(*input, x_user_md); // weights dnnl::memory::desc w_mkl_md = dnnl::memory::desc(wDims, type, dnnl::memory::format_tag::any); dnnl::memory::desc w_user_md = dnnl::memory::desc(wDims, type, wFormatMkl); - if(weights->ews() != 1 || weights->ordering() != 'c' || 1 != wFormat) { - w_user_md.data.format_kind = dnnl_blocked; // overrides format - uint i0, i1, i2, i3; - if(0 == wFormat) { - i0 = 3; i1 = 2; i2 = 0; i3 = 1; // [kH, kW, iC, oC] -> [oC, iC, kH, kW] - } - else if(1 == wFormat) { - i0 = 0; i1 = 1; i2 = 2; i3 = 3; - } - else { - i0 = 0; i1 = 3; i2 = 1; i3 = 2; // [oC, kH, kW, iC] -> [oC, iC, kH, kW] - } - w_user_md.data.format_desc.blocking.strides[0] = weights->strideAt(i0); - w_user_md.data.format_desc.blocking.strides[1] = weights->strideAt(i1); - w_user_md.data.format_desc.blocking.strides[2] = weights->strideAt(i2); - w_user_md.data.format_desc.blocking.strides[3] = weights->strideAt(i3); - } + mkldnnUtils::setBlockStrides(*weights, w_user_md, permut); // gradO dnnl::memory::desc gradO_mkl_md = dnnl::memory::desc(zDims, type, dnnl::memory::format_tag::any); dnnl::memory::desc gradO_user_md = dnnl::memory::desc(zDims, type, xzFormatMkl); - mkldnnUtils::setBlockStrides(gradO, gradO_user_md); + mkldnnUtils::setBlockStrides(*gradO, gradO_user_md); // gradI dnnl::memory::desc gradI_mkl_md = dnnl::memory::desc(xDims, type, dnnl::memory::format_tag::any); dnnl::memory::desc gradI_user_md = dnnl::memory::desc(xDims, type, xzFormatMkl); - mkldnnUtils::setBlockStrides(gradI, gradI_user_md); + mkldnnUtils::setBlockStrides(*gradI, gradI_user_md); // gradW dnnl::memory::desc gradW_mkl_md = dnnl::memory::desc(wDims, type, dnnl::memory::format_tag::any); dnnl::memory::desc gradW_user_md = dnnl::memory::desc(wDims, type, wFormatMkl); - if(gradW->ews() != 1 || gradW->ordering() != 'c' || 1 != wFormat) { - gradW_user_md.data.format_kind = dnnl_blocked; // overrides format - uint i0, i1, i2, i3; - if(0 == wFormat) { - i0 = 3; i1 = 2; i2 = 0; i3 = 1; // [kH, kW, iC, oC] -> [oC, iC, kH, kW] - } - else if(1 == wFormat) { - i0 = 0; i1 = 1; i2 = 2; i3 = 3; - } - else { - i0 = 0; i1 = 3; i2 = 1; i3 = 2; // [oC, kH, kW, iC] -> [oC, iC, kH, kW] - } - gradW_user_md.data.format_desc.blocking.strides[0] = gradW->strideAt(i0); - gradW_user_md.data.format_desc.blocking.strides[1] = gradW->strideAt(i1); - gradW_user_md.data.format_desc.blocking.strides[2] = gradW->strideAt(i2); - gradW_user_md.data.format_desc.blocking.strides[3] = gradW->strideAt(i3); - } + mkldnnUtils::setBlockStrides(*gradW, gradW_user_md, permut); // gradB dnnl::memory::desc gradB_mkl_md; @@ -256,10 +217,10 @@ static void conv2dBpMKLDNN(const NDArray *input, const NDArray *weights, const N // provide memory buffers and check whether reorder is required // input - mkldnnUtils::loadDataToMklStream(input, engine, stream, x_user_md, op_weights_bp_prim_desc.src_desc(), args[DNNL_ARG_SRC]); + mkldnnUtils::loadDataToMklStream(*input, engine, stream, x_user_md, op_weights_bp_prim_desc.src_desc(), args[DNNL_ARG_SRC]); // weights - mkldnnUtils::loadDataToMklStream(weights, engine, stream, w_user_md, op_data_bp_prim_desc.weights_desc(), args[DNNL_ARG_WEIGHTS]); + mkldnnUtils::loadDataToMklStream(*weights, engine, stream, w_user_md, op_data_bp_prim_desc.weights_desc(), args[DNNL_ARG_WEIGHTS]); // gradO auto gradO_user_mem = dnnl::memory(gradO_user_md, engine, const_cast(gradO->buffer())); @@ -274,16 +235,10 @@ static void conv2dBpMKLDNN(const NDArray *input, const NDArray *weights, const N args[DNNL_ARG_DIFF_DST] = gradO_mkl_memD; // gradI - auto gradI_user_mem = dnnl::memory(gradI_user_md, engine, gradI->buffer()); - const bool gradIReorder = op_data_bp_prim_desc.diff_src_desc() != gradI_user_mem.get_desc(); - auto gradI_mkl_mem = gradIReorder ? dnnl::memory(op_data_bp_prim_desc.diff_src_desc(), engine) : gradI_user_mem; - args[DNNL_ARG_DIFF_SRC] = gradI_mkl_mem; + auto gradI_user_mem = mkldnnUtils::loadDataToMklStream(*gradI, engine, stream, gradI_user_md, op_data_bp_prim_desc.diff_src_desc(), args[DNNL_ARG_DIFF_SRC]); // gradW - auto gradW_user_mem = dnnl::memory(gradW_user_md, engine, gradW->buffer()); - const bool gradWReorder = op_weights_bp_prim_desc.diff_weights_desc() != gradW_user_mem.get_desc(); - auto gradW_mkl_mem = gradWReorder ? dnnl::memory(op_weights_bp_prim_desc.diff_weights_desc(), engine) : gradW_user_mem; - args[DNNL_ARG_DIFF_WEIGHTS] = gradW_mkl_mem; + auto gradW_user_mem = mkldnnUtils::loadDataToMklStream(*gradW, engine, stream, gradW_user_md, op_weights_bp_prim_desc.diff_weights_desc(), args[DNNL_ARG_DIFF_WEIGHTS]); // gradB if(gradB != nullptr) { @@ -301,10 +256,10 @@ static void conv2dBpMKLDNN(const NDArray *input, const NDArray *weights, const N dnnl::convolution_backward_weights(op_weights_bp_prim_desc).execute(stream, args); // reorder gradI if necessary - if (gradIReorder) - dnnl::reorder(gradI_mkl_mem, gradI_user_mem).execute(stream, gradI_mkl_mem, gradI_user_mem); - if (gradWReorder) - dnnl::reorder(gradW_mkl_mem, gradW_user_mem).execute(stream, gradW_mkl_mem, gradW_user_mem); + if (op_data_bp_prim_desc.diff_src_desc() != gradI_user_mem.get_desc()) + dnnl::reorder(args[DNNL_ARG_DIFF_SRC], gradI_user_mem).execute(stream, args[DNNL_ARG_DIFF_SRC], gradI_user_mem); + if (op_weights_bp_prim_desc.diff_weights_desc() != gradW_user_mem.get_desc()) + dnnl::reorder(args[DNNL_ARG_DIFF_WEIGHTS], gradW_user_mem).execute(stream, args[DNNL_ARG_DIFF_WEIGHTS], gradW_user_mem); stream.wait(); diff --git a/libnd4j/include/ops/declarable/platform/mkldnn/conv3d.cpp b/libnd4j/include/ops/declarable/platform/mkldnn/conv3d.cpp index b9fa696c5..bfa9e49d1 100644 --- a/libnd4j/include/ops/declarable/platform/mkldnn/conv3d.cpp +++ b/libnd4j/include/ops/declarable/platform/mkldnn/conv3d.cpp @@ -63,6 +63,12 @@ static void conv3dMKLDNN(const NDArray *input, const NDArray *weights, dnnl::memory::dims wDims = {oC, iC, kD, kH, kW}; dnnl::memory::dims zDims = {bS, oC, oD, oH, oW}; + std::vector permut; + if(0 == wFormat) + permut = {4,3,0,1,2}; // [kD, kH, kW, iC, oC] -> [oC, iC, kD, kH, kW] + else if(2 == wFormat) + permut = {0,4,1,2,3}; // [oC, kD, kH, kW, iC] -> [oC, iC, kD, kH, kW] + auto type = dnnl::memory::data_type::f32; // memory descriptors for arrays @@ -70,29 +76,12 @@ static void conv3dMKLDNN(const NDArray *input, const NDArray *weights, // input dnnl::memory::desc x_mkl_md = dnnl::memory::desc(xDims, type, dnnl::memory::format_tag::any); dnnl::memory::desc x_user_md = dnnl::memory::desc(xDims, type, xzFormatMkl); - mkldnnUtils::setBlockStrides(input, x_user_md); + mkldnnUtils::setBlockStrides(*input, x_user_md); // weights dnnl::memory::desc w_mkl_md = dnnl::memory::desc(wDims, type, dnnl::memory::format_tag::any); dnnl::memory::desc w_user_md = dnnl::memory::desc(wDims, type, wFormatMkl); - if(weights->ews() != 1 || weights->ordering() != 'c' || 1 != wFormat) { - w_user_md.data.format_kind = dnnl_blocked; // overrides format - uint i0, i1, i2, i3, i4; - if(0 == wFormat) { - i0 = 4; i1 = 3; i2 = 0; i3 = 1; i4 = 2; // [kD, kH, kW, iC, oC] -> [oC, iC, kD, kH, kW] - } - else if(1 == wFormat) { - i0 = 0; i1 = 1; i2 = 2; i3 = 3; i4 = 4; - } - else { - i0 = 0; i1 = 4; i2 = 1; i3 = 2; i4 = 3; // [oC, kD, kH, kW, iC] -> [oC, iC, kD, kH, kW] - } - w_user_md.data.format_desc.blocking.strides[0] = weights->strideAt(i0); - w_user_md.data.format_desc.blocking.strides[1] = weights->strideAt(i1); - w_user_md.data.format_desc.blocking.strides[2] = weights->strideAt(i2); - w_user_md.data.format_desc.blocking.strides[3] = weights->strideAt(i3); - w_user_md.data.format_desc.blocking.strides[4] = weights->strideAt(i4); - } + mkldnnUtils::setBlockStrides(*weights, w_user_md, permut); // bias dnnl::memory::desc b_mkl_md; @@ -102,7 +91,7 @@ static void conv3dMKLDNN(const NDArray *input, const NDArray *weights, // output dnnl::memory::desc z_mkl_md = dnnl::memory::desc(zDims, type, dnnl::memory::format_tag::any); dnnl::memory::desc z_user_md = dnnl::memory::desc(zDims, type, xzFormatMkl); - mkldnnUtils::setBlockStrides(output, z_user_md); + mkldnnUtils::setBlockStrides(*output, z_user_md); auto engine = mkldnnUtils::getEngine(LaunchContext::defaultContext()->engine()); @@ -118,10 +107,10 @@ static void conv3dMKLDNN(const NDArray *input, const NDArray *weights, // provide memory buffers and check whether reorder is required // input - mkldnnUtils::loadDataToMklStream(input, engine, stream, x_user_md, op_prim_desc.src_desc(), args[DNNL_ARG_SRC]); + mkldnnUtils::loadDataToMklStream(*input, engine, stream, x_user_md, op_prim_desc.src_desc(), args[DNNL_ARG_SRC]); // weights - mkldnnUtils::loadDataToMklStream(weights, engine, stream, w_user_md, op_prim_desc.weights_desc(), args[DNNL_ARG_WEIGHTS]); + mkldnnUtils::loadDataToMklStream(*weights, engine, stream, w_user_md, op_prim_desc.weights_desc(), args[DNNL_ARG_WEIGHTS]); // bias if(bias != nullptr) { @@ -130,17 +119,14 @@ static void conv3dMKLDNN(const NDArray *input, const NDArray *weights, } // output - auto z_user_mem = dnnl::memory(z_user_md, engine, output->buffer()); - const bool zReorder = op_prim_desc.dst_desc() != z_user_mem.get_desc(); - auto z_mkl_mem = zReorder ? dnnl::memory(op_prim_desc.dst_desc(), engine) : z_user_mem; - args[DNNL_ARG_DST] = z_mkl_mem; + auto z_user_mem = mkldnnUtils::loadDataToMklStream(*output, engine, stream, z_user_md, op_prim_desc.dst_desc(), args[DNNL_ARG_DST]); // run calculations dnnl::convolution_forward(op_prim_desc).execute(stream, args); // reorder outputs if necessary - if (zReorder) - dnnl::reorder(z_mkl_mem, z_user_mem).execute(stream, z_mkl_mem, z_user_mem); + if (op_prim_desc.dst_desc() != z_user_mem.get_desc()) + dnnl::reorder(args[DNNL_ARG_DST], z_user_mem).execute(stream, args[DNNL_ARG_DST], z_user_mem); stream.wait(); } @@ -177,68 +163,40 @@ static void conv3dBpMKLDNN(const NDArray *input, const NDArray *weights, const N auto type = dnnl::memory::data_type::f32; + std::vector permut; + if(0 == wFormat) + permut = {4,3,0,1,2}; // [kD, kH, kW, iC, oC] -> [oC, iC, kD, kH, kW] + else if(2 == wFormat) + permut = {0,4,1,2,3}; // [oC, kD, kH, kW, iC] -> [oC, iC, kD, kH, kW] + // memory descriptors for arrays // input dnnl::memory::desc x_mkl_md = dnnl::memory::desc(xDims, type, dnnl::memory::format_tag::any); dnnl::memory::desc x_user_md = dnnl::memory::desc(xDims, type, xzFormatMkl); - mkldnnUtils::setBlockStrides(input, x_user_md); + mkldnnUtils::setBlockStrides(*input, x_user_md); // weights dnnl::memory::desc w_mkl_md = dnnl::memory::desc(wDims, type, dnnl::memory::format_tag::any); dnnl::memory::desc w_user_md = dnnl::memory::desc(wDims, type, wFormatMkl); - if(weights->ews() != 1 || weights->ordering() != 'c' || 1 != wFormat) { - w_user_md.data.format_kind = dnnl_blocked; // overrides format - uint i0, i1, i2, i3, i4; - if(0 == wFormat) { - i0 = 4; i1 = 3; i2 = 0; i3 = 1; i4 = 2; // [kD, kH, kW, iC, oC] -> [oC, iC, kD, kH, kW] - } - else if(1 == wFormat) { - i0 = 0; i1 = 1; i2 = 2; i3 = 3; i4 = 4; - } - else { - i0 = 0; i1 = 4; i2 = 1; i3 = 2; i4 = 3; // [oC, kD, kH, kW, iC] -> [oC, iC, kD, kH, kW] - } - w_user_md.data.format_desc.blocking.strides[0] = weights->strideAt(i0); - w_user_md.data.format_desc.blocking.strides[1] = weights->strideAt(i1); - w_user_md.data.format_desc.blocking.strides[2] = weights->strideAt(i2); - w_user_md.data.format_desc.blocking.strides[3] = weights->strideAt(i3); - w_user_md.data.format_desc.blocking.strides[4] = weights->strideAt(i4); - } + mkldnnUtils::setBlockStrides(*weights, w_user_md, permut); // gradO dnnl::memory::desc gradO_mkl_md = dnnl::memory::desc(zDims, type, dnnl::memory::format_tag::any); dnnl::memory::desc gradO_user_md = dnnl::memory::desc(zDims, type, xzFormatMkl); - mkldnnUtils::setBlockStrides(gradO, gradO_user_md); + mkldnnUtils::setBlockStrides(*gradO, gradO_user_md); // gradI dnnl::memory::desc gradI_mkl_md = dnnl::memory::desc(xDims, type, dnnl::memory::format_tag::any); dnnl::memory::desc gradI_user_md = dnnl::memory::desc(xDims, type, xzFormatMkl); - mkldnnUtils::setBlockStrides(gradI, gradI_user_md); + mkldnnUtils::setBlockStrides(*gradI, gradI_user_md); // gradW dnnl::memory::desc gradW_mkl_md = dnnl::memory::desc(wDims, type, dnnl::memory::format_tag::any); dnnl::memory::desc gradW_user_md = dnnl::memory::desc(wDims, type, wFormatMkl); - if(gradW->ews() != 1 || gradW->ordering() != 'c' || 1 != wFormat) { - gradW_user_md.data.format_kind = dnnl_blocked; // overrides format - uint i0, i1, i2, i3, i4; - if(0 == wFormat) { - i0 = 4; i1 = 3; i2 = 0; i3 = 1; i4 = 2; // [kD, kH, kW, iC, oC] -> [oC, iC, kD, kH, kW] - } - else if(1 == wFormat) { - i0 = 0; i1 = 1; i2 = 2; i3 = 3; i4 = 4; - } - else { - i0 = 0; i1 = 4; i2 = 1; i3 = 2; i4 = 3; // [oC, kD, kH, kW, iC] -> [oC, iC, kD, kH, kW] - } - gradW_user_md.data.format_desc.blocking.strides[0] = gradW->strideAt(i0); - gradW_user_md.data.format_desc.blocking.strides[1] = gradW->strideAt(i1); - gradW_user_md.data.format_desc.blocking.strides[2] = gradW->strideAt(i2); - gradW_user_md.data.format_desc.blocking.strides[3] = gradW->strideAt(i3); - gradW_user_md.data.format_desc.blocking.strides[4] = gradW->strideAt(i4); - } + mkldnnUtils::setBlockStrides(*gradW, gradW_user_md, permut); // gradB dnnl::memory::desc gradB_mkl_md; @@ -267,10 +225,10 @@ static void conv3dBpMKLDNN(const NDArray *input, const NDArray *weights, const N // provide memory buffers and check whether reorder is required // input - mkldnnUtils::loadDataToMklStream(input, engine, stream, x_user_md, op_weights_bp_prim_desc.src_desc(), args[DNNL_ARG_SRC]); + mkldnnUtils::loadDataToMklStream(*input, engine, stream, x_user_md, op_weights_bp_prim_desc.src_desc(), args[DNNL_ARG_SRC]); // weights - mkldnnUtils::loadDataToMklStream(weights, engine, stream, w_user_md, op_data_bp_prim_desc.weights_desc(), args[DNNL_ARG_WEIGHTS]); + mkldnnUtils::loadDataToMklStream(*weights, engine, stream, w_user_md, op_data_bp_prim_desc.weights_desc(), args[DNNL_ARG_WEIGHTS]); // gradO auto gradO_user_mem = dnnl::memory(gradO_user_md, engine, const_cast(gradO->buffer())); @@ -285,16 +243,10 @@ static void conv3dBpMKLDNN(const NDArray *input, const NDArray *weights, const N args[DNNL_ARG_DIFF_DST] = gradO_mkl_memD; // gradI - auto gradI_user_mem = dnnl::memory(gradI_user_md, engine, gradI->buffer()); - const bool gradIReorder = op_data_bp_prim_desc.diff_src_desc() != gradI_user_mem.get_desc(); - auto gradI_mkl_mem = gradIReorder ? dnnl::memory(op_data_bp_prim_desc.diff_src_desc(), engine) : gradI_user_mem; - args[DNNL_ARG_DIFF_SRC] = gradI_mkl_mem; + auto gradI_user_mem = mkldnnUtils::loadDataToMklStream(*gradI, engine, stream, gradI_user_md, op_data_bp_prim_desc.diff_src_desc(), args[DNNL_ARG_DIFF_SRC]); // gradW - auto gradW_user_mem = dnnl::memory(gradW_user_md, engine, gradW->buffer()); - const bool gradWReorder = op_weights_bp_prim_desc.diff_weights_desc() != gradW_user_mem.get_desc(); - auto gradW_mkl_mem = gradWReorder ? dnnl::memory(op_weights_bp_prim_desc.diff_weights_desc(), engine) : gradW_user_mem; - args[DNNL_ARG_DIFF_WEIGHTS] = gradW_mkl_mem; + auto gradW_user_mem = mkldnnUtils::loadDataToMklStream(*gradW, engine, stream, gradW_user_md, op_weights_bp_prim_desc.diff_weights_desc(), args[DNNL_ARG_DIFF_WEIGHTS]); // gradB if(gradB != nullptr) { @@ -312,10 +264,10 @@ static void conv3dBpMKLDNN(const NDArray *input, const NDArray *weights, const N dnnl::convolution_backward_weights(op_weights_bp_prim_desc).execute(stream, args); // reorder gradI if necessary - if (gradIReorder) - dnnl::reorder(gradI_mkl_mem, gradI_user_mem).execute(stream, gradI_mkl_mem, gradI_user_mem); - if (gradWReorder) - dnnl::reorder(gradW_mkl_mem, gradW_user_mem).execute(stream, gradW_mkl_mem, gradW_user_mem); + if (op_data_bp_prim_desc.diff_src_desc() != gradI_user_mem.get_desc()) + dnnl::reorder(args[DNNL_ARG_DIFF_SRC], gradI_user_mem).execute(stream, args[DNNL_ARG_DIFF_SRC], gradI_user_mem); + if (op_weights_bp_prim_desc.diff_weights_desc() != gradW_user_mem.get_desc()) + dnnl::reorder(args[DNNL_ARG_DIFF_WEIGHTS], gradW_user_mem).execute(stream, args[DNNL_ARG_DIFF_WEIGHTS], gradW_user_mem); stream.wait(); diff --git a/libnd4j/include/ops/declarable/platform/mkldnn/deconv2d.cpp b/libnd4j/include/ops/declarable/platform/mkldnn/deconv2d.cpp index 584fd50a5..b7b58b409 100644 --- a/libnd4j/include/ops/declarable/platform/mkldnn/deconv2d.cpp +++ b/libnd4j/include/ops/declarable/platform/mkldnn/deconv2d.cpp @@ -47,16 +47,13 @@ static void deconv2dMKLDNN(const NDArray* input, const NDArray* weights, const N dnnl::memory::dims padding_r = { (iH - 1) * sH - oH + kH - pH, (iW - 1) * sW - oW + kW - pW }; dnnl::memory::dims dilation = { dH-1, dW-1 }; - uint i0, i1, i2, i3; - if(0 == wFormat) { - i0 = 2; i1 = 3; i2 = 0; i3 = 1; // [kH, kW, oC, iC] -> [oC, iC, kH, kW] - } - else if(1 == wFormat) { - i0 = 1; i1 = 0; i2 = 2; i3 = 3; // [iC, oC, kH, kW] -> [oC, iC, kH, kW] - } - else { - i0 = 3; i1 = 0; i2 = 1; i3 = 2; // [iC, kH, kW, oC] -> [oC, iC, kH, kW] - } + std::vector permut; + if(0 == wFormat) + permut = {2,3,0,1}; // [kH, kW, oC, iC] -> [oC, iC, kH, kW] + else if(1 == wFormat) + permut = {1,0,2,3}; // [iC, oC, kH, kW] -> [oC, iC, kH, kW] + else + permut = {3,0,1,2}; // [iC, kH, kW, oC] -> [oC, iC, kH, kW] // input type dnnl::memory::data_type xType; @@ -99,16 +96,12 @@ static void deconv2dMKLDNN(const NDArray* input, const NDArray* weights, const N // input dnnl::memory::desc x_mkl_md = dnnl::memory::desc(xDims, xType, dnnl::memory::format_tag::any); dnnl::memory::desc x_user_md = dnnl::memory::desc(xDims, xType, xFormatMkl); - mkldnnUtils::setBlockStrides(input, x_user_md); + mkldnnUtils::setBlockStrides(*input, x_user_md); // weights dnnl::memory::desc w_mkl_md = dnnl::memory::desc(wDims, wType, dnnl::memory::format_tag::any); dnnl::memory::desc w_user_md = dnnl::memory::desc(wDims, wType, wFormatMkl); - w_user_md.data.format_kind = dnnl_blocked; // overrides format - w_user_md.data.format_desc.blocking.strides[0] = weights->strideAt(i0); - w_user_md.data.format_desc.blocking.strides[1] = weights->strideAt(i1); - w_user_md.data.format_desc.blocking.strides[2] = weights->strideAt(i2); - w_user_md.data.format_desc.blocking.strides[3] = weights->strideAt(i3); + mkldnnUtils::setBlockStrides(*weights, w_user_md, permut); // bias dnnl::memory::desc b_mkl_md; @@ -118,7 +111,7 @@ static void deconv2dMKLDNN(const NDArray* input, const NDArray* weights, const N // output dnnl::memory::desc z_mkl_md = dnnl::memory::desc(zDims, zType, dnnl::memory::format_tag::any); dnnl::memory::desc z_user_md = dnnl::memory::desc(zDims, zType, xFormatMkl); - mkldnnUtils::setBlockStrides(output, z_user_md); + mkldnnUtils::setBlockStrides(*output, z_user_md); auto engine = mkldnnUtils::getEngine(LaunchContext::defaultContext()->engine()); @@ -135,10 +128,10 @@ static void deconv2dMKLDNN(const NDArray* input, const NDArray* weights, const N // provide memory buffers and check whether reorder is required // input - mkldnnUtils::loadDataToMklStream(input, engine, stream, x_user_md, op_prim_desc.src_desc(), args[DNNL_ARG_SRC]); + mkldnnUtils::loadDataToMklStream(*input, engine, stream, x_user_md, op_prim_desc.src_desc(), args[DNNL_ARG_SRC]); // weights - mkldnnUtils::loadDataToMklStream(weights, engine, stream, w_user_md, op_prim_desc.weights_desc(), args[DNNL_ARG_WEIGHTS]); + mkldnnUtils::loadDataToMklStream(*weights, engine, stream, w_user_md, op_prim_desc.weights_desc(), args[DNNL_ARG_WEIGHTS]); // bias if(bias != nullptr) { @@ -147,17 +140,14 @@ static void deconv2dMKLDNN(const NDArray* input, const NDArray* weights, const N } // output - auto z_user_mem = dnnl::memory(z_user_md, engine, output->buffer()); - const bool zReorder = op_prim_desc.dst_desc() != z_user_mem.get_desc(); - auto z_mkl_mem = zReorder ? dnnl::memory(op_prim_desc.dst_desc(), engine) : z_user_mem; - args[DNNL_ARG_DST] = z_mkl_mem; + auto z_user_mem = mkldnnUtils::loadDataToMklStream(*output, engine, stream, z_user_md, op_prim_desc.dst_desc(), args[DNNL_ARG_DST]); // run calculations dnnl::deconvolution_forward(op_prim_desc).execute(stream, args); // reorder outputs if necessary - if (zReorder) - dnnl::reorder(z_mkl_mem, z_user_mem).execute(stream, z_mkl_mem, z_user_mem); + if (op_prim_desc.dst_desc() != z_user_mem.get_desc()) + dnnl::reorder(args[DNNL_ARG_DST], z_user_mem).execute(stream, args[DNNL_ARG_DST], z_user_mem); stream.wait(); @@ -180,16 +170,13 @@ static void deconv2dBpMKLDNN(const NDArray* input, const NDArray* weights, const dnnl::memory::dims padding_r = { (iH - 1) * sH - oH + kH - pH, (iW - 1) * sW - oW + kW - pW }; dnnl::memory::dims dilation = { dH-1, dW-1 }; - uint i0, i1, i2, i3; - if(0 == wFormat) { - i0 = 2; i1 = 3; i2 = 0; i3 = 1; // [kH, kW, oC, iC] -> [oC, iC, kH, kW] - } - else if(1 == wFormat) { - i0 = 1; i1 = 0; i2 = 2; i3 = 3; // [iC, oC, kH, kW] -> [oC, iC, kH, kW] - } - else { - i0 = 3; i1 = 0; i2 = 1; i3 = 2; // [iC, kH, kW, oC] -> [oC, iC, kH, kW] - } + std::vector permut; + if(0 == wFormat) + permut = {2,3,0,1}; // [kH, kW, oC, iC] -> [oC, iC, kH, kW] + else if(1 == wFormat) + permut = {1,0,2,3}; // [iC, oC, kH, kW] -> [oC, iC, kH, kW] + else + permut = {3,0,1,2}; // [iC, kH, kW, oC] -> [oC, iC, kH, kW] // input type dnnl::memory::data_type xType = input->dataType() == DataType::FLOAT32 ? dnnl::memory::data_type::f32 : dnnl::memory::data_type::bf16; @@ -216,35 +203,27 @@ static void deconv2dBpMKLDNN(const NDArray* input, const NDArray* weights, const // input dnnl::memory::desc x_mkl_md = dnnl::memory::desc(xDims, xType, dnnl::memory::format_tag::any); dnnl::memory::desc x_user_md = dnnl::memory::desc(xDims, xType, xFormatMkl); - mkldnnUtils::setBlockStrides(input, x_user_md); + mkldnnUtils::setBlockStrides(*input, x_user_md); // weights dnnl::memory::desc w_mkl_md = dnnl::memory::desc(wDims, wType, dnnl::memory::format_tag::any); dnnl::memory::desc w_user_md = dnnl::memory::desc(wDims, wType, wFormatMkl); - w_user_md.data.format_kind = dnnl_blocked; // overrides format - w_user_md.data.format_desc.blocking.strides[0] = weights->strideAt(i0); - w_user_md.data.format_desc.blocking.strides[1] = weights->strideAt(i1); - w_user_md.data.format_desc.blocking.strides[2] = weights->strideAt(i2); - w_user_md.data.format_desc.blocking.strides[3] = weights->strideAt(i3); + mkldnnUtils::setBlockStrides(*weights, w_user_md, permut); // gradO dnnl::memory::desc gradO_mkl_md = dnnl::memory::desc(zDims, gradOType, dnnl::memory::format_tag::any); dnnl::memory::desc gradO_user_md = dnnl::memory::desc(zDims, gradOType, xFormatMkl); - mkldnnUtils::setBlockStrides(gradO, gradO_user_md); + mkldnnUtils::setBlockStrides(*gradO, gradO_user_md); // gradI dnnl::memory::desc gradI_mkl_md = dnnl::memory::desc(xDims, gradIType, dnnl::memory::format_tag::any); dnnl::memory::desc gradI_user_md = dnnl::memory::desc(xDims, gradIType, xFormatMkl); - mkldnnUtils::setBlockStrides(gradI, gradI_user_md); + mkldnnUtils::setBlockStrides(*gradI, gradI_user_md); // gradW dnnl::memory::desc gradW_mkl_md = dnnl::memory::desc(wDims, gradWType, dnnl::memory::format_tag::any); dnnl::memory::desc gradW_user_md = dnnl::memory::desc(wDims, gradWType, wFormatMkl); - gradW_user_md.data.format_kind = dnnl_blocked; // overrides format - gradW_user_md.data.format_desc.blocking.strides[0] = gradW->strideAt(i0); - gradW_user_md.data.format_desc.blocking.strides[1] = gradW->strideAt(i1); - gradW_user_md.data.format_desc.blocking.strides[2] = gradW->strideAt(i2); - gradW_user_md.data.format_desc.blocking.strides[3] = gradW->strideAt(i3); + mkldnnUtils::setBlockStrides(*gradW, gradW_user_md, permut); // gradB dnnl::memory::desc gradB_mkl_md; @@ -273,10 +252,10 @@ static void deconv2dBpMKLDNN(const NDArray* input, const NDArray* weights, const // provide memory buffers and check whether reorder is required // input - mkldnnUtils::loadDataToMklStream(input, engine, stream, x_user_md, op_weights_bp_prim_desc.src_desc(), args[DNNL_ARG_SRC]); + mkldnnUtils::loadDataToMklStream(*input, engine, stream, x_user_md, op_weights_bp_prim_desc.src_desc(), args[DNNL_ARG_SRC]); // weights - mkldnnUtils::loadDataToMklStream(weights, engine, stream, w_user_md, op_data_bp_prim_desc.weights_desc(), args[DNNL_ARG_WEIGHTS]); + mkldnnUtils::loadDataToMklStream(*weights, engine, stream, w_user_md, op_data_bp_prim_desc.weights_desc(), args[DNNL_ARG_WEIGHTS]); // gradO auto gradO_user_mem = dnnl::memory(gradO_user_md, engine, const_cast(gradO->buffer())); @@ -291,16 +270,10 @@ static void deconv2dBpMKLDNN(const NDArray* input, const NDArray* weights, const args[DNNL_ARG_DIFF_DST] = gradO_mkl_memD; // gradI - auto gradI_user_mem = dnnl::memory(gradI_user_md, engine, gradI->buffer()); - const bool gradIReorder = op_data_bp_prim_desc.diff_src_desc() != gradI_user_mem.get_desc(); - auto gradI_mkl_mem = gradIReorder ? dnnl::memory(op_data_bp_prim_desc.diff_src_desc(), engine) : gradI_user_mem; - args[DNNL_ARG_DIFF_SRC] = gradI_mkl_mem; + auto gradI_user_mem = mkldnnUtils::loadDataToMklStream(*gradI, engine, stream, gradI_user_md, op_data_bp_prim_desc.diff_src_desc(), args[DNNL_ARG_DIFF_SRC]); // gradW - auto gradW_user_mem = dnnl::memory(gradW_user_md, engine, gradW->buffer()); - const bool gradWReorder = op_weights_bp_prim_desc.diff_weights_desc() != gradW_user_mem.get_desc(); - auto gradW_mkl_mem = gradWReorder ? dnnl::memory(op_weights_bp_prim_desc.diff_weights_desc(), engine) : gradW_user_mem; - args[DNNL_ARG_DIFF_WEIGHTS] = gradW_mkl_mem; + auto gradW_user_mem = mkldnnUtils::loadDataToMklStream(*gradW, engine, stream, gradW_user_md, op_weights_bp_prim_desc.diff_weights_desc(), args[DNNL_ARG_DIFF_WEIGHTS]); // gradB if(gradB != nullptr) { @@ -318,10 +291,10 @@ static void deconv2dBpMKLDNN(const NDArray* input, const NDArray* weights, const dnnl::deconvolution_backward_weights(op_weights_bp_prim_desc).execute(stream, args); // reorder gradI if necessary - if (gradIReorder) - dnnl::reorder(gradI_mkl_mem, gradI_user_mem).execute(stream, gradI_mkl_mem, gradI_user_mem); - if (gradWReorder) - dnnl::reorder(gradW_mkl_mem, gradW_user_mem).execute(stream, gradW_mkl_mem, gradW_user_mem); + if (op_data_bp_prim_desc.diff_src_desc() != gradI_user_mem.get_desc()) + dnnl::reorder(args[DNNL_ARG_DIFF_SRC], gradI_user_mem).execute(stream, args[DNNL_ARG_DIFF_SRC], gradI_user_mem); + if (op_weights_bp_prim_desc.diff_weights_desc() != gradW_user_mem.get_desc()) + dnnl::reorder(args[DNNL_ARG_DIFF_WEIGHTS], gradW_user_mem).execute(stream, args[DNNL_ARG_DIFF_WEIGHTS], gradW_user_mem); stream.wait(); diff --git a/libnd4j/include/ops/declarable/platform/mkldnn/deconv2d_tf.cpp b/libnd4j/include/ops/declarable/platform/mkldnn/deconv2d_tf.cpp index 5e5da4748..10e3ba77e 100644 --- a/libnd4j/include/ops/declarable/platform/mkldnn/deconv2d_tf.cpp +++ b/libnd4j/include/ops/declarable/platform/mkldnn/deconv2d_tf.cpp @@ -31,7 +31,7 @@ namespace ops { namespace platforms { ////////////////////////////////////////////////////////////////////////// -static void deconv2TFdBackPropMKLDNN(const NDArray* weights, const NDArray* gradO, NDArray* gradI, +static void deconv2TFdBpMKLDNN(const NDArray* weights, const NDArray* gradO, NDArray* gradI, const int bS, const int iC, const int iH, const int iW, const int oC, const int oH, const int oW, const int kH, const int kW, const int sH, const int sW, const int pH, const int pW, const int dH, const int dW, const bool isNCHW, const int wFormat) { @@ -67,21 +67,17 @@ static void deconv2TFdBackPropMKLDNN(const NDArray* weights, const NDArray* grad // weights dnnl::memory::desc w_mkl_md = dnnl::memory::desc(wDims, wType, dnnl::memory::format_tag::any); dnnl::memory::desc w_user_md = dnnl::memory::desc(wDims, wType, wFormatMkl); - w_user_md.data.format_kind = dnnl_blocked; // overrides format - w_user_md.data.format_desc.blocking.strides[0] = weights->strideAt(3); // permute [kH, kW, iC, oC] -> [oC, iC, kH, kW] - w_user_md.data.format_desc.blocking.strides[1] = weights->strideAt(2); - w_user_md.data.format_desc.blocking.strides[2] = weights->strideAt(0); - w_user_md.data.format_desc.blocking.strides[3] = weights->strideAt(1); + mkldnnUtils::setBlockStrides(*weights, w_user_md, {3,2,0,1}); // permute [kH, kW, iC, oC] -> [oC, iC, kH, kW] // gradO dnnl::memory::desc gradO_mkl_md = dnnl::memory::desc(zDims, gradOType, dnnl::memory::format_tag::any); dnnl::memory::desc gradO_user_md = dnnl::memory::desc(zDims, gradOType, xFormatMkl); - mkldnnUtils::setBlockStrides(gradO, gradO_user_md); + mkldnnUtils::setBlockStrides(*gradO, gradO_user_md); // gradI dnnl::memory::desc gradI_mkl_md = dnnl::memory::desc(xDims, gradIType, dnnl::memory::format_tag::any); dnnl::memory::desc gradI_user_md = dnnl::memory::desc(xDims, gradIType, xFormatMkl); - mkldnnUtils::setBlockStrides(gradI, gradI_user_md); + mkldnnUtils::setBlockStrides(*gradI, gradI_user_md); auto engine = mkldnnUtils::getEngine(LaunchContext::defaultContext()->engine()); @@ -101,23 +97,20 @@ static void deconv2TFdBackPropMKLDNN(const NDArray* weights, const NDArray* grad // provide memory buffers and check whether reorder is required // weights - mkldnnUtils::loadDataToMklStream(weights, engine, stream, w_user_md, op_data_bp_prim_desc.weights_desc(), args[DNNL_ARG_WEIGHTS]); + mkldnnUtils::loadDataToMklStream(*weights, engine, stream, w_user_md, op_data_bp_prim_desc.weights_desc(), args[DNNL_ARG_WEIGHTS]); // gradO - mkldnnUtils::loadDataToMklStream(gradO, engine, stream, gradO_user_md, op_data_bp_prim_desc.diff_dst_desc(), args[DNNL_ARG_DIFF_DST]); + mkldnnUtils::loadDataToMklStream(*gradO, engine, stream, gradO_user_md, op_data_bp_prim_desc.diff_dst_desc(), args[DNNL_ARG_DIFF_DST]); // gradI - auto gradI_user_mem = dnnl::memory(gradI_user_md, engine, gradI->buffer()); - const bool gradIReorder = op_data_bp_prim_desc.diff_src_desc() != gradI_user_mem.get_desc(); - auto gradI_mkl_mem = gradIReorder ? dnnl::memory(op_data_bp_prim_desc.diff_src_desc(), engine) : gradI_user_mem; - args[DNNL_ARG_DIFF_SRC] = gradI_mkl_mem; + auto gradI_user_mem = mkldnnUtils::loadDataToMklStream(*gradI, engine, stream, gradI_user_md, op_data_bp_prim_desc.diff_src_desc(), args[DNNL_ARG_DIFF_SRC]); // run backward data calculations dnnl::convolution_backward_data(op_data_bp_prim_desc).execute(stream, args); // reorder gradI if necessary - if (gradIReorder) - dnnl::reorder(gradI_mkl_mem, gradI_user_mem).execute(stream, gradI_mkl_mem, gradI_user_mem); + if (op_data_bp_prim_desc.diff_src_desc() != gradI_user_mem.get_desc()) + dnnl::reorder(args[DNNL_ARG_DIFF_SRC], gradI_user_mem).execute(stream, args[DNNL_ARG_DIFF_SRC], gradI_user_mem); stream.wait(); @@ -189,7 +182,7 @@ PLATFORM_IMPL(deconv2d_tf, ENGINE_CPU) { // gradO = new NDArray(gradO->permute({0,3,1,2})); // [bS, oH, oW, oC] -> [bS, oC, oH, oW] // } - deconv2TFdBackPropMKLDNN(weights, gradO, gradI, bS, iC, iH, iW, oC, oH, oW, kH, kW, sH, sW, pH, pW, dH, dW, isNCHW, wFormat); + deconv2TFdBpMKLDNN(weights, gradO, gradI, bS, iC, iH, iW, oC, oH, oW, kH, kW, sH, sW, pH, pW, dH, dW, isNCHW, wFormat); // delete weights; diff --git a/libnd4j/include/ops/declarable/platform/mkldnn/deconv3d.cpp b/libnd4j/include/ops/declarable/platform/mkldnn/deconv3d.cpp index eb6966c77..59f355c6e 100644 --- a/libnd4j/include/ops/declarable/platform/mkldnn/deconv3d.cpp +++ b/libnd4j/include/ops/declarable/platform/mkldnn/deconv3d.cpp @@ -48,16 +48,13 @@ static void deconv3dMKLDNN(const NDArray* input, const NDArray* weights, const N dnnl::memory::dims padding_r = { (iD - 1) * sD - oD + kD - pD, (iH - 1) * sH - oH + kH - pH, (iW - 1) * sW - oW + kW - pW }; dnnl::memory::dims dilation = { dD-1, dH-1, dW-1 }; - uint i0, i1, i2, i3, i4; - if(0 == wFormat) { - i0 = 3; i1 = 4; i2 = 0; i3 = 1; i4 = 2; // [kD, kH, kW, oC, iC] -> [oC, iC, kD, kH, kW] - } - else if(1 == wFormat) { - i0 = 1; i1 = 0; i2 = 2; i3 = 3; i4 = 4; // [iC, oC, kD, kH, kW] -> [oC, iC, kD, kH, kW] - } - else { - i0 = 4; i1 = 0; i2 = 1; i3 = 2; i4 = 3; // [iC, kD, kH, kW, oC] -> [oC, iC, kD, kH, kW] - } + std::vector permut; + if(0 == wFormat) + permut = {3,4,0,1,2}; // [kD, kH, kW, oC, iC] -> [oC, iC, kD, kH, kW] + else if(1 == wFormat) + permut = {1,0,2,3,4}; // [iC, oC, kD, kH, kW] -> [oC, iC, kD, kH, kW] + else + permut = {4,0,1,2,3}; // [iC, kD, kH, kW, oC] -> [oC, iC, kD, kH, kW] // input type dnnl::memory::data_type xType; @@ -100,17 +97,12 @@ static void deconv3dMKLDNN(const NDArray* input, const NDArray* weights, const N // input dnnl::memory::desc x_mkl_md = dnnl::memory::desc(xDims, xType, dnnl::memory::format_tag::any); dnnl::memory::desc x_user_md = dnnl::memory::desc(xDims, xType, xFormatMkl); - mkldnnUtils::setBlockStrides(input, x_user_md); + mkldnnUtils::setBlockStrides(*input, x_user_md); // weights dnnl::memory::desc w_mkl_md = dnnl::memory::desc(wDims, wType, dnnl::memory::format_tag::any); dnnl::memory::desc w_user_md = dnnl::memory::desc(wDims, wType, wFormatMkl); - w_user_md.data.format_kind = dnnl_blocked; // overrides format - w_user_md.data.format_desc.blocking.strides[0] = weights->strideAt(i0); - w_user_md.data.format_desc.blocking.strides[1] = weights->strideAt(i1); - w_user_md.data.format_desc.blocking.strides[2] = weights->strideAt(i2); - w_user_md.data.format_desc.blocking.strides[3] = weights->strideAt(i3); - w_user_md.data.format_desc.blocking.strides[4] = weights->strideAt(i4); + mkldnnUtils::setBlockStrides(*weights, w_user_md, permut); // bias dnnl::memory::desc b_mkl_md; @@ -120,7 +112,7 @@ static void deconv3dMKLDNN(const NDArray* input, const NDArray* weights, const N // output dnnl::memory::desc z_mkl_md = dnnl::memory::desc(zDims, zType, dnnl::memory::format_tag::any); dnnl::memory::desc z_user_md = dnnl::memory::desc(zDims, zType, xFormatMkl); - mkldnnUtils::setBlockStrides(output, z_user_md); + mkldnnUtils::setBlockStrides(*output, z_user_md); auto engine = mkldnnUtils::getEngine(LaunchContext::defaultContext()->engine()); @@ -137,10 +129,10 @@ static void deconv3dMKLDNN(const NDArray* input, const NDArray* weights, const N // provide memory buffers and check whether reorder is required // input - mkldnnUtils::loadDataToMklStream(input, engine, stream, x_user_md, op_prim_desc.src_desc(), args[DNNL_ARG_SRC]); + mkldnnUtils::loadDataToMklStream(*input, engine, stream, x_user_md, op_prim_desc.src_desc(), args[DNNL_ARG_SRC]); // weights - mkldnnUtils::loadDataToMklStream(weights, engine, stream, w_user_md, op_prim_desc.weights_desc(), args[DNNL_ARG_WEIGHTS]); + mkldnnUtils::loadDataToMklStream(*weights, engine, stream, w_user_md, op_prim_desc.weights_desc(), args[DNNL_ARG_WEIGHTS]); // bias if(bias != nullptr) { @@ -149,17 +141,14 @@ static void deconv3dMKLDNN(const NDArray* input, const NDArray* weights, const N } // output - auto z_user_mem = dnnl::memory(z_user_md, engine, output->buffer()); - const bool zReorder = op_prim_desc.dst_desc() != z_user_mem.get_desc(); - auto z_mkl_mem = zReorder ? dnnl::memory(op_prim_desc.dst_desc(), engine) : z_user_mem; - args[DNNL_ARG_DST] = z_mkl_mem; + auto z_user_mem = mkldnnUtils::loadDataToMklStream(*output, engine, stream, z_user_md, op_prim_desc.dst_desc(), args[DNNL_ARG_DST]); // run calculations dnnl::deconvolution_forward(op_prim_desc).execute(stream, args); // reorder outputs if necessary - if (zReorder) - dnnl::reorder(z_mkl_mem, z_user_mem).execute(stream, z_mkl_mem, z_user_mem); + if (op_prim_desc.dst_desc() != z_user_mem.get_desc()) + dnnl::reorder(args[DNNL_ARG_DST], z_user_mem).execute(stream, args[DNNL_ARG_DST], z_user_mem); stream.wait(); @@ -185,16 +174,13 @@ static void deconv3dBackPropMKLDNN(const NDArray* input, const NDArray* weights, dnnl::memory::dims padding_r = { (iD - 1) * sD - oD + kD - pD, (iH - 1) * sH - oH + kH - pH, (iW - 1) * sW - oW + kW - pW }; dnnl::memory::dims dilation = { dD-1, dH-1, dW-1 }; - uint i0, i1, i2, i3, i4; - if(0 == wFormat) { - i0 = 3; i1 = 4; i2 = 0; i3 = 1; i4 = 2; // [kD, kH, kW, oC, iC] -> [oC, iC, kD, kH, kW] - } - else if(1 == wFormat) { - i0 = 1; i1 = 0; i2 = 2; i3 = 3; i4 = 4; // [iC, oC, kD, kH, kW] -> [oC, iC, kD, kH, kW] - } - else { - i0 = 4; i1 = 0; i2 = 1; i3 = 2; i4 = 3; // [iC, kD, kH, kW, oC] -> [oC, iC, kD, kH, kW] - } + std::vector permut; + if(0 == wFormat) + permut = {3,4,0,1,2}; // [kD, kH, kW, oC, iC] -> [oC, iC, kD, kH, kW] + else if(1 == wFormat) + permut = {1,0,2,3,4}; // [iC, oC, kD, kH, kW] -> [oC, iC, kD, kH, kW] + else + permut = {4,0,1,2,3}; // [iC, kD, kH, kW, oC] -> [oC, iC, kD, kH, kW] // input type dnnl::memory::data_type xType = input->dataType() == DataType::FLOAT32 ? dnnl::memory::data_type::f32 : dnnl::memory::data_type::bf16; @@ -221,37 +207,27 @@ static void deconv3dBackPropMKLDNN(const NDArray* input, const NDArray* weights, // input dnnl::memory::desc x_mkl_md = dnnl::memory::desc(xDims, xType, dnnl::memory::format_tag::any); dnnl::memory::desc x_user_md = dnnl::memory::desc(xDims, xType, xFormatMkl); - mkldnnUtils::setBlockStrides(input, x_user_md); + mkldnnUtils::setBlockStrides(*input, x_user_md); // weights dnnl::memory::desc w_mkl_md = dnnl::memory::desc(wDims, wType, dnnl::memory::format_tag::any); dnnl::memory::desc w_user_md = dnnl::memory::desc(wDims, wType, wFormatMkl); - w_user_md.data.format_kind = dnnl_blocked; // overrides format - w_user_md.data.format_desc.blocking.strides[0] = weights->strideAt(i0); - w_user_md.data.format_desc.blocking.strides[1] = weights->strideAt(i1); - w_user_md.data.format_desc.blocking.strides[2] = weights->strideAt(i2); - w_user_md.data.format_desc.blocking.strides[3] = weights->strideAt(i3); - w_user_md.data.format_desc.blocking.strides[4] = weights->strideAt(i4); + mkldnnUtils::setBlockStrides(*weights, w_user_md, permut); // gradO dnnl::memory::desc gradO_mkl_md = dnnl::memory::desc(zDims, gradOType, dnnl::memory::format_tag::any); dnnl::memory::desc gradO_user_md = dnnl::memory::desc(zDims, gradOType, xFormatMkl); - mkldnnUtils::setBlockStrides(gradO, gradO_user_md); + mkldnnUtils::setBlockStrides(*gradO, gradO_user_md); // gradI dnnl::memory::desc gradI_mkl_md = dnnl::memory::desc(xDims, gradIType, dnnl::memory::format_tag::any); dnnl::memory::desc gradI_user_md = dnnl::memory::desc(xDims, gradIType, xFormatMkl); - mkldnnUtils::setBlockStrides(gradI, gradI_user_md); + mkldnnUtils::setBlockStrides(*gradI, gradI_user_md); // gradW dnnl::memory::desc gradW_mkl_md = dnnl::memory::desc(wDims, gradWType, dnnl::memory::format_tag::any); dnnl::memory::desc gradW_user_md = dnnl::memory::desc(wDims, gradWType, wFormatMkl); - gradW_user_md.data.format_kind = dnnl_blocked; // overrides format - gradW_user_md.data.format_desc.blocking.strides[0] = gradW->strideAt(i0); - gradW_user_md.data.format_desc.blocking.strides[1] = gradW->strideAt(i1); - gradW_user_md.data.format_desc.blocking.strides[2] = gradW->strideAt(i2); - gradW_user_md.data.format_desc.blocking.strides[3] = gradW->strideAt(i3); - gradW_user_md.data.format_desc.blocking.strides[4] = gradW->strideAt(i4); + mkldnnUtils::setBlockStrides(*gradW, gradW_user_md, permut); // gradB dnnl::memory::desc gradB_mkl_md; @@ -281,10 +257,10 @@ static void deconv3dBackPropMKLDNN(const NDArray* input, const NDArray* weights, // provide memory buffers and check whether reorder is required // input - mkldnnUtils::loadDataToMklStream(input, engine, stream, x_user_md, op_weights_bp_prim_desc.src_desc(), args[DNNL_ARG_SRC]); + mkldnnUtils::loadDataToMklStream(*input, engine, stream, x_user_md, op_weights_bp_prim_desc.src_desc(), args[DNNL_ARG_SRC]); // weights - mkldnnUtils::loadDataToMklStream(weights, engine, stream, w_user_md, op_data_bp_prim_desc.weights_desc(), args[DNNL_ARG_WEIGHTS]); + mkldnnUtils::loadDataToMklStream(*weights, engine, stream, w_user_md, op_data_bp_prim_desc.weights_desc(), args[DNNL_ARG_WEIGHTS]); // gradO auto gradO_user_mem = dnnl::memory(gradO_user_md, engine, const_cast(gradO->buffer())); @@ -299,16 +275,10 @@ static void deconv3dBackPropMKLDNN(const NDArray* input, const NDArray* weights, args[DNNL_ARG_DIFF_DST] = gradO_mkl_memD; // gradI - auto gradI_user_mem = dnnl::memory(gradI_user_md, engine, gradI->buffer()); - const bool gradIReorder = op_data_bp_prim_desc.diff_src_desc() != gradI_user_mem.get_desc(); - auto gradI_mkl_mem = gradIReorder ? dnnl::memory(op_data_bp_prim_desc.diff_src_desc(), engine) : gradI_user_mem; - args[DNNL_ARG_DIFF_SRC] = gradI_mkl_mem; + auto gradI_user_mem = mkldnnUtils::loadDataToMklStream(*gradI, engine, stream, gradI_user_md, op_data_bp_prim_desc.diff_src_desc(), args[DNNL_ARG_DIFF_SRC]); // gradW - auto gradW_user_mem = dnnl::memory(gradW_user_md, engine, gradW->buffer()); - const bool gradWReorder = op_weights_bp_prim_desc.diff_weights_desc() != gradW_user_mem.get_desc(); - auto gradW_mkl_mem = gradWReorder ? dnnl::memory(op_weights_bp_prim_desc.diff_weights_desc(), engine) : gradW_user_mem; - args[DNNL_ARG_DIFF_WEIGHTS] = gradW_mkl_mem; + auto gradW_user_mem = mkldnnUtils::loadDataToMklStream(*gradW, engine, stream, gradW_user_md, op_weights_bp_prim_desc.diff_weights_desc(), args[DNNL_ARG_DIFF_WEIGHTS]); // gradB if(gradB != nullptr) { @@ -326,10 +296,10 @@ static void deconv3dBackPropMKLDNN(const NDArray* input, const NDArray* weights, dnnl::deconvolution_backward_weights(op_weights_bp_prim_desc).execute(stream, args); // reorder gradI if necessary - if (gradIReorder) - dnnl::reorder(gradI_mkl_mem, gradI_user_mem).execute(stream, gradI_mkl_mem, gradI_user_mem); - if (gradWReorder) - dnnl::reorder(gradW_mkl_mem, gradW_user_mem).execute(stream, gradW_mkl_mem, gradW_user_mem); + if (op_data_bp_prim_desc.diff_src_desc() != gradI_user_mem.get_desc()) + dnnl::reorder(args[DNNL_ARG_DIFF_SRC], gradI_user_mem).execute(stream, args[DNNL_ARG_DIFF_SRC], gradI_user_mem); + if (op_weights_bp_prim_desc.diff_weights_desc() != gradW_user_mem.get_desc()) + dnnl::reorder(args[DNNL_ARG_DIFF_WEIGHTS], gradW_user_mem).execute(stream, args[DNNL_ARG_DIFF_WEIGHTS], gradW_user_mem); stream.wait(); diff --git a/libnd4j/include/ops/declarable/platform/mkldnn/depthwiseConv2d.cpp b/libnd4j/include/ops/declarable/platform/mkldnn/depthwiseConv2d.cpp index 92f40537b..938494d5a 100644 --- a/libnd4j/include/ops/declarable/platform/mkldnn/depthwiseConv2d.cpp +++ b/libnd4j/include/ops/declarable/platform/mkldnn/depthwiseConv2d.cpp @@ -28,7 +28,7 @@ using namespace dnnl; -namespace sd { +namespace sd { namespace ops { namespace platforms { @@ -109,7 +109,7 @@ static void depthwiseConv2dMKLDNN(const NDArray* input, const NDArray* weights, // input dnnl::memory::desc x_mkl_md = dnnl::memory::desc(xDims, xType, dnnl::memory::format_tag::any); dnnl::memory::desc x_user_md = dnnl::memory::desc(xDims, xType, xzFormatMkl); - mkldnnUtils::setBlockStrides(input, x_user_md); + mkldnnUtils::setBlockStrides(*input, x_user_md); // weights dnnl::memory::desc w_mkl_md = dnnl::memory::desc(wDims, wType, dnnl::memory::format_tag::any); @@ -129,7 +129,7 @@ static void depthwiseConv2dMKLDNN(const NDArray* input, const NDArray* weights, // output dnnl::memory::desc z_mkl_md = dnnl::memory::desc(zDims, zType, dnnl::memory::format_tag::any); dnnl::memory::desc z_user_md = dnnl::memory::desc(zDims, zType, xzFormatMkl); - mkldnnUtils::setBlockStrides(output, z_user_md); + mkldnnUtils::setBlockStrides(*output, z_user_md); auto engine = mkldnnUtils::getEngine(LaunchContext::defaultContext()->engine()); @@ -146,10 +146,10 @@ static void depthwiseConv2dMKLDNN(const NDArray* input, const NDArray* weights, // provide memory buffers and check whether reorder is required // input - mkldnnUtils::loadDataToMklStream(input, engine, stream, x_user_md, op_prim_desc.src_desc(), args[DNNL_ARG_SRC]); + mkldnnUtils::loadDataToMklStream(*input, engine, stream, x_user_md, op_prim_desc.src_desc(), args[DNNL_ARG_SRC]); // weights - mkldnnUtils::loadDataToMklStream(weights, engine, stream, w_user_md, op_prim_desc.weights_desc(), args[DNNL_ARG_WEIGHTS]); + mkldnnUtils::loadDataToMklStream(*weights, engine, stream, w_user_md, op_prim_desc.weights_desc(), args[DNNL_ARG_WEIGHTS]); // bias if(bias != nullptr) { @@ -158,24 +158,21 @@ static void depthwiseConv2dMKLDNN(const NDArray* input, const NDArray* weights, } // output - auto z_user_mem = dnnl::memory(z_user_md, engine, output->buffer()); - const bool zReorder = op_prim_desc.dst_desc() != z_user_mem.get_desc(); - auto z_mkl_mem = zReorder ? dnnl::memory(op_prim_desc.dst_desc(), engine) : z_user_mem; - args[DNNL_ARG_DST] = z_mkl_mem; + auto z_user_mem = mkldnnUtils::loadDataToMklStream(*output, engine, stream, z_user_md, op_prim_desc.dst_desc(), args[DNNL_ARG_DST]); // run calculations dnnl::convolution_forward(op_prim_desc).execute(stream, args); // reorder outputs if necessary - if (zReorder) - dnnl::reorder(z_mkl_mem, z_user_mem).execute(stream, z_mkl_mem, z_user_mem); + if (op_prim_desc.dst_desc() != z_user_mem.get_desc()) + dnnl::reorder(args[DNNL_ARG_DST], z_user_mem).execute(stream, args[DNNL_ARG_DST], z_user_mem); stream.wait(); // shape::printArray(z_mkl_mem.map_data(),8); } ////////////////////////////////////////////////////////////////////////// -static void depthwiseConv2dNackPropMKLDNN(const NDArray* input, const NDArray* weights, const NDArray* gradO, NDArray* gradI, NDArray* gradW, NDArray* gradB, +static void depthwiseConv2dBpMKLDNN(const NDArray* input, const NDArray* weights, const NDArray* gradO, NDArray* gradI, NDArray* gradW, NDArray* gradB, const int kH, const int kW, const int sH, const int sW, const int pH, const int pW, const int dH, const int dW, const int paddingMode, const bool isNCHW, const int wFormat) { @@ -235,7 +232,7 @@ static void depthwiseConv2dNackPropMKLDNN(const NDArray* input, const NDArray* w // input dnnl::memory::desc x_mkl_md = dnnl::memory::desc(xDims, xType, dnnl::memory::format_tag::any); dnnl::memory::desc x_user_md = dnnl::memory::desc(xDims, xType, xzFormatMkl); - mkldnnUtils::setBlockStrides(input, x_user_md); + mkldnnUtils::setBlockStrides(*input, x_user_md); // weights dnnl::memory::desc w_mkl_md = dnnl::memory::desc(wDims, wType, dnnl::memory::format_tag::any); @@ -250,12 +247,12 @@ static void depthwiseConv2dNackPropMKLDNN(const NDArray* input, const NDArray* w // gradO dnnl::memory::desc gradO_mkl_md = dnnl::memory::desc(zDims, gradOType, dnnl::memory::format_tag::any); dnnl::memory::desc gradO_user_md = dnnl::memory::desc(zDims, gradOType, xzFormatMkl); - mkldnnUtils::setBlockStrides(gradO, gradO_user_md); + mkldnnUtils::setBlockStrides(*gradO, gradO_user_md); // gradI dnnl::memory::desc gradI_mkl_md = dnnl::memory::desc(xDims, gradIType, dnnl::memory::format_tag::any); dnnl::memory::desc gradI_user_md = dnnl::memory::desc(xDims, gradIType, xzFormatMkl); - mkldnnUtils::setBlockStrides(gradI, gradI_user_md); + mkldnnUtils::setBlockStrides(*gradI, gradI_user_md); // gradW dnnl::memory::desc gradW_mkl_md = dnnl::memory::desc(wDims, gradWType, dnnl::memory::format_tag::any); @@ -294,10 +291,10 @@ static void depthwiseConv2dNackPropMKLDNN(const NDArray* input, const NDArray* w // provide memory buffers and check whether reorder is required // input - mkldnnUtils::loadDataToMklStream(input, engine, stream, x_user_md, op_weights_bp_prim_desc.src_desc(), args[DNNL_ARG_SRC]); + mkldnnUtils::loadDataToMklStream(*input, engine, stream, x_user_md, op_weights_bp_prim_desc.src_desc(), args[DNNL_ARG_SRC]); // weights - mkldnnUtils::loadDataToMklStream(weights, engine, stream, w_user_md, op_data_bp_prim_desc.weights_desc(), args[DNNL_ARG_WEIGHTS]); + mkldnnUtils::loadDataToMklStream(*weights, engine, stream, w_user_md, op_data_bp_prim_desc.weights_desc(), args[DNNL_ARG_WEIGHTS]); // gradO auto gradO_user_mem = dnnl::memory(gradO_user_md, engine, const_cast(gradO->buffer())); @@ -312,16 +309,10 @@ static void depthwiseConv2dNackPropMKLDNN(const NDArray* input, const NDArray* w args[DNNL_ARG_DIFF_DST] = gradO_mkl_memD; // gradI - auto gradI_user_mem = dnnl::memory(gradI_user_md, engine, gradI->buffer()); - const bool gradIReorder = op_data_bp_prim_desc.diff_src_desc() != gradI_user_mem.get_desc(); - auto gradI_mkl_mem = gradIReorder ? dnnl::memory(op_data_bp_prim_desc.diff_src_desc(), engine) : gradI_user_mem; - args[DNNL_ARG_DIFF_SRC] = gradI_mkl_mem; + auto gradI_user_mem = mkldnnUtils::loadDataToMklStream(*gradI, engine, stream, gradI_user_md, op_data_bp_prim_desc.diff_src_desc(), args[DNNL_ARG_DIFF_SRC]); // gradW - auto gradW_user_mem = dnnl::memory(gradW_user_md, engine, gradW->buffer()); - const bool gradWReorder = op_weights_bp_prim_desc.diff_weights_desc() != gradW_user_mem.get_desc(); - auto gradW_mkl_mem = gradWReorder ? dnnl::memory(op_weights_bp_prim_desc.diff_weights_desc(), engine) : gradW_user_mem; - args[DNNL_ARG_DIFF_WEIGHTS] = gradW_mkl_mem; + auto gradW_user_mem = mkldnnUtils::loadDataToMklStream(*gradW, engine, stream, gradW_user_md, op_weights_bp_prim_desc.diff_weights_desc(), args[DNNL_ARG_DIFF_WEIGHTS]); // gradB if(gradB != nullptr) { @@ -339,10 +330,10 @@ static void depthwiseConv2dNackPropMKLDNN(const NDArray* input, const NDArray* w dnnl::convolution_backward_weights(op_weights_bp_prim_desc).execute(stream, args); // reorder gradI if necessary - if (gradIReorder) - dnnl::reorder(gradI_mkl_mem, gradI_user_mem).execute(stream, gradI_mkl_mem, gradI_user_mem); - if (gradWReorder) - dnnl::reorder(gradW_mkl_mem, gradW_user_mem).execute(stream, gradW_mkl_mem, gradW_user_mem); + if (op_data_bp_prim_desc.diff_src_desc() != gradI_user_mem.get_desc()) + dnnl::reorder(args[DNNL_ARG_DIFF_SRC], gradI_user_mem).execute(stream, args[DNNL_ARG_DIFF_SRC], gradI_user_mem); + if (op_weights_bp_prim_desc.diff_weights_desc() != gradW_user_mem.get_desc()) + dnnl::reorder(args[DNNL_ARG_DIFF_WEIGHTS], gradW_user_mem).execute(stream, args[DNNL_ARG_DIFF_WEIGHTS], gradW_user_mem); stream.wait(); @@ -458,7 +449,7 @@ PLATFORM_IMPL(depthwise_conv2d_bp, ENGINE_CPU) { if(bias) REQUIRE_TRUE(bias->rankOf() <= 2 && oC == bias->lengthOf(), 0, "CUSTOM DEPTHWISECONV2D_BP MKL OP: wrong shape of array with biases, expected rank, length: <=2, %i, but got %i, %i instead !", oC, bias->rankOf(), bias->lengthOf()); - depthwiseConv2dNackPropMKLDNN(input, weights, gradO, gradI, gradW, gradB, kH, kW, sH, sW, pH, pW, dH, dW, paddingMode, isNCHW, wFormat); + depthwiseConv2dBpMKLDNN(input, weights, gradO, gradI, gradW, gradB, kH, kW, sH, sW, pH, pW, dH, dW, paddingMode, isNCHW, wFormat); return Status::OK(); } diff --git a/libnd4j/include/ops/declarable/platform/mkldnn/lstmLayer.cpp b/libnd4j/include/ops/declarable/platform/mkldnn/lstmLayer.cpp index 60c61ea5f..a74a55732 100644 --- a/libnd4j/include/ops/declarable/platform/mkldnn/lstmLayer.cpp +++ b/libnd4j/include/ops/declarable/platform/mkldnn/lstmLayer.cpp @@ -169,71 +169,43 @@ static void lstmLayerMKLDNN(const NDArray* x, const NDArray* Wx, const NDArray* x_lstm_md = dnnl::memory::desc({sL, bS, nIn}, xType, dnnl::memory::format_tag::any); // x_user_md = dataFormat == 0 ? dnnl::memory::desc({sL, bS, nIn}, type, dnnl::memory::format_tag::tnc) : dnnl::memory::desc({bS, sL, nIn}, type, dnnl::memory::format_tag::ntc); x_user_md = dnnl::memory::desc({sL, bS, nIn}, xType, dnnl::memory::format_tag::tnc); - x_user_md.data.format_kind = dnnl_blocked; // overrides format - x_user_md.data.format_desc.blocking.strides[0] = x->stridesOf()[0]; - x_user_md.data.format_desc.blocking.strides[1] = x->stridesOf()[1]; - x_user_md.data.format_desc.blocking.strides[2] = x->stridesOf()[2]; + mkldnnUtils::setBlockStrides(*x, x_user_md); // wx wx_lstm_md = dnnl::memory::desc({1,dirDim,nIn,4,nOut}, wType, dnnl::memory::format_tag::any); wx_user_md = dnnl::memory::desc({1,dirDim,nIn,4,nOut}, wType, dnnl::memory::format_tag::ldigo); - wx_user_md.data.format_kind = dnnl_blocked; // overrides format - wx_user_md.data.format_desc.blocking.strides[0] = Wx->stridesOf()[0]; - wx_user_md.data.format_desc.blocking.strides[1] = Wx->stridesOf()[1]; - wx_user_md.data.format_desc.blocking.strides[2] = Wx->stridesOf()[2]; - wx_user_md.data.format_desc.blocking.strides[3] = Wx->stridesOf()[3]; - wx_user_md.data.format_desc.blocking.strides[4] = Wx->stridesOf()[4]; + mkldnnUtils::setBlockStrides(*Wx, wx_user_md); // wr wr_lstm_md = dnnl::memory::desc({1,dirDim,nOut,4,nOut}, wType, dnnl::memory::format_tag::any); wr_user_md = dnnl::memory::desc({1,dirDim,nOut,4,nOut}, wType, dnnl::memory::format_tag::ldigo); - wr_user_md.data.format_kind = dnnl_blocked; // overrides format - wr_user_md.data.format_desc.blocking.strides[0] = Wr->stridesOf()[0]; - wr_user_md.data.format_desc.blocking.strides[1] = Wr->stridesOf()[1]; - wr_user_md.data.format_desc.blocking.strides[2] = Wr->stridesOf()[2]; - wr_user_md.data.format_desc.blocking.strides[3] = Wr->stridesOf()[3]; - wr_user_md.data.format_desc.blocking.strides[4] = Wr->stridesOf()[4]; + mkldnnUtils::setBlockStrides(*Wr, wr_user_md); // h h_lstm_md = dnnl::memory::desc({sL, bS, hDirDim*nOut}, hType, dnnl::memory::format_tag::any); // h_user_md = dataFormat == 0 ? dnnl::memory::desc({sL, bS, hDirDim*nOut}, type, dnnl::memory::format_tag::tnc) : dnnl::memory::desc({bS, sL, hDirDim*nOut}, type, dnnl::memory::format_tag::ntc); h_user_md = dnnl::memory::desc({sL, bS, hDirDim*nOut}, hType, dnnl::memory::format_tag::tnc); - h_user_md.data.format_kind = dnnl_blocked; // overrides format - h_user_md.data.format_desc.blocking.strides[0] = h->stridesOf()[0]; - h_user_md.data.format_desc.blocking.strides[1] = h->stridesOf()[1]; - h_user_md.data.format_desc.blocking.strides[2] = h->stridesOf()[2]; + mkldnnUtils::setBlockStrides(*h, h_user_md); // b if(b) { b_lstm_md = dnnl::memory::desc({1,dirDim,4,nOut}, bType, dnnl::memory::format_tag::any); b_user_md = dnnl::memory::desc({1,dirDim,4,nOut}, bType, dnnl::memory::format_tag::ldgo); - b_user_md.data.format_kind = dnnl_blocked; // overrides format - b_user_md.data.format_desc.blocking.strides[0] = b->stridesOf()[0]; - b_user_md.data.format_desc.blocking.strides[1] = b->stridesOf()[1]; - b_user_md.data.format_desc.blocking.strides[2] = b->stridesOf()[2]; - b_user_md.data.format_desc.blocking.strides[3] = b->stridesOf()[3]; + mkldnnUtils::setBlockStrides(*b, b_user_md); } // hI if(hI) { hI_lstm_md = dnnl::memory::desc({1,dirDim,bS,nOut}, xType, dnnl::memory::format_tag::any); hI_user_md = dnnl::memory::desc({1,dirDim,bS,nOut}, xType, dnnl::memory::format_tag::ldnc); - hI_user_md.data.format_kind = dnnl_blocked; // overrides format - hI_user_md.data.format_desc.blocking.strides[0] = hI->stridesOf()[0]; - hI_user_md.data.format_desc.blocking.strides[1] = hI->stridesOf()[1]; - hI_user_md.data.format_desc.blocking.strides[2] = hI->stridesOf()[2]; - hI_user_md.data.format_desc.blocking.strides[3] = hI->stridesOf()[3]; + mkldnnUtils::setBlockStrides(*hI, hI_user_md); } // cI if(cI) { cI_lstm_md = dnnl::memory::desc({1,dirDim,bS,nOut}, xType, dnnl::memory::format_tag::any); cI_user_md = dnnl::memory::desc({1,dirDim,bS,nOut}, xType, dnnl::memory::format_tag::ldnc); - cI_user_md.data.format_kind = dnnl_blocked; // overrides format - cI_user_md.data.format_desc.blocking.strides[0] = cI->stridesOf()[0]; - cI_user_md.data.format_desc.blocking.strides[1] = cI->stridesOf()[1]; - cI_user_md.data.format_desc.blocking.strides[2] = cI->stridesOf()[2]; - cI_user_md.data.format_desc.blocking.strides[2] = cI->stridesOf()[3]; + mkldnnUtils::setBlockStrides(*cI, cI_user_md); } // hL @@ -241,20 +213,13 @@ static void lstmLayerMKLDNN(const NDArray* x, const NDArray* Wx, const NDArray* hL_lstm_md = dnnl::memory::desc({1,dirDim,bS,nOut}, hType, dnnl::memory::format_tag::any); hL_user_md = dnnl::memory::desc({1,dirDim,bS,nOut}, hType, dnnl::memory::format_tag::ldnc); hL_user_md.data.format_kind = dnnl_blocked; // overrides format - hL_user_md.data.format_desc.blocking.strides[0] = hL->stridesOf()[0]; - hL_user_md.data.format_desc.blocking.strides[1] = hL->stridesOf()[1]; - hL_user_md.data.format_desc.blocking.strides[2] = hL->stridesOf()[2]; - hL_user_md.data.format_desc.blocking.strides[3] = hL->stridesOf()[3]; + mkldnnUtils::setBlockStrides(*hL, hL_user_md); } if(cL) { cL_lstm_md = dnnl::memory::desc({1,dirDim,bS,nOut}, hType, dnnl::memory::format_tag::ldnc); cL_user_md = dnnl::memory::desc({1,dirDim,bS,nOut}, hType, dnnl::memory::format_tag::ldnc); - cL_user_md.data.format_kind = dnnl_blocked; // overrides format - cL_user_md.data.format_desc.blocking.strides[0] = cL->stridesOf()[0]; - cL_user_md.data.format_desc.blocking.strides[1] = cL->stridesOf()[1]; - cL_user_md.data.format_desc.blocking.strides[2] = cL->stridesOf()[2]; - cL_user_md.data.format_desc.blocking.strides[3] = cL->stridesOf()[3]; + mkldnnUtils::setBlockStrides(*cL, cL_user_md); } // lstm memory description @@ -272,64 +237,49 @@ static void lstmLayerMKLDNN(const NDArray* x, const NDArray* Wx, const NDArray* // provide memory and check whether reorder is required // x - mkldnnUtils::loadDataToMklStream(x, engine, stream, x_user_md, lstm_prim_desc.src_layer_desc(), args[DNNL_ARG_SRC_LAYER]); + mkldnnUtils::loadDataToMklStream(*x, engine, stream, x_user_md, lstm_prim_desc.src_layer_desc(), args[DNNL_ARG_SRC_LAYER]); // wx - mkldnnUtils::loadDataToMklStream(Wx, engine, stream, wx_user_md, lstm_prim_desc.weights_layer_desc(), args[DNNL_ARG_WEIGHTS_LAYER]); + mkldnnUtils::loadDataToMklStream(*Wx, engine, stream, wx_user_md, lstm_prim_desc.weights_layer_desc(), args[DNNL_ARG_WEIGHTS_LAYER]); // wr - mkldnnUtils::loadDataToMklStream(Wr, engine, stream, wr_user_md, lstm_prim_desc.weights_iter_desc(), args[DNNL_ARG_WEIGHTS_ITER]); + mkldnnUtils::loadDataToMklStream(*Wr, engine, stream, wr_user_md, lstm_prim_desc.weights_iter_desc(), args[DNNL_ARG_WEIGHTS_ITER]); // h - auto h_user_mem = dnnl::memory(h_user_md, engine, h->buffer()); - const bool hReorder = lstm_prim_desc.dst_layer_desc() != h_user_mem.get_desc(); - auto h_lstm_mem = hReorder ? dnnl::memory(lstm_prim_desc.dst_layer_desc(), engine) : h_user_mem; - args[DNNL_ARG_DST_LAYER] = h_lstm_mem; + auto h_user_mem = mkldnnUtils::loadDataToMklStream(*h, engine, stream, h_user_md, lstm_prim_desc.dst_layer_desc(), args[DNNL_ARG_DST_LAYER]); // b - if(b) { - mkldnnUtils::loadDataToMklStream(b, engine, stream, b_user_md, lstm_prim_desc.bias_desc(), args[DNNL_ARG_BIAS]); - } + if(b) + mkldnnUtils::loadDataToMklStream(*b, engine, stream, b_user_md, lstm_prim_desc.bias_desc(), args[DNNL_ARG_BIAS]); // hI - if(hI) { - mkldnnUtils::loadDataToMklStream(hI, engine, stream, hI_user_md, lstm_prim_desc.src_iter_desc(), args[DNNL_ARG_SRC_ITER]); - } + if(hI) + mkldnnUtils::loadDataToMklStream(*hI, engine, stream, hI_user_md, lstm_prim_desc.src_iter_desc(), args[DNNL_ARG_SRC_ITER]); // cI - if(cI) { - mkldnnUtils::loadDataToMklStream(cI, engine, stream, cI_user_md, lstm_prim_desc.src_iter_c_desc(), args[DNNL_ARG_SRC_ITER_C]); - } + if(cI) + mkldnnUtils::loadDataToMklStream(*cI, engine, stream, cI_user_md, lstm_prim_desc.src_iter_c_desc(), args[DNNL_ARG_SRC_ITER_C]); - bool hLReorder(false), cLReorder(false); dnnl::memory hL_user_mem, cL_user_mem, hL_lstm_mem, cL_lstm_mem; // hL - if(hL) { - hL_user_mem = dnnl::memory(hL_user_md, engine, hL->buffer()); - hLReorder = lstm_prim_desc.dst_iter_desc() != hL_user_mem.get_desc(); - hL_lstm_mem = hLReorder ? dnnl::memory(lstm_prim_desc.dst_iter_desc(), engine) : hL_user_mem; - args[DNNL_ARG_DST_ITER] = hL_lstm_mem; - } + if(hL) + hL_user_mem = mkldnnUtils::loadDataToMklStream(*hL, engine, stream, hL_user_md, lstm_prim_desc.dst_iter_desc(), args[DNNL_ARG_DST_ITER]); // cL - if(cL) { - cL_user_mem = dnnl::memory(cL_user_md, engine, cL->buffer()); - cLReorder = lstm_prim_desc.dst_iter_c_desc() != cL_user_mem.get_desc(); - cL_lstm_mem = cLReorder ? dnnl::memory(lstm_prim_desc.dst_iter_c_desc(), engine) : cL_user_mem; - args[DNNL_ARG_DST_ITER_C] = cL_lstm_mem; - } + if(cL) + cL_user_mem = mkldnnUtils::loadDataToMklStream(*cL, engine, stream, cL_user_md, lstm_prim_desc.dst_iter_c_desc(), args[DNNL_ARG_DST_ITER_C]); // run calculations lstm_forward(lstm_prim_desc).execute(stream, args); // reorder outputs if necessary - if (hReorder) - reorder(h_lstm_mem, h_user_mem).execute(stream, h_lstm_mem, h_user_mem); - if(hLReorder) - reorder(hL_lstm_mem, hL_user_mem).execute(stream, hL_lstm_mem, hL_user_mem); - if(cLReorder) - reorder(cL_lstm_mem, cL_user_mem).execute(stream, cL_lstm_mem, cL_user_mem); + if (lstm_prim_desc.dst_layer_desc() != h_user_mem.get_desc()) + reorder(args[DNNL_ARG_DST_LAYER], h_user_mem).execute(stream, args[DNNL_ARG_DST_LAYER], h_user_mem); + if(lstm_prim_desc.dst_iter_desc() != hL_user_mem.get_desc()) + reorder(args[DNNL_ARG_DST_ITER], hL_user_mem).execute(stream, args[DNNL_ARG_DST_ITER], hL_user_mem); + if(lstm_prim_desc.dst_iter_c_desc() != cL_user_mem.get_desc()) + reorder(args[DNNL_ARG_DST_ITER_C], cL_user_mem).execute(stream, args[DNNL_ARG_DST_ITER_C], cL_user_mem); stream.wait(); } @@ -377,9 +327,9 @@ PLATFORM_IMPL(lstmLayer, ENGINE_CPU) { auto cL = retLastC ? OUTPUT_VARIABLE(count++) : nullptr; // cell state at last step // evaluate dimensions - const Nd4jLong sL = dataFormat == 3 ? x->sizeAt(0) : x->sizeAt(dataFormat); - const Nd4jLong bS = dataFormat == 1 || dataFormat == 2 ? x->sizeAt(0) : x->sizeAt(-2); - const Nd4jLong nIn = dataFormat == 2 ? x->sizeAt(1) : x->sizeAt(-1); + const Nd4jLong sL = x->sizeAt(dataFormat); + const Nd4jLong bS = dataFormat == 0 ? x->sizeAt(1) : x->sizeAt(0); + const Nd4jLong nIn = x->sizeAt(2); const Nd4jLong nOut = Wx->sizeAt(-1) / 4; // inputs validations @@ -435,14 +385,21 @@ PLATFORM_IMPL(lstmLayer, ENGINE_CPU) { WxR = new NDArray(Wx->reshape(Wx->ordering(), {1,dirDim,nIn,4,nOut})); WrR = new NDArray(Wr->reshape(Wr->ordering(), {1,dirDim,nOut,4,nOut})); + if(b) - bR = new NDArray(b->reshape(b->ordering(), {1,dirDim,4,nOut})); + bR = new NDArray(b->reshape(b->ordering(), {1,dirDim,4,nOut})); + else + bR = new NDArray(x->ordering(), {1,dirDim,4,nOut}, x->dataType(), x->getContext()); // already nullified + if(hI) hIR = new NDArray(hI->reshape(hI->ordering(), {1,dirDim,bS,nOut})); + if(cI) cIR = new NDArray(cI->reshape(cI->ordering(), {1,dirDim,bS,nOut})); + if(hL) hLR = new NDArray(hL->reshape(hL->ordering(), {1,dirDim,bS,nOut}, false)); + if(cL) cLR = new NDArray(cL->reshape(cL->ordering(), {1,dirDim,bS,nOut}, false)); diff --git a/libnd4j/include/ops/declarable/platform/mkldnn/matmul.cpp b/libnd4j/include/ops/declarable/platform/mkldnn/matmul.cpp index 265fb74bc..f242b2e79 100644 --- a/libnd4j/include/ops/declarable/platform/mkldnn/matmul.cpp +++ b/libnd4j/include/ops/declarable/platform/mkldnn/matmul.cpp @@ -31,20 +31,6 @@ namespace sd { namespace ops { namespace platforms { - dnnl::memory::format_tag get_format_tag(const sd::NDArray &array) { - switch (array.rankOf()) { - case 1: - return dnnl::memory::format_tag::ab; - case 2: - return array.ordering() == 'c' ? dnnl::memory::format_tag::ab : dnnl::memory::format_tag::ba; - case 3: - return array.ordering() == 'c' ? dnnl::memory::format_tag::abc : dnnl::memory::format_tag::cba; - default: - throw std::runtime_error("MKLDNN matmul only supports 2D/3D arrays"); - } - } - - ////////////////////////////////////////////////////////////////////////// static void matmulMKLDNN(const NDArray* x, const NDArray* y, NDArray* z, const bool transX, const bool transY, float alpha = 1.f, float beta = 0.f) { @@ -123,11 +109,16 @@ static void matmulMKLDNN(const NDArray* x, const NDArray* y, NDArray* z, const b else if(z->dataType() == DataType::INT8) zType = dnnl::memory::data_type::s8; + + const auto xFormat = xRank == 1 ? dnnl::memory::format_tag::ab : mkldnnUtils::getFormat(*xTR); + const auto yFormat = yRank == 1 ? dnnl::memory::format_tag::ab : mkldnnUtils::getFormat(*yTR); + const auto zFormat = zRank == 1 ? dnnl::memory::format_tag::ab : mkldnnUtils::getFormat(*zR); + // memory descriptors for arrays + dnnl::memory::desc x_mkl_md, x_user_md, y_mkl_md, y_user_md, z_mkl_md, z_user_md; // x - dnnl::memory::desc x_mkl_md = dnnl::memory::desc(xShape, xType, get_format_tag(*xTR)); - dnnl::memory::desc x_user_md = dnnl::memory::desc(xShape, xType, get_format_tag(*xTR)); + x_user_md = x_mkl_md = dnnl::memory::desc(xShape, xType, xFormat); if(xTR->ews() != 1) { x_user_md.data.format_kind = dnnl_blocked; // overrides format x_user_md.data.format_desc.blocking.strides[0] = xRank == 1 ? 1 : xTR->strideAt(0); @@ -137,8 +128,7 @@ static void matmulMKLDNN(const NDArray* x, const NDArray* y, NDArray* z, const b } // y - dnnl::memory::desc y_mkl_md = dnnl::memory::desc(yShape, yType, get_format_tag(*yTR)); - dnnl::memory::desc y_user_md = dnnl::memory::desc(yShape, yType, get_format_tag(*yTR)); + y_user_md = y_mkl_md = dnnl::memory::desc(yShape, yType, yFormat); if(yTR->ews() != 1) { y_user_md.data.format_kind = dnnl_blocked; // overrides format y_user_md.data.format_desc.blocking.strides[0] = yRank == 1 ? 1 : yTR->strideAt(0); @@ -148,8 +138,7 @@ static void matmulMKLDNN(const NDArray* x, const NDArray* y, NDArray* z, const b } // z - dnnl::memory::desc z_mkl_md = dnnl::memory::desc(zShape, zType, get_format_tag(*zR)); - dnnl::memory::desc z_user_md = dnnl::memory::desc(zShape, zType, get_format_tag(*zR)); + z_user_md = z_mkl_md = dnnl::memory::desc(zShape, zType, zFormat); if(zR->ews() != 1) { z_user_md.data.format_kind = dnnl_blocked; // overrides format z_user_md.data.format_desc.blocking.strides[0] = zRank == 1 ? 1 : zR->strideAt(0); @@ -181,37 +170,20 @@ static void matmulMKLDNN(const NDArray* x, const NDArray* y, NDArray* z, const b // provide memory buffers and check whether reorder is required // input - mkldnnUtils::loadDataToMklStream(xTR, engine, stream, x_user_md, op_prim_desc.src_desc(), args[DNNL_ARG_SRC]); - /* - auto x_user_mem = dnnl::memory(x_user_md, engine, xTR->buffer()); - const bool xReorder = op_prim_desc.src_desc() != x_user_mem.get_desc(); - auto x_mkl_mem = xReorder ? dnnl::memory(op_prim_desc.src_desc(), engine) : x_user_mem; - if (xReorder) - dnnl::reorder(x_user_mem, x_mkl_mem).execute(stream, x_user_mem, x_mkl_mem); - args[DNNL_ARG_SRC] = x_mkl_mem; -*/ + mkldnnUtils::loadDataToMklStream(*xTR, engine, stream, x_user_md, op_prim_desc.src_desc(), args[DNNL_ARG_SRC]); + // y - mkldnnUtils::loadDataToMklStream(yTR, engine, stream, y_user_md, op_prim_desc.weights_desc(), args[DNNL_ARG_WEIGHTS]); - /* - auto y_user_mem = dnnl::memory(y_user_md, engine, yTR->buffer()); - const bool yReorder = op_prim_desc.weights_desc() != y_user_mem.get_desc(); - auto y_mkl_mem = yReorder ? dnnl::memory(op_prim_desc.weights_desc(), engine) : y_user_mem; - if (yReorder) - dnnl::reorder(y_user_mem, y_mkl_mem).execute(stream, y_user_mem, y_mkl_mem); - args[DNNL_ARG_WEIGHTS] = y_mkl_mem; -*/ + mkldnnUtils::loadDataToMklStream(*yTR, engine, stream, y_user_md, op_prim_desc.weights_desc(), args[DNNL_ARG_WEIGHTS]); + // z - auto z_user_mem = dnnl::memory(z_user_md, engine, zR->buffer()); - const bool zReorder = op_prim_desc.dst_desc() != z_user_mem.get_desc(); - auto z_mkl_mem = zReorder ? dnnl::memory(op_prim_desc.dst_desc(), engine) : z_user_mem; - args[DNNL_ARG_DST] = z_mkl_mem; + auto z_user_mem = mkldnnUtils::loadDataToMklStream(*zR, engine, stream, z_user_md, op_prim_desc.dst_desc(), args[DNNL_ARG_DST]); // run calculations dnnl::matmul(op_prim_desc).execute(stream, args); // reorder outputs if necessary - if (zReorder) - dnnl::reorder(z_mkl_mem, z_user_mem).execute(stream, z_mkl_mem, z_user_mem); + if (op_prim_desc.dst_desc() != z_user_mem.get_desc()) + dnnl::reorder(args[DNNL_ARG_DST], z_user_mem).execute(stream, args[DNNL_ARG_DST], z_user_mem); stream.wait(); diff --git a/libnd4j/include/ops/declarable/platform/mkldnn/mkldnnUtils.cpp b/libnd4j/include/ops/declarable/platform/mkldnn/mkldnnUtils.cpp index bc79e6169..dcc0258f4 100644 --- a/libnd4j/include/ops/declarable/platform/mkldnn/mkldnnUtils.cpp +++ b/libnd4j/include/ops/declarable/platform/mkldnn/mkldnnUtils.cpp @@ -38,45 +38,65 @@ void getDims(const NDArray* array, const int rank, dnnl::memory::dims& mklDims){ mklDims = dnnl::memory::dims(vDims); } ////////////////////////////////////////////////////////////////////// -dnnl::memory::format_tag getFormat(const int rank){ - if (2 == rank) { - return dnnl::memory::format_tag::ab; - } - else if (3 == rank) { - return dnnl::memory::format_tag::abc; - } - else if (4 == rank) { - return dnnl::memory::format_tag::abcd; - } - else if (5 == rank) { - return dnnl::memory::format_tag::abcde; - } - else if (6 == rank) { - return dnnl::memory::format_tag::abcdef; - } - return dnnl::memory::format_tag::a; // 1 == dataSetRank +dnnl::memory::format_tag getFormat(const NDArray& arr) { + + dnnl::memory::format_tag result; + + switch (arr.rankOf()) { + case 1: + result = dnnl::memory::format_tag::a; + break; + case 2: + result = arr.ordering() == 'c' ? dnnl::memory::format_tag::ab : dnnl::memory::format_tag::ba; + break; + case 3: + result = arr.ordering() == 'c' ? dnnl::memory::format_tag::abc : dnnl::memory::format_tag::cba; + break; + case 4: + result = dnnl::memory::format_tag::abcd; + break; + case 5: + result = dnnl::memory::format_tag::abcde; + break; + case 6: + result = dnnl::memory::format_tag::abcdef; + break; + default: + throw std::invalid_argument("MKLDNN getFormat: do we really want to use arras with rank > 6 ?"); + } + + return result; } ////////////////////////////////////////////////////////////////////// -void setBlockStrides(const NDArray* array, dnnl::memory::desc& mklMd){ +void setBlockStrides(const NDArray& array, dnnl::memory::desc& mklMd, const std::vector& permut) { - if (array->ews() != 1 || array->ordering() != 'c') { - mklMd.data.format_kind = dnnl_blocked; // overrides format - for (auto i = 0; i < array->rankOf(); ++i) { - mklMd.data.format_desc.blocking.strides[i] = array->strideAt(i); + if (array.ews() != 1 || (array.rankOf() > 3 && array.ordering() == 'f') || !permut.empty()) { + + mklMd.data.format_kind = dnnl_blocked; // overrides format + + if(permut.empty()) + for (auto i = 0; i < array.rankOf(); ++i) + mklMd.data.format_desc.blocking.strides[i] = array.strideAt(i); + else { + if(array.rankOf() != permut.size()) + throw std::invalid_argument("mkldnnUtils::setBlockStrides: size of permut vector is not equal to array rank !"); + for (auto i = 0; i < array.rankOf(); ++i) + mklMd.data.format_desc.blocking.strides[i] = array.strideAt(permut[i]); } } } //////////////////////////////////////////////////////////////////////////////////////////////// -void loadDataToMklStream(const NDArray* array, const dnnl::engine& engine, const dnnl::stream& stream, const dnnl::memory::desc& user_md, const dnnl::memory::desc& primitive_md, - dnnl::memory& arg) { +dnnl::memory loadDataToMklStream(const NDArray& array, const dnnl::engine& engine, const dnnl::stream& stream, + const dnnl::memory::desc& user_md, const dnnl::memory::desc& primitive_md, dnnl::memory& arg) { - auto user_mem = dnnl::memory(user_md, engine,const_cast(array->buffer())); + auto user_mem = dnnl::memory(user_md, engine, const_cast(array).buffer()); const bool bReorder = primitive_md != user_mem.get_desc(); auto mkl_mem = bReorder ? dnnl::memory(primitive_md, engine) : user_mem; if (bReorder) dnnl::reorder(user_mem, mkl_mem).execute(stream, user_mem, mkl_mem); arg = mkl_mem; + return user_mem; } ////////////////////////////////////////////////////////////////////// @@ -122,33 +142,21 @@ void poolingMKLDNN(const NDArray *input, NDArray *output, xzFrmat = isNCHW ? dnnl::memory::format_tag::ncdhw : dnnl::memory::format_tag::ndhwc; } + std::vector permut; + if(!isNCHW) + permut = rank == 4 ? std::vector({0,3,1,2}) : std::vector({0,4,1,2,3}); + // memory descriptors for arrays // input dnnl::memory::desc x_mkl_md = dnnl::memory::desc(xDims, type, xzFrmat); dnnl::memory::desc x_user_md = dnnl::memory::desc(xDims, type, xzFrmat); - if(input->ews() != 1 || input->ordering() != 'c') { - x_user_md.data.format_kind = dnnl_blocked; // overrides format - x_user_md.data.format_desc.blocking.strides[0] = input->strideAt(0); - x_user_md.data.format_desc.blocking.strides[1] = input->strideAt(isNCHW ? 1 :-1); - x_user_md.data.format_desc.blocking.strides[2] = input->strideAt(isNCHW ? 2 : 1); - x_user_md.data.format_desc.blocking.strides[3] = input->strideAt(isNCHW ? 3 : 2); - if(rank == 5) - x_user_md.data.format_desc.blocking.strides[4] = input->strideAt(isNCHW ? 4 : 3); - } + mkldnnUtils::setBlockStrides(*input, x_user_md, permut); // output dnnl::memory::desc z_mkl_md = dnnl::memory::desc(zDims, type, dnnl::memory::format_tag::any); dnnl::memory::desc z_user_md = dnnl::memory::desc(zDims, type, xzFrmat); - if(output->ews() != 1 || output->ordering() != 'c') { - z_user_md.data.format_kind = dnnl_blocked; // overrides format - z_user_md.data.format_desc.blocking.strides[0] = output->strideAt(0); - z_user_md.data.format_desc.blocking.strides[1] = output->strideAt(isNCHW ? 1 :-1); - z_user_md.data.format_desc.blocking.strides[2] = output->strideAt(isNCHW ? 2 : 1); - z_user_md.data.format_desc.blocking.strides[3] = output->strideAt(isNCHW ? 3 : 2); - if(rank == 5) - z_user_md.data.format_desc.blocking.strides[4] = output->strideAt(isNCHW ? 4 : 3); - } + mkldnnUtils::setBlockStrides(*output, z_user_md, permut); auto engine = mkldnnUtils::getEngine(LaunchContext::defaultContext()->engine()); @@ -164,20 +172,17 @@ void poolingMKLDNN(const NDArray *input, NDArray *output, // provide memory buffers and check whether reorder is required // input - mkldnnUtils::loadDataToMklStream(input, engine, stream, x_user_md, op_prim_desc.src_desc(), args[DNNL_ARG_SRC]); + mkldnnUtils::loadDataToMklStream(*input, engine, stream, x_user_md, op_prim_desc.src_desc(), args[DNNL_ARG_SRC]); // output - auto z_user_mem = dnnl::memory(z_user_md, engine, output->buffer()); - const bool zReorder = op_prim_desc.dst_desc() != z_user_mem.get_desc(); - auto z_mkl_mem = zReorder ? dnnl::memory(op_prim_desc.dst_desc(), engine) : z_user_mem; - args[DNNL_ARG_DST] = z_mkl_mem; + auto z_user_mem = mkldnnUtils::loadDataToMklStream(*output, engine, stream, z_user_md, op_prim_desc.dst_desc(), args[DNNL_ARG_DST]); // run calculations dnnl::pooling_forward(op_prim_desc).execute(stream, args); // reorder outputs if necessary - if (zReorder) - dnnl::reorder(z_mkl_mem, z_user_mem).execute(stream, z_mkl_mem, z_user_mem); + if (op_prim_desc.dst_desc() != z_user_mem.get_desc()) + dnnl::reorder(args[DNNL_ARG_DST], z_user_mem).execute(stream, args[DNNL_ARG_DST], z_user_mem); stream.wait(); } @@ -226,46 +231,27 @@ void poolingBpMKLDNN(const NDArray *input, const NDArray *gradO, NDArray *gradI, xzFrmat = isNCHW ? dnnl::memory::format_tag::ncdhw : dnnl::memory::format_tag::ndhwc; } + std::vector permut; + if(!isNCHW) + permut = rank == 4 ? std::vector({0,3,1,2}) : std::vector({0,4,1,2,3}); + + // memory descriptors for arrays // input dnnl::memory::desc x_mkl_md = dnnl::memory::desc(xDims, type, xzFrmat); dnnl::memory::desc x_user_md = dnnl::memory::desc(xDims, type, xzFrmat); - if(input->ews() != 1 || input->ordering() != 'c') { - x_user_md.data.format_kind = dnnl_blocked; // overrides format - x_user_md.data.format_desc.blocking.strides[0] = input->strideAt(0); - x_user_md.data.format_desc.blocking.strides[1] = input->strideAt(isNCHW ? 1 :-1); - x_user_md.data.format_desc.blocking.strides[2] = input->strideAt(isNCHW ? 2 : 1); - x_user_md.data.format_desc.blocking.strides[3] = input->strideAt(isNCHW ? 3 : 2); - if(rank == 5) - x_user_md.data.format_desc.blocking.strides[4] = input->strideAt(isNCHW ? 4 : 3); - } + mkldnnUtils::setBlockStrides(*input, x_user_md, permut); // gradO dnnl::memory::desc gradO_mkl_md = dnnl::memory::desc(zDims, type, dnnl::memory::format_tag::any); dnnl::memory::desc gradO_user_md = dnnl::memory::desc(zDims, type, xzFrmat); - if(gradO->ews() != 1 || gradO->ordering() != 'c') { - gradO_user_md.data.format_kind = dnnl_blocked; // overrides format - gradO_user_md.data.format_desc.blocking.strides[0] = gradO->strideAt(0); - gradO_user_md.data.format_desc.blocking.strides[1] = gradO->strideAt(isNCHW ? 1 :-1); - gradO_user_md.data.format_desc.blocking.strides[2] = gradO->strideAt(isNCHW ? 2 : 1); - gradO_user_md.data.format_desc.blocking.strides[3] = gradO->strideAt(isNCHW ? 3 : 2); - if(rank == 5) - gradO_user_md.data.format_desc.blocking.strides[4] = gradO->strideAt(isNCHW ? 4 : 3); - } + mkldnnUtils::setBlockStrides(*gradO, gradO_user_md, permut); // gradI dnnl::memory::desc gradI_mkl_md = dnnl::memory::desc(xDims, type, dnnl::memory::format_tag::any); dnnl::memory::desc gradI_user_md = dnnl::memory::desc(xDims, type, xzFrmat); - if(gradI->ews() != 1 || gradI->ordering() != 'c') { - gradI_user_md.data.format_kind = dnnl_blocked; // overrides format - gradI_user_md.data.format_desc.blocking.strides[0] = gradI->strideAt(0); - gradI_user_md.data.format_desc.blocking.strides[1] = gradI->strideAt(isNCHW ? 1 :-1); - gradI_user_md.data.format_desc.blocking.strides[2] = gradI->strideAt(isNCHW ? 2 : 1); - gradI_user_md.data.format_desc.blocking.strides[3] = gradI->strideAt(isNCHW ? 3 : 2); - if(rank == 5) - gradI_user_md.data.format_desc.blocking.strides[4] = gradI->strideAt(isNCHW ? 4 : 3); - } + mkldnnUtils::setBlockStrides(*gradI, gradI_user_md, permut); auto engine = mkldnnUtils::getEngine(LaunchContext::defaultContext()->engine()); dnnl::stream stream(engine); @@ -282,18 +268,15 @@ void poolingBpMKLDNN(const NDArray *input, const NDArray *gradO, NDArray *gradI, std::unordered_map args; // gradO - mkldnnUtils::loadDataToMklStream(gradO, engine, stream, gradO_user_md, op_bp_prim_desc.diff_dst_desc(), args[DNNL_ARG_DIFF_DST]); + mkldnnUtils::loadDataToMklStream(*gradO, engine, stream, gradO_user_md, op_bp_prim_desc.diff_dst_desc(), args[DNNL_ARG_DIFF_DST]); // gradI - auto gradI_user_mem = dnnl::memory(gradI_user_md, engine, gradI->buffer()); - const bool gradIReorder = op_bp_prim_desc.diff_src_desc() != gradI_user_mem.get_desc(); - auto gradI_mkl_mem = gradIReorder ? dnnl::memory(op_bp_prim_desc.diff_src_desc(), engine) : gradI_user_mem; - args[DNNL_ARG_DIFF_SRC] = gradI_mkl_mem; + auto gradI_user_mem = mkldnnUtils::loadDataToMklStream(*gradI, engine, stream, gradI_user_md, op_bp_prim_desc.diff_src_desc(), args[DNNL_ARG_DIFF_SRC]); if(mode == algorithm::pooling_max) { // input - mkldnnUtils::loadDataToMklStream(input, engine, stream, x_user_md, op_ff_prim_desc.src_desc(), args[DNNL_ARG_SRC]); + mkldnnUtils::loadDataToMklStream(*input, engine, stream, x_user_md, op_ff_prim_desc.src_desc(), args[DNNL_ARG_SRC]); // z auto z_mkl_mem = dnnl::memory(op_ff_prim_desc.dst_desc(), engine); @@ -310,10 +293,9 @@ void poolingBpMKLDNN(const NDArray *input, const NDArray *gradO, NDArray *gradI, // run backward calculations dnnl::pooling_backward(op_bp_prim_desc).execute(stream, args); - // reorder gradI if necessary - if (gradIReorder) - dnnl::reorder(gradI_mkl_mem, gradI_user_mem).execute(stream, gradI_mkl_mem, gradI_user_mem); + if (op_bp_prim_desc.diff_src_desc() != gradI_user_mem.get_desc()) + dnnl::reorder(args[DNNL_ARG_DIFF_SRC], gradI_user_mem).execute(stream, args[DNNL_ARG_DIFF_SRC], gradI_user_mem); stream.wait(); } diff --git a/libnd4j/include/ops/declarable/platform/mkldnn/mkldnnUtils.h b/libnd4j/include/ops/declarable/platform/mkldnn/mkldnnUtils.h index 514a325c7..f3ff327a4 100644 --- a/libnd4j/include/ops/declarable/platform/mkldnn/mkldnnUtils.h +++ b/libnd4j/include/ops/declarable/platform/mkldnn/mkldnnUtils.h @@ -100,6 +100,8 @@ namespace sd { DECLARE_PLATFORM(xw_plus_b_bp, ENGINE_CPU); + DECLARE_PLATFORM(concat, ENGINE_CPU); + } } @@ -123,19 +125,13 @@ namespace sd { */ void getDims(const NDArray* array, const int rank, dnnl::memory::dims& mklDims); /** - * This function generate memory format tag based on rank - * @param const array rank + * This function evaluate memory format tag based on array shapeInfo + * @param const array * @return memory format */ - dnnl::memory::format_tag getFormat(const int rank); - /** - * This function generate memory format tag based on rank - * @param const pointer to dataset - * @param const dataset rank - * @param reference to memory descriptor - * @return memory format - */ - void setBlockStrides(const NDArray* array, dnnl::memory::desc& mklMd); + dnnl::memory::format_tag getFormat(const NDArray& arr); + + void setBlockStrides(const NDArray& array, dnnl::memory::desc& mklMd, const std::vector& permut = {}); ////////////////////////////////////////////////////////////////////// /** * This function load and reorder user memory to mkl @@ -147,7 +143,7 @@ namespace sd { * @param primitive memory descriptor * @param dnnl arg activation enumerator */ - void loadDataToMklStream(const NDArray* array, const dnnl::engine& engine, const dnnl::stream& stream, const dnnl::memory::desc& user_md, const dnnl::memory::desc& primitive_md, + dnnl::memory loadDataToMklStream(const NDArray& array, const dnnl::engine& engine, const dnnl::stream& stream, const dnnl::memory::desc& user_md, const dnnl::memory::desc& primitive_md, dnnl::memory& arg); /** diff --git a/libnd4j/include/ops/declarable/platform/mkldnn/softmax.cpp b/libnd4j/include/ops/declarable/platform/mkldnn/softmax.cpp index 932affbd3..9935fd50f 100644 --- a/libnd4j/include/ops/declarable/platform/mkldnn/softmax.cpp +++ b/libnd4j/include/ops/declarable/platform/mkldnn/softmax.cpp @@ -35,32 +35,37 @@ namespace sd { ////////////////////////////////////////////////////////////////////// static void softmaxMKLDNN(const NDArray* x, NDArray* z, const int axis) { - const auto xRank = x->rankOf(); - dnnl::memory::dims xShape, zShape; + dnnl::memory::dims shape = x->getShapeAsFlatVector(); - mkldnnUtils::getDims(x, xRank, xShape); - mkldnnUtils::getDims(z, xRank, zShape); + const int xRank = x->rankOf(); + dnnl::memory::format_tag xFormat = mkldnnUtils::getFormat(*x); + dnnl::memory::format_tag zFormat = mkldnnUtils::getFormat(*z); - dnnl::memory::format_tag format = mkldnnUtils::getFormat(xRank); // optimized cases if (2 == xRank && 0 == axis) { - format = dnnl::memory::format_tag::ba; + if(x->ews() == 1) + xFormat = dnnl::memory::format_tag::ba; + if(z->ews() == 1) + zFormat = dnnl::memory::format_tag::ba; } else if (4 == xRank && 1 == axis && (x->sizeAt(2) * x->sizeAt(3)) > 1) { - format = dnnl::memory::format_tag::acdb; + if(x->ews() == 1) + xFormat = dnnl::memory::format_tag::acdb; + if(z->ews() == 1) + zFormat = dnnl::memory::format_tag::acdb; } dnnl::memory::data_type xType = dnnl::memory::data_type::f32; - dnnl::memory::desc x_mkl_md = dnnl::memory::desc(xShape, xType, format); - dnnl::memory::desc x_user_md = dnnl::memory::desc(xShape, xType, format); - mkldnnUtils::setBlockStrides(x, x_user_md); + dnnl::memory::desc x_mkl_md, x_user_md, z_mkl_md, z_user_md; + + x_user_md = x_mkl_md = dnnl::memory::desc(shape, xType, xFormat); + mkldnnUtils::setBlockStrides(*x, x_user_md); // z - dnnl::memory::desc z_mkl_md = dnnl::memory::desc(zShape, xType, format); - dnnl::memory::desc z_user_md = dnnl::memory::desc(zShape, xType, format); - mkldnnUtils::setBlockStrides(z, z_user_md); + z_user_md = z_mkl_md = dnnl::memory::desc(shape, xType, zFormat); + mkldnnUtils::setBlockStrides(*z, z_user_md); auto engine = mkldnnUtils::getEngine(LaunchContext::defaultContext()->engine()); @@ -80,20 +85,17 @@ namespace sd { // provide memory buffers and check whether reorder is required // input - mkldnnUtils::loadDataToMklStream(x, engine, stream, x_user_md, op_prim_desc.src_desc(), args[DNNL_ARG_SRC]); + mkldnnUtils::loadDataToMklStream(*x, engine, stream, x_user_md, op_prim_desc.src_desc(), args[DNNL_ARG_SRC]); // z - auto z_user_mem = dnnl::memory(z_user_md, engine, z->buffer()); - const bool zReorder = op_prim_desc.dst_desc() != z_user_mem.get_desc(); - auto z_mkl_mem = zReorder ? dnnl::memory(op_prim_desc.dst_desc(), engine) : z_user_mem; - args[DNNL_ARG_DST] = z_mkl_mem; + auto z_user_mem = mkldnnUtils::loadDataToMklStream(*z, engine, stream, z_user_md, op_prim_desc.dst_desc(), args[DNNL_ARG_DST]); // run calculations dnnl::softmax_forward(op_prim_desc).execute(stream, args); // reorder outputs if necessary - if (zReorder) - dnnl::reorder(z_mkl_mem, z_user_mem).execute(stream, z_mkl_mem, z_user_mem); + if (op_prim_desc.dst_desc() != z_user_mem.get_desc()) + dnnl::reorder(args[DNNL_ARG_DST], z_user_mem).execute(stream, args[DNNL_ARG_DST], z_user_mem); stream.wait(); } @@ -142,33 +144,19 @@ namespace sd { ////////////////////////////////////////////////////////////////////// static void softmaxBpMKLDNN(const NDArray* x, const NDArray* dLdz, NDArray* dLdx, const int axis) { - const auto xRank = x->rankOf(); - const auto dLdzRank = dLdz->rankOf(); - - dnnl::memory::dims xShape, dLdxShape, dLdzShape; - - mkldnnUtils::getDims(x, xRank, xShape); - mkldnnUtils::getDims(dLdx, xRank, dLdxShape); - mkldnnUtils::getDims(dLdz, dLdzRank, dLdzShape); - - dnnl::memory::format_tag format = mkldnnUtils::getFormat(xRank); + dnnl::memory::desc x_user_md, x_mkl_md, dLdx_mkl_md, dLdx_user_md, dLdz_mkl_md, dLdz_user_md; // x - dnnl::memory::desc x_mkl_md = dnnl::memory::desc(xShape, dnnl::memory::data_type::f32, format); - dnnl::memory::desc x_user_md = dnnl::memory::desc(xShape, dnnl::memory::data_type::f32, format); - mkldnnUtils::setBlockStrides(x, x_user_md); + x_mkl_md = x_user_md = dnnl::memory::desc(x->getShapeAsFlatVector(), dnnl::memory::data_type::f32, mkldnnUtils::getFormat(*x)); + mkldnnUtils::setBlockStrides(*x, x_user_md); // dLdx - dnnl::memory::desc dLdx_mkl_md = dnnl::memory::desc(dLdxShape, dnnl::memory::data_type::f32, format); - dnnl::memory::desc dLdx_user_md = dnnl::memory::desc(dLdxShape, dnnl::memory::data_type::f32, format); - mkldnnUtils::setBlockStrides(dLdx, dLdx_user_md); - // todo if mkl does not support broadcast we can remove this - format = mkldnnUtils::getFormat(dLdzRank); + dLdx_mkl_md = dLdx_user_md = dnnl::memory::desc(dLdx->getShapeAsFlatVector(), dnnl::memory::data_type::f32, mkldnnUtils::getFormat(*dLdx)); + mkldnnUtils::setBlockStrides(*dLdx, dLdx_user_md); // dLdz - dnnl::memory::desc dLdz_mkl_md = dnnl::memory::desc(dLdzShape, dnnl::memory::data_type::f32, format); - dnnl::memory::desc dLdz_user_md = dnnl::memory::desc(dLdzShape, dnnl::memory::data_type::f32, format); - mkldnnUtils::setBlockStrides(dLdz, dLdz_user_md); + dLdz_mkl_md = dLdz_user_md = dnnl::memory::desc(dLdz->getShapeAsFlatVector(), dnnl::memory::data_type::f32, mkldnnUtils::getFormat(*dLdz)); + mkldnnUtils::setBlockStrides(*dLdz, dLdz_user_md); auto engine = mkldnnUtils::getEngine(LaunchContext::defaultContext()->engine()); @@ -188,19 +176,18 @@ namespace sd { // provide memory buffers and check whether reorder is required for forward // input - mkldnnUtils::loadDataToMklStream(x, engine, stream, x_user_md, op_ff_prim_desc.src_desc(), argsff[DNNL_ARG_SRC]); + mkldnnUtils::loadDataToMklStream(*x, engine, stream, x_user_md, op_ff_prim_desc.src_desc(), argsff[DNNL_ARG_SRC]); + + // dLdz + mkldnnUtils::loadDataToMklStream(*dLdz, engine, stream, dLdz_user_md, op_bp_prim_desc.diff_dst_desc(), argsbp[DNNL_ARG_DIFF_DST]); // dLdx - auto dLdx_user_mem = dnnl::memory(dLdx_user_md, engine, dLdx->buffer()); - const bool dLdxReorder = op_ff_prim_desc.dst_desc() != dLdx_user_mem.get_desc(); - auto dLdx_mkl_mem = dLdxReorder ? dnnl::memory(op_ff_prim_desc.dst_desc(), engine) : dLdx_user_mem; - argsff[DNNL_ARG_DST] = dLdx_mkl_mem; + auto dLdx_user_mem = mkldnnUtils::loadDataToMklStream(*dLdx, engine, stream, dLdx_user_md, op_ff_prim_desc.src_desc(), argsff[DNNL_ARG_DST]); // check and arg set for backprob - argsbp[DNNL_ARG_DIFF_SRC] = dLdx_mkl_mem; - argsbp[DNNL_ARG_DST] = dLdx_mkl_mem; - // dLdz - mkldnnUtils::loadDataToMklStream(dLdz, engine, stream, dLdz_user_md, op_bp_prim_desc.diff_dst_desc(), argsbp[DNNL_ARG_DIFF_DST]); + argsbp[DNNL_ARG_DIFF_SRC] = argsff[DNNL_ARG_DST]; + argsbp[DNNL_ARG_DST] = argsff[DNNL_ARG_DST]; + // run calculations forward dnnl::softmax_forward(op_ff_prim_desc).execute(stream, argsff); @@ -209,8 +196,8 @@ namespace sd { dnnl::softmax_backward(op_bp_prim_desc).execute(stream, argsbp); // reorder outputs if necessary - if (dLdxReorder) - dnnl::reorder(dLdx_mkl_mem, dLdx_user_mem).execute(stream, dLdx_mkl_mem, dLdx_user_mem); + if (op_ff_prim_desc.dst_desc() != dLdx_user_mem.get_desc()) + dnnl::reorder(argsff[DNNL_ARG_DST], dLdx_user_mem).execute(stream, argsff[DNNL_ARG_DST], dLdx_user_mem); stream.wait(); } diff --git a/libnd4j/include/ops/declarable/platform/mkldnn/tanh.cpp b/libnd4j/include/ops/declarable/platform/mkldnn/tanh.cpp index 53d75d0a9..a808239de 100644 --- a/libnd4j/include/ops/declarable/platform/mkldnn/tanh.cpp +++ b/libnd4j/include/ops/declarable/platform/mkldnn/tanh.cpp @@ -34,22 +34,16 @@ namespace sd { ////////////////////////////////////////////////////////////////////// static void tanhMKLDNN(const NDArray* x, NDArray* z) { - const auto xRank = x->rankOf(); - dnnl::memory::dims xShape, zShape; + dnnl::memory::dims shape = x->getShapeAsFlatVector(); - mkldnnUtils::getDims(x, xRank, xShape); - mkldnnUtils::getDims(z, xRank, zShape); + dnnl::memory::desc x_mkl_md, x_user_md, z_mkl_md, z_user_md; - dnnl::memory::format_tag format = mkldnnUtils::getFormat(xRank); - - dnnl::memory::desc x_mkl_md = dnnl::memory::desc(xShape, dnnl::memory::data_type::f32, format); - dnnl::memory::desc x_user_md = dnnl::memory::desc(xShape, dnnl::memory::data_type::f32, format); - mkldnnUtils::setBlockStrides(x, x_user_md); + x_user_md = x_mkl_md = dnnl::memory::desc(shape, dnnl::memory::data_type::f32, mkldnnUtils::getFormat(*x)); + mkldnnUtils::setBlockStrides(*x, x_user_md); // z - dnnl::memory::desc z_mkl_md = dnnl::memory::desc(zShape, dnnl::memory::data_type::f32, format); - dnnl::memory::desc z_user_md = dnnl::memory::desc(zShape, dnnl::memory::data_type::f32, format); - mkldnnUtils::setBlockStrides(z, z_user_md); + z_user_md = z_mkl_md = dnnl::memory::desc(shape, dnnl::memory::data_type::f32, mkldnnUtils::getFormat(*z)); + mkldnnUtils::setBlockStrides(*z, z_user_md); auto engine = mkldnnUtils::getEngine(LaunchContext::defaultContext()->engine()); @@ -68,20 +62,17 @@ namespace sd { // provide memory buffers and check whether reorder is required // input - mkldnnUtils::loadDataToMklStream(x, engine, stream, x_user_md, op_prim_desc.src_desc(), args[DNNL_ARG_SRC]); + mkldnnUtils::loadDataToMklStream(*x, engine, stream, x_user_md, op_prim_desc.src_desc(), args[DNNL_ARG_SRC]); // z - auto z_user_mem = dnnl::memory(z_user_md, engine, z->buffer()); - const bool zReorder = op_prim_desc.dst_desc() != z_user_mem.get_desc(); - auto z_mkl_mem = zReorder ? dnnl::memory(op_prim_desc.dst_desc(), engine) : z_user_mem; - args[DNNL_ARG_DST] = z_mkl_mem; + auto z_user_mem = mkldnnUtils::loadDataToMklStream(*z, engine, stream, z_user_md, op_prim_desc.dst_desc(), args[DNNL_ARG_DST]); // run calculations dnnl::eltwise_forward(op_prim_desc).execute(stream, args); // reorder outputs if necessary - if (zReorder) - dnnl::reorder(z_mkl_mem, z_user_mem).execute(stream, z_mkl_mem, z_user_mem); + if (op_prim_desc.dst_desc() != z_user_mem.get_desc()) + dnnl::reorder(args[DNNL_ARG_DST], z_user_mem).execute(stream, args[DNNL_ARG_DST], z_user_mem); stream.wait(); } @@ -121,28 +112,21 @@ namespace sd { ////////////////////////////////////////////////////////////////////// static void tanhBpMKLDNN(const NDArray* x, const NDArray* dLdz, NDArray* dLdx) { - const auto xRank = x->rankOf(); - dnnl::memory::dims xShape, dLdzShape, dLdxShape; + dnnl::memory::dims shape = x->getShapeAsFlatVector(); - mkldnnUtils::getDims(x, xRank, xShape); - mkldnnUtils::getDims(dLdz, xRank, dLdzShape); - mkldnnUtils::getDims(dLdx, xRank, dLdxShape); + dnnl::memory::desc x_mkl_md, x_user_md, dLdx_mkl_md, dLdx_user_md, dLdz_mkl_md, dLdz_user_md; - dnnl::memory::format_tag format = mkldnnUtils::getFormat(xRank); - - dnnl::memory::desc x_mkl_md = dnnl::memory::desc(xShape, dnnl::memory::data_type::f32, format); - dnnl::memory::desc x_user_md = dnnl::memory::desc(xShape, dnnl::memory::data_type::f32, format); - mkldnnUtils::setBlockStrides(x, x_user_md); + // x + x_user_md = x_mkl_md = dnnl::memory::desc(shape, dnnl::memory::data_type::f32, mkldnnUtils::getFormat(*x)); + mkldnnUtils::setBlockStrides(*x, x_user_md); // dLdz - dnnl::memory::desc dLdz_mkl_md = dnnl::memory::desc(xShape, dnnl::memory::data_type::f32, format); - dnnl::memory::desc dLdz_user_md = dnnl::memory::desc(xShape, dnnl::memory::data_type::f32, format); - mkldnnUtils::setBlockStrides(dLdz, dLdz_user_md); + dLdz_user_md = dLdz_mkl_md = dnnl::memory::desc(shape, dnnl::memory::data_type::f32, mkldnnUtils::getFormat(*dLdz)); + mkldnnUtils::setBlockStrides(*dLdz, dLdz_user_md); // dLdx - dnnl::memory::desc dLdx_mkl_md = dnnl::memory::desc(xShape, dnnl::memory::data_type::f32, format); - dnnl::memory::desc dLdx_user_md = dnnl::memory::desc(xShape, dnnl::memory::data_type::f32, format); - mkldnnUtils::setBlockStrides(dLdx, dLdx_user_md); + dLdx_user_md = dLdx_mkl_md = dnnl::memory::desc(shape, dnnl::memory::data_type::f32, mkldnnUtils::getFormat(*dLdx)); + mkldnnUtils::setBlockStrides(*dLdx, dLdx_user_md); auto engine = mkldnnUtils::getEngine(LaunchContext::defaultContext()->engine()); @@ -162,23 +146,20 @@ namespace sd { // provide memory buffers and check whether reorder is required for forward // input - mkldnnUtils::loadDataToMklStream(x, engine, stream, x_user_md, op_prim_desc.src_desc(), args[DNNL_ARG_SRC]); + mkldnnUtils::loadDataToMklStream(*x, engine, stream, x_user_md, op_prim_desc.src_desc(), args[DNNL_ARG_SRC]); // dLdz - mkldnnUtils::loadDataToMklStream(dLdz, engine, stream, dLdz_user_md, op_prim_desc.diff_dst_desc(), args[DNNL_ARG_DIFF_DST]); + mkldnnUtils::loadDataToMklStream(*dLdz, engine, stream, dLdz_user_md, op_prim_desc.diff_dst_desc(), args[DNNL_ARG_DIFF_DST]); // dLdx - auto dLdx_user_mem = dnnl::memory(dLdx_user_md, engine, dLdx->buffer()); - const bool dLdxReorder = op_prim_desc.diff_src_desc() != dLdx_user_mem.get_desc(); - auto dLdx_mkl_mem = dLdxReorder ? dnnl::memory(op_prim_desc.diff_src_desc(), engine) : dLdx_user_mem; - args[DNNL_ARG_DIFF_SRC] = dLdx_mkl_mem; + auto dLdx_user_mem = mkldnnUtils::loadDataToMklStream(*dLdx, engine, stream, dLdx_user_md, op_prim_desc.diff_src_desc(), args[DNNL_ARG_DIFF_SRC]); // run calculations backward dnnl::eltwise_backward(op_prim_desc).execute(stream, args); // reorder outputs if necessary - if (dLdxReorder) - dnnl::reorder(dLdx_mkl_mem, dLdx_user_mem).execute(stream, dLdx_mkl_mem, dLdx_user_mem); + if (op_prim_desc.diff_src_desc() != dLdx_user_mem.get_desc()) + dnnl::reorder(args[DNNL_ARG_DIFF_SRC], dLdx_user_mem).execute(stream, args[DNNL_ARG_DIFF_SRC], dLdx_user_mem); stream.wait(); } diff --git a/libnd4j/include/ops/declarable/platform/mkldnn/xw_plus_b.cpp b/libnd4j/include/ops/declarable/platform/mkldnn/xw_plus_b.cpp index ab7f340ed..1097ccd34 100644 --- a/libnd4j/include/ops/declarable/platform/mkldnn/xw_plus_b.cpp +++ b/libnd4j/include/ops/declarable/platform/mkldnn/xw_plus_b.cpp @@ -82,33 +82,23 @@ namespace sd { // memory descriptors for arrays // x dnnl::memory::desc x_mkl_md = dnnl::memory::desc(xShape, xType, dnnl::memory::format_tag::any); - dnnl::memory::desc x_user_md = dnnl::memory::desc(xShape, xType, format); - mkldnnUtils::setBlockStrides(x, x_user_md); + dnnl::memory::desc x_user_md = dnnl::memory::desc(xShape, xType, mkldnnUtils::getFormat(*x)); + mkldnnUtils::setBlockStrides(*x, x_user_md); // weights dnnl::memory::desc weights_mkl_md = dnnl::memory::desc(wShape, wType, dnnl::memory::format_tag::any); - dnnl::memory::desc weights_user_md = dnnl::memory::desc(wShape, wType, format); - if (weights->ews() != 1 || weights->ordering() != 'c' || bShouldTransp) { + dnnl::memory::desc weights_user_md = dnnl::memory::desc(wShape, wType, mkldnnUtils::getFormat(*weights)); + mkldnnUtils::setBlockStrides(*weights, weights_user_md, bShouldTransp ? std::vector({1,0}) : std::vector()); - weights_user_md.data.format_kind = dnnl_blocked; // overrides format - if (bShouldTransp) { - weights_user_md.data.format_desc.blocking.strides[0] = weights->strideAt(1); - weights_user_md.data.format_desc.blocking.strides[1] = weights->strideAt(0); - } - else { - weights_user_md.data.format_desc.blocking.strides[0] = weights->strideAt(0); - weights_user_md.data.format_desc.blocking.strides[1] = weights->strideAt(1); - } - } // bias - dnnl::memory::desc bias_mkl_md = dnnl::memory::desc(bShape, bType, dnnl::memory::format_tag::x); - dnnl::memory::desc bias_user_md = dnnl::memory::desc(bShape, bType, dnnl::memory::format_tag::x); - mkldnnUtils::setBlockStrides(bias, bias_user_md); + dnnl::memory::desc bias_mkl_md = dnnl::memory::desc(bShape, bType, dnnl::memory::format_tag::a); + dnnl::memory::desc bias_user_md = dnnl::memory::desc(bShape, bType, dnnl::memory::format_tag::a); + mkldnnUtils::setBlockStrides(*bias, bias_user_md); // z dnnl::memory::desc z_mkl_md = dnnl::memory::desc(zShape, zType, dnnl::memory::format_tag::any); - dnnl::memory::desc z_user_md = dnnl::memory::desc(zShape, zType, format); - mkldnnUtils::setBlockStrides(z, z_user_md); + dnnl::memory::desc z_user_md = dnnl::memory::desc(zShape, zType, mkldnnUtils::getFormat(*z)); + mkldnnUtils::setBlockStrides(*z, z_user_md); auto engine = mkldnnUtils::getEngine(LaunchContext::defaultContext()->engine()); @@ -125,27 +115,24 @@ namespace sd { // provide memory buffers and check whether reorder is required // input - mkldnnUtils::loadDataToMklStream(x, engine, stream, x_user_md, op_prim_desc.src_desc(), args[DNNL_ARG_SRC]); + mkldnnUtils::loadDataToMklStream(*x, engine, stream, x_user_md, op_prim_desc.src_desc(), args[DNNL_ARG_SRC]); // weights - mkldnnUtils::loadDataToMklStream(weights, engine, stream, weights_user_md, op_prim_desc.weights_desc(), args[DNNL_ARG_WEIGHTS]); + mkldnnUtils::loadDataToMklStream(*weights, engine, stream, weights_user_md, op_prim_desc.weights_desc(), args[DNNL_ARG_WEIGHTS]); // bias auto bias_mkl_mem = dnnl::memory(bias_mkl_md, engine, const_cast(bias->buffer())); args[DNNL_ARG_BIAS] = bias_mkl_mem; // z - auto z_user_mem = dnnl::memory(z_user_md, engine, z->buffer()); - const bool zReorder = op_prim_desc.dst_desc() != z_user_mem.get_desc(); - auto z_mkl_mem = zReorder ? dnnl::memory(op_prim_desc.dst_desc(), engine) : z_user_mem; - args[DNNL_ARG_DST] = z_mkl_mem; + auto z_user_mem = mkldnnUtils::loadDataToMklStream(*z, engine, stream, z_user_md, op_prim_desc.dst_desc(), args[DNNL_ARG_DST]); // run calculations dnnl::inner_product_forward(op_prim_desc).execute(stream, args); // reorder outputs if necessary - if (zReorder) - dnnl::reorder(z_mkl_mem, z_user_mem).execute(stream, z_mkl_mem, z_user_mem); + if (op_prim_desc.dst_desc() != z_user_mem.get_desc()) + dnnl::reorder(args[DNNL_ARG_DST], z_user_mem).execute(stream, args[DNNL_ARG_DST], z_user_mem); stream.wait(); } @@ -160,7 +147,7 @@ namespace sd { // [M,K] x [K,N] = [M,N] const int M = x->sizeAt(0); - const int K = x->sizeAt(1); // K == wK + const int K = x->sizeAt(1); // K == wK const int N = dLdz->sizeAt(1); // input dims dnnl::memory::dims xShape = dnnl::memory::dims({ M, K }); @@ -168,71 +155,53 @@ namespace sd { dnnl::memory::dims dLdzShape = dnnl::memory::dims({ M, N }); dnnl::memory::dims bShape = dnnl::memory::dims({ N }); + // output dims dnnl::memory::dims dLdxShape = xShape; dnnl::memory::dims dLdwShape = wShape; - dnnl::memory::format_tag format = dnnl::memory::format_tag::ab; dnnl::memory::data_type dataType = dnnl::memory::data_type::f32; // memory descriptors for arrays // x dnnl::memory::desc x_mkl_md = dnnl::memory::desc(xShape, dataType, dnnl::memory::format_tag::any); - dnnl::memory::desc x_user_md = dnnl::memory::desc(xShape, dataType, format); - mkldnnUtils::setBlockStrides(x, x_user_md); + dnnl::memory::desc x_user_md = dnnl::memory::desc(xShape, dataType, mkldnnUtils::getFormat(*x)); + mkldnnUtils::setBlockStrides(*x, x_user_md); // weights dnnl::memory::desc weights_mkl_md = dnnl::memory::desc(wShape, dataType, dnnl::memory::format_tag::any); - dnnl::memory::desc weights_user_md = dnnl::memory::desc(wShape, dataType, format); - if (weights->ews() != 1 || weights->ordering() != 'c' || bShouldTransp) { + dnnl::memory::desc weights_user_md = dnnl::memory::desc(wShape, dataType, mkldnnUtils::getFormat(*weights)); + mkldnnUtils::setBlockStrides(*weights, weights_user_md, bShouldTransp ? std::vector({1,0}) : std::vector()); - weights_user_md.data.format_kind = dnnl_blocked; // overrides format - if (bShouldTransp) { - weights_user_md.data.format_desc.blocking.strides[0] = weights->strideAt(1); - weights_user_md.data.format_desc.blocking.strides[1] = weights->strideAt(0); - } - else { - weights_user_md.data.format_desc.blocking.strides[0] = weights->strideAt(0); - weights_user_md.data.format_desc.blocking.strides[1] = weights->strideAt(1); - } - } // bias - dnnl::memory::desc bias_mkl_md = dnnl::memory::desc(bShape, dataType, dnnl::memory::format_tag::x); - dnnl::memory::desc bias_user_md = dnnl::memory::desc(bShape, dataType, dnnl::memory::format_tag::x); - mkldnnUtils::setBlockStrides(bias, bias_user_md); + dnnl::memory::desc bias_mkl_md = dnnl::memory::desc(bShape, dataType, dnnl::memory::format_tag::any); + dnnl::memory::desc bias_user_md = dnnl::memory::desc(bShape, dataType, mkldnnUtils::getFormat(*bias)); + mkldnnUtils::setBlockStrides(*bias, bias_user_md); // dLdz dnnl::memory::desc dLdz_mkl_md = dnnl::memory::desc(dLdzShape, dataType, dnnl::memory::format_tag::any); - dnnl::memory::desc dLdz_user_md = dnnl::memory::desc(dLdzShape, dataType, format); - mkldnnUtils::setBlockStrides(dLdz, dLdz_user_md); + dnnl::memory::desc dLdz_user_md = dnnl::memory::desc(dLdzShape, dataType, mkldnnUtils::getFormat(*dLdz)); + mkldnnUtils::setBlockStrides(*dLdz, dLdz_user_md); + // dLdw - dnnl::memory::desc dLdw_mkl_md = dnnl::memory::desc(wShape, dataType, format); - dnnl::memory::desc dLdw_user_md = dnnl::memory::desc(wShape, dataType, format); - if (dLdw->ews() != 1 || dLdw->ordering() != 'c' || bShouldTransp) { - - dLdw_user_md.data.format_kind = dnnl_blocked; // overrides format - if (bShouldTransp) { - dLdw_user_md.data.format_desc.blocking.strides[0] = dLdw->strideAt(1); - dLdw_user_md.data.format_desc.blocking.strides[1] = dLdw->strideAt(0); - } - else { - dLdw_user_md.data.format_desc.blocking.strides[0] = dLdw->strideAt(0); - dLdw_user_md.data.format_desc.blocking.strides[1] = dLdw->strideAt(1); - } - } + dnnl::memory::desc dLdw_mkl_md = dnnl::memory::desc(wShape, dataType, dnnl::memory::format_tag::any); + dnnl::memory::desc dLdw_user_md = dnnl::memory::desc(wShape, dataType, mkldnnUtils::getFormat(*dLdw)); + mkldnnUtils::setBlockStrides(*dLdw, dLdw_user_md, bShouldTransp ? std::vector({1,0}) : std::vector()); // dLdb - dnnl::memory::desc dLdb_mkl_md = dnnl::memory::desc(bShape, dataType, dnnl::memory::format_tag::x); - dnnl::memory::desc dLdb_user_md = dnnl::memory::desc(bShape, dataType, dnnl::memory::format_tag::x); - mkldnnUtils::setBlockStrides(dLdb, dLdb_user_md); + dnnl::memory::desc dLdb_mkl_md = dnnl::memory::desc(bShape, dataType, dnnl::memory::format_tag::any); + dnnl::memory::desc dLdb_user_md = dnnl::memory::desc(bShape, dataType, mkldnnUtils::getFormat(*dLdb)); + mkldnnUtils::setBlockStrides(*dLdb, dLdb_user_md); // dLdx dnnl::memory::desc dLdx_mkl_md = dnnl::memory::desc(xShape, dataType, dnnl::memory::format_tag::any); - dnnl::memory::desc dLdx_user_md = dnnl::memory::desc(xShape, dataType, format); - mkldnnUtils::setBlockStrides(dLdx, dLdx_user_md); + dnnl::memory::desc dLdx_user_md = dnnl::memory::desc(xShape, dataType, mkldnnUtils::getFormat(*dLdx)); + mkldnnUtils::setBlockStrides(*dLdx, dLdx_user_md); + // create engine auto engine = mkldnnUtils::getEngine(LaunchContext::defaultContext()->engine()); + // forward // operation primitive description dnnl::inner_product_forward::desc op_ff_desc(dnnl::prop_kind::forward_inference, x_mkl_md, weights_mkl_md, bias_mkl_md, dLdz_mkl_md); @@ -254,34 +223,25 @@ namespace sd { dnnl::stream stream(engine); // dLdz dw - mkldnnUtils::loadDataToMklStream(dLdz, engine, stream, dLdz_user_md, op_bpdw_prim_desc.diff_dst_desc(), argsDw[DNNL_ARG_DIFF_DST]); + mkldnnUtils::loadDataToMklStream(*dLdz, engine, stream, dLdz_user_md, op_bpdw_prim_desc.diff_dst_desc(), argsDw[DNNL_ARG_DIFF_DST]); // dLdz - dx - mkldnnUtils::loadDataToMklStream(dLdz, engine, stream, dLdz_user_md, op_bpdx_prim_desc.diff_dst_desc(), argsDx[DNNL_ARG_DIFF_DST]); + mkldnnUtils::loadDataToMklStream(*dLdz, engine, stream, dLdz_user_md, op_bpdx_prim_desc.diff_dst_desc(), argsDx[DNNL_ARG_DIFF_DST]); // input x for dw - mkldnnUtils::loadDataToMklStream(x, engine, stream, x_user_md, op_bpdw_prim_desc.src_desc(), argsDw[DNNL_ARG_SRC]); + mkldnnUtils::loadDataToMklStream(*x, engine, stream, x_user_md, op_bpdw_prim_desc.src_desc(), argsDw[DNNL_ARG_SRC]); // weights - dx - mkldnnUtils::loadDataToMklStream(weights, engine, stream, weights_user_md, op_bpdx_prim_desc.weights_desc(), argsDx[DNNL_ARG_WEIGHTS]); + mkldnnUtils::loadDataToMklStream(*weights, engine, stream, weights_user_md, op_bpdx_prim_desc.weights_desc(), argsDx[DNNL_ARG_WEIGHTS]); - // dLdw - auto dLdw_user_mem = dnnl::memory(dLdw_user_md, engine, dLdw->buffer()); - const bool dLdwReorder = op_bpdw_prim_desc.diff_weights_desc() != dLdw_user_mem.get_desc(); - auto dLdw_mkl_mem = dLdwReorder ? dnnl::memory(op_bpdw_prim_desc.diff_weights_desc(), engine) : dLdw_user_mem; - argsDw[DNNL_ARG_DIFF_WEIGHTS] = dLdw_mkl_mem; + // dLdw + auto dLdw_user_mem = mkldnnUtils::loadDataToMklStream(*dLdw, engine, stream, dLdw_user_md, op_bpdw_prim_desc.diff_weights_desc(), argsDw[DNNL_ARG_DIFF_WEIGHTS]); - // dLdx - auto dLdx_user_mem = dnnl::memory(dLdx_user_md, engine, dLdx->buffer()); - const bool dLdxReorder = op_bpdx_prim_desc.diff_src_desc() != dLdx_user_mem.get_desc(); - auto dLdx_mkl_mem = dLdxReorder ? dnnl::memory(op_bpdx_prim_desc.diff_src_desc(), engine) : dLdx_user_mem; - argsDx[DNNL_ARG_DIFF_SRC] = dLdx_mkl_mem; + // dLdx + auto dLdx_user_mem = mkldnnUtils::loadDataToMklStream(*dLdx, engine, stream, dLdx_user_md, op_bpdx_prim_desc.diff_src_desc(), argsDx[DNNL_ARG_DIFF_SRC]); // dLdb - auto dLdb_user_mem = dnnl::memory(dLdb_user_md, engine, dLdb->buffer()); - const bool dLdbReorder = op_bpdw_prim_desc.diff_bias_desc() != dLdb_user_mem.get_desc(); - auto dLdb_mkl_mem = dLdbReorder ? dnnl::memory(op_bpdw_prim_desc.diff_bias_desc(), engine) : dLdb_user_mem; - argsDw[DNNL_ARG_DIFF_BIAS] = dLdb_mkl_mem; + auto dLdb_user_mem = mkldnnUtils::loadDataToMklStream(*dLdb, engine, stream, dLdb_user_md, op_bpdw_prim_desc.diff_bias_desc(), argsDw[DNNL_ARG_DIFF_BIAS]); // run calculations dw dnnl::inner_product_backward_weights(op_bpdw_prim_desc).execute(stream, argsDw); @@ -289,14 +249,14 @@ namespace sd { dnnl::inner_product_backward_data(op_bpdx_prim_desc).execute(stream, argsDx); // reorder outputs if necessary - if (dLdxReorder) - dnnl::reorder(dLdx_mkl_mem, dLdx_user_mem).execute(stream, dLdx_mkl_mem, dLdx_user_mem); + if (op_bpdx_prim_desc.diff_src_desc() != dLdx_user_mem.get_desc()) + dnnl::reorder(argsDx[DNNL_ARG_DIFF_SRC], dLdx_user_mem).execute(stream, argsDx[DNNL_ARG_DIFF_SRC], dLdx_user_mem); - if (dLdwReorder) - dnnl::reorder(dLdw_mkl_mem, dLdw_user_mem).execute(stream, dLdw_mkl_mem, dLdw_user_mem); + if (op_bpdw_prim_desc.diff_weights_desc() != dLdw_user_mem.get_desc()) + dnnl::reorder(argsDw[DNNL_ARG_DIFF_WEIGHTS], dLdw_user_mem).execute(stream, argsDw[DNNL_ARG_DIFF_WEIGHTS], dLdw_user_mem); - if (dLdbReorder) - dnnl::reorder(dLdb_mkl_mem, dLdb_user_mem).execute(stream, dLdb_mkl_mem, dLdb_user_mem); + if (op_bpdw_prim_desc.diff_bias_desc() != dLdb_user_mem.get_desc()) + dnnl::reorder(argsDw[DNNL_ARG_DIFF_BIAS], dLdb_user_mem).execute(stream, argsDw[DNNL_ARG_DIFF_BIAS], dLdb_user_mem); stream.wait(); } @@ -315,7 +275,7 @@ namespace sd { const int wRank = w->rankOf(); const int zRank = z->rankOf(); - const bool bShouldTransp = block.getIArguments()->size() > 0 ? (1 != INT_ARG(0)) : true; // [M,K] * [K,N] -> [M, N], mkl -> [M,K] * [N, K]^T -> [M, N] + const bool bShouldTransp = block.getIArguments()->size() > 0 ? (1 != INT_ARG(0)) : true; // [M,K] * [K,N] -> [M, N], mkl -> [M,K] * [N, K]^T -> [M, N] REQUIRE_TRUE(xRank == 2, 0, "xw_plus_b MKL: Input x array should have rank equal 2, but got instead %i!", xRank); REQUIRE_TRUE(wRank == 2, 0, "xw_plus_b MKL: Input weights array should have rank equal 2, but got instead %i!", wRank); @@ -378,7 +338,7 @@ namespace sd { const int wRank = w->rankOf(); const int dLdzRank = dLdz->rankOf(); - const bool bShouldTransp = block.getIArguments()->size() > 0 ? (1 != INT_ARG(0)) : true; // [M,K] * [K,N] -> [M, N], mkl -> [M,K] * [N, K]^T -> [M, N] + const bool bShouldTransp = block.getIArguments()->size() > 0 ? (1 != INT_ARG(0)) : true; // [M,K] * [K,N] -> [M, N], mkl -> [M,K] * [N, K]^T -> [M, N] REQUIRE_TRUE(x->rankOf() == 2, 0, "xw_plus_b BP MKL: Input x array should have rank equal 2, but got instead %i!", x->rankOf()); REQUIRE_TRUE(w->rankOf() == 2, 0, "xw_plus_b BP MKL: Input weights array should have rank equal 2, but got instead %i!", w->rankOf()); diff --git a/libnd4j/include/ops/impl/specials_single.hpp b/libnd4j/include/ops/impl/specials_single.hpp index 9a700251c..b6d717b83 100644 --- a/libnd4j/include/ops/impl/specials_single.hpp +++ b/libnd4j/include/ops/impl/specials_single.hpp @@ -107,6 +107,25 @@ namespace sd { // samediff::Threads::parallel_tad(func, 0, numOfArrs); // } +// static Nd4jLong strideOverContigAxis(const int axis, const Nd4jLong* inShapeInfo) { + +// Nd4jLong result = 9223372036854775807LL; + +// for(uint i = 0; i < shape::rank(inShapeInfo); ++i) { + +// const auto currentStride = shape::stride(inShapeInfo)[i]; + +// if(i == axis || shape::shapeOf(inShapeInfo)[i] == 1) +// continue; + +// if(result > currentStride) +// result = currentStride; +// } + +// return result == 9223372036854775807LL ? 1 : result; +// } + + template void SpecialMethods::concatCpuGeneric(const std::vector& inArrs, NDArray& output, const int axis) { @@ -150,7 +169,7 @@ void SpecialMethods::concatCpuGeneric(const std::vector& inAr // if(!areInputsContin || !allSameOrder) // break; - // strideOfContigStride[i] = shape::strideOverContigAxis(axis, inArrs[i]->shapeInfo()); + // strideOfContigStride[i] = strideOverContigAxis(axis, inArrs[i]->getShapeInfo()); // } // } @@ -158,7 +177,7 @@ void SpecialMethods::concatCpuGeneric(const std::vector& inAr // if(luckCase2) { // for example {2,1,3} + {2,5,3} + {2,10,3} = {2,16,3}, here axis 1 shoud have stride = 1 for all inputs arrays and output array - // const auto zStep = shape::strideOverContigAxis(axis, output.shapeInfo()); + // const auto zStep = strideOverContigAxis(axis, output.getShapeInfo()); // for (uint i = 0; i < output.lengthOf() / output.sizeAt(axis); ++i) { diff --git a/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests16.cpp b/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests16.cpp index c909b7686..cbec08c0c 100644 --- a/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests16.cpp +++ b/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests16.cpp @@ -184,7 +184,7 @@ TEST_F(DeclarableOpsTests16, test_range_2) { double tArgs[] = { -1.0, 1.0, 0.01 }; auto shapes = ::calculateOutputShapes2(nullptr, op.getOpHash(), nullptr, nullptr, 0, tArgs, 3, nullptr, 0, nullptr, 0, nullptr, 0); - shape::printShapeInfoLinear("Result", shapes->at(0)); + // shape::printShapeInfoLinear("Result", shapes->at(0)); ASSERT_TRUE(shape::shapeEquals(z.shapeInfo(), shapes->at(0))); delete shapes; @@ -426,7 +426,7 @@ TEST_F(DeclarableOpsTests16, test_rgb_to_hsv_6) { 0.928968489f, 0.684074104f }); - //get subarray + //get subarray //get subarray NDArray subArrRgbs = rgbs.subarray({ NDIndex::all(), NDIndex::point(0) }); NDArray expected = hsvs.subarray({ NDIndex::all(), NDIndex::point(0) }); @@ -627,7 +627,7 @@ TEST_F(DeclarableOpsTests16, test_hsv_to_rgb_6) { }); auto actual = NDArrayFactory::create('c', { 3 }); - //get subarray + //get subarray NDArray subArrHsvs = hsvs.subarray({ NDIndex::all(), NDIndex::point(0) }); subArrHsvs.reshapei({ 3 }); NDArray expected = rgbs.subarray({ NDIndex::all(), NDIndex::point(0) }); @@ -635,7 +635,7 @@ TEST_F(DeclarableOpsTests16, test_hsv_to_rgb_6) { #if 0 //[RANK][SHAPE][STRIDES][OPTIONS][EWS][ORDER] subArrHsvs.printShapeInfo("subArrHsvs"); -#endif +#endif Context ctx(1); ctx.setInputArray(0, &subArrHsvs); @@ -855,7 +855,7 @@ TEST_F(DeclarableOpsTests16, test_rgb_to_yiq_6) { -0.04447775f, -0.44518381f }); - //get subarray + //get subarray NDArray subArrRgbs = rgbs.subarray({ NDIndex::all(), NDIndex::point(0) }); NDArray expected = yiqs.subarray({ NDIndex::all(), NDIndex::point(0) }); subArrRgbs.reshapei({ 3 }); @@ -1054,7 +1054,7 @@ TEST_F(DeclarableOpsTests16, test_yiq_to_rgb_6) { 0.280231822f, 1.91936605f }); - //get subarray + //get subarray NDArray subArrYiqs = yiqs.subarray({ NDIndex::all(), NDIndex::point(0) }); NDArray expected = rgbs.subarray({ NDIndex::all(), NDIndex::point(0) }); subArrYiqs.reshapei({ 3 }); @@ -1074,3 +1074,422 @@ TEST_F(DeclarableOpsTests16, test_yiq_to_rgb_6) { ASSERT_EQ(ND4J_STATUS_OK, status); ASSERT_TRUE(expected.equalsTo(actual)); } + +//////////////////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests16, clipbynorm_1) { + auto x= NDArrayFactory::create('c', {2, 3}, {-3.0, 0.0, 0.0, 4.0, 0.0, 0.0}); + auto exp= NDArrayFactory::create('c', {2, 3}, {-2.4, 0.0, 0.0, 3.2, 0.0, 0.0}); + + sd::ops::clipbynorm op; + auto result = op.evaluate({&x}, {4.0}, {}); + + auto z = result.at(0); + + ASSERT_TRUE(exp.isSameShape(z)); + ASSERT_TRUE(exp.equalsTo(z)); + +} + +TEST_F(DeclarableOpsTests16, clipbynorm_2) { + auto x= NDArrayFactory::create('c', {2, 3}, {-3.0f, 0.0f, 0.0f, 4.0f, 0.0f, 0.0f}); + auto exp= NDArrayFactory::create('c', {2, 3}, {-3.0f, 0.0f, 0.0f, 4.0f, 0.0f, 0.0f}); + + sd::ops::clipbynorm op; + auto result = op.evaluate({&x}, {6.0}, {}); + + auto z = result.at(0); + + ASSERT_TRUE(exp.isSameShape(z)); + ASSERT_TRUE(exp.equalsTo(z)); + +} + +//////////////////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests16, clipbynorm_3) { + + auto x = NDArrayFactory::create('c', {3, 5}); + auto unities = NDArrayFactory::create('c', {3, 1}, {1., 1., 1.}); + auto scale = NDArrayFactory::create('c', {3, 1}, {1.1, 1., 0.9}); + + x.linspace(100.); + + auto xNorm1 = x.reduceAlongDimension(reduce::Norm2, {1}, true); + x /= xNorm1; + xNorm1 = x.reduceAlongDimension(reduce::Norm2,{1}, true); + + ASSERT_TRUE(unities.isSameShape(xNorm1)); + ASSERT_TRUE(unities.equalsTo(xNorm1)); + + x *= scale; + xNorm1 = x.reduceAlongDimension(reduce::Norm2, {1}, true); + + sd::ops::clipbynorm op; + auto result = op.evaluate({&x}, {1.0}, {1}); + auto z = result.at(0); + + auto zNorm1 = z->reduceAlongDimension(reduce::Norm2, {1}, true); + auto exp = NDArrayFactory::create('c', {3, 1}, {1., 1., xNorm1.e(2)}); + + ASSERT_TRUE(exp.isSameShape(&zNorm1)); + ASSERT_TRUE(exp.equalsTo(&zNorm1)); + +} + +//////////////////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests16, clipbynorm_4) { + + auto x = NDArrayFactory::create('c', {3, 5}, {0.7044955, 0.55606544, 0.15833677, 0.001874401, 0.61595726, 0.3924779, 0.7414847, 0.4127324, 0.24026828, 0.26093036, 0.46741188, 0.01863421, 0.08528871, 0.529365, 0.5510694}); + auto exp = NDArrayFactory::create('c', {3, 5}, {0.405392, 0.319980, 0.091113, 0.001079, 0.354444, 0.225846, 0.426676, 0.237501, 0.138259, 0.150149, 0.268965, 0.010723, 0.049078, 0.304615, 0.317105}); + + sd::ops::clipbynorm op; + auto result = op.evaluate({&x}, {1.f}, {}); + auto output = result.at(0); + + ASSERT_TRUE(exp.isSameShape(output)); + ASSERT_TRUE(exp.equalsTo(output)); +} + +//////////////////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests16, clipbynorm_5) { + + // auto x = NDArrayFactory::create('c', {3, 5}, {1,2,3,4,5, 1,2,3,4,5, 1,2,3,4,5}); + auto x = NDArrayFactory::create('c', {3, 5}); + auto exp = NDArrayFactory::create('c', {3, 5}, {1., 2., 2.89271, 3.50524, 4.00892, 6., 7., 7.71389, 7.88678, 8.01784, 11., 12., 12.53507, 12.26833, 12.02676}); + // auto exp = NDArrayFactory::create('c', {3, 5}, {1,1,1,1,1,1,1,1,1,1,1,1,1,1,1}); + + x.linspace(1); + + sd::ops::clipbynorm op; + auto result = op.evaluate({&x}, {15.f}, {0}); + auto output = result.at(0); + + ASSERT_TRUE(exp.isSameShape(output)); + ASSERT_TRUE(exp.equalsTo(output)); + + +} + +//////////////////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests16, clipbynorm_6) { + + auto x = NDArrayFactory::create('c', {3, 5}); + auto exp = NDArrayFactory::create('c', {3, 5}, {1., 2., 3., 4., 5., 4.95434, 5.78006, 6.60578, 7.43151, 8.25723, 5.64288, 6.15587, 6.66886, 7.18185, 7.69484}); + + x.linspace(1); + + sd::ops::clipbynorm op; + auto result = op.evaluate({&x}, {15.f}, {1}); + auto output = result.at(0); + + ASSERT_TRUE(exp.isSameShape(output)); + ASSERT_TRUE(exp.equalsTo(output)); + + +} + +//////////////////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests16, clipbynorm_7) { + + auto x = NDArrayFactory::create('c', {3, 5}); + auto exp = NDArrayFactory::create('c', {3, 5}, {0.42597, 0.85194, 1.27791, 1.70389, 2.12986, 2.55583, 2.9818 , 3.40777, 3.83374, 4.25971, 4.68569, 5.11166, 5.53763, 5.9636 , 6.38957}); + + x.linspace(1); + + sd::ops::clipbynorm op; + auto result = op.evaluate({&x}, {15.f}, {0,1}); + auto output = result.at(0); + + ASSERT_TRUE(exp.isSameShape(output)); + ASSERT_TRUE(exp.equalsTo(output)); + + +} + +//////////////////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests16, clipbynorm_8) { + + auto x = NDArrayFactory::create('c', {3, 5}); + auto exp = NDArrayFactory::create('c', {3, 5}, {0.42597, 0.85194, 1.27791, 1.70389, 2.12986, 2.55583, 2.9818 , 3.40777, 3.83374, 4.25971, 4.68569, 5.11166, 5.53763, 5.9636 , 6.38957}); + + x.linspace(1); + + sd::ops::clipbynorm op; + auto result = op.evaluate({&x}, {15.}, {}); + auto output = result.at(0); + + ASSERT_TRUE(exp.isSameShape(output)); + ASSERT_TRUE(exp.equalsTo(output)); + + +} + +//////////////////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests16, clipbynorm_9) { + + auto x = NDArrayFactory::create('c', {2}, {3., 4.}); + auto exp = NDArrayFactory::create('c', {2}, {2.4, 3.2}); + + sd::ops::clipbynorm op; + auto result = op.evaluate({&x}, {4.}, {}); + auto output = result.at(0); + + ASSERT_TRUE(exp.isSameShape(output)); + ASSERT_TRUE(exp.equalsTo(output)); + + +} + +//////////////////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests16, clipbynorm_10) { + + auto x = NDArrayFactory::create(6.); + auto exp = NDArrayFactory::create(5.); + + sd::ops::clipbynorm op; + auto result = op.evaluate({&x}, {5.}, {}); + auto output = result.at(0); + + ASSERT_TRUE(exp.isSameShape(output)); + ASSERT_TRUE(exp.equalsTo(output)); + + +} + +//////////////////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests16, clipbynorm_11) { + + auto x = NDArrayFactory::create('c', {2, 3, 4}); + auto exp = NDArrayFactory::create('c', {2, 3, 4}, {1., 2., 3., 4., 4.44787, 5.33745, 6.22702, 7.1166 , 6.33046, 7.03384, 7.73723, 8.44061, + 13., 14., 15., 16., 15.12277, 16.01235, 16.90192, 17.7915 ,14.77107, 15.47446, 16.17784, 16.88123}); + + x.linspace(1); + + sd::ops::clipbynorm op; + auto result = op.evaluate({&x}, {35.}, {0, 2}); + auto output = result.at(0); + + ASSERT_TRUE(exp.isSameShape(output)); + ASSERT_TRUE(exp.equalsTo(output)); + + +} + +//////////////////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests16, clipbynorm_12) { + auto x = NDArrayFactory::create('c', {3, 3}, {1, 2, 3, 4, 5,6, 7, 8, 9}); + auto e = NDArrayFactory::create('c', {3, 3}, {0.03198684, 0.06397368, 0.09596053, 0.12794736, 0.15993419, 0.19192106, 0.22390789, 0.25589472, 0.28788155}); + + sd::ops::clipbynorm op; + auto result = op.evaluate({&x}, {0.54}, {}); + + ASSERT_EQ(e, *result.at(0)); + + +} + +//////////////////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests16, clipbynorm_13) { + + const int bS = 5; + const int nOut = 4; + const int axis = 0; + const double clip = 2.; + + auto x = NDArrayFactory::create('c', {bS, nOut}, {0.412 ,0.184 ,0.961 ,0.897 ,0.173 ,0.931 ,0.736 ,0.540 ,0.953 ,0.278 ,0.573 ,0.787 ,0.320 ,0.776 ,0.338 ,0.311 ,0.835 ,0.909 ,0.890 ,0.290}); // uniform random in range [0,1] + auto colVect = NDArrayFactory::create('c', {bS, 1}, {0.9, 0.95, 1.00, 1.05, 1.1}); + auto expect = NDArrayFactory::create('c', {bS, nOut}); + + auto norm2 = x.reduceAlongDimension(reduce::Norm2, {axis}, true); // norm2 has shape [1, nOut] + + auto y = ( (x / norm2) * clip) * colVect ; + auto temp = (x / norm2) * clip; + + for (int j = 0; j < nOut; ++j) { + auto yCol = y({0,0, j,j+1}); + const double norm2Col = yCol.reduceNumber(reduce::Norm2).e(0); + if (norm2Col <= clip) + expect({0,0, j,j+1}).assign(yCol); + else + expect({0,0, j,j+1}).assign ( yCol * (clip / norm2Col) ); + } + + sd::ops::clipbynorm op; + auto result = op.evaluate({&y}, {clip}, {axis}); + auto outFF = result.at(0); + + ASSERT_TRUE(expect.isSameShape(outFF)); + ASSERT_TRUE(expect.equalsTo(outFF)); + + +} + +//////////////////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests16, clipbynorm_bp_1) { + + const int bS = 2; + const int nOut = 3; + const double clip = 0.7; + + auto x = NDArrayFactory::create('c', {bS, nOut}, {0.412 ,0.184 ,0.961 ,0.173 ,0.736 ,0.540 }); // uniform random in range [0,1] + auto gradO = NDArrayFactory::create('c', {bS, nOut}); + + const OpArgsHolder argsHolderFF({&x}, {clip}, {}); + const OpArgsHolder argsHolderBP({&x, &gradO}, {clip}, {}); + + sd::ops::clipbynorm opFF; + sd::ops::clipbynorm_bp opBP; + + const bool isGradCorrect = GradCheck::checkGrad(opFF, opBP, argsHolderFF, argsHolderBP); + + ASSERT_TRUE(isGradCorrect); +} + +//////////////////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests16, clipbynorm_bp_2) { + + const int bS = 2; + const int nOut = 3; + const int axis = 0; + const double clip = 0.7; + + auto x = NDArrayFactory::create('c', {bS, nOut}, {0.412 ,0.184 ,0.961 ,0.173 ,0.736 ,0.540 }); // uniform random in range [0,1] + auto gradO = NDArrayFactory::create('c', {bS, nOut}); + + const OpArgsHolder argsHolderFF({&x}, {clip}, {axis}); + const OpArgsHolder argsHolderBP({&x, &gradO}, {clip}, {axis}); + + sd::ops::clipbynorm opFF; + sd::ops::clipbynorm_bp opBP; + + const bool isGradCorrect = GradCheck::checkGrad(opFF, opBP, argsHolderFF, argsHolderBP); + + ASSERT_TRUE(isGradCorrect); +} + +//////////////////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests16, clipbynorm_bp_3) { + + const int bS = 2; + const int nOut = 3; + const int axis = 1; + const double clip = 1.; + + auto x = NDArrayFactory::create('c', {bS, nOut}, {0.412 ,0.184 ,0.961 ,0.173 ,0.736 ,0.540 }); // uniform random in range [0,1] + auto gradO = NDArrayFactory::create('c', {bS, nOut}); + + const OpArgsHolder argsHolderFF({&x}, {clip}, {axis}); + const OpArgsHolder argsHolderBP({&x, &gradO}, {clip}, {axis}); + + sd::ops::clipbynorm opFF; + sd::ops::clipbynorm_bp opBP; + + const bool isGradCorrect = GradCheck::checkGrad(opFF, opBP, argsHolderFF, argsHolderBP); + + ASSERT_TRUE(isGradCorrect); +} + +//////////////////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests16, clipbyavgnorm_1) { + auto x = NDArrayFactory::create('c', {2, 3}, {-3.0, 0.0, 0.0, 4.0, 0.0, 0.0}); + auto exp = NDArrayFactory::create('c', {2, 3}, {-2.88, 0.0, 0.0, 3.84, 0.0, 0.0}); + + sd::ops::clipbyavgnorm op; + auto result = op.evaluate({&x}, {0.8}, {}); + + auto z = result.at(0); + + ASSERT_TRUE(exp.isSameShape(z)); + ASSERT_TRUE(exp.equalsTo(z)); +} + +//////////////////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests16, clipbyavgnorm_2) { + auto x= NDArrayFactory::create('c', {2, 3}, {-3.0f, 0.0f, 0.0f, 4.0f, 0.0f, 0.0f}); + auto exp= NDArrayFactory::create('c', {2, 3}, {-3.f, 0.0f, 0.0f, 4.f, 0.0f, 0.0f}); + + sd::ops::clipbyavgnorm op; + auto result = op.evaluate({&x}, {0.9}, {}); + + auto z = result.at(0); + + ASSERT_TRUE(exp.isSameShape(z)); + ASSERT_TRUE(exp.equalsTo(z)); + +} + + + + + + + + + + + + + + + + + +//////////////////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests16, clipbyavgnorm_bp_1) { + + const int bS = 2; + const int nOut = 3; + const double clip = 0.7; + + auto x = NDArrayFactory::create('c', {bS, nOut}, {0.412 ,0.184 ,0.961 ,0.173 ,0.736 ,0.540 }); // uniform random in range [0,1] + auto gradO = NDArrayFactory::create('c', {bS, nOut}); + + const OpArgsHolder argsHolderFF({&x}, {clip}, {}); + const OpArgsHolder argsHolderBP({&x, &gradO}, {clip}, {}); + + sd::ops::clipbyavgnorm opFF; + sd::ops::clipbyavgnorm_bp opBP; + + const bool isGradCorrect = GradCheck::checkGrad(opFF, opBP, argsHolderFF, argsHolderBP); + + ASSERT_TRUE(isGradCorrect); +} + +//////////////////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests16, clipbyavgnorm_bp_2) { + + const int bS = 2; + const int nOut = 3; + const int axis = 1; + const double clip = 1.; + + auto x = NDArrayFactory::create('c', {bS, nOut}, {0.412 ,0.184 ,0.961 ,0.173 ,0.736 ,0.540 }); // uniform random in range [0,1] + auto gradO = NDArrayFactory::create('c', {bS, nOut}); + + const OpArgsHolder argsHolderFF({&x}, {clip}, {axis}); + const OpArgsHolder argsHolderBP({&x, &gradO}, {clip}, {axis}); + + sd::ops::clipbyavgnorm opFF; + sd::ops::clipbyavgnorm_bp opBP; + + const bool isGradCorrect = GradCheck::checkGrad(opFF, opBP, argsHolderFF, argsHolderBP); + + ASSERT_TRUE(isGradCorrect); +} + +//////////////////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests16, clipbyavgnorm_bp_3) { + + NDArray x('c', {2, 3, 4}, {-0.14 ,0.96 ,0.47 ,-0.98 ,0.03 ,0.95 ,0.33 ,-0.97 ,0.59 ,-0.92 ,-0.12 ,-0.33 ,0.82 ,-0.76 ,-0.69 ,-0.95 ,-0.77 ,0.25 ,-0.35 ,0.94 ,0.50 ,0.04 ,0.61 ,0.99}, sd::DataType::DOUBLE); + NDArray gradO('c', {2, 3, 4}, sd::DataType::DOUBLE); + + const OpArgsHolder argsHolderFF({&x}, {0.7}, {0,2}); + const OpArgsHolder argsHolderBP({&x, &gradO}, {0.7}, {0,2}); + + sd::ops::clipbyavgnorm opFF; + sd::ops::clipbyavgnorm_bp opBP; + + const bool isGradCorrect = GradCheck::checkGrad(opFF, opBP, argsHolderFF, argsHolderBP); + + ASSERT_TRUE(isGradCorrect); +} + diff --git a/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests3.cpp b/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests3.cpp index eddef73b3..38006dd50 100644 --- a/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests3.cpp +++ b/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests3.cpp @@ -50,7 +50,7 @@ TEST_F(DeclarableOpsTests3, Test_Tile_1) { ASSERT_TRUE(exp.isSameShape(z)); ASSERT_TRUE(exp.equalsTo(z)); - + } @@ -68,7 +68,7 @@ TEST_F(DeclarableOpsTests3, Test_Tile_2) { ASSERT_TRUE(exp.isSameShape(z)); ASSERT_TRUE(exp.equalsTo(z)); - + } TEST_F(DeclarableOpsTests3, Test_Permute_1) { @@ -123,7 +123,7 @@ TEST_F(DeclarableOpsTests3, Test_Unique_1) { ASSERT_TRUE(expI.isSameShape(i)); ASSERT_TRUE(expI.equalsTo(i)); - + } TEST_F(DeclarableOpsTests3, Test_Unique_2) { @@ -171,7 +171,7 @@ TEST_F(DeclarableOpsTests3, Test_Rint_1) { ASSERT_TRUE(exp.equalsTo(z)); - + } @@ -226,7 +226,7 @@ TEST_F(DeclarableOpsTests3, Test_Norm_2) { ASSERT_TRUE(exp0.isSameShape(z0)); ASSERT_TRUE(exp0.equalsTo(z0)); - + auto result1 = op.evaluate({&x, &axis}, {1}, {}); @@ -244,94 +244,6 @@ TEST_F(DeclarableOpsTests3, Test_Norm_2) { } - -TEST_F(DeclarableOpsTests3, Test_ClipByAvgNorm_1) { - auto x = NDArrayFactory::create('c', {2, 3}, {-3.0, 0.0, 0.0, 4.0, 0.0, 0.0}); - auto exp = NDArrayFactory::create('c', {2, 3}, {-2.88, 0.0, 0.0, 3.84, 0.0, 0.0}); - - sd::ops::clipbyavgnorm op; - auto result = op.evaluate({&x}, {0.8}, {}); - - auto z = result.at(0); - - ASSERT_TRUE(exp.isSameShape(z)); - ASSERT_TRUE(exp.equalsTo(z)); -} - -TEST_F(DeclarableOpsTests3, Test_ClipByAvgNorm_2) { - auto x= NDArrayFactory::create('c', {2, 3}, {-3.0f, 0.0f, 0.0f, 4.0f, 0.0f, 0.0f}); - auto exp= NDArrayFactory::create('c', {2, 3}, {-3.f, 0.0f, 0.0f, 4.f, 0.0f, 0.0f}); - - sd::ops::clipbyavgnorm op; - auto result = op.evaluate({&x}, {0.9}, {}); - - auto z = result.at(0); - - ASSERT_TRUE(exp.isSameShape(z)); - ASSERT_TRUE(exp.equalsTo(z)); - -} - - -TEST_F(DeclarableOpsTests3, Test_ClipByNorm_1) { - auto x= NDArrayFactory::create('c', {2, 3}, {-3.0, 0.0, 0.0, 4.0, 0.0, 0.0}); - auto exp= NDArrayFactory::create('c', {2, 3}, {-2.4, 0.0, 0.0, 3.2, 0.0, 0.0}); - - sd::ops::clipbynorm op; - auto result = op.evaluate({&x}, {4.0}, {}); - - auto z = result.at(0); - - ASSERT_TRUE(exp.isSameShape(z)); - ASSERT_TRUE(exp.equalsTo(z)); - -} - -TEST_F(DeclarableOpsTests3, Test_ClipByNorm_2) { - auto x= NDArrayFactory::create('c', {2, 3}, {-3.0f, 0.0f, 0.0f, 4.0f, 0.0f, 0.0f}); - auto exp= NDArrayFactory::create('c', {2, 3}, {-3.0f, 0.0f, 0.0f, 4.0f, 0.0f, 0.0f}); - - sd::ops::clipbynorm op; - auto result = op.evaluate({&x}, {6.0}, {}); - - auto z = result.at(0); - - ASSERT_TRUE(exp.isSameShape(z)); - ASSERT_TRUE(exp.equalsTo(z)); - -} - -//////////////////////////////////////////////////////////////////////////////// -TEST_F(DeclarableOpsTests3, Test_ClipByNorm_3) { - - auto x = NDArrayFactory::create('c', {3, 5}); - auto unities = NDArrayFactory::create('c', {3, 1}, {1., 1., 1.}); - auto scale = NDArrayFactory::create('c', {3, 1}, {1.1, 1., 0.9}); - - x.linspace(100.); - - auto xNorm1 = x.reduceAlongDimension(reduce::Norm2, {1}, true); - x /= xNorm1; - xNorm1 = x.reduceAlongDimension(reduce::Norm2,{1}, true); - - ASSERT_TRUE(unities.isSameShape(xNorm1)); - ASSERT_TRUE(unities.equalsTo(xNorm1)); - - x *= scale; - xNorm1 = x.reduceAlongDimension(reduce::Norm2, {1}, true); - - sd::ops::clipbynorm op; - auto result = op.evaluate({&x}, {1.0}, {1}); - auto z = result.at(0); - - auto zNorm1 = z->reduceAlongDimension(reduce::Norm2, {1}, true); - auto exp = NDArrayFactory::create('c', {3, 1}, {1., 1., xNorm1.e(2)}); - - ASSERT_TRUE(exp.isSameShape(&zNorm1)); - ASSERT_TRUE(exp.equalsTo(&zNorm1)); - -} - TEST_F(DeclarableOpsTests3, Test_ListDiff_1) { auto x= NDArrayFactory::create('c', {6}, {1.f, 2.f, 3.f, 4.f, 5.f, 6.f}); auto y= NDArrayFactory::create('c', {3}, {1.f, 3.f, 5.f}); @@ -551,7 +463,7 @@ TEST_F(DeclarableOpsTests3, Test_Batched_Gemm_1) { } delete exp; - + } TEST_F(DeclarableOpsTests3, Test_Batched_Gemm_2) { @@ -579,7 +491,7 @@ TEST_F(DeclarableOpsTests3, Test_Batched_Gemm_2) { } delete exp; - + } TEST_F(DeclarableOpsTests3, Test_Batched_Gemm_3) { @@ -607,7 +519,7 @@ TEST_F(DeclarableOpsTests3, Test_Batched_Gemm_3) { } delete exp; - + } TEST_F(DeclarableOpsTests3, Test_Batched_Gemm_4) { @@ -635,7 +547,7 @@ TEST_F(DeclarableOpsTests3, Test_Batched_Gemm_4) { } delete exp; - + } TEST_F(DeclarableOpsTests3, Test_Batched_Gemm_5) { @@ -663,7 +575,7 @@ TEST_F(DeclarableOpsTests3, Test_Batched_Gemm_5) { } delete exp; - + } @@ -692,7 +604,7 @@ TEST_F(DeclarableOpsTests3, Test_Batched_Gemm_6) { } delete exp; - + } TEST_F(DeclarableOpsTests3, Test_Batched_Gemm_7) { @@ -722,7 +634,7 @@ TEST_F(DeclarableOpsTests3, Test_Batched_Gemm_7) { } delete exp; - + } TEST_F(DeclarableOpsTests3, Test_Batched_Gemm_Validation_1) { @@ -734,7 +646,7 @@ TEST_F(DeclarableOpsTests3, Test_Batched_Gemm_Validation_1) { sd::ops::batched_gemm op; try { auto result = op.evaluate({&a, &b, &x, &x, &x, &y, &y, &y}, {}, {112, 112, 2, 3, 5, 5, 3, 2, 3}); - + ASSERT_TRUE(false); } catch (std::invalid_argument &e) { // @@ -875,7 +787,7 @@ TEST_F(DeclarableOpsTests3, sruCell_test3) { ASSERT_TRUE(expCt.isSameShape(ct)); ASSERT_TRUE(expCt.equalsTo(ct)); - + } //////////////////////////////////////////////////////////////////// @@ -946,7 +858,7 @@ TEST_F(DeclarableOpsTests3, gruCell_test2) { ASSERT_TRUE(expHt.isSameShape(ht)); ASSERT_TRUE(expHt.equalsTo(ht)); - + } //////////////////////////////////////////////////////////////////// @@ -1001,7 +913,7 @@ TEST_F(DeclarableOpsTests3, invertPermutation_test1) { ASSERT_TRUE(expected.isSameShape(output)); ASSERT_TRUE(expected.equalsTo(output)); - + } /////////////////////////////////////////////////////////////////// @@ -1021,7 +933,7 @@ TEST_F(DeclarableOpsTests3, invertPermutation_test2) { ASSERT_TRUE(expected.isSameShape(output)); ASSERT_TRUE(expected.equalsTo(output)); - + } /////////////////////////////////////////////////////////////////// @@ -1099,7 +1011,7 @@ TEST_F(DeclarableOpsTests3, diag_test_vector) { ASSERT_TRUE(expected.isSameShape(output)); ASSERT_TRUE(expected.equalsTo(output)); - + delete input; } @@ -1120,7 +1032,7 @@ TEST_F(DeclarableOpsTests3, diag_test_col_vector) { ASSERT_TRUE(expected.isSameShape(output)); ASSERT_TRUE(expected.equalsTo(output)); - + delete input; } /////////////////////////////////////////////////////////////////// @@ -1245,7 +1157,7 @@ TEST_F(DeclarableOpsTests3, matrixSetDiag_test2) { ASSERT_TRUE(expected.isSameShape(output)); ASSERT_TRUE(expected.equalsTo(output)); - + } /////////////////////////////////////////////////////////////////// @@ -1551,7 +1463,7 @@ TEST_F(DeclarableOpsTests3, betainc_test8) { ASSERT_TRUE(expected.isSameShape(output)); ASSERT_TRUE(expected.equalsTo(output, 1e-6)); - + } /////////////////////////////////////////////////////////////////// @@ -1576,7 +1488,7 @@ TEST_F(DeclarableOpsTests3, betainc_test9) { ASSERT_TRUE(expected.isSameShape(output)); ASSERT_TRUE(expected.equalsTo(output)); - + } /////////////////////////////////////////////////////////////////// @@ -1642,7 +1554,7 @@ TEST_F(DeclarableOpsTests3, betainc_test12) { ASSERT_TRUE(expected.isSameShape(output)); ASSERT_TRUE(expected.equalsTo(output)); - + } /////////////////////////////////////////////////////////////////// @@ -1689,7 +1601,7 @@ TEST_F(DeclarableOpsTests3, zeta_test2) { ASSERT_TRUE(expected.isSameShape(output)); ASSERT_TRUE(expected.equalsTo(output)); - + } /////////////////////////////////////////////////////////////////// @@ -1831,7 +1743,7 @@ TEST_F(DeclarableOpsTests3, zeta_test8) { ASSERT_TRUE(expected.isSameShape(output)); ASSERT_TRUE(expected.equalsTo(output)); - + } /////////////////////////////////////////////////////////////////// @@ -1856,7 +1768,7 @@ TEST_F(DeclarableOpsTests3, zeta_test9) { ASSERT_TRUE(expected.isSameShape(z)); ASSERT_TRUE(expected.equalsTo(z)); -// +// } /////////////////////////////////////////////////////////////////// @@ -1881,7 +1793,7 @@ TEST_F(DeclarableOpsTests3, zeta_test10) { ASSERT_TRUE(expected.isSameShape(z)); ASSERT_TRUE(expected.equalsTo(z)); -// +// } @@ -1908,7 +1820,7 @@ TEST_F(DeclarableOpsTests3, polygamma_test1) { x.assign(0.5); auto expected= NDArrayFactory::create('c', {3,3}, {4.934802, -16.828796, 97.409088, -771.474243, 7691.113770, -92203.460938, 1290440.250000, -20644900.000000, 3.71595e+08}); - + sd::ops::polygamma op; auto result = op.evaluate({&n, &x}, {}, {}); @@ -1920,7 +1832,7 @@ TEST_F(DeclarableOpsTests3, polygamma_test1) { ASSERT_TRUE(expected.isSameShape(output)); ASSERT_TRUE(expected.equalsTo(output)); - + } /////////////////////////////////////////////////////////////////// @@ -2263,7 +2175,7 @@ TEST_F(DeclarableOpsTests3, svd_test7) { ASSERT_TRUE(expS.equalsTo(s)); ASSERT_TRUE(expS.isSameShape(s)); - + } /////////////////////////////////////////////////////////////////// @@ -2416,7 +2328,7 @@ TEST_F(DeclarableOpsTests3, svd_test7) { // ASSERT_NEAR(sd::math::nd4j_abs(expV.e(i)), sd::math::nd4j_abs(v->e(i)), 1e-5); // } -// +// // } /////////////////////////////////////////////////////////////////// diff --git a/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests5.cpp b/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests5.cpp index 26c6b5d53..04bb54a61 100644 --- a/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests5.cpp +++ b/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests5.cpp @@ -54,7 +54,7 @@ TEST_F(DeclarableOpsTests5, Test_PermuteEquality_1) { ASSERT_TRUE(exp.isSameShape(z)); ASSERT_TRUE(exp.equalsTo(z)); - + } TEST_F(DeclarableOpsTests5, Test_PermuteEquality_0) { @@ -75,7 +75,7 @@ TEST_F(DeclarableOpsTests5, Test_PermuteEquality_0) { ASSERT_TRUE(exp.isSameShape(z)); ASSERT_TRUE(exp.equalsTo(z)); - + } @@ -97,7 +97,7 @@ TEST_F(DeclarableOpsTests5, Test_PermuteEquality_2) { ASSERT_TRUE(exp.isSameShape(z)); ASSERT_TRUE(exp.equalsTo(z)); - + } TEST_F(DeclarableOpsTests5, Test_PermuteEquality_3) { @@ -118,7 +118,7 @@ TEST_F(DeclarableOpsTests5, Test_PermuteEquality_3) { ASSERT_TRUE(exp.isSameShape(z)); ASSERT_TRUE(exp.equalsTo(z)); - + } TEST_F(DeclarableOpsTests5, Test_PermuteEquality_4) { @@ -139,7 +139,7 @@ TEST_F(DeclarableOpsTests5, Test_PermuteEquality_4) { ASSERT_TRUE(exp.isSameShape(z)); ASSERT_TRUE(exp.equalsTo(z)); - + } TEST_F(DeclarableOpsTests5, Test_PermuteEquality_5) { @@ -160,7 +160,7 @@ TEST_F(DeclarableOpsTests5, Test_PermuteEquality_5) { ASSERT_TRUE(exp.isSameShape(z)); ASSERT_TRUE(exp.equalsTo(z)); - + } TEST_F(DeclarableOpsTests5, Test_TTS_bp_1) { @@ -183,7 +183,7 @@ TEST_F(DeclarableOpsTests5, Test_TTS_bp_1) { ASSERT_TRUE(x.isSameShape(z)); ASSERT_TRUE(exp.equalsTo(z)); - + } @@ -217,7 +217,7 @@ TEST_F(DeclarableOpsTests5, Test_Boolean_diff_1) { auto result = op.evaluate({&x, &y}); ASSERT_EQ(Status::OK(), result.status()); ASSERT_EQ(result.at(0)->t(0), true); - + } TEST_F(DeclarableOpsTests5, Test_SetSeed_1) { @@ -229,7 +229,7 @@ TEST_F(DeclarableOpsTests5, Test_SetSeed_1) { ASSERT_EQ(Status::OK(), result.status()); // result->at(0)->printIndexedBuffer("RES SEED"); - + sd::ops::get_seed getOp; auto getRes = getOp.evaluate({}); ASSERT_EQ(Status::OK(), getRes.status()); @@ -252,7 +252,7 @@ TEST_F(DeclarableOpsTests5, scatterMul_test1) { ASSERT_TRUE(exp.equalsTo(z)); - + } //////////////////////////////////////////////////////////////////// @@ -270,7 +270,7 @@ TEST_F(DeclarableOpsTests5, scatterDiv_test1) { // z->printIndexedBuffer("Scatter Div"); ASSERT_TRUE(exp.equalsTo(z)); - + } //////////////////////////////////////////////////////////////////// @@ -288,7 +288,7 @@ TEST_F(DeclarableOpsTests5, scatterSub_test1) { // z->printIndexedBuffer("Scatter Sub"); ASSERT_TRUE(exp.equalsTo(z)); - + } //////////////////////////////////////////////////////////////////// @@ -303,7 +303,7 @@ TEST_F(DeclarableOpsTests5, hardsigmoid_test1) { auto z = result.at(0); ASSERT_TRUE(exp.equalsTo(z)); - + } //////////////////////////////////////////////////////////////////// @@ -319,7 +319,7 @@ TEST_F(DeclarableOpsTests5, hardsigmoid_test2) { auto z = result.at(0); ASSERT_TRUE(exp.equalsTo(z)); - + } //////////////////////////////////////////////////////////////////// @@ -335,7 +335,7 @@ TEST_F(DeclarableOpsTests5, hardtanh_test1) { // z->printIndexedBuffer("Hardtanh 2x2"); ASSERT_TRUE(exp.equalsTo(z)); - + } //////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests5, hardtanh_test2) { @@ -351,7 +351,7 @@ TEST_F(DeclarableOpsTests5, hardtanh_test2) { // z->printIndexedBuffer("Hardtanh_bp 2x2"); ASSERT_TRUE(exp.equalsTo(z)); - + } //////////////////////////////////////////////////////////////////// @@ -367,7 +367,7 @@ TEST_F(DeclarableOpsTests5, histogram_test1) { // z->printIndexedBuffer("Histogram3"); ASSERT_TRUE(exp.equalsTo(z)); - + } //////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests5, histogram_test2) { @@ -381,7 +381,7 @@ TEST_F(DeclarableOpsTests5, histogram_test2) { auto z = result.at(0); ASSERT_TRUE(exp.equalsTo(z)); - + } //////////////////////////////////////////////////////////////////// @@ -396,7 +396,7 @@ TEST_F(DeclarableOpsTests5, Identity_test1) { auto z = result.at(0); ASSERT_TRUE(matrix.equalsTo(z)); - + } //////////////////////////////////////////////////////////////////// @@ -411,7 +411,7 @@ TEST_F(DeclarableOpsTests5, Identity_test2) { auto z = result.at(0); ASSERT_TRUE(z->equalsTo(eps)); - + } //////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests5, Log1p_test1) { @@ -427,7 +427,7 @@ TEST_F(DeclarableOpsTests5, Log1p_test1) { auto z = result.at(0); ASSERT_TRUE(z->equalsTo(y)); - + } TEST_F(DeclarableOpsTests5, Test_SpaceToBatch_1) { @@ -445,7 +445,7 @@ TEST_F(DeclarableOpsTests5, Test_SpaceToBatch_1) { ASSERT_TRUE(exp.isSameShape(z)); ASSERT_TRUE(exp.equalsTo(z)); - + } @@ -463,7 +463,7 @@ TEST_F(DeclarableOpsTests5, Test_SpaceToBatch_2) { ASSERT_TRUE(exp.isSameShape(z)); ASSERT_TRUE(exp.equalsTo(z)); - + } @@ -483,7 +483,7 @@ TEST_F(DeclarableOpsTests5, Test_SpaceToBatch_3) { ASSERT_TRUE(exp.isSameShape(z)); ASSERT_TRUE(exp.equalsTo(z)); - + } ////////////////////////////////////////////////////////////////////// @@ -512,7 +512,7 @@ TEST_F(DeclarableOpsTests5, Test_SpaceToBatch_4) { ASSERT_TRUE(exp.isSameShape(z)); ASSERT_TRUE(exp.equalsTo(z)); - + } TEST_F(DeclarableOpsTests5, Test_BatchToSpace_1) { @@ -530,7 +530,7 @@ TEST_F(DeclarableOpsTests5, Test_BatchToSpace_1) { ASSERT_TRUE(exp.isSameShape(z)); ASSERT_TRUE(exp.equalsTo(z)); - + } TEST_F(DeclarableOpsTests5, Test_BatchToSpace_2) { @@ -547,7 +547,7 @@ TEST_F(DeclarableOpsTests5, Test_BatchToSpace_2) { ASSERT_TRUE(exp.isSameShape(z)); ASSERT_TRUE(exp.equalsTo(z)); - + } @@ -568,7 +568,7 @@ TEST_F(DeclarableOpsTests5, Test_BatchToSpace_3) { ASSERT_TRUE(exp.isSameShape(z)); ASSERT_TRUE(exp.equalsTo(z)); - + } ////////////////////////////////////////////////////////////////////// @@ -590,7 +590,7 @@ TEST_F(DeclarableOpsTests5, Test_BatchToSpace_4) { ASSERT_TRUE(exp.isSameShape(z)); ASSERT_TRUE(exp.equalsTo(z)); - + } ////////////////////////////////////////////////////////////////////// @@ -607,7 +607,7 @@ TEST_F(DeclarableOpsTests5, eye_test1) { ASSERT_TRUE(expected.isSameShape(output)); ASSERT_TRUE(expected.equalsTo(output)); - + } ////////////////////////////////////////////////////////////////////// @@ -623,7 +623,7 @@ TEST_F(DeclarableOpsTests5, eye_test2) { ASSERT_TRUE(expected.isSameShape(output)); ASSERT_TRUE(expected.equalsTo(output)); - + } ////////////////////////////////////////////////////////////////////// @@ -640,7 +640,7 @@ TEST_F(DeclarableOpsTests5, eye_test3) { ASSERT_TRUE(expected.isSameShape(output)); ASSERT_TRUE(expected.equalsTo(output)); - + } ////////////////////////////////////////////////////////////////////// @@ -656,7 +656,7 @@ TEST_F(DeclarableOpsTests5, eye_test4) { ASSERT_TRUE(expected.isSameShape(output)); ASSERT_TRUE(expected.equalsTo(output)); - + } ////////////////////////////////////////////////////////////////////// @@ -668,7 +668,7 @@ TEST_F(DeclarableOpsTests5, eye_test5) { auto z = result.at(0); ASSERT_EQ(ND4J_STATUS_OK, result.status()); - + } ////////////////////////////////////////////////////////////////////// @@ -688,7 +688,7 @@ TEST_F(DeclarableOpsTests5, gatherNd_test1) { ASSERT_TRUE(expected.isSameShape(output)); ASSERT_TRUE(expected.equalsTo(output)); - + } ////////////////////////////////////////////////////////////////////// @@ -708,7 +708,7 @@ TEST_F(DeclarableOpsTests5, gatherNd_test2) { ASSERT_TRUE(expected.isSameShape(output)); ASSERT_TRUE(expected.equalsTo(output)); - + } ////////////////////////////////////////////////////////////////////// @@ -727,7 +727,7 @@ TEST_F(DeclarableOpsTests5, gatherNd_test3) { ASSERT_TRUE(expected.isSameShape(output)); ASSERT_TRUE(expected.equalsTo(output)); - + } ////////////////////////////////////////////////////////////////////// @@ -746,7 +746,7 @@ TEST_F(DeclarableOpsTests5, gatherNd_test4) { ASSERT_TRUE(expected.isSameShape(output)); ASSERT_TRUE(expected.equalsTo(output)); - + } ////////////////////////////////////////////////////////////////////// @@ -764,7 +764,7 @@ TEST_F(DeclarableOpsTests5, gatherNd_test5) { ASSERT_TRUE(expected.isSameShape(output)); ASSERT_TRUE(expected.equalsTo(output)); - + } ////////////////////////////////////////////////////////////////////// @@ -783,7 +783,7 @@ TEST_F(DeclarableOpsTests5, gatherNd_test6) { ASSERT_TRUE(expected.isSameShape(output)); ASSERT_TRUE(expected.equalsTo(output)); - + } ////////////////////////////////////////////////////////////////////// @@ -802,7 +802,7 @@ TEST_F(DeclarableOpsTests5, gatherNd_test7) { ASSERT_TRUE(expected.isSameShape(output)); ASSERT_TRUE(expected.equalsTo(output)); - + } ////////////////////////////////////////////////////////////////////// @@ -819,7 +819,7 @@ TEST_F(DeclarableOpsTests5, gatherNd_test8) { ASSERT_EQ(e, *z); - + } TEST_F(DeclarableOpsTests5, gatherNd_test9) { @@ -840,7 +840,7 @@ TEST_F(DeclarableOpsTests5, gatherNd_test9) { ASSERT_TRUE(exp.isSameShape(z)); ASSERT_TRUE(exp.equalsTo(z)); - + } ////////////////////////////////////////////////////////////////////// @@ -885,7 +885,7 @@ TEST_F(DeclarableOpsTests5, reverse_sequense_test1) { ASSERT_TRUE(exp.isSameShape(output)); ASSERT_TRUE(exp.equalsTo(output)); - + } ////////////////////////////////////////////////////////////////////// @@ -904,7 +904,7 @@ TEST_F(DeclarableOpsTests5, reverse_sequense_test2) { ASSERT_TRUE(exp.isSameShape(output)); ASSERT_TRUE(exp.equalsTo(output)); - + } ////////////////////////////////////////////////////////////////////// @@ -923,7 +923,7 @@ TEST_F(DeclarableOpsTests5, reverse_sequense_test3) { ASSERT_TRUE(exp.isSameShape(output)); ASSERT_TRUE(exp.equalsTo(output)); - + } ////////////////////////////////////////////////////////////////////// @@ -942,7 +942,7 @@ TEST_F(DeclarableOpsTests5, reverse_sequense_test4) { ASSERT_TRUE(exp.isSameShape(output)); ASSERT_TRUE(exp.equalsTo(output)); - + } ////////////////////////////////////////////////////////////////////// @@ -961,7 +961,7 @@ TEST_F(DeclarableOpsTests5, reverse_sequense_test5) { ASSERT_TRUE(exp.isSameShape(output)); ASSERT_TRUE(exp.equalsTo(output)); - + } ////////////////////////////////////////////////////////////////////// @@ -980,7 +980,7 @@ TEST_F(DeclarableOpsTests5, reverse_sequense_test6) { ASSERT_TRUE(exp.isSameShape(output)); ASSERT_TRUE(exp.equalsTo(output)); - + } ////////////////////////////////////////////////////////////////////// @@ -1000,7 +1000,7 @@ TEST_F(DeclarableOpsTests5, reverse_sequense_test7) { ASSERT_TRUE(exp.isSameShape(output)); ASSERT_TRUE(exp.equalsTo(output)); - + } ////////////////////////////////////////////////////////////////////// @@ -1020,7 +1020,7 @@ TEST_F(DeclarableOpsTests5, reverse_sequense_test8) { ASSERT_TRUE(exp.isSameShape(output)); ASSERT_TRUE(exp.equalsTo(output)); - + } ////////////////////////////////////////////////////////////////////// @@ -1040,7 +1040,7 @@ TEST_F(DeclarableOpsTests5, reverse_sequense_test9) { ASSERT_TRUE(exp.isSameShape(output)); ASSERT_TRUE(exp.equalsTo(output)); - + } ////////////////////////////////////////////////////////////////////// @@ -1060,7 +1060,7 @@ TEST_F(DeclarableOpsTests5, reverse_sequense_test10) { ASSERT_TRUE(exp.isSameShape(output)); ASSERT_TRUE(exp.equalsTo(output)); - + } ////////////////////////////////////////////////////////////////////// @@ -1080,7 +1080,7 @@ TEST_F(DeclarableOpsTests5, reverse_sequense_test11) { ASSERT_TRUE(exp.isSameShape(output)); ASSERT_TRUE(exp.equalsTo(output)); - + } ////////////////////////////////////////////////////////////////////// @@ -1100,7 +1100,7 @@ TEST_F(DeclarableOpsTests5, reverse_sequense_test12) { ASSERT_TRUE(exp.isSameShape(output)); ASSERT_TRUE(exp.equalsTo(output)); - + } ////////////////////////////////////////////////////////////////////// @@ -1120,7 +1120,7 @@ TEST_F(DeclarableOpsTests5, reverse_sequense_test13) { ASSERT_TRUE(exp.isSameShape(output)); ASSERT_TRUE(exp.equalsTo(output)); - + } ////////////////////////////////////////////////////////////////////// @@ -1137,7 +1137,7 @@ TEST_F(DeclarableOpsTests5, reverse_sequense_test14) { ASSERT_EQ(e, *z); - + } ////////////////////////////////////////////////////////////////////// @@ -1176,7 +1176,7 @@ TEST_F(DeclarableOpsTests5, Test_TopK_0) { for (int cases = 0; cases < 100; ++cases) { op.execute({&x}, std::vector{v, i}, {}, {1, 0}, {}); // without sorting } - + } ////////////////////////////////////////////////////////////////////// @@ -1215,7 +1215,7 @@ TEST_F(DeclarableOpsTests5, Test_TopK_1) { for (int cases = 0; cases < 100; ++cases) { op.execute({&x}, std::vector{v, i}, {}, {1, 0}, {}); // without sorting } - + } /////////////////////////////////////////////////////////// @@ -1264,7 +1264,7 @@ TEST_F(DeclarableOpsTests5, Test_TopK_2) { ASSERT_TRUE(expI.isSameShape(i)); ASSERT_TRUE(expI.equalsTo(i)); - + } TEST_F(DeclarableOpsTests5, Test_TopK_3) { @@ -1314,7 +1314,7 @@ TEST_F(DeclarableOpsTests5, Test_TopK_3) { ASSERT_TRUE(expI.isSameShape(i)); ASSERT_TRUE(expI.equalsTo(i)); - + } TEST_F(DeclarableOpsTests5, Test_TopK_3_unsorted) { @@ -1353,7 +1353,7 @@ TEST_F(DeclarableOpsTests5, Test_TopK_3_unsorted) { ASSERT_TRUE(expI.isSameShape(i)); ASSERT_TRUE(expI.equalsTo(i)); - + } ////////////////////////////////////////////////////////////////////// @@ -1377,7 +1377,7 @@ TEST_F(DeclarableOpsTests5, Test_TopK_4) { ASSERT_TRUE(expI.isSameShape(i)); ASSERT_TRUE(expI.equalsTo(i)); - + } ////////////////////////////////////////////////////////////////////// @@ -1401,7 +1401,7 @@ TEST_F(DeclarableOpsTests5, Test_TopK_5) { ASSERT_TRUE(expI.isSameShape(i)); ASSERT_TRUE(expI.equalsTo(i)); - + } /////////////////////////////////////////////////////////// @@ -1440,7 +1440,7 @@ TEST_F(DeclarableOpsTests5, Test_Moments_1) { ASSERT_NEAR(expMean, v->e(0), inf); ASSERT_NEAR(expDeviation, d->e(0), inf); - + } TEST_F(DeclarableOpsTests5, Test_Moments_2) { @@ -1470,7 +1470,7 @@ TEST_F(DeclarableOpsTests5, Test_Moments_2) { ASSERT_TRUE(v->equalsTo(&expV)); ASSERT_TRUE(d->equalsTo(&expD)); - + } TEST_F(DeclarableOpsTests5, Test_Moments_3) { @@ -1504,7 +1504,7 @@ TEST_F(DeclarableOpsTests5, Test_Moments_3) { ASSERT_TRUE(v->equalsTo(&expV)); ASSERT_TRUE(d->equalsTo(&expD)); - + } TEST_F(DeclarableOpsTests5, Test_Moments_4) { @@ -1537,7 +1537,7 @@ TEST_F(DeclarableOpsTests5, Test_Moments_4) { ASSERT_TRUE(v->equalsTo(&expV)); ASSERT_TRUE(d->equalsTo(&expD)); - + } ////////////////////////////////////////////////////////////////////// @@ -1558,7 +1558,7 @@ TEST_F(DeclarableOpsTests5, trace_test1) { // output->printIndexedBuffer("OUT TRACE"); ASSERT_TRUE(exp.equalsTo(output)); - + } ////////////////////////////////////////////////////////////////////// @@ -1576,7 +1576,7 @@ TEST_F(DeclarableOpsTests5, trace_test2) { ASSERT_TRUE(exp.isSameShape(output)); ASSERT_TRUE(exp.equalsTo(output)); - + } ////////////////////////////////////////////////////////////////////// @@ -1594,7 +1594,7 @@ TEST_F(DeclarableOpsTests5, trace_test3) { ASSERT_TRUE(exp.isSameShape(output)); ASSERT_TRUE(exp.equalsTo(output)); - + } ////////////////////////////////////////////////////////////////////// @@ -1612,7 +1612,7 @@ TEST_F(DeclarableOpsTests5, trace_test4) { ASSERT_TRUE(exp.isSameShape(output)); ASSERT_TRUE(exp.equalsTo(output)); - + } ////////////////////////////////////////////////////////////////////// @@ -1630,7 +1630,7 @@ TEST_F(DeclarableOpsTests5, trace_test5) { ASSERT_TRUE(exp.isSameShape(output)); ASSERT_TRUE(exp.equalsTo(output)); - + } ////////////////////////////////////////////////////////////////////// @@ -1653,7 +1653,7 @@ TEST_F(DeclarableOpsTests5, random_shuffle_test1) { ASSERT_TRUE(!input.equalsTo(output)); ASSERT_TRUE(!haveZeros); - + } ////////////////////////////////////////////////////////////////////// @@ -1670,7 +1670,7 @@ TEST_F(DeclarableOpsTests5, random_shuffle_test2) { ASSERT_TRUE(input.isSameShape(output)); ASSERT_TRUE(input.equalsTo(output)); - + } ////////////////////////////////////////////////////////////////////// @@ -1693,7 +1693,7 @@ TEST_F(DeclarableOpsTests5, random_shuffle_test3) { ASSERT_TRUE(!input.equalsTo(output)); ASSERT_TRUE(!haveZeros); - + } ////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests5, random_shuffle_test04) { @@ -1714,7 +1714,7 @@ TEST_F(DeclarableOpsTests5, random_shuffle_test04) { //ASSERT_TRUE(!input.equalsTo(output)); ASSERT_TRUE(!haveZeros); - + } ////////////////////////////////////////////////////////////////////// @@ -1736,7 +1736,7 @@ TEST_F(DeclarableOpsTests5, random_shuffle_test4) { //ASSERT_TRUE(!input.equalsTo(output)); ASSERT_TRUE(!haveZeros); - + } ////////////////////////////////////////////////////////////////////// @@ -1759,7 +1759,7 @@ TEST_F(DeclarableOpsTests5, random_shuffle_test5) { ASSERT_TRUE(!input.equalsTo(output)); ASSERT_TRUE(!haveZeros); - + } ////////////////////////////////////////////////////////////////////// @@ -1782,7 +1782,7 @@ TEST_F(DeclarableOpsTests5, random_shuffle_test6) { ASSERT_TRUE(!input.equalsTo(output)); ASSERT_TRUE(!haveZeros); - + } ////////////////////////////////////////////////////////////////////// @@ -1800,7 +1800,7 @@ TEST_F(DeclarableOpsTests5, random_shuffle_test7) { ASSERT_TRUE(input.isSameShape(output)); ASSERT_TRUE(input.equalsTo(output)); - + } //////////////////////////////////////////////////////////////////////////////////////// @@ -1835,7 +1835,7 @@ TEST_F(DeclarableOpsTests5, EmbeddingLookup_1) { ASSERT_TRUE(exp.equalsTo(output)); - + } TEST_F(DeclarableOpsTests5, EmbeddingLookup_2) { @@ -1871,7 +1871,7 @@ TEST_F(DeclarableOpsTests5, EmbeddingLookup_2) { ASSERT_TRUE(exp.equalsTo(output)); - + } TEST_F(DeclarableOpsTests5, EmbeddingLookup_3) { @@ -1912,7 +1912,7 @@ TEST_F(DeclarableOpsTests5, EmbeddingLookup_3) { ASSERT_TRUE(exp.equalsTo(output)); - + } /* @Test public void testDynamicPartition(){ @@ -1956,7 +1956,7 @@ TEST_F(DeclarableOpsTests5, DynamicPartition_01) { ASSERT_TRUE(exp[e].equalsTo(output)); } - + } TEST_F(DeclarableOpsTests5, DynamicPartition_1) { @@ -1995,7 +1995,7 @@ TEST_F(DeclarableOpsTests5, DynamicPartition_1) { ASSERT_TRUE(exp[e].equalsTo(output)); } - + } //////////////////////////////////////////////////////////////////////////////// @@ -2024,7 +2024,7 @@ TEST_F(DeclarableOpsTests5, DynamicPartition_2) { ASSERT_TRUE(exp[e].equalsTo(output)); } - + } @@ -2062,7 +2062,7 @@ TEST_F(DeclarableOpsTests5, DynamicPartition_3) { } } - + } TEST_F(DeclarableOpsTests5, DynamicStitch_empty_1) { @@ -2078,7 +2078,7 @@ TEST_F(DeclarableOpsTests5, DynamicStitch_empty_1) { auto result = op.evaluate({&i0, &i1, &i2, &d0, &d1, &d2}, {}, {}); ASSERT_EQ(Status::OK(), result.status()); - + } TEST_F(DeclarableOpsTests5, DynamicStitch_empty_2) { @@ -2094,7 +2094,7 @@ TEST_F(DeclarableOpsTests5, DynamicStitch_empty_2) { auto result = op.evaluate({&i0, &i1, &i2, &d0, &d1, &d2}, {}, {}); ASSERT_EQ(Status::OK(), result.status()); - + } //////////////////////////////////////////////////////////////////////////////// @@ -2119,7 +2119,7 @@ TEST_F(DeclarableOpsTests5, DynamicStitch_1) { ASSERT_TRUE(exp.isSameShape(output)); ASSERT_TRUE(exp.equalsTo(output)); - + } //////////////////////////////////////////////////////////////////////////////// @@ -2148,7 +2148,7 @@ TEST_F(DeclarableOpsTests5, DynamicStitch_2) { ASSERT_TRUE(exp.isSameShape(output)); ASSERT_TRUE(exp.equalsTo(output)); - + } ////////////////////////////////////////////////////////////////////// @@ -2177,7 +2177,7 @@ TEST_F(DeclarableOpsTests5, fusedBatchNorm_test1) { ASSERT_TRUE(expBatchMean.isSameShape(batchMean)); ASSERT_TRUE(expBatchVar.isSameShape(batchVar)); - + } ////////////////////////////////////////////////////////////////////// @@ -2206,7 +2206,7 @@ TEST_F(DeclarableOpsTests5, fusedBatchNorm_test2) { ASSERT_TRUE(expBatchMean.isSameShape(batchMean)); ASSERT_TRUE(expBatchVar.isSameShape(batchVar)); - + } ////////////////////////////////////////////////////////////////////// @@ -2235,7 +2235,7 @@ TEST_F(DeclarableOpsTests5, fusedBatchNorm_test3) { ASSERT_TRUE(expBatchMean.isSameShape(batchMean)); ASSERT_TRUE(expBatchVar.isSameShape(batchVar)); - + } ////////////////////////////////////////////////////////////////////// @@ -2270,7 +2270,7 @@ TEST_F(DeclarableOpsTests5, fusedBatchNorm_test4) { ASSERT_TRUE(expBatchMean.isSameShape(batchMean)); ASSERT_TRUE(expBatchVar.isSameShape(batchVar)); - + } ////////////////////////////////////////////////////////////////////// @@ -2305,7 +2305,7 @@ TEST_F(DeclarableOpsTests5, fusedBatchNorm_test5) { ASSERT_TRUE(expBatchMean.isSameShape(batchMean)); ASSERT_TRUE(expBatchVar.isSameShape(batchVar)); - + } ////////////////////////////////////////////////////////////////////// @@ -2324,7 +2324,7 @@ TEST_F(DeclarableOpsTests5, confusion_matrix_test1) { ASSERT_TRUE(expected.isSameShape(output)); ASSERT_TRUE(expected.equalsTo(output)); - + } ////////////////////////////////////////////////////////////////////// @@ -2343,7 +2343,7 @@ TEST_F(DeclarableOpsTests5, confusion_matrix_test2) { ASSERT_TRUE(expected.isSameShape(output)); ASSERT_TRUE(expected.equalsTo(output)); - + } ////////////////////////////////////////////////////////////////////// @@ -2362,7 +2362,7 @@ TEST_F(DeclarableOpsTests5, confusion_matrix_test3) { ASSERT_TRUE(expected.isSameShape(output)); ASSERT_TRUE(expected.equalsTo(output)); - + } ////////////////////////////////////////////////////////////////////// @@ -2382,7 +2382,7 @@ TEST_F(DeclarableOpsTests5, confusion_matrix_test4) { ASSERT_TRUE(expected.isSameShape(output)); ASSERT_TRUE(expected.equalsTo(output)); - + } /////////////////////////////////////////////////////////////////////////////// @@ -2400,7 +2400,7 @@ TEST_F(DeclarableOpsTests5, ZeroFraction_1) { ASSERT_TRUE(res.at(0)->isScalar()); ASSERT_EQ(res.at(0)->e(0), 0.25); - + } //////////////////////////////////////////////////////////////////////////////// @@ -2415,7 +2415,7 @@ TEST_F(DeclarableOpsTests5, ZeroFraction_2) { ASSERT_TRUE(res.at(0)->isScalar()); ASSERT_EQ(res.at(0)->e(0), 0.375); - + } //////////////////////////////////////////////////////////////////////////////// @@ -2430,8 +2430,9 @@ TEST_F(DeclarableOpsTests5, ZeroFraction_3) { ASSERT_TRUE(res.at(0)->isScalar()); ASSERT_EQ(res.at(0)->e(0), 0.375); - + } + //////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests5, XWPlusB_1) { @@ -2451,6 +2452,7 @@ TEST_F(DeclarableOpsTests5, XWPlusB_1) { ASSERT_TRUE(exp.isSameShape(output)); ASSERT_TRUE(exp.equalsTo(output)); } + //////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests5, XWPlusB_2) { @@ -2591,7 +2593,7 @@ TEST_F(DeclarableOpsTests5, StopGradient_1) { ASSERT_TRUE(x.isSameShape(output)); ASSERT_TRUE(x.equalsTo(output)); - + } //////////////////////////////////////////////////////////////////////////////// @@ -2614,7 +2616,7 @@ TEST_F(DeclarableOpsTests5, StopGradient_2) { ASSERT_TRUE(x.isSameShape(output)); ASSERT_TRUE(x.equalsTo(output)); - + } ////////////////////////////////////////////////////////////////////// @@ -2631,7 +2633,7 @@ TEST_F(DeclarableOpsTests5, log_softmax_test1) { ASSERT_TRUE(expOutput.isSameShape(z)); ASSERT_TRUE(expOutput.equalsTo(z)); - + } ////////////////////////////////////////////////////////////////////// @@ -2648,7 +2650,7 @@ TEST_F(DeclarableOpsTests5, log_softmax_test2) { ASSERT_TRUE(expOutput.isSameShape(z)); ASSERT_TRUE(expOutput.equalsTo(z)); - + } ////////////////////////////////////////////////////////////////////// @@ -2665,7 +2667,7 @@ TEST_F(DeclarableOpsTests5, log_softmax_test3) { ASSERT_TRUE(expOutput.isSameShape(z)); ASSERT_TRUE(expOutput.equalsTo(z)); - + } @@ -2683,7 +2685,7 @@ TEST_F(DeclarableOpsTests5, log_softmax_test5) { ASSERT_TRUE(expOutput.isSameShape(z)); ASSERT_TRUE(expOutput.equalsTo(z)); - + } ////////////////////////////////////////////////////////////////////// @@ -2700,7 +2702,7 @@ TEST_F(DeclarableOpsTests5, log_softmax_test6) { ASSERT_TRUE(expOutput.isSameShape(z)); ASSERT_TRUE(expOutput.equalsTo(z)); - + } ////////////////////////////////////////////////////////////////////// @@ -2717,7 +2719,7 @@ TEST_F(DeclarableOpsTests5, log_softmax_test7) { ASSERT_TRUE(expOutput.isSameShape(z)); ASSERT_TRUE(expOutput.equalsTo(z)); - + } ////////////////////////////////////////////////////////////////////// @@ -2734,7 +2736,7 @@ TEST_F(DeclarableOpsTests5, log_softmax_test8) { ASSERT_TRUE(expOutput.isSameShape(z)); ASSERT_TRUE(expOutput.equalsTo(z)); - + } ////////////////////////////////////////////////////////////////////// @@ -2751,7 +2753,7 @@ TEST_F(DeclarableOpsTests5, log_softmax_test9) { ASSERT_TRUE(expOutput.isSameShape(z)); ASSERT_TRUE(expOutput.equalsTo(z)); - + } ////////////////////////////////////////////////////////////////////// @@ -2768,7 +2770,7 @@ TEST_F(DeclarableOpsTests5, log_softmax_test10) { ASSERT_TRUE(expOutput.isSameShape(z)); ASSERT_TRUE(expOutput.equalsTo(z)); - + } ////////////////////////////////////////////////////////////////////// @@ -2785,7 +2787,7 @@ TEST_F(DeclarableOpsTests5, log_softmax_test11) { ASSERT_TRUE(expOutput.isSameShape(z)); ASSERT_TRUE(expOutput.equalsTo(z)); - + } ////////////////////////////////////////////////////////////////////// @@ -2804,7 +2806,7 @@ TEST_F(DeclarableOpsTests5, log_softmax_test12) { ASSERT_TRUE(expOutput.isSameShape(z)); ASSERT_TRUE(expOutput.equalsTo(z, 1e-4)); - + } } @@ -2823,7 +2825,7 @@ TEST_F(DeclarableOpsTests5, log_softmax_bp_test1) { ASSERT_TRUE(exp.isSameShape(output)); ASSERT_TRUE(exp.equalsTo(output)); - + } ////////////////////////////////////////////////////////////////////// @@ -2841,7 +2843,7 @@ TEST_F(DeclarableOpsTests5, log_softmax_bp_test2) { ASSERT_TRUE(exp.isSameShape(output)); ASSERT_TRUE(exp.equalsTo(output)); - + } ////////////////////////////////////////////////////////////////////// @@ -2871,7 +2873,7 @@ TEST_F(DeclarableOpsTests5, L2_Loss_1) { ASSERT_EQ(output->e(0), exp); - + } TEST_F(DeclarableOpsTests5, L2_Loss_2) { @@ -2886,7 +2888,7 @@ TEST_F(DeclarableOpsTests5, L2_Loss_2) { ASSERT_EQ(e, *z); - + } TEST_F(DeclarableOpsTests5, L2_Loss_3) { @@ -2918,7 +2920,7 @@ TEST_F(DeclarableOpsTests5, LogPoissonLoss_1) { ASSERT_TRUE(exp.isSameShape(output)); ASSERT_TRUE(exp.equalsTo(output)); - + } //////////////////////////////////////////////////////////////////////////////// @@ -2939,7 +2941,7 @@ TEST_F(DeclarableOpsTests5, LogPoissonLoss_2) { ASSERT_TRUE(exp.isSameShape(output)); ASSERT_TRUE(exp.equalsTo(output)); - + } //////////////////////////////////////////////////////////////////////////////// @@ -2991,7 +2993,7 @@ TEST_F(DeclarableOpsTests5, NormalizeMoments_1) { ASSERT_TRUE(expMeans.isSameShape(outputDeviance)); ASSERT_TRUE(expDeviance.equalsTo(outputDeviance)); - + } //////////////////////////////////////////////////////////////////////////////// @@ -3042,7 +3044,7 @@ TEST_F(DeclarableOpsTests5, NormalizeMoments_2) { ASSERT_TRUE(expMeans.isSameShape(outputDeviance)); ASSERT_TRUE(expDeviance.equalsTo(outputDeviance)); - + } //////////////////////////////////////////////////////////////////////////////// @@ -3093,6 +3095,6 @@ TEST_F(DeclarableOpsTests5, NormalizeMoments_3) { ASSERT_TRUE(expMeans.isSameShape(outputDeviance)); ASSERT_TRUE(expDeviance.equalsTo(outputDeviance)); - + } diff --git a/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests6.cpp b/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests6.cpp index 450b32bcc..ed9dbee68 100644 --- a/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests6.cpp +++ b/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests6.cpp @@ -57,7 +57,7 @@ TEST_F(DeclarableOpsTests6, Test_StridedSlice_Once_Again_1) { ASSERT_TRUE(exp.equalsTo(z)); - + } TEST_F(DeclarableOpsTests6, Test_StridedSlice_Once_Again_2) { @@ -78,7 +78,7 @@ TEST_F(DeclarableOpsTests6, Test_StridedSlice_Once_Again_2) { ASSERT_EQ(exp, *z); - + } TEST_F(DeclarableOpsTests6, Test_StridedSlice_Once_Again_3) { @@ -100,7 +100,7 @@ TEST_F(DeclarableOpsTests6, Test_StridedSlice_Once_Again_3) { ASSERT_TRUE(z->isEmpty()); //ASSERT_EQ(exp, *z); - + } TEST_F(DeclarableOpsTests6, Test_StridedSlice_Once_Again_4) { @@ -122,7 +122,7 @@ TEST_F(DeclarableOpsTests6, Test_StridedSlice_Once_Again_4) { ASSERT_TRUE(z->equalsTo(exp)); //ASSERT_EQ(exp, *z); - + } TEST_F(DeclarableOpsTests6, Test_StridedSlice_Once_Again_04) { @@ -185,7 +185,7 @@ TEST_F(DeclarableOpsTests6, Test_StridedSlice_Once_Again_5) { auto z = result.at(0); ASSERT_TRUE(exp.equalsTo(z)); - + } TEST_F(DeclarableOpsTests6, Test_StridedSlice_Once_Again_6) { @@ -205,7 +205,7 @@ TEST_F(DeclarableOpsTests6, Test_StridedSlice_Once_Again_6) { auto z = result.at(0); ASSERT_TRUE(exp.equalsTo(z)); - + } TEST_F(DeclarableOpsTests6, Test_StridedSlice_Once_Again_7) { @@ -226,7 +226,7 @@ TEST_F(DeclarableOpsTests6, Test_StridedSlice_Once_Again_7) { auto z = result.at(0); //ASSERT_TRUE(exp.equalsTo(z)); - + } TEST_F(DeclarableOpsTests6, Test_StridedSlice_BP_1) { @@ -248,7 +248,7 @@ TEST_F(DeclarableOpsTests6, Test_StridedSlice_BP_1) { auto z = result.at(0); //ASSERT_TRUE(exp.equalsTo(z)); - + } TEST_F(DeclarableOpsTests6, Test_StridedSlice_BP_2) { @@ -270,7 +270,7 @@ TEST_F(DeclarableOpsTests6, Test_StridedSlice_BP_2) { auto z = result.at(0); //ASSERT_TRUE(exp.equalsTo(z)); - + } TEST_F(DeclarableOpsTests6, Test_StridedSlice_BP_3) { @@ -292,7 +292,7 @@ TEST_F(DeclarableOpsTests6, Test_StridedSlice_BP_3) { auto z = result.at(0); //ASSERT_TRUE(exp.equalsTo(z)); - + } TEST_F(DeclarableOpsTests6, Test_Simple_Scalar_1) { @@ -309,7 +309,7 @@ TEST_F(DeclarableOpsTests6, Test_Simple_Scalar_1) { ASSERT_TRUE(exp.isSameShape(z)); ASSERT_TRUE(exp.equalsTo(z)); - + } TEST_F(DeclarableOpsTests6, Test_Order_1) { @@ -326,7 +326,7 @@ TEST_F(DeclarableOpsTests6, Test_Order_1) { ASSERT_TRUE(exp.equalsTo(z)); ASSERT_NE(x.ordering(), z->ordering()); - + } TEST_F(DeclarableOpsTests6, cumSum_1) { @@ -342,7 +342,7 @@ TEST_F(DeclarableOpsTests6, cumSum_1) { ASSERT_TRUE(exp.isSameShape(z)); ASSERT_TRUE(exp.equalsTo(z)); - + } TEST_F(DeclarableOpsTests6, cumSum_2) { @@ -359,7 +359,7 @@ TEST_F(DeclarableOpsTests6, cumSum_2) { ASSERT_TRUE(exp.isSameShape(z)); ASSERT_TRUE(exp.equalsTo(z)); - + } TEST_F(DeclarableOpsTests6, cumSum_3) { @@ -375,7 +375,7 @@ TEST_F(DeclarableOpsTests6, cumSum_3) { ASSERT_TRUE(exp.isSameShape(z)); ASSERT_TRUE(exp.equalsTo(z)); - + } TEST_F(DeclarableOpsTests6, cumSum_4) { @@ -391,7 +391,7 @@ TEST_F(DeclarableOpsTests6, cumSum_4) { ASSERT_TRUE(exp.equalsTo(z)); - + } TEST_F(DeclarableOpsTests6, cumSum_5) { @@ -406,7 +406,7 @@ TEST_F(DeclarableOpsTests6, cumSum_5) { ASSERT_TRUE(exp.equalsTo(z)); - + } TEST_F(DeclarableOpsTests6, cumSum_6) { @@ -421,7 +421,7 @@ TEST_F(DeclarableOpsTests6, cumSum_6) { ASSERT_TRUE(exp.equalsTo(z)); - + } TEST_F(DeclarableOpsTests6, cumSum_7) { @@ -436,7 +436,7 @@ TEST_F(DeclarableOpsTests6, cumSum_7) { ASSERT_TRUE(exp.equalsTo(z)); - + } TEST_F(DeclarableOpsTests6, cumSum_8) { @@ -452,7 +452,7 @@ TEST_F(DeclarableOpsTests6, cumSum_8) { ASSERT_TRUE(exp.equalsTo(z)); - + } //////////////////////////////////////////////////////////////////////////////// @@ -477,7 +477,7 @@ TEST_F(DeclarableOpsTests6, cumSum_9) { ASSERT_EQ(Status::OK(), result.status()); auto z = result.at(0); ASSERT_TRUE(expFF.equalsTo(z)); - + //************************************// exclusive = 1; reverse = 0; @@ -486,7 +486,7 @@ TEST_F(DeclarableOpsTests6, cumSum_9) { ASSERT_EQ(Status::OK(), result.status()); z = result.at(0); ASSERT_TRUE(expTF.equalsTo(z)); - + //************************************// exclusive = 0; reverse = 1; @@ -495,7 +495,7 @@ TEST_F(DeclarableOpsTests6, cumSum_9) { ASSERT_EQ(Status::OK(), result.status()); z = result.at(0); ASSERT_TRUE(expFT.equalsTo(z)); - + //************************************// exclusive = 1; reverse = 1; @@ -504,7 +504,7 @@ TEST_F(DeclarableOpsTests6, cumSum_9) { ASSERT_EQ(Status::OK(), result.status()); z = result.at(0); ASSERT_TRUE(expTT.equalsTo(z)); - + } @@ -517,7 +517,7 @@ TEST_F(DeclarableOpsTests6, cumSum_10) { auto result = op.evaluate({&x, &y}, {}, {1, 1}); ASSERT_EQ(Status::OK(), result.status()); - + } //////////////////////////////////////////////////////////////////////////////// @@ -536,7 +536,7 @@ TEST_F(DeclarableOpsTests6, cumSum_11) { ASSERT_TRUE(exp.equalsTo(z)); - + } //////////////////////////////////////////////////////////////////////////////// @@ -555,7 +555,7 @@ TEST_F(DeclarableOpsTests6, cumSum_12) { ASSERT_TRUE(exp.equalsTo(z)); - + } //////////////////////////////////////////////////////////////////////////////// @@ -574,7 +574,7 @@ TEST_F(DeclarableOpsTests6, cumSum_13) { ASSERT_TRUE(exp.equalsTo(z)); - + } //////////////////////////////////////////////////////////////////////////////// @@ -593,7 +593,7 @@ TEST_F(DeclarableOpsTests6, cumSum_14) { ASSERT_TRUE(exp.equalsTo(z)); - + } //////////////////////////////////////////////////////////////////////////////// @@ -612,7 +612,7 @@ TEST_F(DeclarableOpsTests6, cumSum_15) { ASSERT_TRUE(exp.equalsTo(z)); - + } //////////////////////////////////////////////////////////////////////////////// @@ -631,7 +631,7 @@ TEST_F(DeclarableOpsTests6, cumSum_16) { ASSERT_TRUE(z->ews() == 1); ASSERT_TRUE(x.ews() == 1); - + } //////////////////////////////////////////////////////////////////////////////// @@ -664,7 +664,7 @@ TEST_F(DeclarableOpsTests6, cumSum_17) { ASSERT_TRUE(exp.equalsTo(z)); - + } //////////////////////////////////////////////////////////////////////////////// @@ -697,7 +697,7 @@ TEST_F(DeclarableOpsTests6, cumSum_18) { ASSERT_TRUE(exp.equalsTo(z)); - + } //////////////////////////////////////////////////////////////////////////////// @@ -731,7 +731,7 @@ TEST_F(DeclarableOpsTests6, cumSum_19) { ASSERT_TRUE(exp.equalsTo(z)); - + } //////////////////////////////////////////////////////////////////////////////// @@ -764,7 +764,7 @@ TEST_F(DeclarableOpsTests6, cumSum_20) { ASSERT_TRUE(exp.equalsTo(z)); - + } //////////////////////////////////////////////////////////////////////////////// @@ -779,30 +779,40 @@ TEST_F(DeclarableOpsTests6, TestMergeMaxIndex_1) { auto res = op.evaluate({&x, &y, &z}, {}, {}, {}); ASSERT_EQ(ND4J_STATUS_OK, res.status()); -// res.at(0)->printIndexedBuffer("MergeMaxIndex Result is "); -// res.at(0)->printShapeInfo("Shape info for MergeMaxIdex"); -// x.printIndexedBuffer("Input is"); ASSERT_TRUE(res.at(0)->equalsTo(exp)); - + } //////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests6, TestMergeMaxIndex_2) { - auto x = NDArrayFactory::create('c', {2, 2, 2}, {1.f, 2.f, 3.f, 4.f, 5.f, 6.f, 7.f, 8.f}); - auto y = NDArrayFactory::create('c', {2, 2, 2}, {10.f, 2.f, 30.f, 4.f, 50.f, 6.f, 70.f, 8.f}); - auto z = NDArrayFactory::create('c', {2, 2, 2}, {1.f, 20.f, 3.f, 40.f, 5.f, 60.f, 7.f, 80.f}); - auto exp = NDArrayFactory::create('c', {2, 2, 2}, {1, 2, 1, 2, 1, 2, 1, 2}); + auto x = NDArrayFactory::create('c', {2, 2, 2}, {1.f, 2.f, 3.f, 4.f, 5.f, 60.f, 7.f, 8.f}); + auto y = NDArrayFactory::create('c', {2, 2, 2}, {10.f, 2.f, 30.f, 4.f, 50.f, 6.f, 70.f, 8.f}); + auto z = NDArrayFactory::create('c', {2, 2, 2}, {1.f, 20.f, 3.f, 40.f, 5.f, 6.f, 7.f, 80.f}); + auto exp = NDArrayFactory::create('c', {2, 2, 2}, {1, 2, 1, 2, 1, 0, 1, 2}); sd::ops::mergemaxindex op; auto ress = op.evaluate({&x, &y, &z}, {}, {sd::DataType::INT64}); ASSERT_EQ(ND4J_STATUS_OK, ress.status()); -// res.at(0)->printIndexedBuffer("MergeMaxIndex2 Result is "); -// res.at(0)->printShapeInfo("Shape info for MergeMaxIdex2"); -// x.printIndexedBuffer("Input is"); ASSERT_TRUE(ress.at(0)->equalsTo(exp)); - + +} + +//////////////////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests6, TestMergeMaxIndex_3) { + + auto x1 = NDArrayFactory::create('c', {3}, {1.f, 0.f, 0.f}); + auto x2 = NDArrayFactory::create('c', {3}, {0.f, 1.f, 0.f}); + auto x3 = NDArrayFactory::create('c', {3}, {0.f, 0.f, 1.f}); + NDArray z('c', {3}, sd::DataType::INT32); + NDArray expZ('c', {3}, {0, 1, 2}, sd::DataType::INT32); + + sd::ops::mergemaxindex op; + auto result = op.execute({&x1, &x2, &x3}, {&z}, {}, {}, {}); + + ASSERT_EQ(Status::OK(), result); + ASSERT_TRUE(z.equalsTo(expZ)); } //////////////////////////////////////////////////////////////////////////////// @@ -818,7 +828,7 @@ TEST_F(DeclarableOpsTests6, TestDropout_1) { //res.at(0)->printIndexedBuffer("Result is "); //x.printIndexedBuffer("Input is"); - + } //////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests6, TestMod_1) { @@ -834,7 +844,7 @@ TEST_F(DeclarableOpsTests6, TestMod_1) { // res.at(0)->printIndexedBuffer("MOD Result is "); // x.printIndexedBuffer("Input is"); ASSERT_TRUE(res.at(0)->equalsTo(exp)); - + } //////////////////////////////////////////////////////////////////////////////// @@ -853,7 +863,7 @@ TEST_F(DeclarableOpsTests6, TestMod_BP_1) { // x.printIndexedBuffer("Input is"); ASSERT_TRUE(res.at(0)->equalsTo(exp)); - + } /////////////////////////////////////////////////////////////////////////////// @@ -870,7 +880,7 @@ TEST_F(DeclarableOpsTests6, TestRank_1) { ASSERT_EQ(ND4J_STATUS_OK, res.status()); ASSERT_TRUE(res.at(0)->equalsTo(exp)); - + } TEST_F(DeclarableOpsTests6, TestDropout_2) { // auto x0 = NDArrayFactory::create('c', {10, 10}); @@ -883,7 +893,7 @@ TEST_F(DeclarableOpsTests6, TestDropout_2) { ASSERT_EQ(ND4J_STATUS_OK, res.status()); - + } TEST_F(DeclarableOpsTests6, TestDropout_3) { @@ -898,7 +908,7 @@ TEST_F(DeclarableOpsTests6, TestDropout_3) { ASSERT_EQ(ND4J_STATUS_OK, res.status()); - + } //////////////////////////////////////////////////////////////////////////////// @@ -922,7 +932,7 @@ TEST_F(DeclarableOpsTests6, MaxPoolWithArgmax_1) { ASSERT_TRUE(expI.equalsTo(res.at(1))); - + } //////////////////////////////////////////////////////////////////////////////// @@ -947,7 +957,7 @@ TEST_F(DeclarableOpsTests6, SufficientStatistics_1) { ASSERT_TRUE(sumExp.equalsTo(res.at(1))); ASSERT_TRUE(sqrExp.equalsTo(res.at(2))); - + } //////////////////////////////////////////////////////////////////////////////// @@ -979,7 +989,7 @@ TEST_F(DeclarableOpsTests6, SufficientStatistics_2) { ASSERT_TRUE(sumExp.equalsTo(res.at(1))); ASSERT_TRUE(sqrExp.equalsTo(res.at(2))); - + } //////////////////////////////////////////////////////////////////////////////// @@ -1270,7 +1280,7 @@ TEST_F(DeclarableOpsTests6, ClipByGlobalNorm_1) { ASSERT_TRUE(exp.equalsTo(z)); // ASSERT_TRUE(expNorm.equalsTo(norm)); - + } //////////////////////////////////////////////////////////////////////////////// @@ -1310,7 +1320,7 @@ TEST_F(DeclarableOpsTests6, ClipByGlobalNorm_2) { ASSERT_TRUE(exp.equalsTo(z)); ASSERT_TRUE(exp.equalsTo(y)); - + } //////////////////////////////////////////////////////////////////////////////// @@ -1344,7 +1354,7 @@ TEST_F(DeclarableOpsTests6, ClipByGlobalNorm_3) { ASSERT_TRUE(exp.equalsTo(z)); ASSERT_TRUE(exp.equalsTo(y)); - + } //////////////////////////////////////////////////////////////////////////////// @@ -1365,7 +1375,7 @@ TEST_F(DeclarableOpsTests6, MatrixDeterminant_1) { ASSERT_TRUE(exp.isSameShape(z)); ASSERT_TRUE(exp.equalsTo(z)); - + } //////////////////////////////////////////////////////////////////////////////// @@ -1386,7 +1396,7 @@ TEST_F(DeclarableOpsTests6, MatrixDeterminant_2) { ASSERT_TRUE(exp.isSameShape(z)); ASSERT_TRUE(exp.equalsTo(z)); - + } //////////////////////////////////////////////////////////////////////////////// @@ -1407,7 +1417,7 @@ TEST_F(DeclarableOpsTests6, MatrixDeterminant_3) { ASSERT_TRUE(exp.isSameShape(z)); ASSERT_TRUE(exp.equalsTo(z)); - + } //////////////////////////////////////////////////////////////////////////////// @@ -1428,7 +1438,7 @@ TEST_F(DeclarableOpsTests6, MatrixDeterminant_4) { ASSERT_TRUE(exp.isSameShape(z)); ASSERT_TRUE(exp.equalsTo(z)); - + } //////////////////////////////////////////////////////////////////////////////// @@ -1452,7 +1462,7 @@ TEST_F(DeclarableOpsTests6, MatrixDeterminant_5) { ASSERT_TRUE(exp.isSameShape(z)); ASSERT_TRUE(exp.equalsTo(z)); - + } //////////////////////////////////////////////////////////////////////////////// @@ -1477,7 +1487,7 @@ TEST_F(DeclarableOpsTests6, MatrixDeterminant_6) { ASSERT_TRUE(exp.isSameShape(z)); ASSERT_TRUE(exp.equalsTo(z)); - + } //////////////////////////////////////////////////////////////////////////////// @@ -1496,7 +1506,7 @@ TEST_F(DeclarableOpsTests6, LogMatrixDeterminant_1) { ASSERT_TRUE(exp.isSameShape(z)); ASSERT_TRUE(exp.equalsTo(z)); - + } //////////////////////////////////////////////////////////////////////////////// @@ -1514,7 +1524,7 @@ TEST_F(DeclarableOpsTests6, LogDet_1) { ASSERT_TRUE(exp.isSameShape(z)); ASSERT_TRUE(exp.equalsTo(z)); - + } //////////////////////////////////////////////////////////////////////////////// @@ -1533,7 +1543,7 @@ TEST_F(DeclarableOpsTests6, LogDet_2) { ASSERT_TRUE(exp.isSameShape(z)); ASSERT_TRUE(exp.equalsTo(z)); - + } //////////////////////////////////////////////////////////////////////////////// @@ -1552,7 +1562,7 @@ TEST_F(DeclarableOpsTests6, LogDet_3) { ASSERT_TRUE(exp.isSameShape(z)); ASSERT_TRUE(exp.equalsTo(z)); - + } //////////////////////////////////////////////////////////////////////////////// @@ -1596,7 +1606,7 @@ TEST_F(DeclarableOpsTests6, MatrixInverse_1) { ASSERT_TRUE(exp.isSameShape(z)); ASSERT_TRUE(exp.equalsTo(z)); - + } //////////////////////////////////////////////////////////////////////////////// @@ -1615,7 +1625,7 @@ TEST_F(DeclarableOpsTests6, MatrixInverse_010) { ASSERT_TRUE(exp.isSameShape(z)); ASSERT_TRUE(exp.equalsTo(z)); - + } //////////////////////////////////////////////////////////////////////////////// @@ -1634,7 +1644,7 @@ TEST_F(DeclarableOpsTests6, MatrixInverse_01) { ASSERT_TRUE(exp.isSameShape(z)); ASSERT_TRUE(exp.equalsTo(z)); - + } //////////////////////////////////////////////////////////////////////////////// @@ -1653,7 +1663,7 @@ TEST_F(DeclarableOpsTests6, MatrixInverse_02) { ASSERT_TRUE(exp.isSameShape(z)); ASSERT_TRUE(exp.equalsTo(z)); - + } //////////////////////////////////////////////////////////////////////////////// @@ -1700,7 +1710,7 @@ TEST_F(DeclarableOpsTests6, MatrixInverse_2) { ASSERT_TRUE(exp.isSameShape(z)); ASSERT_TRUE(exp.equalsTo(z)); - + } */ TEST_F(DeclarableOpsTests6, MatrixInverse_03) { @@ -1733,7 +1743,7 @@ TEST_F(DeclarableOpsTests6, MatrixInverse_03) { ASSERT_TRUE(exp.isSameShape(z)); ASSERT_TRUE(exp.equalsTo(z)); - + } //////////////////////////////////////////////////////////////////////////////// @@ -1767,7 +1777,7 @@ TEST_F(DeclarableOpsTests6, MatrixInverse_3) { ASSERT_TRUE(exp.isSameShape(z)); ASSERT_TRUE(exp.equalsTo(z)); - + } //////////////////////////////////////////////////////////////////////////////// @@ -1801,7 +1811,7 @@ TEST_F(DeclarableOpsTests6, MatrixInverse_4) { ASSERT_TRUE(exp.isSameShape(z)); ASSERT_TRUE(exp.equalsTo(z)); - + } //////////////////////////////////////////////////////////////////////////////// @@ -1835,7 +1845,7 @@ TEST_F(DeclarableOpsTests6, MatrixInverse_04) { ASSERT_TRUE(exp.isSameShape(z)); ASSERT_TRUE(exp.equalsTo(z)); - + } //////////////////////////////////////////////////////////////////////////////// @@ -1864,7 +1874,7 @@ TEST_F(DeclarableOpsTests6, ReluLayer_1) { ASSERT_TRUE(exp.isSameShape(z)); ASSERT_TRUE(exp.equalsTo(z)); - + } TEST_F(DeclarableOpsTests6, Test_Reduce3_Edge) { @@ -1917,7 +1927,7 @@ TEST_F(DeclarableOpsTests6, static_rnn_test1) { ASSERT_TRUE(expHFinal.isSameShape(hFinal)); ASSERT_TRUE(expHFinal.equalsTo(hFinal)); - + } /////////////////////////////////////////////////////////////////// @@ -1960,7 +1970,7 @@ TEST_F(DeclarableOpsTests6, static_rnn_test2) { ASSERT_TRUE(expHFinal.isSameShape(hFinal)); ASSERT_TRUE(expHFinal.equalsTo(hFinal)); - + } /////////////////////////////////////////////////////////////////// @@ -2003,7 +2013,7 @@ TEST_F(DeclarableOpsTests6, static_rnn_test3) { ASSERT_TRUE(expHFinal.isSameShape(hFinal)); ASSERT_TRUE(expHFinal.equalsTo(hFinal)); - + } /////////////////////////////////////////////////////////////////// @@ -2045,7 +2055,7 @@ TEST_F(DeclarableOpsTests6, static_rnn_test4) { ASSERT_TRUE(expHFinal.isSameShape(hFinal)); ASSERT_TRUE(expHFinal.equalsTo(hFinal)); - + } /////////////////////////////////////////////////////////////////// @@ -2087,7 +2097,7 @@ TEST_F(DeclarableOpsTests6, static_rnn_test5) { ASSERT_TRUE(expHFinal.isSameShape(hFinal)); ASSERT_TRUE(expHFinal.equalsTo(hFinal)); - + } /////////////////////////////////////////////////////////////////// @@ -2141,7 +2151,7 @@ TEST_F(DeclarableOpsTests6, static_bidir_rnn_test1) { ASSERT_TRUE(expHBWfinal.isSameShape(hBWfinal)); ASSERT_TRUE(expHBWfinal.equalsTo(hBWfinal)); - + } /////////////////////////////////////////////////////////////////// @@ -2194,7 +2204,7 @@ TEST_F(DeclarableOpsTests6, static_bidir_rnn_test2) { ASSERT_TRUE(expHBWfinal.isSameShape(hBWfinal)); ASSERT_TRUE(expHBWfinal.equalsTo(hBWfinal)); - + } @@ -2247,7 +2257,7 @@ TEST_F(DeclarableOpsTests6, static_bidir_rnn_test3) { ASSERT_TRUE(expHBWfinal.isSameShape(hBWfinal)); ASSERT_TRUE(expHBWfinal.equalsTo(hBWfinal)); - + } /////////////////////////////////////////////////////////////////// @@ -2290,7 +2300,7 @@ TEST_F(DeclarableOpsTests6, dynamic_rnn_test1) { ASSERT_TRUE(expHFinal.isSameShape(hFinal)); ASSERT_TRUE(expHFinal.equalsTo(hFinal)); - + } @@ -2335,7 +2345,7 @@ TEST_F(DeclarableOpsTests6, dynamic_rnn_test2) { ASSERT_TRUE(expHFinal.isSameShape(hFinal)); ASSERT_TRUE(expHFinal.equalsTo(hFinal)); - + } /////////////////////////////////////////////////////////////////// @@ -2377,7 +2387,7 @@ TEST_F(DeclarableOpsTests6, dynamic_rnn_test3) { ASSERT_TRUE(expHFinal.isSameShape(hFinal)); ASSERT_TRUE(expHFinal.equalsTo(hFinal)); - + } /////////////////////////////////////////////////////////////////// @@ -2418,7 +2428,7 @@ TEST_F(DeclarableOpsTests6, dynamic_rnn_test4) { ASSERT_TRUE(expHFinal.isSameShape(hFinal)); ASSERT_TRUE(expHFinal.equalsTo(hFinal)); - + } /////////////////////////////////////////////////////////////////// @@ -2459,7 +2469,7 @@ TEST_F(DeclarableOpsTests6, dynamic_rnn_test5) { ASSERT_TRUE(expHFinal.isSameShape(hFinal)); ASSERT_TRUE(expHFinal.equalsTo(hFinal)); - + } /////////////////////////////////////////////////////////////////// @@ -2521,7 +2531,7 @@ TEST_F(DeclarableOpsTests6, dynamic_bidir_rnn_test1) { ASSERT_TRUE(expHBWfinal.isSameShape(hBWfinal)); ASSERT_TRUE(expHBWfinal.equalsTo(hBWfinal)); - + } /////////////////////////////////////////////////////////////////// @@ -2581,7 +2591,7 @@ TEST_F(DeclarableOpsTests6, dynamic_bidir_rnn_test2) { ASSERT_TRUE(expHBWfinal.isSameShape(hBWfinal)); ASSERT_TRUE(expHBWfinal.equalsTo(hBWfinal)); - + } /////////////////////////////////////////////////////////////////// @@ -2637,7 +2647,7 @@ TEST_F(DeclarableOpsTests6, dynamic_bidir_rnn_test3) { ASSERT_TRUE(expHBWfinal.isSameShape(hBWfinal)); ASSERT_TRUE(expHBWfinal.equalsTo(hBWfinal)); - + } /////////////////////////////////////////////////////////////////// @@ -2696,7 +2706,7 @@ TEST_F(DeclarableOpsTests6, dynamic_bidir_rnn_test4) { ASSERT_TRUE(expHBWfinal.isSameShape(hBWfinal)); ASSERT_TRUE(expHBWfinal.equalsTo(hBWfinal)); - + } TEST_F(DeclarableOpsTests6, dynamic_bidir_rnn_test5) { @@ -2749,7 +2759,7 @@ TEST_F(DeclarableOpsTests6, dynamic_bidir_rnn_test5) { ASSERT_TRUE(expHBWfinal.isSameShape(hBWfinal)); ASSERT_TRUE(expHBWfinal.equalsTo(hBWfinal)); - + } @@ -2763,7 +2773,7 @@ TEST_F(DeclarableOpsTests6, Test_Diag_119_1) { ASSERT_EQ(e, *result.at(0)); - + } TEST_F(DeclarableOpsTests6, Test_Diag_119_2) { @@ -2776,7 +2786,7 @@ TEST_F(DeclarableOpsTests6, Test_Diag_119_2) { ASSERT_EQ(e, *result.at(0)); - + } TEST_F(DeclarableOpsTests6, Test_Diag_119_3) { @@ -2789,7 +2799,7 @@ TEST_F(DeclarableOpsTests6, Test_Diag_119_3) { ASSERT_EQ(e, *result.at(0)); - + } diff --git a/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests8.cpp b/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests8.cpp index 589adebcb..0ca5e210a 100644 --- a/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests8.cpp +++ b/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests8.cpp @@ -65,7 +65,7 @@ TEST_F(DeclarableOpsTests8, reduceVariance_test1) { ASSERT_TRUE(exp.isSameShape(output)); ASSERT_TRUE(exp.equalsTo(output)); - + } //////////////////////////////////////////////////////////////////////////////// @@ -83,7 +83,7 @@ TEST_F(DeclarableOpsTests8, reduceVariance_test2) { ASSERT_TRUE(exp.isSameShape(output)); ASSERT_TRUE(exp.equalsTo(output)); - + } //////////////////////////////////////////////////////////////////////////////// @@ -101,7 +101,7 @@ TEST_F(DeclarableOpsTests8, reduceVariance_test3) { ASSERT_TRUE(exp.isSameShape(output)); ASSERT_TRUE(exp.equalsTo(output)); - + } //////////////////////////////////////////////////////////////////////////////// @@ -119,7 +119,7 @@ TEST_F(DeclarableOpsTests8, reduceVariance_test4) { ASSERT_TRUE(exp.isSameShape(output)); ASSERT_TRUE(exp.equalsTo(output)); - + } //////////////////////////////////////////////////////////////////////////////// @@ -137,7 +137,7 @@ TEST_F(DeclarableOpsTests8, reduceVariance_test5) { ASSERT_TRUE(exp.isSameShape(output)); ASSERT_TRUE(exp.equalsTo(output)); - + } //////////////////////////////////////////////////////////////////////////////// @@ -155,7 +155,7 @@ TEST_F(DeclarableOpsTests8, reduceVariance_test6) { ASSERT_TRUE(exp.isSameShape(output)); ASSERT_TRUE(exp.equalsTo(output)); - + } //////////////////////////////////////////////////////////////////////////////// @@ -173,7 +173,7 @@ TEST_F(DeclarableOpsTests8, reduceVariance_test7) { ASSERT_TRUE(exp.isSameShape(output)); ASSERT_TRUE(exp.equalsTo(output)); - + } //////////////////////////////////////////////////////////////////////////////// @@ -191,7 +191,7 @@ TEST_F(DeclarableOpsTests8, reduceVariance_test8) { ASSERT_TRUE(exp.isSameShape(output)); ASSERT_TRUE(exp.equalsTo(output)); - + } //////////////////////////////////////////////////////////////////////////////// @@ -209,7 +209,7 @@ TEST_F(DeclarableOpsTests8, reduceStDev_test1) { ASSERT_TRUE(exp.isSameShape(output)); ASSERT_TRUE(exp.equalsTo(output)); - + } //////////////////////////////////////////////////////////////////////////////// @@ -227,7 +227,7 @@ TEST_F(DeclarableOpsTests8, reduceStDev_test2) { ASSERT_TRUE(exp.isSameShape(output)); ASSERT_TRUE(exp.equalsTo(output)); - + } //////////////////////////////////////////////////////////////////////////////// @@ -245,7 +245,7 @@ TEST_F(DeclarableOpsTests8, reduceStDev_test3) { ASSERT_TRUE(exp.isSameShape(output)); ASSERT_TRUE(exp.equalsTo(output)); - + } //////////////////////////////////////////////////////////////////////////////// @@ -263,7 +263,7 @@ TEST_F(DeclarableOpsTests8, reduceStDev_test4) { ASSERT_TRUE(exp.isSameShape(output)); ASSERT_TRUE(exp.equalsTo(output)); - + } //////////////////////////////////////////////////////////////////////////////// @@ -281,7 +281,7 @@ TEST_F(DeclarableOpsTests8, reduceStDev_test5) { ASSERT_TRUE(exp.isSameShape(output)); ASSERT_TRUE(exp.equalsTo(output)); - + } //////////////////////////////////////////////////////////////////////////////// @@ -299,7 +299,7 @@ TEST_F(DeclarableOpsTests8, reduceStDev_test6) { ASSERT_TRUE(exp.isSameShape(output)); ASSERT_TRUE(exp.equalsTo(output)); - + } //////////////////////////////////////////////////////////////////////////////// @@ -317,7 +317,7 @@ TEST_F(DeclarableOpsTests8, reduceStDev_test7) { ASSERT_TRUE(exp.isSameShape(output)); ASSERT_TRUE(exp.equalsTo(output)); - + } //////////////////////////////////////////////////////////////////////////////// @@ -335,7 +335,7 @@ TEST_F(DeclarableOpsTests8, reduceStDev_test8) { ASSERT_TRUE(exp.isSameShape(output)); ASSERT_TRUE(exp.equalsTo(output)); - + } //////////////////////////////////////////////////////////////////////////////// @@ -353,7 +353,7 @@ TEST_F(DeclarableOpsTests8, reduceStDev_test08) { ASSERT_TRUE(exp.isSameShape(output)); ASSERT_TRUE(exp.equalsTo(output)); - + } //////////////////////////////////////////////////////////////////////////////// @@ -374,28 +374,28 @@ TEST_F(DeclarableOpsTests8, reduceVarianceBP_test1) { auto output = result.at(0); ASSERT_TRUE(exp12.isSameShape(output)); ASSERT_TRUE(exp12.equalsTo(output)); - + result = op.evaluate({&x, &gradO1}, {1,1}, {}); ASSERT_EQ(Status::OK(), result.status()); output = result.at(0); ASSERT_TRUE(exp12.isSameShape(output)); ASSERT_TRUE(exp12.equalsTo(output)); - + result = op.evaluate({&x, &gradO2}, {0,0}, {}); ASSERT_EQ(Status::OK(), result.status()); output = result.at(0); ASSERT_TRUE(exp34.isSameShape(output)); ASSERT_TRUE(exp34.equalsTo(output)); - + result = op.evaluate({&x, &gradO1}, {1,0}, {}); ASSERT_EQ(Status::OK(), result.status()); output = result.at(0); ASSERT_TRUE(exp34.isSameShape(output)); ASSERT_TRUE(exp34.equalsTo(output)); - + } @@ -417,28 +417,28 @@ TEST_F(DeclarableOpsTests8, reduceVarianceBP_test2) { auto output = result.at(0); ASSERT_TRUE(exp12.isSameShape(output)); ASSERT_TRUE(exp12.equalsTo(output)); - + result = op.evaluate({&x, &gradO1}, {1,0}, {0}); ASSERT_EQ(Status::OK(), result.status()); output = result.at(0); ASSERT_TRUE(exp12.isSameShape(output)); ASSERT_TRUE(exp12.equalsTo(output)); - + result = op.evaluate({&x, &gradO2}, {0,1}, {0}); ASSERT_EQ(Status::OK(), result.status()); output = result.at(0); ASSERT_TRUE(exp34.isSameShape(output)); ASSERT_TRUE(exp34.equalsTo(output)); - + result = op.evaluate({&x, &gradO1}, {1,1}, {0}); ASSERT_EQ(Status::OK(), result.status()); output = result.at(0); ASSERT_TRUE(exp34.isSameShape(output)); ASSERT_TRUE(exp34.equalsTo(output)); - + } @@ -460,28 +460,28 @@ TEST_F(DeclarableOpsTests8, reduceVarianceBP_test02) { auto output = result.at(0); ASSERT_TRUE(exp12.isSameShape(output)); ASSERT_TRUE(exp12.equalsTo(output)); - + result = op.evaluate({&x, &gradO1, &axes}, {}, {}, {true, false}); ASSERT_EQ(Status::OK(), result.status()); output = result.at(0); ASSERT_TRUE(exp12.isSameShape(output)); ASSERT_TRUE(exp12.equalsTo(output)); - + result = op.evaluate({&x, &gradO2, &axes}, {}, {}, {false, true}); ASSERT_EQ(Status::OK(), result.status()); output = result.at(0); ASSERT_TRUE(exp34.isSameShape(output)); ASSERT_TRUE(exp34.equalsTo(output)); - + result = op.evaluate({&x, &gradO1, &axes}, {}, {}, {true, true}); ASSERT_EQ(Status::OK(), result.status()); output = result.at(0); ASSERT_TRUE(exp34.isSameShape(output)); ASSERT_TRUE(exp34.equalsTo(output)); - + } @@ -507,28 +507,28 @@ TEST_F(DeclarableOpsTests8, reduceVarianceBP_test3) { auto output = result.at(0); ASSERT_TRUE(exp12.isSameShape(output)); ASSERT_TRUE(exp12.equalsTo(output)); - + result = op.evaluate({&x, &gradO1}, {1, 0}, {1}); ASSERT_EQ(Status::OK(), result.status()); output = result.at(0); ASSERT_TRUE(exp12.isSameShape(output)); ASSERT_TRUE(exp12.equalsTo(output)); - + result = op.evaluate({&x, &gradO2}, {0, 1}, {1}); ASSERT_EQ(Status::OK(), result.status()); output = result.at(0); ASSERT_TRUE(exp34.isSameShape(output)); ASSERT_TRUE(exp34.equalsTo(output)); - + result = op.evaluate({&x, &gradO1}, {1, 1}, {1}); ASSERT_EQ(Status::OK(), result.status()); output = result.at(0); ASSERT_TRUE(exp34.isSameShape(output)); ASSERT_TRUE(exp34.equalsTo(output)); - + } //////////////////////////////////////////////////////////////////////////////// @@ -550,28 +550,28 @@ TEST_F(DeclarableOpsTests8, reduceStDevBP_test1) { // output->printIndexedBuffer(); ASSERT_TRUE(exp12.isSameShape(output)); ASSERT_TRUE(exp12.equalsTo(output)); - + result = op.evaluate({&x, &gradO1}, {1,1}, {}); ASSERT_EQ(Status::OK(), result.status()); output = result.at(0); ASSERT_TRUE(exp12.isSameShape(output)); ASSERT_TRUE(exp12.equalsTo(output)); - + result = op.evaluate({&x, &gradO2}, {0,0}, {}); ASSERT_EQ(Status::OK(), result.status()); output = result.at(0); ASSERT_TRUE(exp34.isSameShape(output)); ASSERT_TRUE(exp34.equalsTo(output)); - + result = op.evaluate({&x, &gradO1}, {1,0}, {}); ASSERT_EQ(Status::OK(), result.status()); output = result.at(0); ASSERT_TRUE(exp34.isSameShape(output)); ASSERT_TRUE(exp34.equalsTo(output)); - + } //////////////////////////////////////////////////////////////////////////////// @@ -592,28 +592,28 @@ TEST_F(DeclarableOpsTests8, reduceStDevBP_test2) { auto output = result.at(0); ASSERT_TRUE(exp12.isSameShape(output)); ASSERT_TRUE(exp12.equalsTo(output)); - + result = op.evaluate({&x, &gradO1}, {1,0}, {0}); ASSERT_EQ(Status::OK(), result.status()); output = result.at(0); ASSERT_TRUE(exp12.isSameShape(output)); ASSERT_TRUE(exp12.equalsTo(output)); - + result = op.evaluate({&x, &gradO2}, {0,1}, {0}); ASSERT_EQ(Status::OK(), result.status()); output = result.at(0); ASSERT_TRUE(exp34.isSameShape(output)); ASSERT_TRUE(exp34.equalsTo(output)); - + result = op.evaluate({&x, &gradO1}, {1,1}, {0}); ASSERT_EQ(Status::OK(), result.status()); output = result.at(0); ASSERT_TRUE(exp34.isSameShape(output)); ASSERT_TRUE(exp34.equalsTo(output)); - + } //////////////////////////////////////////////////////////////////////////////// @@ -635,28 +635,28 @@ TEST_F(DeclarableOpsTests8, reduceStDevBP_test02) { auto output = result.at(0); ASSERT_TRUE(exp12.isSameShape(output)); ASSERT_TRUE(exp12.equalsTo(output)); - + result = op.evaluate({&x, &gradO1, &axis}, {}, {}, {true, false}); ASSERT_EQ(Status::OK(), result.status()); output = result.at(0); ASSERT_TRUE(exp12.isSameShape(output)); ASSERT_TRUE(exp12.equalsTo(output)); - + result = op.evaluate({&x, &gradO2, &axis}, {}, {}, {false, true}); ASSERT_EQ(Status::OK(), result.status()); output = result.at(0); ASSERT_TRUE(exp34.isSameShape(output)); ASSERT_TRUE(exp34.equalsTo(output)); - + result = op.evaluate({&x, &gradO1, &axis}, {}, {}, {true, true}); ASSERT_EQ(Status::OK(), result.status()); output = result.at(0); ASSERT_TRUE(exp34.isSameShape(output)); ASSERT_TRUE(exp34.equalsTo(output)); - + } //////////////////////////////////////////////////////////////////////////////// @@ -677,28 +677,28 @@ TEST_F(DeclarableOpsTests8, reduceStDevBP_test3) { auto output = result.at(0); ASSERT_TRUE(exp12.isSameShape(output)); ASSERT_TRUE(exp12.equalsTo(output)); - + result = op.evaluate({&x, &gradO1}, {1,0}, {1}); ASSERT_EQ(Status::OK(), result.status()); output = result.at(0); ASSERT_TRUE(exp12.isSameShape(output)); ASSERT_TRUE(exp12.equalsTo(output)); - + result = op.evaluate({&x, &gradO2}, {0,1}, {1}); ASSERT_EQ(Status::OK(), result.status()); output = result.at(0); ASSERT_TRUE(exp34.isSameShape(output)); ASSERT_TRUE(exp34.equalsTo(output)); - + result = op.evaluate({&x, &gradO1}, {1,1}, {1}); ASSERT_EQ(Status::OK(), result.status()); output = result.at(0); ASSERT_TRUE(exp34.isSameShape(output)); ASSERT_TRUE(exp34.equalsTo(output)); - + } @@ -717,7 +717,7 @@ TEST_F(DeclarableOpsTests8, Test_Reduce_Sum_1) { auto z = result.at(0); //z->printIndexedBuffer("Result is "); ASSERT_TRUE(exp.equalsTo(z)); - + } //////////////////////////////////////////////////////////////////////////////// @@ -734,7 +734,7 @@ TEST_F(DeclarableOpsTests8, Test_Reduce_Sum_2) { auto z = result.at(0); // z->printIndexedBuffer("Result is "); ASSERT_TRUE(exp.equalsTo(z)); - + } //////////////////////////////////////////////////////////////////////////////// @@ -752,7 +752,7 @@ TEST_F(DeclarableOpsTests8, Test_Reduce_Sum_03) { auto z = result.at(0); // z->printIndexedBuffer("Result is "); ASSERT_TRUE(exp.equalsTo(z)); - + } //////////////////////////////////////////////////////////////////////////////// @@ -769,7 +769,7 @@ TEST_F(DeclarableOpsTests8, Test_Reduce_Prod_1) { auto z = result.at(0); //z->printIndexedBuffer("Result is "); ASSERT_TRUE(exp.equalsTo(z)); - + } //////////////////////////////////////////////////////////////////////////////// @@ -786,7 +786,7 @@ TEST_F(DeclarableOpsTests8, Test_Reduce_Prod_2) { auto z = result.at(0); // z->printIndexedBuffer("Result is "); ASSERT_TRUE(exp.equalsTo(z)); - + } //////////////////////////////////////////////////////////////////////////////// @@ -805,7 +805,7 @@ TEST_F(DeclarableOpsTests8, Test_Reduce_Sum_01) { ASSERT_TRUE(exp.isSameShape(output)); ASSERT_TRUE(exp.equalsTo(output)); - + } //////////////////////////////////////////////////////////////////////////////// @@ -825,7 +825,7 @@ TEST_F(DeclarableOpsTests8, Test_Reduce_Sum_02) { ASSERT_TRUE(exp.isSameShape(output)); ASSERT_TRUE(exp.equalsTo(output)); - + } //////////////////////////////////////////////////////////////////////////////// @@ -845,7 +845,7 @@ TEST_F(DeclarableOpsTests8, Test_Reduce_Sum_3) { ASSERT_TRUE(exp.isSameShape(output)); ASSERT_TRUE(exp.equalsTo(output)); - + } //////////////////////////////////////////////////////////////////////////////// @@ -865,7 +865,7 @@ TEST_F(DeclarableOpsTests8, Test_Reduce_Sum_4) { ASSERT_TRUE(exp.isSameShape(output)); ASSERT_TRUE(exp.equalsTo(output)); - + } //////////////////////////////////////////////////////////////////////////////// @@ -885,7 +885,7 @@ TEST_F(DeclarableOpsTests8, Test_Reduce_Sum_5) { ASSERT_TRUE(exp.isSameShape(output)); ASSERT_TRUE(exp.equalsTo(output)); - + } //////////////////////////////////////////////////////////////////////////////// @@ -905,7 +905,7 @@ TEST_F(DeclarableOpsTests8, Test_Reduce_Sum_6) { ASSERT_TRUE(exp.isSameShape(output)); ASSERT_TRUE(exp.equalsTo(output)); - + } //////////////////////////////////////////////////////////////////////////////// @@ -925,7 +925,7 @@ TEST_F(DeclarableOpsTests8, Test_Reduce_Sum_7) { ASSERT_TRUE(exp.isSameShape(output)); ASSERT_TRUE(exp.equalsTo(output)); - + } //////////////////////////////////////////////////////////////////////////////// @@ -944,7 +944,7 @@ TEST_F(DeclarableOpsTests8, Test_Reduce_Prod_01) { ASSERT_TRUE(exp.isSameShape(output)); ASSERT_TRUE(exp.equalsTo(output)); - + } //////////////////////////////////////////////////////////////////////////////// @@ -964,7 +964,7 @@ TEST_F(DeclarableOpsTests8, Test_Reduce_Prod_02) { ASSERT_TRUE(exp.isSameShape(output)); ASSERT_TRUE(exp.equalsTo(output)); - + } //////////////////////////////////////////////////////////////////////////////// @@ -984,7 +984,7 @@ TEST_F(DeclarableOpsTests8, Test_Reduce_Prod_3) { ASSERT_TRUE(exp.isSameShape(output)); ASSERT_TRUE(exp.equalsTo(output)); - + } //////////////////////////////////////////////////////////////////////////////// @@ -1004,7 +1004,7 @@ TEST_F(DeclarableOpsTests8, Test_Reduce_Prod_4) { ASSERT_TRUE(exp.isSameShape(output)); ASSERT_TRUE(exp.equalsTo(output)); - + } //////////////////////////////////////////////////////////////////////////////// @@ -1025,7 +1025,7 @@ TEST_F(DeclarableOpsTests8, Test_Reduce_Prod_04) { ASSERT_TRUE(exp.isSameShape(output)); ASSERT_TRUE(exp.equalsTo(output)); - + } //////////////////////////////////////////////////////////////////////////////// @@ -1045,7 +1045,7 @@ TEST_F(DeclarableOpsTests8, Test_Reduce_Prod_5) { ASSERT_TRUE(exp.isSameShape(output)); ASSERT_TRUE(exp.equalsTo(output)); - + } //////////////////////////////////////////////////////////////////////////////// @@ -1065,7 +1065,7 @@ TEST_F(DeclarableOpsTests8, Test_Reduce_Prod_6) { ASSERT_TRUE(exp.isSameShape(output)); ASSERT_TRUE(exp.equalsTo(output)); - + } //////////////////////////////////////////////////////////////////////////////// @@ -1085,7 +1085,7 @@ TEST_F(DeclarableOpsTests8, Test_Reduce_Prod_7) { ASSERT_TRUE(exp.isSameShape(output)); ASSERT_TRUE(exp.equalsTo(output)); - + } //////////////////////////////////////////////////////////////////////////////// @@ -1104,7 +1104,7 @@ TEST_F(DeclarableOpsTests8, Test_Reduce_Min_1) { ASSERT_TRUE(exp.isSameShape(output)); ASSERT_TRUE(exp.equalsTo(output)); - + } //////////////////////////////////////////////////////////////////////////////// @@ -1124,7 +1124,7 @@ TEST_F(DeclarableOpsTests8, Test_Reduce_Min_2) { ASSERT_TRUE(exp.isSameShape(output)); ASSERT_TRUE(exp.equalsTo(output)); - + } //////////////////////////////////////////////////////////////////////////////// @@ -1144,7 +1144,7 @@ TEST_F(DeclarableOpsTests8, Test_Reduce_Min_3) { ASSERT_TRUE(exp.isSameShape(output)); ASSERT_TRUE(exp.equalsTo(output)); - + } //////////////////////////////////////////////////////////////////////////////// @@ -1164,7 +1164,7 @@ TEST_F(DeclarableOpsTests8, Test_Reduce_Min_4) { ASSERT_TRUE(exp.isSameShape(output)); ASSERT_TRUE(exp.equalsTo(output)); - + } //////////////////////////////////////////////////////////////////////////////// @@ -1185,7 +1185,7 @@ TEST_F(DeclarableOpsTests8, Test_Reduce_Min_04) { ASSERT_TRUE(exp.isSameShape(output)); ASSERT_TRUE(exp.equalsTo(output)); - + } //////////////////////////////////////////////////////////////////////////////// @@ -1205,7 +1205,7 @@ TEST_F(DeclarableOpsTests8, Test_Reduce_Min_5) { ASSERT_TRUE(exp.isSameShape(output)); ASSERT_TRUE(exp.equalsTo(output)); - + } //////////////////////////////////////////////////////////////////////////////// @@ -1225,7 +1225,7 @@ TEST_F(DeclarableOpsTests8, Test_Reduce_Min_6) { ASSERT_TRUE(exp.isSameShape(output)); ASSERT_TRUE(exp.equalsTo(output)); - + } //////////////////////////////////////////////////////////////////////////////// @@ -1245,7 +1245,7 @@ TEST_F(DeclarableOpsTests8, Test_Reduce_Min_7) { ASSERT_TRUE(exp.isSameShape(output)); ASSERT_TRUE(exp.equalsTo(output)); - + } //////////////////////////////////////////////////////////////////////////////// @@ -1265,7 +1265,7 @@ TEST_F(DeclarableOpsTests8, Test_Reduce_Max_1) { ASSERT_TRUE(exp.isSameShape(output)); ASSERT_TRUE(exp.equalsTo(output)); - + } //////////////////////////////////////////////////////////////////////////////// @@ -1285,7 +1285,7 @@ TEST_F(DeclarableOpsTests8, Test_Reduce_Max_2) { ASSERT_TRUE(exp.isSameShape(output)); ASSERT_TRUE(exp.equalsTo(output)); - + } //////////////////////////////////////////////////////////////////////////////// @@ -1305,7 +1305,7 @@ TEST_F(DeclarableOpsTests8, Test_Reduce_Max_3) { ASSERT_TRUE(exp.isSameShape(output)); ASSERT_TRUE(exp.equalsTo(output)); - + } //////////////////////////////////////////////////////////////////////////////// @@ -1325,7 +1325,7 @@ TEST_F(DeclarableOpsTests8, Test_Reduce_Max_4) { ASSERT_TRUE(exp.isSameShape(output)); ASSERT_TRUE(exp.equalsTo(output)); - + } //////////////////////////////////////////////////////////////////////////////// @@ -1346,7 +1346,7 @@ TEST_F(DeclarableOpsTests8, Test_Reduce_Max_04) { ASSERT_TRUE(exp.isSameShape(output)); ASSERT_TRUE(exp.equalsTo(output)); - + } //////////////////////////////////////////////////////////////////////////////// @@ -1366,7 +1366,7 @@ TEST_F(DeclarableOpsTests8, Test_Reduce_Max_5) { ASSERT_TRUE(exp.isSameShape(output)); ASSERT_TRUE(exp.equalsTo(output)); - + } //////////////////////////////////////////////////////////////////////////////// @@ -1386,7 +1386,7 @@ TEST_F(DeclarableOpsTests8, Test_Reduce_Max_6) { ASSERT_TRUE(exp.isSameShape(output)); ASSERT_TRUE(exp.equalsTo(output)); - + } //////////////////////////////////////////////////////////////////////////////// @@ -1406,7 +1406,7 @@ TEST_F(DeclarableOpsTests8, Test_Reduce_Max_7) { ASSERT_TRUE(exp.isSameShape(output)); ASSERT_TRUE(exp.equalsTo(output)); - + } TEST_F(DeclarableOpsTests8, Test_Reduce_Norm1_1) { @@ -1424,7 +1424,7 @@ TEST_F(DeclarableOpsTests8, Test_Reduce_Norm1_1) { ASSERT_TRUE(exp.isSameShape(output)); ASSERT_TRUE(exp.equalsTo(output)); - + } //////////////////////////////////////////////////////////////////////////////// @@ -1444,7 +1444,7 @@ TEST_F(DeclarableOpsTests8, Test_Reduce_Norm1_2) { ASSERT_TRUE(exp.isSameShape(output)); ASSERT_TRUE(exp.equalsTo(output)); - + } //////////////////////////////////////////////////////////////////////////////// @@ -1464,7 +1464,7 @@ TEST_F(DeclarableOpsTests8, Test_Reduce_Norm1_3) { ASSERT_TRUE(exp.isSameShape(output)); ASSERT_TRUE(exp.equalsTo(output)); - + } //////////////////////////////////////////////////////////////////////////////// @@ -1484,7 +1484,7 @@ TEST_F(DeclarableOpsTests8, Test_Reduce_Norm1_4) { ASSERT_TRUE(exp.isSameShape(output)); ASSERT_TRUE(exp.equalsTo(output)); - + } //////////////////////////////////////////////////////////////////////////////// @@ -1505,7 +1505,7 @@ TEST_F(DeclarableOpsTests8, Test_Reduce_Norm1_04) { ASSERT_TRUE(exp.isSameShape(output)); ASSERT_TRUE(exp.equalsTo(output)); - + } //////////////////////////////////////////////////////////////////////////////// @@ -1525,7 +1525,7 @@ TEST_F(DeclarableOpsTests8, Test_Reduce_Norm1_5) { ASSERT_TRUE(exp.isSameShape(output)); ASSERT_TRUE(exp.equalsTo(output)); - + } //////////////////////////////////////////////////////////////////////////////// @@ -1545,7 +1545,7 @@ TEST_F(DeclarableOpsTests8, Test_Reduce_Norm1_6) { ASSERT_TRUE(exp.isSameShape(output)); ASSERT_TRUE(exp.equalsTo(output)); - + } //////////////////////////////////////////////////////////////////////////////// @@ -1565,7 +1565,7 @@ TEST_F(DeclarableOpsTests8, Test_Reduce_Norm1_7) { ASSERT_TRUE(exp.isSameShape(output)); ASSERT_TRUE(exp.equalsTo(output)); - + } TEST_F(DeclarableOpsTests8, Test_Reduce_Norm2_1) { @@ -1583,7 +1583,7 @@ TEST_F(DeclarableOpsTests8, Test_Reduce_Norm2_1) { ASSERT_TRUE(exp.isSameShape(output)); ASSERT_TRUE(exp.equalsTo(output)); - + } //////////////////////////////////////////////////////////////////////////////// @@ -1603,7 +1603,7 @@ TEST_F(DeclarableOpsTests8, Test_Reduce_Norm2_2) { ASSERT_TRUE(exp.isSameShape(output)); ASSERT_TRUE(exp.equalsTo(output)); - + } //////////////////////////////////////////////////////////////////////////////// @@ -1623,7 +1623,7 @@ TEST_F(DeclarableOpsTests8, Test_Reduce_Norm2_3) { ASSERT_TRUE(exp.isSameShape(output)); ASSERT_TRUE(exp.equalsTo(output)); - + } //////////////////////////////////////////////////////////////////////////////// @@ -1643,7 +1643,7 @@ TEST_F(DeclarableOpsTests8, Test_Reduce_Norm2_4) { ASSERT_TRUE(exp.isSameShape(output)); ASSERT_TRUE(exp.equalsTo(output)); - + } //////////////////////////////////////////////////////////////////////////////// @@ -1664,7 +1664,7 @@ TEST_F(DeclarableOpsTests8, Test_Reduce_Norm2_04) { ASSERT_TRUE(exp.isSameShape(output)); ASSERT_TRUE(exp.equalsTo(output)); - + } //////////////////////////////////////////////////////////////////////////////// @@ -1684,7 +1684,7 @@ TEST_F(DeclarableOpsTests8, Test_Reduce_Norm2_5) { ASSERT_TRUE(exp.isSameShape(output)); ASSERT_TRUE(exp.equalsTo(output)); - + } //////////////////////////////////////////////////////////////////////////////// @@ -1704,7 +1704,7 @@ TEST_F(DeclarableOpsTests8, Test_Reduce_Norm2_6) { ASSERT_TRUE(exp.isSameShape(output)); ASSERT_TRUE(exp.equalsTo(output)); - + } //////////////////////////////////////////////////////////////////////////////// @@ -1724,7 +1724,7 @@ TEST_F(DeclarableOpsTests8, Test_Reduce_Norm2_7) { ASSERT_TRUE(exp.isSameShape(output)); ASSERT_TRUE(exp.equalsTo(output)); - + } //////////////////////////////////////////////////////////////////////////////// @@ -1743,7 +1743,7 @@ TEST_F(DeclarableOpsTests8, Test_Reduce_NormMax_1) { ASSERT_TRUE(exp.isSameShape(output)); ASSERT_TRUE(exp.equalsTo(output)); - + } //////////////////////////////////////////////////////////////////////////////// @@ -1762,7 +1762,7 @@ TEST_F(DeclarableOpsTests8, Test_Reduce_NormMax_2) { ASSERT_TRUE(exp.isSameShape(output)); ASSERT_TRUE(exp.equalsTo(output)); - + } //////////////////////////////////////////////////////////////////////////////// @@ -1781,7 +1781,7 @@ TEST_F(DeclarableOpsTests8, Test_Reduce_NormMax_3) { ASSERT_TRUE(exp.isSameShape(output)); ASSERT_TRUE(exp.equalsTo(output)); - + } //////////////////////////////////////////////////////////////////////////////// @@ -1800,7 +1800,7 @@ TEST_F(DeclarableOpsTests8, Test_Reduce_NormMax_4) { ASSERT_TRUE(exp.isSameShape(output)); ASSERT_TRUE(exp.equalsTo(output)); - + } //////////////////////////////////////////////////////////////////////////////// @@ -1820,7 +1820,7 @@ TEST_F(DeclarableOpsTests8, Test_Reduce_NormMax_04) { ASSERT_TRUE(exp.isSameShape(output)); ASSERT_TRUE(exp.equalsTo(output)); - + } //////////////////////////////////////////////////////////////////////////////// @@ -1840,7 +1840,7 @@ TEST_F(DeclarableOpsTests8, Test_Reduce_NormMax_5) { ASSERT_TRUE(exp.isSameShape(output)); ASSERT_TRUE(exp.equalsTo(output)); - + } //////////////////////////////////////////////////////////////////////////////// @@ -1860,7 +1860,7 @@ TEST_F(DeclarableOpsTests8, Test_Reduce_NormMax_6) { ASSERT_TRUE(exp.isSameShape(output)); ASSERT_TRUE(exp.equalsTo(output)); - + } //////////////////////////////////////////////////////////////////////////////// @@ -1880,7 +1880,7 @@ TEST_F(DeclarableOpsTests8, Test_Reduce_NormMax_7) { ASSERT_TRUE(exp.isSameShape(output)); ASSERT_TRUE(exp.equalsTo(output)); - + } //////////////////////////////////////////////////////////////////////////////// @@ -1899,7 +1899,7 @@ TEST_F(DeclarableOpsTests8, Test_Reduce_SquaredNorm_1) { ASSERT_TRUE(exp.isSameShape(output)); ASSERT_TRUE(exp.equalsTo(output)); - + } //////////////////////////////////////////////////////////////////////////////// @@ -1918,7 +1918,7 @@ TEST_F(DeclarableOpsTests8, Test_Reduce_SquaredNorm_2) { ASSERT_TRUE(exp.isSameShape(output)); ASSERT_TRUE(exp.equalsTo(output)); - + } //////////////////////////////////////////////////////////////////////////////// @@ -1937,7 +1937,7 @@ TEST_F(DeclarableOpsTests8, Test_Reduce_SquaredNorm_3) { ASSERT_TRUE(exp.isSameShape(output)); ASSERT_TRUE(exp.equalsTo(output)); - + } //////////////////////////////////////////////////////////////////////////////// @@ -1956,7 +1956,7 @@ TEST_F(DeclarableOpsTests8, Test_Reduce_SquaredNorm_4) { ASSERT_TRUE(exp.isSameShape(output)); ASSERT_TRUE(exp.equalsTo(output)); - + } //////////////////////////////////////////////////////////////////////////////// @@ -1976,7 +1976,7 @@ TEST_F(DeclarableOpsTests8, Test_Reduce_SquaredNorm_04) { ASSERT_TRUE(exp.isSameShape(output)); ASSERT_TRUE(exp.equalsTo(output)); - + } //////////////////////////////////////////////////////////////////////////////// @@ -1996,7 +1996,7 @@ TEST_F(DeclarableOpsTests8, Test_Reduce_SquaredNorm_5) { ASSERT_TRUE(exp.isSameShape(output)); ASSERT_TRUE(exp.equalsTo(output)); - + } //////////////////////////////////////////////////////////////////////////////// @@ -2016,7 +2016,7 @@ TEST_F(DeclarableOpsTests8, Test_Reduce_SquaredNorm_6) { ASSERT_TRUE(exp.isSameShape(output)); ASSERT_TRUE(exp.equalsTo(output)); - + } //////////////////////////////////////////////////////////////////////////////// @@ -2036,7 +2036,7 @@ TEST_F(DeclarableOpsTests8, Test_Reduce_SquaredNorm_7) { ASSERT_TRUE(exp.isSameShape(output)); ASSERT_TRUE(exp.equalsTo(output)); - + } //////////////////////////////////////////////////////////////////////////////// @@ -2055,7 +2055,7 @@ TEST_F(DeclarableOpsTests8, Test_Reduce_Sum_BP_1) { // z->printIndexedBuffer("Result is "); // z->printShapeInfo(); ASSERT_TRUE(exp.equalsTo(z)); - + } //////////////////////////////////////////////////////////////////////////////// @@ -2076,7 +2076,7 @@ TEST_F(DeclarableOpsTests8, Test_Reduce_Sum_BP_2) { // z->printIndexedBuffer("Result is "); // z->printShapeInfo(); ASSERT_TRUE(exp.equalsTo(z)); - + } //////////////////////////////////////////////////////////////////////////////// @@ -2097,7 +2097,7 @@ TEST_F(DeclarableOpsTests8, Test_Reduce_Sum_BP_3) { // z->printIndexedBuffer("Result is "); // z->printShapeInfo(); ASSERT_TRUE(exp.equalsTo(z)); - + } //////////////////////////////////////////////////////////////////////////////// @@ -2118,7 +2118,7 @@ TEST_F(DeclarableOpsTests8, Test_Reduce_Sum_BP_4) { // z->printIndexedBuffer("Result is "); // z->printShapeInfo(); ASSERT_TRUE(exp.equalsTo(z)); - + } //////////////////////////////////////////////////////////////////////////////// @@ -2141,7 +2141,7 @@ TEST_F(DeclarableOpsTests8, Test_Reduce_Sum_BP_04) { // z->printIndexedBuffer("Result is "); // z->printShapeInfo(); ASSERT_TRUE(exp.equalsTo(z)); - + } //////////////////////////////////////////////////////////////////////////////// @@ -2166,7 +2166,7 @@ TEST_F(DeclarableOpsTests8, Test_Reduce_Prod_BP_1) { // z->printIndexedBuffer("Result is "); // z->printShapeInfo(); ASSERT_TRUE(exp.equalsTo(z)); - + } //////////////////////////////////////////////////////////////////////////////// @@ -2186,7 +2186,7 @@ TEST_F(DeclarableOpsTests8, reduceMean_test1) { ASSERT_TRUE(exp.isSameShape(output)); ASSERT_TRUE(exp.equalsTo(output)); - + } //////////////////////////////////////////////////////////////////////////////// @@ -2206,7 +2206,7 @@ TEST_F(DeclarableOpsTests8, reduceMean_test2) { ASSERT_TRUE(exp.isSameShape(output)); ASSERT_TRUE(exp.equalsTo(output)); - + } //////////////////////////////////////////////////////////////////////////////// @@ -2226,7 +2226,7 @@ TEST_F(DeclarableOpsTests8, reduceMean_test3) { ASSERT_TRUE(exp.isSameShape(output)); ASSERT_TRUE(exp.equalsTo(output)); - + } //////////////////////////////////////////////////////////////////////////////// @@ -2246,7 +2246,7 @@ TEST_F(DeclarableOpsTests8, reduceMean_test4) { ASSERT_TRUE(exp.isSameShape(output)); ASSERT_TRUE(exp.equalsTo(output)); - + } //////////////////////////////////////////////////////////////////////////////// @@ -2266,7 +2266,7 @@ TEST_F(DeclarableOpsTests8, reduceMean_test5) { ASSERT_TRUE(exp.isSameShape(output)); ASSERT_TRUE(exp.equalsTo(output)); - + } //////////////////////////////////////////////////////////////////////////////// @@ -2285,7 +2285,7 @@ TEST_F(DeclarableOpsTests8, reduceMean_test6) { ASSERT_TRUE(exp.isSameShape(output)); ASSERT_TRUE(exp.equalsTo(output)); - + } //////////////////////////////////////////////////////////////////////////////// @@ -2304,7 +2304,7 @@ TEST_F(DeclarableOpsTests8, reduceMean_test7) { ASSERT_TRUE(exp.isSameShape(output)); ASSERT_TRUE(exp.equalsTo(output)); - + } //////////////////////////////////////////////////////////////////////////////// @@ -2324,7 +2324,7 @@ TEST_F(DeclarableOpsTests8, reduceMean_test8) { ASSERT_TRUE(exp.isSameShape(output)); ASSERT_TRUE(exp.equalsTo(output)); - + } //////////////////////////////////////////////////////////////////////////////// @@ -2347,14 +2347,14 @@ TEST_F(DeclarableOpsTests8, reduceMeanBP_test1) { ASSERT_TRUE(exp.isSameShape(output)); ASSERT_TRUE(exp.equalsTo(output)); - + result = op.evaluate({&x, &gradO2}, {1}, {}); ASSERT_EQ(Status::OK(), result.status()); output = result.at(0); ASSERT_TRUE(exp.isSameShape(output)); ASSERT_TRUE(exp.equalsTo(output)); - + } @@ -2375,14 +2375,14 @@ TEST_F(DeclarableOpsTests8, reduceMeanBP_test2) { auto output = result.at(0); ASSERT_TRUE(exp.isSameShape(output)); ASSERT_TRUE(exp.equalsTo(output)); - + result = op.evaluate({&x, &gradO2}, {1}, {0}); ASSERT_EQ(Status::OK(), result.status()); output = result.at(0); ASSERT_TRUE(exp.isSameShape(output)); ASSERT_TRUE(exp.equalsTo(output)); - + } //////////////////////////////////////////////////////////////////////////////// @@ -2403,14 +2403,14 @@ TEST_F(DeclarableOpsTests8, reduceMeanBP_test02) { auto output = result.at(0); ASSERT_TRUE(exp.isSameShape(output)); ASSERT_TRUE(exp.equalsTo(output)); - + result = op.evaluate({&x, &gradO2, &axis}, {}, {}, {true}); ASSERT_EQ(Status::OK(), result.status()); output = result.at(0); ASSERT_TRUE(exp.isSameShape(output)); ASSERT_TRUE(exp.equalsTo(output)); - + } //////////////////////////////////////////////////////////////////////////////// @@ -2430,14 +2430,14 @@ TEST_F(DeclarableOpsTests8, reduceMeanBP_test3) { auto output = result.at(0); ASSERT_TRUE(exp.isSameShape(output)); ASSERT_TRUE(exp.equalsTo(output)); - + result = op.evaluate({&x, &gradO2}, {1}, {1}); ASSERT_EQ(Status::OK(), result.status()); output = result.at(0); ASSERT_TRUE(exp.isSameShape(output)); ASSERT_TRUE(exp.equalsTo(output)); - + } //////////////////////////////////////////////////////////////////////////////// @@ -2455,7 +2455,7 @@ TEST_F(DeclarableOpsTests8, reduceStDevBP_test4) { ASSERT_TRUE(exp.isSameShape(output)); ASSERT_TRUE(exp.equalsTo(output)); - + } @@ -2478,7 +2478,7 @@ TEST_F(DeclarableOpsTests8, softmax_cross_entropy_loss_with_logits_test1) { ASSERT_TRUE(expected.isSameShape(output)); ASSERT_TRUE(expected.equalsTo(output)); - + } /////////////////////////////////////////////////////////////////// @@ -2500,7 +2500,7 @@ TEST_F(DeclarableOpsTests8, softmax_cross_entropy_loss_with_logits_test2) { ASSERT_TRUE(expected.isSameShape(output)); ASSERT_TRUE(expected.equalsTo(output)); - + } /////////////////////////////////////////////////////////////////// @@ -2522,7 +2522,7 @@ TEST_F(DeclarableOpsTests8, softmax_cross_entropy_loss_with_logits_test3) { ASSERT_TRUE(expected.isSameShape(output)); ASSERT_TRUE(expected.equalsTo(output)); - + } /////////////////////////////////////////////////////////////////// @@ -2544,7 +2544,7 @@ TEST_F(DeclarableOpsTests8, softmax_cross_entropy_loss_with_logits_test4) { ASSERT_TRUE(expected.isSameShape(output)); ASSERT_TRUE(expected.equalsTo(output)); - + } /////////////////////////////////////////////////////////////////// @@ -2566,7 +2566,7 @@ TEST_F(DeclarableOpsTests8, softmax_cross_entropy_loss_with_logits_test5) { ASSERT_TRUE(expected.isSameShape(output)); ASSERT_TRUE(expected.equalsTo(output)); - + } /////////////////////////////////////////////////////////////////// @@ -2588,7 +2588,7 @@ TEST_F(DeclarableOpsTests8, softmax_cross_entropy_loss_with_logits_test6) { ASSERT_TRUE(expected.isSameShape(output)); ASSERT_TRUE(expected.equalsTo(output)); - + } /////////////////////////////////////////////////////////////////// @@ -2610,7 +2610,7 @@ TEST_F(DeclarableOpsTests8, softmax_cross_entropy_loss_with_logits_test7) { ASSERT_TRUE(expected.isSameShape(output)); ASSERT_TRUE(expected.equalsTo(output)); - + } /////////////////////////////////////////////////////////////////// @@ -2632,7 +2632,7 @@ TEST_F(DeclarableOpsTests8, softmax_cross_entropy_loss_with_logits_test8) { ASSERT_TRUE(expected.isSameShape(output)); ASSERT_TRUE(expected.equalsTo(output)); - + } /////////////////////////////////////////////////////////////////// @@ -2652,7 +2652,7 @@ TEST_F(DeclarableOpsTests8, softmax_cross_entropy_loss_with_logits_test9) { ASSERT_TRUE(expected.isSameShape(output)); ASSERT_TRUE(expected.equalsTo(output)); - + } /////////////////////////////////////////////////////////////////// @@ -2674,161 +2674,7 @@ TEST_F(DeclarableOpsTests8, softmax_cross_entropy_loss_with_logits_test10) { ASSERT_TRUE(expected.isSameShape(output)); ASSERT_TRUE(expected.equalsTo(output)); - -} -//////////////////////////////////////////////////////////////////////////////// -TEST_F(DeclarableOpsTests8, clipbynorm_test4) { - - auto x = NDArrayFactory::create('c', {3, 5}, {0.7044955, 0.55606544, 0.15833677, 0.001874401, 0.61595726, 0.3924779, 0.7414847, 0.4127324, 0.24026828, 0.26093036, 0.46741188, 0.01863421, 0.08528871, 0.529365, 0.5510694}); - auto exp = NDArrayFactory::create('c', {3, 5}, {0.405392, 0.319980, 0.091113, 0.001079, 0.354444, 0.225846, 0.426676, 0.237501, 0.138259, 0.150149, 0.268965, 0.010723, 0.049078, 0.304615, 0.317105}); - - sd::ops::clipbynorm op; - auto result = op.evaluate({&x}, {1.f}, {}); - auto output = result.at(0); - - ASSERT_TRUE(exp.isSameShape(output)); - ASSERT_TRUE(exp.equalsTo(output)); - - -} - -//////////////////////////////////////////////////////////////////////////////// -TEST_F(DeclarableOpsTests8, clipbynorm_test5) { - - // auto x = NDArrayFactory::create('c', {3, 5}, {1,2,3,4,5, 1,2,3,4,5, 1,2,3,4,5}); - auto x = NDArrayFactory::create('c', {3, 5}); - auto exp = NDArrayFactory::create('c', {3, 5}, {1., 2., 2.89271, 3.50524, 4.00892, 6., 7., 7.71389, 7.88678, 8.01784, 11., 12., 12.53507, 12.26833, 12.02676}); - // auto exp = NDArrayFactory::create('c', {3, 5}, {1,1,1,1,1,1,1,1,1,1,1,1,1,1,1}); - - x.linspace(1); - - sd::ops::clipbynorm op; - auto result = op.evaluate({&x}, {15.f}, {0}); - auto output = result.at(0); - - ASSERT_TRUE(exp.isSameShape(output)); - ASSERT_TRUE(exp.equalsTo(output)); - - -} - -//////////////////////////////////////////////////////////////////////////////// -TEST_F(DeclarableOpsTests8, clipbynorm_test6) { - - auto x = NDArrayFactory::create('c', {3, 5}); - auto exp = NDArrayFactory::create('c', {3, 5}, {1., 2., 3., 4., 5., 4.95434, 5.78006, 6.60578, 7.43151, 8.25723, 5.64288, 6.15587, 6.66886, 7.18185, 7.69484}); - - x.linspace(1); - - sd::ops::clipbynorm op; - auto result = op.evaluate({&x}, {15.f}, {1}); - auto output = result.at(0); - - ASSERT_TRUE(exp.isSameShape(output)); - ASSERT_TRUE(exp.equalsTo(output)); - - -} - -//////////////////////////////////////////////////////////////////////////////// -TEST_F(DeclarableOpsTests8, clipbynorm_test7) { - - auto x = NDArrayFactory::create('c', {3, 5}); - auto exp = NDArrayFactory::create('c', {3, 5}, {0.42597, 0.85194, 1.27791, 1.70389, 2.12986, 2.55583, 2.9818 , 3.40777, 3.83374, 4.25971, 4.68569, 5.11166, 5.53763, 5.9636 , 6.38957}); - - x.linspace(1); - - sd::ops::clipbynorm op; - auto result = op.evaluate({&x}, {15.f}, {0,1}); - auto output = result.at(0); - - ASSERT_TRUE(exp.isSameShape(output)); - ASSERT_TRUE(exp.equalsTo(output)); - - -} - -//////////////////////////////////////////////////////////////////////////////// -TEST_F(DeclarableOpsTests8, clipbynorm_test8) { - - auto x = NDArrayFactory::create('c', {3, 5}); - auto exp = NDArrayFactory::create('c', {3, 5}, {0.42597, 0.85194, 1.27791, 1.70389, 2.12986, 2.55583, 2.9818 , 3.40777, 3.83374, 4.25971, 4.68569, 5.11166, 5.53763, 5.9636 , 6.38957}); - - x.linspace(1); - - sd::ops::clipbynorm op; - auto result = op.evaluate({&x}, {15.}, {}); - auto output = result.at(0); - - ASSERT_TRUE(exp.isSameShape(output)); - ASSERT_TRUE(exp.equalsTo(output)); - - -} - -//////////////////////////////////////////////////////////////////////////////// -TEST_F(DeclarableOpsTests8, clipbynorm_test9) { - - auto x = NDArrayFactory::create('c', {2}, {3., 4.}); - auto exp = NDArrayFactory::create('c', {2}, {2.4, 3.2}); - - sd::ops::clipbynorm op; - auto result = op.evaluate({&x}, {4.}, {}); - auto output = result.at(0); - - ASSERT_TRUE(exp.isSameShape(output)); - ASSERT_TRUE(exp.equalsTo(output)); - - -} - -//////////////////////////////////////////////////////////////////////////////// -TEST_F(DeclarableOpsTests8, clipbynorm_test10) { - - auto x = NDArrayFactory::create(6.); - auto exp = NDArrayFactory::create(5.); - - sd::ops::clipbynorm op; - auto result = op.evaluate({&x}, {5.}, {}); - auto output = result.at(0); - - ASSERT_TRUE(exp.isSameShape(output)); - ASSERT_TRUE(exp.equalsTo(output)); - - -} - -//////////////////////////////////////////////////////////////////////////////// -TEST_F(DeclarableOpsTests8, clipbynorm_test11) { - - auto x = NDArrayFactory::create('c', {2, 3, 4}); - auto exp = NDArrayFactory::create('c', {2, 3, 4}, {1., 2., 3., 4., 4.44787, 5.33745, 6.22702, 7.1166 , 6.33046, 7.03384, 7.73723, 8.44061, - 13., 14., 15., 16., 15.12277, 16.01235, 16.90192, 17.7915 ,14.77107, 15.47446, 16.17784, 16.88123}); - - x.linspace(1); - - sd::ops::clipbynorm op; - auto result = op.evaluate({&x}, {35.}, {0, 2}); - auto output = result.at(0); - - ASSERT_TRUE(exp.isSameShape(output)); - ASSERT_TRUE(exp.equalsTo(output)); - - -} - - -TEST_F(DeclarableOpsTests8, clipbynorm_test_tf_119_1) { - auto x = NDArrayFactory::create('c', {3, 3}, {1, 2, 3, 4, 5,6, 7, 8, 9}); - auto e = NDArrayFactory::create('c', {3, 3}, {0.03198684, 0.06397368, 0.09596053, 0.12794736, 0.15993419, 0.19192106, 0.22390789, 0.25589472, 0.28788155}); - - sd::ops::clipbynorm op; - auto result = op.evaluate({&x}, {0.54}, {}); - - ASSERT_EQ(e, *result.at(0)); - - } //////////////////////////////////////////////////////////////////////////////// @@ -2846,14 +2692,14 @@ TEST_F(DeclarableOpsTests8, reduceMeanBP_test4) { auto output = result.at(0); ASSERT_TRUE(exp.isSameShape(output)); ASSERT_TRUE(exp.equalsTo(output)); - + result = op.evaluate({&x, &gradO2}, {1}, {0}); ASSERT_EQ(Status::OK(), result.status()); output = result.at(0); ASSERT_TRUE(exp.isSameShape(output)); ASSERT_TRUE(exp.equalsTo(output)); - + } @@ -2872,14 +2718,14 @@ TEST_F(DeclarableOpsTests8, reduceMeanBP_test5) { auto output = result.at(0); ASSERT_TRUE(exp.isSameShape(output)); ASSERT_TRUE(exp.equalsTo(output)); - + result = op.evaluate({&x, &gradO2}, {1}, {1}); ASSERT_EQ(Status::OK(), result.status()); output = result.at(0); ASSERT_TRUE(exp.isSameShape(output)); ASSERT_TRUE(exp.equalsTo(output)); - + } @@ -2898,14 +2744,14 @@ TEST_F(DeclarableOpsTests8, reduceStDevBP_test5) { auto output = result.at(0); ASSERT_TRUE(exp.isSameShape(output)); ASSERT_TRUE(exp.equalsTo(output)); - + result = op.evaluate({&x, &gradO2}, {1}, {0}); ASSERT_EQ(Status::OK(), result.status()); output = result.at(0); ASSERT_TRUE(exp.isSameShape(output)); ASSERT_TRUE(exp.equalsTo(output)); - + } //////////////////////////////////////////////////////////////////////////////// @@ -2940,7 +2786,7 @@ TEST_F(DeclarableOpsTests8, zeros_as_test2) { ASSERT_TRUE(y->isSameShape(exp)); ASSERT_TRUE(y->equalsTo(exp)); - + } //////////////////////////////////////////////////////////////////////////////// @@ -2975,7 +2821,7 @@ TEST_F(DeclarableOpsTests8, ones_as_test2) { ASSERT_TRUE(y->isSameShape(exp)); ASSERT_TRUE(y->equalsTo(exp)); - + } //////////////////////////////////////////////////////////////////////////////// @@ -2994,7 +2840,7 @@ TEST_F(DeclarableOpsTests8, ones_as_test3) { ASSERT_TRUE(y->isSameShape(exp)); ASSERT_TRUE(y->equalsTo(exp)); - + } //////////////////////////////////////////////////////////////////////////////// @@ -3042,7 +2888,7 @@ TEST_F(DeclarableOpsTests8, NormalizeMoments_SGO_1) { // ASSERT_TRUE(expMeans.isSameShape(outputDeviance)); // ASSERT_TRUE(expDeviance.equalsTo(outputDeviance)); - + } //////////////////////////////////////////////////////////////////////////////// @@ -3073,7 +2919,7 @@ TEST_F(DeclarableOpsTests8, Test_Moments_1) { ASSERT_TRUE(expVariance.isSameShape(outputVariance)); ASSERT_TRUE(expVariance.equalsTo(outputVariance)); - + } //////////////////////////////////////////////////////////////////////////////// @@ -3102,7 +2948,7 @@ TEST_F(DeclarableOpsTests8, Test_Moments_2) { ASSERT_TRUE(expVariance.isSameShape(outputVariance)); ASSERT_TRUE(expVariance.equalsTo(outputVariance)); - + } //////////////////////////////////////////////////////////////////////////////// @@ -3131,7 +2977,7 @@ TEST_F(DeclarableOpsTests8, Test_Moments_3) { ASSERT_TRUE(expVariance.isSameShape(outputVariance)); ASSERT_TRUE(expVariance.equalsTo(outputVariance)); - + } //////////////////////////////////////////////////////////////////////////////// @@ -3160,7 +3006,7 @@ TEST_F(DeclarableOpsTests8, Test_Moments_4) { ASSERT_TRUE(expVariance.isSameShape(outputVariance)); ASSERT_TRUE(expVariance.equalsTo(outputVariance)); - + } //////////////////////////////////////////////////////////////////////////////// @@ -3187,7 +3033,7 @@ TEST_F(DeclarableOpsTests8, Test_Moments_6) { ASSERT_TRUE(expVariance.isSameShape(outputVariance)); ASSERT_TRUE(expVariance.equalsTo(outputVariance)); - + } //////////////////////////////////////////////////////////////////////////////// @@ -3215,7 +3061,7 @@ TEST_F(DeclarableOpsTests8, Test_Moments_7) { ASSERT_TRUE(expVariance.isSameShape(outputVariance)); ASSERT_TRUE(expVariance.equalsTo(outputVariance)); - + } //////////////////////////////////////////////////////////////////////////////// @@ -3239,7 +3085,7 @@ TYPED_TEST(TypedDeclarableOpsTests8, LrnTest_01) { //exp.printBuffer("LRN exp"); ASSERT_TRUE(exp.equalsTo(out)); - + } //////////////////////////////////////////////////////////////////////////////// @@ -3261,7 +3107,7 @@ TYPED_TEST(TypedDeclarableOpsTests8, LrnTest_02) { // exp.printIndexedBuffer("LRN exp"); ASSERT_TRUE(exp.equalsTo(out)); - + } //////////////////////////////////////////////////////////////////////////////// @@ -3280,7 +3126,7 @@ TYPED_TEST(TypedDeclarableOpsTests8, LrnTest_03) { // exp.printIndexedBuffer("LRN exp"); ASSERT_TRUE(exp.equalsTo(out)); - + } //////////////////////////////////////////////////////////////////////////////// @@ -3309,7 +3155,7 @@ TYPED_TEST(TypedDeclarableOpsTests8, LrnTest_1) { // exp.printIndexedBuffer("LRN exp"); ASSERT_TRUE(exp.equalsTo(out)); - + } //////////////////////////////////////////////////////////////////////////////// @@ -3386,7 +3232,7 @@ TYPED_TEST(TypedDeclarableOpsTests8, LrnTest_2) { // exp.printIndexedBuffer("LRN exp"); ASSERT_TRUE(exp.equalsTo(out)); - + } //////////////////////////////////////////////////////////////////////////////// @@ -3463,7 +3309,7 @@ TYPED_TEST(TypedDeclarableOpsTests8, LrnTest_3) { // exp.printIndexedBuffer("LRN exp"); ASSERT_TRUE(exp.equalsTo(out)); - + } //////////////////////////////////////////////////////////////////////////////// @@ -3483,7 +3329,7 @@ TYPED_TEST(TypedDeclarableOpsTests8, LrnTest_4) { // exp.printIndexedBuffer("LRN exp"); // ASSERT_TRUE(exp.equalsTo(out)); - + } //////////////////////////////////////////////////////////////////////////////// @@ -3535,7 +3381,7 @@ TYPED_TEST(TypedDeclarableOpsTests8, LrnTest_BP_01) { //exp.printBuffer("LRN BP exp"); //ASSERT_TRUE(exp.equalsTo(out)); - + } //////////////////////////////////////////////////////////////////////////////// @@ -3591,7 +3437,7 @@ auto exp = NDArrayFactory::create('c', {3,3,5,5}, { // exp.printBuffer("LRN BP exp"); //ASSERT_TRUE(exp.equalsTo(out)); - + } //////////////////////////////////////////////////////////////////////////////// @@ -3671,7 +3517,7 @@ TYPED_TEST(TypedDeclarableOpsTests8, LrnTest_BP_2) { // exp.printIndexedBuffer("LRN exp"); // ASSERT_TRUE(exp.equalsTo(out)); - + } diff --git a/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests9.cpp b/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests9.cpp index 556ce3bb6..949b43d25 100644 --- a/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests9.cpp +++ b/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests9.cpp @@ -236,10 +236,10 @@ TEST_F(DeclarableOpsTests9, ScalarOpTest_MixedOrders_1) { //////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests9, concat_test1) { - auto x0 = NDArrayFactory::create('c', {2,3,4}); - auto x1 = NDArrayFactory::create('c', {2,2,4}); - auto x2 = NDArrayFactory::create('c', {2,1,4}); - auto exp = NDArrayFactory::create('c', {2,6,4}, {1.f, 2.f, 3.f, 4.f, 5.f, 6.f, 7.f, 8.f, 9.f, 10.f, 11.f, 12.f, 1.f, 2.f, 3.f, 4.f, 5.f, 6.f, 7.f, 8.f, 1.f, 2.f, 3.f, 4.f, + auto x0 = NDArrayFactory::create('c', {2,3,4}); + auto x1 = NDArrayFactory::create('c', {2,2,4}); + auto x2 = NDArrayFactory::create('c', {2,1,4}); + auto exp = NDArrayFactory::create('c', {2,6,4}, {1.f, 2.f, 3.f, 4.f, 5.f, 6.f, 7.f, 8.f, 9.f, 10.f, 11.f, 12.f, 1.f, 2.f, 3.f, 4.f, 5.f, 6.f, 7.f, 8.f, 1.f, 2.f, 3.f, 4.f, 13.f, 14.f, 15.f, 16.f,17.f, 18.f, 19.f, 20.f,21.f, 22.f, 23.f, 24.f, 9.f, 10.f, 11.f, 12.f,13.f, 14.f, 15.f, 16.f, 5.f, 6.f, 7.f, 8.}); x0.linspace(1); @@ -261,10 +261,10 @@ TEST_F(DeclarableOpsTests9, concat_test1) { //////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests9, concat_test2) { - auto x0 = NDArrayFactory::create('c', {1,3,1}); - auto x1 = NDArrayFactory::create('c', {1,2,1}); - auto x2 = NDArrayFactory::create('c', {1,1,1}); - auto exp = NDArrayFactory::create('c', {1,6,1}, {1.f, 2.f, 3.f, 1.f, 2.f, 1.f}); + auto x0 = NDArrayFactory::create('c', {1,3,1}); + auto x1 = NDArrayFactory::create('c', {1,2,1}); + auto x2 = NDArrayFactory::create('c', {1,1,1}); + auto exp = NDArrayFactory::create('c', {1,6,1}, {1.f, 2.f, 3.f, 1.f, 2.f, 1.f}); x0.linspace(1); x1.linspace(1); @@ -285,10 +285,10 @@ TEST_F(DeclarableOpsTests9, concat_test2) { //////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests9, concat_test3) { - auto x0 = NDArrayFactory::create('c', {3}); - auto x1 = NDArrayFactory::create('c', {2}); - auto x2 = NDArrayFactory::create('c', {1}); - auto exp = NDArrayFactory::create('c', {6}, {1.f, 2.f, 3.f, 1.f, 2.f, 1.f}); + auto x0 = NDArrayFactory::create('c', {3}); + auto x1 = NDArrayFactory::create('c', {2}); + auto x2 = NDArrayFactory::create('c', {1}); + auto exp = NDArrayFactory::create('c', {6}, {1.f, 2.f, 3.f, 1.f, 2.f, 1.f}); x0.linspace(1); x1.linspace(1); @@ -300,21 +300,17 @@ TEST_F(DeclarableOpsTests9, concat_test3) { ASSERT_EQ(ND4J_STATUS_OK, result.status()); auto output = result.at(0); - output->printBuffer(); - ASSERT_TRUE(exp.isSameShape(output)); ASSERT_TRUE(exp.equalsTo(output)); - - } //////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests9, concat_test4) { - auto x0 = NDArrayFactory::create('c', {1,1,1}, {1.f}); - auto x1 = NDArrayFactory::create('c', {1,1,1}, {2.f}); - auto x2 = NDArrayFactory::create('c', {1,1,1}, {3.f}); - auto exp = NDArrayFactory::create('c', {1,3,1}, {1.f, 2.f, 3.f}); + auto x0 = NDArrayFactory::create('c', {1,1,1}, {1.f}); + auto x1 = NDArrayFactory::create('c', {1,1,1}, {2.f}); + auto x2 = NDArrayFactory::create('c', {1,1,1}, {3.f}); + auto exp = NDArrayFactory::create('c', {1,3,1}, {1.f, 2.f, 3.f}); sd::ops::concat op; @@ -331,10 +327,10 @@ TEST_F(DeclarableOpsTests9, concat_test4) { //////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests9, concat_test5) { - auto x0 = NDArrayFactory::create(1.f); - auto x1 = NDArrayFactory::create('c', {1}, {2.f}); - auto x2 = NDArrayFactory::create(3.f); - auto exp = NDArrayFactory::create('c', {3}, {1.f, 2.f, 3.f}); + auto x0 = NDArrayFactory::create(1.f); + auto x1 = NDArrayFactory::create('c', {1}, {2.f}); + auto x2 = NDArrayFactory::create(3.f); + auto exp = NDArrayFactory::create('c', {3}, {1.f, 2.f, 3.f}); sd::ops::concat op; @@ -351,10 +347,10 @@ TEST_F(DeclarableOpsTests9, concat_test5) { //////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests9, concat_test6) { - auto x0 = NDArrayFactory::create(1.f); - auto x1 = NDArrayFactory::create('c', {2}, {2.f, 20.f}); - auto x2 = NDArrayFactory::create(3.f); - auto exp = NDArrayFactory::create('c', {4}, {1.f, 2.f, 20.f, 3.f}); + auto x0 = NDArrayFactory::create(1.f); + auto x1 = NDArrayFactory::create('c', {2}, {2.f, 20.f}); + auto x2 = NDArrayFactory::create(3.f); + auto exp = NDArrayFactory::create('c', {4}, {1.f, 2.f, 20.f, 3.f}); sd::ops::concat op; @@ -371,10 +367,10 @@ TEST_F(DeclarableOpsTests9, concat_test6) { //////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests9, concat_test7) { - auto x0 = NDArrayFactory::create(1.f); - auto x1 = NDArrayFactory::create(2.f); - auto x2 = NDArrayFactory::create(3.f); - auto exp = NDArrayFactory::create('c', {3}, {1.f, 2.f, 3.f}); + auto x0 = NDArrayFactory::create(1.f); + auto x1 = NDArrayFactory::create(2.f); + auto x2 = NDArrayFactory::create(3.f); + auto exp = NDArrayFactory::create('c', {3}, {1.f, 2.f, 3.f}); sd::ops::concat op; @@ -391,8 +387,8 @@ TEST_F(DeclarableOpsTests9, concat_test7) { //////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests9, concat_test8) { - auto x0 = NDArrayFactory::create(1.f); - auto exp = NDArrayFactory::create('c', {1}, {1.f}); + auto x0 = NDArrayFactory::create(1.f); + auto exp = NDArrayFactory::create('c', {1}, {1.f}); sd::ops::concat op; @@ -409,8 +405,8 @@ TEST_F(DeclarableOpsTests9, concat_test8) { //////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests9, concat_test9) { - auto x0 = NDArrayFactory::create('c', {1}, {1.f}); - auto exp = NDArrayFactory::create('c', {1}, {1.f}); + auto x0 = NDArrayFactory::create('c', {1}, {1.f}); + auto exp = NDArrayFactory::create('c', {1}, {1.f}); sd::ops::concat op; @@ -427,10 +423,10 @@ TEST_F(DeclarableOpsTests9, concat_test9) { //////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests9, concat_test10) { - auto x0 = NDArrayFactory::create('c', {2,3,4}); - auto x1 = NDArrayFactory::create('f', {2,2,4}); - auto x2 = NDArrayFactory::create('c', {2,1,4}); - auto exp = NDArrayFactory::create('c', {2,6,4}, { 1.f, 2.f, 3.f, 4.f, 5.f, 6.f, 7.f, 8.f, 9.f, 10.f, 11.f, 12.f, 1.f, 2.f, 3.f, 4.f, 5.f, 6.f, 7.f, 8.f, 1.f, 2.f, 3.f, 4.f, + auto x0 = NDArrayFactory::create('c', {2,3,4}); + auto x1 = NDArrayFactory::create('f', {2,2,4}); + auto x2 = NDArrayFactory::create('c', {2,1,4}); + auto exp = NDArrayFactory::create('c', {2,6,4}, { 1.f, 2.f, 3.f, 4.f, 5.f, 6.f, 7.f, 8.f, 9.f, 10.f, 11.f, 12.f, 1.f, 2.f, 3.f, 4.f, 5.f, 6.f, 7.f, 8.f, 1.f, 2.f, 3.f, 4.f, 13.f, 14.f, 15.f, 16.f,17.f, 18.f, 19.f, 20.f,21.f, 22.f, 23.f, 24.f, 9.f, 10.f, 11.f, 12.f,13.f, 14.f, 15.f, 16.f, 5.f, 6.f, 7.f, 8.f}); x0.linspace(1); @@ -452,10 +448,10 @@ TEST_F(DeclarableOpsTests9, concat_test10) { //////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests9, concat_test11) { - auto x0 = NDArrayFactory::create('c', {2,3,4}); - auto x1 = NDArrayFactory::create('f', {2,2,4}); - auto x2 = NDArrayFactory::create('f', {2,1,4}); - auto exp = NDArrayFactory::create('c', {2,6,4}, { 1.f, 2.f, 3.f, 4.f, 5.f, 6.f, 7.f, 8.f, 9.f, 10.f, 11.f, 12.f, 1.f, 2.f, 3.f, 4.f, 5.f, 6.f, 7.f, 8.f, 1.f, 2.f, 3.f, 4.f, + auto x0 = NDArrayFactory::create('c', {2,3,4}); + auto x1 = NDArrayFactory::create('f', {2,2,4}); + auto x2 = NDArrayFactory::create('f', {2,1,4}); + auto exp = NDArrayFactory::create('c', {2,6,4}, { 1.f, 2.f, 3.f, 4.f, 5.f, 6.f, 7.f, 8.f, 9.f, 10.f, 11.f, 12.f, 1.f, 2.f, 3.f, 4.f, 5.f, 6.f, 7.f, 8.f, 1.f, 2.f, 3.f, 4.f, 13.f, 14.f, 15.f, 16.f,17.f, 18.f, 19.f, 20.f,21.f, 22.f, 23.f, 24.f, 9.f, 10.f, 11.f, 12.f,13.f, 14.f, 15.f, 16.f, 5.f, 6.f, 7.f, 8.f}); x0.linspace(1); @@ -477,10 +473,10 @@ TEST_F(DeclarableOpsTests9, concat_test11) { //////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests9, concat_test12) { - auto x0 = NDArrayFactory::create('c', {2,3,4}); - auto x1 = NDArrayFactory::create('f', {2,2,4}); - auto x2 = NDArrayFactory::create('f', {2,1,4}); - auto exp = NDArrayFactory::create('c', {2,6,4}, { 1.f, 2.f, 3.f, 4.f, 5.f, 6.f, 7.f, 8.f, 9.f, 10.f, 11.f, 12.f, 1.f, 2.f, 3.f, 4.f, 5.f, 6.f, 7.f, 8.f, 1.f, 2.f, 3.f, 4.f, + auto x0 = NDArrayFactory::create('c', {2,3,4}); + auto x1 = NDArrayFactory::create('f', {2,2,4}); + auto x2 = NDArrayFactory::create('f', {2,1,4}); + auto exp = NDArrayFactory::create('c', {2,6,4}, { 1.f, 2.f, 3.f, 4.f, 5.f, 6.f, 7.f, 8.f, 9.f, 10.f, 11.f, 12.f, 1.f, 2.f, 3.f, 4.f, 5.f, 6.f, 7.f, 8.f, 1.f, 2.f, 3.f, 4.f, 13.f, 14.f, 15.f, 16.f,17.f, 18.f, 19.f, 20.f,21.f, 22.f, 23.f, 24.f, 9.f, 10.f, 11.f, 12.f,13.f, 14.f, 15.f, 16.f, 5.f, 6.f, 7.f, 8.f}); x0.linspace(1); @@ -502,10 +498,10 @@ TEST_F(DeclarableOpsTests9, concat_test12) { //////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests9, concat_test13) { - auto x0 = NDArrayFactory::create('f', {2,3,4}); - auto x1 = NDArrayFactory::create('f', {2,2,4}); - auto x2 = NDArrayFactory::create('f', {2,1,4}); - auto exp = NDArrayFactory::create('f', {2,6,4}, { 1.f, 13.f, 5.f, 17.f, 9.f, 21.f, 1.f, 9.f, 5.f, 13.f, 1.f, 5.f, 2.f, 14.f, 6.f, 18.f,10.f, 22.f, 2.f, 10.f, 6.f, 14.f, 2.f, 6.f, + auto x0 = NDArrayFactory::create('f', {2,3,4}); + auto x1 = NDArrayFactory::create('f', {2,2,4}); + auto x2 = NDArrayFactory::create('f', {2,1,4}); + auto exp = NDArrayFactory::create('f', {2,6,4}, { 1.f, 13.f, 5.f, 17.f, 9.f, 21.f, 1.f, 9.f, 5.f, 13.f, 1.f, 5.f, 2.f, 14.f, 6.f, 18.f,10.f, 22.f, 2.f, 10.f, 6.f, 14.f, 2.f, 6.f, 3.f, 15.f, 7.f, 19.f,11.f, 23.f, 3.f, 11.f, 7.f, 15.f, 3.f, 7.f, 4.f, 16.f, 8.f, 20.f,12.f, 24.f, 4.f, 12.f, 8.f, 16.f, 4.f, 8.f}); x0.linspace(1); @@ -527,8 +523,8 @@ TEST_F(DeclarableOpsTests9, concat_test13) { TEST_F(DeclarableOpsTests9, concat_test14) { - NDArray x0('c', {1, 40, 60}, sd::DataType::DOUBLE); - NDArray x1('c', {1, 40, 60}, sd::DataType::DOUBLE); + NDArray x0('c', {1, 40, 60}, sd::DataType::FLOAT32); + NDArray x1('c', {1, 40, 60}, sd::DataType::FLOAT32); x0 = 1.; x1 = 2.; @@ -544,7 +540,7 @@ TEST_F(DeclarableOpsTests9, concat_test14) { for (int e = 0; e < numOfTads; ++e) { NDArray tad = (*z)(e, {0}); - auto mean = tad.meanNumber().e(0); + auto mean = tad.meanNumber().e(0); ASSERT_NEAR((e+1)*1., mean, 1e-5); } @@ -552,9 +548,9 @@ TEST_F(DeclarableOpsTests9, concat_test14) { } TEST_F(DeclarableOpsTests9, concat_test15) { - auto x = NDArrayFactory::create('c', {2}, {1, 0}); - auto y = NDArrayFactory::create (3.0f); - auto exp = NDArrayFactory::create('c', {3}, {1, 0, 3}); + auto x = NDArrayFactory::create('c', {2}, {1, 0}); + auto y = NDArrayFactory::create (3.0f); + auto exp = NDArrayFactory::create('c', {3}, {1, 0, 3}); sd::ops::concat op; auto result = op.evaluate({&x, &y}, {}, {0}); @@ -571,9 +567,9 @@ TEST_F(DeclarableOpsTests9, concat_test15) { ////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests9, concat_test16) { - auto x = NDArrayFactory::create('c', {0,2,3}); - auto y = NDArrayFactory::create('c', {0,2,3}); - auto exp = NDArrayFactory::create('c', {0,2,3}); + auto x = NDArrayFactory::create('c', {0,2,3}); + auto y = NDArrayFactory::create('c', {0,2,3}); + auto exp = NDArrayFactory::create('c', {0,2,3}); sd::ops::concat op; auto result = op.evaluate({&x, &y}, {}, {0}); @@ -587,8 +583,8 @@ TEST_F(DeclarableOpsTests9, concat_test16) { ////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests9, concat_test17) { - NDArray x0('c', {1, 55, 40}, sd::DataType::DOUBLE); - NDArray x1('c', {1, 55, 40}, sd::DataType::DOUBLE); + NDArray x0('c', {1, 55, 40}, sd::DataType::FLOAT32); + NDArray x1('c', {1, 55, 40}, sd::DataType::FLOAT32); x0 = 1.; x1 = 2.; @@ -606,7 +602,7 @@ TEST_F(DeclarableOpsTests9, concat_test17) { for (int e = 0; e < numOfTads; ++e) { NDArray tad = (*z)(e, {0}); - auto mean = tad.meanNumber().e(0); + auto mean = tad.meanNumber().e(0); ASSERT_NEAR((e+1)*1., mean, 1e-5); } } @@ -664,10 +660,10 @@ TEST_F(DeclarableOpsTests9, concat_test19) { //////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests9, concat_test20) { - auto x0 = NDArrayFactory::create('c', {1, 100, 150}); - auto x1 = NDArrayFactory::create('c', {1, 100, 150}); - auto x2 = NDArrayFactory::create('c', {1, 100, 150}); - auto x3 = NDArrayFactory::create('c', {1, 100, 150}); + auto x0 = NDArrayFactory::create('c', {1, 100, 150}); + auto x1 = NDArrayFactory::create('c', {1, 100, 150}); + auto x2 = NDArrayFactory::create('c', {1, 100, 150}); + auto x3 = NDArrayFactory::create('c', {1, 100, 150}); x0.assign(1.0); x1.assign(2.0); @@ -685,8 +681,8 @@ TEST_F(DeclarableOpsTests9, concat_test20) { for (int e = 0; e < numOfTads; e++) { NDArray tad = (*z)(e, {0}); - auto mean = tad.meanNumber().e(0); - ASSERT_NEAR((double) e+1, mean, 1e-5); + auto mean = tad.meanNumber().e(0); + ASSERT_NEAR((float) e+1, mean, 1e-5); } @@ -710,10 +706,10 @@ TEST_F(DeclarableOpsTests9, concat_test21) { //////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests9, concat_test22) { - NDArray x0('c', {1,6}, {1,2,3,4,5,6}); - NDArray x1('c', {1,6}, {7,8,9,10,11,12}); - NDArray output('f', {2,6}, sd::DataType::DOUBLE); - NDArray exp('c', {2,6}, {1,2,3,4,5,6,7,8,9,10,11,12}); + NDArray x0('c', {1,6}, {1,2,3,4,5,6}, sd::DataType::FLOAT32); + NDArray x1('c', {1,6}, {7,8,9,10,11,12}, sd::DataType::FLOAT32); + NDArray output('f', {2,6}, sd::DataType::FLOAT32); + NDArray exp('c', {2,6}, {1,2,3,4,5,6,7,8,9,10,11,12}, sd::DataType::FLOAT32); sd::ops::concat op; @@ -726,10 +722,10 @@ TEST_F(DeclarableOpsTests9, concat_test22) { //////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests9, concat_test23) { - NDArray x0('c', {1,4}, {1,2,3,4}); - NDArray x1('c', {1,4}, {5,6,7,8}); - NDArray output('c', {2,4}, sd::DataType::DOUBLE); - NDArray exp('c', {2,4}, {1,2,3,4,5,6,7,8}); + NDArray x0('c', {1,4}, {1,2,3,4},sd::DataType::FLOAT32); + NDArray x1('c', {1,4}, {5,6,7,8},sd::DataType::FLOAT32); + NDArray output('c', {2,4}, sd::DataType::FLOAT32); + NDArray exp('c', {2,4}, {1,2,3,4,5,6,7,8}, sd::DataType::FLOAT32); sd::ops::concat op; @@ -741,10 +737,10 @@ TEST_F(DeclarableOpsTests9, concat_test23) { //////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests9, concat_test24) { - auto x = NDArrayFactory::create('c', {2, 1}, {1, 1}); - auto y = NDArrayFactory::create('c', {2, 1}, {0, 0}); - auto e = NDArrayFactory::create('c', {2, 2}, {1, 0, 1, 0}); - auto z = NDArrayFactory::create('c', {2, 2}); + auto x = NDArrayFactory::create('c', {2, 1}, {1, 1}); + auto y = NDArrayFactory::create('c', {2, 1}, {0, 0}); + auto e = NDArrayFactory::create('c', {2, 2}, {1, 0, 1, 0}); + auto z = NDArrayFactory::create('c', {2, 2}); sd::ops::concat op; auto status = op.execute({&x, &y}, {&z}, {}, {1}, {}); @@ -756,10 +752,10 @@ TEST_F(DeclarableOpsTests9, concat_test24) { //////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests9, concat_test25) { - auto x0 = NDArrayFactory::create('c', {1,4}, {1,2,3,4}); - auto x1 = NDArrayFactory::create('c', {1,4}, {5,6,7,8}); - auto axis = NDArrayFactory::create('c', {1}, {0.}); - auto exp = NDArrayFactory::create('c', {2,4}, {1,2,3,4,5,6,7,8}); + auto x0 = NDArrayFactory::create('c', {1,4}, {1,2,3,4}); + auto x1 = NDArrayFactory::create('c', {1,4}, {5,6,7,8}); + auto axis = NDArrayFactory::create('c', {1}, {0.}); + auto exp = NDArrayFactory::create('c', {2,4}, {1,2,3,4,5,6,7,8}); sd::ops::concat op; @@ -793,7 +789,7 @@ TEST_F(DeclarableOpsTests9, concat_test26) { ASSERT_EQ(ND4J_STATUS_OK, result.status()); auto output = result.at(0); - output->printLinearBuffer(); + // output->printLinearBuffer(); ASSERT_TRUE(exp.isSameShape(output)); ASSERT_TRUE(exp.equalsTo(output)); @@ -802,10 +798,10 @@ TEST_F(DeclarableOpsTests9, concat_test26) { ////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests9, concat_test27) { - auto x1 = NDArrayFactory::create('c', {0,1}); - auto x2 = NDArrayFactory::create('c', {0,1}); - auto x3 = NDArrayFactory::create('c', {0,1}); - auto x4 = NDArrayFactory::create('c', {0,1}); + auto x1 = NDArrayFactory::create('c', {0,1}); + auto x2 = NDArrayFactory::create('c', {0,1}); + auto x3 = NDArrayFactory::create('c', {0,1}); + auto x4 = NDArrayFactory::create('c', {0,1}); std::vector expShape = {0, 4}; @@ -1245,109 +1241,6 @@ TEST_F(DeclarableOpsTests9, test_unstack_SGO_1) { } -//////////////////////////////////////////////////////////////////////////////// -TEST_F(DeclarableOpsTests9, clipbynorm_test12) { - - const int bS = 5; - const int nOut = 4; - const int axis = 0; - const double clip = 2.; - - auto x = NDArrayFactory::create('c', {bS, nOut}, {0.412 ,0.184 ,0.961 ,0.897 ,0.173 ,0.931 ,0.736 ,0.540 ,0.953 ,0.278 ,0.573 ,0.787 ,0.320 ,0.776 ,0.338 ,0.311 ,0.835 ,0.909 ,0.890 ,0.290}); // uniform random in range [0,1] - auto colVect = NDArrayFactory::create('c', {bS, 1}, {0.9, 0.95, 1.00, 1.05, 1.1}); - auto expect = NDArrayFactory::create('c', {bS, nOut}); - - auto norm2 = x.reduceAlongDimension(reduce::Norm2, {axis}, true); // norm2 has shape [1, nOut] - - auto y = ( (x / norm2) * clip) * colVect ; - auto temp = (x / norm2) * clip; - - for (int j = 0; j < nOut; ++j) { - auto yCol = y({0,0, j,j+1}); - const double norm2Col = yCol.reduceNumber(reduce::Norm2).e(0); - if (norm2Col <= clip) - expect({0,0, j,j+1}).assign(yCol); - else - expect({0,0, j,j+1}).assign ( yCol * (clip / norm2Col) ); - } - - sd::ops::clipbynorm op; - auto result = op.evaluate({&y}, {clip}, {axis}); - auto outFF = result.at(0); - - ASSERT_TRUE(expect.isSameShape(outFF)); - ASSERT_TRUE(expect.equalsTo(outFF)); - - -} - - -//////////////////////////////////////////////////////////////////////////////// -TEST_F(DeclarableOpsTests9, clipbynorm_bp_test1) { - - const int bS = 2; - const int nOut = 3; - const double clip = 0.7; - - auto x = NDArrayFactory::create('c', {bS, nOut}, {0.412 ,0.184 ,0.961 ,0.173 ,0.736 ,0.540 }); // uniform random in range [0,1] - auto gradO = NDArrayFactory::create('c', {bS, nOut}); - - const OpArgsHolder argsHolderFF({&x}, {clip}, {}); - const OpArgsHolder argsHolderBP({&x, &gradO}, {clip}, {}); - - sd::ops::clipbynorm opFF; - sd::ops::clipbynorm_bp opBP; - - const bool isGradCorrect = GradCheck::checkGrad(opFF, opBP, argsHolderFF, argsHolderBP); - - ASSERT_TRUE(isGradCorrect); -} - -//////////////////////////////////////////////////////////////////////////////// -TEST_F(DeclarableOpsTests9, clipbynorm_bp_test2) { - - const int bS = 2; - const int nOut = 3; - const int axis = 0; - const double clip = 0.7; - - auto x = NDArrayFactory::create('c', {bS, nOut}, {0.412 ,0.184 ,0.961 ,0.173 ,0.736 ,0.540 }); // uniform random in range [0,1] - auto gradO = NDArrayFactory::create('c', {bS, nOut}); - - const OpArgsHolder argsHolderFF({&x}, {clip}, {axis}); - const OpArgsHolder argsHolderBP({&x, &gradO}, {clip}, {axis}); - - sd::ops::clipbynorm opFF; - sd::ops::clipbynorm_bp opBP; - - const bool isGradCorrect = GradCheck::checkGrad(opFF, opBP, argsHolderFF, argsHolderBP); - - ASSERT_TRUE(isGradCorrect); -} - - -//////////////////////////////////////////////////////////////////////////////// -TEST_F(DeclarableOpsTests9, clipbynorm_bp_test3) { - - const int bS = 2; - const int nOut = 3; - const int axis = 1; - const double clip = 1.; - - auto x = NDArrayFactory::create('c', {bS, nOut}, {0.412 ,0.184 ,0.961 ,0.173 ,0.736 ,0.540 }); // uniform random in range [0,1] - auto gradO = NDArrayFactory::create('c', {bS, nOut}); - - const OpArgsHolder argsHolderFF({&x}, {clip}, {axis}); - const OpArgsHolder argsHolderBP({&x, &gradO}, {clip}, {axis}); - - sd::ops::clipbynorm opFF; - sd::ops::clipbynorm_bp opBP; - - const bool isGradCorrect = GradCheck::checkGrad(opFF, opBP, argsHolderFF, argsHolderBP); - - ASSERT_TRUE(isGradCorrect); -} - //////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests9, cumprod_1) {