diff --git a/libnd4j/include/loops/legacy_ops.h b/libnd4j/include/loops/legacy_ops.h index 5108ba4a7..7de54a858 100644 --- a/libnd4j/include/loops/legacy_ops.h +++ b/libnd4j/include/loops/legacy_ops.h @@ -112,7 +112,8 @@ (4, IsInfOrNan), \ (5, MatchConditionBool), \ (6, IsPositive) , \ - (7, Not) + (7, Not), \ + (8, IsNegative) #define TRANSFORM_STRICT_OPS \ @@ -279,7 +280,8 @@ (3, IsInfOrNan), \ (4, IsNan), \ (5, IsInf), \ - (6, IsPositive) + (6, IsPositive), \ + (7, IsNegative) #define REDUCE_SAME_OPS \ (0, Sum), \ diff --git a/libnd4j/include/ops/declarable/generic/parity_ops/adjust_contrast.cpp b/libnd4j/include/ops/declarable/generic/parity_ops/adjust_contrast.cpp index d790dd9c2..1aa0c5249 100644 --- a/libnd4j/include/ops/declarable/generic/parity_ops/adjust_contrast.cpp +++ b/libnd4j/include/ops/declarable/generic/parity_ops/adjust_contrast.cpp @@ -27,7 +27,8 @@ namespace nd4j { namespace ops { -CONFIGURABLE_OP_IMPL(adjust_contrast, 1, 1, true, -2, 0) { +//////////////////////////////////////////////////////////////////// +CONFIGURABLE_OP_IMPL(adjust_contrast, 1, 1, true, 0, 0) { auto input = INPUT_VARIABLE(0); auto output = OUTPUT_VARIABLE(0); @@ -37,23 +38,31 @@ CONFIGURABLE_OP_IMPL(adjust_contrast, 1, 1, true, -2, 0) { return Status::OK(); REQUIRE_TRUE(block.numT() > 0 || block.width() > 1, 0, "ADJUST_CONTRAST: Scale factor required"); - - const double factor = block.width() > 1 ? INPUT_VARIABLE(1)->e(0) : T_ARG(0); - REQUIRE_TRUE(input->rankOf() > 2, 0, "ADJUST_CONTRAST: op expects rank of input array to be >= 3, but got %i instead", input->rankOf()); REQUIRE_TRUE(input->sizeAt(-1) == 3, 0, "ADJUST_CONTRAST: operation expects image with 3 channels (R, G, B), but got %i instead", input->sizeAt(-1)); - // compute mean before + + NDArray* factor = nullptr; + + if(block.width() > 1) + factor = INPUT_VARIABLE(1); + else { + factor = new NDArray(output->dataType(), block.launchContext()); + factor->p(0, T_ARG(0)); + } + // fill up axes vector first std::vector axes(input->rankOf() - 1); for (auto i = 0; i < axes.size(); ++i) axes[i] = i; + // mean as reduction for last dimension set auto mean = input->reduceAlongDims(reduce::Mean, axes); - NDArray factorT(output->dataType(), block.launchContext()); // = NDArrayFactory::create(factor, block.launchContext()); - factorT.p(0, factor); // this is contrast calculation - output->assign((*input - mean) * factorT + mean); + output->assign((*input - mean) * (*factor) + mean); + + if(block.width() == 1) + delete factor; return Status::OK(); } @@ -64,45 +73,54 @@ DECLARE_TYPES(adjust_contrast) { ->setSameMode(true); } +//////////////////////////////////////////////////////////////////// +CONFIGURABLE_OP_IMPL(adjust_contrast_v2, 1, 1, true, 0, 0) { - CONFIGURABLE_OP_IMPL(adjust_contrast_v2, 1, 1, true, -2, 0) { - - auto input = INPUT_VARIABLE(0); - auto output = OUTPUT_VARIABLE(0); - - REQUIRE_TRUE(block.numT() > 0 || block.width() > 1, 0, "ADJUST_CONTRAST_V2: Scale factor required"); - - const double factor = block.width() > 1 ? INPUT_VARIABLE(1)->e(0) : T_ARG(0); - - // just skip op if input is empty - if (input->isEmpty()) - return Status::OK(); - - REQUIRE_TRUE(input->rankOf() > 2, 0, "ADJUST_CONTRAST_V2: op expects rank of input array to be >= 3, but got %i instead", input->rankOf()); - REQUIRE_TRUE(input->sizeAt(-1) == 3, 0, "ADJUST_CONTRAST_V2: operation expects image with 3 channels (R, G, B), but got %i instead", input->sizeAt(-1)); - - // compute mean before - std::vector axes(input->rankOf() - 1); - for (auto i = 0; i < axes.size(); ++i) - axes[i] = i; - - // mean as reduction for last dimension set - auto mean = input->reduceAlongDims(reduce::Mean, axes); - - // result as (x - mean) * factor + mean - auto temp = input->ulike(); - input->applyTrueBroadcast(BroadcastOpsTuple::Subtract(), &mean, &temp); - temp.applyScalar(scalar::Multiply, factor); - temp.applyTrueBroadcast(BroadcastOpsTuple::Add(), &mean, output); + auto input = INPUT_VARIABLE(0); + auto output = OUTPUT_VARIABLE(0); + // just skip op if input is empty + if (input->isEmpty()) return Status::OK(); + + REQUIRE_TRUE(input->rankOf() > 2, 0, "ADJUST_CONTRAST_V2: op expects rank of input array to be >= 3, but got %i instead", input->rankOf()); + REQUIRE_TRUE(input->sizeAt(-1) == 3, 0, "ADJUST_CONTRAST_V2: operation expects image with 3 channels (R, G, B), but got %i instead", input->sizeAt(-1)); + REQUIRE_TRUE(block.numT() > 0 || block.width() > 1, 0, "ADJUST_CONTRAST_V2: Scale factor required"); + + NDArray* factor = nullptr; + + if(block.width() > 1) + factor = INPUT_VARIABLE(1); + else { + factor = new NDArray(output->dataType(), block.launchContext()); + factor->p(0, T_ARG(0)); } - DECLARE_TYPES(adjust_contrast_v2) { - getOpDescriptor()->setAllowedInputTypes(nd4j::DataType::ANY) - ->setAllowedOutputTypes({ALL_FLOATS}) - ->setSameMode(true); - } + // compute mean before + std::vector axes(input->rankOf() - 1); + for (auto i = 0; i < axes.size(); ++i) + axes[i] = i; + + // mean as reduction for last dimension set + auto mean = input->reduceAlongDims(reduce::Mean, axes); + + // result as (x - mean) * factor + mean + auto temp = input->ulike(); + input->applyTrueBroadcast(BroadcastOpsTuple::Subtract(), &mean, &temp); + temp.applyScalarArr(scalar::Multiply, factor); + temp.applyTrueBroadcast(BroadcastOpsTuple::Add(), &mean, output); + + if(block.width() == 1) + delete factor; + + return Status::OK(); +} + +DECLARE_TYPES(adjust_contrast_v2) { + getOpDescriptor()->setAllowedInputTypes(nd4j::DataType::ANY) + ->setAllowedOutputTypes({ALL_FLOATS}) + ->setSameMode(true); +} } } diff --git a/libnd4j/include/ops/declarable/generic/parity_ops/adjust_hue.cpp b/libnd4j/include/ops/declarable/generic/parity_ops/adjust_hue.cpp index d1d81acf8..32e51bdb9 100644 --- a/libnd4j/include/ops/declarable/generic/parity_ops/adjust_hue.cpp +++ b/libnd4j/include/ops/declarable/generic/parity_ops/adjust_hue.cpp @@ -24,13 +24,12 @@ #include #include -#include namespace nd4j { namespace ops { -CONFIGURABLE_OP_IMPL(adjust_hue, 1, 1, true, 1, -2) { +CONFIGURABLE_OP_IMPL(adjust_hue, 1, 1, true, 0, 0) { auto input = INPUT_VARIABLE(0); auto output = OUTPUT_VARIABLE(0); @@ -41,15 +40,26 @@ CONFIGURABLE_OP_IMPL(adjust_hue, 1, 1, true, 1, -2) { const int rank = input->rankOf(); const int dimC = block.getIArguments()->size() > 0 ? (INT_ARG(0) >= 0 ? INT_ARG(0) : INT_ARG(0) + rank) : rank - 1; - const double delta = T_ARG(0); + REQUIRE_TRUE(block.numT() > 0 || block.width() > 1, 0, "ADJUST_HUE: delta factor is required !"); REQUIRE_TRUE(rank >= 3, 0, "ADJUST_HUE: op expects rank of input array to be >= 3, but got %i instead", rank); REQUIRE_TRUE(input->sizeAt(dimC) == 3, 0, "ADJUST_HUE: operation expects image with 3 channels (R, G, B), but got %i instead", input->sizeAt(dimC)); - REQUIRE_TRUE(-1. <= delta && delta <= 1., 0, "ADJUST_HUE: parameter delta must be within [-1, 1] interval, but got %f instead", delta); - NDArray deltaScalarArr = NDArrayFactory::create(delta, block.launchContext()); + NDArray* delta = nullptr; - helpers::adjustHue(block.launchContext(), input, &deltaScalarArr, output, dimC); + if(block.width() > 1) + delta = INPUT_VARIABLE(1); + else { + delta = new NDArray(output->dataType(), block.launchContext()); + delta->p(0, T_ARG(0)); + } + + REQUIRE_TRUE(-1. <= delta->e(0) && delta->e(0) <= 1., 0, "ADJUST_HUE: parameter delta must be within [-1, 1] interval, but got %f instead", delta); + + helpers::adjustHue(block.launchContext(), input, delta, output, dimC); + + if(block.width() == 1) + delete delta; return Status::OK(); } diff --git a/libnd4j/include/ops/declarable/generic/parity_ops/adjust_saturation.cpp b/libnd4j/include/ops/declarable/generic/parity_ops/adjust_saturation.cpp index 5030e5952..de947c9ae 100644 --- a/libnd4j/include/ops/declarable/generic/parity_ops/adjust_saturation.cpp +++ b/libnd4j/include/ops/declarable/generic/parity_ops/adjust_saturation.cpp @@ -28,7 +28,7 @@ namespace nd4j { namespace ops { -CONFIGURABLE_OP_IMPL(adjust_saturation, 1, 1, true, 1, -2) { +CONFIGURABLE_OP_IMPL(adjust_saturation, 1, 1, true, 0, 0) { auto input = INPUT_VARIABLE(0); auto output = OUTPUT_VARIABLE(0); @@ -37,16 +37,26 @@ CONFIGURABLE_OP_IMPL(adjust_saturation, 1, 1, true, 1, -2) { if (input->isEmpty()) return Status::OK(); - const int rank = input->rankOf(); - const int dimC = block.getIArguments()->size() > 0 ? (INT_ARG(0) >= 0 ? INT_ARG(0) : INT_ARG(0) + rank) : rank - 1; - const double factor = T_ARG(0); + const int rank = input->rankOf(); + const int dimC = block.getIArguments()->size() > 0 ? (INT_ARG(0) >= 0 ? INT_ARG(0) : INT_ARG(0) + rank) : rank - 1; REQUIRE_TRUE(rank >= 3, 0, "ADJUST_SATURATION: op expects rank of input array to be >= 3, but got %i instead", rank); REQUIRE_TRUE(input->sizeAt(dimC) == 3, 0, "ADJUST_SATURATION: operation expects image with 3 channels (R, G, B), but got %i instead", input->sizeAt(dimC)); + REQUIRE_TRUE(block.numT() > 0 || block.width() > 1, 0, "ADJUST_SATURATION: scale factor is required !"); - NDArray factorScalarArr = NDArrayFactory::create(factor, block.launchContext()); + NDArray* factor = nullptr; - helpers::adjustSaturation(block.launchContext(), input, &factorScalarArr, output, dimC); + if(block.width() > 1) + factor = INPUT_VARIABLE(1); + else { + factor = new NDArray(output->dataType(), block.launchContext()); + factor->p(0, T_ARG(0)); + } + + helpers::adjustSaturation(block.launchContext(), input, factor, output, dimC); + + if(block.width() == 1) + delete factor; return Status::OK(); } diff --git a/libnd4j/include/ops/declarable/helpers/polyGamma.h b/libnd4j/include/ops/declarable/generic/parity_ops/digamma.cpp similarity index 57% rename from libnd4j/include/ops/declarable/helpers/polyGamma.h rename to libnd4j/include/ops/declarable/generic/parity_ops/digamma.cpp index 5681d68b7..8a2894be7 100644 --- a/libnd4j/include/ops/declarable/helpers/polyGamma.h +++ b/libnd4j/include/ops/declarable/generic/parity_ops/digamma.cpp @@ -1,5 +1,6 @@ /******************************************************************************* * Copyright (c) 2015-2018 Skymind, Inc. + * Copyright (c) 2019 Konduit K.K. * * This program and the accompanying materials are made available under the * terms of the Apache License, Version 2.0 which is available at @@ -15,27 +16,35 @@ ******************************************************************************/ // -// Created by Yurii Shyrma on 13.12.2017. +// @author Yurii Shyrma (iuriish@yahoo.com) // -#ifndef LIBND4J_POLYGAMMA_H -#define LIBND4J_POLYGAMMA_H +#include +#if NOT_EXCLUDED(OP_digamma) -#include -#include "NDArray.h" +#include +#include namespace nd4j { -namespace ops { -namespace helpers { +namespace ops { +CONFIGURABLE_OP_IMPL(digamma, 1, 1, false, 0, 0) { - // calculate the polygamma function - void polyGamma(nd4j::LaunchContext * context, const NDArray& n, const NDArray& x, NDArray& output); - + auto x = INPUT_VARIABLE(0); + auto z = OUTPUT_VARIABLE(0); + + helpers::diGamma(block.launchContext(), *x, *z); + + return Status::OK(); +} + +DECLARE_TYPES(digamma) { + getOpDescriptor() + ->setAllowedInputTypes({ALL_FLOATS, ALL_INTS}) + ->setSameMode(true); +} } } -} - -#endif //LIBND4J_POLYGAMMA_H +#endif \ No newline at end of file diff --git a/libnd4j/include/ops/declarable/generic/parity_ops/polygamma.cpp b/libnd4j/include/ops/declarable/generic/parity_ops/polygamma.cpp index 0f850cd4b..1cfd86a26 100644 --- a/libnd4j/include/ops/declarable/generic/parity_ops/polygamma.cpp +++ b/libnd4j/include/ops/declarable/generic/parity_ops/polygamma.cpp @@ -15,14 +15,14 @@ ******************************************************************************/ // -// @author Yurii Shyrma (iuriish@yahoo.com), created on 13.12.2017 +// @author Yurii Shyrma (iuriish@yahoo.com) // #include #if NOT_EXCLUDED(OP_polygamma) #include -#include +#include namespace nd4j { namespace ops { @@ -37,11 +37,11 @@ CONFIGURABLE_OP_IMPL(polygamma, 2, 1, false, 0, 0) { Nd4jLong arrLen = n->lengthOf(); // FIXME: this shit should be single op call, not a loop! - auto nPositive = n->reduceNumber(nd4j::reduce::IsPositive, nullptr); - auto xPositive = x->reduceNumber(nd4j::reduce::IsPositive, nullptr); - bool nPositiveFlag = nPositive.e(0); - bool xPositiveFlag = xPositive.e(0); - REQUIRE_TRUE(nPositiveFlag, 0, "POLYGAMMA op: all elements of n array must be > 0 !"); + auto nNegative = n->reduceNumber(nd4j::reduce::IsNegative, nullptr); + auto xPositive = x->reduceNumber(nd4j::reduce::IsPositive, nullptr); + bool nPositiveFlag = !nNegative.e(0); // require all n >= 0 + bool xPositiveFlag = xPositive.e(0); // require all x > 0 + REQUIRE_TRUE(nPositiveFlag, 0, "POLYGAMMA op: all elements of n array must be >= 0 !"); REQUIRE_TRUE(xPositiveFlag, 0, "POLYGAMMA op: all elements of x array must be > 0 !"); helpers::polyGamma(block.launchContext(), *n, *x, *output); diff --git a/libnd4j/include/ops/declarable/headers/parity_ops.h b/libnd4j/include/ops/declarable/headers/parity_ops.h index 0d354630c..e56ba9d6e 100644 --- a/libnd4j/include/ops/declarable/headers/parity_ops.h +++ b/libnd4j/include/ops/declarable/headers/parity_ops.h @@ -513,7 +513,6 @@ namespace nd4j { /** * This op calculates polygamma function psi^(n)(x). Implementation is based on serial representation written in * terms of the Hurwitz zeta function: polygamma = (-1)^{n+1} * n! * zeta(n+1, x). - * Currently the case n = 0 is not supported. * * Input arrays: * 0: n - define derivative order (n+1), type integer (however currently is implemented as float casted to integer) @@ -528,6 +527,20 @@ namespace nd4j { DECLARE_CONFIGURABLE_OP(polygamma, 2, 1, false, 0, 0); #endif + /** + * This op calculates digamma function psi(x) = derivative of log(Gamma(x)) + * + * Input arrays: + * 0: x - abscissa points where to evaluate the digamma function, type float + * + * Output array: + * 0: values of digamma function at corresponding x, type float + * + */ + #if NOT_EXCLUDED(OP_digamma) + DECLARE_CONFIGURABLE_OP(digamma, 1, 1, false, 0, 0); + #endif + /** * This operation takes shape as first argument, and returns new NDArray filled with specific scalar value. * Input arrays: @@ -575,44 +588,47 @@ namespace nd4j { * This operation adjusts image hue by delta * Input arrays: * 0 - input array with rank >= 3, must have at least one dimension equal 3, that is dimension containing channels. + * 1 - optional argument, input scalar-array containing delta * * T arguments: - * 0 - delta value + * 0 - optional argument, delta value * * Int arguments: * 0 - optional argument, corresponds to dimension with 3 channels */ #if NOT_EXCLUDED(OP_adjust_hue) - DECLARE_CONFIGURABLE_OP(adjust_hue, 1, 1, true, 1, -2); + DECLARE_CONFIGURABLE_OP(adjust_hue, 1, 1, true, 0, 0); #endif /** * This operation adjusts image saturation by delta * Input arrays: * 0 - input array with rank >= 3, must have at least one dimension equal 3, that is dimension containing channels. + * 1 - optional argument, input scalar-array containing saturation factor * * T arguments: - * 0 - saturation factor + * 0 - optional argument, saturation factor * * Int arguments: * 0 - optional argument, corresponds to dimension with 3 channels */ #if NOT_EXCLUDED(OP_adjust_saturation) - DECLARE_CONFIGURABLE_OP(adjust_saturation, 1, 1, true, 1, -2); + DECLARE_CONFIGURABLE_OP(adjust_saturation, 1, 1, true, 0, 0); #endif /** * This operation adjusts image contrast by given factor ( z = (x - mean) * factor + mean ) * Input arrays: * 0 - input array with rank >= 3, must have last one dimension equal 3, that is dimension containing channels. + * 1 - optional argument, input scalar-array containing saturation contrast factor * * T arguments: - * 0 - contrast factor + * 0 - optional argument, contrast factor * */ #if NOT_EXCLUDED(OP_adjust_contrast) - DECLARE_CONFIGURABLE_OP(adjust_contrast, 1, 1, true, -2, 0); - DECLARE_CONFIGURABLE_OP(adjust_contrast_v2, 1, 1, true, -2, 0); + DECLARE_CONFIGURABLE_OP(adjust_contrast, 1, 1, true, 0, 0); + DECLARE_CONFIGURABLE_OP(adjust_contrast_v2, 1, 1, true, 0, 0); #endif @@ -1832,7 +1848,7 @@ namespace nd4j { #endif /** - * compare_and_bitpack - compare with greater and pack result with uint8 + * compare_and_bitpack - compare with greater and pack result with uint8 * * input params: * 0 - NDArray (input) diff --git a/libnd4j/include/ops/declarable/helpers/cpu/betaInc.cpp b/libnd4j/include/ops/declarable/helpers/cpu/betaInc.cpp index ddd1ad892..88186b62a 100644 --- a/libnd4j/include/ops/declarable/helpers/cpu/betaInc.cpp +++ b/libnd4j/include/ops/declarable/helpers/cpu/betaInc.cpp @@ -107,12 +107,12 @@ static T betaIncCore(T a, T b, T x) { return x; const T gammaPart = lgamma(a) + lgamma(b) - lgamma(a + b); - const T front = math::nd4j_exp(math::nd4j_log(x) * a + math::nd4j_log(1 - x) * b - gammaPart) / a; + const T front = math::nd4j_exp(math::nd4j_log(x) * a + math::nd4j_log(1.f - x) * b - gammaPart); if (x <= (a + static_cast(1)) / (a + b + static_cast(2))) - return front * continuedFraction(a, b, x); + return front * continuedFraction(a, b, x) / a; else // symmetry relation - return static_cast(1) - front * continuedFraction(b, a, static_cast(1) - x); + return static_cast(1) - front * continuedFraction(b, a, static_cast(1) - x) / b; } diff --git a/libnd4j/include/ops/declarable/helpers/cpu/diGamma.cpp b/libnd4j/include/ops/declarable/helpers/cpu/diGamma.cpp new file mode 100644 index 000000000..8035f8216 --- /dev/null +++ b/libnd4j/include/ops/declarable/helpers/cpu/diGamma.cpp @@ -0,0 +1,53 @@ +/******************************************************************************* + * Copyright (c) 2015-2018 Skymind, Inc. + * Copyright (c) 2019 Konduit K.K. + * + * This program and the accompanying materials are made available under the + * terms of the Apache License, Version 2.0 which is available at + * https://www.apache.org/licenses/LICENSE-2.0. + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + * + * SPDX-License-Identifier: Apache-2.0 + ******************************************************************************/ + +// +// @author Yurii Shyrma (iuriish@yahoo.com) +// + +#include +#include + +namespace nd4j { +namespace ops { +namespace helpers { + +////////////////////////////////////////////////////////////////////////// +// calculate digamma function for array elements +template +static void diGamma_(const NDArray& x, NDArray& z) { + + auto func = PRAGMA_THREADS_FOR { + for (auto i = start; i < stop; i += increment) + z.p(i, diGammaScalar(x.e(i))); + }; + samediff::Threads::parallel_for(func, 0, x.lengthOf()); +} + +void diGamma(nd4j::LaunchContext* context, const NDArray& x, NDArray& z) { + + BUILD_SINGLE_SELECTOR(x.dataType(), diGamma_, (x, z), FLOAT_TYPES); +} + +BUILD_SINGLE_TEMPLATE(template void diGamma_, (const NDArray& x, NDArray& z), FLOAT_TYPES); + + + +} +} +} + diff --git a/libnd4j/include/ops/declarable/helpers/cpu/polyGamma.cpp b/libnd4j/include/ops/declarable/helpers/cpu/polyGamma.cpp index cb97ffe1e..fc572677e 100644 --- a/libnd4j/include/ops/declarable/helpers/cpu/polyGamma.cpp +++ b/libnd4j/include/ops/declarable/helpers/cpu/polyGamma.cpp @@ -18,7 +18,7 @@ // Created by Yurii Shyrma on 12.12.2017 // -#include +#include #include #include #include @@ -42,7 +42,7 @@ static FORCEINLINE T getFactorial(const int n) { for(int i = 2; i <= n; ++i) result *= i; - + return result; } @@ -50,17 +50,15 @@ static FORCEINLINE T getFactorial(const int n) { // implementation is based on serial representation written in terms of the Hurwitz zeta function as polygamma = (-1)^{n+1} * n! * zeta(n+1, x) template static FORCEINLINE T polyGammaScalar(nd4j::LaunchContext * context, const int n, const T x) { - - // if (n < 0) + + // if (n < 0) // throw("polyGamma function: n must be >= 0 !"); - // if (x <= (T)0.) + // if (x <= (T)0.) // throw("polyGamma function: x must be > 0 !"); - - // TODO case for n = 0 (digamma) int sign = (n + 1) % 2 ? -1 : 1; - // T factorial = (T)std::tgamma(n + 1); + // T factorial = (T)std::tgamma(n + 1); return sign * getFactorial(n) * zetaScalar((T)(n + 1), x); } @@ -71,17 +69,18 @@ static FORCEINLINE T polyGammaScalar(nd4j::LaunchContext * context, const int n, template static void polyGamma_(nd4j::LaunchContext * context, const NDArray& n, const NDArray& x, NDArray& output) { - NDArray& result = output; - - int xLen = x.lengthOf(); - auto func = PRAGMA_THREADS_FOR { - for (auto i = start; i < stop; i += increment) - result.p(i, polyGammaScalar(context, n.e(i), x.e(i))); + for (auto i = start; i < stop; i += increment) { + const T order = n.e(i); + if(order != static_cast(order)) // if order has fractional part then do not perform calculations and return NAN + output.p(i, std::numeric_limits::quiet_NaN()); + else if (order == 0) // polygamma function of zero order is digamma function + output.p(i, diGammaScalar(x.e(i))); + else + output.p(i, polyGammaScalar(context, order, x.e(i))); + } }; samediff::Threads::parallel_for(func, 0, x.lengthOf()); - -// return result; } void polyGamma(nd4j::LaunchContext * context, const NDArray& n, const NDArray& x, NDArray& output) { diff --git a/libnd4j/include/ops/declarable/helpers/cuda/betaInc.cu b/libnd4j/include/ops/declarable/helpers/cuda/betaInc.cu index 90619c76c..e7541a005 100644 --- a/libnd4j/include/ops/declarable/helpers/cuda/betaInc.cu +++ b/libnd4j/include/ops/declarable/helpers/cuda/betaInc.cu @@ -89,20 +89,6 @@ __device__ T continuedFractionCuda(const T a, const T b, const T x) { return 1.f / 0.f; // no convergence, more iterations is required } -/////////////////////////////////////////////////////////////////// -// evaluates incomplete beta function for positive a and b, and x between 0 and 1. -template -__device__ T betaIncCoreCuda(T a, T b, T x) { - - const T gammaPart = lgamma(a) + lgamma(b) - lgamma(a + b); - const T front = math::nd4j_exp(math::nd4j_log(x) * a + math::nd4j_log(1 - x) * b - gammaPart) / a; - - if (x <= (a + static_cast(1)) / (a + b + static_cast(2))) - return front * continuedFractionCuda(a, b, x); - else // symmetry relation - return static_cast(1) - front * continuedFractionCuda(b, a, static_cast(1) - x); -} - /////////////////////////////////////////////////////////////////// template __global__ void betaIncForArrayCuda(const void* va, const Nd4jLong* aShapeInfo, @@ -115,12 +101,21 @@ __global__ void betaIncForArrayCuda(const void* va, const Nd4jLong* aShapeInfo, const Nd4jLong j = blockIdx.x; // one block per each element - Nd4jLong len = shape::length(xShapeInfo); + T& z = *(reinterpret_cast(vz) + shape::getIndexOffset(j, zShapeInfo)); - const T a = *(reinterpret_cast(va) + shape::getIndexOffset(j, aShapeInfo)); - const T b = *(reinterpret_cast(vb) + shape::getIndexOffset(j, bShapeInfo)); - const T x = *(reinterpret_cast(vx) + shape::getIndexOffset(j, xShapeInfo)); - T& z = *(reinterpret_cast(vz) + shape::getIndexOffset(j, zShapeInfo)); + __shared__ T a, b, x; + __shared__ bool symmCond; + + if (threadIdx.x == 0) { + + a = *(reinterpret_cast(va) + shape::getIndexOffset(j, aShapeInfo)); + b = *(reinterpret_cast(vb) + shape::getIndexOffset(j, bShapeInfo)); + x = *(reinterpret_cast(vx) + shape::getIndexOffset(j, xShapeInfo)); + + symmCond = x <= (a + static_cast(1)) / (a + b + static_cast(2)); + + } + __syncthreads(); // t^{n-1} * (1 - t)^{n-1} is symmetric function with respect to x = 0.5 if(a == b && x == static_cast(0.5)) { @@ -135,17 +130,31 @@ __global__ void betaIncForArrayCuda(const void* va, const Nd4jLong* aShapeInfo, if(threadIdx.x % 2 == 0) { /***** even part *****/ const int m = threadIdx.x + 1; - sharedMem[threadIdx.x] = m * (b - m) * x / ((a + 2 * m - static_cast(1)) * (a + 2 * m)); + if(symmCond) + sharedMem[threadIdx.x] = m * (b - m) * x / ((a + 2 * m - static_cast(1)) * (a + 2 * m)); + else + sharedMem[threadIdx.x] = m * (a - m) * (1.f-x) / ((b + 2 * m - static_cast(1)) * (b + 2 * m)); } else { /***** odd part *****/ const int m = threadIdx.x; - sharedMem[threadIdx.x] = -(a + m) * (a + b + m) * x / ((a + 2 * m + static_cast(1)) * (a + 2 * m)); + if(symmCond) + sharedMem[threadIdx.x] = -(a + m) * (a + b + m) * x / ((a + 2 * m + static_cast(1)) * (a + 2 * m)); + else + sharedMem[threadIdx.x] = -(b + m) * (a + b + m) * (1.f-x) / ((b + 2 * m + static_cast(1)) * (b + 2 * m)); } __syncthreads(); - if(threadIdx.x == 0) - z = betaIncCoreCuda(a, b, x); + if(threadIdx.x == 0) { + + const T gammaPart = lgamma(a) + lgamma(b) - lgamma(a + b); + const T front = math::nd4j_exp(math::nd4j_log(x) * a + math::nd4j_log(1.f - x) * b - gammaPart); + + if (symmCond) + z = front * continuedFractionCuda(a, b, x) / a; + else // symmetry relation + z = static_cast(1) - front * continuedFractionCuda(b, a, static_cast(1) - x) / b; + } } /////////////////////////////////////////////////////////////////// diff --git a/libnd4j/include/ops/declarable/helpers/cuda/diGamma.cu b/libnd4j/include/ops/declarable/helpers/cuda/diGamma.cu new file mode 100644 index 000000000..3edd59ecc --- /dev/null +++ b/libnd4j/include/ops/declarable/helpers/cuda/diGamma.cu @@ -0,0 +1,78 @@ +/******************************************************************************* + * Copyright (c) 2015-2018 Skymind, Inc. + * Copyright (c) 2019 Konduit K.K. + * + * This program and the accompanying materials are made available under the + * terms of the Apache License, Version 2.0 which is available at + * https://www.apache.org/licenses/LICENSE-2.0. + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + * + * SPDX-License-Identifier: Apache-2.0 + ******************************************************************************/ + +// +// @author Yurii Shyrma (iuriish@yahoo.com) +// + +#include +#include + +namespace nd4j { +namespace ops { +namespace helpers { + +/////////////////////////////////////////////////////////////////// +template +__global__ static void diGammaCuda(const void *vx, const Nd4jLong *xShapeInfo, + void *vz, const Nd4jLong *zShapeInfo) { + + const auto x = reinterpret_cast(vx); + auto z = reinterpret_cast(vz); + + __shared__ Nd4jLong len; + __shared__ bool sameOffset; + + if (threadIdx.x == 0) { + len = shape::length(xShapeInfo); + sameOffset = shape::haveSameShapeAndStrides(xShapeInfo, zShapeInfo); + } + __syncthreads(); + + for (int i = blockIdx.x * blockDim.x + threadIdx.x; i < len; i += gridDim.x * blockDim.x) { + + const auto xOffset = shape::getIndexOffset(i, xShapeInfo); + const auto zOffset = sameOffset ? xOffset : shape::getIndexOffset(i, zShapeInfo); + + z[zOffset] = diGammaScalar(x[xOffset]); + } +} + +/////////////////////////////////////////////////////////////////// +template +static void diGammaCudaLauncher(const int blocksPerGrid, const int threadsPerBlock, const cudaStream_t *stream, const void *vx, const Nd4jLong *xShapeInfo, void *vz, const Nd4jLong *zShapeInfo) { + + diGammaCuda<<>>(vx, xShapeInfo, vz, zShapeInfo); +} + +/////////////////////////////////////////////////////////////////// +void diGamma(nd4j::LaunchContext* context, const NDArray& x, NDArray& z) { + + int threadsPerBlock = MAX_NUM_THREADS / 2; + int blocksPerGrid = (z.lengthOf() + threadsPerBlock - 1) / threadsPerBlock; + + NDArray::prepareSpecialUse({&z}, {&x}); + BUILD_SINGLE_SELECTOR(x.dataType(), diGammaCudaLauncher, (blocksPerGrid, threadsPerBlock, context->getCudaStream(), x.getSpecialBuffer(), x.getSpecialShapeInfo(), z.getSpecialBuffer(), z.getSpecialShapeInfo()), FLOAT_TYPES); + NDArray::registerSpecialUse({&z}, {&x}); +} + +BUILD_SINGLE_TEMPLATE(template void diGammaCudaLauncher, (const int blocksPerGrid, const int threadsPerBlock, const cudaStream_t *stream, const void *vx, const Nd4jLong *xShapeInfo, void *vz, const Nd4jLong *zShapeInfo), FLOAT_TYPES); + +} +} +} + diff --git a/libnd4j/include/ops/declarable/helpers/cuda/polyGamma.cu b/libnd4j/include/ops/declarable/helpers/cuda/polyGamma.cu index 01b9464fa..6a62246a3 100644 --- a/libnd4j/include/ops/declarable/helpers/cuda/polyGamma.cu +++ b/libnd4j/include/ops/declarable/helpers/cuda/polyGamma.cu @@ -18,7 +18,7 @@ // @author Yurii Shyrma (iuriish@yahoo.com), created on 26.04.2019 // -#include +#include #include #include @@ -37,9 +37,13 @@ __global__ static void polyGammaCuda(const void *vn, const Nd4jLong *nShapeInfo, auto z = reinterpret_cast(vz); __shared__ Nd4jLong len; + __shared__ bool sameOffsetNX, sameOffsetNZ; - if (threadIdx.x == 0) + if (threadIdx.x == 0) { len = shape::length(nShapeInfo); + sameOffsetNX = shape::haveSameShapeAndStrides(xShapeInfo, nShapeInfo); + sameOffsetNZ = shape::haveSameShapeAndStrides(zShapeInfo, nShapeInfo); + } __syncthreads(); const auto tid = blockIdx.x * blockDim.x + threadIdx.x; @@ -48,19 +52,26 @@ __global__ static void polyGammaCuda(const void *vn, const Nd4jLong *nShapeInfo, for (int i = tid; i < len; i += totalThreads) { const auto nOffset = shape::getIndexOffset(i, nShapeInfo); - const auto xOffset = shape::getIndexOffset(i, xShapeInfo); - const auto zOffset = shape::getIndexOffset(i, zShapeInfo); + const auto xOffset = sameOffsetNX ? nOffset : shape::getIndexOffset(i, xShapeInfo); + const auto zOffset = sameOffsetNZ ? nOffset : shape::getIndexOffset(i, zShapeInfo); - const T nVal = n[nOffset]; + const T order = n[nOffset]; - int sign = (static_cast(nVal) + 1) % 2 ? -1 : 1; + int sign = (static_cast(order) + 1) % 2 ? -1 : 1; - T factorial = 1; - if(nVal != 0 && nVal != 1) - for(int i = 2; i <= nVal; ++i) - factorial *= i; + if(order != static_cast(order)) { + z[zOffset] = DataTypeUtils::nanOrZero(); + } + else if(order == 0) { + z[zOffset] = diGammaScalar(x[xOffset]); + } + else { + T factorial = 1; + for(int i = 2; i <= order; ++i) + factorial *= i; - z[zOffset] = sign * factorial * zetaScalar(nVal + 1, x[xOffset]); + z[zOffset] = sign * factorial * zetaScalar(order + 1, x[xOffset]); + } } } @@ -76,7 +87,7 @@ void polyGamma(nd4j::LaunchContext * context, const NDArray& n, const NDArray& x NDArray::prepareSpecialUse({&z}, {&n, &x}); - int threadsPerBlock = MAX_NUM_THREADS; + int threadsPerBlock = MAX_NUM_THREADS / 2; int blocksPerGrid = (z.lengthOf() + threadsPerBlock - 1) / threadsPerBlock; BUILD_SINGLE_SELECTOR(n.dataType(), polyGammaCudaLauncher, (blocksPerGrid, threadsPerBlock, context->getCudaStream(), n.getSpecialBuffer(), n.getSpecialShapeInfo(), x.getSpecialBuffer(), x.getSpecialShapeInfo(), z.getSpecialBuffer(), z.getSpecialShapeInfo()), FLOAT_TYPES); diff --git a/libnd4j/include/ops/declarable/helpers/gammaMathFunc.h b/libnd4j/include/ops/declarable/helpers/gammaMathFunc.h new file mode 100644 index 000000000..2ad540409 --- /dev/null +++ b/libnd4j/include/ops/declarable/helpers/gammaMathFunc.h @@ -0,0 +1,100 @@ +/******************************************************************************* + * Copyright (c) 2015-2018 Skymind, Inc. + * Copyright (c) 2019 Konduit K.K. + * + * This program and the accompanying materials are made available under the + * terms of the Apache License, Version 2.0 which is available at + * https://www.apache.org/licenses/LICENSE-2.0. + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + * + * SPDX-License-Identifier: Apache-2.0 + ******************************************************************************/ + +// +// @author Yurii Shyrma (iuriish@yahoo.com) +// + +#ifndef LIBND4J_GAMMAMATHFUNC_H +#define LIBND4J_GAMMAMATHFUNC_H + +#include +#include "NDArray.h" + +namespace nd4j { +namespace ops { +namespace helpers { + + // calculate the digamma function for each element for array + void diGamma(nd4j::LaunchContext* context, const NDArray& x, NDArray& z); + + // calculate the polygamma function + void polyGamma(nd4j::LaunchContext* context, const NDArray& n, const NDArray& x, NDArray& z); + + // calculate the digamma function for one element + // implementation is based on serial representation written in terms of the Hurwitz zeta function as polygamma = (-1)^{n+1} * n! * zeta(n+1, x) + template + _CUDA_HD T diGammaScalar(T x) { + + const int xInt = static_cast(x); + + // negative and zero + if(x <= 0) { + if(x == xInt) // integer + return DataTypeUtils::infOrMax(); + else + return diGammaScalar(1 - x) - M_PI / nd4j::math::nd4j_tan(M_PI * x); // use reflection formula psi(1-x) = psi(x) + pi*cot(pi*x) + } + + // positive integer + if(x == xInt && xInt <= 20) { // psi(n) = -Euler_Mascheroni_const + sum_from_k=1_to_n-1( 1/k ), for n = 1,2,3,...inf, we use this formula only for n <= 20 to avoid time consuming sum calculation for bigger n + T result = -0.577215664901532; + for (uint i = 1; i <= xInt - 1; ++i) { + result += static_cast(1) / i; + } + return result; + } + + // positive half-integer + if(x - xInt == 0.5 && xInt <= 20) { // psi(n+0.5) = -Euler_Mascheroni_const - 2*ln(2) + sum_from_k=1_to_n( 2/(2*k-1) ) , for n = 1,2,3,...inf, we use this formula only for n <= 20 to avoid time consuming sum calculation for bigger n + T result = -0.577215664901532 - 2 * nd4j::math::nd4j_log(2); + for (uint i = 1; i <= xInt; ++i) { + result += static_cast(2) / (2*i - 1); + } + return result; + } + + // positive, smaller then 5; we should use number > 5 in order to have satisfactory accuracy in asymptotic expansion + if(x < 5) + return diGammaScalar(1 + x) - static_cast(1) / x; // recurrence formula psi(x) = psi(x+1) - 1/x. + + // *** other positive **** // + + // truncated expansion formula (from wiki) + // psi(x) = log(x) - 1/(2*x) - 1/(12*x^2) + 1/(120*x^4) - 1/(252*x^6) + 1/(240*x^8) - 5/(660*x^10) + 691/(32760*x^12) - 1/(12*x^14) + ... + + if(x >= (sizeof(T) > 4 ? 1.e16 : 1.e8)) // if x is too big take into account only log(x) + return nd4j::math::nd4j_log(x); + + // coefficients used in truncated asymptotic expansion formula + const T coeffs[7] = {-(T)1/12, (T)1/120, -(T)1/252, (T)1/240, -(T)5/660, (T)691/32760, -(T)1/12}; + // const T coeffs[7] = {-0.0833333333333333, 0.00833333333333333, -0.00396825396825397, 0.00416666666666667, -0.00757575757575758, 0.0210927960927961, -0.0833333333333333}; + + const T x2Inv = static_cast(1) / (x * x); + T result = 0; + + for (int i = 6; i >= 0; --i) + result = (result + coeffs[i]) * x2Inv; + return result + nd4j::math::nd4j_log(x) - static_cast(0.5) / x; + } + +} +} +} + + +#endif //LIBND4J_GAMMAMATHFUNC_H diff --git a/libnd4j/include/ops/declarable/platform/mkldnn/conv2d.cpp b/libnd4j/include/ops/declarable/platform/mkldnn/conv2d.cpp index 027c62484..4531fda81 100644 --- a/libnd4j/include/ops/declarable/platform/mkldnn/conv2d.cpp +++ b/libnd4j/include/ops/declarable/platform/mkldnn/conv2d.cpp @@ -36,7 +36,7 @@ namespace platforms { ////////////////////////////////////////////////////////////////////// static void conv2d_mkldnn(nd4j::graph::Context &block, const NDArray *input, const NDArray *weights, const NDArray *bias, NDArray *output, const int kH, const int kW, const int sH, - const int sW, int pH, int pW, const int dH, const int dW, const int isSameMode, + const int sW, int pH, int pW, const int dH, const int dW, const int paddingMode, const int isNCHW) { int bS, iC, iH, iW, oC, oH, oW; // batch size, input channels, input height/width, output channels, output height/width; @@ -44,8 +44,7 @@ static void conv2d_mkldnn(nd4j::graph::Context &block, const NDArray *input, con ConvolutionUtils::getSizesAndIndexesConv2d(isNCHW, *input, *output, bS, iC, iH, iW, oC, oH, oW, indIOioC, indIiH, indWiC, indWoC, indWkH, indOoH); - if(isSameMode) // SAME - ConvolutionUtils::calcPadding2D(pH, pW, oH, oW, iH, iW, kH, kW, sH, sW, dH, dW); + ConvolutionUtils::calcPadding2D(pH, pW, oH, oW, iH, iW, kH, kW, sH, sW, dH, dW, paddingMode); dnnl_memory_desc_t empty; dnnl::memory::desc conv_src_md(empty), conv_weights_md(empty), conv_bias_md(empty), conv_dst_md( @@ -53,7 +52,7 @@ static void conv2d_mkldnn(nd4j::graph::Context &block, const NDArray *input, con dnnl::memory::desc user_src_md(empty), user_weights_md(empty), user_bias_md(empty), user_dst_md( empty); dnnl::memory::dims conv_strides, conv_padding, conv_padding_r, conv_dilation; - mkldnnUtils::getMKLDNNMemoryDescConv2d(kH, kW, sH, sW, pH, pW, dH, dW, isSameMode, isNCHW, + mkldnnUtils::getMKLDNNMemoryDescConv2d(kH, kW, sH, sW, pH, pW, dH, dW, paddingMode, isNCHW, bS, iC, iH, iW, oC, oH, oW, input, nullptr, weights, nullptr, bias, output, &conv_src_md, nullptr, &conv_weights_md, nullptr, @@ -115,13 +114,11 @@ static void conv2d_mkldnn(nd4j::graph::Context &block, const NDArray *input, con ////////////////////////////////////////////////////////////////////// PLATFORM_IMPL(conv2d) { - auto input = INPUT_VARIABLE( - 0); // [bS, iH, iW, iC] (NHWC) or [bS, iC, iH, iW] (NCHW) + auto input = INPUT_VARIABLE(0); // [bS, iH, iW, iC] (NHWC) or [bS, iC, iH, iW] (NCHW) auto weights = INPUT_VARIABLE(1); // [kH, kW, iC, oC] always auto bias = block.width() > 2 ? INPUT_VARIABLE(2) : nullptr; // [oC] - auto output = OUTPUT_VARIABLE( - 0); // [bS, oH, oW, oC] (NHWC) or [bS, oC, oH, oW] (NCHW) + auto output = OUTPUT_VARIABLE(0); // [bS, oH, oW, oC] (NHWC) or [bS, oC, oH, oW] (NCHW) int sH = INT_ARG(2); // strides height int sW = INT_ARG(3); // strides width @@ -129,13 +126,13 @@ PLATFORM_IMPL(conv2d) { int pW = INT_ARG(5); // paddings width int dH = INT_ARG(6); // dilations height int dW = INT_ARG(7); // dilations width - int isSameMode = INT_ARG(8); // 0-VALID, 1-SAME + int paddingMode = INT_ARG(8); // 0-VALID, 1-SAME bool isNCHW = block.getIArguments()->size() > 9 ? !INT_ARG(9) : 1; // INT_ARG(9): 0-NCHW, 1-NHWC int kH = INT_ARG(0) > 0 ? INT_ARG(0) : static_cast(weights->sizeAt(0)); // filter(kernel) height int kW = INT_ARG(1) > 0 ? INT_ARG(1) : static_cast(weights->sizeAt(1)); // filter(kernel) width - conv2d_mkldnn(block, input, weights, bias, output, kH, kW, sH, sW, pH, pW, dH, dW, isSameMode, isNCHW); + conv2d_mkldnn(block, input, weights, bias, output, kH, kW, sH, sW, pH, pW, dH, dW, paddingMode, isNCHW); return Status::OK(); } @@ -155,18 +152,13 @@ PLATFORM_CHECK(conv2d) { ////////////////////////////////////////////////////////////////////// PLATFORM_IMPL(conv2d_bp) { - auto input = INPUT_VARIABLE( - 0); // [bS, iH, iW, iC] (NHWC) or [bS, iC, iH, iW] (NCHW) - auto weights = INPUT_VARIABLE( - 1); // [kH, kW, iC, oC] always + auto input = INPUT_VARIABLE(0); // [bS, iH, iW, iC] (NHWC) or [bS, iC, iH, iW] (NCHW) + auto weights = INPUT_VARIABLE(1); // [kH, kW, iC, oC] always auto bias = block.width() > 3 ? INPUT_VARIABLE(2) : nullptr; // [oC] - auto gradO = block.width() > 3 ? INPUT_VARIABLE(3) : INPUT_VARIABLE( - 2); // [bS, oH, oW, oC] (NHWC) or [bS, oC, oH, oW] (NCHW), epsilon_next + auto gradO = block.width() > 3 ? INPUT_VARIABLE(3) : INPUT_VARIABLE(2); // [bS, oH, oW, oC] (NHWC) or [bS, oC, oH, oW] (NCHW), epsilon_next - auto gradI = OUTPUT_VARIABLE( - 0); // [bS, iH, iW, iC] (NHWC) or [bS, iC, iH, iW] (NCHW), epsilon - auto gradW = OUTPUT_VARIABLE( - 1); // [kH, kW, iC, oC] always + auto gradI = OUTPUT_VARIABLE(0); // [bS, iH, iW, iC] (NHWC) or [bS, iC, iH, iW] (NCHW), epsilon + auto gradW = OUTPUT_VARIABLE(1); // [kH, kW, iC, oC] always auto gradB = block.width() > 3 ? OUTPUT_VARIABLE(2) : nullptr; // [oC] int kH = INT_ARG(0); // filter(kernel) height @@ -177,7 +169,7 @@ PLATFORM_IMPL(conv2d_bp) { int pW = INT_ARG(5); // paddings width int dH = INT_ARG(6); // dilations height int dW = INT_ARG(7); // dilations width - int isSameMode = INT_ARG(8); // 0-VALID, 1-SAME + int paddingMode = INT_ARG(8); // 0-VALID, 1-SAME int isNCHW = block.getIArguments()->size() > 9 ? !INT_ARG(9) : 1; // INT_ARG(9): 0-NCHW, 1-NHWC REQUIRE_TRUE(input->rankOf() == 4, 0, @@ -195,8 +187,7 @@ PLATFORM_IMPL(conv2d_bp) { ConvolutionUtils::getSizesAndIndexesConv2d(isNCHW, *input, *gradO, bS, iC, iH, iW, oC, oH, oW, indIOioC, indIiH, indWiC, indWoC, indWkH, indOoH); - if (isSameMode) // SAME - ConvolutionUtils::calcPadding2D(pH, pW, oH, oW, iH, iW, kH, kW, sH, sW, dH, dW); + ConvolutionUtils::calcPadding2D(pH, pW, oH, oW, iH, iW, kH, kW, sH, sW, dH, dW, paddingMode); dnnl_memory_desc_t empty; dnnl::memory::desc conv_src_md(empty), conv_diff_src_md(empty), conv_weights_md(empty), @@ -204,7 +195,7 @@ PLATFORM_IMPL(conv2d_bp) { dnnl::memory::desc user_src_md(empty), user_diff_src_md(empty), user_weights_md(empty), user_diff_weights_md(empty), user_bias_md(empty), user_dst_md(empty); dnnl::memory::dims conv_strides, conv_padding, conv_padding_r, conv_dilation; - mkldnnUtils::getMKLDNNMemoryDescConv2d(kH, kW, sH, sW, pH, pW, dH, dW, isSameMode, isNCHW, + mkldnnUtils::getMKLDNNMemoryDescConv2d(kH, kW, sH, sW, pH, pW, dH, dW, paddingMode, isNCHW, bS, iC, iH, iW, oC, oH, oW, input, gradI, weights, gradW, gradB, gradO, &conv_src_md, &conv_diff_src_md, &conv_weights_md, @@ -342,18 +333,13 @@ PLATFORM_CHECK(conv2d_bp) { if (::optimalLevel() < 2) return false; - auto input = INPUT_VARIABLE( - 0); // [bS, iH, iW, iC] (NHWC) or [bS, iC, iH, iW] (NCHW) - auto weights = INPUT_VARIABLE( - 1); // [kH, kW, iC, oC] always + auto input = INPUT_VARIABLE(0); // [bS, iH, iW, iC] (NHWC) or [bS, iC, iH, iW] (NCHW) + auto weights = INPUT_VARIABLE(1); // [kH, kW, iC, oC] always auto bias = block.width() > 3 ? INPUT_VARIABLE(2) : nullptr; // [oC] - auto gradO = block.width() > 3 ? INPUT_VARIABLE(3) : INPUT_VARIABLE( - 2); // [bS, oH, oW, oC] (NHWC) or [bS, oC, oH, oW] (NCHW), epsilon_next + auto gradO = block.width() > 3 ? INPUT_VARIABLE(3) : INPUT_VARIABLE(2); // [bS, oH, oW, oC] (NHWC) or [bS, oC, oH, oW] (NCHW), epsilon_next - auto gradI = OUTPUT_VARIABLE( - 0); // [bS, iH, iW, iC] (NHWC) or [bS, iC, iH, iW] (NCHW), epsilon - auto gradW = OUTPUT_VARIABLE( - 1); // [kH, kW, iC, oC] always + auto gradI = OUTPUT_VARIABLE(0); // [bS, iH, iW, iC] (NHWC) or [bS, iC, iH, iW] (NCHW), epsilon + auto gradW = OUTPUT_VARIABLE(1); // [kH, kW, iC, oC] always auto gradB = block.width() > 3 ? OUTPUT_VARIABLE(2) : nullptr; // [oC] diff --git a/libnd4j/include/ops/ops.h b/libnd4j/include/ops/ops.h index f3f9f9699..1b54889d4 100644 --- a/libnd4j/include/ops/ops.h +++ b/libnd4j/include/ops/ops.h @@ -1568,11 +1568,39 @@ namespace simdOps { return opOutput + old; } - op_def static Z update(X old, X opOutput, X *extraParams) { return opOutput + old; } + op_def static Z postProcess(X reduction, Nd4jLong n, X *extraParams) { + return reduction; + } + }; + + template + class IsNegative { + public: + no_op_exec_special_bool + no_op_exec_special_bool_cuda + + no_op_exec_special_accumulation + no_op_exec_special_accumulation_cuda + + op_def static Z op(X d1, X *params) { + return d1 < (X)0.f; + } + + op_def static X startingValue(const X *input) { + return static_cast(0); + } + + op_def static Z merge(X old, X opOutput, X *extraParams) { + return opOutput + old; + } + + op_def static Z update(X old, X opOutput, X *extraParams) { + return opOutput + old; + } op_def static Z postProcess(X reduction, Nd4jLong n, X *extraParams) { return reduction; diff --git a/libnd4j/tests_cpu/layers_tests/ConvolutionTests1.cpp b/libnd4j/tests_cpu/layers_tests/ConvolutionTests1.cpp index b3552a00f..99092b37d 100644 --- a/libnd4j/tests_cpu/layers_tests/ConvolutionTests1.cpp +++ b/libnd4j/tests_cpu/layers_tests/ConvolutionTests1.cpp @@ -1008,6 +1008,38 @@ TEST_F(ConvolutionTests1, conv1d_causal_6) { delete results; } +////////////////////////////////////////////////////////////////////// +TEST_F(ConvolutionTests1, conv1d_causal_7) { + + int bS=2, iW=8, iC=3,oC=4, kW=2, sW=1, pW=0, dW=1; + int oW = (iW-1)/sW + 1; + int paddingMode = 2; // CAUSAL + int dataFormat = 1; // 1-NHWC, 0-NCHW + + NDArray input('c', {bS, iW, iC}, nd4j::DataType::FLOAT32); + NDArray weights('c', {kW, iC, oC}, nd4j::DataType::FLOAT32); + + NDArray expOutput('c', {bS, oW, oC}, {11.000000, 11.600000, 12.200000, 12.800000, 30.099998, 32.200001, 34.299999, 36.400002, 49.899998, 53.800003, 57.699997, + 61.599998, 69.699997, 75.400002, 81.099998, 86.800003, 89.500000, 97.000000, 104.500000, 112.000000, 109.300003, 118.600006, 127.899994, 137.199997, 129.100006, + 140.199997, 151.300003, 162.399994, 148.899994, 161.800003, 174.699997, 187.600006, 133.399994, 141.200012, 149.000000, 156.800003, 188.500000, 205.000000, + 221.500000, 238.000000, 208.299988, 226.600006, 244.899994, 263.200012, 228.100006, 248.200012, 268.299988, 288.399994, 247.899994, 269.799988, 291.700012, + 313.600006, 267.700012, 291.399994, 315.100006, 338.799988, 287.500000, 313.000000, 338.500000, 364.000000, 307.299988, 334.600006, 361.899994, 389.200012}, nd4j::DataType::FLOAT32); + + input.linspace(1., 1.); + weights.linspace(0.1, 0.1); + + nd4j::ops::conv1d op; + auto results = op.execute({&input, &weights}, {}, {kW, sW, pW, dW, paddingMode, dataFormat}); + auto output = results->at(0); + + ASSERT_EQ(Status::OK(), results->status()); + + ASSERT_TRUE(expOutput.isSameShape(output)); + ASSERT_TRUE(expOutput.equalsTo(output)); + + delete results; +} + ////////////////////////////////////////////////////////////////////// TEST_F(ConvolutionTests1, conv1d_causal_bp_1) { @@ -1174,14 +1206,14 @@ TYPED_TEST(TypedConvolutionTests1, conv2d_bp_test3) { auto bias = NDArrayFactory::create('c', {oC}, {1,2,3}); auto gradO = NDArrayFactory::create('c', {bS, oC, oH, oW}); - auto expGradI = NDArrayFactory::create('c', {bS, iC, iH, iW},{ 0.567f, 1.224f, 0.66f, 1.314f, 2.82f, 1.512f, 1.386f, 2.976f, 1.596f, 0.801f, 1.71f, 0.912f, 0.657f, 1.422f, 0.768f, 1.53f, 3.288f, 1.764f, 1.602f, 3.444f, 1.848f, 0.927f, 1.98f, 1.056f, - 0.747f, 1.62f, 0.876f, 1.746f, 3.756f, 2.016f, 1.818f, 3.912f, 2.1f, 1.053f, 2.25f, 1.2f, 0.837f, 1.818f, 0.984f, 1.962f, 4.224f, 2.268f, 2.034f, 4.38f, 2.352f, 1.179f, 2.52f, 1.344f, - 1.467f, 3.06f, 1.596f, 3.186f, 6.636f, 3.456f, 3.402f, 7.08f, 3.684f, 1.845f, 3.834f, 1.992f, 1.773f, 3.69f, 1.92f, 3.834f, 7.968f, 4.14f, 4.05f, 8.412f, 4.368f, 2.187f, 4.536f, 2.352f, + auto expGradI = NDArrayFactory::create('c', {bS, iC, iH, iW},{ 0.567f, 1.224f, 0.66f, 1.314f, 2.82f, 1.512f, 1.386f, 2.976f, 1.596f, 0.801f, 1.71f, 0.912f, 0.657f, 1.422f, 0.768f, 1.53f, 3.288f, 1.764f, 1.602f, 3.444f, 1.848f, 0.927f, 1.98f, 1.056f, + 0.747f, 1.62f, 0.876f, 1.746f, 3.756f, 2.016f, 1.818f, 3.912f, 2.1f, 1.053f, 2.25f, 1.2f, 0.837f, 1.818f, 0.984f, 1.962f, 4.224f, 2.268f, 2.034f, 4.38f, 2.352f, 1.179f, 2.52f, 1.344f, + 1.467f, 3.06f, 1.596f, 3.186f, 6.636f, 3.456f, 3.402f, 7.08f, 3.684f, 1.845f, 3.834f, 1.992f, 1.773f, 3.69f, 1.92f, 3.834f, 7.968f, 4.14f, 4.05f, 8.412f, 4.368f, 2.187f, 4.536f, 2.352f, 2.079f, 4.32f, 2.244f, 4.482f, 9.3f, 4.824f, 4.698f, 9.744f, 5.052f, 2.529f, 5.238f, 2.712f, 2.385f, 4.95f, 2.568f, 5.13f, 10.632f, 5.508f, 5.346f, 11.076f, 5.736f, 2.871f, 5.94f, 3.072f}); - auto expGradW = NDArrayFactory::create('c', {oC, iC, kH, kW},{1.3600e+00f, 1.3600e+00f, 1.3600e+00f, 1.3600e+00f, 1.3600e+00f, 1.3600e+00f, 1.3600e+00f, 1.3600e+00f, 1.3600e+00f, 1.3600e+00f, 1.3600e+00f, 1.3600e+00f, 1.3600e+00f, 1.3600e+00f, 1.3600e+00f, 1.3600e+00f, 1.3600e+00f, 1.3600e+00f, - 1.3600e+00f, 1.3600e+00f, 1.3600e+00f, 1.3600e+00f, 1.3600e+00f, 1.3600e+00f, 2.0000e+00f, 2.0000e+00f, 2.0000e+00f, 2.0000e+00f, 2.0000e+00f, 2.0000e+00f, 2.0000e+00f, 2.0000e+00f, 2.0000e+00f, 2.0000e+00f, 2.0000e+00f, 2.0000e+00f, - 2.0000e+00f, 2.0000e+00f, 2.0000e+00f, 2.0000e+00f, 2.0000e+00f, 2.0000e+00f, 2.0000e+00f, 2.0000e+00f, 2.0000e+00f, 2.0000e+00f, 2.0000e+00f, 2.0000e+00f, 2.6400e+00f, 2.6400e+00f, 2.6400e+00f, 2.6400e+00f, 2.6400e+00f, 2.6400e+00f, + auto expGradW = NDArrayFactory::create('c', {oC, iC, kH, kW},{1.3600e+00f, 1.3600e+00f, 1.3600e+00f, 1.3600e+00f, 1.3600e+00f, 1.3600e+00f, 1.3600e+00f, 1.3600e+00f, 1.3600e+00f, 1.3600e+00f, 1.3600e+00f, 1.3600e+00f, 1.3600e+00f, 1.3600e+00f, 1.3600e+00f, 1.3600e+00f, 1.3600e+00f, 1.3600e+00f, + 1.3600e+00f, 1.3600e+00f, 1.3600e+00f, 1.3600e+00f, 1.3600e+00f, 1.3600e+00f, 2.0000e+00f, 2.0000e+00f, 2.0000e+00f, 2.0000e+00f, 2.0000e+00f, 2.0000e+00f, 2.0000e+00f, 2.0000e+00f, 2.0000e+00f, 2.0000e+00f, 2.0000e+00f, 2.0000e+00f, + 2.0000e+00f, 2.0000e+00f, 2.0000e+00f, 2.0000e+00f, 2.0000e+00f, 2.0000e+00f, 2.0000e+00f, 2.0000e+00f, 2.0000e+00f, 2.0000e+00f, 2.0000e+00f, 2.0000e+00f, 2.6400e+00f, 2.6400e+00f, 2.6400e+00f, 2.6400e+00f, 2.6400e+00f, 2.6400e+00f, 2.6400e+00f, 2.6400e+00f, 2.6400e+00f, 2.6400e+00f, 2.6400e+00f, 2.6400e+00f, 2.6400e+00f, 2.6400e+00f, 2.6400e+00f, 2.6400e+00f, 2.6400e+00f, 2.6400e+00f, 2.6400e+00f, 2.6400e+00f, 2.6400e+00f, 2.6400e+00f, 2.6400e+00f, 2.6400e+00f}); auto expGradB = NDArrayFactory::create('c', {oC},{0.68f, 1.f, 1.32f}); @@ -1252,20 +1284,20 @@ TYPED_TEST(TypedConvolutionTests1, conv3d_bp_test1) { auto bias = NDArrayFactory::create('c', {oC}, {1,2,3}); auto gradO = NDArrayFactory::create('c', {bS, oD, oH, oW, oC}); - auto expGradI = NDArrayFactory::create('c', {bS, iD, iH, iW, iC},{0.226f, 0.343f, 0.46f, 0.577f, 1.172f, 1.46f, 1.748f, 2.036f, 1.892f, 2.288f, 2.684f, 3.08f, 1.284f, 1.581f, 1.878f, 2.175f, 4.458f, 5.133f, 5.808f, 6.483f, 6.186f, 7.023f, 7.86f, 8.697f, 3.39f, 3.93f, 4.47f, 5.01f, 9.642f, 10.803f, 11.964f, 13.125f, 11.37f, 12.693f, 14.016f, 15.339f, - 5.266f, 5.707f, 6.148f, 6.589f, 12.98f, 13.916f, 14.852f, 15.788f, 14.564f, 15.608f, 16.652f, 17.696f, 6.284f, 7.166f, 8.048f, 8.93f, 17.896f, 19.768f, 21.64f, 23.512f, 21.928f, 24.016f, 26.104f, 28.192f, 18.12f, 19.686f, 21.252f, 22.818f, 45.852f, 49.146f, 52.44f, 55.734f, 53.196f, 56.814f, 60.432f, 64.05f, - 28.164f, 30.216f, 32.268f, 34.32f, 67.884f, 72.15f, 76.416f, 80.682f, 75.228f, 79.818f, 84.408f, 88.998f, 29.324f, 30.854f, 32.384f, 33.914f, 67.432f, 70.6f, 73.768f, 76.936f, 73.192f, 76.576f, 79.96f, 83.344f, 27.884f, 30.062f, 32.24f, 34.418f, 66.28f, 70.744f, 75.208f, 79.672f, 70.312f, 74.992f, 79.672f, 84.352f, - 58.296f, 61.806f, 65.316f, 68.826f, 133.98f, 141.162f, 148.344f, 155.526f, 141.324f, 148.83f, 156.336f, 163.842f, 68.34f, 72.336f, 76.332f, 80.328f, 156.012f, 164.166f, 172.32f, 180.474f, 163.356f, 171.834f, 180.312f, 188.79f, 61.292f, 64.118f, 66.944f, 69.77f, 136.552f, 142.312f, 148.072f, 153.832f, 142.312f, 148.288f, 154.264f, 160.24f, - 9.298f, 11.359f, 13.42f, 15.481f, 27.092f, 31.268f, 35.444f, 39.62f, 27.812f, 32.096f, 36.38f, 40.664f, 26.556f, 29.769f, 32.982f, 36.195f, 66.666f, 73.173f, 79.68f, 86.187f, 68.394f, 75.063f, 81.732f, 88.401f, 28.662f, 32.118f, 35.574f, 39.03f, 71.85f, 78.843f, 85.836f, 92.829f, 73.578f, 80.733f, 87.888f, 95.043f, - 29.89f, 32.275f, 34.66f, 37.045f, 70.004f, 74.828f, 79.652f, 84.476f, 71.588f, 76.52f, 81.452f, 86.384f, 71.084f, 75.854f, 80.624f, 85.394f, 163.048f, 172.696f, 182.344f, 191.992f, 167.08f, 176.944f, 186.808f, 196.672f, 138.648f, 146.046f, 153.444f, 160.842f, 310.236f, 325.194f, 340.152f, 355.11f, 317.58f, 332.862f, 348.144f, 363.426f, - 148.692f, 156.576f, 164.46f, 172.344f, 332.268f, 348.198f, 364.128f, 380.058f, 339.612f, 355.866f, 372.12f, 388.374f, 125.228f, 130.646f, 136.064f, 141.482f, 274.792f, 285.736f, 296.68f, 307.624f, 280.552f, 291.712f, 302.872f, 314.032f, 92.684f, 98.75f, 104.816f, 110.882f, 211.432f, 223.672f, 235.912f, 248.152f, 215.464f, 227.92f, 240.376f, 252.832f, + auto expGradI = NDArrayFactory::create('c', {bS, iD, iH, iW, iC},{0.226f, 0.343f, 0.46f, 0.577f, 1.172f, 1.46f, 1.748f, 2.036f, 1.892f, 2.288f, 2.684f, 3.08f, 1.284f, 1.581f, 1.878f, 2.175f, 4.458f, 5.133f, 5.808f, 6.483f, 6.186f, 7.023f, 7.86f, 8.697f, 3.39f, 3.93f, 4.47f, 5.01f, 9.642f, 10.803f, 11.964f, 13.125f, 11.37f, 12.693f, 14.016f, 15.339f, + 5.266f, 5.707f, 6.148f, 6.589f, 12.98f, 13.916f, 14.852f, 15.788f, 14.564f, 15.608f, 16.652f, 17.696f, 6.284f, 7.166f, 8.048f, 8.93f, 17.896f, 19.768f, 21.64f, 23.512f, 21.928f, 24.016f, 26.104f, 28.192f, 18.12f, 19.686f, 21.252f, 22.818f, 45.852f, 49.146f, 52.44f, 55.734f, 53.196f, 56.814f, 60.432f, 64.05f, + 28.164f, 30.216f, 32.268f, 34.32f, 67.884f, 72.15f, 76.416f, 80.682f, 75.228f, 79.818f, 84.408f, 88.998f, 29.324f, 30.854f, 32.384f, 33.914f, 67.432f, 70.6f, 73.768f, 76.936f, 73.192f, 76.576f, 79.96f, 83.344f, 27.884f, 30.062f, 32.24f, 34.418f, 66.28f, 70.744f, 75.208f, 79.672f, 70.312f, 74.992f, 79.672f, 84.352f, + 58.296f, 61.806f, 65.316f, 68.826f, 133.98f, 141.162f, 148.344f, 155.526f, 141.324f, 148.83f, 156.336f, 163.842f, 68.34f, 72.336f, 76.332f, 80.328f, 156.012f, 164.166f, 172.32f, 180.474f, 163.356f, 171.834f, 180.312f, 188.79f, 61.292f, 64.118f, 66.944f, 69.77f, 136.552f, 142.312f, 148.072f, 153.832f, 142.312f, 148.288f, 154.264f, 160.24f, + 9.298f, 11.359f, 13.42f, 15.481f, 27.092f, 31.268f, 35.444f, 39.62f, 27.812f, 32.096f, 36.38f, 40.664f, 26.556f, 29.769f, 32.982f, 36.195f, 66.666f, 73.173f, 79.68f, 86.187f, 68.394f, 75.063f, 81.732f, 88.401f, 28.662f, 32.118f, 35.574f, 39.03f, 71.85f, 78.843f, 85.836f, 92.829f, 73.578f, 80.733f, 87.888f, 95.043f, + 29.89f, 32.275f, 34.66f, 37.045f, 70.004f, 74.828f, 79.652f, 84.476f, 71.588f, 76.52f, 81.452f, 86.384f, 71.084f, 75.854f, 80.624f, 85.394f, 163.048f, 172.696f, 182.344f, 191.992f, 167.08f, 176.944f, 186.808f, 196.672f, 138.648f, 146.046f, 153.444f, 160.842f, 310.236f, 325.194f, 340.152f, 355.11f, 317.58f, 332.862f, 348.144f, 363.426f, + 148.692f, 156.576f, 164.46f, 172.344f, 332.268f, 348.198f, 364.128f, 380.058f, 339.612f, 355.866f, 372.12f, 388.374f, 125.228f, 130.646f, 136.064f, 141.482f, 274.792f, 285.736f, 296.68f, 307.624f, 280.552f, 291.712f, 302.872f, 314.032f, 92.684f, 98.75f, 104.816f, 110.882f, 211.432f, 223.672f, 235.912f, 248.152f, 215.464f, 227.92f, 240.376f, 252.832f, 178.824f, 188.166f, 197.508f, 206.85f, 398.364f, 417.21f, 436.056f, 454.902f, 405.708f, 424.878f, 444.048f, 463.218f, 188.868f, 198.696f, 208.524f, 218.352f, 420.396f, 440.214f, 460.032f, 479.85f, 427.74f, 447.882f, 468.024f, 488.166f, 157.196f, 163.91f, 170.624f, 177.338f, 343.912f, 357.448f, 370.984f, 384.52f, 349.672f, 363.424f, 377.176f, 390.928f}); - auto expGradW = NDArrayFactory::create('c', {kD, kH, kW, iC, oC},{120.96f, 122.04f, 123.12f, 120.96f, 122.04f, 123.12f, 120.96f, 122.04f, 123.12f, 120.96f, 122.04f, 123.12f, 79.56f, 80.28f, 81.f, 79.56f, 80.28f, 81.f, 79.56f, 80.28f, 81.f, 79.56f, 80.28f, 81.f, - 154.8f, 156.24f, 157.68f, 154.8f, 156.24f, 157.68f, 154.8f, 156.24f, 157.68f, 154.8f, 156.24f, 157.68f, 101.76f, 102.72f, 103.68f, 101.76f, 102.72f, 103.68f, 101.76f, 102.72f, 103.68f, 101.76f, 102.72f, 103.68f, - 111.24f, 112.32f, 113.4f, 111.24f, 112.32f, 113.4f, 111.24f, 112.32f, 113.4f, 111.24f, 112.32f, 113.4f, 73.08f, 73.8f, 74.52f, 73.08f, 73.8f, 74.52f, 73.08f, 73.8f, 74.52f, 73.08f, 73.8f, 74.52f, - 67.68f, 68.4f, 69.12f, 67.68f, 68.4f, 69.12f, 67.68f, 68.4f, 69.12f, 67.68f, 68.4f, 69.12f, 44.4f, 44.88f, 45.36f, 44.4f, 44.88f, 45.36f, 44.4f, 44.88f, 45.36f, 44.4f, 44.88f, 45.36f, - 85.92f, 86.88f, 87.84f, 85.92f, 86.88f, 87.84f, 85.92f, 86.88f, 87.84f, 85.92f, 86.88f, 87.84f, 56.32f, 56.96f, 57.6f, 56.32f, 56.96f, 57.6f, 56.32f, 56.96f, 57.6f, 56.32f, 56.96f, 57.6f, + auto expGradW = NDArrayFactory::create('c', {kD, kH, kW, iC, oC},{120.96f, 122.04f, 123.12f, 120.96f, 122.04f, 123.12f, 120.96f, 122.04f, 123.12f, 120.96f, 122.04f, 123.12f, 79.56f, 80.28f, 81.f, 79.56f, 80.28f, 81.f, 79.56f, 80.28f, 81.f, 79.56f, 80.28f, 81.f, + 154.8f, 156.24f, 157.68f, 154.8f, 156.24f, 157.68f, 154.8f, 156.24f, 157.68f, 154.8f, 156.24f, 157.68f, 101.76f, 102.72f, 103.68f, 101.76f, 102.72f, 103.68f, 101.76f, 102.72f, 103.68f, 101.76f, 102.72f, 103.68f, + 111.24f, 112.32f, 113.4f, 111.24f, 112.32f, 113.4f, 111.24f, 112.32f, 113.4f, 111.24f, 112.32f, 113.4f, 73.08f, 73.8f, 74.52f, 73.08f, 73.8f, 74.52f, 73.08f, 73.8f, 74.52f, 73.08f, 73.8f, 74.52f, + 67.68f, 68.4f, 69.12f, 67.68f, 68.4f, 69.12f, 67.68f, 68.4f, 69.12f, 67.68f, 68.4f, 69.12f, 44.4f, 44.88f, 45.36f, 44.4f, 44.88f, 45.36f, 44.4f, 44.88f, 45.36f, 44.4f, 44.88f, 45.36f, + 85.92f, 86.88f, 87.84f, 85.92f, 86.88f, 87.84f, 85.92f, 86.88f, 87.84f, 85.92f, 86.88f, 87.84f, 56.32f, 56.96f, 57.6f, 56.32f, 56.96f, 57.6f, 56.32f, 56.96f, 57.6f, 56.32f, 56.96f, 57.6f, 61.2f, 61.92f, 62.64f, 61.2f, 61.92f, 62.64f, 61.2f, 61.92f, 62.64f, 61.2f, 61.92f, 62.64f, 40.08f, 40.56f, 41.04f, 40.08f, 40.56f, 41.04f, 40.08f, 40.56f, 41.04f, 40.08f, 40.56f, 41.04f}); // auto expGradB('c', {oC},{}); @@ -1302,18 +1334,18 @@ TYPED_TEST(TypedConvolutionTests1, conv3d_bp_test2) { auto bias = NDArrayFactory::create('c', {oC}, {1,2,3}); auto gradO = NDArrayFactory::create('c', {bS, oD, oH, oW, oC}); - auto expGradI = NDArrayFactory::create('c', {bS, iD, iH, iW, iC},{ 0.014f, 0.032f, 0.05f, 0.068f, 0.118f, 0.181f, 0.244f, 0.307f, 0.212f, 0.257f, 0.302f, 0.347f, 0.208f, 0.298f, 0.388f, 0.478f, 1.028f, 1.262f, 1.496f, 1.73f, 1.036f, 1.18f, 1.324f, 1.468f, 0.928f, 1.018f, 1.108f, 1.198f, 2.9f, 3.134f, 3.368f, 3.602f, 2.188f, 2.332f, 2.476f, 2.62f, - 1.202f, 1.274f, 1.346f, 1.418f, 3.142f, 3.313f, 3.484f, 3.655f, 2.048f, 2.147f, 2.246f, 2.345f, 0.532f, 0.676f, 0.82f, 0.964f, 2.324f, 2.666f, 3.008f, 3.35f, 2.008f, 2.206f, 2.404f, 2.602f, 3.584f, 3.98f, 4.376f, 4.772f, 10.552f, 11.452f, 12.352f, 13.252f, 7.4f, 7.904f, 8.408f, 8.912f, - 6.752f, 7.148f, 7.544f, 7.94f, 17.752f, 18.652f, 19.552f, 20.452f, 11.432f, 11.936f, 12.44f, 12.944f, 5.932f, 6.184f, 6.436f, 6.688f, 14.42f, 14.978f, 15.536f, 16.094f, 8.704f, 9.01f, 9.316f, 9.622f, 3.11f, 3.236f, 3.362f, 3.488f, 7.39f, 7.669f, 7.948f, 8.227f, 4.388f, 4.541f, 4.694f, 4.847f, - 8.56f, 8.866f, 9.172f, 9.478f, 19.892f, 20.558f, 21.224f, 21.89f, 11.548f, 11.908f, 12.268f, 12.628f, 11.008f, 11.314f, 11.62f, 11.926f, 25.22f, 25.886f, 26.552f, 27.218f, 14.428f, 14.788f, 15.148f, 15.508f, 7.322f, 7.502f, 7.682f, 7.862f, 16.462f, 16.849f, 17.236f, 17.623f, 9.248f, 9.455f, 9.662f, 9.869f, - 0.158f, 0.392f, 0.626f, 0.86f, 1.27f, 1.765f, 2.26f, 2.755f, 1.22f, 1.481f, 1.742f, 2.003f, 2.224f, 2.746f, 3.268f, 3.79f, 6.788f, 7.886f, 8.984f, 10.082f, 4.78f, 5.356f, 5.932f, 6.508f, 6.4f, 6.922f, 7.444f, 7.966f, 15.572f, 16.67f, 17.768f, 18.866f, 9.388f, 9.964f, 10.54f, 11.116f, - 4.802f, 5.09f, 5.378f, 5.666f, 11.206f, 11.809f, 12.412f, 13.015f, 6.512f, 6.827f, 7.142f, 7.457f, 6.004f, 6.58f, 7.156f, 7.732f, 14.996f, 16.202f, 17.408f, 18.614f, 9.208f, 9.838f, 10.468f, 11.098f, 17.984f, 19.244f, 20.504f, 21.764f, 42.808f, 45.436f, 48.064f, 50.692f, 25.256f, 26.624f, 27.992f, 29.36f, - 28.064f, 29.324f, 30.584f, 31.844f, 63.832f, 66.46f, 69.088f, 71.716f, 36.2f, 37.568f, 38.936f, 40.304f, 18.316f, 19.f, 19.684f, 20.368f, 40.916f, 42.338f, 43.76f, 45.182f, 22.816f, 23.554f, 24.292f, 25.03f, 8.438f, 8.78f, 9.122f, 9.464f, 18.91f, 19.621f, 20.332f, 21.043f, 10.58f, 10.949f, 11.318f, 11.687f, + auto expGradI = NDArrayFactory::create('c', {bS, iD, iH, iW, iC},{ 0.014f, 0.032f, 0.05f, 0.068f, 0.118f, 0.181f, 0.244f, 0.307f, 0.212f, 0.257f, 0.302f, 0.347f, 0.208f, 0.298f, 0.388f, 0.478f, 1.028f, 1.262f, 1.496f, 1.73f, 1.036f, 1.18f, 1.324f, 1.468f, 0.928f, 1.018f, 1.108f, 1.198f, 2.9f, 3.134f, 3.368f, 3.602f, 2.188f, 2.332f, 2.476f, 2.62f, + 1.202f, 1.274f, 1.346f, 1.418f, 3.142f, 3.313f, 3.484f, 3.655f, 2.048f, 2.147f, 2.246f, 2.345f, 0.532f, 0.676f, 0.82f, 0.964f, 2.324f, 2.666f, 3.008f, 3.35f, 2.008f, 2.206f, 2.404f, 2.602f, 3.584f, 3.98f, 4.376f, 4.772f, 10.552f, 11.452f, 12.352f, 13.252f, 7.4f, 7.904f, 8.408f, 8.912f, + 6.752f, 7.148f, 7.544f, 7.94f, 17.752f, 18.652f, 19.552f, 20.452f, 11.432f, 11.936f, 12.44f, 12.944f, 5.932f, 6.184f, 6.436f, 6.688f, 14.42f, 14.978f, 15.536f, 16.094f, 8.704f, 9.01f, 9.316f, 9.622f, 3.11f, 3.236f, 3.362f, 3.488f, 7.39f, 7.669f, 7.948f, 8.227f, 4.388f, 4.541f, 4.694f, 4.847f, + 8.56f, 8.866f, 9.172f, 9.478f, 19.892f, 20.558f, 21.224f, 21.89f, 11.548f, 11.908f, 12.268f, 12.628f, 11.008f, 11.314f, 11.62f, 11.926f, 25.22f, 25.886f, 26.552f, 27.218f, 14.428f, 14.788f, 15.148f, 15.508f, 7.322f, 7.502f, 7.682f, 7.862f, 16.462f, 16.849f, 17.236f, 17.623f, 9.248f, 9.455f, 9.662f, 9.869f, + 0.158f, 0.392f, 0.626f, 0.86f, 1.27f, 1.765f, 2.26f, 2.755f, 1.22f, 1.481f, 1.742f, 2.003f, 2.224f, 2.746f, 3.268f, 3.79f, 6.788f, 7.886f, 8.984f, 10.082f, 4.78f, 5.356f, 5.932f, 6.508f, 6.4f, 6.922f, 7.444f, 7.966f, 15.572f, 16.67f, 17.768f, 18.866f, 9.388f, 9.964f, 10.54f, 11.116f, + 4.802f, 5.09f, 5.378f, 5.666f, 11.206f, 11.809f, 12.412f, 13.015f, 6.512f, 6.827f, 7.142f, 7.457f, 6.004f, 6.58f, 7.156f, 7.732f, 14.996f, 16.202f, 17.408f, 18.614f, 9.208f, 9.838f, 10.468f, 11.098f, 17.984f, 19.244f, 20.504f, 21.764f, 42.808f, 45.436f, 48.064f, 50.692f, 25.256f, 26.624f, 27.992f, 29.36f, + 28.064f, 29.324f, 30.584f, 31.844f, 63.832f, 66.46f, 69.088f, 71.716f, 36.2f, 37.568f, 38.936f, 40.304f, 18.316f, 19.f, 19.684f, 20.368f, 40.916f, 42.338f, 43.76f, 45.182f, 22.816f, 23.554f, 24.292f, 25.03f, 8.438f, 8.78f, 9.122f, 9.464f, 18.91f, 19.621f, 20.332f, 21.043f, 10.58f, 10.949f, 11.318f, 11.687f, 20.944f, 21.682f, 22.42f, 23.158f, 46.388f, 47.918f, 49.448f, 50.978f, 25.66f, 26.452f, 27.244f, 28.036f, 26.848f, 27.586f, 28.324f, 29.062f, 58.628f, 60.158f, 61.688f, 63.218f, 31.996f, 32.788f, 33.58f, 34.372f, 16.106f, 16.502f, 16.898f, 17.294f, 34.894f, 35.713f, 36.532f, 37.351f, 18.896f, 19.319f, 19.742f, 20.165f}); - auto expGradW = NDArrayFactory::create('c', {kD, kH, kW, iC, oC},{7.52f, 7.84f, 8.16f, 7.52f, 7.84f, 8.16f, 7.52f, 7.84f, 8.16f, 7.52f, 7.84f, 8.16f, 7.52f, 7.84f, 8.16f, 7.52f, 7.84f, 8.16f, 7.52f, 7.84f, 8.16f, 7.52f, 7.84f, 8.16f, 7.52f, 7.84f, 8.16f, 7.52f, 7.84f, 8.16f, 7.52f, 7.84f, 8.16f, 7.52f, 7.84f, 8.16f, - 7.52f, 7.84f, 8.16f, 7.52f, 7.84f, 8.16f, 7.52f, 7.84f, 8.16f, 7.52f, 7.84f, 8.16f, 7.52f, 7.84f, 8.16f, 7.52f, 7.84f, 8.16f, 7.52f, 7.84f, 8.16f, 7.52f, 7.84f, 8.16f, 7.52f, 7.84f, 8.16f, 7.52f, 7.84f, 8.16f, 7.52f, 7.84f, 8.16f, 7.52f, 7.84f, 8.16f, - 7.52f, 7.84f, 8.16f, 7.52f, 7.84f, 8.16f, 7.52f, 7.84f, 8.16f, 7.52f, 7.84f, 8.16f, 7.52f, 7.84f, 8.16f, 7.52f, 7.84f, 8.16f, 7.52f, 7.84f, 8.16f, 7.52f, 7.84f, 8.16f, 7.52f, 7.84f, 8.16f, 7.52f, 7.84f, 8.16f, 7.52f, 7.84f, 8.16f, 7.52f, 7.84f, 8.16f, + auto expGradW = NDArrayFactory::create('c', {kD, kH, kW, iC, oC},{7.52f, 7.84f, 8.16f, 7.52f, 7.84f, 8.16f, 7.52f, 7.84f, 8.16f, 7.52f, 7.84f, 8.16f, 7.52f, 7.84f, 8.16f, 7.52f, 7.84f, 8.16f, 7.52f, 7.84f, 8.16f, 7.52f, 7.84f, 8.16f, 7.52f, 7.84f, 8.16f, 7.52f, 7.84f, 8.16f, 7.52f, 7.84f, 8.16f, 7.52f, 7.84f, 8.16f, + 7.52f, 7.84f, 8.16f, 7.52f, 7.84f, 8.16f, 7.52f, 7.84f, 8.16f, 7.52f, 7.84f, 8.16f, 7.52f, 7.84f, 8.16f, 7.52f, 7.84f, 8.16f, 7.52f, 7.84f, 8.16f, 7.52f, 7.84f, 8.16f, 7.52f, 7.84f, 8.16f, 7.52f, 7.84f, 8.16f, 7.52f, 7.84f, 8.16f, 7.52f, 7.84f, 8.16f, + 7.52f, 7.84f, 8.16f, 7.52f, 7.84f, 8.16f, 7.52f, 7.84f, 8.16f, 7.52f, 7.84f, 8.16f, 7.52f, 7.84f, 8.16f, 7.52f, 7.84f, 8.16f, 7.52f, 7.84f, 8.16f, 7.52f, 7.84f, 8.16f, 7.52f, 7.84f, 8.16f, 7.52f, 7.84f, 8.16f, 7.52f, 7.84f, 8.16f, 7.52f, 7.84f, 8.16f, 7.52f, 7.84f, 8.16f, 7.52f, 7.84f, 8.16f, 7.52f, 7.84f, 8.16f, 7.52f, 7.84f, 8.16f, 7.52f, 7.84f, 8.16f, 7.52f, 7.84f, 8.16f, 7.52f, 7.84f, 8.16f, 7.52f, 7.84f, 8.16f, 7.52f, 7.84f, 8.16f, 7.52f, 7.84f, 8.16f, 7.52f, 7.84f, 8.16f, 7.52f, 7.84f, 8.16f}); // auto expGradB('c', {oC},{}); @@ -1350,20 +1382,20 @@ TYPED_TEST(TypedConvolutionTests1, conv3d_bp_test3) { auto bias = NDArrayFactory::create('c', {oC}, {1,2,3}); auto gradO = NDArrayFactory::create('c', {bS, oC, oD, oH, oW}); - auto expGradI = NDArrayFactory::create('c', {bS, iC, iD, iH, iW},{2.091f, 4.356f, 2.268f, 4.53f, 9.42f, 4.896f, 4.65f, 9.672f, 5.028f, 2.517f, 5.226f, 2.712f, 4.932f, 10.242f, 5.316f, 10.62f, 22.02f, 11.412f, 10.908f, 22.62f, 11.724f, 5.868f, 12.15f, 6.288f, 2.913f, 6.03f, 3.12f, 6.234f, 12.888f, 6.66f, 6.402f, 13.236f, 6.84f, 3.423f, 7.068f, 3.648f, - 2.415f, 5.04f, 2.628f, 5.25f, 10.932f, 5.688f, 5.37f, 11.184f, 5.82f, 2.913f, 6.054f, 3.144f, 5.724f, 11.898f, 6.18f, 12.348f, 25.62f, 13.284f, 12.636f, 26.22f, 13.596f, 6.804f, 14.094f, 7.296f, 3.381f, 7.002f, 3.624f, 7.242f, 14.976f, 7.74f, 7.41f, 15.324f, 7.92f, 3.963f, 8.184f, 4.224f, - 2.739f, 5.724f, 2.988f, 5.97f, 12.444f, 6.48f, 6.09f, 12.696f, 6.612f, 3.309f, 6.882f, 3.576f, 6.516f, 13.554f, 7.044f, 14.076f, 29.22f, 15.156f, 14.364f, 29.82f, 15.468f, 7.74f, 16.038f, 8.304f, 3.849f, 7.974f, 4.128f, 8.25f, 17.064f, 8.82f, 8.418f, 17.412f, 9.f, 4.503f, 9.3f, 4.8f, - 3.063f, 6.408f, 3.348f, 6.69f, 13.956f, 7.272f, 6.81f, 14.208f, 7.404f, 3.705f, 7.71f, 4.008f, 7.308f, 15.21f, 7.908f, 15.804f, 32.82f, 17.028f, 16.092f, 33.42f, 17.34f, 8.676f, 17.982f, 9.312f, 4.317f, 8.946f, 4.632f, 9.258f, 19.152f, 9.9f, 9.426f, 19.5f, 10.08f, 5.043f, 10.416f, 5.376f, - 5.619f, 11.484f, 5.868f, 11.73f, 23.964f, 12.24f, 12.138f, 24.792f, 12.66f, 6.333f, 12.93f, 6.6f, 12.42f, 25.362f, 12.948f, 25.884f, 52.836f, 26.964f, 26.748f, 54.588f, 27.852f, 13.932f, 28.422f, 14.496f, 6.873f, 14.022f, 7.152f, 14.298f, 29.16f, 14.868f, 14.754f, 30.084f, 15.336f, 7.671f, 15.636f, 7.968f, - 6.807f, 13.896f, 7.092f, 14.178f, 28.932f, 14.76f, 14.586f, 29.76f, 15.18f, 7.593f, 15.486f, 7.896f, 14.94f, 30.474f, 15.54f, 31.068f, 63.348f, 32.292f, 31.932f, 65.1f, 33.18f, 16.596f, 33.822f, 17.232f, 8.205f, 16.722f, 8.52f, 17.034f, 34.704f, 17.676f, 17.49f, 35.628f, 18.144f, 9.075f, 18.48f, 9.408f, - 7.995f, 16.308f, 8.316f, 16.626f, 33.9f, 17.28f, 17.034f, 34.728f, 17.7f, 8.853f, 18.042f, 9.192f, 17.46f, 35.586f, 18.132f, 36.252f, 73.86f, 37.62f, 37.116f, 75.612f, 38.508f, 19.26f, 39.222f, 19.968f, 9.537f, 19.422f, 9.888f, 19.77f, 40.248f, 20.484f, 20.226f, 41.172f, 20.952f, 10.479f, 21.324f, 10.848f, + auto expGradI = NDArrayFactory::create('c', {bS, iC, iD, iH, iW},{2.091f, 4.356f, 2.268f, 4.53f, 9.42f, 4.896f, 4.65f, 9.672f, 5.028f, 2.517f, 5.226f, 2.712f, 4.932f, 10.242f, 5.316f, 10.62f, 22.02f, 11.412f, 10.908f, 22.62f, 11.724f, 5.868f, 12.15f, 6.288f, 2.913f, 6.03f, 3.12f, 6.234f, 12.888f, 6.66f, 6.402f, 13.236f, 6.84f, 3.423f, 7.068f, 3.648f, + 2.415f, 5.04f, 2.628f, 5.25f, 10.932f, 5.688f, 5.37f, 11.184f, 5.82f, 2.913f, 6.054f, 3.144f, 5.724f, 11.898f, 6.18f, 12.348f, 25.62f, 13.284f, 12.636f, 26.22f, 13.596f, 6.804f, 14.094f, 7.296f, 3.381f, 7.002f, 3.624f, 7.242f, 14.976f, 7.74f, 7.41f, 15.324f, 7.92f, 3.963f, 8.184f, 4.224f, + 2.739f, 5.724f, 2.988f, 5.97f, 12.444f, 6.48f, 6.09f, 12.696f, 6.612f, 3.309f, 6.882f, 3.576f, 6.516f, 13.554f, 7.044f, 14.076f, 29.22f, 15.156f, 14.364f, 29.82f, 15.468f, 7.74f, 16.038f, 8.304f, 3.849f, 7.974f, 4.128f, 8.25f, 17.064f, 8.82f, 8.418f, 17.412f, 9.f, 4.503f, 9.3f, 4.8f, + 3.063f, 6.408f, 3.348f, 6.69f, 13.956f, 7.272f, 6.81f, 14.208f, 7.404f, 3.705f, 7.71f, 4.008f, 7.308f, 15.21f, 7.908f, 15.804f, 32.82f, 17.028f, 16.092f, 33.42f, 17.34f, 8.676f, 17.982f, 9.312f, 4.317f, 8.946f, 4.632f, 9.258f, 19.152f, 9.9f, 9.426f, 19.5f, 10.08f, 5.043f, 10.416f, 5.376f, + 5.619f, 11.484f, 5.868f, 11.73f, 23.964f, 12.24f, 12.138f, 24.792f, 12.66f, 6.333f, 12.93f, 6.6f, 12.42f, 25.362f, 12.948f, 25.884f, 52.836f, 26.964f, 26.748f, 54.588f, 27.852f, 13.932f, 28.422f, 14.496f, 6.873f, 14.022f, 7.152f, 14.298f, 29.16f, 14.868f, 14.754f, 30.084f, 15.336f, 7.671f, 15.636f, 7.968f, + 6.807f, 13.896f, 7.092f, 14.178f, 28.932f, 14.76f, 14.586f, 29.76f, 15.18f, 7.593f, 15.486f, 7.896f, 14.94f, 30.474f, 15.54f, 31.068f, 63.348f, 32.292f, 31.932f, 65.1f, 33.18f, 16.596f, 33.822f, 17.232f, 8.205f, 16.722f, 8.52f, 17.034f, 34.704f, 17.676f, 17.49f, 35.628f, 18.144f, 9.075f, 18.48f, 9.408f, + 7.995f, 16.308f, 8.316f, 16.626f, 33.9f, 17.28f, 17.034f, 34.728f, 17.7f, 8.853f, 18.042f, 9.192f, 17.46f, 35.586f, 18.132f, 36.252f, 73.86f, 37.62f, 37.116f, 75.612f, 38.508f, 19.26f, 39.222f, 19.968f, 9.537f, 19.422f, 9.888f, 19.77f, 40.248f, 20.484f, 20.226f, 41.172f, 20.952f, 10.479f, 21.324f, 10.848f, 9.183f, 18.72f, 9.54f, 19.074f, 38.868f, 19.8f, 19.482f, 39.696f, 20.22f, 10.113f, 20.598f, 10.488f, 19.98f, 40.698f, 20.724f, 41.436f, 84.372f, 42.948f, 42.3f, 86.124f, 43.836f, 21.924f, 44.622f, 22.704f, 10.869f, 22.122f, 11.256f, 22.506f, 45.792f, 23.292f, 22.962f, 46.716f, 23.76f, 11.883f, 24.168f, 12.288f}); - auto expGradW = NDArrayFactory::create('c', {oC, iC, kD, kH, kW},{5.28f, 5.28f, 5.28f, 5.28f, 5.28f, 5.28f, 5.28f, 5.28f, 5.28f, 5.28f, 5.28f, 5.28f, 5.28f, 5.28f, 5.28f, 5.28f, 5.28f, 5.28f, 5.28f, 5.28f, 5.28f, 5.28f, 5.28f, 5.28f, - 5.28f, 5.28f, 5.28f, 5.28f, 5.28f, 5.28f, 5.28f, 5.28f, 5.28f, 5.28f, 5.28f, 5.28f, 5.28f, 5.28f, 5.28f, 5.28f, 5.28f, 5.28f, 5.28f, 5.28f, 5.28f, 5.28f, 5.28f, 5.28f, - 7.84f, 7.84f, 7.84f, 7.84f, 7.84f, 7.84f, 7.84f, 7.84f, 7.84f, 7.84f, 7.84f, 7.84f, 7.84f, 7.84f, 7.84f, 7.84f, 7.84f, 7.84f, 7.84f, 7.84f, 7.84f, 7.84f, 7.84f, 7.84f, - 7.84f, 7.84f, 7.84f, 7.84f, 7.84f, 7.84f, 7.84f, 7.84f, 7.84f, 7.84f, 7.84f, 7.84f, 7.84f, 7.84f, 7.84f, 7.84f, 7.84f, 7.84f, 7.84f, 7.84f, 7.84f, 7.84f, 7.84f, 7.84f, - 10.4f, 10.4f, 10.4f, 10.4f, 10.4f, 10.4f, 10.4f, 10.4f, 10.4f, 10.4f, 10.4f, 10.4f, 10.4f, 10.4f, 10.4f, 10.4f, 10.4f, 10.4f, 10.4f, 10.4f, 10.4f, 10.4f, 10.4f, 10.4f, + auto expGradW = NDArrayFactory::create('c', {oC, iC, kD, kH, kW},{5.28f, 5.28f, 5.28f, 5.28f, 5.28f, 5.28f, 5.28f, 5.28f, 5.28f, 5.28f, 5.28f, 5.28f, 5.28f, 5.28f, 5.28f, 5.28f, 5.28f, 5.28f, 5.28f, 5.28f, 5.28f, 5.28f, 5.28f, 5.28f, + 5.28f, 5.28f, 5.28f, 5.28f, 5.28f, 5.28f, 5.28f, 5.28f, 5.28f, 5.28f, 5.28f, 5.28f, 5.28f, 5.28f, 5.28f, 5.28f, 5.28f, 5.28f, 5.28f, 5.28f, 5.28f, 5.28f, 5.28f, 5.28f, + 7.84f, 7.84f, 7.84f, 7.84f, 7.84f, 7.84f, 7.84f, 7.84f, 7.84f, 7.84f, 7.84f, 7.84f, 7.84f, 7.84f, 7.84f, 7.84f, 7.84f, 7.84f, 7.84f, 7.84f, 7.84f, 7.84f, 7.84f, 7.84f, + 7.84f, 7.84f, 7.84f, 7.84f, 7.84f, 7.84f, 7.84f, 7.84f, 7.84f, 7.84f, 7.84f, 7.84f, 7.84f, 7.84f, 7.84f, 7.84f, 7.84f, 7.84f, 7.84f, 7.84f, 7.84f, 7.84f, 7.84f, 7.84f, + 10.4f, 10.4f, 10.4f, 10.4f, 10.4f, 10.4f, 10.4f, 10.4f, 10.4f, 10.4f, 10.4f, 10.4f, 10.4f, 10.4f, 10.4f, 10.4f, 10.4f, 10.4f, 10.4f, 10.4f, 10.4f, 10.4f, 10.4f, 10.4f, 10.4f, 10.4f, 10.4f, 10.4f, 10.4f, 10.4f, 10.4f, 10.4f, 10.4f, 10.4f, 10.4f, 10.4f, 10.4f, 10.4f, 10.4f, 10.4f, 10.4f, 10.4f, 10.4f, 10.4f, 10.4f, 10.4f, 10.4f, 10.4f}); auto expGradB = NDArrayFactory::create('c', {oC},{2.64f, 3.92f, 5.2f}); @@ -1407,9 +1439,9 @@ TYPED_TEST(TypedConvolutionTests1, depthwise_conv2d_1) { auto weights = NDArrayFactory::create('c', {kH, kW, iC, mC}); - auto expOutput = NDArrayFactory::create('c', {bS, oH, oW, oC},{12.f, 12.8f, 13.6f, 14.4f, 12.f, 12.8f, 13.6f, 14.4f, 5.2f, 5.6f, 6.f, 6.4f, 13.2f, 14.4f, 15.6f, 16.8f, 13.2f, 14.4f, 15.6f, 16.8f, 5.4f, 6.f, 6.6f, 7.2f, - 13.2f, 14.4f, 15.6f, 16.8f, 13.2f, 14.4f, 15.6f, 16.8f, 5.4f, 6.f, 6.6f, 7.2f, 5.6f, 6.4f, 7.2f, 8.f, 5.6f, 6.4f, 7.2f, 8.f, 2.f, 2.4f, 2.8f, 3.2f, - 12.f, 12.8f, 13.6f, 14.4f, 12.f, 12.8f, 13.6f, 14.4f, 5.2f, 5.6f, 6.f, 6.4f, 13.2f, 14.4f, 15.6f, 16.8f, 13.2f, 14.4f, 15.6f, 16.8f, 5.4f, 6.f, 6.6f, 7.2f, + auto expOutput = NDArrayFactory::create('c', {bS, oH, oW, oC},{12.f, 12.8f, 13.6f, 14.4f, 12.f, 12.8f, 13.6f, 14.4f, 5.2f, 5.6f, 6.f, 6.4f, 13.2f, 14.4f, 15.6f, 16.8f, 13.2f, 14.4f, 15.6f, 16.8f, 5.4f, 6.f, 6.6f, 7.2f, + 13.2f, 14.4f, 15.6f, 16.8f, 13.2f, 14.4f, 15.6f, 16.8f, 5.4f, 6.f, 6.6f, 7.2f, 5.6f, 6.4f, 7.2f, 8.f, 5.6f, 6.4f, 7.2f, 8.f, 2.f, 2.4f, 2.8f, 3.2f, + 12.f, 12.8f, 13.6f, 14.4f, 12.f, 12.8f, 13.6f, 14.4f, 5.2f, 5.6f, 6.f, 6.4f, 13.2f, 14.4f, 15.6f, 16.8f, 13.2f, 14.4f, 15.6f, 16.8f, 5.4f, 6.f, 6.6f, 7.2f, 13.2f, 14.4f, 15.6f, 16.8f, 13.2f, 14.4f, 15.6f, 16.8f, 5.4f, 6.f, 6.6f, 7.2f, 5.6f, 6.4f, 7.2f, 8.f, 5.6f, 6.4f, 7.2f, 8.f, 2.f, 2.4f, 2.8f, 3.2f}); input = 2.; weights.linspace(0.1, 0.1); @@ -1439,7 +1471,7 @@ TEST_F(ConvolutionTests1, depthwise_conv2d_2) { auto weights = NDArrayFactory::create('c', {kH, kW, iC, mC}); - auto expOutput = NDArrayFactory::create('c', {bS, oH, oW, oC},{13.2f, 14.4f, 15.6f, 16.8f, 13.2f, 14.4f, 15.6f, 16.8f, 13.2f, 14.4f, 15.6f, 16.8f, 13.2f, 14.4f, 15.6f, 16.8f, + auto expOutput = NDArrayFactory::create('c', {bS, oH, oW, oC},{13.2f, 14.4f, 15.6f, 16.8f, 13.2f, 14.4f, 15.6f, 16.8f, 13.2f, 14.4f, 15.6f, 16.8f, 13.2f, 14.4f, 15.6f, 16.8f, 13.2f, 14.4f, 15.6f, 16.8f, 13.2f, 14.4f, 15.6f, 16.8f, 13.2f, 14.4f, 15.6f, 16.8f, 13.2f, 14.4f, 15.6f, 16.8f}); input = 2.; weights.linspace(0.1, 0.1); @@ -1697,13 +1729,13 @@ TYPED_TEST(TypedConvolutionTests1, conv3d_test1) { auto input = NDArrayFactory::create('c', {bS, iD, iH, iW, iC}); auto weights = NDArrayFactory::create('c', {kD, kH, kW, iC, oC}); - auto expected = NDArrayFactory::create('c', {2, 3, 4, 3, 3}, {64.f, 64.f, 64.f, 64.f, 64.f, 64.f, 32.f, 32.f, 32.f, 96.f, 96.f, 96.f, 96.f, 96.f, 96.f, 48.f, 48.f, 48.f, 96.f, 96.f, 96.f, 96.f, 96.f, 96.f, 48.f, 48.f, 48.f, - 64.f, 64.f, 64.f, 64.f, 64.f, 64.f, 32.f, 32.f, 32.f, 64.f, 64.f, 64.f, 64.f, 64.f, 64.f, 32.f, 32.f, 32.f, 96.f, 96.f, 96.f, 96.f, 96.f, 96.f, 48.f, 48.f, 48.f, - 96.f, 96.f, 96.f, 96.f, 96.f, 96.f, 48.f, 48.f, 48.f, 64.f, 64.f, 64.f, 64.f, 64.f, 64.f, 32.f, 32.f, 32.f, 32.f, 32.f, 32.f, 32.f, 32.f, 32.f, 16.f, 16.f, 16.f, - 48.f, 48.f, 48.f, 48.f, 48.f, 48.f, 24.f, 24.f, 24.f, 48.f, 48.f, 48.f, 48.f, 48.f, 48.f, 24.f, 24.f, 24.f, 32.f, 32.f, 32.f, 32.f, 32.f, 32.f, 16.f, 16.f, 16.f, - 64.f, 64.f, 64.f, 64.f, 64.f, 64.f, 32.f, 32.f, 32.f, 96.f, 96.f, 96.f, 96.f, 96.f, 96.f, 48.f, 48.f, 48.f, 96.f, 96.f, 96.f, 96.f, 96.f, 96.f, 48.f, 48.f, 48.f, - 64.f, 64.f, 64.f, 64.f, 64.f, 64.f, 32.f, 32.f, 32.f, 64.f, 64.f, 64.f, 64.f, 64.f, 64.f, 32.f, 32.f, 32.f, 96.f, 96.f, 96.f, 96.f, 96.f, 96.f, 48.f, 48.f, 48.f, - 96.f, 96.f, 96.f, 96.f, 96.f, 96.f, 48.f, 48.f, 48.f, 64.f, 64.f, 64.f, 64.f, 64.f, 64.f, 32.f, 32.f, 32.f, 32.f, 32.f, 32.f, 32.f, 32.f, 32.f, 16.f, 16.f, 16.f, + auto expected = NDArrayFactory::create('c', {2, 3, 4, 3, 3}, {64.f, 64.f, 64.f, 64.f, 64.f, 64.f, 32.f, 32.f, 32.f, 96.f, 96.f, 96.f, 96.f, 96.f, 96.f, 48.f, 48.f, 48.f, 96.f, 96.f, 96.f, 96.f, 96.f, 96.f, 48.f, 48.f, 48.f, + 64.f, 64.f, 64.f, 64.f, 64.f, 64.f, 32.f, 32.f, 32.f, 64.f, 64.f, 64.f, 64.f, 64.f, 64.f, 32.f, 32.f, 32.f, 96.f, 96.f, 96.f, 96.f, 96.f, 96.f, 48.f, 48.f, 48.f, + 96.f, 96.f, 96.f, 96.f, 96.f, 96.f, 48.f, 48.f, 48.f, 64.f, 64.f, 64.f, 64.f, 64.f, 64.f, 32.f, 32.f, 32.f, 32.f, 32.f, 32.f, 32.f, 32.f, 32.f, 16.f, 16.f, 16.f, + 48.f, 48.f, 48.f, 48.f, 48.f, 48.f, 24.f, 24.f, 24.f, 48.f, 48.f, 48.f, 48.f, 48.f, 48.f, 24.f, 24.f, 24.f, 32.f, 32.f, 32.f, 32.f, 32.f, 32.f, 16.f, 16.f, 16.f, + 64.f, 64.f, 64.f, 64.f, 64.f, 64.f, 32.f, 32.f, 32.f, 96.f, 96.f, 96.f, 96.f, 96.f, 96.f, 48.f, 48.f, 48.f, 96.f, 96.f, 96.f, 96.f, 96.f, 96.f, 48.f, 48.f, 48.f, + 64.f, 64.f, 64.f, 64.f, 64.f, 64.f, 32.f, 32.f, 32.f, 64.f, 64.f, 64.f, 64.f, 64.f, 64.f, 32.f, 32.f, 32.f, 96.f, 96.f, 96.f, 96.f, 96.f, 96.f, 48.f, 48.f, 48.f, + 96.f, 96.f, 96.f, 96.f, 96.f, 96.f, 48.f, 48.f, 48.f, 64.f, 64.f, 64.f, 64.f, 64.f, 64.f, 32.f, 32.f, 32.f, 32.f, 32.f, 32.f, 32.f, 32.f, 32.f, 16.f, 16.f, 16.f, 48.f, 48.f, 48.f, 48.f, 48.f, 48.f, 24.f, 24.f, 24.f, 48.f, 48.f, 48.f, 48.f, 48.f, 48.f, 24.f, 24.f, 24.f, 32.f, 32.f, 32.f, 32.f, 32.f, 32.f, 16.f, 16.f, 16.f}); input = 2.; weights = 1.; @@ -1729,13 +1761,13 @@ TYPED_TEST(TypedConvolutionTests1, conv3d_test2) { auto input = NDArrayFactory::create('c', {bS, iD, iH, iW, iC}); auto weights = NDArrayFactory::create('c', {kD, kH, kW, iC, oC}); - auto expected = NDArrayFactory::create('c', {2, 3, 4, 3, 3}, {534.4f, 540.8f, 547.2f, 534.4f, 540.8f, 547.2f, 248.f, 251.2f, 254.4f, 686.4f, 696.f, 705.6f, 686.4f, 696.f, 705.6f, 314.4f, 319.2f, 324.f, 686.4f, 696.f, 705.6f, 686.4f, 696.f, 705.6f, 314.4f, 319.2f, 324.f, - 380.8f, 387.2f, 393.6f, 380.8f, 387.2f, 393.6f, 171.2f, 174.4f, 177.6f, 534.4f, 540.8f, 547.2f, 534.4f, 540.8f, 547.2f, 248.f, 251.2f, 254.4f, 686.4f, 696.f, 705.6f, 686.4f, 696.f, 705.6f, 314.4f, 319.2f, 324.f, - 686.4f, 696.f, 705.6f, 686.4f, 696.f, 705.6f, 314.4f, 319.2f, 324.f, 380.8f, 387.2f, 393.6f, 380.8f, 387.2f, 393.6f, 171.2f, 174.4f, 177.6f, 152.f, 155.2f, 158.4f, 152.f, 155.2f, 158.4f, 66.4f, 68.f, 69.6f, - 170.4f, 175.2f, 180.f, 170.4f, 175.2f, 180.f, 70.8f, 73.2f, 75.6f, 170.4f, 175.2f, 180.f, 170.4f, 175.2f, 180.f, 70.8f, 73.2f, 75.6f, 75.2f, 78.4f, 81.6f, 75.2f, 78.4f, 81.6f, 28.f, 29.6f, 31.2f, - 534.4f, 540.8f, 547.2f, 534.4f, 540.8f, 547.2f, 248.f, 251.2f, 254.4f, 686.4f, 696.f, 705.6f, 686.4f, 696.f, 705.6f, 314.4f, 319.2f, 324.f, 686.4f, 696.f, 705.6f, 686.4f, 696.f, 705.6f, 314.4f, 319.2f, 324.f, - 380.8f, 387.2f, 393.6f, 380.8f, 387.2f, 393.6f, 171.2f, 174.4f, 177.6f, 534.4f, 540.8f, 547.2f, 534.4f, 540.8f, 547.2f, 248.f, 251.2f, 254.4f, 686.4f, 696.f, 705.6f, 686.4f, 696.f, 705.6f, 314.4f, 319.2f, 324.f, - 686.4f, 696.f, 705.6f, 686.4f, 696.f, 705.6f, 314.4f, 319.2f, 324.f, 380.8f, 387.2f, 393.6f, 380.8f, 387.2f, 393.6f, 171.2f, 174.4f, 177.6f, 152.f, 155.2f, 158.4f, 152.f, 155.2f, 158.4f, 66.4f, 68.f, 69.6f, + auto expected = NDArrayFactory::create('c', {2, 3, 4, 3, 3}, {534.4f, 540.8f, 547.2f, 534.4f, 540.8f, 547.2f, 248.f, 251.2f, 254.4f, 686.4f, 696.f, 705.6f, 686.4f, 696.f, 705.6f, 314.4f, 319.2f, 324.f, 686.4f, 696.f, 705.6f, 686.4f, 696.f, 705.6f, 314.4f, 319.2f, 324.f, + 380.8f, 387.2f, 393.6f, 380.8f, 387.2f, 393.6f, 171.2f, 174.4f, 177.6f, 534.4f, 540.8f, 547.2f, 534.4f, 540.8f, 547.2f, 248.f, 251.2f, 254.4f, 686.4f, 696.f, 705.6f, 686.4f, 696.f, 705.6f, 314.4f, 319.2f, 324.f, + 686.4f, 696.f, 705.6f, 686.4f, 696.f, 705.6f, 314.4f, 319.2f, 324.f, 380.8f, 387.2f, 393.6f, 380.8f, 387.2f, 393.6f, 171.2f, 174.4f, 177.6f, 152.f, 155.2f, 158.4f, 152.f, 155.2f, 158.4f, 66.4f, 68.f, 69.6f, + 170.4f, 175.2f, 180.f, 170.4f, 175.2f, 180.f, 70.8f, 73.2f, 75.6f, 170.4f, 175.2f, 180.f, 170.4f, 175.2f, 180.f, 70.8f, 73.2f, 75.6f, 75.2f, 78.4f, 81.6f, 75.2f, 78.4f, 81.6f, 28.f, 29.6f, 31.2f, + 534.4f, 540.8f, 547.2f, 534.4f, 540.8f, 547.2f, 248.f, 251.2f, 254.4f, 686.4f, 696.f, 705.6f, 686.4f, 696.f, 705.6f, 314.4f, 319.2f, 324.f, 686.4f, 696.f, 705.6f, 686.4f, 696.f, 705.6f, 314.4f, 319.2f, 324.f, + 380.8f, 387.2f, 393.6f, 380.8f, 387.2f, 393.6f, 171.2f, 174.4f, 177.6f, 534.4f, 540.8f, 547.2f, 534.4f, 540.8f, 547.2f, 248.f, 251.2f, 254.4f, 686.4f, 696.f, 705.6f, 686.4f, 696.f, 705.6f, 314.4f, 319.2f, 324.f, + 686.4f, 696.f, 705.6f, 686.4f, 696.f, 705.6f, 314.4f, 319.2f, 324.f, 380.8f, 387.2f, 393.6f, 380.8f, 387.2f, 393.6f, 171.2f, 174.4f, 177.6f, 152.f, 155.2f, 158.4f, 152.f, 155.2f, 158.4f, 66.4f, 68.f, 69.6f, 170.4f, 175.2f, 180.f, 170.4f, 175.2f, 180.f, 70.8f, 73.2f, 75.6f, 170.4f, 175.2f, 180.f, 170.4f, 175.2f, 180.f, 70.8f, 73.2f, 75.6f, 75.2f, 78.4f, 81.6f, 75.2f, 78.4f, 81.6f, 28.f, 29.6f, 31.2f}); input = 2.; weights.linspace(0.1, 0.1); @@ -1760,9 +1792,9 @@ TYPED_TEST(TypedConvolutionTests1, conv3d_test3) { auto input = NDArrayFactory::create('c', {bS, iD, iH, iW, iC}); auto weights = NDArrayFactory::create('c', {kD, kH, kW, iC, oC}); - auto expected = NDArrayFactory::create('c', {2, 2, 2, 2, 3}, {686.4f, 696.f, 705.6f, 686.4f, 696.f, 705.6f, 686.4f, 696.f, 705.6f, 686.4f, 696.f, 705.6f, - 686.4f, 696.f, 705.6f, 686.4f, 696.f, 705.6f, 686.4f, 696.f, 705.6f, 686.4f, 696.f, 705.6f, - 686.4f, 696.f, 705.6f, 686.4f, 696.f, 705.6f, 686.4f, 696.f, 705.6f, 686.4f, 696.f, 705.6f, + auto expected = NDArrayFactory::create('c', {2, 2, 2, 2, 3}, {686.4f, 696.f, 705.6f, 686.4f, 696.f, 705.6f, 686.4f, 696.f, 705.6f, 686.4f, 696.f, 705.6f, + 686.4f, 696.f, 705.6f, 686.4f, 696.f, 705.6f, 686.4f, 696.f, 705.6f, 686.4f, 696.f, 705.6f, + 686.4f, 696.f, 705.6f, 686.4f, 696.f, 705.6f, 686.4f, 696.f, 705.6f, 686.4f, 696.f, 705.6f, 686.4f, 696.f, 705.6f, 686.4f, 696.f, 705.6f, 686.4f, 696.f, 705.6f, 686.4f, 696.f, 705.6f}); input = 2.; weights.linspace(0.1, 0.1); @@ -1844,8 +1876,8 @@ TYPED_TEST(TypedConvolutionTests1, conv3d_test6) { auto input = NDArrayFactory::create('c', {bS, iC, iD, iH, iW}); auto weights = NDArrayFactory::create('c', {kD, kH, kW, iC, oC}); auto bias = NDArrayFactory::create('c', {oC},{1.f, 2.f, 3.f}); - auto expected = NDArrayFactory::create('c', {2, 3, 2, 2, 2},{49.f, 49.f, 49.f, 49.f, 49.f, 49.f, 49.f, 49.f, 50.f, 50.f, 50.f, 50.f, 50.f, 50.f, 50.f, 50.f, - 51.f, 51.f, 51.f, 51.f, 51.f, 51.f, 51.f, 51.f, 49.f, 49.f, 49.f, 49.f, 49.f, 49.f, 49.f, 49.f, + auto expected = NDArrayFactory::create('c', {2, 3, 2, 2, 2},{49.f, 49.f, 49.f, 49.f, 49.f, 49.f, 49.f, 49.f, 50.f, 50.f, 50.f, 50.f, 50.f, 50.f, 50.f, 50.f, + 51.f, 51.f, 51.f, 51.f, 51.f, 51.f, 51.f, 51.f, 49.f, 49.f, 49.f, 49.f, 49.f, 49.f, 49.f, 49.f, 50.f, 50.f, 50.f, 50.f, 50.f, 50.f, 50.f, 50.f, 51.f, 51.f, 51.f, 51.f, 51.f, 51.f, 51.f, 51.f}); input = 2.; weights = 0.5; @@ -1873,9 +1905,9 @@ TYPED_TEST(TypedConvolutionTests1, conv3d_test7) { auto input = NDArrayFactory::create('c', {bS, iC, iD, iH, iW}); auto weights = NDArrayFactory::create('c', {oC, iC, kD, kH, kW}); auto bias = NDArrayFactory::create('c', {oC},{1.f, 2.f, 3.f}); - auto expected = NDArrayFactory::create('c', {2, 3, 2, 2, 2},{236.2f, 236.2f, 236.2f, 236.2f, 236.2f, 236.2f, 236.2f, 236.2f, 698.f, 698.f, 698.f, 698.f, - 698.f, 698.f, 698.f, 698.f, 1159.8f, 1159.8f, 1159.8f, 1159.8f, 1159.8f, 1159.8f, 1159.8f, 1159.8f, - 236.2f, 236.2f, 236.2f, 236.2f, 236.2f, 236.2f, 236.2f, 236.2f, 698.f, 698.f, 698.f, 698.f, + auto expected = NDArrayFactory::create('c', {2, 3, 2, 2, 2},{236.2f, 236.2f, 236.2f, 236.2f, 236.2f, 236.2f, 236.2f, 236.2f, 698.f, 698.f, 698.f, 698.f, + 698.f, 698.f, 698.f, 698.f, 1159.8f, 1159.8f, 1159.8f, 1159.8f, 1159.8f, 1159.8f, 1159.8f, 1159.8f, + 236.2f, 236.2f, 236.2f, 236.2f, 236.2f, 236.2f, 236.2f, 236.2f, 698.f, 698.f, 698.f, 698.f, 698.f, 698.f, 698.f, 698.f, 1159.8f, 1159.8f, 1159.8f, 1159.8f, 1159.8f, 1159.8f, 1159.8f, 1159.8f}); input = 2.; weights.linspace(0.1, 0.1); @@ -1903,8 +1935,8 @@ TYPED_TEST(TypedConvolutionTests1, conv3d_test8) { auto input = NDArrayFactory::create('c', {bS, iC, iD, iH, iW}); auto weights = NDArrayFactory::create('c', {oC, iC, kD, kH, kW}); - auto expected = NDArrayFactory::create('c', {2, 3, 2, 2, 2},{235.2f, 235.2f, 235.2f, 235.2f, 235.2f, 235.2f, 235.2f, 235.2f, 696.f, 696.f, 696.f, 696.f, 696.f, 696.f, 696.f, 696.f, - 1156.8f, 1156.8f, 1156.8f, 1156.8f, 1156.8f, 1156.8f, 1156.8f, 1156.8f, 235.2f, 235.2f, 235.2f, 235.2f, 235.2f, 235.2f, 235.2f, 235.2f, + auto expected = NDArrayFactory::create('c', {2, 3, 2, 2, 2},{235.2f, 235.2f, 235.2f, 235.2f, 235.2f, 235.2f, 235.2f, 235.2f, 696.f, 696.f, 696.f, 696.f, 696.f, 696.f, 696.f, 696.f, + 1156.8f, 1156.8f, 1156.8f, 1156.8f, 1156.8f, 1156.8f, 1156.8f, 1156.8f, 235.2f, 235.2f, 235.2f, 235.2f, 235.2f, 235.2f, 235.2f, 235.2f, 696.f, 696.f, 696.f, 696.f, 696.f, 696.f, 696.f, 696.f, 1156.8f, 1156.8f, 1156.8f, 1156.8f, 1156.8f, 1156.8f, 1156.8f, 1156.8f}); input = 2.; weights.linspace(0.1, 0.1); @@ -1997,9 +2029,9 @@ TYPED_TEST(TypedConvolutionTests1, pointwise_conv2d_test1) { auto bias = NDArrayFactory::create('c', {oC}); - auto expOutput = NDArrayFactory::create('c', {bS, iH, iW, oC},{ 5.4f, 6.2f, 7.0f, 5.4f, 6.2f, 7.0f, 5.4f, 6.2f, 7.0f, 5.4f, 6.2f, 7.0f, 5.4f, 6.2f, 7.0f, 5.4f, 6.2f, 7.0f, 5.4f, 6.2f, - 7.0f, 5.4f, 6.2f, 7.0f, 5.4f, 6.2f, 7.0f, 5.4f, 6.2f, 7.0f, 5.4f, 6.2f, 7.0f, 5.4f, 6.2f, 7.0f, 5.4f, 6.2f, 7.0f, 5.4f, - 6.2f, 7.0f, 5.4f, 6.2f, 7.0f, 5.4f, 6.2f, 7.0f, 5.4f, 6.2f, 7.0f, 5.4f, 6.2f, 7.0f, 5.4f, 6.2f, 7.0f, 5.4f, 6.2f, 7.0f, + auto expOutput = NDArrayFactory::create('c', {bS, iH, iW, oC},{ 5.4f, 6.2f, 7.0f, 5.4f, 6.2f, 7.0f, 5.4f, 6.2f, 7.0f, 5.4f, 6.2f, 7.0f, 5.4f, 6.2f, 7.0f, 5.4f, 6.2f, 7.0f, 5.4f, 6.2f, + 7.0f, 5.4f, 6.2f, 7.0f, 5.4f, 6.2f, 7.0f, 5.4f, 6.2f, 7.0f, 5.4f, 6.2f, 7.0f, 5.4f, 6.2f, 7.0f, 5.4f, 6.2f, 7.0f, 5.4f, + 6.2f, 7.0f, 5.4f, 6.2f, 7.0f, 5.4f, 6.2f, 7.0f, 5.4f, 6.2f, 7.0f, 5.4f, 6.2f, 7.0f, 5.4f, 6.2f, 7.0f, 5.4f, 6.2f, 7.0f, 5.4f, 6.2f, 7.0f, 5.4f, 6.2f, 7.0f, 5.4f, 6.2f, 7.0f, 5.4f, 6.2f, 7.0f}); input = 2.; weights.linspace(0.1, 0.1); @@ -2110,20 +2142,20 @@ TEST_F(ConvolutionTests1, vol2col_test2) { auto columns = NDArrayFactory::create('c', {kD, iC, kH, oW, kW, bS, oD, oH}); columns.permutei({5, 1, 0, 2, 4, 6, 7, 3}); columns = -1.; - auto columnsExpected = NDArrayFactory::create('c', {bS, iC, kD, kH, kW, oD, oH, oW}, {1.f, 2.f, 3.f, 4.f, 5.f, 6.f, 7.f, 8.f, 9.f, -10.f, 11.f, 12.f, 2.f, 0.f, 4.f, 0.f, 6.f, 0.f, 8.f, 0.f, 10.f, 0.f, 12.f, 0.f, 3.f, 4.f, 5.f, 6.f, 0.f, 0.f, 9.f, 10.f, 11.f, 12.f, 0.f, 0.f, 4.f, 0.f, 6.f, 0.f, 0.f, 0.f, 10.f, 0.f, 12.f, 0.f, 0.f, 0.f, 5.f, 6.f, 0.f, 0.f, 0.f, 0.f, 11.f, 12.f, 0.f, 0.f, 0.f, 0.f, 6.f, 0.f, 0.f, 0.f, 0.f, 0.f, 12.f, 0.f, 0.f, 0.f, 0.f, 0.f, 7.f, 8.f, -9.f, 10.f, 11.f, 12.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 8.f, 0.f, 10.f, 0.f, 12.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 9.f, 10.f, 11.f, 12.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 10.f, 0.f, 12.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 11.f, 12.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 12.f, 0.f, 0.f, 0.f, 0.f, 0.f, -0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 13.f, 14.f, 15.f, 16.f, 17.f, 18.f, 19.f, 20.f, 21.f, 22.f, 23.f, 24.f, 14.f, 0.f, 16.f, 0.f, 18.f, 0.f, 20.f, 0.f, 22.f, 0.f, 24.f, 0.f, 15.f, 16.f, 17.f, 18.f, 0.f, 0.f, 21.f, 22.f, 23.f, 24.f, 0.f, 0.f, 16.f, 0.f, 18.f, 0.f, 0.f, 0.f, 22.f, 0.f, 24.f, 0.f, 0.f, 0.f, 17.f, 18.f, 0.f, 0.f, 0.f, 0.f, -23.f, 24.f, 0.f, 0.f, 0.f, 0.f, 18.f, 0.f, 0.f, 0.f, 0.f, 0.f, 24.f, 0.f, 0.f, 0.f, 0.f, 0.f, 19.f, 20.f, 21.f, 22.f, 23.f, 24.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 20.f, 0.f, 22.f, 0.f, 24.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 21.f, 22.f, 23.f, 24.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 22.f, 0.f, 24.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, -0.f, 0.f, 0.f, 23.f, 24.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 24.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 25.f, 26.f, 27.f, 28.f, 29.f, 30.f, 31.f, 32.f, 33.f, 34.f, 35.f, 36.f, 26.f, 0.f, 28.f, 0.f, 30.f, 0.f, 32.f, 0.f, 34.f, 0.f, 36.f, 0.f, 27.f, 28.f, 29.f, 30.f, 0.f, 0.f, 33.f, 34.f, 35.f, 36.f, -0.f, 0.f, 28.f, 0.f, 30.f, 0.f, 0.f, 0.f, 34.f, 0.f, 36.f, 0.f, 0.f, 0.f, 29.f, 30.f, 0.f, 0.f, 0.f, 0.f, 35.f, 36.f, 0.f, 0.f, 0.f, 0.f, 30.f, 0.f, 0.f, 0.f, 0.f, 0.f, 36.f, 0.f, 0.f, 0.f, 0.f, 0.f, 31.f, 32.f, 33.f, 34.f, 35.f, 36.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 32.f, 0.f, 34.f, 0.f, 36.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 33.f, -34.f, 35.f, 36.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 34.f, 0.f, 36.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 35.f, 36.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 36.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 37.f, 38.f, 39.f, 40.f, 41.f, 42.f, 43.f, 44.f, 45.f, 46.f, 47.f, 48.f, 38.f, 0.f, 40.f, -0.f, 42.f, 0.f, 44.f, 0.f, 46.f, 0.f, 48.f, 0.f, 39.f, 40.f, 41.f, 42.f, 0.f, 0.f, 45.f, 46.f, 47.f, 48.f, 0.f, 0.f, 40.f, 0.f, 42.f, 0.f, 0.f, 0.f, 46.f, 0.f, 48.f, 0.f, 0.f, 0.f, 41.f, 42.f, 0.f, 0.f, 0.f, 0.f, 47.f, 48.f, 0.f, 0.f, 0.f, 0.f, 42.f, 0.f, 0.f, 0.f, 0.f, 0.f, 48.f, 0.f, 0.f, 0.f, 0.f, 0.f, 43.f, 44.f, 45.f, 46.f, 47.f, -48.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 44.f, 0.f, 46.f, 0.f, 48.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 45.f, 46.f, 47.f, 48.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 46.f, 0.f, 48.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 47.f, 48.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 48.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, -0.f, 0.f, 49.f, 50.f, 51.f, 52.f, 53.f, 54.f, 55.f, 56.f, 57.f, 58.f, 59.f, 60.f, 50.f, 0.f, 52.f, 0.f, 54.f, 0.f, 56.f, 0.f, 58.f, 0.f, 60.f, 0.f, 51.f, 52.f, 53.f, 54.f, 0.f, 0.f, 57.f, 58.f, 59.f, 60.f, 0.f, 0.f, 52.f, 0.f, 54.f, 0.f, 0.f, 0.f, 58.f, 0.f, 60.f, 0.f, 0.f, 0.f, 53.f, 54.f, 0.f, 0.f, 0.f, 0.f, 59.f, 60.f, 0.f, 0.f, -0.f, 0.f, 54.f, 0.f, 0.f, 0.f, 0.f, 0.f, 60.f, 0.f, 0.f, 0.f, 0.f, 0.f, 55.f, 56.f, 57.f, 58.f, 59.f, 60.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 56.f, 0.f, 58.f, 0.f, 60.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 57.f, 58.f, 59.f, 60.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 58.f, 0.f, 60.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 59.f, 60.f, -0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 60.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 61.f, 62.f, 63.f, 64.f, 65.f, 66.f, 67.f, 68.f, 69.f, 70.f, 71.f, 72.f, 62.f, 0.f, 64.f, 0.f, 66.f, 0.f, 68.f, 0.f, 70.f, 0.f, 72.f, 0.f, 63.f, 64.f, 65.f, 66.f, 0.f, 0.f, 69.f, 70.f, 71.f, 72.f, 0.f, 0.f, 64.f, 0.f, 66.f, -0.f, 0.f, 0.f, 70.f, 0.f, 72.f, 0.f, 0.f, 0.f, 65.f, 66.f, 0.f, 0.f, 0.f, 0.f, 71.f, 72.f, 0.f, 0.f, 0.f, 0.f, 66.f, 0.f, 0.f, 0.f, 0.f, 0.f, 72.f, 0.f, 0.f, 0.f, 0.f, 0.f, 67.f, 68.f, 69.f, 70.f, 71.f, 72.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 68.f, 0.f, 70.f, 0.f, 72.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 69.f, 70.f, 71.f, 72.f, 0.f, 0.f, + auto columnsExpected = NDArrayFactory::create('c', {bS, iC, kD, kH, kW, oD, oH, oW}, {1.f, 2.f, 3.f, 4.f, 5.f, 6.f, 7.f, 8.f, 9.f, +10.f, 11.f, 12.f, 2.f, 0.f, 4.f, 0.f, 6.f, 0.f, 8.f, 0.f, 10.f, 0.f, 12.f, 0.f, 3.f, 4.f, 5.f, 6.f, 0.f, 0.f, 9.f, 10.f, 11.f, 12.f, 0.f, 0.f, 4.f, 0.f, 6.f, 0.f, 0.f, 0.f, 10.f, 0.f, 12.f, 0.f, 0.f, 0.f, 5.f, 6.f, 0.f, 0.f, 0.f, 0.f, 11.f, 12.f, 0.f, 0.f, 0.f, 0.f, 6.f, 0.f, 0.f, 0.f, 0.f, 0.f, 12.f, 0.f, 0.f, 0.f, 0.f, 0.f, 7.f, 8.f, +9.f, 10.f, 11.f, 12.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 8.f, 0.f, 10.f, 0.f, 12.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 9.f, 10.f, 11.f, 12.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 10.f, 0.f, 12.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 11.f, 12.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 12.f, 0.f, 0.f, 0.f, 0.f, 0.f, +0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 13.f, 14.f, 15.f, 16.f, 17.f, 18.f, 19.f, 20.f, 21.f, 22.f, 23.f, 24.f, 14.f, 0.f, 16.f, 0.f, 18.f, 0.f, 20.f, 0.f, 22.f, 0.f, 24.f, 0.f, 15.f, 16.f, 17.f, 18.f, 0.f, 0.f, 21.f, 22.f, 23.f, 24.f, 0.f, 0.f, 16.f, 0.f, 18.f, 0.f, 0.f, 0.f, 22.f, 0.f, 24.f, 0.f, 0.f, 0.f, 17.f, 18.f, 0.f, 0.f, 0.f, 0.f, +23.f, 24.f, 0.f, 0.f, 0.f, 0.f, 18.f, 0.f, 0.f, 0.f, 0.f, 0.f, 24.f, 0.f, 0.f, 0.f, 0.f, 0.f, 19.f, 20.f, 21.f, 22.f, 23.f, 24.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 20.f, 0.f, 22.f, 0.f, 24.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 21.f, 22.f, 23.f, 24.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 22.f, 0.f, 24.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, +0.f, 0.f, 0.f, 23.f, 24.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 24.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 25.f, 26.f, 27.f, 28.f, 29.f, 30.f, 31.f, 32.f, 33.f, 34.f, 35.f, 36.f, 26.f, 0.f, 28.f, 0.f, 30.f, 0.f, 32.f, 0.f, 34.f, 0.f, 36.f, 0.f, 27.f, 28.f, 29.f, 30.f, 0.f, 0.f, 33.f, 34.f, 35.f, 36.f, +0.f, 0.f, 28.f, 0.f, 30.f, 0.f, 0.f, 0.f, 34.f, 0.f, 36.f, 0.f, 0.f, 0.f, 29.f, 30.f, 0.f, 0.f, 0.f, 0.f, 35.f, 36.f, 0.f, 0.f, 0.f, 0.f, 30.f, 0.f, 0.f, 0.f, 0.f, 0.f, 36.f, 0.f, 0.f, 0.f, 0.f, 0.f, 31.f, 32.f, 33.f, 34.f, 35.f, 36.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 32.f, 0.f, 34.f, 0.f, 36.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 33.f, +34.f, 35.f, 36.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 34.f, 0.f, 36.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 35.f, 36.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 36.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 37.f, 38.f, 39.f, 40.f, 41.f, 42.f, 43.f, 44.f, 45.f, 46.f, 47.f, 48.f, 38.f, 0.f, 40.f, +0.f, 42.f, 0.f, 44.f, 0.f, 46.f, 0.f, 48.f, 0.f, 39.f, 40.f, 41.f, 42.f, 0.f, 0.f, 45.f, 46.f, 47.f, 48.f, 0.f, 0.f, 40.f, 0.f, 42.f, 0.f, 0.f, 0.f, 46.f, 0.f, 48.f, 0.f, 0.f, 0.f, 41.f, 42.f, 0.f, 0.f, 0.f, 0.f, 47.f, 48.f, 0.f, 0.f, 0.f, 0.f, 42.f, 0.f, 0.f, 0.f, 0.f, 0.f, 48.f, 0.f, 0.f, 0.f, 0.f, 0.f, 43.f, 44.f, 45.f, 46.f, 47.f, +48.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 44.f, 0.f, 46.f, 0.f, 48.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 45.f, 46.f, 47.f, 48.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 46.f, 0.f, 48.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 47.f, 48.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 48.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, +0.f, 0.f, 49.f, 50.f, 51.f, 52.f, 53.f, 54.f, 55.f, 56.f, 57.f, 58.f, 59.f, 60.f, 50.f, 0.f, 52.f, 0.f, 54.f, 0.f, 56.f, 0.f, 58.f, 0.f, 60.f, 0.f, 51.f, 52.f, 53.f, 54.f, 0.f, 0.f, 57.f, 58.f, 59.f, 60.f, 0.f, 0.f, 52.f, 0.f, 54.f, 0.f, 0.f, 0.f, 58.f, 0.f, 60.f, 0.f, 0.f, 0.f, 53.f, 54.f, 0.f, 0.f, 0.f, 0.f, 59.f, 60.f, 0.f, 0.f, +0.f, 0.f, 54.f, 0.f, 0.f, 0.f, 0.f, 0.f, 60.f, 0.f, 0.f, 0.f, 0.f, 0.f, 55.f, 56.f, 57.f, 58.f, 59.f, 60.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 56.f, 0.f, 58.f, 0.f, 60.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 57.f, 58.f, 59.f, 60.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 58.f, 0.f, 60.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 59.f, 60.f, +0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 60.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 61.f, 62.f, 63.f, 64.f, 65.f, 66.f, 67.f, 68.f, 69.f, 70.f, 71.f, 72.f, 62.f, 0.f, 64.f, 0.f, 66.f, 0.f, 68.f, 0.f, 70.f, 0.f, 72.f, 0.f, 63.f, 64.f, 65.f, 66.f, 0.f, 0.f, 69.f, 70.f, 71.f, 72.f, 0.f, 0.f, 64.f, 0.f, 66.f, +0.f, 0.f, 0.f, 70.f, 0.f, 72.f, 0.f, 0.f, 0.f, 65.f, 66.f, 0.f, 0.f, 0.f, 0.f, 71.f, 72.f, 0.f, 0.f, 0.f, 0.f, 66.f, 0.f, 0.f, 0.f, 0.f, 0.f, 72.f, 0.f, 0.f, 0.f, 0.f, 0.f, 67.f, 68.f, 69.f, 70.f, 71.f, 72.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 68.f, 0.f, 70.f, 0.f, 72.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 69.f, 70.f, 71.f, 72.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 70.f, 0.f, 72.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 71.f, 72.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 72.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f}); graph::Context context(1); @@ -2164,11 +2196,11 @@ TEST_F(ConvolutionTests1, upsampling2d_test1) { auto input = NDArrayFactory::create('c', {bS, iH, iW, iC}); input.linspace(1); - auto expOutput = NDArrayFactory::create('c', {bS, iH*factorH, iW*factorW, iC}, {1.f, 2.f, 3.f, 1.f, 2.f, 3.f, 1.f, 2.f, 3.f, 4.f, 5.f, 6.f, 4.f, 5.f, 6.f, 4.f, 5.f, 6.f, 1.f, 2.f, 3.f, 1.f, 2.f, 3.f, 1.f, 2.f, 3.f, 4.f, 5.f, 6.f, 4.f, 5.f, 6.f, 4.f, 5.f, 6.f, - 7.f, 8.f, 9.f, 7.f, 8.f, 9.f, 7.f, 8.f, 9.f, 10.f, 11.f, 12.f, 10.f, 11.f, 12.f, 10.f, 11.f, 12.f, 7.f, 8.f, 9.f, 7.f, 8.f, 9.f, 7.f, 8.f, 9.f, 10.f, 11.f, 12.f, 10.f, 11.f, 12.f, 10.f, 11.f, 12.f, - 13.f, 14.f, 15.f, 13.f, 14.f, 15.f, 13.f, 14.f, 15.f, 16.f, 17.f, 18.f, 16.f, 17.f, 18.f, 16.f, 17.f, 18.f, 13.f, 14.f, 15.f, 13.f, 14.f, 15.f, 13.f, 14.f, 15.f, 16.f, 17.f, 18.f, 16.f, 17.f, 18.f, 16.f, 17.f, 18.f, - 19.f, 20.f, 21.f, 19.f, 20.f, 21.f, 19.f, 20.f, 21.f, 22.f, 23.f, 24.f, 22.f, 23.f, 24.f, 22.f, 23.f, 24.f, 19.f, 20.f, 21.f, 19.f, 20.f, 21.f, 19.f, 20.f, 21.f, 22.f, 23.f, 24.f, 22.f, 23.f, 24.f, 22.f, 23.f, 24.f, - 25.f, 26.f, 27.f, 25.f, 26.f, 27.f, 25.f, 26.f, 27.f, 28.f, 29.f, 30.f, 28.f, 29.f, 30.f, 28.f, 29.f, 30.f, 25.f, 26.f, 27.f, 25.f, 26.f, 27.f, 25.f, 26.f, 27.f, 28.f, 29.f, 30.f, 28.f, 29.f, 30.f, 28.f, 29.f, 30.f, + auto expOutput = NDArrayFactory::create('c', {bS, iH*factorH, iW*factorW, iC}, {1.f, 2.f, 3.f, 1.f, 2.f, 3.f, 1.f, 2.f, 3.f, 4.f, 5.f, 6.f, 4.f, 5.f, 6.f, 4.f, 5.f, 6.f, 1.f, 2.f, 3.f, 1.f, 2.f, 3.f, 1.f, 2.f, 3.f, 4.f, 5.f, 6.f, 4.f, 5.f, 6.f, 4.f, 5.f, 6.f, + 7.f, 8.f, 9.f, 7.f, 8.f, 9.f, 7.f, 8.f, 9.f, 10.f, 11.f, 12.f, 10.f, 11.f, 12.f, 10.f, 11.f, 12.f, 7.f, 8.f, 9.f, 7.f, 8.f, 9.f, 7.f, 8.f, 9.f, 10.f, 11.f, 12.f, 10.f, 11.f, 12.f, 10.f, 11.f, 12.f, + 13.f, 14.f, 15.f, 13.f, 14.f, 15.f, 13.f, 14.f, 15.f, 16.f, 17.f, 18.f, 16.f, 17.f, 18.f, 16.f, 17.f, 18.f, 13.f, 14.f, 15.f, 13.f, 14.f, 15.f, 13.f, 14.f, 15.f, 16.f, 17.f, 18.f, 16.f, 17.f, 18.f, 16.f, 17.f, 18.f, + 19.f, 20.f, 21.f, 19.f, 20.f, 21.f, 19.f, 20.f, 21.f, 22.f, 23.f, 24.f, 22.f, 23.f, 24.f, 22.f, 23.f, 24.f, 19.f, 20.f, 21.f, 19.f, 20.f, 21.f, 19.f, 20.f, 21.f, 22.f, 23.f, 24.f, 22.f, 23.f, 24.f, 22.f, 23.f, 24.f, + 25.f, 26.f, 27.f, 25.f, 26.f, 27.f, 25.f, 26.f, 27.f, 28.f, 29.f, 30.f, 28.f, 29.f, 30.f, 28.f, 29.f, 30.f, 25.f, 26.f, 27.f, 25.f, 26.f, 27.f, 25.f, 26.f, 27.f, 28.f, 29.f, 30.f, 28.f, 29.f, 30.f, 28.f, 29.f, 30.f, 31.f, 32.f, 33.f, 31.f, 32.f, 33.f, 31.f, 32.f, 33.f, 34.f, 35.f, 36.f, 34.f, 35.f, 36.f, 34.f, 35.f, 36.f, 31.f, 32.f, 33.f, 31.f, 32.f, 33.f, 31.f, 32.f, 33.f, 34.f, 35.f, 36.f, 34.f, 35.f, 36.f, 34.f, 35.f, 36.f}); nd4j::ops::upsampling2d op; @@ -2192,11 +2224,11 @@ TEST_F(ConvolutionTests1, upsampling2d_test2) { auto input = NDArrayFactory::create('c', {bS, iC, iH, iW}); input.linspace(1); - auto expOutput = NDArrayFactory::create('c', {bS, iC, iH*factorH, iW*factorW}, {1.f, 1.f, 1.f, 2.f, 2.f, 2.f, 1.f, 1.f, 1.f, 2.f, 2.f, 2.f, 3.f, 3.f, 3.f, 4.f, 4.f, 4.f, 3.f, 3.f, 3.f, 4.f, 4.f, 4.f, - 5.f, 5.f, 5.f, 6.f, 6.f, 6.f, 5.f, 5.f, 5.f, 6.f, 6.f, 6.f, 7.f, 7.f, 7.f, 8.f, 8.f, 8.f, 7.f, 7.f, 7.f, 8.f, 8.f, 8.f, 9.f, 9.f, 9.f, 10.f, 10.f, 10.f, 9.f, 9.f, 9.f, 10.f, 10.f, 10.f, 11.f, 11.f, 11.f, 12.f, 12.f, 12.f, 11.f, 11.f, 11.f, 12.f, 12.f, 12.f, - 13.f, 13.f, 13.f, 14.f, 14.f, 14.f, 13.f, 13.f, 13.f, 14.f, 14.f, 14.f, 15.f, 15.f, 15.f, 16.f, 16.f, 16.f, 15.f, 15.f, 15.f, 16.f, 16.f, 16.f, 17.f, 17.f, 17.f, 18.f, 18.f, 18.f, 17.f, 17.f, 17.f, 18.f, 18.f, 18.f, 19.f, 19.f, 19.f, 20.f, 20.f, 20.f, 19.f, 19.f, 19.f, 20.f, 20.f, 20.f, - 21.f, 21.f, 21.f, 22.f, 22.f, 22.f, 21.f, 21.f, 21.f, 22.f, 22.f, 22.f, 23.f, 23.f, 23.f, 24.f, 24.f, 24.f, 23.f, 23.f, 23.f, 24.f, 24.f, 24.f, 25.f, 25.f, 25.f, 26.f, 26.f, 26.f, 25.f, 25.f, 25.f, 26.f, 26.f, 26.f, 27.f, 27.f, 27.f, 28.f, 28.f, 28.f, 27.f, 27.f, 27.f, 28.f, 28.f, 28.f, - 29.f, 29.f, 29.f, 30.f, 30.f, 30.f, 29.f, 29.f, 29.f, 30.f, 30.f, 30.f, 31.f, 31.f, 31.f, 32.f, 32.f, 32.f, 31.f, 31.f, 31.f, 32.f, 32.f, 32.f, + auto expOutput = NDArrayFactory::create('c', {bS, iC, iH*factorH, iW*factorW}, {1.f, 1.f, 1.f, 2.f, 2.f, 2.f, 1.f, 1.f, 1.f, 2.f, 2.f, 2.f, 3.f, 3.f, 3.f, 4.f, 4.f, 4.f, 3.f, 3.f, 3.f, 4.f, 4.f, 4.f, + 5.f, 5.f, 5.f, 6.f, 6.f, 6.f, 5.f, 5.f, 5.f, 6.f, 6.f, 6.f, 7.f, 7.f, 7.f, 8.f, 8.f, 8.f, 7.f, 7.f, 7.f, 8.f, 8.f, 8.f, 9.f, 9.f, 9.f, 10.f, 10.f, 10.f, 9.f, 9.f, 9.f, 10.f, 10.f, 10.f, 11.f, 11.f, 11.f, 12.f, 12.f, 12.f, 11.f, 11.f, 11.f, 12.f, 12.f, 12.f, + 13.f, 13.f, 13.f, 14.f, 14.f, 14.f, 13.f, 13.f, 13.f, 14.f, 14.f, 14.f, 15.f, 15.f, 15.f, 16.f, 16.f, 16.f, 15.f, 15.f, 15.f, 16.f, 16.f, 16.f, 17.f, 17.f, 17.f, 18.f, 18.f, 18.f, 17.f, 17.f, 17.f, 18.f, 18.f, 18.f, 19.f, 19.f, 19.f, 20.f, 20.f, 20.f, 19.f, 19.f, 19.f, 20.f, 20.f, 20.f, + 21.f, 21.f, 21.f, 22.f, 22.f, 22.f, 21.f, 21.f, 21.f, 22.f, 22.f, 22.f, 23.f, 23.f, 23.f, 24.f, 24.f, 24.f, 23.f, 23.f, 23.f, 24.f, 24.f, 24.f, 25.f, 25.f, 25.f, 26.f, 26.f, 26.f, 25.f, 25.f, 25.f, 26.f, 26.f, 26.f, 27.f, 27.f, 27.f, 28.f, 28.f, 28.f, 27.f, 27.f, 27.f, 28.f, 28.f, 28.f, + 29.f, 29.f, 29.f, 30.f, 30.f, 30.f, 29.f, 29.f, 29.f, 30.f, 30.f, 30.f, 31.f, 31.f, 31.f, 32.f, 32.f, 32.f, 31.f, 31.f, 31.f, 32.f, 32.f, 32.f, 33.f, 33.f, 33.f, 34.f, 34.f, 34.f, 33.f, 33.f, 33.f, 34.f, 34.f, 34.f, 35.f, 35.f, 35.f, 36.f, 36.f, 36.f, 35.f, 35.f, 35.f, 36.f, 36.f, 36.f}); nd4j::ops::upsampling2d op; @@ -2221,20 +2253,20 @@ TEST_F(ConvolutionTests1, upsampling3d_test1) { auto input = NDArrayFactory::create('c', {bS, iD, iH, iW, iC}); input.linspace(1); - auto expOutput = NDArrayFactory::create('c', {bS, iD*factorD, iH*factorH, iW*factorW, iC}, {1.f, 2.f, 3.f, 1.f, 2.f, 3.f, 4.f, 5.f, 6.f, 4.f, 5.f, 6.f, 1.f, 2.f, 3.f, 1.f, 2.f, 3.f, 4.f, 5.f, 6.f, 4.f, 5.f, 6.f, 1.f, 2.f, 3.f, 1.f, 2.f, 3.f, 4.f, 5.f, 6.f, 4.f, 5.f, 6.f, 7.f, 8.f, 9.f, 7.f, 8.f, 9.f, 10.f, 11.f, 12.f, 10.f, 11.f, 12.f, 7.f, 8.f, 9.f, 7.f, 8.f, 9.f, 10.f, 11.f, 12.f, 10.f, 11.f, 12.f, - 7.f, 8.f, 9.f, 7.f, 8.f, 9.f, 10.f, 11.f, 12.f, 10.f, 11.f, 12.f, 1.f, 2.f, 3.f, 1.f, 2.f, 3.f, 4.f, 5.f, 6.f, 4.f, 5.f, 6.f, 1.f, 2.f, 3.f, 1.f, 2.f, 3.f, 4.f, 5.f, 6.f, 4.f, 5.f, 6.f, 1.f, 2.f, 3.f, 1.f, 2.f, 3.f, 4.f, 5.f, 6.f, 4.f, 5.f, 6.f, 7.f, 8.f, 9.f, 7.f, 8.f, 9.f, 10.f, 11.f, 12.f, 10.f, 11.f, 12.f, - 7.f, 8.f, 9.f, 7.f, 8.f, 9.f, 10.f, 11.f, 12.f, 10.f, 11.f, 12.f, 7.f, 8.f, 9.f, 7.f, 8.f, 9.f, 10.f, 11.f, 12.f, 10.f, 11.f, 12.f, 13.f, 14.f, 15.f, 13.f, 14.f, 15.f, 16.f, 17.f, 18.f, 16.f, 17.f, 18.f, 13.f, 14.f, 15.f, 13.f, 14.f, 15.f, 16.f, 17.f, 18.f, 16.f, 17.f, 18.f, 13.f, 14.f, 15.f, 13.f, 14.f, 15.f, 16.f, 17.f, 18.f, 16.f, 17.f, 18.f, - 19.f, 20.f, 21.f, 19.f, 20.f, 21.f, 22.f, 23.f, 24.f, 22.f, 23.f, 24.f, 19.f, 20.f, 21.f, 19.f, 20.f, 21.f, 22.f, 23.f, 24.f, 22.f, 23.f, 24.f, 19.f, 20.f, 21.f, 19.f, 20.f, 21.f, 22.f, 23.f, 24.f, 22.f, 23.f, 24.f, 13.f, 14.f, 15.f, 13.f, 14.f, 15.f, 16.f, 17.f, 18.f, 16.f, 17.f, 18.f, 13.f, 14.f, 15.f, 13.f, 14.f, 15.f, 16.f, 17.f, 18.f, 16.f, 17.f, 18.f, - 13.f, 14.f, 15.f, 13.f, 14.f, 15.f, 16.f, 17.f, 18.f, 16.f, 17.f, 18.f, 19.f, 20.f, 21.f, 19.f, 20.f, 21.f, 22.f, 23.f, 24.f, 22.f, 23.f, 24.f, 19.f, 20.f, 21.f, 19.f, 20.f, 21.f, 22.f, 23.f, 24.f, 22.f, 23.f, 24.f, 19.f, 20.f, 21.f, 19.f, 20.f, 21.f, 22.f, 23.f, 24.f, 22.f, 23.f, 24.f, 25.f, 26.f, 27.f, 25.f, 26.f, 27.f, 28.f, 29.f, 30.f, 28.f, 29.f, 30.f, - 25.f, 26.f, 27.f, 25.f, 26.f, 27.f, 28.f, 29.f, 30.f, 28.f, 29.f, 30.f, 25.f, 26.f, 27.f, 25.f, 26.f, 27.f, 28.f, 29.f, 30.f, 28.f, 29.f, 30.f, 31.f, 32.f, 33.f, 31.f, 32.f, 33.f, 34.f, 35.f, 36.f, 34.f, 35.f, 36.f, 31.f, 32.f, 33.f, 31.f, 32.f, 33.f, 34.f, 35.f, 36.f, 34.f, 35.f, 36.f, 31.f, 32.f, 33.f, 31.f, 32.f, 33.f, 34.f, 35.f, 36.f, 34.f, 35.f, 36.f, - 25.f, 26.f, 27.f, 25.f, 26.f, 27.f, 28.f, 29.f, 30.f, 28.f, 29.f, 30.f, 25.f, 26.f, 27.f, 25.f, 26.f, 27.f, 28.f, 29.f, 30.f, 28.f, 29.f, 30.f, 25.f, 26.f, 27.f, 25.f, 26.f, 27.f, 28.f, 29.f, 30.f, 28.f, 29.f, 30.f, 31.f, 32.f, 33.f, 31.f, 32.f, 33.f, 34.f, 35.f, 36.f, 34.f, 35.f, 36.f, 31.f, 32.f, 33.f, 31.f, 32.f, 33.f, 34.f, 35.f, 36.f, 34.f, 35.f, 36.f, - 31.f, 32.f, 33.f, 31.f, 32.f, 33.f, 34.f, 35.f, 36.f, 34.f, 35.f, 36.f, 37.f, 38.f, 39.f, 37.f, 38.f, 39.f, 40.f, 41.f, 42.f, 40.f, 41.f, 42.f, 37.f, 38.f, 39.f, 37.f, 38.f, 39.f, 40.f, 41.f, 42.f, 40.f, 41.f, 42.f, 37.f, 38.f, 39.f, 37.f, 38.f, 39.f, 40.f, 41.f, 42.f, 40.f, 41.f, 42.f, 43.f, 44.f, 45.f, 43.f, 44.f, 45.f, 46.f, 47.f, 48.f, 46.f, 47.f, 48.f, - 43.f, 44.f, 45.f, 43.f, 44.f, 45.f, 46.f, 47.f, 48.f, 46.f, 47.f, 48.f, 43.f, 44.f, 45.f, 43.f, 44.f, 45.f, 46.f, 47.f, 48.f, 46.f, 47.f, 48.f, 37.f, 38.f, 39.f, 37.f, 38.f, 39.f, 40.f, 41.f, 42.f, 40.f, 41.f, 42.f, 37.f, 38.f, 39.f, 37.f, 38.f, 39.f, 40.f, 41.f, 42.f, 40.f, 41.f, 42.f, 37.f, 38.f, 39.f, 37.f, 38.f, 39.f, 40.f, 41.f, 42.f, 40.f, 41.f, 42.f, - 43.f, 44.f, 45.f, 43.f, 44.f, 45.f, 46.f, 47.f, 48.f, 46.f, 47.f, 48.f, 43.f, 44.f, 45.f, 43.f, 44.f, 45.f, 46.f, 47.f, 48.f, 46.f, 47.f, 48.f, 43.f, 44.f, 45.f, 43.f, 44.f, 45.f, 46.f, 47.f, 48.f, 46.f, 47.f, 48.f, 49.f, 50.f, 51.f, 49.f, 50.f, 51.f, 52.f, 53.f, 54.f, 52.f, 53.f, 54.f, 49.f, 50.f, 51.f, 49.f, 50.f, 51.f, 52.f, 53.f, 54.f, 52.f, 53.f, 54.f, - 49.f, 50.f, 51.f, 49.f, 50.f, 51.f, 52.f, 53.f, 54.f, 52.f, 53.f, 54.f, 55.f, 56.f, 57.f, 55.f, 56.f, 57.f, 58.f, 59.f, 60.f, 58.f, 59.f, 60.f, 55.f, 56.f, 57.f, 55.f, 56.f, 57.f, 58.f, 59.f, 60.f, 58.f, 59.f, 60.f, 55.f, 56.f, 57.f, 55.f, 56.f, 57.f, 58.f, 59.f, 60.f, 58.f, 59.f, 60.f, 49.f, 50.f, 51.f, 49.f, 50.f, 51.f, 52.f, 53.f, 54.f, 52.f, 53.f, 54.f, - 49.f, 50.f, 51.f, 49.f, 50.f, 51.f, 52.f, 53.f, 54.f, 52.f, 53.f, 54.f, 49.f, 50.f, 51.f, 49.f, 50.f, 51.f, 52.f, 53.f, 54.f, 52.f, 53.f, 54.f, 55.f, 56.f, 57.f, 55.f, 56.f, 57.f, 58.f, 59.f, 60.f, 58.f, 59.f, 60.f, 55.f, 56.f, 57.f, 55.f, 56.f, 57.f, 58.f, 59.f, 60.f, 58.f, 59.f, 60.f, 55.f, 56.f, 57.f, 55.f, 56.f, 57.f, 58.f, 59.f, 60.f, 58.f, 59.f, 60.f, - 61.f, 62.f, 63.f, 61.f, 62.f, 63.f, 64.f, 65.f, 66.f, 64.f, 65.f, 66.f, 61.f, 62.f, 63.f, 61.f, 62.f, 63.f, 64.f, 65.f, 66.f, 64.f, 65.f, 66.f, 61.f, 62.f, 63.f, 61.f, 62.f, 63.f, 64.f, 65.f, 66.f, 64.f, 65.f, 66.f, 67.f, 68.f, 69.f, 67.f, 68.f, 69.f, 70.f, 71.f, 72.f, 70.f, 71.f, 72.f, 67.f, 68.f, 69.f, 67.f, 68.f, 69.f, 70.f, 71.f, 72.f, 70.f, 71.f, 72.f, - 67.f, 68.f, 69.f, 67.f, 68.f, 69.f, 70.f, 71.f, 72.f, 70.f, 71.f, 72.f, 61.f, 62.f, 63.f, 61.f, 62.f, 63.f, 64.f, 65.f, 66.f, 64.f, 65.f, 66.f, 61.f, 62.f, 63.f, 61.f, 62.f, 63.f, 64.f, 65.f, 66.f, 64.f, 65.f, 66.f, 61.f, 62.f, 63.f, 61.f, 62.f, 63.f, 64.f, 65.f, 66.f, 64.f, 65.f, 66.f, 67.f, 68.f, 69.f, 67.f, 68.f, 69.f, 70.f, 71.f, 72.f, 70.f, 71.f, 72.f, + auto expOutput = NDArrayFactory::create('c', {bS, iD*factorD, iH*factorH, iW*factorW, iC}, {1.f, 2.f, 3.f, 1.f, 2.f, 3.f, 4.f, 5.f, 6.f, 4.f, 5.f, 6.f, 1.f, 2.f, 3.f, 1.f, 2.f, 3.f, 4.f, 5.f, 6.f, 4.f, 5.f, 6.f, 1.f, 2.f, 3.f, 1.f, 2.f, 3.f, 4.f, 5.f, 6.f, 4.f, 5.f, 6.f, 7.f, 8.f, 9.f, 7.f, 8.f, 9.f, 10.f, 11.f, 12.f, 10.f, 11.f, 12.f, 7.f, 8.f, 9.f, 7.f, 8.f, 9.f, 10.f, 11.f, 12.f, 10.f, 11.f, 12.f, + 7.f, 8.f, 9.f, 7.f, 8.f, 9.f, 10.f, 11.f, 12.f, 10.f, 11.f, 12.f, 1.f, 2.f, 3.f, 1.f, 2.f, 3.f, 4.f, 5.f, 6.f, 4.f, 5.f, 6.f, 1.f, 2.f, 3.f, 1.f, 2.f, 3.f, 4.f, 5.f, 6.f, 4.f, 5.f, 6.f, 1.f, 2.f, 3.f, 1.f, 2.f, 3.f, 4.f, 5.f, 6.f, 4.f, 5.f, 6.f, 7.f, 8.f, 9.f, 7.f, 8.f, 9.f, 10.f, 11.f, 12.f, 10.f, 11.f, 12.f, + 7.f, 8.f, 9.f, 7.f, 8.f, 9.f, 10.f, 11.f, 12.f, 10.f, 11.f, 12.f, 7.f, 8.f, 9.f, 7.f, 8.f, 9.f, 10.f, 11.f, 12.f, 10.f, 11.f, 12.f, 13.f, 14.f, 15.f, 13.f, 14.f, 15.f, 16.f, 17.f, 18.f, 16.f, 17.f, 18.f, 13.f, 14.f, 15.f, 13.f, 14.f, 15.f, 16.f, 17.f, 18.f, 16.f, 17.f, 18.f, 13.f, 14.f, 15.f, 13.f, 14.f, 15.f, 16.f, 17.f, 18.f, 16.f, 17.f, 18.f, + 19.f, 20.f, 21.f, 19.f, 20.f, 21.f, 22.f, 23.f, 24.f, 22.f, 23.f, 24.f, 19.f, 20.f, 21.f, 19.f, 20.f, 21.f, 22.f, 23.f, 24.f, 22.f, 23.f, 24.f, 19.f, 20.f, 21.f, 19.f, 20.f, 21.f, 22.f, 23.f, 24.f, 22.f, 23.f, 24.f, 13.f, 14.f, 15.f, 13.f, 14.f, 15.f, 16.f, 17.f, 18.f, 16.f, 17.f, 18.f, 13.f, 14.f, 15.f, 13.f, 14.f, 15.f, 16.f, 17.f, 18.f, 16.f, 17.f, 18.f, + 13.f, 14.f, 15.f, 13.f, 14.f, 15.f, 16.f, 17.f, 18.f, 16.f, 17.f, 18.f, 19.f, 20.f, 21.f, 19.f, 20.f, 21.f, 22.f, 23.f, 24.f, 22.f, 23.f, 24.f, 19.f, 20.f, 21.f, 19.f, 20.f, 21.f, 22.f, 23.f, 24.f, 22.f, 23.f, 24.f, 19.f, 20.f, 21.f, 19.f, 20.f, 21.f, 22.f, 23.f, 24.f, 22.f, 23.f, 24.f, 25.f, 26.f, 27.f, 25.f, 26.f, 27.f, 28.f, 29.f, 30.f, 28.f, 29.f, 30.f, + 25.f, 26.f, 27.f, 25.f, 26.f, 27.f, 28.f, 29.f, 30.f, 28.f, 29.f, 30.f, 25.f, 26.f, 27.f, 25.f, 26.f, 27.f, 28.f, 29.f, 30.f, 28.f, 29.f, 30.f, 31.f, 32.f, 33.f, 31.f, 32.f, 33.f, 34.f, 35.f, 36.f, 34.f, 35.f, 36.f, 31.f, 32.f, 33.f, 31.f, 32.f, 33.f, 34.f, 35.f, 36.f, 34.f, 35.f, 36.f, 31.f, 32.f, 33.f, 31.f, 32.f, 33.f, 34.f, 35.f, 36.f, 34.f, 35.f, 36.f, + 25.f, 26.f, 27.f, 25.f, 26.f, 27.f, 28.f, 29.f, 30.f, 28.f, 29.f, 30.f, 25.f, 26.f, 27.f, 25.f, 26.f, 27.f, 28.f, 29.f, 30.f, 28.f, 29.f, 30.f, 25.f, 26.f, 27.f, 25.f, 26.f, 27.f, 28.f, 29.f, 30.f, 28.f, 29.f, 30.f, 31.f, 32.f, 33.f, 31.f, 32.f, 33.f, 34.f, 35.f, 36.f, 34.f, 35.f, 36.f, 31.f, 32.f, 33.f, 31.f, 32.f, 33.f, 34.f, 35.f, 36.f, 34.f, 35.f, 36.f, + 31.f, 32.f, 33.f, 31.f, 32.f, 33.f, 34.f, 35.f, 36.f, 34.f, 35.f, 36.f, 37.f, 38.f, 39.f, 37.f, 38.f, 39.f, 40.f, 41.f, 42.f, 40.f, 41.f, 42.f, 37.f, 38.f, 39.f, 37.f, 38.f, 39.f, 40.f, 41.f, 42.f, 40.f, 41.f, 42.f, 37.f, 38.f, 39.f, 37.f, 38.f, 39.f, 40.f, 41.f, 42.f, 40.f, 41.f, 42.f, 43.f, 44.f, 45.f, 43.f, 44.f, 45.f, 46.f, 47.f, 48.f, 46.f, 47.f, 48.f, + 43.f, 44.f, 45.f, 43.f, 44.f, 45.f, 46.f, 47.f, 48.f, 46.f, 47.f, 48.f, 43.f, 44.f, 45.f, 43.f, 44.f, 45.f, 46.f, 47.f, 48.f, 46.f, 47.f, 48.f, 37.f, 38.f, 39.f, 37.f, 38.f, 39.f, 40.f, 41.f, 42.f, 40.f, 41.f, 42.f, 37.f, 38.f, 39.f, 37.f, 38.f, 39.f, 40.f, 41.f, 42.f, 40.f, 41.f, 42.f, 37.f, 38.f, 39.f, 37.f, 38.f, 39.f, 40.f, 41.f, 42.f, 40.f, 41.f, 42.f, + 43.f, 44.f, 45.f, 43.f, 44.f, 45.f, 46.f, 47.f, 48.f, 46.f, 47.f, 48.f, 43.f, 44.f, 45.f, 43.f, 44.f, 45.f, 46.f, 47.f, 48.f, 46.f, 47.f, 48.f, 43.f, 44.f, 45.f, 43.f, 44.f, 45.f, 46.f, 47.f, 48.f, 46.f, 47.f, 48.f, 49.f, 50.f, 51.f, 49.f, 50.f, 51.f, 52.f, 53.f, 54.f, 52.f, 53.f, 54.f, 49.f, 50.f, 51.f, 49.f, 50.f, 51.f, 52.f, 53.f, 54.f, 52.f, 53.f, 54.f, + 49.f, 50.f, 51.f, 49.f, 50.f, 51.f, 52.f, 53.f, 54.f, 52.f, 53.f, 54.f, 55.f, 56.f, 57.f, 55.f, 56.f, 57.f, 58.f, 59.f, 60.f, 58.f, 59.f, 60.f, 55.f, 56.f, 57.f, 55.f, 56.f, 57.f, 58.f, 59.f, 60.f, 58.f, 59.f, 60.f, 55.f, 56.f, 57.f, 55.f, 56.f, 57.f, 58.f, 59.f, 60.f, 58.f, 59.f, 60.f, 49.f, 50.f, 51.f, 49.f, 50.f, 51.f, 52.f, 53.f, 54.f, 52.f, 53.f, 54.f, + 49.f, 50.f, 51.f, 49.f, 50.f, 51.f, 52.f, 53.f, 54.f, 52.f, 53.f, 54.f, 49.f, 50.f, 51.f, 49.f, 50.f, 51.f, 52.f, 53.f, 54.f, 52.f, 53.f, 54.f, 55.f, 56.f, 57.f, 55.f, 56.f, 57.f, 58.f, 59.f, 60.f, 58.f, 59.f, 60.f, 55.f, 56.f, 57.f, 55.f, 56.f, 57.f, 58.f, 59.f, 60.f, 58.f, 59.f, 60.f, 55.f, 56.f, 57.f, 55.f, 56.f, 57.f, 58.f, 59.f, 60.f, 58.f, 59.f, 60.f, + 61.f, 62.f, 63.f, 61.f, 62.f, 63.f, 64.f, 65.f, 66.f, 64.f, 65.f, 66.f, 61.f, 62.f, 63.f, 61.f, 62.f, 63.f, 64.f, 65.f, 66.f, 64.f, 65.f, 66.f, 61.f, 62.f, 63.f, 61.f, 62.f, 63.f, 64.f, 65.f, 66.f, 64.f, 65.f, 66.f, 67.f, 68.f, 69.f, 67.f, 68.f, 69.f, 70.f, 71.f, 72.f, 70.f, 71.f, 72.f, 67.f, 68.f, 69.f, 67.f, 68.f, 69.f, 70.f, 71.f, 72.f, 70.f, 71.f, 72.f, + 67.f, 68.f, 69.f, 67.f, 68.f, 69.f, 70.f, 71.f, 72.f, 70.f, 71.f, 72.f, 61.f, 62.f, 63.f, 61.f, 62.f, 63.f, 64.f, 65.f, 66.f, 64.f, 65.f, 66.f, 61.f, 62.f, 63.f, 61.f, 62.f, 63.f, 64.f, 65.f, 66.f, 64.f, 65.f, 66.f, 61.f, 62.f, 63.f, 61.f, 62.f, 63.f, 64.f, 65.f, 66.f, 64.f, 65.f, 66.f, 67.f, 68.f, 69.f, 67.f, 68.f, 69.f, 70.f, 71.f, 72.f, 70.f, 71.f, 72.f, 67.f, 68.f, 69.f, 67.f, 68.f, 69.f, 70.f, 71.f, 72.f, 70.f, 71.f, 72.f, 67.f, 68.f, 69.f, 67.f, 68.f, 69.f, 70.f, 71.f, 72.f, 70.f, 71.f, 72.f}); nd4j::ops::upsampling3d op; @@ -2258,17 +2290,17 @@ TEST_F(ConvolutionTests1, upsampling3d_test2) { auto input = NDArrayFactory::create('c', {bS, iC, iD, iH, iW}); input.linspace(1); - auto expOutput = NDArrayFactory::create('c', {bS, iC, iD*factorD, iH*factorH, iW*factorW}, { 1.f, 1.f, 2.f, 2.f, 1.f, 1.f, 2.f, 2.f, 1.f, 1.f, 2.f, 2.f, 3.f, 3.f, 4.f, 4.f, 3.f, 3.f, 4.f, 4.f, 3.f, 3.f, 4.f, 4.f, 1.f, 1.f, 2.f, 2.f, 1.f, 1.f, 2.f, 2.f, 1.f, 1.f, 2.f, 2.f, 3.f, 3.f, 4.f, 4.f, 3.f, 3.f, 4.f, 4.f, 3.f, 3.f, 4.f, 4.f, 5.f, 5.f, 6.f, 6.f, 5.f, 5.f, 6.f, 6.f, 5.f, 5.f, 6.f, 6.f, 7.f, 7.f, 8.f, 8.f, 7.f, 7.f, 8.f, 8.f, 7.f, 7.f, 8.f, 8.f, - 5.f, 5.f, 6.f, 6.f, 5.f, 5.f, 6.f, 6.f, 5.f, 5.f, 6.f, 6.f, 7.f, 7.f, 8.f, 8.f, 7.f, 7.f, 8.f, 8.f, 7.f, 7.f, 8.f, 8.f, 9.f, 9.f, 10.f, 10.f, 9.f, 9.f, 10.f, 10.f, 9.f, 9.f, 10.f, 10.f, 11.f, 11.f, 12.f, 12.f, 11.f, 11.f, 12.f, 12.f, 11.f, 11.f, 12.f, 12.f, 9.f, 9.f, 10.f, 10.f, 9.f, 9.f, 10.f, 10.f, 9.f, 9.f, 10.f, 10.f, 11.f, 11.f, 12.f, 12.f, 11.f, 11.f, 12.f, 12.f, 11.f, 11.f, 12.f, 12.f, - 13.f, 13.f, 14.f, 14.f, 13.f, 13.f, 14.f, 14.f, 13.f, 13.f, 14.f, 14.f, 15.f, 15.f, 16.f, 16.f, 15.f, 15.f, 16.f, 16.f, 15.f, 15.f, 16.f, 16.f, 13.f, 13.f, 14.f, 14.f, 13.f, 13.f, 14.f, 14.f, 13.f, 13.f, 14.f, 14.f, 15.f, 15.f, 16.f, 16.f, 15.f, 15.f, 16.f, 16.f, 15.f, 15.f, 16.f, 16.f, 17.f, 17.f, 18.f, 18.f, 17.f, 17.f, 18.f, 18.f, 17.f, 17.f, 18.f, 18.f, 19.f, 19.f, 20.f, 20.f, 19.f, 19.f, 20.f, 20.f, 19.f, 19.f, 20.f, 20.f, - 17.f, 17.f, 18.f, 18.f, 17.f, 17.f, 18.f, 18.f, 17.f, 17.f, 18.f, 18.f, 19.f, 19.f, 20.f, 20.f, 19.f, 19.f, 20.f, 20.f, 19.f, 19.f, 20.f, 20.f, 21.f, 21.f, 22.f, 22.f, 21.f, 21.f, 22.f, 22.f, 21.f, 21.f, 22.f, 22.f, 23.f, 23.f, 24.f, 24.f, 23.f, 23.f, 24.f, 24.f, 23.f, 23.f, 24.f, 24.f, 21.f, 21.f, 22.f, 22.f, 21.f, 21.f, 22.f, 22.f, 21.f, 21.f, 22.f, 22.f, 23.f, 23.f, 24.f, 24.f, 23.f, 23.f, 24.f, 24.f, 23.f, 23.f, 24.f, 24.f, - 25.f, 25.f, 26.f, 26.f, 25.f, 25.f, 26.f, 26.f, 25.f, 25.f, 26.f, 26.f, 27.f, 27.f, 28.f, 28.f, 27.f, 27.f, 28.f, 28.f, 27.f, 27.f, 28.f, 28.f, 25.f, 25.f, 26.f, 26.f, 25.f, 25.f, 26.f, 26.f, 25.f, 25.f, 26.f, 26.f, 27.f, 27.f, 28.f, 28.f, 27.f, 27.f, 28.f, 28.f, 27.f, 27.f, 28.f, 28.f, 29.f, 29.f, 30.f, 30.f, 29.f, 29.f, 30.f, 30.f, 29.f, 29.f, 30.f, 30.f, 31.f, 31.f, 32.f, 32.f, 31.f, 31.f, 32.f, 32.f, 31.f, 31.f, 32.f, 32.f, - 29.f, 29.f, 30.f, 30.f, 29.f, 29.f, 30.f, 30.f, 29.f, 29.f, 30.f, 30.f, 31.f, 31.f, 32.f, 32.f, 31.f, 31.f, 32.f, 32.f, 31.f, 31.f, 32.f, 32.f, 33.f, 33.f, 34.f, 34.f, 33.f, 33.f, 34.f, 34.f, 33.f, 33.f, 34.f, 34.f, 35.f, 35.f, 36.f, 36.f, 35.f, 35.f, 36.f, 36.f, 35.f, 35.f, 36.f, 36.f, 33.f, 33.f, 34.f, 34.f, 33.f, 33.f, 34.f, 34.f, 33.f, 33.f, 34.f, 34.f, 35.f, 35.f, 36.f, 36.f, 35.f, 35.f, 36.f, 36.f, 35.f, 35.f, 36.f, 36.f, - 37.f, 37.f, 38.f, 38.f, 37.f, 37.f, 38.f, 38.f, 37.f, 37.f, 38.f, 38.f, 39.f, 39.f, 40.f, 40.f, 39.f, 39.f, 40.f, 40.f, 39.f, 39.f, 40.f, 40.f, 37.f, 37.f, 38.f, 38.f, 37.f, 37.f, 38.f, 38.f, 37.f, 37.f, 38.f, 38.f, 39.f, 39.f, 40.f, 40.f, 39.f, 39.f, 40.f, 40.f, 39.f, 39.f, 40.f, 40.f, 41.f, 41.f, 42.f, 42.f, 41.f, 41.f, 42.f, 42.f, 41.f, 41.f, 42.f, 42.f, 43.f, 43.f, 44.f, 44.f, 43.f, 43.f, 44.f, 44.f, 43.f, 43.f, 44.f, 44.f, - 41.f, 41.f, 42.f, 42.f, 41.f, 41.f, 42.f, 42.f, 41.f, 41.f, 42.f, 42.f, 43.f, 43.f, 44.f, 44.f, 43.f, 43.f, 44.f, 44.f, 43.f, 43.f, 44.f, 44.f, 45.f, 45.f, 46.f, 46.f, 45.f, 45.f, 46.f, 46.f, 45.f, 45.f, 46.f, 46.f, 47.f, 47.f, 48.f, 48.f, 47.f, 47.f, 48.f, 48.f, 47.f, 47.f, 48.f, 48.f, 45.f, 45.f, 46.f, 46.f, 45.f, 45.f, 46.f, 46.f, 45.f, 45.f, 46.f, 46.f, 47.f, 47.f, 48.f, 48.f, 47.f, 47.f, 48.f, 48.f, 47.f, 47.f, 48.f, 48.f, - 49.f, 49.f, 50.f, 50.f, 49.f, 49.f, 50.f, 50.f, 49.f, 49.f, 50.f, 50.f, 51.f, 51.f, 52.f, 52.f, 51.f, 51.f, 52.f, 52.f, 51.f, 51.f, 52.f, 52.f, 49.f, 49.f, 50.f, 50.f, 49.f, 49.f, 50.f, 50.f, 49.f, 49.f, 50.f, 50.f, 51.f, 51.f, 52.f, 52.f, 51.f, 51.f, 52.f, 52.f, 51.f, 51.f, 52.f, 52.f, 53.f, 53.f, 54.f, 54.f, 53.f, 53.f, 54.f, 54.f, 53.f, 53.f, 54.f, 54.f, 55.f, 55.f, 56.f, 56.f, 55.f, 55.f, 56.f, 56.f, 55.f, 55.f, 56.f, 56.f, - 53.f, 53.f, 54.f, 54.f, 53.f, 53.f, 54.f, 54.f, 53.f, 53.f, 54.f, 54.f, 55.f, 55.f, 56.f, 56.f, 55.f, 55.f, 56.f, 56.f, 55.f, 55.f, 56.f, 56.f, 57.f, 57.f, 58.f, 58.f, 57.f, 57.f, 58.f, 58.f, 57.f, 57.f, 58.f, 58.f, 59.f, 59.f, 60.f, 60.f, 59.f, 59.f, 60.f, 60.f, 59.f, 59.f, 60.f, 60.f, 57.f, 57.f, 58.f, 58.f, 57.f, 57.f, 58.f, 58.f, 57.f, 57.f, 58.f, 58.f, 59.f, 59.f, 60.f, 60.f, 59.f, 59.f, 60.f, 60.f, 59.f, 59.f, 60.f, 60.f, - 61.f, 61.f, 62.f, 62.f, 61.f, 61.f, 62.f, 62.f, 61.f, 61.f, 62.f, 62.f, 63.f, 63.f, 64.f, 64.f, 63.f, 63.f, 64.f, 64.f, 63.f, 63.f, 64.f, 64.f, 61.f, 61.f, 62.f, 62.f, 61.f, 61.f, 62.f, 62.f, 61.f, 61.f, 62.f, 62.f, 63.f, 63.f, 64.f, 64.f, 63.f, 63.f, 64.f, 64.f, 63.f, 63.f, 64.f, 64.f, 65.f, 65.f, 66.f, 66.f, 65.f, 65.f, 66.f, 66.f, 65.f, 65.f, 66.f, 66.f, 67.f, 67.f, 68.f, 68.f, 67.f, 67.f, 68.f, 68.f, 67.f, 67.f, 68.f, 68.f, + auto expOutput = NDArrayFactory::create('c', {bS, iC, iD*factorD, iH*factorH, iW*factorW}, { 1.f, 1.f, 2.f, 2.f, 1.f, 1.f, 2.f, 2.f, 1.f, 1.f, 2.f, 2.f, 3.f, 3.f, 4.f, 4.f, 3.f, 3.f, 4.f, 4.f, 3.f, 3.f, 4.f, 4.f, 1.f, 1.f, 2.f, 2.f, 1.f, 1.f, 2.f, 2.f, 1.f, 1.f, 2.f, 2.f, 3.f, 3.f, 4.f, 4.f, 3.f, 3.f, 4.f, 4.f, 3.f, 3.f, 4.f, 4.f, 5.f, 5.f, 6.f, 6.f, 5.f, 5.f, 6.f, 6.f, 5.f, 5.f, 6.f, 6.f, 7.f, 7.f, 8.f, 8.f, 7.f, 7.f, 8.f, 8.f, 7.f, 7.f, 8.f, 8.f, + 5.f, 5.f, 6.f, 6.f, 5.f, 5.f, 6.f, 6.f, 5.f, 5.f, 6.f, 6.f, 7.f, 7.f, 8.f, 8.f, 7.f, 7.f, 8.f, 8.f, 7.f, 7.f, 8.f, 8.f, 9.f, 9.f, 10.f, 10.f, 9.f, 9.f, 10.f, 10.f, 9.f, 9.f, 10.f, 10.f, 11.f, 11.f, 12.f, 12.f, 11.f, 11.f, 12.f, 12.f, 11.f, 11.f, 12.f, 12.f, 9.f, 9.f, 10.f, 10.f, 9.f, 9.f, 10.f, 10.f, 9.f, 9.f, 10.f, 10.f, 11.f, 11.f, 12.f, 12.f, 11.f, 11.f, 12.f, 12.f, 11.f, 11.f, 12.f, 12.f, + 13.f, 13.f, 14.f, 14.f, 13.f, 13.f, 14.f, 14.f, 13.f, 13.f, 14.f, 14.f, 15.f, 15.f, 16.f, 16.f, 15.f, 15.f, 16.f, 16.f, 15.f, 15.f, 16.f, 16.f, 13.f, 13.f, 14.f, 14.f, 13.f, 13.f, 14.f, 14.f, 13.f, 13.f, 14.f, 14.f, 15.f, 15.f, 16.f, 16.f, 15.f, 15.f, 16.f, 16.f, 15.f, 15.f, 16.f, 16.f, 17.f, 17.f, 18.f, 18.f, 17.f, 17.f, 18.f, 18.f, 17.f, 17.f, 18.f, 18.f, 19.f, 19.f, 20.f, 20.f, 19.f, 19.f, 20.f, 20.f, 19.f, 19.f, 20.f, 20.f, + 17.f, 17.f, 18.f, 18.f, 17.f, 17.f, 18.f, 18.f, 17.f, 17.f, 18.f, 18.f, 19.f, 19.f, 20.f, 20.f, 19.f, 19.f, 20.f, 20.f, 19.f, 19.f, 20.f, 20.f, 21.f, 21.f, 22.f, 22.f, 21.f, 21.f, 22.f, 22.f, 21.f, 21.f, 22.f, 22.f, 23.f, 23.f, 24.f, 24.f, 23.f, 23.f, 24.f, 24.f, 23.f, 23.f, 24.f, 24.f, 21.f, 21.f, 22.f, 22.f, 21.f, 21.f, 22.f, 22.f, 21.f, 21.f, 22.f, 22.f, 23.f, 23.f, 24.f, 24.f, 23.f, 23.f, 24.f, 24.f, 23.f, 23.f, 24.f, 24.f, + 25.f, 25.f, 26.f, 26.f, 25.f, 25.f, 26.f, 26.f, 25.f, 25.f, 26.f, 26.f, 27.f, 27.f, 28.f, 28.f, 27.f, 27.f, 28.f, 28.f, 27.f, 27.f, 28.f, 28.f, 25.f, 25.f, 26.f, 26.f, 25.f, 25.f, 26.f, 26.f, 25.f, 25.f, 26.f, 26.f, 27.f, 27.f, 28.f, 28.f, 27.f, 27.f, 28.f, 28.f, 27.f, 27.f, 28.f, 28.f, 29.f, 29.f, 30.f, 30.f, 29.f, 29.f, 30.f, 30.f, 29.f, 29.f, 30.f, 30.f, 31.f, 31.f, 32.f, 32.f, 31.f, 31.f, 32.f, 32.f, 31.f, 31.f, 32.f, 32.f, + 29.f, 29.f, 30.f, 30.f, 29.f, 29.f, 30.f, 30.f, 29.f, 29.f, 30.f, 30.f, 31.f, 31.f, 32.f, 32.f, 31.f, 31.f, 32.f, 32.f, 31.f, 31.f, 32.f, 32.f, 33.f, 33.f, 34.f, 34.f, 33.f, 33.f, 34.f, 34.f, 33.f, 33.f, 34.f, 34.f, 35.f, 35.f, 36.f, 36.f, 35.f, 35.f, 36.f, 36.f, 35.f, 35.f, 36.f, 36.f, 33.f, 33.f, 34.f, 34.f, 33.f, 33.f, 34.f, 34.f, 33.f, 33.f, 34.f, 34.f, 35.f, 35.f, 36.f, 36.f, 35.f, 35.f, 36.f, 36.f, 35.f, 35.f, 36.f, 36.f, + 37.f, 37.f, 38.f, 38.f, 37.f, 37.f, 38.f, 38.f, 37.f, 37.f, 38.f, 38.f, 39.f, 39.f, 40.f, 40.f, 39.f, 39.f, 40.f, 40.f, 39.f, 39.f, 40.f, 40.f, 37.f, 37.f, 38.f, 38.f, 37.f, 37.f, 38.f, 38.f, 37.f, 37.f, 38.f, 38.f, 39.f, 39.f, 40.f, 40.f, 39.f, 39.f, 40.f, 40.f, 39.f, 39.f, 40.f, 40.f, 41.f, 41.f, 42.f, 42.f, 41.f, 41.f, 42.f, 42.f, 41.f, 41.f, 42.f, 42.f, 43.f, 43.f, 44.f, 44.f, 43.f, 43.f, 44.f, 44.f, 43.f, 43.f, 44.f, 44.f, + 41.f, 41.f, 42.f, 42.f, 41.f, 41.f, 42.f, 42.f, 41.f, 41.f, 42.f, 42.f, 43.f, 43.f, 44.f, 44.f, 43.f, 43.f, 44.f, 44.f, 43.f, 43.f, 44.f, 44.f, 45.f, 45.f, 46.f, 46.f, 45.f, 45.f, 46.f, 46.f, 45.f, 45.f, 46.f, 46.f, 47.f, 47.f, 48.f, 48.f, 47.f, 47.f, 48.f, 48.f, 47.f, 47.f, 48.f, 48.f, 45.f, 45.f, 46.f, 46.f, 45.f, 45.f, 46.f, 46.f, 45.f, 45.f, 46.f, 46.f, 47.f, 47.f, 48.f, 48.f, 47.f, 47.f, 48.f, 48.f, 47.f, 47.f, 48.f, 48.f, + 49.f, 49.f, 50.f, 50.f, 49.f, 49.f, 50.f, 50.f, 49.f, 49.f, 50.f, 50.f, 51.f, 51.f, 52.f, 52.f, 51.f, 51.f, 52.f, 52.f, 51.f, 51.f, 52.f, 52.f, 49.f, 49.f, 50.f, 50.f, 49.f, 49.f, 50.f, 50.f, 49.f, 49.f, 50.f, 50.f, 51.f, 51.f, 52.f, 52.f, 51.f, 51.f, 52.f, 52.f, 51.f, 51.f, 52.f, 52.f, 53.f, 53.f, 54.f, 54.f, 53.f, 53.f, 54.f, 54.f, 53.f, 53.f, 54.f, 54.f, 55.f, 55.f, 56.f, 56.f, 55.f, 55.f, 56.f, 56.f, 55.f, 55.f, 56.f, 56.f, + 53.f, 53.f, 54.f, 54.f, 53.f, 53.f, 54.f, 54.f, 53.f, 53.f, 54.f, 54.f, 55.f, 55.f, 56.f, 56.f, 55.f, 55.f, 56.f, 56.f, 55.f, 55.f, 56.f, 56.f, 57.f, 57.f, 58.f, 58.f, 57.f, 57.f, 58.f, 58.f, 57.f, 57.f, 58.f, 58.f, 59.f, 59.f, 60.f, 60.f, 59.f, 59.f, 60.f, 60.f, 59.f, 59.f, 60.f, 60.f, 57.f, 57.f, 58.f, 58.f, 57.f, 57.f, 58.f, 58.f, 57.f, 57.f, 58.f, 58.f, 59.f, 59.f, 60.f, 60.f, 59.f, 59.f, 60.f, 60.f, 59.f, 59.f, 60.f, 60.f, + 61.f, 61.f, 62.f, 62.f, 61.f, 61.f, 62.f, 62.f, 61.f, 61.f, 62.f, 62.f, 63.f, 63.f, 64.f, 64.f, 63.f, 63.f, 64.f, 64.f, 63.f, 63.f, 64.f, 64.f, 61.f, 61.f, 62.f, 62.f, 61.f, 61.f, 62.f, 62.f, 61.f, 61.f, 62.f, 62.f, 63.f, 63.f, 64.f, 64.f, 63.f, 63.f, 64.f, 64.f, 63.f, 63.f, 64.f, 64.f, 65.f, 65.f, 66.f, 66.f, 65.f, 65.f, 66.f, 66.f, 65.f, 65.f, 66.f, 66.f, 67.f, 67.f, 68.f, 68.f, 67.f, 67.f, 68.f, 68.f, 67.f, 67.f, 68.f, 68.f, 65.f, 65.f, 66.f, 66.f, 65.f, 65.f, 66.f, 66.f, 65.f, 65.f, 66.f, 66.f, 67.f, 67.f, 68.f, 68.f, 67.f, 67.f, 68.f, 68.f, 67.f, 67.f, 68.f, 68.f, 69.f, 69.f, 70.f, 70.f, 69.f, 69.f, 70.f, 70.f, 69.f, 69.f, 70.f, 70.f, 71.f, 71.f, 72.f, 72.f, 71.f, 71.f, 72.f, 72.f, 71.f, 71.f, 72.f, 72.f, 69.f, 69.f, 70.f, 70.f, 69.f, 69.f, 70.f, 70.f, 69.f, 69.f, 70.f, 70.f, 71.f, 71.f, 72.f, 72.f, 71.f, 71.f, 72.f, 72.f, 71.f, 71.f, 72.f, 72.f}); nd4j::ops::upsampling3d op; @@ -2412,13 +2444,13 @@ TEST_F(ConvolutionTests1, deconv2d_test1) { auto input = NDArrayFactory::create('c', {bS, iH, iW, iC}); auto weights = NDArrayFactory::create('c', {kH, kW, oC, iC}); - auto exp = NDArrayFactory::create('c', {bS, oH, oW, oC}, { 2.75f, 7.75f, 12.75f, 17.75f, 22.75f, 30.5f, 40.5f, 50.5f, 60.5f, 70.5f, 30.5f, 40.5f, 50.5f, 60.5f, 70.5f, 27.75f, 32.75f, 37.75f, 42.75f, 47.75f, - 55.5f, 65.5f, 75.5f, 85.5f, 95.5f, 161.f, 181.f, 201.f, 221.f, 241.f, 161.f, 181.f, 201.f, 221.f, 241.f, 105.5f, 115.5f, 125.5f, 135.5f, 145.5f, - 55.5f, 65.5f, 75.5f, 85.5f, 95.5f, 161.f, 181.f, 201.f, 221.f, 241.f, 161.f, 181.f, 201.f, 221.f, 241.f, 105.5f, 115.5f, 125.5f, 135.5f, 145.5f, - 52.75f, 57.75f, 62.75f, 67.75f, 72.75f, 130.5f, 140.5f, 150.5f, 160.5f, 170.5f, 130.5f, 140.5f, 150.5f, 160.5f, 170.5f, 77.75f, 82.75f, 87.75f, 92.75f, 97.75f, - 2.75f, 7.75f, 12.75f, 17.75f, 22.75f, 30.5f, 40.5f, 50.5f, 60.5f, 70.5f, 30.5f, 40.5f, 50.5f, 60.5f, 70.5f, 27.75f, 32.75f, 37.75f, 42.75f, 47.75f, - 55.5f, 65.5f, 75.5f, 85.5f, 95.5f, 161.f, 181.f, 201.f, 221.f, 241.f, 161.f, 181.f, 201.f, 221.f, 241.f, 105.5f, 115.5f, 125.5f, 135.5f, 145.5f, - 55.5f, 65.5f, 75.5f, 85.5f, 95.5f, 161.f, 181.f, 201.f, 221.f, 241.f, 161.f, 181.f, 201.f, 221.f, 241.f, 105.5f, 115.5f, 125.5f, 135.5f, 145.5f, + auto exp = NDArrayFactory::create('c', {bS, oH, oW, oC}, { 2.75f, 7.75f, 12.75f, 17.75f, 22.75f, 30.5f, 40.5f, 50.5f, 60.5f, 70.5f, 30.5f, 40.5f, 50.5f, 60.5f, 70.5f, 27.75f, 32.75f, 37.75f, 42.75f, 47.75f, + 55.5f, 65.5f, 75.5f, 85.5f, 95.5f, 161.f, 181.f, 201.f, 221.f, 241.f, 161.f, 181.f, 201.f, 221.f, 241.f, 105.5f, 115.5f, 125.5f, 135.5f, 145.5f, + 55.5f, 65.5f, 75.5f, 85.5f, 95.5f, 161.f, 181.f, 201.f, 221.f, 241.f, 161.f, 181.f, 201.f, 221.f, 241.f, 105.5f, 115.5f, 125.5f, 135.5f, 145.5f, + 52.75f, 57.75f, 62.75f, 67.75f, 72.75f, 130.5f, 140.5f, 150.5f, 160.5f, 170.5f, 130.5f, 140.5f, 150.5f, 160.5f, 170.5f, 77.75f, 82.75f, 87.75f, 92.75f, 97.75f, + 2.75f, 7.75f, 12.75f, 17.75f, 22.75f, 30.5f, 40.5f, 50.5f, 60.5f, 70.5f, 30.5f, 40.5f, 50.5f, 60.5f, 70.5f, 27.75f, 32.75f, 37.75f, 42.75f, 47.75f, + 55.5f, 65.5f, 75.5f, 85.5f, 95.5f, 161.f, 181.f, 201.f, 221.f, 241.f, 161.f, 181.f, 201.f, 221.f, 241.f, 105.5f, 115.5f, 125.5f, 135.5f, 145.5f, + 55.5f, 65.5f, 75.5f, 85.5f, 95.5f, 161.f, 181.f, 201.f, 221.f, 241.f, 161.f, 181.f, 201.f, 221.f, 241.f, 105.5f, 115.5f, 125.5f, 135.5f, 145.5f, 52.75f, 57.75f, 62.75f, 67.75f, 72.75f, 130.5f, 140.5f, 150.5f, 160.5f, 170.5f, 130.5f, 140.5f, 150.5f, 160.5f, 170.5f, 77.75f, 82.75f, 87.75f, 92.75f, 97.75f}); input = 0.5; weights.linspace(0.1, 0.1); @@ -2445,13 +2477,13 @@ TEST_F(ConvolutionTests1, deconv2d_test2) { auto input = NDArrayFactory::create('c', {bS, oH, oW, oC}); auto weights = NDArrayFactory::create('c', {kH, kW, iC, oC}); - auto exp = NDArrayFactory::create('c', {bS, iH, iW, iC}, {2.75f, 7.75f, 12.75f, 17.75f, 22.75f, 30.5f, 40.5f, 50.5f, 60.5f, 70.5f, 30.5f, 40.5f, 50.5f, 60.5f, 70.5f, 30.5f, 40.5f, 50.5f, 60.5f, 70.5f, - 55.5f, 65.5f, 75.5f, 85.5f, 95.5f, 161.f, 181.f, 201.f, 221.f, 241.f, 161.f, 181.f, 201.f, 221.f, 241.f, 161.f, 181.f, 201.f, 221.f, 241.f, - 55.5f, 65.5f, 75.5f, 85.5f, 95.5f, 161.f, 181.f, 201.f, 221.f, 241.f, 161.f, 181.f, 201.f, 221.f, 241.f, 161.f, 181.f, 201.f, 221.f, 241.f, - 55.5f, 65.5f, 75.5f, 85.5f, 95.5f, 161.f, 181.f, 201.f, 221.f, 241.f, 161.f, 181.f, 201.f, 221.f, 241.f, 161.f, 181.f, 201.f, 221.f, 241.f, - 2.75f, 7.75f, 12.75f, 17.75f, 22.75f, 30.5f, 40.5f, 50.5f, 60.5f, 70.5f, 30.5f, 40.5f, 50.5f, 60.5f, 70.5f, 30.5f, 40.5f, 50.5f, 60.5f, 70.5f, - 55.5f, 65.5f, 75.5f, 85.5f, 95.5f, 161.f, 181.f, 201.f, 221.f, 241.f, 161.f, 181.f, 201.f, 221.f, 241.f, 161.f, 181.f, 201.f, 221.f, 241.f, - 55.5f, 65.5f, 75.5f, 85.5f, 95.5f, 161.f, 181.f, 201.f, 221.f, 241.f, 161.f, 181.f, 201.f, 221.f, 241.f, 161.f, 181.f, 201.f, 221.f, 241.f, + auto exp = NDArrayFactory::create('c', {bS, iH, iW, iC}, {2.75f, 7.75f, 12.75f, 17.75f, 22.75f, 30.5f, 40.5f, 50.5f, 60.5f, 70.5f, 30.5f, 40.5f, 50.5f, 60.5f, 70.5f, 30.5f, 40.5f, 50.5f, 60.5f, 70.5f, + 55.5f, 65.5f, 75.5f, 85.5f, 95.5f, 161.f, 181.f, 201.f, 221.f, 241.f, 161.f, 181.f, 201.f, 221.f, 241.f, 161.f, 181.f, 201.f, 221.f, 241.f, + 55.5f, 65.5f, 75.5f, 85.5f, 95.5f, 161.f, 181.f, 201.f, 221.f, 241.f, 161.f, 181.f, 201.f, 221.f, 241.f, 161.f, 181.f, 201.f, 221.f, 241.f, + 55.5f, 65.5f, 75.5f, 85.5f, 95.5f, 161.f, 181.f, 201.f, 221.f, 241.f, 161.f, 181.f, 201.f, 221.f, 241.f, 161.f, 181.f, 201.f, 221.f, 241.f, + 2.75f, 7.75f, 12.75f, 17.75f, 22.75f, 30.5f, 40.5f, 50.5f, 60.5f, 70.5f, 30.5f, 40.5f, 50.5f, 60.5f, 70.5f, 30.5f, 40.5f, 50.5f, 60.5f, 70.5f, + 55.5f, 65.5f, 75.5f, 85.5f, 95.5f, 161.f, 181.f, 201.f, 221.f, 241.f, 161.f, 181.f, 201.f, 221.f, 241.f, 161.f, 181.f, 201.f, 221.f, 241.f, + 55.5f, 65.5f, 75.5f, 85.5f, 95.5f, 161.f, 181.f, 201.f, 221.f, 241.f, 161.f, 181.f, 201.f, 221.f, 241.f, 161.f, 181.f, 201.f, 221.f, 241.f, 55.5f, 65.5f, 75.5f, 85.5f, 95.5f, 161.f, 181.f, 201.f, 221.f, 241.f, 161.f, 181.f, 201.f, 221.f, 241.f, 161.f, 181.f, 201.f, 221.f, 241.f }); input = 0.5; weights.linspace(0.1, 0.1); @@ -2673,13 +2705,13 @@ TYPED_TEST(TypedConvolutionTests1, deconv2d_tf_test1) { auto input = NDArrayFactory::create('c', {bS, oH, oW, oC}); auto weights = NDArrayFactory::create('c', {kH, kW, iC, oC}); auto outShape = NDArrayFactory::create('c', {4}, {static_cast(bS), static_cast(iH), static_cast(iW), static_cast(iC)}); - auto exp = NDArrayFactory::create('c', {bS, iH, iW, iC}, { 2.75f, 7.75f, 12.75f, 17.75f, 22.75f, 30.5f, 40.5f, 50.5f, 60.5f, 70.5f, 30.5f, 40.5f, 50.5f, 60.5f, 70.5f, 27.75f, 32.75f, 37.75f, 42.75f, 47.75f, - 55.5f, 65.5f, 75.5f, 85.5f, 95.5f, 161.f, 181.f, 201.f, 221.f, 241.f, 161.f, 181.f, 201.f, 221.f, 241.f, 105.5f, 115.5f, 125.5f, 135.5f, 145.5f, - 55.5f, 65.5f, 75.5f, 85.5f, 95.5f, 161.f, 181.f, 201.f, 221.f, 241.f, 161.f, 181.f, 201.f, 221.f, 241.f, 105.5f, 115.5f, 125.5f, 135.5f, 145.5f, - 52.75f, 57.75f, 62.75f, 67.75f, 72.75f, 130.5f, 140.5f, 150.5f, 160.5f, 170.5f, 130.5f, 140.5f, 150.5f, 160.5f, 170.5f, 77.75f, 82.75f, 87.75f, 92.75f, 97.75f, - 2.75f, 7.75f, 12.75f, 17.75f, 22.75f, 30.5f, 40.5f, 50.5f, 60.5f, 70.5f, 30.5f, 40.5f, 50.5f, 60.5f, 70.5f, 27.75f, 32.75f, 37.75f, 42.75f, 47.75f, - 55.5f, 65.5f, 75.5f, 85.5f, 95.5f, 161.f, 181.f, 201.f, 221.f, 241.f, 161.f, 181.f, 201.f, 221.f, 241.f, 105.5f, 115.5f, 125.5f, 135.5f, 145.5f, - 55.5f, 65.5f, 75.5f, 85.5f, 95.5f, 161.f, 181.f, 201.f, 221.f, 241.f, 161.f, 181.f, 201.f, 221.f, 241.f, 105.5f, 115.5f, 125.5f, 135.5f, 145.5f, + auto exp = NDArrayFactory::create('c', {bS, iH, iW, iC}, { 2.75f, 7.75f, 12.75f, 17.75f, 22.75f, 30.5f, 40.5f, 50.5f, 60.5f, 70.5f, 30.5f, 40.5f, 50.5f, 60.5f, 70.5f, 27.75f, 32.75f, 37.75f, 42.75f, 47.75f, + 55.5f, 65.5f, 75.5f, 85.5f, 95.5f, 161.f, 181.f, 201.f, 221.f, 241.f, 161.f, 181.f, 201.f, 221.f, 241.f, 105.5f, 115.5f, 125.5f, 135.5f, 145.5f, + 55.5f, 65.5f, 75.5f, 85.5f, 95.5f, 161.f, 181.f, 201.f, 221.f, 241.f, 161.f, 181.f, 201.f, 221.f, 241.f, 105.5f, 115.5f, 125.5f, 135.5f, 145.5f, + 52.75f, 57.75f, 62.75f, 67.75f, 72.75f, 130.5f, 140.5f, 150.5f, 160.5f, 170.5f, 130.5f, 140.5f, 150.5f, 160.5f, 170.5f, 77.75f, 82.75f, 87.75f, 92.75f, 97.75f, + 2.75f, 7.75f, 12.75f, 17.75f, 22.75f, 30.5f, 40.5f, 50.5f, 60.5f, 70.5f, 30.5f, 40.5f, 50.5f, 60.5f, 70.5f, 27.75f, 32.75f, 37.75f, 42.75f, 47.75f, + 55.5f, 65.5f, 75.5f, 85.5f, 95.5f, 161.f, 181.f, 201.f, 221.f, 241.f, 161.f, 181.f, 201.f, 221.f, 241.f, 105.5f, 115.5f, 125.5f, 135.5f, 145.5f, + 55.5f, 65.5f, 75.5f, 85.5f, 95.5f, 161.f, 181.f, 201.f, 221.f, 241.f, 161.f, 181.f, 201.f, 221.f, 241.f, 105.5f, 115.5f, 125.5f, 135.5f, 145.5f, 52.75f, 57.75f, 62.75f, 67.75f, 72.75f, 130.5f, 140.5f, 150.5f, 160.5f, 170.5f, 130.5f, 140.5f, 150.5f, 160.5f, 170.5f, 77.75f, 82.75f, 87.75f, 92.75f, 97.75f}); input = 0.5; weights.linspace(0.1, 0.1); diff --git a/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests13.cpp b/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests13.cpp index 209b16a7a..76a44be0b 100644 --- a/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests13.cpp +++ b/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests13.cpp @@ -428,10 +428,11 @@ TEST_F(DeclarableOpsTests13, CellContains_test_1) { TEST_F(DeclarableOpsTests13, adjustHue_1) { NDArray input('c', {2,2,3}, {0,100,56, 17,220,5, 150,97,230, 255,2,13}, nd4j::DataType::FLOAT32); + NDArray factor = NDArrayFactory::create(0.5); NDArray exp ('c', {2,2,3}, {100,0,44, 208,5,220, 177,230,97, 2,255,244}, nd4j::DataType::FLOAT32); nd4j::ops::adjust_hue op; - auto results = op.execute({&input}, {0.5}, {2}); + auto results = op.execute({&input, &factor}, {}, {2}); ASSERT_EQ(ND4J_STATUS_OK, results->status()); @@ -525,10 +526,11 @@ TEST_F(DeclarableOpsTests13, adjustHue_5) { TEST_F(DeclarableOpsTests13, adjustSaturation_1) { NDArray input('c', {2,2,3}, {0,100,56, 17,220,5, 150,97,230, 255,2,13}, nd4j::DataType::FLOAT32); + NDArray factor = NDArrayFactory::create(0.5); NDArray exp ('c', {2,2,3}, {50,100,78, 118.5,220,112.5, 190,163.5,230, 255,128.5,134}, nd4j::DataType::FLOAT32); nd4j::ops::adjust_saturation op; - auto results = op.execute({&input}, {0.5}, {2}); + auto results = op.execute({&input, &factor}, {}, {2}); ASSERT_EQ(ND4J_STATUS_OK, results->status()); diff --git a/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests15.cpp b/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests15.cpp index 41dc12a14..b2ccad86f 100644 --- a/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests15.cpp +++ b/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests15.cpp @@ -159,18 +159,19 @@ TEST_F(DeclarableOpsTests15, Test_standarize_bp_1) { TEST_F(DeclarableOpsTests15, Test_AdjustContrast_1) { auto x = NDArrayFactory::create('c', {4,4,3}); - auto e = NDArrayFactory::create('c', {4,4,3}, { - -21.5, -20.5, -19.5, -15.5, -14.5, -13.5, -9.5, -8.5, -7.5, -3.5, -2.5, -1.5, - 2.5, 3.5, 4.5, 8.5, 9.5, 10.5, 14.5, 15.5, 16.5, 20.5, 21.5, 22.5, - 26.5, 27.5, 28.5, 32.5, 33.5, 34.5, 38.5, 39.5, 40.5, 44.5, 45.5, 46.5, - 50.5, 51.5, 52.5, 56.5, 57.5, 58.5, 62.5, 63.5, 64.5, 68.5, 69.5, 70.5 - }); + NDArray factor = NDArrayFactory::create(2.); + auto e = NDArrayFactory::create('c', {4,4,3}, {-21.5, -20.5, -19.5, -15.5, -14.5, -13.5, -9.5, -8.5, -7.5, -3.5, -2.5, -1.5, + 2.5, 3.5, 4.5, 8.5, 9.5, 10.5, 14.5, 15.5, 16.5, 20.5, 21.5, 22.5, + 26.5, 27.5, 28.5, 32.5, 33.5, 34.5, 38.5, 39.5, 40.5, 44.5, 45.5, 46.5, + 50.5, 51.5, 52.5, 56.5, 57.5, 58.5, 62.5, 63.5, 64.5, 68.5, 69.5, 70.5}); + + x.linspace(1.); nd4j::ops::adjust_contrast op; - auto result = op.execute({&x}, {2.}, {}, {}); + auto result = op.execute({&x, &factor}, {}, {}, {}); ASSERT_EQ(Status::OK(), result->status()); auto out = result->at(0); -// out->printIndexedBuffer("Adjusted Constrast"); + ASSERT_TRUE(e.equalsTo(out)); delete result; } diff --git a/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests3.cpp b/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests3.cpp index 269c13f51..a521be97b 100644 --- a/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests3.cpp +++ b/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests3.cpp @@ -1774,6 +1774,28 @@ TEST_F(DeclarableOpsTests3, betainc_test10) { delete results; } +/////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests3, betainc_test11) { + + NDArray a('c', {4}, {0.7788f, 0.8012f, 0.7244f, 0.2309f}, nd4j::DataType::FLOAT32); + NDArray b('c', {4}, {0.7717f, 0.9281f, 0.9846f, 0.4838f}, nd4j::DataType::FLOAT32); + NDArray x('c', {4}, {0.9441f, 0.5957f, 0.8669f, 0.3502f}, nd4j::DataType::FLOAT32); + + NDArray expected('c', {4}, {0.912156, 0.634443, 0.898314, 0.624544}, nd4j::DataType::FLOAT32); + + nd4j::ops::betainc op; + auto results = op.execute({&a, &b, &x}, {}, {}); + + ASSERT_EQ(ND4J_STATUS_OK, results->status()); + + auto *output = results->at(0); + + ASSERT_TRUE(expected.isSameShape(output)); + ASSERT_TRUE(expected.equalsTo(output)); + + delete results; +} + /////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests3, zeta_test1) { @@ -2092,8 +2114,26 @@ TEST_F(DeclarableOpsTests3, polygamma_test3) { x.linspace(10.); auto expected= NDArrayFactory::create('c', {3,3}, {1.05166336e-01,-9.04983497e-03, 1.31009323e-03,-2.44459433e-04, 5.31593880e-05,-1.28049888e-05, 3.31755364e-06,-9.07408791e-07, 2.58758130e-07}); + nd4j::ops::polygamma op; + auto results = op.execute({&n, &x}, {}, {}); - //ASSERT_FALSE(true); + ASSERT_EQ(ND4J_STATUS_OK, results->status()); + + auto output = results->at(0); + + ASSERT_TRUE(expected.isSameShape(output)); + ASSERT_TRUE(expected.equalsTo(output)); + + delete results; +} + +TEST_F(DeclarableOpsTests3, polygamma_test4) { + + NDArray n('c', {3,4}, {/*0.7788*/0, 0,1,2,3,4,5,6,7,8,9,10}, nd4j::DataType::DOUBLE); + NDArray x('c', {3,4}, {0.7717,0.9281,0.9846,0.4838,0.6433,0.6041,0.6501,0.7612,0.7605,0.3948,0.9493,0.8600}, nd4j::DataType::DOUBLE); + + NDArray expected('c', {3,4}, {/*std::numeric_limits::quiet_NaN()*/-1.031918, -7.021327e-01, 1.682743e+00, -1.851378e+01,3.604167e+01, -3.008293e+02, + 1.596005e+03, -4.876665e+03,4.510025e+04, -1.730340e+08, 6.110257e+05, -1.907087e+07}, nd4j::DataType::DOUBLE); nd4j::ops::polygamma op; auto results = op.execute({&n, &x}, {}, {}); @@ -2108,6 +2148,26 @@ TEST_F(DeclarableOpsTests3, polygamma_test3) { delete results; } +TEST_F(DeclarableOpsTests3, digamma_1) { + + NDArray x('c', {18}, {-25, -24.99999, -21.5, -21.2, -5.5, -4.1, -2.1, -0.5, -0.3, 0., 0.2, 1, 1.5, 2.2, 5.2, 19., 21, 22.2}, nd4j::DataType::DOUBLE); + + NDArray expected('c', {18}, {std::numeric_limits::infinity(), -99996.761229, 3.091129, 7.401432, 1.792911,11.196838,10.630354, 0.03649, 2.11331, + std::numeric_limits::infinity(),-5.28904,-0.577216, 0.03649, 0.544293, 1.549434,2.917892, 3.020524, 3.077401}, nd4j::DataType::DOUBLE); + + nd4j::ops::digamma op; + auto results = op.execute({&x}, {}, {}); + + ASSERT_EQ(ND4J_STATUS_OK, results->status()); + + auto output = results->at(0); + + ASSERT_TRUE(expected.isSameShape(output)); + ASSERT_TRUE(expected.equalsTo(output)); + + delete results; +} + /////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests3, svd_test1) {