Shyrma adjust (#98)
* - add possibility of passing scalar-array as input parameter for scale factor in adjust hue/contrast/saturation ops - correct typo in function which calculates regularized incomplete beta integral Signed-off-by: Yurii <iuriish@yahoo.com> * - fix bug in betainc cuda kernel Signed-off-by: Yurii <iuriish@yahoo.com> * - start working on implementation of digamma function Signed-off-by: Yurii <iuriish@yahoo.com> * - further work on digamma function (cpu) Signed-off-by: Yurii <iuriish@yahoo.com> * - testing and fixing bugs in digamma op Signed-off-by: Yurii <iuriish@yahoo.com> * - make correction n cuda kernel for polyGamma Signed-off-by: Yurii <iuriish@yahoo.com> * - remove unnecessary stuff from betaInc cuda kernel Signed-off-by: Yurii <iuriish@yahoo.com> * - resolve conflicts in DeclarableOpsTests3.cpp after master branch has been merged Signed-off-by: Yurii <iuriish@yahoo.com> * - restore id number of Not opertion in legacy_ops.h Signed-off-by: Yurii <iuriish@yahoo.com> * - correct padding calculation in mkl dnn conv1d causal Signed-off-by: Yurii <iuriish@yahoo.com> * restore empty check in adjust_contrast_v2 Signed-off-by: raver119 <raver119@gmail.com>
This commit is contained in:
		
							parent
							
								
									1e9ff114aa
								
							
						
					
					
						commit
						1f5e15b541
					
				| @ -112,7 +112,8 @@ | |||||||
|         (4, IsInfOrNan), \ |         (4, IsInfOrNan), \ | ||||||
|         (5, MatchConditionBool), \ |         (5, MatchConditionBool), \ | ||||||
|         (6, IsPositive) , \ |         (6, IsPositive) , \ | ||||||
|         (7, Not) |         (7, Not), \ | ||||||
|  |         (8, IsNegative) | ||||||
| 
 | 
 | ||||||
| 
 | 
 | ||||||
| #define TRANSFORM_STRICT_OPS \ | #define TRANSFORM_STRICT_OPS \ | ||||||
| @ -279,7 +280,8 @@ | |||||||
|         (3, IsInfOrNan), \ |         (3, IsInfOrNan), \ | ||||||
|         (4, IsNan), \ |         (4, IsNan), \ | ||||||
|         (5, IsInf), \ |         (5, IsInf), \ | ||||||
|         (6, IsPositive) |         (6, IsPositive), \ | ||||||
|  |         (7, IsNegative) | ||||||
| 
 | 
 | ||||||
| #define REDUCE_SAME_OPS \ | #define REDUCE_SAME_OPS \ | ||||||
|         (0, Sum), \ |         (0, Sum), \ | ||||||
|  | |||||||
| @ -27,7 +27,8 @@ | |||||||
| namespace nd4j { | namespace nd4j { | ||||||
| namespace ops { | namespace ops { | ||||||
| 
 | 
 | ||||||
| CONFIGURABLE_OP_IMPL(adjust_contrast, 1, 1, true, -2, 0) { | ////////////////////////////////////////////////////////////////////
 | ||||||
|  | CONFIGURABLE_OP_IMPL(adjust_contrast, 1, 1, true, 0, 0) { | ||||||
| 
 | 
 | ||||||
|     auto input  = INPUT_VARIABLE(0); |     auto input  = INPUT_VARIABLE(0); | ||||||
|     auto output = OUTPUT_VARIABLE(0); |     auto output = OUTPUT_VARIABLE(0); | ||||||
| @ -37,23 +38,31 @@ CONFIGURABLE_OP_IMPL(adjust_contrast, 1, 1, true, -2, 0) { | |||||||
|         return Status::OK(); |         return Status::OK(); | ||||||
| 
 | 
 | ||||||
|     REQUIRE_TRUE(block.numT() > 0 || block.width() > 1, 0, "ADJUST_CONTRAST: Scale factor required"); |     REQUIRE_TRUE(block.numT() > 0 || block.width() > 1, 0, "ADJUST_CONTRAST: Scale factor required"); | ||||||
| 
 |  | ||||||
|     const double factor = block.width() > 1 ? INPUT_VARIABLE(1)->e<double>(0) : T_ARG(0); |  | ||||||
| 
 |  | ||||||
|     REQUIRE_TRUE(input->rankOf() > 2, 0, "ADJUST_CONTRAST: op expects rank of input array to be >= 3, but got %i instead", input->rankOf()); |     REQUIRE_TRUE(input->rankOf() > 2, 0, "ADJUST_CONTRAST: op expects rank of input array to be >= 3, but got %i instead", input->rankOf()); | ||||||
|     REQUIRE_TRUE(input->sizeAt(-1) == 3, 0, "ADJUST_CONTRAST: operation expects image with 3 channels (R, G, B), but got %i instead", input->sizeAt(-1)); |     REQUIRE_TRUE(input->sizeAt(-1) == 3, 0, "ADJUST_CONTRAST: operation expects image with 3 channels (R, G, B), but got %i instead", input->sizeAt(-1)); | ||||||
|     // compute mean before
 | 
 | ||||||
|  |     NDArray* factor = nullptr; | ||||||
|  | 
 | ||||||
|  |     if(block.width() > 1) | ||||||
|  |         factor = INPUT_VARIABLE(1); | ||||||
|  |     else { | ||||||
|  |         factor = new NDArray(output->dataType(), block.launchContext()); | ||||||
|  |         factor->p(0, T_ARG(0)); | ||||||
|  |     } | ||||||
|  | 
 | ||||||
|     // fill up axes vector first
 |     // fill up axes vector first
 | ||||||
|     std::vector<int> axes(input->rankOf() - 1); |     std::vector<int> axes(input->rankOf() - 1); | ||||||
|     for (auto i = 0; i < axes.size(); ++i) |     for (auto i = 0; i < axes.size(); ++i) | ||||||
|         axes[i] = i; |         axes[i] = i; | ||||||
|  | 
 | ||||||
|     // mean as reduction for last dimension set
 |     // mean as reduction for last dimension set
 | ||||||
|     auto mean = input->reduceAlongDims(reduce::Mean, axes); |     auto mean = input->reduceAlongDims(reduce::Mean, axes); | ||||||
| 
 | 
 | ||||||
|     NDArray factorT(output->dataType(), block.launchContext()); // = NDArrayFactory::create(factor, block.launchContext());
 |  | ||||||
|     factorT.p(0, factor); |  | ||||||
|     // this is contrast calculation
 |     // this is contrast calculation
 | ||||||
|     output->assign((*input - mean) * factorT + mean); |     output->assign((*input - mean) * (*factor) + mean); | ||||||
|  | 
 | ||||||
|  |     if(block.width() == 1) | ||||||
|  |         delete factor; | ||||||
| 
 | 
 | ||||||
|     return Status::OK(); |     return Status::OK(); | ||||||
| } | } | ||||||
| @ -64,45 +73,54 @@ DECLARE_TYPES(adjust_contrast) { | |||||||
|                      ->setSameMode(true); |                      ->setSameMode(true); | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
|  | ////////////////////////////////////////////////////////////////////
 | ||||||
|  | CONFIGURABLE_OP_IMPL(adjust_contrast_v2, 1, 1, true, 0, 0) { | ||||||
| 
 | 
 | ||||||
|     CONFIGURABLE_OP_IMPL(adjust_contrast_v2, 1, 1, true, -2, 0) { |     auto input  = INPUT_VARIABLE(0); | ||||||
| 
 |     auto output = OUTPUT_VARIABLE(0); | ||||||
|         auto input  = INPUT_VARIABLE(0); |  | ||||||
|         auto output = OUTPUT_VARIABLE(0); |  | ||||||
| 
 |  | ||||||
|         REQUIRE_TRUE(block.numT() > 0 || block.width() > 1, 0, "ADJUST_CONTRAST_V2: Scale factor required"); |  | ||||||
| 
 |  | ||||||
|         const double factor = block.width() > 1 ? INPUT_VARIABLE(1)->e<double>(0) : T_ARG(0); |  | ||||||
| 
 |  | ||||||
|         // just skip op if input is empty
 |  | ||||||
|         if (input->isEmpty()) |  | ||||||
|             return Status::OK(); |  | ||||||
| 
 |  | ||||||
|         REQUIRE_TRUE(input->rankOf() > 2, 0, "ADJUST_CONTRAST_V2: op expects rank of input array to be >= 3, but got %i instead", input->rankOf()); |  | ||||||
|         REQUIRE_TRUE(input->sizeAt(-1) == 3, 0, "ADJUST_CONTRAST_V2: operation expects image with 3 channels (R, G, B), but got %i instead", input->sizeAt(-1)); |  | ||||||
| 
 |  | ||||||
|         // compute mean before
 |  | ||||||
|         std::vector<int> axes(input->rankOf() - 1); |  | ||||||
|         for (auto i = 0; i < axes.size(); ++i) |  | ||||||
|             axes[i] = i; |  | ||||||
| 
 |  | ||||||
|         // mean as reduction for last dimension set
 |  | ||||||
|         auto mean = input->reduceAlongDims(reduce::Mean, axes); |  | ||||||
| 
 |  | ||||||
|         // result as (x - mean) * factor + mean
 |  | ||||||
|         auto temp = input->ulike(); |  | ||||||
|         input->applyTrueBroadcast(BroadcastOpsTuple::Subtract(), &mean, &temp); |  | ||||||
|         temp.applyScalar(scalar::Multiply, factor); |  | ||||||
|         temp.applyTrueBroadcast(BroadcastOpsTuple::Add(), &mean, output); |  | ||||||
| 
 | 
 | ||||||
|  |     // just skip op if input is empty
 | ||||||
|  |     if (input->isEmpty()) | ||||||
|         return Status::OK(); |         return Status::OK(); | ||||||
|  | 
 | ||||||
|  |     REQUIRE_TRUE(input->rankOf() > 2, 0, "ADJUST_CONTRAST_V2: op expects rank of input array to be >= 3, but got %i instead", input->rankOf()); | ||||||
|  |     REQUIRE_TRUE(input->sizeAt(-1) == 3, 0, "ADJUST_CONTRAST_V2: operation expects image with 3 channels (R, G, B), but got %i instead", input->sizeAt(-1)); | ||||||
|  |     REQUIRE_TRUE(block.numT() > 0 || block.width() > 1, 0, "ADJUST_CONTRAST_V2: Scale factor required"); | ||||||
|  | 
 | ||||||
|  |     NDArray* factor = nullptr; | ||||||
|  | 
 | ||||||
|  |     if(block.width() > 1) | ||||||
|  |         factor = INPUT_VARIABLE(1); | ||||||
|  |     else { | ||||||
|  |         factor = new NDArray(output->dataType(), block.launchContext()); | ||||||
|  |         factor->p(0, T_ARG(0)); | ||||||
|     } |     } | ||||||
| 
 | 
 | ||||||
|     DECLARE_TYPES(adjust_contrast_v2) { |     // compute mean before
 | ||||||
|         getOpDescriptor()->setAllowedInputTypes(nd4j::DataType::ANY) |     std::vector<int> axes(input->rankOf() - 1); | ||||||
|                 ->setAllowedOutputTypes({ALL_FLOATS}) |     for (auto i = 0; i < axes.size(); ++i) | ||||||
|                 ->setSameMode(true); |         axes[i] = i; | ||||||
|     } | 
 | ||||||
|  |     // mean as reduction for last dimension set
 | ||||||
|  |     auto mean = input->reduceAlongDims(reduce::Mean, axes); | ||||||
|  | 
 | ||||||
|  |     // result as (x - mean) * factor + mean
 | ||||||
|  |     auto temp = input->ulike(); | ||||||
|  |     input->applyTrueBroadcast(BroadcastOpsTuple::Subtract(), &mean, &temp); | ||||||
|  |     temp.applyScalarArr(scalar::Multiply, factor); | ||||||
|  |     temp.applyTrueBroadcast(BroadcastOpsTuple::Add(), &mean, output); | ||||||
|  | 
 | ||||||
|  |     if(block.width() == 1) | ||||||
|  |         delete factor; | ||||||
|  | 
 | ||||||
|  |     return Status::OK(); | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | DECLARE_TYPES(adjust_contrast_v2) { | ||||||
|  |     getOpDescriptor()->setAllowedInputTypes(nd4j::DataType::ANY) | ||||||
|  |             ->setAllowedOutputTypes({ALL_FLOATS}) | ||||||
|  |             ->setSameMode(true); | ||||||
|  | } | ||||||
| 
 | 
 | ||||||
| } | } | ||||||
| } | } | ||||||
|  | |||||||
| @ -24,13 +24,12 @@ | |||||||
| 
 | 
 | ||||||
| #include <ops/declarable/headers/parity_ops.h> | #include <ops/declarable/headers/parity_ops.h> | ||||||
| #include <ops/declarable/helpers/adjust_hue.h> | #include <ops/declarable/helpers/adjust_hue.h> | ||||||
| #include <NDArrayFactory.h> |  | ||||||
| 
 | 
 | ||||||
| namespace nd4j { | namespace nd4j { | ||||||
| namespace ops { | namespace ops { | ||||||
| 
 | 
 | ||||||
| 
 | 
 | ||||||
| CONFIGURABLE_OP_IMPL(adjust_hue, 1, 1, true, 1, -2) { | CONFIGURABLE_OP_IMPL(adjust_hue, 1, 1, true, 0, 0) { | ||||||
| 
 | 
 | ||||||
|     auto input  = INPUT_VARIABLE(0); |     auto input  = INPUT_VARIABLE(0); | ||||||
|     auto output = OUTPUT_VARIABLE(0); |     auto output = OUTPUT_VARIABLE(0); | ||||||
| @ -41,15 +40,26 @@ CONFIGURABLE_OP_IMPL(adjust_hue, 1, 1, true, 1, -2) { | |||||||
| 
 | 
 | ||||||
|     const int rank     = input->rankOf(); |     const int rank     = input->rankOf(); | ||||||
|     const int dimC     = block.getIArguments()->size() > 0 ? (INT_ARG(0) >= 0 ? INT_ARG(0) : INT_ARG(0) + rank) : rank - 1; |     const int dimC     = block.getIArguments()->size() > 0 ? (INT_ARG(0) >= 0 ? INT_ARG(0) : INT_ARG(0) + rank) : rank - 1; | ||||||
|     const double delta = T_ARG(0); |  | ||||||
| 
 | 
 | ||||||
|  |     REQUIRE_TRUE(block.numT() > 0 || block.width() > 1, 0, "ADJUST_HUE: delta factor is required !"); | ||||||
|     REQUIRE_TRUE(rank >= 3, 0, "ADJUST_HUE: op expects rank of input array to be >= 3, but got %i instead", rank); |     REQUIRE_TRUE(rank >= 3, 0, "ADJUST_HUE: op expects rank of input array to be >= 3, but got %i instead", rank); | ||||||
|     REQUIRE_TRUE(input->sizeAt(dimC) == 3, 0, "ADJUST_HUE: operation expects image with 3 channels (R, G, B), but got %i instead", input->sizeAt(dimC)); |     REQUIRE_TRUE(input->sizeAt(dimC) == 3, 0, "ADJUST_HUE: operation expects image with 3 channels (R, G, B), but got %i instead", input->sizeAt(dimC)); | ||||||
|     REQUIRE_TRUE(-1. <= delta && delta <= 1., 0, "ADJUST_HUE: parameter delta must be within [-1, 1] interval, but got %f instead", delta); |  | ||||||
| 
 | 
 | ||||||
|     NDArray deltaScalarArr = NDArrayFactory::create<double>(delta, block.launchContext()); |     NDArray* delta = nullptr; | ||||||
| 
 | 
 | ||||||
|     helpers::adjustHue(block.launchContext(), input, &deltaScalarArr, output, dimC); |     if(block.width() > 1) | ||||||
|  |         delta = INPUT_VARIABLE(1); | ||||||
|  |     else { | ||||||
|  |         delta = new NDArray(output->dataType(), block.launchContext()); | ||||||
|  |         delta->p(0, T_ARG(0)); | ||||||
|  |     } | ||||||
|  | 
 | ||||||
|  |     REQUIRE_TRUE(-1. <= delta->e<double>(0) && delta->e<double>(0) <= 1., 0, "ADJUST_HUE: parameter delta must be within [-1, 1] interval, but got %f instead", delta); | ||||||
|  | 
 | ||||||
|  |     helpers::adjustHue(block.launchContext(), input, delta, output, dimC); | ||||||
|  | 
 | ||||||
|  |     if(block.width() == 1) | ||||||
|  |         delete delta; | ||||||
| 
 | 
 | ||||||
|     return Status::OK(); |     return Status::OK(); | ||||||
| } | } | ||||||
|  | |||||||
| @ -28,7 +28,7 @@ | |||||||
| namespace nd4j { | namespace nd4j { | ||||||
| namespace ops { | namespace ops { | ||||||
| 
 | 
 | ||||||
| CONFIGURABLE_OP_IMPL(adjust_saturation, 1, 1, true, 1, -2) { | CONFIGURABLE_OP_IMPL(adjust_saturation, 1, 1, true, 0, 0) { | ||||||
| 
 | 
 | ||||||
|     auto input  = INPUT_VARIABLE(0); |     auto input  = INPUT_VARIABLE(0); | ||||||
|     auto output = OUTPUT_VARIABLE(0); |     auto output = OUTPUT_VARIABLE(0); | ||||||
| @ -37,16 +37,26 @@ CONFIGURABLE_OP_IMPL(adjust_saturation, 1, 1, true, 1, -2) { | |||||||
|     if (input->isEmpty()) |     if (input->isEmpty()) | ||||||
|         return Status::OK(); |         return Status::OK(); | ||||||
| 
 | 
 | ||||||
|     const int rank     = input->rankOf(); |     const int rank = input->rankOf(); | ||||||
|     const int dimC     = block.getIArguments()->size() > 0 ? (INT_ARG(0) >= 0 ? INT_ARG(0) : INT_ARG(0) + rank) : rank - 1; |     const int dimC = block.getIArguments()->size() > 0 ? (INT_ARG(0) >= 0 ? INT_ARG(0) : INT_ARG(0) + rank) : rank - 1; | ||||||
|     const double factor = T_ARG(0); |  | ||||||
| 
 | 
 | ||||||
|     REQUIRE_TRUE(rank >= 3, 0, "ADJUST_SATURATION: op expects rank of input array to be >= 3, but got %i instead", rank); |     REQUIRE_TRUE(rank >= 3, 0, "ADJUST_SATURATION: op expects rank of input array to be >= 3, but got %i instead", rank); | ||||||
|     REQUIRE_TRUE(input->sizeAt(dimC) == 3, 0, "ADJUST_SATURATION: operation expects image with 3 channels (R, G, B), but got %i instead", input->sizeAt(dimC)); |     REQUIRE_TRUE(input->sizeAt(dimC) == 3, 0, "ADJUST_SATURATION: operation expects image with 3 channels (R, G, B), but got %i instead", input->sizeAt(dimC)); | ||||||
|  |     REQUIRE_TRUE(block.numT() > 0 || block.width() > 1, 0, "ADJUST_SATURATION: scale factor is required !"); | ||||||
| 
 | 
 | ||||||
|     NDArray factorScalarArr = NDArrayFactory::create<double>(factor, block.launchContext()); |     NDArray* factor = nullptr; | ||||||
| 
 | 
 | ||||||
|     helpers::adjustSaturation(block.launchContext(), input, &factorScalarArr, output, dimC); |     if(block.width() > 1) | ||||||
|  |         factor = INPUT_VARIABLE(1); | ||||||
|  |     else { | ||||||
|  |         factor = new NDArray(output->dataType(), block.launchContext()); | ||||||
|  |         factor->p(0, T_ARG(0)); | ||||||
|  |     } | ||||||
|  | 
 | ||||||
|  |     helpers::adjustSaturation(block.launchContext(), input, factor, output, dimC); | ||||||
|  | 
 | ||||||
|  |     if(block.width() == 1) | ||||||
|  |         delete factor; | ||||||
| 
 | 
 | ||||||
|     return Status::OK(); |     return Status::OK(); | ||||||
| } | } | ||||||
|  | |||||||
| @ -1,5 +1,6 @@ | |||||||
| /*******************************************************************************
 | /*******************************************************************************
 | ||||||
|  * Copyright (c) 2015-2018 Skymind, Inc. |  * Copyright (c) 2015-2018 Skymind, Inc. | ||||||
|  |  * Copyright (c) 2019 Konduit K.K. | ||||||
|  * |  * | ||||||
|  * This program and the accompanying materials are made available under the |  * This program and the accompanying materials are made available under the | ||||||
|  * terms of the Apache License, Version 2.0 which is available at |  * terms of the Apache License, Version 2.0 which is available at | ||||||
| @ -15,27 +16,35 @@ | |||||||
|  ******************************************************************************/ |  ******************************************************************************/ | ||||||
| 
 | 
 | ||||||
| //
 | //
 | ||||||
| // Created by Yurii Shyrma on 13.12.2017.
 | // @author Yurii Shyrma (iuriish@yahoo.com)
 | ||||||
| //
 | //
 | ||||||
| 
 | 
 | ||||||
| #ifndef LIBND4J_POLYGAMMA_H | #include <op_boilerplate.h> | ||||||
| #define LIBND4J_POLYGAMMA_H | #if NOT_EXCLUDED(OP_digamma) | ||||||
| 
 | 
 | ||||||
| #include <ops/declarable/helpers/helpers.h> | #include <ops/declarable/CustomOperations.h> | ||||||
| #include "NDArray.h" | #include <ops/declarable/helpers/gammaMathFunc.h> | ||||||
| 
 | 
 | ||||||
| namespace nd4j { | namespace nd4j { | ||||||
| namespace ops { | namespace ops  { | ||||||
| namespace helpers { |  | ||||||
| 
 | 
 | ||||||
|  | CONFIGURABLE_OP_IMPL(digamma, 1, 1, false, 0, 0) { | ||||||
| 
 | 
 | ||||||
| 	// calculate the polygamma function
 |     auto x = INPUT_VARIABLE(0); | ||||||
|     void polyGamma(nd4j::LaunchContext * context, const NDArray& n, const NDArray& x, NDArray& output); |     auto z = OUTPUT_VARIABLE(0); | ||||||
|      | 
 | ||||||
|  |     helpers::diGamma(block.launchContext(), *x, *z); | ||||||
|  | 
 | ||||||
|  |     return Status::OK(); | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | DECLARE_TYPES(digamma) { | ||||||
|  |     getOpDescriptor() | ||||||
|  |             ->setAllowedInputTypes({ALL_FLOATS, ALL_INTS}) | ||||||
|  |             ->setSameMode(true); | ||||||
|  | } | ||||||
| 
 | 
 | ||||||
| } | } | ||||||
| } | } | ||||||
| } |  | ||||||
| 
 | 
 | ||||||
| 
 | #endif | ||||||
| #endif //LIBND4J_POLYGAMMA_H
 |  | ||||||
| @ -15,14 +15,14 @@ | |||||||
|  ******************************************************************************/ |  ******************************************************************************/ | ||||||
| 
 | 
 | ||||||
| //
 | //
 | ||||||
| // @author Yurii Shyrma (iuriish@yahoo.com), created on 13.12.2017
 | // @author Yurii Shyrma (iuriish@yahoo.com)
 | ||||||
| //
 | //
 | ||||||
| 
 | 
 | ||||||
| #include <op_boilerplate.h> | #include <op_boilerplate.h> | ||||||
| #if NOT_EXCLUDED(OP_polygamma) | #if NOT_EXCLUDED(OP_polygamma) | ||||||
| 
 | 
 | ||||||
| #include <ops/declarable/CustomOperations.h> | #include <ops/declarable/CustomOperations.h> | ||||||
| #include <ops/declarable/helpers/polyGamma.h> | #include <ops/declarable/helpers/gammaMathFunc.h> | ||||||
| 
 | 
 | ||||||
| namespace nd4j { | namespace nd4j { | ||||||
| namespace ops  { | namespace ops  { | ||||||
| @ -37,11 +37,11 @@ CONFIGURABLE_OP_IMPL(polygamma, 2, 1, false, 0, 0) { | |||||||
| 
 | 
 | ||||||
|     Nd4jLong arrLen = n->lengthOf(); |     Nd4jLong arrLen = n->lengthOf(); | ||||||
|     // FIXME: this shit should be single op call, not a loop!
 |     // FIXME: this shit should be single op call, not a loop!
 | ||||||
|     auto nPositive =  n->reduceNumber(nd4j::reduce::IsPositive, nullptr); |     auto nNegative = n->reduceNumber(nd4j::reduce::IsNegative, nullptr); | ||||||
|     auto xPositive =  x->reduceNumber(nd4j::reduce::IsPositive, nullptr); |     auto xPositive = x->reduceNumber(nd4j::reduce::IsPositive, nullptr); | ||||||
|     bool nPositiveFlag = nPositive.e<bool>(0); |     bool nPositiveFlag = !nNegative.e<bool>(0);                             // require all n >= 0
 | ||||||
|     bool xPositiveFlag = xPositive.e<bool>(0); |     bool xPositiveFlag = xPositive.e<bool>(0);                              // require all x > 0
 | ||||||
|     REQUIRE_TRUE(nPositiveFlag, 0, "POLYGAMMA op: all elements of n array must be > 0 !"); |     REQUIRE_TRUE(nPositiveFlag, 0, "POLYGAMMA op: all elements of n array must be >= 0 !"); | ||||||
|     REQUIRE_TRUE(xPositiveFlag, 0, "POLYGAMMA op: all elements of x array must be > 0 !"); |     REQUIRE_TRUE(xPositiveFlag, 0, "POLYGAMMA op: all elements of x array must be > 0 !"); | ||||||
| 
 | 
 | ||||||
|     helpers::polyGamma(block.launchContext(), *n, *x, *output); |     helpers::polyGamma(block.launchContext(), *n, *x, *output); | ||||||
|  | |||||||
| @ -513,7 +513,6 @@ namespace nd4j { | |||||||
|         /**
 |         /**
 | ||||||
|         * This op calculates polygamma function psi^(n)(x). Implementation is based on serial representation written in |         * This op calculates polygamma function psi^(n)(x). Implementation is based on serial representation written in | ||||||
|         * terms of the Hurwitz zeta function: polygamma = (-1)^{n+1} * n! * zeta(n+1, x). |         * terms of the Hurwitz zeta function: polygamma = (-1)^{n+1} * n! * zeta(n+1, x). | ||||||
|         * Currently the case n = 0 is not supported. |  | ||||||
|         * |         * | ||||||
|         * Input arrays: |         * Input arrays: | ||||||
|         *    0: n - define derivative order (n+1), type integer (however currently is implemented as float casted to integer) |         *    0: n - define derivative order (n+1), type integer (however currently is implemented as float casted to integer) | ||||||
| @ -528,6 +527,20 @@ namespace nd4j { | |||||||
|         DECLARE_CONFIGURABLE_OP(polygamma, 2, 1, false, 0, 0); |         DECLARE_CONFIGURABLE_OP(polygamma, 2, 1, false, 0, 0); | ||||||
|         #endif |         #endif | ||||||
| 
 | 
 | ||||||
|  |         /**
 | ||||||
|  |         * This op calculates digamma function psi(x) = derivative of log(Gamma(x)) | ||||||
|  |         * | ||||||
|  |         * Input arrays: | ||||||
|  |         *    0: x - abscissa points where to evaluate the digamma function, type float | ||||||
|  |         * | ||||||
|  |         * Output array: | ||||||
|  |         *    0: values of digamma function at corresponding x, type float | ||||||
|  |         * | ||||||
|  |         */ | ||||||
|  |         #if NOT_EXCLUDED(OP_digamma) | ||||||
|  |         DECLARE_CONFIGURABLE_OP(digamma, 1, 1, false, 0, 0); | ||||||
|  |         #endif | ||||||
|  | 
 | ||||||
|         /**
 |         /**
 | ||||||
|          * This operation takes shape as first argument, and returns new NDArray filled with specific scalar value. |          * This operation takes shape as first argument, and returns new NDArray filled with specific scalar value. | ||||||
|          * Input arrays: |          * Input arrays: | ||||||
| @ -575,44 +588,47 @@ namespace nd4j { | |||||||
|          * This operation adjusts image hue by delta |          * This operation adjusts image hue by delta | ||||||
|          * Input arrays: |          * Input arrays: | ||||||
|          * 0 - input array with rank >= 3, must have at least one dimension equal 3, that is dimension containing channels. |          * 0 - input array with rank >= 3, must have at least one dimension equal 3, that is dimension containing channels. | ||||||
|  |          * 1 - optional argument, input scalar-array containing delta | ||||||
|          * |          * | ||||||
|          * T arguments: |          * T arguments: | ||||||
|          * 0 - delta value |          * 0 - optional argument, delta value | ||||||
|          * |          * | ||||||
|          * Int arguments: |          * Int arguments: | ||||||
|          * 0 - optional argument, corresponds to dimension with 3 channels |          * 0 - optional argument, corresponds to dimension with 3 channels | ||||||
|          */ |          */ | ||||||
|         #if NOT_EXCLUDED(OP_adjust_hue) |         #if NOT_EXCLUDED(OP_adjust_hue) | ||||||
|         DECLARE_CONFIGURABLE_OP(adjust_hue, 1, 1, true, 1, -2); |         DECLARE_CONFIGURABLE_OP(adjust_hue, 1, 1, true, 0, 0); | ||||||
|         #endif |         #endif | ||||||
| 
 | 
 | ||||||
|         /**
 |         /**
 | ||||||
|          * This operation adjusts image saturation by delta |          * This operation adjusts image saturation by delta | ||||||
|          * Input arrays: |          * Input arrays: | ||||||
|          * 0 - input array with rank >= 3, must have at least one dimension equal 3, that is dimension containing channels. |          * 0 - input array with rank >= 3, must have at least one dimension equal 3, that is dimension containing channels. | ||||||
|  |          * 1 - optional argument, input scalar-array containing saturation factor | ||||||
|          * |          * | ||||||
|          * T arguments: |          * T arguments: | ||||||
|          * 0 - saturation factor |          * 0 - optional argument, saturation factor | ||||||
|          * |          * | ||||||
|          * Int arguments: |          * Int arguments: | ||||||
|          * 0 - optional argument, corresponds to dimension with 3 channels |          * 0 - optional argument, corresponds to dimension with 3 channels | ||||||
|          */ |          */ | ||||||
|         #if NOT_EXCLUDED(OP_adjust_saturation) |         #if NOT_EXCLUDED(OP_adjust_saturation) | ||||||
|         DECLARE_CONFIGURABLE_OP(adjust_saturation, 1, 1, true, 1, -2); |         DECLARE_CONFIGURABLE_OP(adjust_saturation, 1, 1, true, 0, 0); | ||||||
|         #endif |         #endif | ||||||
| 
 | 
 | ||||||
|         /**
 |         /**
 | ||||||
|          * This operation adjusts image contrast by given factor ( z = (x - mean) * factor + mean ) |          * This operation adjusts image contrast by given factor ( z = (x - mean) * factor + mean ) | ||||||
|          * Input arrays: |          * Input arrays: | ||||||
|          * 0 - input array with rank >= 3, must have last one dimension equal 3, that is dimension containing channels. |          * 0 - input array with rank >= 3, must have last one dimension equal 3, that is dimension containing channels. | ||||||
|  |          * 1 - optional argument, input scalar-array containing saturation contrast factor | ||||||
|          * |          * | ||||||
|          * T arguments: |          * T arguments: | ||||||
|          * 0 - contrast factor |          * 0 - optional argument, contrast factor | ||||||
|          * |          * | ||||||
|          */ |          */ | ||||||
|         #if NOT_EXCLUDED(OP_adjust_contrast) |         #if NOT_EXCLUDED(OP_adjust_contrast) | ||||||
|         DECLARE_CONFIGURABLE_OP(adjust_contrast, 1, 1, true, -2, 0); |         DECLARE_CONFIGURABLE_OP(adjust_contrast, 1, 1, true, 0, 0); | ||||||
|         DECLARE_CONFIGURABLE_OP(adjust_contrast_v2, 1, 1, true, -2, 0); |         DECLARE_CONFIGURABLE_OP(adjust_contrast_v2, 1, 1, true, 0, 0); | ||||||
|         #endif |         #endif | ||||||
| 
 | 
 | ||||||
| 
 | 
 | ||||||
| @ -1832,7 +1848,7 @@ namespace nd4j { | |||||||
|         #endif |         #endif | ||||||
| 
 | 
 | ||||||
|         /**
 |         /**
 | ||||||
|          * compare_and_bitpack - compare with greater and pack result with uint8  |          * compare_and_bitpack - compare with greater and pack result with uint8 | ||||||
|          * |          * | ||||||
|          * input params: |          * input params: | ||||||
|          *    0 - NDArray (input) |          *    0 - NDArray (input) | ||||||
|  | |||||||
| @ -107,12 +107,12 @@ static T betaIncCore(T a, T b, T x) { | |||||||
| 		return x; | 		return x; | ||||||
| 
 | 
 | ||||||
| 	const T gammaPart = lgamma(a) + lgamma(b) - lgamma(a + b); | 	const T gammaPart = lgamma(a) + lgamma(b) - lgamma(a + b); | ||||||
|     const T front = math::nd4j_exp<T,T>(math::nd4j_log<T, T>(x) * a + math::nd4j_log<T, T>(1 - x) * b - gammaPart) / a; |     const T front = math::nd4j_exp<T,T>(math::nd4j_log<T, T>(x) * a + math::nd4j_log<T, T>(1.f - x) * b - gammaPart); | ||||||
| 
 | 
 | ||||||
| 	if (x <= (a + static_cast<T>(1)) / (a + b + static_cast<T>(2))) | 	if (x <= (a + static_cast<T>(1)) / (a + b + static_cast<T>(2))) | ||||||
| 		return front * continuedFraction(a, b, x); | 		return front * continuedFraction(a, b, x) / a; | ||||||
| 	else // symmetry relation
 | 	else // symmetry relation
 | ||||||
| 		return static_cast<T>(1) - front * continuedFraction(b, a, static_cast<T>(1) - x); | 		return static_cast<T>(1) - front * continuedFraction(b, a, static_cast<T>(1) - x) / b; | ||||||
| 
 | 
 | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
|  | |||||||
							
								
								
									
										53
									
								
								libnd4j/include/ops/declarable/helpers/cpu/diGamma.cpp
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										53
									
								
								libnd4j/include/ops/declarable/helpers/cpu/diGamma.cpp
									
									
									
									
									
										Normal file
									
								
							| @ -0,0 +1,53 @@ | |||||||
|  | /*******************************************************************************
 | ||||||
|  |  * Copyright (c) 2015-2018 Skymind, Inc. | ||||||
|  |  * Copyright (c) 2019 Konduit K.K. | ||||||
|  |  * | ||||||
|  |  * This program and the accompanying materials are made available under the | ||||||
|  |  * terms of the Apache License, Version 2.0 which is available at | ||||||
|  |  * https://www.apache.org/licenses/LICENSE-2.0.
 | ||||||
|  |  * | ||||||
|  |  * Unless required by applicable law or agreed to in writing, software | ||||||
|  |  * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT | ||||||
|  |  * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the | ||||||
|  |  * License for the specific language governing permissions and limitations | ||||||
|  |  * under the License. | ||||||
|  |  * | ||||||
|  |  * SPDX-License-Identifier: Apache-2.0 | ||||||
|  |  ******************************************************************************/ | ||||||
|  | 
 | ||||||
|  | //
 | ||||||
|  | // @author Yurii Shyrma (iuriish@yahoo.com)
 | ||||||
|  | //
 | ||||||
|  | 
 | ||||||
|  | #include<ops/declarable/helpers/gammaMathFunc.h> | ||||||
|  | #include <execution/Threads.h> | ||||||
|  | 
 | ||||||
|  | namespace nd4j { | ||||||
|  | namespace ops { | ||||||
|  | namespace helpers { | ||||||
|  | 
 | ||||||
|  | //////////////////////////////////////////////////////////////////////////
 | ||||||
|  | // calculate digamma function for array elements
 | ||||||
|  | template <typename T> | ||||||
|  | static void diGamma_(const NDArray& x, NDArray& z) { | ||||||
|  | 
 | ||||||
|  | 	auto func = PRAGMA_THREADS_FOR { | ||||||
|  |         for (auto i = start; i < stop; i += increment) | ||||||
|  |             z.p(i, diGammaScalar<T>(x.e<T>(i))); | ||||||
|  |     }; | ||||||
|  | 	samediff::Threads::parallel_for(func, 0, x.lengthOf()); | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | void diGamma(nd4j::LaunchContext* context, const NDArray& x, NDArray& z) { | ||||||
|  | 
 | ||||||
|  | 	BUILD_SINGLE_SELECTOR(x.dataType(), diGamma_, (x, z), FLOAT_TYPES); | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | BUILD_SINGLE_TEMPLATE(template void diGamma_, (const NDArray& x, NDArray& z), FLOAT_TYPES); | ||||||
|  | 
 | ||||||
|  | 
 | ||||||
|  | 
 | ||||||
|  | } | ||||||
|  | } | ||||||
|  | } | ||||||
|  | 
 | ||||||
| @ -18,7 +18,7 @@ | |||||||
| // Created by Yurii Shyrma on 12.12.2017
 | // Created by Yurii Shyrma on 12.12.2017
 | ||||||
| //
 | //
 | ||||||
| 
 | 
 | ||||||
| #include<ops/declarable/helpers/polyGamma.h> | #include<ops/declarable/helpers/gammaMathFunc.h> | ||||||
| #include<ops/declarable/helpers/zeta.h> | #include<ops/declarable/helpers/zeta.h> | ||||||
| #include <NDArrayFactory.h> | #include <NDArrayFactory.h> | ||||||
| #include <execution/Threads.h> | #include <execution/Threads.h> | ||||||
| @ -42,7 +42,7 @@ static FORCEINLINE T getFactorial(const int n) { | |||||||
| 
 | 
 | ||||||
| 	for(int i = 2; i <= n; ++i) | 	for(int i = 2; i <= n; ++i) | ||||||
| 		result *= i; | 		result *= i; | ||||||
| 	 | 
 | ||||||
| 	return result; | 	return result; | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| @ -50,17 +50,15 @@ static FORCEINLINE T getFactorial(const int n) { | |||||||
| // implementation is based on serial representation written in terms of the Hurwitz zeta function as polygamma = (-1)^{n+1} * n! * zeta(n+1, x)
 | // implementation is based on serial representation written in terms of the Hurwitz zeta function as polygamma = (-1)^{n+1} * n! * zeta(n+1, x)
 | ||||||
| template <typename T> | template <typename T> | ||||||
| static FORCEINLINE T polyGammaScalar(nd4j::LaunchContext * context, const int n, const T x) { | static FORCEINLINE T polyGammaScalar(nd4j::LaunchContext * context, const int n, const T x) { | ||||||
| 	 | 
 | ||||||
| 	// if (n < 0) 
 | 	// if (n < 0)
 | ||||||
| 	// 	throw("polyGamma function: n must be >= 0 !");
 | 	// 	throw("polyGamma function: n must be >= 0 !");
 | ||||||
| 
 | 
 | ||||||
| 	// if (x <= (T)0.) 
 | 	// if (x <= (T)0.)
 | ||||||
| 	// 	throw("polyGamma function: x must be > 0 !");
 | 	// 	throw("polyGamma function: x must be > 0 !");
 | ||||||
| 	 |  | ||||||
| 	// TODO case for n = 0 (digamma)
 |  | ||||||
| 
 | 
 | ||||||
| 	int sign = (n + 1) % 2  ?  -1 : 1; | 	int sign = (n + 1) % 2  ?  -1 : 1; | ||||||
| 	// T factorial = (T)std::tgamma(n + 1);		
 | 	// T factorial = (T)std::tgamma(n + 1);
 | ||||||
| 
 | 
 | ||||||
| 	return sign * getFactorial<T>(n) * zetaScalar<T>((T)(n + 1), x); | 	return sign * getFactorial<T>(n) * zetaScalar<T>((T)(n + 1), x); | ||||||
| } | } | ||||||
| @ -71,17 +69,18 @@ static FORCEINLINE T polyGammaScalar(nd4j::LaunchContext * context, const int n, | |||||||
| template <typename T> | template <typename T> | ||||||
| static void polyGamma_(nd4j::LaunchContext * context, const NDArray& n, const NDArray& x, NDArray& output) { | static void polyGamma_(nd4j::LaunchContext * context, const NDArray& n, const NDArray& x, NDArray& output) { | ||||||
| 
 | 
 | ||||||
| 	NDArray& result = output; |  | ||||||
| 
 |  | ||||||
| 	int xLen = x.lengthOf(); |  | ||||||
| 
 |  | ||||||
| 	auto func = PRAGMA_THREADS_FOR { | 	auto func = PRAGMA_THREADS_FOR { | ||||||
|         for (auto i = start; i < stop; i += increment) |         for (auto i = start; i < stop; i += increment) { | ||||||
|             result.p(i, polyGammaScalar<T>(context, n.e<int>(i), x.e<T>(i))); |         	const T order = n.e<T>(i); | ||||||
|  |         	if(order != static_cast<int>(order))						// if order has fractional part then do not perform calculations and return NAN
 | ||||||
|  |         		output.p(i, std::numeric_limits<T>::quiet_NaN()); | ||||||
|  |         	else if (order == 0)										// polygamma function of zero order is digamma function
 | ||||||
|  |         		output.p(i, diGammaScalar<T>(x.e<T>(i))); | ||||||
|  |         	else | ||||||
|  |             	output.p(i, polyGammaScalar<T>(context, order, x.e<T>(i))); | ||||||
|  |         } | ||||||
|     }; |     }; | ||||||
| 	samediff::Threads::parallel_for(func, 0, x.lengthOf()); | 	samediff::Threads::parallel_for(func, 0, x.lengthOf()); | ||||||
| 
 |  | ||||||
| //	return result;
 |  | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| 	void polyGamma(nd4j::LaunchContext * context, const NDArray& n, const NDArray& x, NDArray& output) { | 	void polyGamma(nd4j::LaunchContext * context, const NDArray& n, const NDArray& x, NDArray& output) { | ||||||
|  | |||||||
| @ -89,20 +89,6 @@ __device__ T continuedFractionCuda(const T a, const T b, const T x) { | |||||||
|     return 1.f / 0.f;	// no convergence, more iterations is required |     return 1.f / 0.f;	// no convergence, more iterations is required | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| /////////////////////////////////////////////////////////////////// |  | ||||||
| // evaluates incomplete beta function for positive a and b, and x between 0 and 1. |  | ||||||
| template <typename T> |  | ||||||
| __device__ T betaIncCoreCuda(T a, T b, T x) { |  | ||||||
| 
 |  | ||||||
| 	const T gammaPart = lgamma(a) + lgamma(b) - lgamma(a + b); |  | ||||||
|     const T front = math::nd4j_exp<T,T>(math::nd4j_log<T, T>(x) * a + math::nd4j_log<T, T>(1 - x) * b - gammaPart) / a; |  | ||||||
| 
 |  | ||||||
| 	if (x <= (a + static_cast<T>(1)) / (a + b + static_cast<T>(2))) |  | ||||||
| 		return front * continuedFractionCuda(a, b, x); |  | ||||||
| 	else  // symmetry relation |  | ||||||
| 		return static_cast<T>(1) - front * continuedFractionCuda(b, a, static_cast<T>(1) - x); |  | ||||||
| } |  | ||||||
| 
 |  | ||||||
| /////////////////////////////////////////////////////////////////// | /////////////////////////////////////////////////////////////////// | ||||||
| template<typename T> | template<typename T> | ||||||
| __global__ void betaIncForArrayCuda(const void* va, const Nd4jLong* aShapeInfo, | __global__ void betaIncForArrayCuda(const void* va, const Nd4jLong* aShapeInfo, | ||||||
| @ -115,12 +101,21 @@ __global__ void betaIncForArrayCuda(const void* va, const Nd4jLong* aShapeInfo, | |||||||
| 
 | 
 | ||||||
|     const Nd4jLong j = blockIdx.x;			// one block per each element |     const Nd4jLong j = blockIdx.x;			// one block per each element | ||||||
| 
 | 
 | ||||||
|     Nd4jLong len = shape::length(xShapeInfo); |     T& z = *(reinterpret_cast<T*>(vz) + shape::getIndexOffset(j, zShapeInfo)); | ||||||
| 
 | 
 | ||||||
|     const T  a = *(reinterpret_cast<const T*>(va) + shape::getIndexOffset(j, aShapeInfo)); |     __shared__ T a, b, x; | ||||||
|     const T  b = *(reinterpret_cast<const T*>(vb) + shape::getIndexOffset(j, bShapeInfo)); |     __shared__ bool symmCond; | ||||||
|     const T  x = *(reinterpret_cast<const T*>(vx) + shape::getIndexOffset(j, xShapeInfo)); | 
 | ||||||
|     	  T& z = *(reinterpret_cast<T*>(vz) 	  + shape::getIndexOffset(j, zShapeInfo)); |     if (threadIdx.x == 0) { | ||||||
|  | 
 | ||||||
|  |     	a = *(reinterpret_cast<const T*>(va) + shape::getIndexOffset(j, aShapeInfo)); | ||||||
|  |     	b = *(reinterpret_cast<const T*>(vb) + shape::getIndexOffset(j, bShapeInfo)); | ||||||
|  |     	x = *(reinterpret_cast<const T*>(vx) + shape::getIndexOffset(j, xShapeInfo)); | ||||||
|  | 
 | ||||||
|  |     	symmCond = x <= (a + static_cast<T>(1)) / (a + b + static_cast<T>(2)); | ||||||
|  | 
 | ||||||
|  |     } | ||||||
|  |     __syncthreads(); | ||||||
| 
 | 
 | ||||||
|     // t^{n-1} * (1 - t)^{n-1} is symmetric function with respect to x = 0.5 |     // t^{n-1} * (1 - t)^{n-1} is symmetric function with respect to x = 0.5 | ||||||
|    	if(a == b && x == static_cast<T>(0.5)) { |    	if(a == b && x == static_cast<T>(0.5)) { | ||||||
| @ -135,17 +130,31 @@ __global__ void betaIncForArrayCuda(const void* va, const Nd4jLong* aShapeInfo, | |||||||
| 
 | 
 | ||||||
|    	if(threadIdx.x % 2 == 0) { 	/***** even part *****/ |    	if(threadIdx.x % 2 == 0) { 	/***** even part *****/ | ||||||
| 		const int m = threadIdx.x + 1; | 		const int m = threadIdx.x + 1; | ||||||
| 		sharedMem[threadIdx.x] = m * (b - m) * x / ((a + 2 * m - static_cast<T>(1)) * (a + 2 * m)); | 		if(symmCond) | ||||||
|  | 			sharedMem[threadIdx.x] = m * (b - m) * x / ((a + 2 * m - static_cast<T>(1)) * (a + 2 * m)); | ||||||
|  | 		else | ||||||
|  | 			sharedMem[threadIdx.x] = m * (a - m) * (1.f-x) / ((b + 2 * m - static_cast<T>(1)) * (b + 2 * m)); | ||||||
| 	} | 	} | ||||||
| 	else {						/***** odd part *****/ | 	else {						/***** odd part *****/ | ||||||
| 		const int m = threadIdx.x; | 		const int m = threadIdx.x; | ||||||
| 		sharedMem[threadIdx.x] = -(a + m) * (a + b + m) * x / ((a + 2 * m + static_cast<T>(1)) * (a + 2 * m)); | 		if(symmCond) | ||||||
|  | 			sharedMem[threadIdx.x] = -(a + m) * (a + b + m) * x / ((a + 2 * m + static_cast<T>(1)) * (a + 2 * m)); | ||||||
|  | 		else | ||||||
|  | 			sharedMem[threadIdx.x] = -(b + m) * (a + b + m) * (1.f-x) / ((b + 2 * m + static_cast<T>(1)) * (b + 2 * m)); | ||||||
| 	} | 	} | ||||||
| 
 | 
 | ||||||
| 	__syncthreads(); | 	__syncthreads(); | ||||||
| 
 | 
 | ||||||
| 	if(threadIdx.x == 0) | 	if(threadIdx.x == 0) { | ||||||
| 		z = betaIncCoreCuda(a, b, x); | 
 | ||||||
|  | 		const T gammaPart = lgamma(a) + lgamma(b) - lgamma(a + b); | ||||||
|  | 	    const T front = math::nd4j_exp<T,T>(math::nd4j_log<T, T>(x) * a + math::nd4j_log<T, T>(1.f - x) * b - gammaPart); | ||||||
|  | 
 | ||||||
|  | 		if (symmCond) | ||||||
|  | 			z =  front * continuedFractionCuda(a, b, x) / a; | ||||||
|  | 		else  // symmetry relation | ||||||
|  | 			z =  static_cast<T>(1) - front * continuedFractionCuda(b, a, static_cast<T>(1) - x) / b; | ||||||
|  | 	} | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| /////////////////////////////////////////////////////////////////// | /////////////////////////////////////////////////////////////////// | ||||||
|  | |||||||
							
								
								
									
										78
									
								
								libnd4j/include/ops/declarable/helpers/cuda/diGamma.cu
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										78
									
								
								libnd4j/include/ops/declarable/helpers/cuda/diGamma.cu
									
									
									
									
									
										Normal file
									
								
							| @ -0,0 +1,78 @@ | |||||||
|  | /******************************************************************************* | ||||||
|  |  * Copyright (c) 2015-2018 Skymind, Inc. | ||||||
|  |  * Copyright (c) 2019 Konduit K.K. | ||||||
|  |  * | ||||||
|  |  * This program and the accompanying materials are made available under the | ||||||
|  |  * terms of the Apache License, Version 2.0 which is available at | ||||||
|  |  * https://www.apache.org/licenses/LICENSE-2.0. | ||||||
|  |  * | ||||||
|  |  * Unless required by applicable law or agreed to in writing, software | ||||||
|  |  * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT | ||||||
|  |  * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the | ||||||
|  |  * License for the specific language governing permissions and limitations | ||||||
|  |  * under the License. | ||||||
|  |  * | ||||||
|  |  * SPDX-License-Identifier: Apache-2.0 | ||||||
|  |  ******************************************************************************/ | ||||||
|  | 
 | ||||||
|  | // | ||||||
|  | // @author Yurii Shyrma (iuriish@yahoo.com) | ||||||
|  | // | ||||||
|  | 
 | ||||||
|  | #include<ops/declarable/helpers/gammaMathFunc.h> | ||||||
|  | #include <NDArrayFactory.h> | ||||||
|  | 
 | ||||||
|  | namespace nd4j { | ||||||
|  | namespace ops { | ||||||
|  | namespace helpers { | ||||||
|  | 
 | ||||||
|  | /////////////////////////////////////////////////////////////////// | ||||||
|  | template<typename T> | ||||||
|  | __global__ static void diGammaCuda(const void *vx, const Nd4jLong *xShapeInfo, | ||||||
|  |                                      	 void *vz, const Nd4jLong *zShapeInfo) { | ||||||
|  | 
 | ||||||
|  |     const auto x = reinterpret_cast<const T*>(vx); | ||||||
|  |           auto z = reinterpret_cast<T*>(vz); | ||||||
|  | 
 | ||||||
|  |     __shared__ Nd4jLong len; | ||||||
|  |     __shared__ bool sameOffset; | ||||||
|  | 
 | ||||||
|  |     if (threadIdx.x == 0) { | ||||||
|  |         len = shape::length(xShapeInfo); | ||||||
|  |         sameOffset = shape::haveSameShapeAndStrides(xShapeInfo, zShapeInfo); | ||||||
|  |     } | ||||||
|  |     __syncthreads(); | ||||||
|  | 
 | ||||||
|  |     for (int i = blockIdx.x * blockDim.x + threadIdx.x; i < len; i += gridDim.x * blockDim.x) { | ||||||
|  | 
 | ||||||
|  |         const auto xOffset = shape::getIndexOffset(i, xShapeInfo); | ||||||
|  |         const auto zOffset = sameOffset ? xOffset : shape::getIndexOffset(i, zShapeInfo); | ||||||
|  | 
 | ||||||
|  |         z[zOffset] = diGammaScalar<T>(x[xOffset]); | ||||||
|  |     } | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | /////////////////////////////////////////////////////////////////// | ||||||
|  | template<typename T> | ||||||
|  | static void diGammaCudaLauncher(const int blocksPerGrid, const int threadsPerBlock, const cudaStream_t *stream, const void *vx, const Nd4jLong *xShapeInfo, void *vz, const Nd4jLong *zShapeInfo) { | ||||||
|  | 
 | ||||||
|  |     diGammaCuda<T><<<blocksPerGrid, threadsPerBlock, 1024, *stream>>>(vx, xShapeInfo, vz, zShapeInfo); | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | /////////////////////////////////////////////////////////////////// | ||||||
|  | void diGamma(nd4j::LaunchContext* context, const NDArray& x, NDArray& z) { | ||||||
|  | 
 | ||||||
|  |     int threadsPerBlock = MAX_NUM_THREADS / 2; | ||||||
|  |     int blocksPerGrid = (z.lengthOf() + threadsPerBlock - 1) / threadsPerBlock; | ||||||
|  | 
 | ||||||
|  |     NDArray::prepareSpecialUse({&z}, {&x}); | ||||||
|  |     BUILD_SINGLE_SELECTOR(x.dataType(), diGammaCudaLauncher, (blocksPerGrid, threadsPerBlock, context->getCudaStream(), x.getSpecialBuffer(), x.getSpecialShapeInfo(), z.getSpecialBuffer(), z.getSpecialShapeInfo()), FLOAT_TYPES); | ||||||
|  |     NDArray::registerSpecialUse({&z}, {&x}); | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | BUILD_SINGLE_TEMPLATE(template void diGammaCudaLauncher, (const int blocksPerGrid, const int threadsPerBlock, const cudaStream_t *stream, const void *vx, const Nd4jLong *xShapeInfo, void *vz, const Nd4jLong *zShapeInfo), FLOAT_TYPES); | ||||||
|  | 
 | ||||||
|  | } | ||||||
|  | } | ||||||
|  | } | ||||||
|  | 
 | ||||||
| @ -18,7 +18,7 @@ | |||||||
| // @author Yurii Shyrma (iuriish@yahoo.com), created on 26.04.2019 | // @author Yurii Shyrma (iuriish@yahoo.com), created on 26.04.2019 | ||||||
| // | // | ||||||
| 
 | 
 | ||||||
| #include<ops/declarable/helpers/polyGamma.h> | #include<ops/declarable/helpers/gammaMathFunc.h> | ||||||
| #include<ops/declarable/helpers/zeta.h> | #include<ops/declarable/helpers/zeta.h> | ||||||
| #include <NDArrayFactory.h> | #include <NDArrayFactory.h> | ||||||
| 
 | 
 | ||||||
| @ -37,9 +37,13 @@ __global__ static void polyGammaCuda(const void *vn, const Nd4jLong *nShapeInfo, | |||||||
|           auto z = reinterpret_cast<T*>(vz); |           auto z = reinterpret_cast<T*>(vz); | ||||||
| 
 | 
 | ||||||
|     __shared__ Nd4jLong len; |     __shared__ Nd4jLong len; | ||||||
|  |     __shared__ bool sameOffsetNX, sameOffsetNZ; | ||||||
| 
 | 
 | ||||||
|     if (threadIdx.x == 0) |     if (threadIdx.x == 0) { | ||||||
|         len = shape::length(nShapeInfo); |         len = shape::length(nShapeInfo); | ||||||
|  |         sameOffsetNX = shape::haveSameShapeAndStrides(xShapeInfo, nShapeInfo); | ||||||
|  |         sameOffsetNZ = shape::haveSameShapeAndStrides(zShapeInfo, nShapeInfo); | ||||||
|  |     } | ||||||
|     __syncthreads(); |     __syncthreads(); | ||||||
| 
 | 
 | ||||||
|     const auto tid = blockIdx.x * blockDim.x + threadIdx.x; |     const auto tid = blockIdx.x * blockDim.x + threadIdx.x; | ||||||
| @ -48,19 +52,26 @@ __global__ static void polyGammaCuda(const void *vn, const Nd4jLong *nShapeInfo, | |||||||
|     for (int i = tid; i < len; i += totalThreads) { |     for (int i = tid; i < len; i += totalThreads) { | ||||||
| 
 | 
 | ||||||
|         const auto nOffset = shape::getIndexOffset(i, nShapeInfo); |         const auto nOffset = shape::getIndexOffset(i, nShapeInfo); | ||||||
|         const auto xOffset = shape::getIndexOffset(i, xShapeInfo); |         const auto xOffset = sameOffsetNX ? nOffset : shape::getIndexOffset(i, xShapeInfo); | ||||||
|         const auto zOffset = shape::getIndexOffset(i, zShapeInfo); |         const auto zOffset = sameOffsetNZ ? nOffset : shape::getIndexOffset(i, zShapeInfo); | ||||||
| 
 | 
 | ||||||
|         const T nVal = n[nOffset]; |         const T order = n[nOffset]; | ||||||
| 
 | 
 | ||||||
|         int sign = (static_cast<int>(nVal) + 1) % 2  ?  -1 : 1; |         int sign = (static_cast<int>(order) + 1) % 2  ?  -1 : 1; | ||||||
| 
 | 
 | ||||||
|         T factorial = 1; |         if(order != static_cast<int>(order)) { | ||||||
|         if(nVal != 0 && nVal != 1) |             z[zOffset] = DataTypeUtils::nanOrZero<T>(); | ||||||
|         	for(int i = 2; i <= nVal; ++i) |         } | ||||||
| 				factorial *= i; |         else if(order == 0) { | ||||||
|  |             z[zOffset] = diGammaScalar<T>(x[xOffset]); | ||||||
|  |         } | ||||||
|  |         else { | ||||||
|  |             T factorial = 1; | ||||||
|  |             for(int i = 2; i <= order; ++i) | ||||||
|  |                 factorial *= i; | ||||||
| 
 | 
 | ||||||
|         z[zOffset] = sign * factorial * zetaScalar<T>(nVal + 1, x[xOffset]); |             z[zOffset] = sign * factorial * zetaScalar<T>(order + 1, x[xOffset]); | ||||||
|  |         } | ||||||
|     } |     } | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| @ -76,7 +87,7 @@ void polyGamma(nd4j::LaunchContext * context, const NDArray& n, const NDArray& x | |||||||
| 
 | 
 | ||||||
|     NDArray::prepareSpecialUse({&z}, {&n, &x}); |     NDArray::prepareSpecialUse({&z}, {&n, &x}); | ||||||
| 
 | 
 | ||||||
|     int threadsPerBlock = MAX_NUM_THREADS; |     int threadsPerBlock = MAX_NUM_THREADS / 2; | ||||||
|     int blocksPerGrid = (z.lengthOf() + threadsPerBlock - 1) / threadsPerBlock; |     int blocksPerGrid = (z.lengthOf() + threadsPerBlock - 1) / threadsPerBlock; | ||||||
| 
 | 
 | ||||||
|     BUILD_SINGLE_SELECTOR(n.dataType(), polyGammaCudaLauncher, (blocksPerGrid, threadsPerBlock, context->getCudaStream(), n.getSpecialBuffer(), n.getSpecialShapeInfo(), x.getSpecialBuffer(), x.getSpecialShapeInfo(), z.getSpecialBuffer(), z.getSpecialShapeInfo()), FLOAT_TYPES); |     BUILD_SINGLE_SELECTOR(n.dataType(), polyGammaCudaLauncher, (blocksPerGrid, threadsPerBlock, context->getCudaStream(), n.getSpecialBuffer(), n.getSpecialShapeInfo(), x.getSpecialBuffer(), x.getSpecialShapeInfo(), z.getSpecialBuffer(), z.getSpecialShapeInfo()), FLOAT_TYPES); | ||||||
|  | |||||||
							
								
								
									
										100
									
								
								libnd4j/include/ops/declarable/helpers/gammaMathFunc.h
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										100
									
								
								libnd4j/include/ops/declarable/helpers/gammaMathFunc.h
									
									
									
									
									
										Normal file
									
								
							| @ -0,0 +1,100 @@ | |||||||
|  | /*******************************************************************************
 | ||||||
|  |  * Copyright (c) 2015-2018 Skymind, Inc. | ||||||
|  |  * Copyright (c) 2019 Konduit K.K. | ||||||
|  |  * | ||||||
|  |  * This program and the accompanying materials are made available under the | ||||||
|  |  * terms of the Apache License, Version 2.0 which is available at | ||||||
|  |  * https://www.apache.org/licenses/LICENSE-2.0.
 | ||||||
|  |  * | ||||||
|  |  * Unless required by applicable law or agreed to in writing, software | ||||||
|  |  * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT | ||||||
|  |  * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the | ||||||
|  |  * License for the specific language governing permissions and limitations | ||||||
|  |  * under the License. | ||||||
|  |  * | ||||||
|  |  * SPDX-License-Identifier: Apache-2.0 | ||||||
|  |  ******************************************************************************/ | ||||||
|  | 
 | ||||||
|  | //
 | ||||||
|  | // @author Yurii Shyrma (iuriish@yahoo.com)
 | ||||||
|  | //
 | ||||||
|  | 
 | ||||||
|  | #ifndef LIBND4J_GAMMAMATHFUNC_H | ||||||
|  | #define LIBND4J_GAMMAMATHFUNC_H | ||||||
|  | 
 | ||||||
|  | #include <ops/declarable/helpers/helpers.h> | ||||||
|  | #include "NDArray.h" | ||||||
|  | 
 | ||||||
|  | namespace nd4j { | ||||||
|  | namespace ops { | ||||||
|  | namespace helpers { | ||||||
|  | 
 | ||||||
|  |     // calculate the digamma function for each element for array
 | ||||||
|  |     void diGamma(nd4j::LaunchContext* context, const NDArray& x, NDArray& z); | ||||||
|  | 
 | ||||||
|  | 	// calculate the polygamma function
 | ||||||
|  |     void polyGamma(nd4j::LaunchContext* context, const NDArray& n, const NDArray& x, NDArray& z); | ||||||
|  | 
 | ||||||
|  |     // calculate the digamma function for one element
 | ||||||
|  | 	// implementation is based on serial representation written in terms of the Hurwitz zeta function as polygamma = (-1)^{n+1} * n! * zeta(n+1, x)
 | ||||||
|  | 	template <typename T> | ||||||
|  | 	_CUDA_HD T diGammaScalar(T x) { | ||||||
|  | 
 | ||||||
|  | 		const int xInt = static_cast<int>(x); | ||||||
|  | 
 | ||||||
|  | 		// negative and zero
 | ||||||
|  | 		if(x <= 0) { | ||||||
|  | 			if(x == xInt)	// integer
 | ||||||
|  | 				return DataTypeUtils::infOrMax<T>(); | ||||||
|  | 			else | ||||||
|  | 				return diGammaScalar<T>(1 - x) - M_PI / nd4j::math::nd4j_tan<T,T>(M_PI * x); // use reflection formula psi(1-x) = psi(x) + pi*cot(pi*x)
 | ||||||
|  | 		} | ||||||
|  | 
 | ||||||
|  | 		// positive integer
 | ||||||
|  | 		if(x == xInt && xInt <= 20) {		// psi(n) = -Euler_Mascheroni_const + sum_from_k=1_to_n-1( 1/k ), for n = 1,2,3,...inf, we use this formula only for n <= 20 to avoid time consuming sum calculation for bigger n
 | ||||||
|  | 			T result = -0.577215664901532; | ||||||
|  | 			for (uint i = 1; i <= xInt - 1; ++i) { | ||||||
|  | 				result += static_cast<T>(1) / i; | ||||||
|  | 			} | ||||||
|  | 			return result; | ||||||
|  | 		} | ||||||
|  | 
 | ||||||
|  | 		// positive half-integer
 | ||||||
|  | 		if(x - xInt == 0.5 && xInt <= 20) {		// psi(n+0.5) = -Euler_Mascheroni_const - 2*ln(2) + sum_from_k=1_to_n( 2/(2*k-1) )	, for n = 1,2,3,...inf, we use this formula only for n <= 20 to avoid time consuming sum calculation for bigger n
 | ||||||
|  | 			T result = -0.577215664901532 - 2 * nd4j::math::nd4j_log<T,T>(2); | ||||||
|  | 			for (uint i = 1; i <= xInt; ++i) { | ||||||
|  | 				result += static_cast<T>(2) / (2*i - 1); | ||||||
|  | 			} | ||||||
|  | 			return result; | ||||||
|  | 		} | ||||||
|  | 
 | ||||||
|  | 		// positive, smaller then 5; we should use number > 5 in order to have satisfactory accuracy in asymptotic expansion
 | ||||||
|  | 		if(x < 5) | ||||||
|  | 			return diGammaScalar<T>(1 + x) - static_cast<T>(1) / x;		 // recurrence formula  psi(x) = psi(x+1) - 1/x.
 | ||||||
|  | 
 | ||||||
|  | 		// *** other positive **** //
 | ||||||
|  | 
 | ||||||
|  | 		// truncated expansion formula (from wiki)
 | ||||||
|  | 		// psi(x) = log(x) - 1/(2*x) - 1/(12*x^2) + 1/(120*x^4) - 1/(252*x^6) + 1/(240*x^8) - 5/(660*x^10) + 691/(32760*x^12) - 1/(12*x^14) + ...
 | ||||||
|  | 
 | ||||||
|  | 		if(x >= (sizeof(T) > 4 ? 1.e16 : 1.e8))		// if x is too big take into account only log(x)
 | ||||||
|  | 			return nd4j::math::nd4j_log<T,T>(x); | ||||||
|  | 
 | ||||||
|  | 		// coefficients used in truncated asymptotic expansion formula
 | ||||||
|  | 		const T coeffs[7] = {-(T)1/12, (T)1/120, -(T)1/252, (T)1/240, -(T)5/660, (T)691/32760, -(T)1/12}; | ||||||
|  | 		// const T coeffs[7] = {-0.0833333333333333, 0.00833333333333333, -0.00396825396825397, 0.00416666666666667, -0.00757575757575758, 0.0210927960927961, -0.0833333333333333};
 | ||||||
|  | 
 | ||||||
|  | 		const T x2Inv = static_cast<T>(1) / (x * x); | ||||||
|  | 		T result = 0; | ||||||
|  | 
 | ||||||
|  | 		for (int i = 6; i >= 0; --i) | ||||||
|  | 			result = (result + coeffs[i]) * x2Inv; | ||||||
|  | 		return result + nd4j::math::nd4j_log<T,T>(x) - static_cast<T>(0.5) / x; | ||||||
|  | 	} | ||||||
|  | 
 | ||||||
|  | } | ||||||
|  | } | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | 
 | ||||||
|  | #endif //LIBND4J_GAMMAMATHFUNC_H
 | ||||||
| @ -36,7 +36,7 @@ namespace platforms { | |||||||
| //////////////////////////////////////////////////////////////////////
 | //////////////////////////////////////////////////////////////////////
 | ||||||
| static void conv2d_mkldnn(nd4j::graph::Context &block, const NDArray *input, const NDArray *weights, | static void conv2d_mkldnn(nd4j::graph::Context &block, const NDArray *input, const NDArray *weights, | ||||||
|                           const NDArray *bias, NDArray *output, const int kH, const int kW, const int sH, |                           const NDArray *bias, NDArray *output, const int kH, const int kW, const int sH, | ||||||
|                           const int sW, int pH, int pW, const int dH, const int dW, const int isSameMode, |                           const int sW, int pH, int pW, const int dH, const int dW, const int paddingMode, | ||||||
|                           const int isNCHW) { |                           const int isNCHW) { | ||||||
| 
 | 
 | ||||||
|     int bS, iC, iH, iW, oC, oH, oW;                             // batch size, input channels, input height/width, output channels, output height/width;
 |     int bS, iC, iH, iW, oC, oH, oW;                             // batch size, input channels, input height/width, output channels, output height/width;
 | ||||||
| @ -44,8 +44,7 @@ static void conv2d_mkldnn(nd4j::graph::Context &block, const NDArray *input, con | |||||||
|     ConvolutionUtils::getSizesAndIndexesConv2d(isNCHW, *input, *output, bS, iC, iH, iW, oC, oH, oW, |     ConvolutionUtils::getSizesAndIndexesConv2d(isNCHW, *input, *output, bS, iC, iH, iW, oC, oH, oW, | ||||||
|                                                indIOioC, indIiH, indWiC, indWoC, indWkH, indOoH); |                                                indIOioC, indIiH, indWiC, indWoC, indWkH, indOoH); | ||||||
| 
 | 
 | ||||||
|     if(isSameMode)                       // SAME
 |     ConvolutionUtils::calcPadding2D(pH, pW, oH, oW, iH, iW, kH, kW, sH, sW, dH, dW, paddingMode); | ||||||
|         ConvolutionUtils::calcPadding2D(pH, pW, oH, oW, iH, iW, kH, kW, sH, sW, dH, dW); |  | ||||||
| 
 | 
 | ||||||
|     dnnl_memory_desc_t empty; |     dnnl_memory_desc_t empty; | ||||||
|     dnnl::memory::desc conv_src_md(empty), conv_weights_md(empty), conv_bias_md(empty), conv_dst_md( |     dnnl::memory::desc conv_src_md(empty), conv_weights_md(empty), conv_bias_md(empty), conv_dst_md( | ||||||
| @ -53,7 +52,7 @@ static void conv2d_mkldnn(nd4j::graph::Context &block, const NDArray *input, con | |||||||
|     dnnl::memory::desc user_src_md(empty), user_weights_md(empty), user_bias_md(empty), user_dst_md( |     dnnl::memory::desc user_src_md(empty), user_weights_md(empty), user_bias_md(empty), user_dst_md( | ||||||
|             empty); |             empty); | ||||||
|     dnnl::memory::dims conv_strides, conv_padding, conv_padding_r, conv_dilation; |     dnnl::memory::dims conv_strides, conv_padding, conv_padding_r, conv_dilation; | ||||||
|     mkldnnUtils::getMKLDNNMemoryDescConv2d(kH, kW, sH, sW, pH, pW, dH, dW, isSameMode, isNCHW, |     mkldnnUtils::getMKLDNNMemoryDescConv2d(kH, kW, sH, sW, pH, pW, dH, dW, paddingMode, isNCHW, | ||||||
|                                            bS, iC, iH, iW, oC, oH, oW, input, nullptr, weights, nullptr, |                                            bS, iC, iH, iW, oC, oH, oW, input, nullptr, weights, nullptr, | ||||||
|                                            bias, output, |                                            bias, output, | ||||||
|                                            &conv_src_md, nullptr, &conv_weights_md, nullptr, |                                            &conv_src_md, nullptr, &conv_weights_md, nullptr, | ||||||
| @ -115,13 +114,11 @@ static void conv2d_mkldnn(nd4j::graph::Context &block, const NDArray *input, con | |||||||
| 
 | 
 | ||||||
| //////////////////////////////////////////////////////////////////////
 | //////////////////////////////////////////////////////////////////////
 | ||||||
| PLATFORM_IMPL(conv2d) { | PLATFORM_IMPL(conv2d) { | ||||||
|     auto input = INPUT_VARIABLE( |     auto input = INPUT_VARIABLE(0);                                    // [bS, iH, iW, iC] (NHWC) or [bS, iC, iH, iW] (NCHW)
 | ||||||
|             0);                                    // [bS, iH, iW, iC] (NHWC) or [bS, iC, iH, iW] (NCHW)
 |  | ||||||
|     auto weights = INPUT_VARIABLE(1);                                    // [kH, kW, iC, oC] always
 |     auto weights = INPUT_VARIABLE(1);                                    // [kH, kW, iC, oC] always
 | ||||||
|     auto bias = block.width() > 2 ? INPUT_VARIABLE(2) : nullptr;      // [oC]
 |     auto bias = block.width() > 2 ? INPUT_VARIABLE(2) : nullptr;      // [oC]
 | ||||||
| 
 | 
 | ||||||
|     auto output = OUTPUT_VARIABLE( |     auto output = OUTPUT_VARIABLE(0);                                   // [bS, oH, oW, oC] (NHWC) or [bS, oC, oH, oW] (NCHW)
 | ||||||
|             0);                                   // [bS, oH, oW, oC] (NHWC) or [bS, oC, oH, oW] (NCHW)
 |  | ||||||
| 
 | 
 | ||||||
|     int sH = INT_ARG(2);                                                        // strides height
 |     int sH = INT_ARG(2);                                                        // strides height
 | ||||||
|     int sW = INT_ARG(3);                                                        // strides width
 |     int sW = INT_ARG(3);                                                        // strides width
 | ||||||
| @ -129,13 +126,13 @@ PLATFORM_IMPL(conv2d) { | |||||||
|     int pW = INT_ARG(5);                                                        // paddings width
 |     int pW = INT_ARG(5);                                                        // paddings width
 | ||||||
|     int dH = INT_ARG(6);                                                        // dilations height
 |     int dH = INT_ARG(6);                                                        // dilations height
 | ||||||
|     int dW = INT_ARG(7);                                                        // dilations width
 |     int dW = INT_ARG(7);                                                        // dilations width
 | ||||||
|     int isSameMode = INT_ARG(8);                                                // 0-VALID, 1-SAME
 |     int paddingMode = INT_ARG(8);                                                // 0-VALID, 1-SAME
 | ||||||
|     bool isNCHW = block.getIArguments()->size() > 9 ? !INT_ARG(9) : 1;       // INT_ARG(9): 0-NCHW,  1-NHWC
 |     bool isNCHW = block.getIArguments()->size() > 9 ? !INT_ARG(9) : 1;       // INT_ARG(9): 0-NCHW,  1-NHWC
 | ||||||
| 
 | 
 | ||||||
|     int kH = INT_ARG(0) > 0 ? INT_ARG(0) : static_cast<int>(weights->sizeAt(0)); // filter(kernel) height
 |     int kH = INT_ARG(0) > 0 ? INT_ARG(0) : static_cast<int>(weights->sizeAt(0)); // filter(kernel) height
 | ||||||
|     int kW = INT_ARG(1) > 0 ? INT_ARG(1) : static_cast<int>(weights->sizeAt(1)); // filter(kernel) width
 |     int kW = INT_ARG(1) > 0 ? INT_ARG(1) : static_cast<int>(weights->sizeAt(1)); // filter(kernel) width
 | ||||||
| 
 | 
 | ||||||
|     conv2d_mkldnn(block, input, weights, bias, output, kH, kW, sH, sW, pH, pW, dH, dW, isSameMode, isNCHW); |     conv2d_mkldnn(block, input, weights, bias, output, kH, kW, sH, sW, pH, pW, dH, dW, paddingMode, isNCHW); | ||||||
| 
 | 
 | ||||||
|     return Status::OK(); |     return Status::OK(); | ||||||
| } | } | ||||||
| @ -155,18 +152,13 @@ PLATFORM_CHECK(conv2d) { | |||||||
| 
 | 
 | ||||||
| //////////////////////////////////////////////////////////////////////
 | //////////////////////////////////////////////////////////////////////
 | ||||||
| PLATFORM_IMPL(conv2d_bp) { | PLATFORM_IMPL(conv2d_bp) { | ||||||
|     auto input = INPUT_VARIABLE( |     auto input = INPUT_VARIABLE(0);                                                // [bS, iH, iW, iC] (NHWC) or [bS, iC, iH, iW] (NCHW)
 | ||||||
|             0);                                                // [bS, iH, iW, iC] (NHWC) or [bS, iC, iH, iW] (NCHW)
 |     auto weights = INPUT_VARIABLE(1);                                                // [kH, kW, iC, oC] always
 | ||||||
|     auto weights = INPUT_VARIABLE( |  | ||||||
|             1);                                                // [kH, kW, iC, oC] always
 |  | ||||||
|     auto bias = block.width() > 3 ? INPUT_VARIABLE(2) : nullptr;                  // [oC]
 |     auto bias = block.width() > 3 ? INPUT_VARIABLE(2) : nullptr;                  // [oC]
 | ||||||
|     auto gradO = block.width() > 3 ? INPUT_VARIABLE(3) : INPUT_VARIABLE( |     auto gradO = block.width() > 3 ? INPUT_VARIABLE(3) : INPUT_VARIABLE(2);        // [bS, oH, oW, oC] (NHWC) or [bS, oC, oH, oW] (NCHW), epsilon_next
 | ||||||
|             2);        // [bS, oH, oW, oC] (NHWC) or [bS, oC, oH, oW] (NCHW), epsilon_next
 |  | ||||||
| 
 | 
 | ||||||
|     auto gradI = OUTPUT_VARIABLE( |     auto gradI = OUTPUT_VARIABLE(0);                                                 // [bS, iH, iW, iC] (NHWC) or [bS, iC, iH, iW] (NCHW), epsilon
 | ||||||
|             0);                                                 // [bS, iH, iW, iC] (NHWC) or [bS, iC, iH, iW] (NCHW), epsilon
 |     auto gradW = OUTPUT_VARIABLE(1);                                                 // [kH, kW, iC, oC] always
 | ||||||
|     auto gradW = OUTPUT_VARIABLE( |  | ||||||
|             1);                                                 // [kH, kW, iC, oC] always
 |  | ||||||
|     auto gradB = block.width() > 3 ? OUTPUT_VARIABLE(2) : nullptr;                   // [oC]
 |     auto gradB = block.width() > 3 ? OUTPUT_VARIABLE(2) : nullptr;                   // [oC]
 | ||||||
| 
 | 
 | ||||||
|     int kH = INT_ARG(0);                                                        // filter(kernel) height
 |     int kH = INT_ARG(0);                                                        // filter(kernel) height
 | ||||||
| @ -177,7 +169,7 @@ PLATFORM_IMPL(conv2d_bp) { | |||||||
|     int pW = INT_ARG(5);                                                        // paddings width
 |     int pW = INT_ARG(5);                                                        // paddings width
 | ||||||
|     int dH = INT_ARG(6);                                                        // dilations height
 |     int dH = INT_ARG(6);                                                        // dilations height
 | ||||||
|     int dW = INT_ARG(7);                                                        // dilations width
 |     int dW = INT_ARG(7);                                                        // dilations width
 | ||||||
|     int isSameMode = INT_ARG(8);                                                // 0-VALID, 1-SAME
 |     int paddingMode = INT_ARG(8);                                                // 0-VALID, 1-SAME
 | ||||||
|     int isNCHW = block.getIArguments()->size() > 9 ? !INT_ARG(9) : 1;          // INT_ARG(9): 0-NCHW, 1-NHWC
 |     int isNCHW = block.getIArguments()->size() > 9 ? !INT_ARG(9) : 1;          // INT_ARG(9): 0-NCHW, 1-NHWC
 | ||||||
| 
 | 
 | ||||||
|     REQUIRE_TRUE(input->rankOf() == 4, 0, |     REQUIRE_TRUE(input->rankOf() == 4, 0, | ||||||
| @ -195,8 +187,7 @@ PLATFORM_IMPL(conv2d_bp) { | |||||||
|     ConvolutionUtils::getSizesAndIndexesConv2d(isNCHW, *input, *gradO, bS, iC, iH, iW, oC, oH, oW, indIOioC, |     ConvolutionUtils::getSizesAndIndexesConv2d(isNCHW, *input, *gradO, bS, iC, iH, iW, oC, oH, oW, indIOioC, | ||||||
|                                                indIiH, indWiC, indWoC, indWkH, indOoH); |                                                indIiH, indWiC, indWoC, indWkH, indOoH); | ||||||
| 
 | 
 | ||||||
|     if (isSameMode)                       // SAME
 |     ConvolutionUtils::calcPadding2D(pH, pW, oH, oW, iH, iW, kH, kW, sH, sW, dH, dW, paddingMode); | ||||||
|         ConvolutionUtils::calcPadding2D(pH, pW, oH, oW, iH, iW, kH, kW, sH, sW, dH, dW); |  | ||||||
| 
 | 
 | ||||||
|     dnnl_memory_desc_t empty; |     dnnl_memory_desc_t empty; | ||||||
|     dnnl::memory::desc conv_src_md(empty), conv_diff_src_md(empty), conv_weights_md(empty), |     dnnl::memory::desc conv_src_md(empty), conv_diff_src_md(empty), conv_weights_md(empty), | ||||||
| @ -204,7 +195,7 @@ PLATFORM_IMPL(conv2d_bp) { | |||||||
|     dnnl::memory::desc user_src_md(empty), user_diff_src_md(empty), user_weights_md(empty), |     dnnl::memory::desc user_src_md(empty), user_diff_src_md(empty), user_weights_md(empty), | ||||||
|             user_diff_weights_md(empty), user_bias_md(empty), user_dst_md(empty); |             user_diff_weights_md(empty), user_bias_md(empty), user_dst_md(empty); | ||||||
|     dnnl::memory::dims conv_strides, conv_padding, conv_padding_r, conv_dilation; |     dnnl::memory::dims conv_strides, conv_padding, conv_padding_r, conv_dilation; | ||||||
|     mkldnnUtils::getMKLDNNMemoryDescConv2d(kH, kW, sH, sW, pH, pW, dH, dW, isSameMode, isNCHW, |     mkldnnUtils::getMKLDNNMemoryDescConv2d(kH, kW, sH, sW, pH, pW, dH, dW, paddingMode, isNCHW, | ||||||
|                                            bS, iC, iH, iW, oC, oH, oW, input, gradI, weights, gradW, |                                            bS, iC, iH, iW, oC, oH, oW, input, gradI, weights, gradW, | ||||||
|                                            gradB, gradO, |                                            gradB, gradO, | ||||||
|                                            &conv_src_md, &conv_diff_src_md, &conv_weights_md, |                                            &conv_src_md, &conv_diff_src_md, &conv_weights_md, | ||||||
| @ -342,18 +333,13 @@ PLATFORM_CHECK(conv2d_bp) { | |||||||
|     if (::optimalLevel() < 2) |     if (::optimalLevel() < 2) | ||||||
|         return false; |         return false; | ||||||
| 
 | 
 | ||||||
|     auto input = INPUT_VARIABLE( |     auto input = INPUT_VARIABLE(0);                                                // [bS, iH, iW, iC] (NHWC) or [bS, iC, iH, iW] (NCHW)
 | ||||||
|             0);                                                // [bS, iH, iW, iC] (NHWC) or [bS, iC, iH, iW] (NCHW)
 |     auto weights = INPUT_VARIABLE(1);                                                // [kH, kW, iC, oC] always
 | ||||||
|     auto weights = INPUT_VARIABLE( |  | ||||||
|             1);                                                // [kH, kW, iC, oC] always
 |  | ||||||
|     auto bias = block.width() > 3 ? INPUT_VARIABLE(2) : nullptr;                  // [oC]
 |     auto bias = block.width() > 3 ? INPUT_VARIABLE(2) : nullptr;                  // [oC]
 | ||||||
|     auto gradO = block.width() > 3 ? INPUT_VARIABLE(3) : INPUT_VARIABLE( |     auto gradO = block.width() > 3 ? INPUT_VARIABLE(3) : INPUT_VARIABLE(2);        // [bS, oH, oW, oC] (NHWC) or [bS, oC, oH, oW] (NCHW), epsilon_next
 | ||||||
|             2);        // [bS, oH, oW, oC] (NHWC) or [bS, oC, oH, oW] (NCHW), epsilon_next
 |  | ||||||
| 
 | 
 | ||||||
|     auto gradI = OUTPUT_VARIABLE( |     auto gradI = OUTPUT_VARIABLE(0);                                                 // [bS, iH, iW, iC] (NHWC) or [bS, iC, iH, iW] (NCHW), epsilon
 | ||||||
|             0);                                                 // [bS, iH, iW, iC] (NHWC) or [bS, iC, iH, iW] (NCHW), epsilon
 |     auto gradW = OUTPUT_VARIABLE(1);                                                 // [kH, kW, iC, oC] always
 | ||||||
|     auto gradW = OUTPUT_VARIABLE( |  | ||||||
|             1);                                                 // [kH, kW, iC, oC] always
 |  | ||||||
|     auto gradB = block.width() > 3 ? OUTPUT_VARIABLE(2) : nullptr;                   // [oC]
 |     auto gradB = block.width() > 3 ? OUTPUT_VARIABLE(2) : nullptr;                   // [oC]
 | ||||||
| 
 | 
 | ||||||
| 
 | 
 | ||||||
|  | |||||||
| @ -1568,11 +1568,39 @@ namespace simdOps { | |||||||
|             return opOutput + old; |             return opOutput + old; | ||||||
|         } |         } | ||||||
| 
 | 
 | ||||||
| 
 |  | ||||||
|         op_def static Z update(X old, X opOutput, X *extraParams) { |         op_def static Z update(X old, X opOutput, X *extraParams) { | ||||||
|             return opOutput + old; |             return opOutput + old; | ||||||
|         } |         } | ||||||
| 
 | 
 | ||||||
|  |         op_def static Z postProcess(X reduction, Nd4jLong n, X *extraParams) { | ||||||
|  |             return reduction; | ||||||
|  |         } | ||||||
|  | 	}; | ||||||
|  | 
 | ||||||
|  | 	template <typename X, typename Z> | ||||||
|  | 	class IsNegative { | ||||||
|  | 	public: | ||||||
|  | 		no_op_exec_special_bool | ||||||
|  | 		no_op_exec_special_bool_cuda | ||||||
|  | 
 | ||||||
|  | 		no_op_exec_special_accumulation | ||||||
|  | 		no_op_exec_special_accumulation_cuda | ||||||
|  | 
 | ||||||
|  | 		op_def static Z op(X d1, X *params) { | ||||||
|  | 			return d1 < (X)0.f; | ||||||
|  | 		} | ||||||
|  | 
 | ||||||
|  |         op_def static X startingValue(const X *input) { | ||||||
|  |             return static_cast<X>(0); | ||||||
|  |         } | ||||||
|  | 
 | ||||||
|  |         op_def static Z merge(X old, X opOutput, X *extraParams) { | ||||||
|  |             return opOutput + old; | ||||||
|  |         } | ||||||
|  | 
 | ||||||
|  |         op_def static Z update(X old, X opOutput, X *extraParams) { | ||||||
|  |             return opOutput + old; | ||||||
|  |         } | ||||||
| 
 | 
 | ||||||
|         op_def static Z postProcess(X reduction, Nd4jLong n, X *extraParams) { |         op_def static Z postProcess(X reduction, Nd4jLong n, X *extraParams) { | ||||||
|             return reduction; |             return reduction; | ||||||
|  | |||||||
| @ -1008,6 +1008,38 @@ TEST_F(ConvolutionTests1, conv1d_causal_6) { | |||||||
|     delete results; |     delete results; | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
|  | //////////////////////////////////////////////////////////////////////
 | ||||||
|  | TEST_F(ConvolutionTests1, conv1d_causal_7) { | ||||||
|  | 
 | ||||||
|  |     int bS=2, iW=8,  iC=3,oC=4,  kW=2,  sW=1,  pW=0,  dW=1; | ||||||
|  |     int oW = (iW-1)/sW + 1; | ||||||
|  |     int paddingMode = 2;             // CAUSAL
 | ||||||
|  |     int dataFormat  = 1;             // 1-NHWC, 0-NCHW
 | ||||||
|  | 
 | ||||||
|  |     NDArray input('c', {bS, iW, iC}, nd4j::DataType::FLOAT32); | ||||||
|  |     NDArray weights('c', {kW, iC, oC}, nd4j::DataType::FLOAT32); | ||||||
|  | 
 | ||||||
|  |     NDArray expOutput('c', {bS, oW, oC}, {11.000000, 11.600000, 12.200000, 12.800000, 30.099998, 32.200001, 34.299999, 36.400002, 49.899998, 53.800003, 57.699997, | ||||||
|  |         61.599998, 69.699997, 75.400002, 81.099998, 86.800003, 89.500000, 97.000000, 104.500000, 112.000000, 109.300003, 118.600006, 127.899994, 137.199997, 129.100006, | ||||||
|  |         140.199997, 151.300003, 162.399994, 148.899994, 161.800003, 174.699997, 187.600006, 133.399994, 141.200012, 149.000000, 156.800003, 188.500000, 205.000000, | ||||||
|  |         221.500000, 238.000000, 208.299988, 226.600006, 244.899994, 263.200012, 228.100006, 248.200012, 268.299988, 288.399994, 247.899994, 269.799988, 291.700012, | ||||||
|  |         313.600006, 267.700012, 291.399994, 315.100006, 338.799988, 287.500000, 313.000000, 338.500000, 364.000000, 307.299988, 334.600006, 361.899994, 389.200012}, nd4j::DataType::FLOAT32); | ||||||
|  | 
 | ||||||
|  |     input.linspace(1., 1.); | ||||||
|  |     weights.linspace(0.1, 0.1); | ||||||
|  | 
 | ||||||
|  |     nd4j::ops::conv1d op; | ||||||
|  |     auto results = op.execute({&input, &weights}, {}, {kW, sW, pW, dW,  paddingMode, dataFormat}); | ||||||
|  |     auto output = results->at(0); | ||||||
|  | 
 | ||||||
|  |     ASSERT_EQ(Status::OK(), results->status()); | ||||||
|  | 
 | ||||||
|  |     ASSERT_TRUE(expOutput.isSameShape(output)); | ||||||
|  |     ASSERT_TRUE(expOutput.equalsTo(output)); | ||||||
|  | 
 | ||||||
|  |     delete results; | ||||||
|  | } | ||||||
|  | 
 | ||||||
| //////////////////////////////////////////////////////////////////////
 | //////////////////////////////////////////////////////////////////////
 | ||||||
| TEST_F(ConvolutionTests1, conv1d_causal_bp_1) { | TEST_F(ConvolutionTests1, conv1d_causal_bp_1) { | ||||||
| 
 | 
 | ||||||
| @ -1174,14 +1206,14 @@ TYPED_TEST(TypedConvolutionTests1, conv2d_bp_test3) { | |||||||
|     auto bias = NDArrayFactory::create<TypeParam>('c', {oC}, {1,2,3}); |     auto bias = NDArrayFactory::create<TypeParam>('c', {oC}, {1,2,3}); | ||||||
|     auto gradO = NDArrayFactory::create<TypeParam>('c', {bS, oC, oH, oW}); |     auto gradO = NDArrayFactory::create<TypeParam>('c', {bS, oC, oH, oW}); | ||||||
| 
 | 
 | ||||||
|     auto expGradI = NDArrayFactory::create<TypeParam>('c', {bS, iC, iH, iW},{ 0.567f, 1.224f, 0.66f, 1.314f, 2.82f, 1.512f, 1.386f, 2.976f, 1.596f, 0.801f, 1.71f, 0.912f, 0.657f, 1.422f, 0.768f, 1.53f, 3.288f, 1.764f, 1.602f, 3.444f, 1.848f, 0.927f, 1.98f, 1.056f,  |     auto expGradI = NDArrayFactory::create<TypeParam>('c', {bS, iC, iH, iW},{ 0.567f, 1.224f, 0.66f, 1.314f, 2.82f, 1.512f, 1.386f, 2.976f, 1.596f, 0.801f, 1.71f, 0.912f, 0.657f, 1.422f, 0.768f, 1.53f, 3.288f, 1.764f, 1.602f, 3.444f, 1.848f, 0.927f, 1.98f, 1.056f, | ||||||
|                                                      0.747f, 1.62f, 0.876f, 1.746f, 3.756f, 2.016f, 1.818f, 3.912f, 2.1f, 1.053f, 2.25f, 1.2f, 0.837f, 1.818f, 0.984f, 1.962f, 4.224f, 2.268f, 2.034f, 4.38f, 2.352f, 1.179f, 2.52f, 1.344f,  |                                                      0.747f, 1.62f, 0.876f, 1.746f, 3.756f, 2.016f, 1.818f, 3.912f, 2.1f, 1.053f, 2.25f, 1.2f, 0.837f, 1.818f, 0.984f, 1.962f, 4.224f, 2.268f, 2.034f, 4.38f, 2.352f, 1.179f, 2.52f, 1.344f, | ||||||
|                                                      1.467f, 3.06f, 1.596f, 3.186f, 6.636f, 3.456f, 3.402f, 7.08f, 3.684f, 1.845f, 3.834f, 1.992f, 1.773f, 3.69f, 1.92f, 3.834f, 7.968f, 4.14f, 4.05f, 8.412f, 4.368f, 2.187f, 4.536f, 2.352f,  |                                                      1.467f, 3.06f, 1.596f, 3.186f, 6.636f, 3.456f, 3.402f, 7.08f, 3.684f, 1.845f, 3.834f, 1.992f, 1.773f, 3.69f, 1.92f, 3.834f, 7.968f, 4.14f, 4.05f, 8.412f, 4.368f, 2.187f, 4.536f, 2.352f, | ||||||
|                                                      2.079f, 4.32f, 2.244f, 4.482f, 9.3f, 4.824f, 4.698f, 9.744f, 5.052f, 2.529f, 5.238f, 2.712f, 2.385f, 4.95f, 2.568f, 5.13f, 10.632f, 5.508f, 5.346f, 11.076f, 5.736f, 2.871f, 5.94f, 3.072f}); |                                                      2.079f, 4.32f, 2.244f, 4.482f, 9.3f, 4.824f, 4.698f, 9.744f, 5.052f, 2.529f, 5.238f, 2.712f, 2.385f, 4.95f, 2.568f, 5.13f, 10.632f, 5.508f, 5.346f, 11.076f, 5.736f, 2.871f, 5.94f, 3.072f}); | ||||||
| 
 | 
 | ||||||
|     auto expGradW = NDArrayFactory::create<TypeParam>('c', {oC, iC, kH, kW},{1.3600e+00f, 1.3600e+00f, 1.3600e+00f, 1.3600e+00f, 1.3600e+00f, 1.3600e+00f, 1.3600e+00f, 1.3600e+00f, 1.3600e+00f, 1.3600e+00f, 1.3600e+00f, 1.3600e+00f, 1.3600e+00f, 1.3600e+00f, 1.3600e+00f, 1.3600e+00f, 1.3600e+00f, 1.3600e+00f,  |     auto expGradW = NDArrayFactory::create<TypeParam>('c', {oC, iC, kH, kW},{1.3600e+00f, 1.3600e+00f, 1.3600e+00f, 1.3600e+00f, 1.3600e+00f, 1.3600e+00f, 1.3600e+00f, 1.3600e+00f, 1.3600e+00f, 1.3600e+00f, 1.3600e+00f, 1.3600e+00f, 1.3600e+00f, 1.3600e+00f, 1.3600e+00f, 1.3600e+00f, 1.3600e+00f, 1.3600e+00f, | ||||||
|                                                     1.3600e+00f, 1.3600e+00f, 1.3600e+00f, 1.3600e+00f, 1.3600e+00f, 1.3600e+00f, 2.0000e+00f, 2.0000e+00f, 2.0000e+00f, 2.0000e+00f, 2.0000e+00f, 2.0000e+00f, 2.0000e+00f, 2.0000e+00f, 2.0000e+00f, 2.0000e+00f, 2.0000e+00f, 2.0000e+00f,  |                                                     1.3600e+00f, 1.3600e+00f, 1.3600e+00f, 1.3600e+00f, 1.3600e+00f, 1.3600e+00f, 2.0000e+00f, 2.0000e+00f, 2.0000e+00f, 2.0000e+00f, 2.0000e+00f, 2.0000e+00f, 2.0000e+00f, 2.0000e+00f, 2.0000e+00f, 2.0000e+00f, 2.0000e+00f, 2.0000e+00f, | ||||||
|                                                     2.0000e+00f, 2.0000e+00f, 2.0000e+00f, 2.0000e+00f, 2.0000e+00f, 2.0000e+00f, 2.0000e+00f, 2.0000e+00f, 2.0000e+00f, 2.0000e+00f, 2.0000e+00f, 2.0000e+00f, 2.6400e+00f, 2.6400e+00f, 2.6400e+00f, 2.6400e+00f, 2.6400e+00f, 2.6400e+00f,  |                                                     2.0000e+00f, 2.0000e+00f, 2.0000e+00f, 2.0000e+00f, 2.0000e+00f, 2.0000e+00f, 2.0000e+00f, 2.0000e+00f, 2.0000e+00f, 2.0000e+00f, 2.0000e+00f, 2.0000e+00f, 2.6400e+00f, 2.6400e+00f, 2.6400e+00f, 2.6400e+00f, 2.6400e+00f, 2.6400e+00f, | ||||||
|                                                     2.6400e+00f, 2.6400e+00f, 2.6400e+00f, 2.6400e+00f, 2.6400e+00f, 2.6400e+00f, 2.6400e+00f, 2.6400e+00f, 2.6400e+00f, 2.6400e+00f, 2.6400e+00f, 2.6400e+00f, 2.6400e+00f, 2.6400e+00f, 2.6400e+00f, 2.6400e+00f, 2.6400e+00f, 2.6400e+00f}); |                                                     2.6400e+00f, 2.6400e+00f, 2.6400e+00f, 2.6400e+00f, 2.6400e+00f, 2.6400e+00f, 2.6400e+00f, 2.6400e+00f, 2.6400e+00f, 2.6400e+00f, 2.6400e+00f, 2.6400e+00f, 2.6400e+00f, 2.6400e+00f, 2.6400e+00f, 2.6400e+00f, 2.6400e+00f, 2.6400e+00f}); | ||||||
|     auto expGradB = NDArrayFactory::create<TypeParam>('c', {oC},{0.68f, 1.f, 1.32f}); |     auto expGradB = NDArrayFactory::create<TypeParam>('c', {oC},{0.68f, 1.f, 1.32f}); | ||||||
| 
 | 
 | ||||||
| @ -1252,20 +1284,20 @@ TYPED_TEST(TypedConvolutionTests1, conv3d_bp_test1) { | |||||||
|     auto bias     = NDArrayFactory::create<TypeParam>('c', {oC}, {1,2,3}); |     auto bias     = NDArrayFactory::create<TypeParam>('c', {oC}, {1,2,3}); | ||||||
|     auto gradO    = NDArrayFactory::create<TypeParam>('c', {bS, oD, oH, oW, oC}); |     auto gradO    = NDArrayFactory::create<TypeParam>('c', {bS, oD, oH, oW, oC}); | ||||||
| 
 | 
 | ||||||
|     auto expGradI = NDArrayFactory::create<TypeParam>('c', {bS, iD, iH, iW, iC},{0.226f, 0.343f, 0.46f, 0.577f, 1.172f, 1.46f, 1.748f, 2.036f, 1.892f, 2.288f, 2.684f, 3.08f, 1.284f, 1.581f, 1.878f, 2.175f, 4.458f, 5.133f, 5.808f, 6.483f, 6.186f, 7.023f, 7.86f, 8.697f, 3.39f, 3.93f, 4.47f, 5.01f, 9.642f, 10.803f, 11.964f, 13.125f, 11.37f, 12.693f, 14.016f, 15.339f,  |     auto expGradI = NDArrayFactory::create<TypeParam>('c', {bS, iD, iH, iW, iC},{0.226f, 0.343f, 0.46f, 0.577f, 1.172f, 1.46f, 1.748f, 2.036f, 1.892f, 2.288f, 2.684f, 3.08f, 1.284f, 1.581f, 1.878f, 2.175f, 4.458f, 5.133f, 5.808f, 6.483f, 6.186f, 7.023f, 7.86f, 8.697f, 3.39f, 3.93f, 4.47f, 5.01f, 9.642f, 10.803f, 11.964f, 13.125f, 11.37f, 12.693f, 14.016f, 15.339f, | ||||||
|                                                         5.266f, 5.707f, 6.148f, 6.589f, 12.98f, 13.916f, 14.852f, 15.788f, 14.564f, 15.608f, 16.652f, 17.696f, 6.284f, 7.166f, 8.048f, 8.93f, 17.896f, 19.768f, 21.64f, 23.512f, 21.928f, 24.016f, 26.104f, 28.192f, 18.12f, 19.686f, 21.252f, 22.818f, 45.852f, 49.146f, 52.44f, 55.734f, 53.196f, 56.814f, 60.432f, 64.05f,  |                                                         5.266f, 5.707f, 6.148f, 6.589f, 12.98f, 13.916f, 14.852f, 15.788f, 14.564f, 15.608f, 16.652f, 17.696f, 6.284f, 7.166f, 8.048f, 8.93f, 17.896f, 19.768f, 21.64f, 23.512f, 21.928f, 24.016f, 26.104f, 28.192f, 18.12f, 19.686f, 21.252f, 22.818f, 45.852f, 49.146f, 52.44f, 55.734f, 53.196f, 56.814f, 60.432f, 64.05f, | ||||||
|                                                        28.164f, 30.216f, 32.268f, 34.32f, 67.884f, 72.15f, 76.416f, 80.682f, 75.228f, 79.818f, 84.408f, 88.998f, 29.324f, 30.854f, 32.384f, 33.914f, 67.432f, 70.6f, 73.768f, 76.936f, 73.192f, 76.576f, 79.96f, 83.344f, 27.884f, 30.062f, 32.24f, 34.418f, 66.28f, 70.744f, 75.208f, 79.672f, 70.312f, 74.992f, 79.672f, 84.352f,  |                                                        28.164f, 30.216f, 32.268f, 34.32f, 67.884f, 72.15f, 76.416f, 80.682f, 75.228f, 79.818f, 84.408f, 88.998f, 29.324f, 30.854f, 32.384f, 33.914f, 67.432f, 70.6f, 73.768f, 76.936f, 73.192f, 76.576f, 79.96f, 83.344f, 27.884f, 30.062f, 32.24f, 34.418f, 66.28f, 70.744f, 75.208f, 79.672f, 70.312f, 74.992f, 79.672f, 84.352f, | ||||||
|                                                        58.296f, 61.806f, 65.316f, 68.826f, 133.98f, 141.162f, 148.344f, 155.526f, 141.324f, 148.83f, 156.336f, 163.842f, 68.34f, 72.336f, 76.332f, 80.328f, 156.012f, 164.166f, 172.32f, 180.474f, 163.356f, 171.834f, 180.312f, 188.79f, 61.292f, 64.118f, 66.944f, 69.77f, 136.552f, 142.312f, 148.072f, 153.832f, 142.312f, 148.288f, 154.264f, 160.24f,  |                                                        58.296f, 61.806f, 65.316f, 68.826f, 133.98f, 141.162f, 148.344f, 155.526f, 141.324f, 148.83f, 156.336f, 163.842f, 68.34f, 72.336f, 76.332f, 80.328f, 156.012f, 164.166f, 172.32f, 180.474f, 163.356f, 171.834f, 180.312f, 188.79f, 61.292f, 64.118f, 66.944f, 69.77f, 136.552f, 142.312f, 148.072f, 153.832f, 142.312f, 148.288f, 154.264f, 160.24f, | ||||||
|                                                         9.298f, 11.359f, 13.42f, 15.481f, 27.092f, 31.268f, 35.444f, 39.62f, 27.812f, 32.096f, 36.38f, 40.664f, 26.556f, 29.769f, 32.982f, 36.195f, 66.666f, 73.173f, 79.68f, 86.187f, 68.394f, 75.063f, 81.732f, 88.401f, 28.662f, 32.118f, 35.574f, 39.03f, 71.85f, 78.843f, 85.836f, 92.829f, 73.578f, 80.733f, 87.888f, 95.043f,  |                                                         9.298f, 11.359f, 13.42f, 15.481f, 27.092f, 31.268f, 35.444f, 39.62f, 27.812f, 32.096f, 36.38f, 40.664f, 26.556f, 29.769f, 32.982f, 36.195f, 66.666f, 73.173f, 79.68f, 86.187f, 68.394f, 75.063f, 81.732f, 88.401f, 28.662f, 32.118f, 35.574f, 39.03f, 71.85f, 78.843f, 85.836f, 92.829f, 73.578f, 80.733f, 87.888f, 95.043f, | ||||||
|                                                        29.89f, 32.275f, 34.66f, 37.045f, 70.004f, 74.828f, 79.652f, 84.476f, 71.588f, 76.52f, 81.452f, 86.384f, 71.084f, 75.854f, 80.624f, 85.394f, 163.048f, 172.696f, 182.344f, 191.992f, 167.08f, 176.944f, 186.808f, 196.672f, 138.648f, 146.046f, 153.444f, 160.842f, 310.236f, 325.194f, 340.152f, 355.11f, 317.58f, 332.862f, 348.144f, 363.426f,  |                                                        29.89f, 32.275f, 34.66f, 37.045f, 70.004f, 74.828f, 79.652f, 84.476f, 71.588f, 76.52f, 81.452f, 86.384f, 71.084f, 75.854f, 80.624f, 85.394f, 163.048f, 172.696f, 182.344f, 191.992f, 167.08f, 176.944f, 186.808f, 196.672f, 138.648f, 146.046f, 153.444f, 160.842f, 310.236f, 325.194f, 340.152f, 355.11f, 317.58f, 332.862f, 348.144f, 363.426f, | ||||||
|                                                       148.692f, 156.576f, 164.46f, 172.344f, 332.268f, 348.198f, 364.128f, 380.058f, 339.612f, 355.866f, 372.12f, 388.374f, 125.228f, 130.646f, 136.064f, 141.482f, 274.792f, 285.736f, 296.68f, 307.624f, 280.552f, 291.712f, 302.872f, 314.032f, 92.684f, 98.75f, 104.816f, 110.882f, 211.432f, 223.672f, 235.912f, 248.152f, 215.464f, 227.92f, 240.376f, 252.832f,  |                                                       148.692f, 156.576f, 164.46f, 172.344f, 332.268f, 348.198f, 364.128f, 380.058f, 339.612f, 355.866f, 372.12f, 388.374f, 125.228f, 130.646f, 136.064f, 141.482f, 274.792f, 285.736f, 296.68f, 307.624f, 280.552f, 291.712f, 302.872f, 314.032f, 92.684f, 98.75f, 104.816f, 110.882f, 211.432f, 223.672f, 235.912f, 248.152f, 215.464f, 227.92f, 240.376f, 252.832f, | ||||||
|                                                       178.824f, 188.166f, 197.508f, 206.85f, 398.364f, 417.21f, 436.056f, 454.902f, 405.708f, 424.878f, 444.048f, 463.218f, 188.868f, 198.696f, 208.524f, 218.352f, 420.396f, 440.214f, 460.032f, 479.85f, 427.74f, 447.882f, 468.024f, 488.166f, 157.196f, 163.91f, 170.624f, 177.338f, 343.912f, 357.448f, 370.984f, 384.52f, 349.672f, 363.424f, 377.176f, 390.928f}); |                                                       178.824f, 188.166f, 197.508f, 206.85f, 398.364f, 417.21f, 436.056f, 454.902f, 405.708f, 424.878f, 444.048f, 463.218f, 188.868f, 198.696f, 208.524f, 218.352f, 420.396f, 440.214f, 460.032f, 479.85f, 427.74f, 447.882f, 468.024f, 488.166f, 157.196f, 163.91f, 170.624f, 177.338f, 343.912f, 357.448f, 370.984f, 384.52f, 349.672f, 363.424f, 377.176f, 390.928f}); | ||||||
| 
 | 
 | ||||||
|     auto expGradW = NDArrayFactory::create<TypeParam>('c', {kD, kH, kW, iC, oC},{120.96f, 122.04f, 123.12f, 120.96f, 122.04f, 123.12f, 120.96f, 122.04f, 123.12f, 120.96f, 122.04f, 123.12f, 79.56f, 80.28f, 81.f, 79.56f, 80.28f, 81.f, 79.56f, 80.28f, 81.f, 79.56f, 80.28f, 81.f,  |     auto expGradW = NDArrayFactory::create<TypeParam>('c', {kD, kH, kW, iC, oC},{120.96f, 122.04f, 123.12f, 120.96f, 122.04f, 123.12f, 120.96f, 122.04f, 123.12f, 120.96f, 122.04f, 123.12f, 79.56f, 80.28f, 81.f, 79.56f, 80.28f, 81.f, 79.56f, 80.28f, 81.f, 79.56f, 80.28f, 81.f, | ||||||
|                                                         154.8f, 156.24f, 157.68f, 154.8f, 156.24f, 157.68f, 154.8f, 156.24f, 157.68f, 154.8f, 156.24f, 157.68f, 101.76f, 102.72f, 103.68f, 101.76f, 102.72f, 103.68f, 101.76f, 102.72f, 103.68f, 101.76f, 102.72f, 103.68f,  |                                                         154.8f, 156.24f, 157.68f, 154.8f, 156.24f, 157.68f, 154.8f, 156.24f, 157.68f, 154.8f, 156.24f, 157.68f, 101.76f, 102.72f, 103.68f, 101.76f, 102.72f, 103.68f, 101.76f, 102.72f, 103.68f, 101.76f, 102.72f, 103.68f, | ||||||
|                                                         111.24f, 112.32f, 113.4f, 111.24f, 112.32f, 113.4f, 111.24f, 112.32f, 113.4f, 111.24f, 112.32f, 113.4f, 73.08f, 73.8f, 74.52f, 73.08f, 73.8f, 74.52f, 73.08f, 73.8f, 74.52f, 73.08f, 73.8f, 74.52f,  |                                                         111.24f, 112.32f, 113.4f, 111.24f, 112.32f, 113.4f, 111.24f, 112.32f, 113.4f, 111.24f, 112.32f, 113.4f, 73.08f, 73.8f, 74.52f, 73.08f, 73.8f, 74.52f, 73.08f, 73.8f, 74.52f, 73.08f, 73.8f, 74.52f, | ||||||
|                                                          67.68f, 68.4f, 69.12f, 67.68f, 68.4f, 69.12f, 67.68f, 68.4f, 69.12f, 67.68f, 68.4f, 69.12f, 44.4f, 44.88f, 45.36f, 44.4f, 44.88f, 45.36f, 44.4f, 44.88f, 45.36f, 44.4f, 44.88f, 45.36f,  |                                                          67.68f, 68.4f, 69.12f, 67.68f, 68.4f, 69.12f, 67.68f, 68.4f, 69.12f, 67.68f, 68.4f, 69.12f, 44.4f, 44.88f, 45.36f, 44.4f, 44.88f, 45.36f, 44.4f, 44.88f, 45.36f, 44.4f, 44.88f, 45.36f, | ||||||
|                                                          85.92f, 86.88f, 87.84f, 85.92f, 86.88f, 87.84f, 85.92f, 86.88f, 87.84f, 85.92f, 86.88f, 87.84f, 56.32f, 56.96f, 57.6f, 56.32f, 56.96f, 57.6f, 56.32f, 56.96f, 57.6f, 56.32f, 56.96f, 57.6f,  |                                                          85.92f, 86.88f, 87.84f, 85.92f, 86.88f, 87.84f, 85.92f, 86.88f, 87.84f, 85.92f, 86.88f, 87.84f, 56.32f, 56.96f, 57.6f, 56.32f, 56.96f, 57.6f, 56.32f, 56.96f, 57.6f, 56.32f, 56.96f, 57.6f, | ||||||
|                                                          61.2f, 61.92f, 62.64f, 61.2f, 61.92f, 62.64f, 61.2f, 61.92f, 62.64f, 61.2f, 61.92f, 62.64f, 40.08f, 40.56f, 41.04f, 40.08f, 40.56f, 41.04f, 40.08f, 40.56f, 41.04f, 40.08f, 40.56f, 41.04f}); |                                                          61.2f, 61.92f, 62.64f, 61.2f, 61.92f, 62.64f, 61.2f, 61.92f, 62.64f, 61.2f, 61.92f, 62.64f, 40.08f, 40.56f, 41.04f, 40.08f, 40.56f, 41.04f, 40.08f, 40.56f, 41.04f, 40.08f, 40.56f, 41.04f}); | ||||||
|     // auto expGradB('c', {oC},{});
 |     // auto expGradB('c', {oC},{});
 | ||||||
| 
 | 
 | ||||||
| @ -1302,18 +1334,18 @@ TYPED_TEST(TypedConvolutionTests1, conv3d_bp_test2) { | |||||||
|     auto bias     = NDArrayFactory::create<TypeParam>('c', {oC}, {1,2,3}); |     auto bias     = NDArrayFactory::create<TypeParam>('c', {oC}, {1,2,3}); | ||||||
|     auto gradO    = NDArrayFactory::create<TypeParam>('c', {bS, oD, oH, oW, oC}); |     auto gradO    = NDArrayFactory::create<TypeParam>('c', {bS, oD, oH, oW, oC}); | ||||||
| 
 | 
 | ||||||
|     auto expGradI = NDArrayFactory::create<TypeParam>('c', {bS, iD, iH, iW, iC},{ 0.014f, 0.032f, 0.05f, 0.068f, 0.118f, 0.181f, 0.244f, 0.307f, 0.212f, 0.257f, 0.302f, 0.347f, 0.208f, 0.298f, 0.388f, 0.478f, 1.028f, 1.262f, 1.496f, 1.73f, 1.036f, 1.18f, 1.324f, 1.468f, 0.928f, 1.018f, 1.108f, 1.198f, 2.9f, 3.134f, 3.368f, 3.602f, 2.188f, 2.332f, 2.476f, 2.62f,  |     auto expGradI = NDArrayFactory::create<TypeParam>('c', {bS, iD, iH, iW, iC},{ 0.014f, 0.032f, 0.05f, 0.068f, 0.118f, 0.181f, 0.244f, 0.307f, 0.212f, 0.257f, 0.302f, 0.347f, 0.208f, 0.298f, 0.388f, 0.478f, 1.028f, 1.262f, 1.496f, 1.73f, 1.036f, 1.18f, 1.324f, 1.468f, 0.928f, 1.018f, 1.108f, 1.198f, 2.9f, 3.134f, 3.368f, 3.602f, 2.188f, 2.332f, 2.476f, 2.62f, | ||||||
|                                                          1.202f, 1.274f, 1.346f, 1.418f, 3.142f, 3.313f, 3.484f, 3.655f, 2.048f, 2.147f, 2.246f, 2.345f, 0.532f, 0.676f, 0.82f, 0.964f, 2.324f, 2.666f, 3.008f, 3.35f, 2.008f, 2.206f, 2.404f, 2.602f, 3.584f, 3.98f, 4.376f, 4.772f, 10.552f, 11.452f, 12.352f, 13.252f, 7.4f, 7.904f, 8.408f, 8.912f,  |                                                          1.202f, 1.274f, 1.346f, 1.418f, 3.142f, 3.313f, 3.484f, 3.655f, 2.048f, 2.147f, 2.246f, 2.345f, 0.532f, 0.676f, 0.82f, 0.964f, 2.324f, 2.666f, 3.008f, 3.35f, 2.008f, 2.206f, 2.404f, 2.602f, 3.584f, 3.98f, 4.376f, 4.772f, 10.552f, 11.452f, 12.352f, 13.252f, 7.4f, 7.904f, 8.408f, 8.912f, | ||||||
|                                                          6.752f, 7.148f, 7.544f, 7.94f, 17.752f, 18.652f, 19.552f, 20.452f, 11.432f, 11.936f, 12.44f, 12.944f, 5.932f, 6.184f, 6.436f, 6.688f, 14.42f, 14.978f, 15.536f, 16.094f, 8.704f, 9.01f, 9.316f, 9.622f, 3.11f, 3.236f, 3.362f, 3.488f, 7.39f, 7.669f, 7.948f, 8.227f, 4.388f, 4.541f, 4.694f, 4.847f,  |                                                          6.752f, 7.148f, 7.544f, 7.94f, 17.752f, 18.652f, 19.552f, 20.452f, 11.432f, 11.936f, 12.44f, 12.944f, 5.932f, 6.184f, 6.436f, 6.688f, 14.42f, 14.978f, 15.536f, 16.094f, 8.704f, 9.01f, 9.316f, 9.622f, 3.11f, 3.236f, 3.362f, 3.488f, 7.39f, 7.669f, 7.948f, 8.227f, 4.388f, 4.541f, 4.694f, 4.847f, | ||||||
|                                                          8.56f, 8.866f, 9.172f, 9.478f, 19.892f, 20.558f, 21.224f, 21.89f, 11.548f, 11.908f, 12.268f, 12.628f, 11.008f, 11.314f, 11.62f, 11.926f, 25.22f, 25.886f, 26.552f, 27.218f, 14.428f, 14.788f, 15.148f, 15.508f, 7.322f, 7.502f, 7.682f, 7.862f, 16.462f, 16.849f, 17.236f, 17.623f, 9.248f, 9.455f, 9.662f, 9.869f,  |                                                          8.56f, 8.866f, 9.172f, 9.478f, 19.892f, 20.558f, 21.224f, 21.89f, 11.548f, 11.908f, 12.268f, 12.628f, 11.008f, 11.314f, 11.62f, 11.926f, 25.22f, 25.886f, 26.552f, 27.218f, 14.428f, 14.788f, 15.148f, 15.508f, 7.322f, 7.502f, 7.682f, 7.862f, 16.462f, 16.849f, 17.236f, 17.623f, 9.248f, 9.455f, 9.662f, 9.869f, | ||||||
|                                                          0.158f, 0.392f, 0.626f, 0.86f, 1.27f, 1.765f, 2.26f, 2.755f, 1.22f, 1.481f, 1.742f, 2.003f, 2.224f, 2.746f, 3.268f, 3.79f, 6.788f, 7.886f, 8.984f, 10.082f, 4.78f, 5.356f, 5.932f, 6.508f, 6.4f, 6.922f, 7.444f, 7.966f, 15.572f, 16.67f, 17.768f, 18.866f, 9.388f, 9.964f, 10.54f, 11.116f,  |                                                          0.158f, 0.392f, 0.626f, 0.86f, 1.27f, 1.765f, 2.26f, 2.755f, 1.22f, 1.481f, 1.742f, 2.003f, 2.224f, 2.746f, 3.268f, 3.79f, 6.788f, 7.886f, 8.984f, 10.082f, 4.78f, 5.356f, 5.932f, 6.508f, 6.4f, 6.922f, 7.444f, 7.966f, 15.572f, 16.67f, 17.768f, 18.866f, 9.388f, 9.964f, 10.54f, 11.116f, | ||||||
|                                                          4.802f, 5.09f, 5.378f, 5.666f, 11.206f, 11.809f, 12.412f, 13.015f, 6.512f, 6.827f, 7.142f, 7.457f, 6.004f, 6.58f, 7.156f, 7.732f, 14.996f, 16.202f, 17.408f, 18.614f, 9.208f, 9.838f, 10.468f, 11.098f, 17.984f, 19.244f, 20.504f, 21.764f, 42.808f, 45.436f, 48.064f, 50.692f, 25.256f, 26.624f, 27.992f, 29.36f,  |                                                          4.802f, 5.09f, 5.378f, 5.666f, 11.206f, 11.809f, 12.412f, 13.015f, 6.512f, 6.827f, 7.142f, 7.457f, 6.004f, 6.58f, 7.156f, 7.732f, 14.996f, 16.202f, 17.408f, 18.614f, 9.208f, 9.838f, 10.468f, 11.098f, 17.984f, 19.244f, 20.504f, 21.764f, 42.808f, 45.436f, 48.064f, 50.692f, 25.256f, 26.624f, 27.992f, 29.36f, | ||||||
|                                                         28.064f, 29.324f, 30.584f, 31.844f, 63.832f, 66.46f, 69.088f, 71.716f, 36.2f, 37.568f, 38.936f, 40.304f, 18.316f, 19.f, 19.684f, 20.368f, 40.916f, 42.338f, 43.76f, 45.182f, 22.816f, 23.554f, 24.292f, 25.03f, 8.438f, 8.78f, 9.122f, 9.464f, 18.91f, 19.621f, 20.332f, 21.043f, 10.58f, 10.949f, 11.318f, 11.687f,  |                                                         28.064f, 29.324f, 30.584f, 31.844f, 63.832f, 66.46f, 69.088f, 71.716f, 36.2f, 37.568f, 38.936f, 40.304f, 18.316f, 19.f, 19.684f, 20.368f, 40.916f, 42.338f, 43.76f, 45.182f, 22.816f, 23.554f, 24.292f, 25.03f, 8.438f, 8.78f, 9.122f, 9.464f, 18.91f, 19.621f, 20.332f, 21.043f, 10.58f, 10.949f, 11.318f, 11.687f, | ||||||
|                                                         20.944f, 21.682f, 22.42f, 23.158f, 46.388f, 47.918f, 49.448f, 50.978f, 25.66f, 26.452f, 27.244f, 28.036f, 26.848f, 27.586f, 28.324f, 29.062f, 58.628f, 60.158f, 61.688f, 63.218f, 31.996f, 32.788f, 33.58f, 34.372f, 16.106f, 16.502f, 16.898f, 17.294f, 34.894f, 35.713f, 36.532f, 37.351f, 18.896f, 19.319f, 19.742f, 20.165f}); |                                                         20.944f, 21.682f, 22.42f, 23.158f, 46.388f, 47.918f, 49.448f, 50.978f, 25.66f, 26.452f, 27.244f, 28.036f, 26.848f, 27.586f, 28.324f, 29.062f, 58.628f, 60.158f, 61.688f, 63.218f, 31.996f, 32.788f, 33.58f, 34.372f, 16.106f, 16.502f, 16.898f, 17.294f, 34.894f, 35.713f, 36.532f, 37.351f, 18.896f, 19.319f, 19.742f, 20.165f}); | ||||||
| 
 | 
 | ||||||
|     auto expGradW = NDArrayFactory::create<TypeParam>('c', {kD, kH, kW, iC, oC},{7.52f, 7.84f, 8.16f, 7.52f, 7.84f, 8.16f, 7.52f, 7.84f, 8.16f, 7.52f, 7.84f, 8.16f, 7.52f, 7.84f, 8.16f, 7.52f, 7.84f, 8.16f, 7.52f, 7.84f, 8.16f, 7.52f, 7.84f, 8.16f, 7.52f, 7.84f, 8.16f, 7.52f, 7.84f, 8.16f, 7.52f, 7.84f, 8.16f, 7.52f, 7.84f, 8.16f,  |     auto expGradW = NDArrayFactory::create<TypeParam>('c', {kD, kH, kW, iC, oC},{7.52f, 7.84f, 8.16f, 7.52f, 7.84f, 8.16f, 7.52f, 7.84f, 8.16f, 7.52f, 7.84f, 8.16f, 7.52f, 7.84f, 8.16f, 7.52f, 7.84f, 8.16f, 7.52f, 7.84f, 8.16f, 7.52f, 7.84f, 8.16f, 7.52f, 7.84f, 8.16f, 7.52f, 7.84f, 8.16f, 7.52f, 7.84f, 8.16f, 7.52f, 7.84f, 8.16f, | ||||||
|                                                         7.52f, 7.84f, 8.16f, 7.52f, 7.84f, 8.16f, 7.52f, 7.84f, 8.16f, 7.52f, 7.84f, 8.16f, 7.52f, 7.84f, 8.16f, 7.52f, 7.84f, 8.16f, 7.52f, 7.84f, 8.16f, 7.52f, 7.84f, 8.16f, 7.52f, 7.84f, 8.16f, 7.52f, 7.84f, 8.16f, 7.52f, 7.84f, 8.16f, 7.52f, 7.84f, 8.16f,  |                                                         7.52f, 7.84f, 8.16f, 7.52f, 7.84f, 8.16f, 7.52f, 7.84f, 8.16f, 7.52f, 7.84f, 8.16f, 7.52f, 7.84f, 8.16f, 7.52f, 7.84f, 8.16f, 7.52f, 7.84f, 8.16f, 7.52f, 7.84f, 8.16f, 7.52f, 7.84f, 8.16f, 7.52f, 7.84f, 8.16f, 7.52f, 7.84f, 8.16f, 7.52f, 7.84f, 8.16f, | ||||||
|                                                         7.52f, 7.84f, 8.16f, 7.52f, 7.84f, 8.16f, 7.52f, 7.84f, 8.16f, 7.52f, 7.84f, 8.16f, 7.52f, 7.84f, 8.16f, 7.52f, 7.84f, 8.16f, 7.52f, 7.84f, 8.16f, 7.52f, 7.84f, 8.16f, 7.52f, 7.84f, 8.16f, 7.52f, 7.84f, 8.16f, 7.52f, 7.84f, 8.16f, 7.52f, 7.84f, 8.16f,  |                                                         7.52f, 7.84f, 8.16f, 7.52f, 7.84f, 8.16f, 7.52f, 7.84f, 8.16f, 7.52f, 7.84f, 8.16f, 7.52f, 7.84f, 8.16f, 7.52f, 7.84f, 8.16f, 7.52f, 7.84f, 8.16f, 7.52f, 7.84f, 8.16f, 7.52f, 7.84f, 8.16f, 7.52f, 7.84f, 8.16f, 7.52f, 7.84f, 8.16f, 7.52f, 7.84f, 8.16f, | ||||||
|                                                         7.52f, 7.84f, 8.16f, 7.52f, 7.84f, 8.16f, 7.52f, 7.84f, 8.16f, 7.52f, 7.84f, 8.16f, 7.52f, 7.84f, 8.16f, 7.52f, 7.84f, 8.16f, 7.52f, 7.84f, 8.16f, 7.52f, 7.84f, 8.16f, 7.52f, 7.84f, 8.16f, 7.52f, 7.84f, 8.16f, 7.52f, 7.84f, 8.16f, 7.52f, 7.84f, 8.16f}); |                                                         7.52f, 7.84f, 8.16f, 7.52f, 7.84f, 8.16f, 7.52f, 7.84f, 8.16f, 7.52f, 7.84f, 8.16f, 7.52f, 7.84f, 8.16f, 7.52f, 7.84f, 8.16f, 7.52f, 7.84f, 8.16f, 7.52f, 7.84f, 8.16f, 7.52f, 7.84f, 8.16f, 7.52f, 7.84f, 8.16f, 7.52f, 7.84f, 8.16f, 7.52f, 7.84f, 8.16f}); | ||||||
|     // auto expGradB('c', {oC},{});
 |     // auto expGradB('c', {oC},{});
 | ||||||
| 
 | 
 | ||||||
| @ -1350,20 +1382,20 @@ TYPED_TEST(TypedConvolutionTests1, conv3d_bp_test3) { | |||||||
|     auto bias     = NDArrayFactory::create<TypeParam>('c', {oC}, {1,2,3}); |     auto bias     = NDArrayFactory::create<TypeParam>('c', {oC}, {1,2,3}); | ||||||
|     auto gradO    = NDArrayFactory::create<TypeParam>('c', {bS, oC, oD, oH, oW}); |     auto gradO    = NDArrayFactory::create<TypeParam>('c', {bS, oC, oD, oH, oW}); | ||||||
| 
 | 
 | ||||||
|     auto expGradI = NDArrayFactory::create<TypeParam>('c', {bS, iC, iD, iH, iW},{2.091f, 4.356f, 2.268f, 4.53f, 9.42f, 4.896f, 4.65f, 9.672f, 5.028f, 2.517f, 5.226f, 2.712f, 4.932f, 10.242f, 5.316f, 10.62f, 22.02f, 11.412f, 10.908f, 22.62f, 11.724f, 5.868f, 12.15f, 6.288f, 2.913f, 6.03f, 3.12f, 6.234f, 12.888f, 6.66f, 6.402f, 13.236f, 6.84f, 3.423f, 7.068f, 3.648f,  |     auto expGradI = NDArrayFactory::create<TypeParam>('c', {bS, iC, iD, iH, iW},{2.091f, 4.356f, 2.268f, 4.53f, 9.42f, 4.896f, 4.65f, 9.672f, 5.028f, 2.517f, 5.226f, 2.712f, 4.932f, 10.242f, 5.316f, 10.62f, 22.02f, 11.412f, 10.908f, 22.62f, 11.724f, 5.868f, 12.15f, 6.288f, 2.913f, 6.03f, 3.12f, 6.234f, 12.888f, 6.66f, 6.402f, 13.236f, 6.84f, 3.423f, 7.068f, 3.648f, | ||||||
|                                                         2.415f, 5.04f, 2.628f, 5.25f, 10.932f, 5.688f, 5.37f, 11.184f, 5.82f, 2.913f, 6.054f, 3.144f, 5.724f, 11.898f, 6.18f, 12.348f, 25.62f, 13.284f, 12.636f, 26.22f, 13.596f, 6.804f, 14.094f, 7.296f, 3.381f, 7.002f, 3.624f, 7.242f, 14.976f, 7.74f, 7.41f, 15.324f, 7.92f, 3.963f, 8.184f, 4.224f,  |                                                         2.415f, 5.04f, 2.628f, 5.25f, 10.932f, 5.688f, 5.37f, 11.184f, 5.82f, 2.913f, 6.054f, 3.144f, 5.724f, 11.898f, 6.18f, 12.348f, 25.62f, 13.284f, 12.636f, 26.22f, 13.596f, 6.804f, 14.094f, 7.296f, 3.381f, 7.002f, 3.624f, 7.242f, 14.976f, 7.74f, 7.41f, 15.324f, 7.92f, 3.963f, 8.184f, 4.224f, | ||||||
|                                                         2.739f, 5.724f, 2.988f, 5.97f, 12.444f, 6.48f, 6.09f, 12.696f, 6.612f, 3.309f, 6.882f, 3.576f, 6.516f, 13.554f, 7.044f, 14.076f, 29.22f, 15.156f, 14.364f, 29.82f, 15.468f, 7.74f, 16.038f, 8.304f, 3.849f, 7.974f, 4.128f, 8.25f, 17.064f, 8.82f, 8.418f, 17.412f, 9.f, 4.503f, 9.3f, 4.8f,  |                                                         2.739f, 5.724f, 2.988f, 5.97f, 12.444f, 6.48f, 6.09f, 12.696f, 6.612f, 3.309f, 6.882f, 3.576f, 6.516f, 13.554f, 7.044f, 14.076f, 29.22f, 15.156f, 14.364f, 29.82f, 15.468f, 7.74f, 16.038f, 8.304f, 3.849f, 7.974f, 4.128f, 8.25f, 17.064f, 8.82f, 8.418f, 17.412f, 9.f, 4.503f, 9.3f, 4.8f, | ||||||
|                                                         3.063f, 6.408f, 3.348f, 6.69f, 13.956f, 7.272f, 6.81f, 14.208f, 7.404f, 3.705f, 7.71f, 4.008f, 7.308f, 15.21f, 7.908f, 15.804f, 32.82f, 17.028f, 16.092f, 33.42f, 17.34f, 8.676f, 17.982f, 9.312f, 4.317f, 8.946f, 4.632f, 9.258f, 19.152f, 9.9f, 9.426f, 19.5f, 10.08f, 5.043f, 10.416f, 5.376f,  |                                                         3.063f, 6.408f, 3.348f, 6.69f, 13.956f, 7.272f, 6.81f, 14.208f, 7.404f, 3.705f, 7.71f, 4.008f, 7.308f, 15.21f, 7.908f, 15.804f, 32.82f, 17.028f, 16.092f, 33.42f, 17.34f, 8.676f, 17.982f, 9.312f, 4.317f, 8.946f, 4.632f, 9.258f, 19.152f, 9.9f, 9.426f, 19.5f, 10.08f, 5.043f, 10.416f, 5.376f, | ||||||
|                                                         5.619f, 11.484f, 5.868f, 11.73f, 23.964f, 12.24f, 12.138f, 24.792f, 12.66f, 6.333f, 12.93f, 6.6f, 12.42f, 25.362f, 12.948f, 25.884f, 52.836f, 26.964f, 26.748f, 54.588f, 27.852f, 13.932f, 28.422f, 14.496f, 6.873f, 14.022f, 7.152f, 14.298f, 29.16f, 14.868f, 14.754f, 30.084f, 15.336f, 7.671f, 15.636f, 7.968f,  |                                                         5.619f, 11.484f, 5.868f, 11.73f, 23.964f, 12.24f, 12.138f, 24.792f, 12.66f, 6.333f, 12.93f, 6.6f, 12.42f, 25.362f, 12.948f, 25.884f, 52.836f, 26.964f, 26.748f, 54.588f, 27.852f, 13.932f, 28.422f, 14.496f, 6.873f, 14.022f, 7.152f, 14.298f, 29.16f, 14.868f, 14.754f, 30.084f, 15.336f, 7.671f, 15.636f, 7.968f, | ||||||
|                                                         6.807f, 13.896f, 7.092f, 14.178f, 28.932f, 14.76f, 14.586f, 29.76f, 15.18f, 7.593f, 15.486f, 7.896f, 14.94f, 30.474f, 15.54f, 31.068f, 63.348f, 32.292f, 31.932f, 65.1f, 33.18f, 16.596f, 33.822f, 17.232f, 8.205f, 16.722f, 8.52f, 17.034f, 34.704f, 17.676f, 17.49f, 35.628f, 18.144f, 9.075f, 18.48f, 9.408f,  |                                                         6.807f, 13.896f, 7.092f, 14.178f, 28.932f, 14.76f, 14.586f, 29.76f, 15.18f, 7.593f, 15.486f, 7.896f, 14.94f, 30.474f, 15.54f, 31.068f, 63.348f, 32.292f, 31.932f, 65.1f, 33.18f, 16.596f, 33.822f, 17.232f, 8.205f, 16.722f, 8.52f, 17.034f, 34.704f, 17.676f, 17.49f, 35.628f, 18.144f, 9.075f, 18.48f, 9.408f, | ||||||
|                                                         7.995f, 16.308f, 8.316f, 16.626f, 33.9f, 17.28f, 17.034f, 34.728f, 17.7f, 8.853f, 18.042f, 9.192f, 17.46f, 35.586f, 18.132f, 36.252f, 73.86f, 37.62f, 37.116f, 75.612f, 38.508f, 19.26f, 39.222f, 19.968f, 9.537f, 19.422f, 9.888f, 19.77f, 40.248f, 20.484f, 20.226f, 41.172f, 20.952f, 10.479f, 21.324f, 10.848f,  |                                                         7.995f, 16.308f, 8.316f, 16.626f, 33.9f, 17.28f, 17.034f, 34.728f, 17.7f, 8.853f, 18.042f, 9.192f, 17.46f, 35.586f, 18.132f, 36.252f, 73.86f, 37.62f, 37.116f, 75.612f, 38.508f, 19.26f, 39.222f, 19.968f, 9.537f, 19.422f, 9.888f, 19.77f, 40.248f, 20.484f, 20.226f, 41.172f, 20.952f, 10.479f, 21.324f, 10.848f, | ||||||
|                                                         9.183f, 18.72f, 9.54f, 19.074f, 38.868f, 19.8f, 19.482f, 39.696f, 20.22f, 10.113f, 20.598f, 10.488f, 19.98f, 40.698f, 20.724f, 41.436f, 84.372f, 42.948f, 42.3f, 86.124f, 43.836f, 21.924f, 44.622f, 22.704f, 10.869f, 22.122f, 11.256f, 22.506f, 45.792f, 23.292f, 22.962f, 46.716f, 23.76f, 11.883f, 24.168f, 12.288f}); |                                                         9.183f, 18.72f, 9.54f, 19.074f, 38.868f, 19.8f, 19.482f, 39.696f, 20.22f, 10.113f, 20.598f, 10.488f, 19.98f, 40.698f, 20.724f, 41.436f, 84.372f, 42.948f, 42.3f, 86.124f, 43.836f, 21.924f, 44.622f, 22.704f, 10.869f, 22.122f, 11.256f, 22.506f, 45.792f, 23.292f, 22.962f, 46.716f, 23.76f, 11.883f, 24.168f, 12.288f}); | ||||||
| 
 | 
 | ||||||
|     auto expGradW = NDArrayFactory::create<TypeParam>('c', {oC, iC, kD, kH, kW},{5.28f, 5.28f, 5.28f, 5.28f, 5.28f, 5.28f, 5.28f, 5.28f, 5.28f, 5.28f, 5.28f, 5.28f, 5.28f, 5.28f, 5.28f, 5.28f, 5.28f, 5.28f, 5.28f, 5.28f, 5.28f, 5.28f, 5.28f, 5.28f,  |     auto expGradW = NDArrayFactory::create<TypeParam>('c', {oC, iC, kD, kH, kW},{5.28f, 5.28f, 5.28f, 5.28f, 5.28f, 5.28f, 5.28f, 5.28f, 5.28f, 5.28f, 5.28f, 5.28f, 5.28f, 5.28f, 5.28f, 5.28f, 5.28f, 5.28f, 5.28f, 5.28f, 5.28f, 5.28f, 5.28f, 5.28f, | ||||||
|                                                         5.28f, 5.28f, 5.28f, 5.28f, 5.28f, 5.28f, 5.28f, 5.28f, 5.28f, 5.28f, 5.28f, 5.28f, 5.28f, 5.28f, 5.28f, 5.28f, 5.28f, 5.28f, 5.28f, 5.28f, 5.28f, 5.28f, 5.28f, 5.28f,  |                                                         5.28f, 5.28f, 5.28f, 5.28f, 5.28f, 5.28f, 5.28f, 5.28f, 5.28f, 5.28f, 5.28f, 5.28f, 5.28f, 5.28f, 5.28f, 5.28f, 5.28f, 5.28f, 5.28f, 5.28f, 5.28f, 5.28f, 5.28f, 5.28f, | ||||||
|                                                         7.84f, 7.84f, 7.84f, 7.84f, 7.84f, 7.84f, 7.84f, 7.84f, 7.84f, 7.84f, 7.84f, 7.84f, 7.84f, 7.84f, 7.84f, 7.84f, 7.84f, 7.84f, 7.84f, 7.84f, 7.84f, 7.84f, 7.84f, 7.84f,  |                                                         7.84f, 7.84f, 7.84f, 7.84f, 7.84f, 7.84f, 7.84f, 7.84f, 7.84f, 7.84f, 7.84f, 7.84f, 7.84f, 7.84f, 7.84f, 7.84f, 7.84f, 7.84f, 7.84f, 7.84f, 7.84f, 7.84f, 7.84f, 7.84f, | ||||||
|                                                         7.84f, 7.84f, 7.84f, 7.84f, 7.84f, 7.84f, 7.84f, 7.84f, 7.84f, 7.84f, 7.84f, 7.84f, 7.84f, 7.84f, 7.84f, 7.84f, 7.84f, 7.84f, 7.84f, 7.84f, 7.84f, 7.84f, 7.84f, 7.84f,  |                                                         7.84f, 7.84f, 7.84f, 7.84f, 7.84f, 7.84f, 7.84f, 7.84f, 7.84f, 7.84f, 7.84f, 7.84f, 7.84f, 7.84f, 7.84f, 7.84f, 7.84f, 7.84f, 7.84f, 7.84f, 7.84f, 7.84f, 7.84f, 7.84f, | ||||||
|                                                         10.4f, 10.4f, 10.4f, 10.4f, 10.4f, 10.4f, 10.4f, 10.4f, 10.4f, 10.4f, 10.4f, 10.4f, 10.4f, 10.4f, 10.4f, 10.4f, 10.4f, 10.4f, 10.4f, 10.4f, 10.4f, 10.4f, 10.4f, 10.4f,  |                                                         10.4f, 10.4f, 10.4f, 10.4f, 10.4f, 10.4f, 10.4f, 10.4f, 10.4f, 10.4f, 10.4f, 10.4f, 10.4f, 10.4f, 10.4f, 10.4f, 10.4f, 10.4f, 10.4f, 10.4f, 10.4f, 10.4f, 10.4f, 10.4f, | ||||||
|                                                         10.4f, 10.4f, 10.4f, 10.4f, 10.4f, 10.4f, 10.4f, 10.4f, 10.4f, 10.4f, 10.4f, 10.4f, 10.4f, 10.4f, 10.4f, 10.4f, 10.4f, 10.4f, 10.4f, 10.4f, 10.4f, 10.4f, 10.4f, 10.4f}); |                                                         10.4f, 10.4f, 10.4f, 10.4f, 10.4f, 10.4f, 10.4f, 10.4f, 10.4f, 10.4f, 10.4f, 10.4f, 10.4f, 10.4f, 10.4f, 10.4f, 10.4f, 10.4f, 10.4f, 10.4f, 10.4f, 10.4f, 10.4f, 10.4f}); | ||||||
| 
 | 
 | ||||||
|     auto expGradB = NDArrayFactory::create<TypeParam>('c', {oC},{2.64f, 3.92f, 5.2f}); |     auto expGradB = NDArrayFactory::create<TypeParam>('c', {oC},{2.64f, 3.92f, 5.2f}); | ||||||
| @ -1407,9 +1439,9 @@ TYPED_TEST(TypedConvolutionTests1, depthwise_conv2d_1) { | |||||||
|     auto weights  = NDArrayFactory::create<TypeParam>('c', {kH, kW, iC, mC}); |     auto weights  = NDArrayFactory::create<TypeParam>('c', {kH, kW, iC, mC}); | ||||||
| 
 | 
 | ||||||
| 
 | 
 | ||||||
|     auto expOutput = NDArrayFactory::create<TypeParam>('c', {bS, oH, oW, oC},{12.f, 12.8f, 13.6f, 14.4f, 12.f, 12.8f, 13.6f, 14.4f, 5.2f, 5.6f, 6.f, 6.4f, 13.2f, 14.4f, 15.6f, 16.8f, 13.2f, 14.4f, 15.6f, 16.8f, 5.4f, 6.f, 6.6f, 7.2f,  |     auto expOutput = NDArrayFactory::create<TypeParam>('c', {bS, oH, oW, oC},{12.f, 12.8f, 13.6f, 14.4f, 12.f, 12.8f, 13.6f, 14.4f, 5.2f, 5.6f, 6.f, 6.4f, 13.2f, 14.4f, 15.6f, 16.8f, 13.2f, 14.4f, 15.6f, 16.8f, 5.4f, 6.f, 6.6f, 7.2f, | ||||||
|                                                      13.2f, 14.4f, 15.6f, 16.8f, 13.2f, 14.4f, 15.6f, 16.8f, 5.4f, 6.f, 6.6f, 7.2f, 5.6f, 6.4f, 7.2f, 8.f, 5.6f, 6.4f, 7.2f, 8.f, 2.f, 2.4f, 2.8f, 3.2f,  |                                                      13.2f, 14.4f, 15.6f, 16.8f, 13.2f, 14.4f, 15.6f, 16.8f, 5.4f, 6.f, 6.6f, 7.2f, 5.6f, 6.4f, 7.2f, 8.f, 5.6f, 6.4f, 7.2f, 8.f, 2.f, 2.4f, 2.8f, 3.2f, | ||||||
|                                                      12.f, 12.8f, 13.6f, 14.4f, 12.f, 12.8f, 13.6f, 14.4f, 5.2f, 5.6f, 6.f, 6.4f, 13.2f, 14.4f, 15.6f, 16.8f, 13.2f, 14.4f, 15.6f, 16.8f, 5.4f, 6.f, 6.6f, 7.2f,  |                                                      12.f, 12.8f, 13.6f, 14.4f, 12.f, 12.8f, 13.6f, 14.4f, 5.2f, 5.6f, 6.f, 6.4f, 13.2f, 14.4f, 15.6f, 16.8f, 13.2f, 14.4f, 15.6f, 16.8f, 5.4f, 6.f, 6.6f, 7.2f, | ||||||
|                                                      13.2f, 14.4f, 15.6f, 16.8f, 13.2f, 14.4f, 15.6f, 16.8f, 5.4f, 6.f, 6.6f, 7.2f, 5.6f, 6.4f, 7.2f, 8.f, 5.6f, 6.4f, 7.2f, 8.f, 2.f, 2.4f, 2.8f, 3.2f}); |                                                      13.2f, 14.4f, 15.6f, 16.8f, 13.2f, 14.4f, 15.6f, 16.8f, 5.4f, 6.f, 6.6f, 7.2f, 5.6f, 6.4f, 7.2f, 8.f, 5.6f, 6.4f, 7.2f, 8.f, 2.f, 2.4f, 2.8f, 3.2f}); | ||||||
|     input = 2.; |     input = 2.; | ||||||
|     weights.linspace(0.1, 0.1); |     weights.linspace(0.1, 0.1); | ||||||
| @ -1439,7 +1471,7 @@ TEST_F(ConvolutionTests1, depthwise_conv2d_2) { | |||||||
|     auto weights  = NDArrayFactory::create<double>('c', {kH, kW, iC, mC}); |     auto weights  = NDArrayFactory::create<double>('c', {kH, kW, iC, mC}); | ||||||
| 
 | 
 | ||||||
| 
 | 
 | ||||||
|     auto expOutput = NDArrayFactory::create<double>('c', {bS, oH, oW, oC},{13.2f,  14.4f,  15.6f,  16.8f,  13.2f,  14.4f,  15.6f,  16.8f,  13.2f,  14.4f,  15.6f,  16.8f,  13.2f,  14.4f,  15.6f,  16.8f,  |     auto expOutput = NDArrayFactory::create<double>('c', {bS, oH, oW, oC},{13.2f,  14.4f,  15.6f,  16.8f,  13.2f,  14.4f,  15.6f,  16.8f,  13.2f,  14.4f,  15.6f,  16.8f,  13.2f,  14.4f,  15.6f,  16.8f, | ||||||
|                                                      13.2f,  14.4f,  15.6f,  16.8f, 13.2f,  14.4f,  15.6f,  16.8f,  13.2f,  14.4f,  15.6f,  16.8f, 13.2f,  14.4f,  15.6f,  16.8f}); |                                                      13.2f,  14.4f,  15.6f,  16.8f, 13.2f,  14.4f,  15.6f,  16.8f,  13.2f,  14.4f,  15.6f,  16.8f, 13.2f,  14.4f,  15.6f,  16.8f}); | ||||||
|     input = 2.; |     input = 2.; | ||||||
|     weights.linspace(0.1, 0.1); |     weights.linspace(0.1, 0.1); | ||||||
| @ -1697,13 +1729,13 @@ TYPED_TEST(TypedConvolutionTests1, conv3d_test1) { | |||||||
| 
 | 
 | ||||||
|     auto input    = NDArrayFactory::create<TypeParam>('c', {bS, iD, iH, iW, iC}); |     auto input    = NDArrayFactory::create<TypeParam>('c', {bS, iD, iH, iW, iC}); | ||||||
|     auto weights  = NDArrayFactory::create<TypeParam>('c', {kD, kH, kW, iC, oC}); |     auto weights  = NDArrayFactory::create<TypeParam>('c', {kD, kH, kW, iC, oC}); | ||||||
|     auto expected = NDArrayFactory::create<TypeParam>('c', {2, 3, 4, 3, 3}, {64.f, 64.f, 64.f, 64.f, 64.f, 64.f, 32.f, 32.f, 32.f, 96.f, 96.f, 96.f, 96.f, 96.f, 96.f, 48.f, 48.f, 48.f, 96.f, 96.f, 96.f, 96.f, 96.f, 96.f, 48.f, 48.f, 48.f,  |     auto expected = NDArrayFactory::create<TypeParam>('c', {2, 3, 4, 3, 3}, {64.f, 64.f, 64.f, 64.f, 64.f, 64.f, 32.f, 32.f, 32.f, 96.f, 96.f, 96.f, 96.f, 96.f, 96.f, 48.f, 48.f, 48.f, 96.f, 96.f, 96.f, 96.f, 96.f, 96.f, 48.f, 48.f, 48.f, | ||||||
|                                                    64.f, 64.f, 64.f, 64.f, 64.f, 64.f, 32.f, 32.f, 32.f, 64.f, 64.f, 64.f, 64.f, 64.f, 64.f, 32.f, 32.f, 32.f, 96.f, 96.f, 96.f, 96.f, 96.f, 96.f, 48.f, 48.f, 48.f,  |                                                    64.f, 64.f, 64.f, 64.f, 64.f, 64.f, 32.f, 32.f, 32.f, 64.f, 64.f, 64.f, 64.f, 64.f, 64.f, 32.f, 32.f, 32.f, 96.f, 96.f, 96.f, 96.f, 96.f, 96.f, 48.f, 48.f, 48.f, | ||||||
|                                                    96.f, 96.f, 96.f, 96.f, 96.f, 96.f, 48.f, 48.f, 48.f, 64.f, 64.f, 64.f, 64.f, 64.f, 64.f, 32.f, 32.f, 32.f, 32.f, 32.f, 32.f, 32.f, 32.f, 32.f, 16.f, 16.f, 16.f,  |                                                    96.f, 96.f, 96.f, 96.f, 96.f, 96.f, 48.f, 48.f, 48.f, 64.f, 64.f, 64.f, 64.f, 64.f, 64.f, 32.f, 32.f, 32.f, 32.f, 32.f, 32.f, 32.f, 32.f, 32.f, 16.f, 16.f, 16.f, | ||||||
|                                                    48.f, 48.f, 48.f, 48.f, 48.f, 48.f, 24.f, 24.f, 24.f, 48.f, 48.f, 48.f, 48.f, 48.f, 48.f, 24.f, 24.f, 24.f, 32.f, 32.f, 32.f, 32.f, 32.f, 32.f, 16.f, 16.f, 16.f,  |                                                    48.f, 48.f, 48.f, 48.f, 48.f, 48.f, 24.f, 24.f, 24.f, 48.f, 48.f, 48.f, 48.f, 48.f, 48.f, 24.f, 24.f, 24.f, 32.f, 32.f, 32.f, 32.f, 32.f, 32.f, 16.f, 16.f, 16.f, | ||||||
|                                                    64.f, 64.f, 64.f, 64.f, 64.f, 64.f, 32.f, 32.f, 32.f, 96.f, 96.f, 96.f, 96.f, 96.f, 96.f, 48.f, 48.f, 48.f, 96.f, 96.f, 96.f, 96.f, 96.f, 96.f, 48.f, 48.f, 48.f,  |                                                    64.f, 64.f, 64.f, 64.f, 64.f, 64.f, 32.f, 32.f, 32.f, 96.f, 96.f, 96.f, 96.f, 96.f, 96.f, 48.f, 48.f, 48.f, 96.f, 96.f, 96.f, 96.f, 96.f, 96.f, 48.f, 48.f, 48.f, | ||||||
|                                                    64.f, 64.f, 64.f, 64.f, 64.f, 64.f, 32.f, 32.f, 32.f, 64.f, 64.f, 64.f, 64.f, 64.f, 64.f, 32.f, 32.f, 32.f, 96.f, 96.f, 96.f, 96.f, 96.f, 96.f, 48.f, 48.f, 48.f,  |                                                    64.f, 64.f, 64.f, 64.f, 64.f, 64.f, 32.f, 32.f, 32.f, 64.f, 64.f, 64.f, 64.f, 64.f, 64.f, 32.f, 32.f, 32.f, 96.f, 96.f, 96.f, 96.f, 96.f, 96.f, 48.f, 48.f, 48.f, | ||||||
|                                                    96.f, 96.f, 96.f, 96.f, 96.f, 96.f, 48.f, 48.f, 48.f, 64.f, 64.f, 64.f, 64.f, 64.f, 64.f, 32.f, 32.f, 32.f, 32.f, 32.f, 32.f, 32.f, 32.f, 32.f, 16.f, 16.f, 16.f,  |                                                    96.f, 96.f, 96.f, 96.f, 96.f, 96.f, 48.f, 48.f, 48.f, 64.f, 64.f, 64.f, 64.f, 64.f, 64.f, 32.f, 32.f, 32.f, 32.f, 32.f, 32.f, 32.f, 32.f, 32.f, 16.f, 16.f, 16.f, | ||||||
|                                                    48.f, 48.f, 48.f, 48.f, 48.f, 48.f, 24.f, 24.f, 24.f, 48.f, 48.f, 48.f, 48.f, 48.f, 48.f, 24.f, 24.f, 24.f, 32.f, 32.f, 32.f, 32.f, 32.f, 32.f, 16.f, 16.f, 16.f}); |                                                    48.f, 48.f, 48.f, 48.f, 48.f, 48.f, 24.f, 24.f, 24.f, 48.f, 48.f, 48.f, 48.f, 48.f, 48.f, 24.f, 24.f, 24.f, 32.f, 32.f, 32.f, 32.f, 32.f, 32.f, 16.f, 16.f, 16.f}); | ||||||
|     input = 2.; |     input = 2.; | ||||||
|     weights = 1.; |     weights = 1.; | ||||||
| @ -1729,13 +1761,13 @@ TYPED_TEST(TypedConvolutionTests1, conv3d_test2) { | |||||||
| 
 | 
 | ||||||
|     auto input    = NDArrayFactory::create<TypeParam>('c', {bS, iD, iH, iW, iC}); |     auto input    = NDArrayFactory::create<TypeParam>('c', {bS, iD, iH, iW, iC}); | ||||||
|     auto weights  = NDArrayFactory::create<TypeParam>('c', {kD, kH, kW, iC, oC}); |     auto weights  = NDArrayFactory::create<TypeParam>('c', {kD, kH, kW, iC, oC}); | ||||||
|     auto expected = NDArrayFactory::create<TypeParam>('c', {2, 3, 4, 3, 3}, {534.4f, 540.8f, 547.2f, 534.4f, 540.8f, 547.2f, 248.f, 251.2f, 254.4f, 686.4f, 696.f, 705.6f, 686.4f, 696.f, 705.6f, 314.4f, 319.2f, 324.f, 686.4f, 696.f, 705.6f, 686.4f, 696.f, 705.6f, 314.4f, 319.2f, 324.f,  |     auto expected = NDArrayFactory::create<TypeParam>('c', {2, 3, 4, 3, 3}, {534.4f, 540.8f, 547.2f, 534.4f, 540.8f, 547.2f, 248.f, 251.2f, 254.4f, 686.4f, 696.f, 705.6f, 686.4f, 696.f, 705.6f, 314.4f, 319.2f, 324.f, 686.4f, 696.f, 705.6f, 686.4f, 696.f, 705.6f, 314.4f, 319.2f, 324.f, | ||||||
|                                                    380.8f, 387.2f, 393.6f, 380.8f, 387.2f, 393.6f, 171.2f, 174.4f, 177.6f, 534.4f, 540.8f, 547.2f, 534.4f, 540.8f, 547.2f, 248.f, 251.2f, 254.4f, 686.4f, 696.f, 705.6f, 686.4f, 696.f, 705.6f, 314.4f, 319.2f, 324.f,  |                                                    380.8f, 387.2f, 393.6f, 380.8f, 387.2f, 393.6f, 171.2f, 174.4f, 177.6f, 534.4f, 540.8f, 547.2f, 534.4f, 540.8f, 547.2f, 248.f, 251.2f, 254.4f, 686.4f, 696.f, 705.6f, 686.4f, 696.f, 705.6f, 314.4f, 319.2f, 324.f, | ||||||
|                                                    686.4f, 696.f, 705.6f, 686.4f, 696.f, 705.6f, 314.4f, 319.2f, 324.f, 380.8f, 387.2f, 393.6f, 380.8f, 387.2f, 393.6f, 171.2f, 174.4f, 177.6f, 152.f, 155.2f, 158.4f, 152.f, 155.2f, 158.4f, 66.4f, 68.f, 69.6f,  |                                                    686.4f, 696.f, 705.6f, 686.4f, 696.f, 705.6f, 314.4f, 319.2f, 324.f, 380.8f, 387.2f, 393.6f, 380.8f, 387.2f, 393.6f, 171.2f, 174.4f, 177.6f, 152.f, 155.2f, 158.4f, 152.f, 155.2f, 158.4f, 66.4f, 68.f, 69.6f, | ||||||
|                                                    170.4f, 175.2f, 180.f, 170.4f, 175.2f, 180.f, 70.8f, 73.2f, 75.6f, 170.4f, 175.2f, 180.f, 170.4f, 175.2f, 180.f, 70.8f, 73.2f, 75.6f, 75.2f, 78.4f, 81.6f, 75.2f, 78.4f, 81.6f, 28.f, 29.6f, 31.2f,  |                                                    170.4f, 175.2f, 180.f, 170.4f, 175.2f, 180.f, 70.8f, 73.2f, 75.6f, 170.4f, 175.2f, 180.f, 170.4f, 175.2f, 180.f, 70.8f, 73.2f, 75.6f, 75.2f, 78.4f, 81.6f, 75.2f, 78.4f, 81.6f, 28.f, 29.6f, 31.2f, | ||||||
|                                                    534.4f, 540.8f, 547.2f, 534.4f, 540.8f, 547.2f, 248.f, 251.2f, 254.4f, 686.4f, 696.f, 705.6f, 686.4f, 696.f, 705.6f, 314.4f, 319.2f, 324.f, 686.4f, 696.f, 705.6f, 686.4f, 696.f, 705.6f, 314.4f, 319.2f, 324.f,  |                                                    534.4f, 540.8f, 547.2f, 534.4f, 540.8f, 547.2f, 248.f, 251.2f, 254.4f, 686.4f, 696.f, 705.6f, 686.4f, 696.f, 705.6f, 314.4f, 319.2f, 324.f, 686.4f, 696.f, 705.6f, 686.4f, 696.f, 705.6f, 314.4f, 319.2f, 324.f, | ||||||
|                                                    380.8f, 387.2f, 393.6f, 380.8f, 387.2f, 393.6f, 171.2f, 174.4f, 177.6f, 534.4f, 540.8f, 547.2f, 534.4f, 540.8f, 547.2f, 248.f, 251.2f, 254.4f, 686.4f, 696.f, 705.6f, 686.4f, 696.f, 705.6f, 314.4f, 319.2f, 324.f,  |                                                    380.8f, 387.2f, 393.6f, 380.8f, 387.2f, 393.6f, 171.2f, 174.4f, 177.6f, 534.4f, 540.8f, 547.2f, 534.4f, 540.8f, 547.2f, 248.f, 251.2f, 254.4f, 686.4f, 696.f, 705.6f, 686.4f, 696.f, 705.6f, 314.4f, 319.2f, 324.f, | ||||||
|                                                    686.4f, 696.f, 705.6f, 686.4f, 696.f, 705.6f, 314.4f, 319.2f, 324.f, 380.8f, 387.2f, 393.6f, 380.8f, 387.2f, 393.6f, 171.2f, 174.4f, 177.6f, 152.f, 155.2f, 158.4f, 152.f, 155.2f, 158.4f, 66.4f, 68.f, 69.6f,  |                                                    686.4f, 696.f, 705.6f, 686.4f, 696.f, 705.6f, 314.4f, 319.2f, 324.f, 380.8f, 387.2f, 393.6f, 380.8f, 387.2f, 393.6f, 171.2f, 174.4f, 177.6f, 152.f, 155.2f, 158.4f, 152.f, 155.2f, 158.4f, 66.4f, 68.f, 69.6f, | ||||||
|                                                    170.4f, 175.2f, 180.f, 170.4f, 175.2f, 180.f, 70.8f, 73.2f, 75.6f, 170.4f, 175.2f, 180.f, 170.4f, 175.2f, 180.f, 70.8f, 73.2f, 75.6f, 75.2f, 78.4f, 81.6f, 75.2f, 78.4f, 81.6f, 28.f, 29.6f, 31.2f}); |                                                    170.4f, 175.2f, 180.f, 170.4f, 175.2f, 180.f, 70.8f, 73.2f, 75.6f, 170.4f, 175.2f, 180.f, 170.4f, 175.2f, 180.f, 70.8f, 73.2f, 75.6f, 75.2f, 78.4f, 81.6f, 75.2f, 78.4f, 81.6f, 28.f, 29.6f, 31.2f}); | ||||||
|     input = 2.; |     input = 2.; | ||||||
|     weights.linspace(0.1, 0.1); |     weights.linspace(0.1, 0.1); | ||||||
| @ -1760,9 +1792,9 @@ TYPED_TEST(TypedConvolutionTests1, conv3d_test3) { | |||||||
| 
 | 
 | ||||||
|     auto input    = NDArrayFactory::create<TypeParam>('c', {bS, iD, iH, iW, iC}); |     auto input    = NDArrayFactory::create<TypeParam>('c', {bS, iD, iH, iW, iC}); | ||||||
|     auto weights  = NDArrayFactory::create<TypeParam>('c', {kD, kH, kW, iC, oC}); |     auto weights  = NDArrayFactory::create<TypeParam>('c', {kD, kH, kW, iC, oC}); | ||||||
|     auto expected = NDArrayFactory::create<TypeParam>('c', {2, 2, 2, 2, 3},  {686.4f, 696.f, 705.6f, 686.4f, 696.f, 705.6f, 686.4f, 696.f, 705.6f, 686.4f, 696.f, 705.6f,  |     auto expected = NDArrayFactory::create<TypeParam>('c', {2, 2, 2, 2, 3},  {686.4f, 696.f, 705.6f, 686.4f, 696.f, 705.6f, 686.4f, 696.f, 705.6f, 686.4f, 696.f, 705.6f, | ||||||
|                                                     686.4f, 696.f, 705.6f, 686.4f, 696.f, 705.6f, 686.4f, 696.f, 705.6f, 686.4f, 696.f, 705.6f,  |                                                     686.4f, 696.f, 705.6f, 686.4f, 696.f, 705.6f, 686.4f, 696.f, 705.6f, 686.4f, 696.f, 705.6f, | ||||||
|                                                     686.4f, 696.f, 705.6f, 686.4f, 696.f, 705.6f, 686.4f, 696.f, 705.6f, 686.4f, 696.f, 705.6f,  |                                                     686.4f, 696.f, 705.6f, 686.4f, 696.f, 705.6f, 686.4f, 696.f, 705.6f, 686.4f, 696.f, 705.6f, | ||||||
|                                                     686.4f, 696.f, 705.6f, 686.4f, 696.f, 705.6f, 686.4f, 696.f, 705.6f, 686.4f, 696.f, 705.6f}); |                                                     686.4f, 696.f, 705.6f, 686.4f, 696.f, 705.6f, 686.4f, 696.f, 705.6f, 686.4f, 696.f, 705.6f}); | ||||||
|     input = 2.; |     input = 2.; | ||||||
|     weights.linspace(0.1, 0.1); |     weights.linspace(0.1, 0.1); | ||||||
| @ -1844,8 +1876,8 @@ TYPED_TEST(TypedConvolutionTests1, conv3d_test6) { | |||||||
|     auto input    = NDArrayFactory::create<TypeParam>('c', {bS, iC, iD, iH, iW}); |     auto input    = NDArrayFactory::create<TypeParam>('c', {bS, iC, iD, iH, iW}); | ||||||
|     auto weights  = NDArrayFactory::create<TypeParam>('c', {kD, kH, kW, iC, oC}); |     auto weights  = NDArrayFactory::create<TypeParam>('c', {kD, kH, kW, iC, oC}); | ||||||
|     auto bias     = NDArrayFactory::create<TypeParam>('c', {oC},{1.f, 2.f, 3.f}); |     auto bias     = NDArrayFactory::create<TypeParam>('c', {oC},{1.f, 2.f, 3.f}); | ||||||
|     auto expected = NDArrayFactory::create<TypeParam>('c', {2, 3, 2, 2, 2},{49.f,  49.f, 49.f,  49.f,  49.f,  49.f, 49.f,  49.f,  50.f,  50.f, 50.f,  50.f,  50.f,  50.f, 50.f,  50.f,  |     auto expected = NDArrayFactory::create<TypeParam>('c', {2, 3, 2, 2, 2},{49.f,  49.f, 49.f,  49.f,  49.f,  49.f, 49.f,  49.f,  50.f,  50.f, 50.f,  50.f,  50.f,  50.f, 50.f,  50.f, | ||||||
|                                                   51.f,  51.f, 51.f,  51.f,  51.f,  51.f, 51.f,  51.f,  49.f,  49.f, 49.f,  49.f,  49.f,  49.f, 49.f,  49.f,  |                                                   51.f,  51.f, 51.f,  51.f,  51.f,  51.f, 51.f,  51.f,  49.f,  49.f, 49.f,  49.f,  49.f,  49.f, 49.f,  49.f, | ||||||
|                                                   50.f,  50.f, 50.f,  50.f,  50.f,  50.f, 50.f,  50.f,  51.f,  51.f, 51.f,  51.f,  51.f,  51.f, 51.f,  51.f}); |                                                   50.f,  50.f, 50.f,  50.f,  50.f,  50.f, 50.f,  50.f,  51.f,  51.f, 51.f,  51.f,  51.f,  51.f, 51.f,  51.f}); | ||||||
|     input = 2.; |     input = 2.; | ||||||
|     weights = 0.5; |     weights = 0.5; | ||||||
| @ -1873,9 +1905,9 @@ TYPED_TEST(TypedConvolutionTests1, conv3d_test7) { | |||||||
|     auto input    = NDArrayFactory::create<TypeParam>('c', {bS, iC, iD, iH, iW}); |     auto input    = NDArrayFactory::create<TypeParam>('c', {bS, iC, iD, iH, iW}); | ||||||
|     auto weights  = NDArrayFactory::create<TypeParam>('c', {oC, iC, kD, kH, kW}); |     auto weights  = NDArrayFactory::create<TypeParam>('c', {oC, iC, kD, kH, kW}); | ||||||
|     auto bias     = NDArrayFactory::create<TypeParam>('c', {oC},{1.f, 2.f, 3.f}); |     auto bias     = NDArrayFactory::create<TypeParam>('c', {oC},{1.f, 2.f, 3.f}); | ||||||
|     auto expected = NDArrayFactory::create<TypeParam>('c', {2, 3, 2, 2, 2},{236.2f, 236.2f, 236.2f, 236.2f, 236.2f, 236.2f, 236.2f, 236.2f, 698.f, 698.f, 698.f, 698.f,  |     auto expected = NDArrayFactory::create<TypeParam>('c', {2, 3, 2, 2, 2},{236.2f, 236.2f, 236.2f, 236.2f, 236.2f, 236.2f, 236.2f, 236.2f, 698.f, 698.f, 698.f, 698.f, | ||||||
|                                                   698.f, 698.f, 698.f, 698.f, 1159.8f, 1159.8f, 1159.8f, 1159.8f, 1159.8f, 1159.8f, 1159.8f, 1159.8f,  |                                                   698.f, 698.f, 698.f, 698.f, 1159.8f, 1159.8f, 1159.8f, 1159.8f, 1159.8f, 1159.8f, 1159.8f, 1159.8f, | ||||||
|                                                   236.2f, 236.2f, 236.2f, 236.2f, 236.2f, 236.2f, 236.2f, 236.2f, 698.f, 698.f, 698.f, 698.f,  |                                                   236.2f, 236.2f, 236.2f, 236.2f, 236.2f, 236.2f, 236.2f, 236.2f, 698.f, 698.f, 698.f, 698.f, | ||||||
|                                                   698.f, 698.f, 698.f, 698.f, 1159.8f, 1159.8f, 1159.8f, 1159.8f, 1159.8f, 1159.8f, 1159.8f, 1159.8f}); |                                                   698.f, 698.f, 698.f, 698.f, 1159.8f, 1159.8f, 1159.8f, 1159.8f, 1159.8f, 1159.8f, 1159.8f, 1159.8f}); | ||||||
|     input = 2.; |     input = 2.; | ||||||
|     weights.linspace(0.1, 0.1); |     weights.linspace(0.1, 0.1); | ||||||
| @ -1903,8 +1935,8 @@ TYPED_TEST(TypedConvolutionTests1, conv3d_test8) { | |||||||
| 
 | 
 | ||||||
|     auto input    = NDArrayFactory::create<TypeParam>('c', {bS, iC, iD, iH, iW}); |     auto input    = NDArrayFactory::create<TypeParam>('c', {bS, iC, iD, iH, iW}); | ||||||
|     auto weights  = NDArrayFactory::create<TypeParam>('c', {oC, iC, kD, kH, kW}); |     auto weights  = NDArrayFactory::create<TypeParam>('c', {oC, iC, kD, kH, kW}); | ||||||
|     auto expected = NDArrayFactory::create<TypeParam>('c', {2, 3, 2, 2, 2},{235.2f, 235.2f, 235.2f, 235.2f, 235.2f, 235.2f, 235.2f, 235.2f, 696.f, 696.f, 696.f, 696.f, 696.f, 696.f, 696.f, 696.f,  |     auto expected = NDArrayFactory::create<TypeParam>('c', {2, 3, 2, 2, 2},{235.2f, 235.2f, 235.2f, 235.2f, 235.2f, 235.2f, 235.2f, 235.2f, 696.f, 696.f, 696.f, 696.f, 696.f, 696.f, 696.f, 696.f, | ||||||
|                                                   1156.8f, 1156.8f, 1156.8f, 1156.8f, 1156.8f, 1156.8f, 1156.8f, 1156.8f, 235.2f, 235.2f, 235.2f, 235.2f, 235.2f, 235.2f, 235.2f, 235.2f,  |                                                   1156.8f, 1156.8f, 1156.8f, 1156.8f, 1156.8f, 1156.8f, 1156.8f, 1156.8f, 235.2f, 235.2f, 235.2f, 235.2f, 235.2f, 235.2f, 235.2f, 235.2f, | ||||||
|                                                   696.f, 696.f, 696.f, 696.f, 696.f, 696.f, 696.f, 696.f, 1156.8f, 1156.8f, 1156.8f, 1156.8f, 1156.8f, 1156.8f, 1156.8f, 1156.8f}); |                                                   696.f, 696.f, 696.f, 696.f, 696.f, 696.f, 696.f, 696.f, 1156.8f, 1156.8f, 1156.8f, 1156.8f, 1156.8f, 1156.8f, 1156.8f, 1156.8f}); | ||||||
|     input = 2.; |     input = 2.; | ||||||
|     weights.linspace(0.1, 0.1); |     weights.linspace(0.1, 0.1); | ||||||
| @ -1997,9 +2029,9 @@ TYPED_TEST(TypedConvolutionTests1, pointwise_conv2d_test1) { | |||||||
|     auto bias     = NDArrayFactory::create<TypeParam>('c', {oC}); |     auto bias     = NDArrayFactory::create<TypeParam>('c', {oC}); | ||||||
| 
 | 
 | ||||||
| 
 | 
 | ||||||
|     auto expOutput = NDArrayFactory::create<TypeParam>('c', {bS, iH, iW, oC},{ 5.4f, 6.2f, 7.0f, 5.4f, 6.2f, 7.0f, 5.4f, 6.2f, 7.0f, 5.4f, 6.2f, 7.0f, 5.4f, 6.2f, 7.0f, 5.4f, 6.2f, 7.0f, 5.4f, 6.2f,  |     auto expOutput = NDArrayFactory::create<TypeParam>('c', {bS, iH, iW, oC},{ 5.4f, 6.2f, 7.0f, 5.4f, 6.2f, 7.0f, 5.4f, 6.2f, 7.0f, 5.4f, 6.2f, 7.0f, 5.4f, 6.2f, 7.0f, 5.4f, 6.2f, 7.0f, 5.4f, 6.2f, | ||||||
|                                                       7.0f, 5.4f, 6.2f, 7.0f, 5.4f, 6.2f, 7.0f, 5.4f, 6.2f, 7.0f, 5.4f, 6.2f, 7.0f, 5.4f, 6.2f, 7.0f, 5.4f, 6.2f, 7.0f, 5.4f,  |                                                       7.0f, 5.4f, 6.2f, 7.0f, 5.4f, 6.2f, 7.0f, 5.4f, 6.2f, 7.0f, 5.4f, 6.2f, 7.0f, 5.4f, 6.2f, 7.0f, 5.4f, 6.2f, 7.0f, 5.4f, | ||||||
|                                                       6.2f, 7.0f, 5.4f, 6.2f, 7.0f, 5.4f, 6.2f, 7.0f, 5.4f, 6.2f, 7.0f, 5.4f, 6.2f, 7.0f, 5.4f, 6.2f, 7.0f, 5.4f, 6.2f, 7.0f,  |                                                       6.2f, 7.0f, 5.4f, 6.2f, 7.0f, 5.4f, 6.2f, 7.0f, 5.4f, 6.2f, 7.0f, 5.4f, 6.2f, 7.0f, 5.4f, 6.2f, 7.0f, 5.4f, 6.2f, 7.0f, | ||||||
|                                                       5.4f, 6.2f, 7.0f, 5.4f, 6.2f, 7.0f, 5.4f, 6.2f, 7.0f, 5.4f, 6.2f, 7.0f}); |                                                       5.4f, 6.2f, 7.0f, 5.4f, 6.2f, 7.0f, 5.4f, 6.2f, 7.0f, 5.4f, 6.2f, 7.0f}); | ||||||
|     input = 2.; |     input = 2.; | ||||||
|     weights.linspace(0.1, 0.1); |     weights.linspace(0.1, 0.1); | ||||||
| @ -2110,20 +2142,20 @@ TEST_F(ConvolutionTests1, vol2col_test2) { | |||||||
|     auto columns = NDArrayFactory::create<float>('c', {kD, iC, kH, oW, kW, bS, oD, oH}); |     auto columns = NDArrayFactory::create<float>('c', {kD, iC, kH, oW, kW, bS, oD, oH}); | ||||||
|     columns.permutei({5, 1, 0, 2, 4, 6, 7, 3}); |     columns.permutei({5, 1, 0, 2, 4, 6, 7, 3}); | ||||||
|     columns = -1.; |     columns = -1.; | ||||||
|     auto columnsExpected = NDArrayFactory::create<float>('c', {bS, iC, kD, kH, kW, oD, oH, oW}, {1.f, 2.f, 3.f, 4.f, 5.f, 6.f, 7.f, 8.f, 9.f,  |     auto columnsExpected = NDArrayFactory::create<float>('c', {bS, iC, kD, kH, kW, oD, oH, oW}, {1.f, 2.f, 3.f, 4.f, 5.f, 6.f, 7.f, 8.f, 9.f, | ||||||
| 10.f, 11.f, 12.f, 2.f, 0.f, 4.f, 0.f, 6.f, 0.f, 8.f, 0.f, 10.f, 0.f, 12.f, 0.f, 3.f, 4.f, 5.f, 6.f, 0.f, 0.f, 9.f, 10.f, 11.f, 12.f, 0.f, 0.f, 4.f, 0.f, 6.f, 0.f, 0.f, 0.f, 10.f, 0.f, 12.f, 0.f, 0.f, 0.f, 5.f, 6.f, 0.f, 0.f, 0.f, 0.f, 11.f, 12.f, 0.f, 0.f, 0.f, 0.f, 6.f, 0.f, 0.f, 0.f, 0.f, 0.f, 12.f, 0.f, 0.f, 0.f, 0.f, 0.f, 7.f, 8.f,  | 10.f, 11.f, 12.f, 2.f, 0.f, 4.f, 0.f, 6.f, 0.f, 8.f, 0.f, 10.f, 0.f, 12.f, 0.f, 3.f, 4.f, 5.f, 6.f, 0.f, 0.f, 9.f, 10.f, 11.f, 12.f, 0.f, 0.f, 4.f, 0.f, 6.f, 0.f, 0.f, 0.f, 10.f, 0.f, 12.f, 0.f, 0.f, 0.f, 5.f, 6.f, 0.f, 0.f, 0.f, 0.f, 11.f, 12.f, 0.f, 0.f, 0.f, 0.f, 6.f, 0.f, 0.f, 0.f, 0.f, 0.f, 12.f, 0.f, 0.f, 0.f, 0.f, 0.f, 7.f, 8.f, | ||||||
| 9.f, 10.f, 11.f, 12.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 8.f, 0.f, 10.f, 0.f, 12.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 9.f, 10.f, 11.f, 12.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 10.f, 0.f, 12.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 11.f, 12.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 12.f, 0.f, 0.f, 0.f, 0.f, 0.f,  | 9.f, 10.f, 11.f, 12.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 8.f, 0.f, 10.f, 0.f, 12.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 9.f, 10.f, 11.f, 12.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 10.f, 0.f, 12.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 11.f, 12.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 12.f, 0.f, 0.f, 0.f, 0.f, 0.f, | ||||||
| 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 13.f, 14.f, 15.f, 16.f, 17.f, 18.f, 19.f, 20.f, 21.f, 22.f, 23.f, 24.f, 14.f, 0.f, 16.f, 0.f, 18.f, 0.f, 20.f, 0.f, 22.f, 0.f, 24.f, 0.f, 15.f, 16.f, 17.f, 18.f, 0.f, 0.f, 21.f, 22.f, 23.f, 24.f, 0.f, 0.f, 16.f, 0.f, 18.f, 0.f, 0.f, 0.f, 22.f, 0.f, 24.f, 0.f, 0.f, 0.f, 17.f, 18.f, 0.f, 0.f, 0.f, 0.f,  | 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 13.f, 14.f, 15.f, 16.f, 17.f, 18.f, 19.f, 20.f, 21.f, 22.f, 23.f, 24.f, 14.f, 0.f, 16.f, 0.f, 18.f, 0.f, 20.f, 0.f, 22.f, 0.f, 24.f, 0.f, 15.f, 16.f, 17.f, 18.f, 0.f, 0.f, 21.f, 22.f, 23.f, 24.f, 0.f, 0.f, 16.f, 0.f, 18.f, 0.f, 0.f, 0.f, 22.f, 0.f, 24.f, 0.f, 0.f, 0.f, 17.f, 18.f, 0.f, 0.f, 0.f, 0.f, | ||||||
| 23.f, 24.f, 0.f, 0.f, 0.f, 0.f, 18.f, 0.f, 0.f, 0.f, 0.f, 0.f, 24.f, 0.f, 0.f, 0.f, 0.f, 0.f, 19.f, 20.f, 21.f, 22.f, 23.f, 24.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 20.f, 0.f, 22.f, 0.f, 24.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 21.f, 22.f, 23.f, 24.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 22.f, 0.f, 24.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f,  | 23.f, 24.f, 0.f, 0.f, 0.f, 0.f, 18.f, 0.f, 0.f, 0.f, 0.f, 0.f, 24.f, 0.f, 0.f, 0.f, 0.f, 0.f, 19.f, 20.f, 21.f, 22.f, 23.f, 24.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 20.f, 0.f, 22.f, 0.f, 24.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 21.f, 22.f, 23.f, 24.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 22.f, 0.f, 24.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, | ||||||
| 0.f, 0.f, 0.f, 23.f, 24.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 24.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 25.f, 26.f, 27.f, 28.f, 29.f, 30.f, 31.f, 32.f, 33.f, 34.f, 35.f, 36.f, 26.f, 0.f, 28.f, 0.f, 30.f, 0.f, 32.f, 0.f, 34.f, 0.f, 36.f, 0.f, 27.f, 28.f, 29.f, 30.f, 0.f, 0.f, 33.f, 34.f, 35.f, 36.f,  | 0.f, 0.f, 0.f, 23.f, 24.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 24.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 25.f, 26.f, 27.f, 28.f, 29.f, 30.f, 31.f, 32.f, 33.f, 34.f, 35.f, 36.f, 26.f, 0.f, 28.f, 0.f, 30.f, 0.f, 32.f, 0.f, 34.f, 0.f, 36.f, 0.f, 27.f, 28.f, 29.f, 30.f, 0.f, 0.f, 33.f, 34.f, 35.f, 36.f, | ||||||
| 0.f, 0.f, 28.f, 0.f, 30.f, 0.f, 0.f, 0.f, 34.f, 0.f, 36.f, 0.f, 0.f, 0.f, 29.f, 30.f, 0.f, 0.f, 0.f, 0.f, 35.f, 36.f, 0.f, 0.f, 0.f, 0.f, 30.f, 0.f, 0.f, 0.f, 0.f, 0.f, 36.f, 0.f, 0.f, 0.f, 0.f, 0.f, 31.f, 32.f, 33.f, 34.f, 35.f, 36.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 32.f, 0.f, 34.f, 0.f, 36.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 33.f,  | 0.f, 0.f, 28.f, 0.f, 30.f, 0.f, 0.f, 0.f, 34.f, 0.f, 36.f, 0.f, 0.f, 0.f, 29.f, 30.f, 0.f, 0.f, 0.f, 0.f, 35.f, 36.f, 0.f, 0.f, 0.f, 0.f, 30.f, 0.f, 0.f, 0.f, 0.f, 0.f, 36.f, 0.f, 0.f, 0.f, 0.f, 0.f, 31.f, 32.f, 33.f, 34.f, 35.f, 36.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 32.f, 0.f, 34.f, 0.f, 36.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 33.f, | ||||||
| 34.f, 35.f, 36.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 34.f, 0.f, 36.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 35.f, 36.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 36.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 37.f, 38.f, 39.f, 40.f, 41.f, 42.f, 43.f, 44.f, 45.f, 46.f, 47.f, 48.f, 38.f, 0.f, 40.f,  | 34.f, 35.f, 36.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 34.f, 0.f, 36.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 35.f, 36.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 36.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 37.f, 38.f, 39.f, 40.f, 41.f, 42.f, 43.f, 44.f, 45.f, 46.f, 47.f, 48.f, 38.f, 0.f, 40.f, | ||||||
| 0.f, 42.f, 0.f, 44.f, 0.f, 46.f, 0.f, 48.f, 0.f, 39.f, 40.f, 41.f, 42.f, 0.f, 0.f, 45.f, 46.f, 47.f, 48.f, 0.f, 0.f, 40.f, 0.f, 42.f, 0.f, 0.f, 0.f, 46.f, 0.f, 48.f, 0.f, 0.f, 0.f, 41.f, 42.f, 0.f, 0.f, 0.f, 0.f, 47.f, 48.f, 0.f, 0.f, 0.f, 0.f, 42.f, 0.f, 0.f, 0.f, 0.f, 0.f, 48.f, 0.f, 0.f, 0.f, 0.f, 0.f, 43.f, 44.f, 45.f, 46.f, 47.f,  | 0.f, 42.f, 0.f, 44.f, 0.f, 46.f, 0.f, 48.f, 0.f, 39.f, 40.f, 41.f, 42.f, 0.f, 0.f, 45.f, 46.f, 47.f, 48.f, 0.f, 0.f, 40.f, 0.f, 42.f, 0.f, 0.f, 0.f, 46.f, 0.f, 48.f, 0.f, 0.f, 0.f, 41.f, 42.f, 0.f, 0.f, 0.f, 0.f, 47.f, 48.f, 0.f, 0.f, 0.f, 0.f, 42.f, 0.f, 0.f, 0.f, 0.f, 0.f, 48.f, 0.f, 0.f, 0.f, 0.f, 0.f, 43.f, 44.f, 45.f, 46.f, 47.f, | ||||||
| 48.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 44.f, 0.f, 46.f, 0.f, 48.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 45.f, 46.f, 47.f, 48.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 46.f, 0.f, 48.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 47.f, 48.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 48.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f,  | 48.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 44.f, 0.f, 46.f, 0.f, 48.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 45.f, 46.f, 47.f, 48.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 46.f, 0.f, 48.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 47.f, 48.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 48.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, | ||||||
| 0.f, 0.f, 49.f, 50.f, 51.f, 52.f, 53.f, 54.f, 55.f, 56.f, 57.f, 58.f, 59.f, 60.f, 50.f, 0.f, 52.f, 0.f, 54.f, 0.f, 56.f, 0.f, 58.f, 0.f, 60.f, 0.f, 51.f, 52.f, 53.f, 54.f, 0.f, 0.f, 57.f, 58.f, 59.f, 60.f, 0.f, 0.f, 52.f, 0.f, 54.f, 0.f, 0.f, 0.f, 58.f, 0.f, 60.f, 0.f, 0.f, 0.f, 53.f, 54.f, 0.f, 0.f, 0.f, 0.f, 59.f, 60.f, 0.f, 0.f,  | 0.f, 0.f, 49.f, 50.f, 51.f, 52.f, 53.f, 54.f, 55.f, 56.f, 57.f, 58.f, 59.f, 60.f, 50.f, 0.f, 52.f, 0.f, 54.f, 0.f, 56.f, 0.f, 58.f, 0.f, 60.f, 0.f, 51.f, 52.f, 53.f, 54.f, 0.f, 0.f, 57.f, 58.f, 59.f, 60.f, 0.f, 0.f, 52.f, 0.f, 54.f, 0.f, 0.f, 0.f, 58.f, 0.f, 60.f, 0.f, 0.f, 0.f, 53.f, 54.f, 0.f, 0.f, 0.f, 0.f, 59.f, 60.f, 0.f, 0.f, | ||||||
| 0.f, 0.f, 54.f, 0.f, 0.f, 0.f, 0.f, 0.f, 60.f, 0.f, 0.f, 0.f, 0.f, 0.f, 55.f, 56.f, 57.f, 58.f, 59.f, 60.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 56.f, 0.f, 58.f, 0.f, 60.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 57.f, 58.f, 59.f, 60.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 58.f, 0.f, 60.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 59.f, 60.f,  | 0.f, 0.f, 54.f, 0.f, 0.f, 0.f, 0.f, 0.f, 60.f, 0.f, 0.f, 0.f, 0.f, 0.f, 55.f, 56.f, 57.f, 58.f, 59.f, 60.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 56.f, 0.f, 58.f, 0.f, 60.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 57.f, 58.f, 59.f, 60.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 58.f, 0.f, 60.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 59.f, 60.f, | ||||||
| 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 60.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 61.f, 62.f, 63.f, 64.f, 65.f, 66.f, 67.f, 68.f, 69.f, 70.f, 71.f, 72.f, 62.f, 0.f, 64.f, 0.f, 66.f, 0.f, 68.f, 0.f, 70.f, 0.f, 72.f, 0.f, 63.f, 64.f, 65.f, 66.f, 0.f, 0.f, 69.f, 70.f, 71.f, 72.f, 0.f, 0.f, 64.f, 0.f, 66.f,  | 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 60.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 61.f, 62.f, 63.f, 64.f, 65.f, 66.f, 67.f, 68.f, 69.f, 70.f, 71.f, 72.f, 62.f, 0.f, 64.f, 0.f, 66.f, 0.f, 68.f, 0.f, 70.f, 0.f, 72.f, 0.f, 63.f, 64.f, 65.f, 66.f, 0.f, 0.f, 69.f, 70.f, 71.f, 72.f, 0.f, 0.f, 64.f, 0.f, 66.f, | ||||||
| 0.f, 0.f, 0.f, 70.f, 0.f, 72.f, 0.f, 0.f, 0.f, 65.f, 66.f, 0.f, 0.f, 0.f, 0.f, 71.f, 72.f, 0.f, 0.f, 0.f, 0.f, 66.f, 0.f, 0.f, 0.f, 0.f, 0.f, 72.f, 0.f, 0.f, 0.f, 0.f, 0.f, 67.f, 68.f, 69.f, 70.f, 71.f, 72.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 68.f, 0.f, 70.f, 0.f, 72.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 69.f, 70.f, 71.f, 72.f, 0.f, 0.f,  | 0.f, 0.f, 0.f, 70.f, 0.f, 72.f, 0.f, 0.f, 0.f, 65.f, 66.f, 0.f, 0.f, 0.f, 0.f, 71.f, 72.f, 0.f, 0.f, 0.f, 0.f, 66.f, 0.f, 0.f, 0.f, 0.f, 0.f, 72.f, 0.f, 0.f, 0.f, 0.f, 0.f, 67.f, 68.f, 69.f, 70.f, 71.f, 72.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 68.f, 0.f, 70.f, 0.f, 72.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 69.f, 70.f, 71.f, 72.f, 0.f, 0.f, | ||||||
| 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 70.f, 0.f, 72.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 71.f, 72.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 72.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f}); | 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 70.f, 0.f, 72.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 71.f, 72.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 72.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f}); | ||||||
| 
 | 
 | ||||||
|     graph::Context context(1); |     graph::Context context(1); | ||||||
| @ -2164,11 +2196,11 @@ TEST_F(ConvolutionTests1, upsampling2d_test1) { | |||||||
|     auto input  = NDArrayFactory::create<float>('c', {bS, iH, iW, iC}); |     auto input  = NDArrayFactory::create<float>('c', {bS, iH, iW, iC}); | ||||||
|     input.linspace(1); |     input.linspace(1); | ||||||
| 
 | 
 | ||||||
|     auto expOutput = NDArrayFactory::create<float>('c', {bS, iH*factorH, iW*factorW, iC}, {1.f, 2.f, 3.f, 1.f, 2.f, 3.f, 1.f, 2.f, 3.f, 4.f, 5.f, 6.f, 4.f, 5.f, 6.f, 4.f, 5.f, 6.f, 1.f, 2.f, 3.f, 1.f, 2.f, 3.f, 1.f, 2.f, 3.f, 4.f, 5.f, 6.f, 4.f, 5.f, 6.f, 4.f, 5.f, 6.f,  |     auto expOutput = NDArrayFactory::create<float>('c', {bS, iH*factorH, iW*factorW, iC}, {1.f, 2.f, 3.f, 1.f, 2.f, 3.f, 1.f, 2.f, 3.f, 4.f, 5.f, 6.f, 4.f, 5.f, 6.f, 4.f, 5.f, 6.f, 1.f, 2.f, 3.f, 1.f, 2.f, 3.f, 1.f, 2.f, 3.f, 4.f, 5.f, 6.f, 4.f, 5.f, 6.f, 4.f, 5.f, 6.f, | ||||||
|                                          7.f, 8.f, 9.f, 7.f, 8.f, 9.f, 7.f, 8.f, 9.f, 10.f, 11.f, 12.f, 10.f, 11.f, 12.f, 10.f, 11.f, 12.f, 7.f, 8.f, 9.f, 7.f, 8.f, 9.f, 7.f, 8.f, 9.f, 10.f, 11.f, 12.f, 10.f, 11.f, 12.f, 10.f, 11.f, 12.f,  |                                          7.f, 8.f, 9.f, 7.f, 8.f, 9.f, 7.f, 8.f, 9.f, 10.f, 11.f, 12.f, 10.f, 11.f, 12.f, 10.f, 11.f, 12.f, 7.f, 8.f, 9.f, 7.f, 8.f, 9.f, 7.f, 8.f, 9.f, 10.f, 11.f, 12.f, 10.f, 11.f, 12.f, 10.f, 11.f, 12.f, | ||||||
|                                         13.f, 14.f, 15.f, 13.f, 14.f, 15.f, 13.f, 14.f, 15.f, 16.f, 17.f, 18.f, 16.f, 17.f, 18.f, 16.f, 17.f, 18.f, 13.f, 14.f, 15.f, 13.f, 14.f, 15.f, 13.f, 14.f, 15.f, 16.f, 17.f, 18.f, 16.f, 17.f, 18.f, 16.f, 17.f, 18.f,  |                                         13.f, 14.f, 15.f, 13.f, 14.f, 15.f, 13.f, 14.f, 15.f, 16.f, 17.f, 18.f, 16.f, 17.f, 18.f, 16.f, 17.f, 18.f, 13.f, 14.f, 15.f, 13.f, 14.f, 15.f, 13.f, 14.f, 15.f, 16.f, 17.f, 18.f, 16.f, 17.f, 18.f, 16.f, 17.f, 18.f, | ||||||
|                                         19.f, 20.f, 21.f, 19.f, 20.f, 21.f, 19.f, 20.f, 21.f, 22.f, 23.f, 24.f, 22.f, 23.f, 24.f, 22.f, 23.f, 24.f, 19.f, 20.f, 21.f, 19.f, 20.f, 21.f, 19.f, 20.f, 21.f, 22.f, 23.f, 24.f, 22.f, 23.f, 24.f, 22.f, 23.f, 24.f,  |                                         19.f, 20.f, 21.f, 19.f, 20.f, 21.f, 19.f, 20.f, 21.f, 22.f, 23.f, 24.f, 22.f, 23.f, 24.f, 22.f, 23.f, 24.f, 19.f, 20.f, 21.f, 19.f, 20.f, 21.f, 19.f, 20.f, 21.f, 22.f, 23.f, 24.f, 22.f, 23.f, 24.f, 22.f, 23.f, 24.f, | ||||||
|                                         25.f, 26.f, 27.f, 25.f, 26.f, 27.f, 25.f, 26.f, 27.f, 28.f, 29.f, 30.f, 28.f, 29.f, 30.f, 28.f, 29.f, 30.f, 25.f, 26.f, 27.f, 25.f, 26.f, 27.f, 25.f, 26.f, 27.f, 28.f, 29.f, 30.f, 28.f, 29.f, 30.f, 28.f, 29.f, 30.f,  |                                         25.f, 26.f, 27.f, 25.f, 26.f, 27.f, 25.f, 26.f, 27.f, 28.f, 29.f, 30.f, 28.f, 29.f, 30.f, 28.f, 29.f, 30.f, 25.f, 26.f, 27.f, 25.f, 26.f, 27.f, 25.f, 26.f, 27.f, 28.f, 29.f, 30.f, 28.f, 29.f, 30.f, 28.f, 29.f, 30.f, | ||||||
|                                         31.f, 32.f, 33.f, 31.f, 32.f, 33.f, 31.f, 32.f, 33.f, 34.f, 35.f, 36.f, 34.f, 35.f, 36.f, 34.f, 35.f, 36.f, 31.f, 32.f, 33.f, 31.f, 32.f, 33.f, 31.f, 32.f, 33.f, 34.f, 35.f, 36.f, 34.f, 35.f, 36.f, 34.f, 35.f, 36.f}); |                                         31.f, 32.f, 33.f, 31.f, 32.f, 33.f, 31.f, 32.f, 33.f, 34.f, 35.f, 36.f, 34.f, 35.f, 36.f, 34.f, 35.f, 36.f, 31.f, 32.f, 33.f, 31.f, 32.f, 33.f, 31.f, 32.f, 33.f, 34.f, 35.f, 36.f, 34.f, 35.f, 36.f, 34.f, 35.f, 36.f}); | ||||||
| 
 | 
 | ||||||
|     nd4j::ops::upsampling2d op; |     nd4j::ops::upsampling2d op; | ||||||
| @ -2192,11 +2224,11 @@ TEST_F(ConvolutionTests1, upsampling2d_test2) { | |||||||
|     auto input  = NDArrayFactory::create<float>('c', {bS, iC, iH, iW}); |     auto input  = NDArrayFactory::create<float>('c', {bS, iC, iH, iW}); | ||||||
|     input.linspace(1); |     input.linspace(1); | ||||||
| 
 | 
 | ||||||
|     auto expOutput = NDArrayFactory::create<float>('c', {bS, iC, iH*factorH, iW*factorW}, {1.f, 1.f, 1.f, 2.f, 2.f, 2.f, 1.f, 1.f, 1.f, 2.f, 2.f, 2.f, 3.f, 3.f, 3.f, 4.f, 4.f, 4.f, 3.f, 3.f, 3.f, 4.f, 4.f, 4.f,  |     auto expOutput = NDArrayFactory::create<float>('c', {bS, iC, iH*factorH, iW*factorW}, {1.f, 1.f, 1.f, 2.f, 2.f, 2.f, 1.f, 1.f, 1.f, 2.f, 2.f, 2.f, 3.f, 3.f, 3.f, 4.f, 4.f, 4.f, 3.f, 3.f, 3.f, 4.f, 4.f, 4.f, | ||||||
|                                  5.f, 5.f, 5.f, 6.f, 6.f, 6.f, 5.f, 5.f, 5.f, 6.f, 6.f, 6.f, 7.f, 7.f, 7.f, 8.f, 8.f, 8.f, 7.f, 7.f, 7.f, 8.f, 8.f, 8.f, 9.f, 9.f, 9.f, 10.f, 10.f, 10.f, 9.f, 9.f, 9.f, 10.f, 10.f, 10.f, 11.f, 11.f, 11.f, 12.f, 12.f, 12.f, 11.f, 11.f, 11.f, 12.f, 12.f, 12.f,  |                                  5.f, 5.f, 5.f, 6.f, 6.f, 6.f, 5.f, 5.f, 5.f, 6.f, 6.f, 6.f, 7.f, 7.f, 7.f, 8.f, 8.f, 8.f, 7.f, 7.f, 7.f, 8.f, 8.f, 8.f, 9.f, 9.f, 9.f, 10.f, 10.f, 10.f, 9.f, 9.f, 9.f, 10.f, 10.f, 10.f, 11.f, 11.f, 11.f, 12.f, 12.f, 12.f, 11.f, 11.f, 11.f, 12.f, 12.f, 12.f, | ||||||
|                                 13.f, 13.f, 13.f, 14.f, 14.f, 14.f, 13.f, 13.f, 13.f, 14.f, 14.f, 14.f, 15.f, 15.f, 15.f, 16.f, 16.f, 16.f, 15.f, 15.f, 15.f, 16.f, 16.f, 16.f, 17.f, 17.f, 17.f, 18.f, 18.f, 18.f, 17.f, 17.f, 17.f, 18.f, 18.f, 18.f, 19.f, 19.f, 19.f, 20.f, 20.f, 20.f, 19.f, 19.f, 19.f, 20.f, 20.f, 20.f,  |                                 13.f, 13.f, 13.f, 14.f, 14.f, 14.f, 13.f, 13.f, 13.f, 14.f, 14.f, 14.f, 15.f, 15.f, 15.f, 16.f, 16.f, 16.f, 15.f, 15.f, 15.f, 16.f, 16.f, 16.f, 17.f, 17.f, 17.f, 18.f, 18.f, 18.f, 17.f, 17.f, 17.f, 18.f, 18.f, 18.f, 19.f, 19.f, 19.f, 20.f, 20.f, 20.f, 19.f, 19.f, 19.f, 20.f, 20.f, 20.f, | ||||||
|                                 21.f, 21.f, 21.f, 22.f, 22.f, 22.f, 21.f, 21.f, 21.f, 22.f, 22.f, 22.f, 23.f, 23.f, 23.f, 24.f, 24.f, 24.f, 23.f, 23.f, 23.f, 24.f, 24.f, 24.f, 25.f, 25.f, 25.f, 26.f, 26.f, 26.f, 25.f, 25.f, 25.f, 26.f, 26.f, 26.f, 27.f, 27.f, 27.f, 28.f, 28.f, 28.f, 27.f, 27.f, 27.f, 28.f, 28.f, 28.f,  |                                 21.f, 21.f, 21.f, 22.f, 22.f, 22.f, 21.f, 21.f, 21.f, 22.f, 22.f, 22.f, 23.f, 23.f, 23.f, 24.f, 24.f, 24.f, 23.f, 23.f, 23.f, 24.f, 24.f, 24.f, 25.f, 25.f, 25.f, 26.f, 26.f, 26.f, 25.f, 25.f, 25.f, 26.f, 26.f, 26.f, 27.f, 27.f, 27.f, 28.f, 28.f, 28.f, 27.f, 27.f, 27.f, 28.f, 28.f, 28.f, | ||||||
|                                 29.f, 29.f, 29.f, 30.f, 30.f, 30.f, 29.f, 29.f, 29.f, 30.f, 30.f, 30.f, 31.f, 31.f, 31.f, 32.f, 32.f, 32.f, 31.f, 31.f, 31.f, 32.f, 32.f, 32.f,  |                                 29.f, 29.f, 29.f, 30.f, 30.f, 30.f, 29.f, 29.f, 29.f, 30.f, 30.f, 30.f, 31.f, 31.f, 31.f, 32.f, 32.f, 32.f, 31.f, 31.f, 31.f, 32.f, 32.f, 32.f, | ||||||
|                                 33.f, 33.f, 33.f, 34.f, 34.f, 34.f, 33.f, 33.f, 33.f, 34.f, 34.f, 34.f, 35.f, 35.f, 35.f, 36.f, 36.f, 36.f, 35.f, 35.f, 35.f, 36.f, 36.f, 36.f}); |                                 33.f, 33.f, 33.f, 34.f, 34.f, 34.f, 33.f, 33.f, 33.f, 34.f, 34.f, 34.f, 35.f, 35.f, 35.f, 36.f, 36.f, 36.f, 35.f, 35.f, 35.f, 36.f, 36.f, 36.f}); | ||||||
| 
 | 
 | ||||||
|     nd4j::ops::upsampling2d op; |     nd4j::ops::upsampling2d op; | ||||||
| @ -2221,20 +2253,20 @@ TEST_F(ConvolutionTests1, upsampling3d_test1) { | |||||||
|     auto input  = NDArrayFactory::create<float>('c', {bS, iD, iH, iW, iC}); |     auto input  = NDArrayFactory::create<float>('c', {bS, iD, iH, iW, iC}); | ||||||
|     input.linspace(1); |     input.linspace(1); | ||||||
| 
 | 
 | ||||||
|     auto expOutput = NDArrayFactory::create<float>('c', {bS, iD*factorD, iH*factorH, iW*factorW, iC}, {1.f, 2.f, 3.f, 1.f, 2.f, 3.f, 4.f, 5.f, 6.f, 4.f, 5.f, 6.f, 1.f, 2.f, 3.f, 1.f, 2.f, 3.f, 4.f, 5.f, 6.f, 4.f, 5.f, 6.f, 1.f, 2.f, 3.f, 1.f, 2.f, 3.f, 4.f, 5.f, 6.f, 4.f, 5.f, 6.f, 7.f, 8.f, 9.f, 7.f, 8.f, 9.f, 10.f, 11.f, 12.f, 10.f, 11.f, 12.f, 7.f, 8.f, 9.f, 7.f, 8.f, 9.f, 10.f, 11.f, 12.f, 10.f, 11.f, 12.f,  |     auto expOutput = NDArrayFactory::create<float>('c', {bS, iD*factorD, iH*factorH, iW*factorW, iC}, {1.f, 2.f, 3.f, 1.f, 2.f, 3.f, 4.f, 5.f, 6.f, 4.f, 5.f, 6.f, 1.f, 2.f, 3.f, 1.f, 2.f, 3.f, 4.f, 5.f, 6.f, 4.f, 5.f, 6.f, 1.f, 2.f, 3.f, 1.f, 2.f, 3.f, 4.f, 5.f, 6.f, 4.f, 5.f, 6.f, 7.f, 8.f, 9.f, 7.f, 8.f, 9.f, 10.f, 11.f, 12.f, 10.f, 11.f, 12.f, 7.f, 8.f, 9.f, 7.f, 8.f, 9.f, 10.f, 11.f, 12.f, 10.f, 11.f, 12.f, | ||||||
|              7.f, 8.f, 9.f, 7.f, 8.f, 9.f, 10.f, 11.f, 12.f, 10.f, 11.f, 12.f, 1.f, 2.f, 3.f, 1.f, 2.f, 3.f, 4.f, 5.f, 6.f, 4.f, 5.f, 6.f, 1.f, 2.f, 3.f, 1.f, 2.f, 3.f, 4.f, 5.f, 6.f, 4.f, 5.f, 6.f, 1.f, 2.f, 3.f, 1.f, 2.f, 3.f, 4.f, 5.f, 6.f, 4.f, 5.f, 6.f, 7.f, 8.f, 9.f, 7.f, 8.f, 9.f, 10.f, 11.f, 12.f, 10.f, 11.f, 12.f,  |              7.f, 8.f, 9.f, 7.f, 8.f, 9.f, 10.f, 11.f, 12.f, 10.f, 11.f, 12.f, 1.f, 2.f, 3.f, 1.f, 2.f, 3.f, 4.f, 5.f, 6.f, 4.f, 5.f, 6.f, 1.f, 2.f, 3.f, 1.f, 2.f, 3.f, 4.f, 5.f, 6.f, 4.f, 5.f, 6.f, 1.f, 2.f, 3.f, 1.f, 2.f, 3.f, 4.f, 5.f, 6.f, 4.f, 5.f, 6.f, 7.f, 8.f, 9.f, 7.f, 8.f, 9.f, 10.f, 11.f, 12.f, 10.f, 11.f, 12.f, | ||||||
|              7.f, 8.f, 9.f, 7.f, 8.f, 9.f, 10.f, 11.f, 12.f, 10.f, 11.f, 12.f, 7.f, 8.f, 9.f, 7.f, 8.f, 9.f, 10.f, 11.f, 12.f, 10.f, 11.f, 12.f, 13.f, 14.f, 15.f, 13.f, 14.f, 15.f, 16.f, 17.f, 18.f, 16.f, 17.f, 18.f, 13.f, 14.f, 15.f, 13.f, 14.f, 15.f, 16.f, 17.f, 18.f, 16.f, 17.f, 18.f, 13.f, 14.f, 15.f, 13.f, 14.f, 15.f, 16.f, 17.f, 18.f, 16.f, 17.f, 18.f,  |              7.f, 8.f, 9.f, 7.f, 8.f, 9.f, 10.f, 11.f, 12.f, 10.f, 11.f, 12.f, 7.f, 8.f, 9.f, 7.f, 8.f, 9.f, 10.f, 11.f, 12.f, 10.f, 11.f, 12.f, 13.f, 14.f, 15.f, 13.f, 14.f, 15.f, 16.f, 17.f, 18.f, 16.f, 17.f, 18.f, 13.f, 14.f, 15.f, 13.f, 14.f, 15.f, 16.f, 17.f, 18.f, 16.f, 17.f, 18.f, 13.f, 14.f, 15.f, 13.f, 14.f, 15.f, 16.f, 17.f, 18.f, 16.f, 17.f, 18.f, | ||||||
|             19.f, 20.f, 21.f, 19.f, 20.f, 21.f, 22.f, 23.f, 24.f, 22.f, 23.f, 24.f, 19.f, 20.f, 21.f, 19.f, 20.f, 21.f, 22.f, 23.f, 24.f, 22.f, 23.f, 24.f, 19.f, 20.f, 21.f, 19.f, 20.f, 21.f, 22.f, 23.f, 24.f, 22.f, 23.f, 24.f, 13.f, 14.f, 15.f, 13.f, 14.f, 15.f, 16.f, 17.f, 18.f, 16.f, 17.f, 18.f, 13.f, 14.f, 15.f, 13.f, 14.f, 15.f, 16.f, 17.f, 18.f, 16.f, 17.f, 18.f,  |             19.f, 20.f, 21.f, 19.f, 20.f, 21.f, 22.f, 23.f, 24.f, 22.f, 23.f, 24.f, 19.f, 20.f, 21.f, 19.f, 20.f, 21.f, 22.f, 23.f, 24.f, 22.f, 23.f, 24.f, 19.f, 20.f, 21.f, 19.f, 20.f, 21.f, 22.f, 23.f, 24.f, 22.f, 23.f, 24.f, 13.f, 14.f, 15.f, 13.f, 14.f, 15.f, 16.f, 17.f, 18.f, 16.f, 17.f, 18.f, 13.f, 14.f, 15.f, 13.f, 14.f, 15.f, 16.f, 17.f, 18.f, 16.f, 17.f, 18.f, | ||||||
|             13.f, 14.f, 15.f, 13.f, 14.f, 15.f, 16.f, 17.f, 18.f, 16.f, 17.f, 18.f, 19.f, 20.f, 21.f, 19.f, 20.f, 21.f, 22.f, 23.f, 24.f, 22.f, 23.f, 24.f, 19.f, 20.f, 21.f, 19.f, 20.f, 21.f, 22.f, 23.f, 24.f, 22.f, 23.f, 24.f, 19.f, 20.f, 21.f, 19.f, 20.f, 21.f, 22.f, 23.f, 24.f, 22.f, 23.f, 24.f, 25.f, 26.f, 27.f, 25.f, 26.f, 27.f, 28.f, 29.f, 30.f, 28.f, 29.f, 30.f,  |             13.f, 14.f, 15.f, 13.f, 14.f, 15.f, 16.f, 17.f, 18.f, 16.f, 17.f, 18.f, 19.f, 20.f, 21.f, 19.f, 20.f, 21.f, 22.f, 23.f, 24.f, 22.f, 23.f, 24.f, 19.f, 20.f, 21.f, 19.f, 20.f, 21.f, 22.f, 23.f, 24.f, 22.f, 23.f, 24.f, 19.f, 20.f, 21.f, 19.f, 20.f, 21.f, 22.f, 23.f, 24.f, 22.f, 23.f, 24.f, 25.f, 26.f, 27.f, 25.f, 26.f, 27.f, 28.f, 29.f, 30.f, 28.f, 29.f, 30.f, | ||||||
|             25.f, 26.f, 27.f, 25.f, 26.f, 27.f, 28.f, 29.f, 30.f, 28.f, 29.f, 30.f, 25.f, 26.f, 27.f, 25.f, 26.f, 27.f, 28.f, 29.f, 30.f, 28.f, 29.f, 30.f, 31.f, 32.f, 33.f, 31.f, 32.f, 33.f, 34.f, 35.f, 36.f, 34.f, 35.f, 36.f, 31.f, 32.f, 33.f, 31.f, 32.f, 33.f, 34.f, 35.f, 36.f, 34.f, 35.f, 36.f, 31.f, 32.f, 33.f, 31.f, 32.f, 33.f, 34.f, 35.f, 36.f, 34.f, 35.f, 36.f,  |             25.f, 26.f, 27.f, 25.f, 26.f, 27.f, 28.f, 29.f, 30.f, 28.f, 29.f, 30.f, 25.f, 26.f, 27.f, 25.f, 26.f, 27.f, 28.f, 29.f, 30.f, 28.f, 29.f, 30.f, 31.f, 32.f, 33.f, 31.f, 32.f, 33.f, 34.f, 35.f, 36.f, 34.f, 35.f, 36.f, 31.f, 32.f, 33.f, 31.f, 32.f, 33.f, 34.f, 35.f, 36.f, 34.f, 35.f, 36.f, 31.f, 32.f, 33.f, 31.f, 32.f, 33.f, 34.f, 35.f, 36.f, 34.f, 35.f, 36.f, | ||||||
|             25.f, 26.f, 27.f, 25.f, 26.f, 27.f, 28.f, 29.f, 30.f, 28.f, 29.f, 30.f, 25.f, 26.f, 27.f, 25.f, 26.f, 27.f, 28.f, 29.f, 30.f, 28.f, 29.f, 30.f, 25.f, 26.f, 27.f, 25.f, 26.f, 27.f, 28.f, 29.f, 30.f, 28.f, 29.f, 30.f, 31.f, 32.f, 33.f, 31.f, 32.f, 33.f, 34.f, 35.f, 36.f, 34.f, 35.f, 36.f, 31.f, 32.f, 33.f, 31.f, 32.f, 33.f, 34.f, 35.f, 36.f, 34.f, 35.f, 36.f,  |             25.f, 26.f, 27.f, 25.f, 26.f, 27.f, 28.f, 29.f, 30.f, 28.f, 29.f, 30.f, 25.f, 26.f, 27.f, 25.f, 26.f, 27.f, 28.f, 29.f, 30.f, 28.f, 29.f, 30.f, 25.f, 26.f, 27.f, 25.f, 26.f, 27.f, 28.f, 29.f, 30.f, 28.f, 29.f, 30.f, 31.f, 32.f, 33.f, 31.f, 32.f, 33.f, 34.f, 35.f, 36.f, 34.f, 35.f, 36.f, 31.f, 32.f, 33.f, 31.f, 32.f, 33.f, 34.f, 35.f, 36.f, 34.f, 35.f, 36.f, | ||||||
|             31.f, 32.f, 33.f, 31.f, 32.f, 33.f, 34.f, 35.f, 36.f, 34.f, 35.f, 36.f, 37.f, 38.f, 39.f, 37.f, 38.f, 39.f, 40.f, 41.f, 42.f, 40.f, 41.f, 42.f, 37.f, 38.f, 39.f, 37.f, 38.f, 39.f, 40.f, 41.f, 42.f, 40.f, 41.f, 42.f, 37.f, 38.f, 39.f, 37.f, 38.f, 39.f, 40.f, 41.f, 42.f, 40.f, 41.f, 42.f, 43.f, 44.f, 45.f, 43.f, 44.f, 45.f, 46.f, 47.f, 48.f, 46.f, 47.f, 48.f,  |             31.f, 32.f, 33.f, 31.f, 32.f, 33.f, 34.f, 35.f, 36.f, 34.f, 35.f, 36.f, 37.f, 38.f, 39.f, 37.f, 38.f, 39.f, 40.f, 41.f, 42.f, 40.f, 41.f, 42.f, 37.f, 38.f, 39.f, 37.f, 38.f, 39.f, 40.f, 41.f, 42.f, 40.f, 41.f, 42.f, 37.f, 38.f, 39.f, 37.f, 38.f, 39.f, 40.f, 41.f, 42.f, 40.f, 41.f, 42.f, 43.f, 44.f, 45.f, 43.f, 44.f, 45.f, 46.f, 47.f, 48.f, 46.f, 47.f, 48.f, | ||||||
|             43.f, 44.f, 45.f, 43.f, 44.f, 45.f, 46.f, 47.f, 48.f, 46.f, 47.f, 48.f, 43.f, 44.f, 45.f, 43.f, 44.f, 45.f, 46.f, 47.f, 48.f, 46.f, 47.f, 48.f, 37.f, 38.f, 39.f, 37.f, 38.f, 39.f, 40.f, 41.f, 42.f, 40.f, 41.f, 42.f, 37.f, 38.f, 39.f, 37.f, 38.f, 39.f, 40.f, 41.f, 42.f, 40.f, 41.f, 42.f, 37.f, 38.f, 39.f, 37.f, 38.f, 39.f, 40.f, 41.f, 42.f, 40.f, 41.f, 42.f,  |             43.f, 44.f, 45.f, 43.f, 44.f, 45.f, 46.f, 47.f, 48.f, 46.f, 47.f, 48.f, 43.f, 44.f, 45.f, 43.f, 44.f, 45.f, 46.f, 47.f, 48.f, 46.f, 47.f, 48.f, 37.f, 38.f, 39.f, 37.f, 38.f, 39.f, 40.f, 41.f, 42.f, 40.f, 41.f, 42.f, 37.f, 38.f, 39.f, 37.f, 38.f, 39.f, 40.f, 41.f, 42.f, 40.f, 41.f, 42.f, 37.f, 38.f, 39.f, 37.f, 38.f, 39.f, 40.f, 41.f, 42.f, 40.f, 41.f, 42.f, | ||||||
|             43.f, 44.f, 45.f, 43.f, 44.f, 45.f, 46.f, 47.f, 48.f, 46.f, 47.f, 48.f, 43.f, 44.f, 45.f, 43.f, 44.f, 45.f, 46.f, 47.f, 48.f, 46.f, 47.f, 48.f, 43.f, 44.f, 45.f, 43.f, 44.f, 45.f, 46.f, 47.f, 48.f, 46.f, 47.f, 48.f, 49.f, 50.f, 51.f, 49.f, 50.f, 51.f, 52.f, 53.f, 54.f, 52.f, 53.f, 54.f, 49.f, 50.f, 51.f, 49.f, 50.f, 51.f, 52.f, 53.f, 54.f, 52.f, 53.f, 54.f,  |             43.f, 44.f, 45.f, 43.f, 44.f, 45.f, 46.f, 47.f, 48.f, 46.f, 47.f, 48.f, 43.f, 44.f, 45.f, 43.f, 44.f, 45.f, 46.f, 47.f, 48.f, 46.f, 47.f, 48.f, 43.f, 44.f, 45.f, 43.f, 44.f, 45.f, 46.f, 47.f, 48.f, 46.f, 47.f, 48.f, 49.f, 50.f, 51.f, 49.f, 50.f, 51.f, 52.f, 53.f, 54.f, 52.f, 53.f, 54.f, 49.f, 50.f, 51.f, 49.f, 50.f, 51.f, 52.f, 53.f, 54.f, 52.f, 53.f, 54.f, | ||||||
|             49.f, 50.f, 51.f, 49.f, 50.f, 51.f, 52.f, 53.f, 54.f, 52.f, 53.f, 54.f, 55.f, 56.f, 57.f, 55.f, 56.f, 57.f, 58.f, 59.f, 60.f, 58.f, 59.f, 60.f, 55.f, 56.f, 57.f, 55.f, 56.f, 57.f, 58.f, 59.f, 60.f, 58.f, 59.f, 60.f, 55.f, 56.f, 57.f, 55.f, 56.f, 57.f, 58.f, 59.f, 60.f, 58.f, 59.f, 60.f, 49.f, 50.f, 51.f, 49.f, 50.f, 51.f, 52.f, 53.f, 54.f, 52.f, 53.f, 54.f,  |             49.f, 50.f, 51.f, 49.f, 50.f, 51.f, 52.f, 53.f, 54.f, 52.f, 53.f, 54.f, 55.f, 56.f, 57.f, 55.f, 56.f, 57.f, 58.f, 59.f, 60.f, 58.f, 59.f, 60.f, 55.f, 56.f, 57.f, 55.f, 56.f, 57.f, 58.f, 59.f, 60.f, 58.f, 59.f, 60.f, 55.f, 56.f, 57.f, 55.f, 56.f, 57.f, 58.f, 59.f, 60.f, 58.f, 59.f, 60.f, 49.f, 50.f, 51.f, 49.f, 50.f, 51.f, 52.f, 53.f, 54.f, 52.f, 53.f, 54.f, | ||||||
|             49.f, 50.f, 51.f, 49.f, 50.f, 51.f, 52.f, 53.f, 54.f, 52.f, 53.f, 54.f, 49.f, 50.f, 51.f, 49.f, 50.f, 51.f, 52.f, 53.f, 54.f, 52.f, 53.f, 54.f, 55.f, 56.f, 57.f, 55.f, 56.f, 57.f, 58.f, 59.f, 60.f, 58.f, 59.f, 60.f, 55.f, 56.f, 57.f, 55.f, 56.f, 57.f, 58.f, 59.f, 60.f, 58.f, 59.f, 60.f, 55.f, 56.f, 57.f, 55.f, 56.f, 57.f, 58.f, 59.f, 60.f, 58.f, 59.f, 60.f,  |             49.f, 50.f, 51.f, 49.f, 50.f, 51.f, 52.f, 53.f, 54.f, 52.f, 53.f, 54.f, 49.f, 50.f, 51.f, 49.f, 50.f, 51.f, 52.f, 53.f, 54.f, 52.f, 53.f, 54.f, 55.f, 56.f, 57.f, 55.f, 56.f, 57.f, 58.f, 59.f, 60.f, 58.f, 59.f, 60.f, 55.f, 56.f, 57.f, 55.f, 56.f, 57.f, 58.f, 59.f, 60.f, 58.f, 59.f, 60.f, 55.f, 56.f, 57.f, 55.f, 56.f, 57.f, 58.f, 59.f, 60.f, 58.f, 59.f, 60.f, | ||||||
|             61.f, 62.f, 63.f, 61.f, 62.f, 63.f, 64.f, 65.f, 66.f, 64.f, 65.f, 66.f, 61.f, 62.f, 63.f, 61.f, 62.f, 63.f, 64.f, 65.f, 66.f, 64.f, 65.f, 66.f, 61.f, 62.f, 63.f, 61.f, 62.f, 63.f, 64.f, 65.f, 66.f, 64.f, 65.f, 66.f, 67.f, 68.f, 69.f, 67.f, 68.f, 69.f, 70.f, 71.f, 72.f, 70.f, 71.f, 72.f, 67.f, 68.f, 69.f, 67.f, 68.f, 69.f, 70.f, 71.f, 72.f, 70.f, 71.f, 72.f,  |             61.f, 62.f, 63.f, 61.f, 62.f, 63.f, 64.f, 65.f, 66.f, 64.f, 65.f, 66.f, 61.f, 62.f, 63.f, 61.f, 62.f, 63.f, 64.f, 65.f, 66.f, 64.f, 65.f, 66.f, 61.f, 62.f, 63.f, 61.f, 62.f, 63.f, 64.f, 65.f, 66.f, 64.f, 65.f, 66.f, 67.f, 68.f, 69.f, 67.f, 68.f, 69.f, 70.f, 71.f, 72.f, 70.f, 71.f, 72.f, 67.f, 68.f, 69.f, 67.f, 68.f, 69.f, 70.f, 71.f, 72.f, 70.f, 71.f, 72.f, | ||||||
|             67.f, 68.f, 69.f, 67.f, 68.f, 69.f, 70.f, 71.f, 72.f, 70.f, 71.f, 72.f, 61.f, 62.f, 63.f, 61.f, 62.f, 63.f, 64.f, 65.f, 66.f, 64.f, 65.f, 66.f, 61.f, 62.f, 63.f, 61.f, 62.f, 63.f, 64.f, 65.f, 66.f, 64.f, 65.f, 66.f, 61.f, 62.f, 63.f, 61.f, 62.f, 63.f, 64.f, 65.f, 66.f, 64.f, 65.f, 66.f, 67.f, 68.f, 69.f, 67.f, 68.f, 69.f, 70.f, 71.f, 72.f, 70.f, 71.f, 72.f,  |             67.f, 68.f, 69.f, 67.f, 68.f, 69.f, 70.f, 71.f, 72.f, 70.f, 71.f, 72.f, 61.f, 62.f, 63.f, 61.f, 62.f, 63.f, 64.f, 65.f, 66.f, 64.f, 65.f, 66.f, 61.f, 62.f, 63.f, 61.f, 62.f, 63.f, 64.f, 65.f, 66.f, 64.f, 65.f, 66.f, 61.f, 62.f, 63.f, 61.f, 62.f, 63.f, 64.f, 65.f, 66.f, 64.f, 65.f, 66.f, 67.f, 68.f, 69.f, 67.f, 68.f, 69.f, 70.f, 71.f, 72.f, 70.f, 71.f, 72.f, | ||||||
|             67.f, 68.f, 69.f, 67.f, 68.f, 69.f, 70.f, 71.f, 72.f, 70.f, 71.f, 72.f, 67.f, 68.f, 69.f, 67.f, 68.f, 69.f, 70.f, 71.f, 72.f, 70.f, 71.f, 72.f}); |             67.f, 68.f, 69.f, 67.f, 68.f, 69.f, 70.f, 71.f, 72.f, 70.f, 71.f, 72.f, 67.f, 68.f, 69.f, 67.f, 68.f, 69.f, 70.f, 71.f, 72.f, 70.f, 71.f, 72.f}); | ||||||
| 
 | 
 | ||||||
|     nd4j::ops::upsampling3d op; |     nd4j::ops::upsampling3d op; | ||||||
| @ -2258,17 +2290,17 @@ TEST_F(ConvolutionTests1, upsampling3d_test2) { | |||||||
|     auto input = NDArrayFactory::create<float>('c', {bS, iC, iD, iH, iW}); |     auto input = NDArrayFactory::create<float>('c', {bS, iC, iD, iH, iW}); | ||||||
|     input.linspace(1); |     input.linspace(1); | ||||||
| 
 | 
 | ||||||
|     auto expOutput = NDArrayFactory::create<float>('c', {bS, iC, iD*factorD, iH*factorH, iW*factorW}, { 1.f, 1.f, 2.f, 2.f, 1.f, 1.f, 2.f, 2.f, 1.f, 1.f, 2.f, 2.f, 3.f, 3.f, 4.f, 4.f, 3.f, 3.f, 4.f, 4.f, 3.f, 3.f, 4.f, 4.f, 1.f, 1.f, 2.f, 2.f, 1.f, 1.f, 2.f, 2.f, 1.f, 1.f, 2.f, 2.f, 3.f, 3.f, 4.f, 4.f, 3.f, 3.f, 4.f, 4.f, 3.f, 3.f, 4.f, 4.f, 5.f, 5.f, 6.f, 6.f, 5.f, 5.f, 6.f, 6.f, 5.f, 5.f, 6.f, 6.f, 7.f, 7.f, 8.f, 8.f, 7.f, 7.f, 8.f, 8.f, 7.f, 7.f, 8.f, 8.f,  |     auto expOutput = NDArrayFactory::create<float>('c', {bS, iC, iD*factorD, iH*factorH, iW*factorW}, { 1.f, 1.f, 2.f, 2.f, 1.f, 1.f, 2.f, 2.f, 1.f, 1.f, 2.f, 2.f, 3.f, 3.f, 4.f, 4.f, 3.f, 3.f, 4.f, 4.f, 3.f, 3.f, 4.f, 4.f, 1.f, 1.f, 2.f, 2.f, 1.f, 1.f, 2.f, 2.f, 1.f, 1.f, 2.f, 2.f, 3.f, 3.f, 4.f, 4.f, 3.f, 3.f, 4.f, 4.f, 3.f, 3.f, 4.f, 4.f, 5.f, 5.f, 6.f, 6.f, 5.f, 5.f, 6.f, 6.f, 5.f, 5.f, 6.f, 6.f, 7.f, 7.f, 8.f, 8.f, 7.f, 7.f, 8.f, 8.f, 7.f, 7.f, 8.f, 8.f, | ||||||
|              5.f, 5.f, 6.f, 6.f, 5.f, 5.f, 6.f, 6.f, 5.f, 5.f, 6.f, 6.f, 7.f, 7.f, 8.f, 8.f, 7.f, 7.f, 8.f, 8.f, 7.f, 7.f, 8.f, 8.f, 9.f, 9.f, 10.f, 10.f, 9.f, 9.f, 10.f, 10.f, 9.f, 9.f, 10.f, 10.f, 11.f, 11.f, 12.f, 12.f, 11.f, 11.f, 12.f, 12.f, 11.f, 11.f, 12.f, 12.f, 9.f, 9.f, 10.f, 10.f, 9.f, 9.f, 10.f, 10.f, 9.f, 9.f, 10.f, 10.f, 11.f, 11.f, 12.f, 12.f, 11.f, 11.f, 12.f, 12.f, 11.f, 11.f, 12.f, 12.f,  |              5.f, 5.f, 6.f, 6.f, 5.f, 5.f, 6.f, 6.f, 5.f, 5.f, 6.f, 6.f, 7.f, 7.f, 8.f, 8.f, 7.f, 7.f, 8.f, 8.f, 7.f, 7.f, 8.f, 8.f, 9.f, 9.f, 10.f, 10.f, 9.f, 9.f, 10.f, 10.f, 9.f, 9.f, 10.f, 10.f, 11.f, 11.f, 12.f, 12.f, 11.f, 11.f, 12.f, 12.f, 11.f, 11.f, 12.f, 12.f, 9.f, 9.f, 10.f, 10.f, 9.f, 9.f, 10.f, 10.f, 9.f, 9.f, 10.f, 10.f, 11.f, 11.f, 12.f, 12.f, 11.f, 11.f, 12.f, 12.f, 11.f, 11.f, 12.f, 12.f, | ||||||
|             13.f, 13.f, 14.f, 14.f, 13.f, 13.f, 14.f, 14.f, 13.f, 13.f, 14.f, 14.f, 15.f, 15.f, 16.f, 16.f, 15.f, 15.f, 16.f, 16.f, 15.f, 15.f, 16.f, 16.f, 13.f, 13.f, 14.f, 14.f, 13.f, 13.f, 14.f, 14.f, 13.f, 13.f, 14.f, 14.f, 15.f, 15.f, 16.f, 16.f, 15.f, 15.f, 16.f, 16.f, 15.f, 15.f, 16.f, 16.f, 17.f, 17.f, 18.f, 18.f, 17.f, 17.f, 18.f, 18.f, 17.f, 17.f, 18.f, 18.f, 19.f, 19.f, 20.f, 20.f, 19.f, 19.f, 20.f, 20.f, 19.f, 19.f, 20.f, 20.f,  |             13.f, 13.f, 14.f, 14.f, 13.f, 13.f, 14.f, 14.f, 13.f, 13.f, 14.f, 14.f, 15.f, 15.f, 16.f, 16.f, 15.f, 15.f, 16.f, 16.f, 15.f, 15.f, 16.f, 16.f, 13.f, 13.f, 14.f, 14.f, 13.f, 13.f, 14.f, 14.f, 13.f, 13.f, 14.f, 14.f, 15.f, 15.f, 16.f, 16.f, 15.f, 15.f, 16.f, 16.f, 15.f, 15.f, 16.f, 16.f, 17.f, 17.f, 18.f, 18.f, 17.f, 17.f, 18.f, 18.f, 17.f, 17.f, 18.f, 18.f, 19.f, 19.f, 20.f, 20.f, 19.f, 19.f, 20.f, 20.f, 19.f, 19.f, 20.f, 20.f, | ||||||
|             17.f, 17.f, 18.f, 18.f, 17.f, 17.f, 18.f, 18.f, 17.f, 17.f, 18.f, 18.f, 19.f, 19.f, 20.f, 20.f, 19.f, 19.f, 20.f, 20.f, 19.f, 19.f, 20.f, 20.f, 21.f, 21.f, 22.f, 22.f, 21.f, 21.f, 22.f, 22.f, 21.f, 21.f, 22.f, 22.f, 23.f, 23.f, 24.f, 24.f, 23.f, 23.f, 24.f, 24.f, 23.f, 23.f, 24.f, 24.f, 21.f, 21.f, 22.f, 22.f, 21.f, 21.f, 22.f, 22.f, 21.f, 21.f, 22.f, 22.f, 23.f, 23.f, 24.f, 24.f, 23.f, 23.f, 24.f, 24.f, 23.f, 23.f, 24.f, 24.f,  |             17.f, 17.f, 18.f, 18.f, 17.f, 17.f, 18.f, 18.f, 17.f, 17.f, 18.f, 18.f, 19.f, 19.f, 20.f, 20.f, 19.f, 19.f, 20.f, 20.f, 19.f, 19.f, 20.f, 20.f, 21.f, 21.f, 22.f, 22.f, 21.f, 21.f, 22.f, 22.f, 21.f, 21.f, 22.f, 22.f, 23.f, 23.f, 24.f, 24.f, 23.f, 23.f, 24.f, 24.f, 23.f, 23.f, 24.f, 24.f, 21.f, 21.f, 22.f, 22.f, 21.f, 21.f, 22.f, 22.f, 21.f, 21.f, 22.f, 22.f, 23.f, 23.f, 24.f, 24.f, 23.f, 23.f, 24.f, 24.f, 23.f, 23.f, 24.f, 24.f, | ||||||
|             25.f, 25.f, 26.f, 26.f, 25.f, 25.f, 26.f, 26.f, 25.f, 25.f, 26.f, 26.f, 27.f, 27.f, 28.f, 28.f, 27.f, 27.f, 28.f, 28.f, 27.f, 27.f, 28.f, 28.f, 25.f, 25.f, 26.f, 26.f, 25.f, 25.f, 26.f, 26.f, 25.f, 25.f, 26.f, 26.f, 27.f, 27.f, 28.f, 28.f, 27.f, 27.f, 28.f, 28.f, 27.f, 27.f, 28.f, 28.f, 29.f, 29.f, 30.f, 30.f, 29.f, 29.f, 30.f, 30.f, 29.f, 29.f, 30.f, 30.f, 31.f, 31.f, 32.f, 32.f, 31.f, 31.f, 32.f, 32.f, 31.f, 31.f, 32.f, 32.f,  |             25.f, 25.f, 26.f, 26.f, 25.f, 25.f, 26.f, 26.f, 25.f, 25.f, 26.f, 26.f, 27.f, 27.f, 28.f, 28.f, 27.f, 27.f, 28.f, 28.f, 27.f, 27.f, 28.f, 28.f, 25.f, 25.f, 26.f, 26.f, 25.f, 25.f, 26.f, 26.f, 25.f, 25.f, 26.f, 26.f, 27.f, 27.f, 28.f, 28.f, 27.f, 27.f, 28.f, 28.f, 27.f, 27.f, 28.f, 28.f, 29.f, 29.f, 30.f, 30.f, 29.f, 29.f, 30.f, 30.f, 29.f, 29.f, 30.f, 30.f, 31.f, 31.f, 32.f, 32.f, 31.f, 31.f, 32.f, 32.f, 31.f, 31.f, 32.f, 32.f, | ||||||
|             29.f, 29.f, 30.f, 30.f, 29.f, 29.f, 30.f, 30.f, 29.f, 29.f, 30.f, 30.f, 31.f, 31.f, 32.f, 32.f, 31.f, 31.f, 32.f, 32.f, 31.f, 31.f, 32.f, 32.f, 33.f, 33.f, 34.f, 34.f, 33.f, 33.f, 34.f, 34.f, 33.f, 33.f, 34.f, 34.f, 35.f, 35.f, 36.f, 36.f, 35.f, 35.f, 36.f, 36.f, 35.f, 35.f, 36.f, 36.f, 33.f, 33.f, 34.f, 34.f, 33.f, 33.f, 34.f, 34.f, 33.f, 33.f, 34.f, 34.f, 35.f, 35.f, 36.f, 36.f, 35.f, 35.f, 36.f, 36.f, 35.f, 35.f, 36.f, 36.f,  |             29.f, 29.f, 30.f, 30.f, 29.f, 29.f, 30.f, 30.f, 29.f, 29.f, 30.f, 30.f, 31.f, 31.f, 32.f, 32.f, 31.f, 31.f, 32.f, 32.f, 31.f, 31.f, 32.f, 32.f, 33.f, 33.f, 34.f, 34.f, 33.f, 33.f, 34.f, 34.f, 33.f, 33.f, 34.f, 34.f, 35.f, 35.f, 36.f, 36.f, 35.f, 35.f, 36.f, 36.f, 35.f, 35.f, 36.f, 36.f, 33.f, 33.f, 34.f, 34.f, 33.f, 33.f, 34.f, 34.f, 33.f, 33.f, 34.f, 34.f, 35.f, 35.f, 36.f, 36.f, 35.f, 35.f, 36.f, 36.f, 35.f, 35.f, 36.f, 36.f, | ||||||
|             37.f, 37.f, 38.f, 38.f, 37.f, 37.f, 38.f, 38.f, 37.f, 37.f, 38.f, 38.f, 39.f, 39.f, 40.f, 40.f, 39.f, 39.f, 40.f, 40.f, 39.f, 39.f, 40.f, 40.f, 37.f, 37.f, 38.f, 38.f, 37.f, 37.f, 38.f, 38.f, 37.f, 37.f, 38.f, 38.f, 39.f, 39.f, 40.f, 40.f, 39.f, 39.f, 40.f, 40.f, 39.f, 39.f, 40.f, 40.f, 41.f, 41.f, 42.f, 42.f, 41.f, 41.f, 42.f, 42.f, 41.f, 41.f, 42.f, 42.f, 43.f, 43.f, 44.f, 44.f, 43.f, 43.f, 44.f, 44.f, 43.f, 43.f, 44.f, 44.f,  |             37.f, 37.f, 38.f, 38.f, 37.f, 37.f, 38.f, 38.f, 37.f, 37.f, 38.f, 38.f, 39.f, 39.f, 40.f, 40.f, 39.f, 39.f, 40.f, 40.f, 39.f, 39.f, 40.f, 40.f, 37.f, 37.f, 38.f, 38.f, 37.f, 37.f, 38.f, 38.f, 37.f, 37.f, 38.f, 38.f, 39.f, 39.f, 40.f, 40.f, 39.f, 39.f, 40.f, 40.f, 39.f, 39.f, 40.f, 40.f, 41.f, 41.f, 42.f, 42.f, 41.f, 41.f, 42.f, 42.f, 41.f, 41.f, 42.f, 42.f, 43.f, 43.f, 44.f, 44.f, 43.f, 43.f, 44.f, 44.f, 43.f, 43.f, 44.f, 44.f, | ||||||
|             41.f, 41.f, 42.f, 42.f, 41.f, 41.f, 42.f, 42.f, 41.f, 41.f, 42.f, 42.f, 43.f, 43.f, 44.f, 44.f, 43.f, 43.f, 44.f, 44.f, 43.f, 43.f, 44.f, 44.f, 45.f, 45.f, 46.f, 46.f, 45.f, 45.f, 46.f, 46.f, 45.f, 45.f, 46.f, 46.f, 47.f, 47.f, 48.f, 48.f, 47.f, 47.f, 48.f, 48.f, 47.f, 47.f, 48.f, 48.f, 45.f, 45.f, 46.f, 46.f, 45.f, 45.f, 46.f, 46.f, 45.f, 45.f, 46.f, 46.f, 47.f, 47.f, 48.f, 48.f, 47.f, 47.f, 48.f, 48.f, 47.f, 47.f, 48.f, 48.f,  |             41.f, 41.f, 42.f, 42.f, 41.f, 41.f, 42.f, 42.f, 41.f, 41.f, 42.f, 42.f, 43.f, 43.f, 44.f, 44.f, 43.f, 43.f, 44.f, 44.f, 43.f, 43.f, 44.f, 44.f, 45.f, 45.f, 46.f, 46.f, 45.f, 45.f, 46.f, 46.f, 45.f, 45.f, 46.f, 46.f, 47.f, 47.f, 48.f, 48.f, 47.f, 47.f, 48.f, 48.f, 47.f, 47.f, 48.f, 48.f, 45.f, 45.f, 46.f, 46.f, 45.f, 45.f, 46.f, 46.f, 45.f, 45.f, 46.f, 46.f, 47.f, 47.f, 48.f, 48.f, 47.f, 47.f, 48.f, 48.f, 47.f, 47.f, 48.f, 48.f, | ||||||
|             49.f, 49.f, 50.f, 50.f, 49.f, 49.f, 50.f, 50.f, 49.f, 49.f, 50.f, 50.f, 51.f, 51.f, 52.f, 52.f, 51.f, 51.f, 52.f, 52.f, 51.f, 51.f, 52.f, 52.f, 49.f, 49.f, 50.f, 50.f, 49.f, 49.f, 50.f, 50.f, 49.f, 49.f, 50.f, 50.f, 51.f, 51.f, 52.f, 52.f, 51.f, 51.f, 52.f, 52.f, 51.f, 51.f, 52.f, 52.f, 53.f, 53.f, 54.f, 54.f, 53.f, 53.f, 54.f, 54.f, 53.f, 53.f, 54.f, 54.f, 55.f, 55.f, 56.f, 56.f, 55.f, 55.f, 56.f, 56.f, 55.f, 55.f, 56.f, 56.f,  |             49.f, 49.f, 50.f, 50.f, 49.f, 49.f, 50.f, 50.f, 49.f, 49.f, 50.f, 50.f, 51.f, 51.f, 52.f, 52.f, 51.f, 51.f, 52.f, 52.f, 51.f, 51.f, 52.f, 52.f, 49.f, 49.f, 50.f, 50.f, 49.f, 49.f, 50.f, 50.f, 49.f, 49.f, 50.f, 50.f, 51.f, 51.f, 52.f, 52.f, 51.f, 51.f, 52.f, 52.f, 51.f, 51.f, 52.f, 52.f, 53.f, 53.f, 54.f, 54.f, 53.f, 53.f, 54.f, 54.f, 53.f, 53.f, 54.f, 54.f, 55.f, 55.f, 56.f, 56.f, 55.f, 55.f, 56.f, 56.f, 55.f, 55.f, 56.f, 56.f, | ||||||
|             53.f, 53.f, 54.f, 54.f, 53.f, 53.f, 54.f, 54.f, 53.f, 53.f, 54.f, 54.f, 55.f, 55.f, 56.f, 56.f, 55.f, 55.f, 56.f, 56.f, 55.f, 55.f, 56.f, 56.f, 57.f, 57.f, 58.f, 58.f, 57.f, 57.f, 58.f, 58.f, 57.f, 57.f, 58.f, 58.f, 59.f, 59.f, 60.f, 60.f, 59.f, 59.f, 60.f, 60.f, 59.f, 59.f, 60.f, 60.f, 57.f, 57.f, 58.f, 58.f, 57.f, 57.f, 58.f, 58.f, 57.f, 57.f, 58.f, 58.f, 59.f, 59.f, 60.f, 60.f, 59.f, 59.f, 60.f, 60.f, 59.f, 59.f, 60.f, 60.f,  |             53.f, 53.f, 54.f, 54.f, 53.f, 53.f, 54.f, 54.f, 53.f, 53.f, 54.f, 54.f, 55.f, 55.f, 56.f, 56.f, 55.f, 55.f, 56.f, 56.f, 55.f, 55.f, 56.f, 56.f, 57.f, 57.f, 58.f, 58.f, 57.f, 57.f, 58.f, 58.f, 57.f, 57.f, 58.f, 58.f, 59.f, 59.f, 60.f, 60.f, 59.f, 59.f, 60.f, 60.f, 59.f, 59.f, 60.f, 60.f, 57.f, 57.f, 58.f, 58.f, 57.f, 57.f, 58.f, 58.f, 57.f, 57.f, 58.f, 58.f, 59.f, 59.f, 60.f, 60.f, 59.f, 59.f, 60.f, 60.f, 59.f, 59.f, 60.f, 60.f, | ||||||
|             61.f, 61.f, 62.f, 62.f, 61.f, 61.f, 62.f, 62.f, 61.f, 61.f, 62.f, 62.f, 63.f, 63.f, 64.f, 64.f, 63.f, 63.f, 64.f, 64.f, 63.f, 63.f, 64.f, 64.f, 61.f, 61.f, 62.f, 62.f, 61.f, 61.f, 62.f, 62.f, 61.f, 61.f, 62.f, 62.f, 63.f, 63.f, 64.f, 64.f, 63.f, 63.f, 64.f, 64.f, 63.f, 63.f, 64.f, 64.f, 65.f, 65.f, 66.f, 66.f, 65.f, 65.f, 66.f, 66.f, 65.f, 65.f, 66.f, 66.f, 67.f, 67.f, 68.f, 68.f, 67.f, 67.f, 68.f, 68.f, 67.f, 67.f, 68.f, 68.f,  |             61.f, 61.f, 62.f, 62.f, 61.f, 61.f, 62.f, 62.f, 61.f, 61.f, 62.f, 62.f, 63.f, 63.f, 64.f, 64.f, 63.f, 63.f, 64.f, 64.f, 63.f, 63.f, 64.f, 64.f, 61.f, 61.f, 62.f, 62.f, 61.f, 61.f, 62.f, 62.f, 61.f, 61.f, 62.f, 62.f, 63.f, 63.f, 64.f, 64.f, 63.f, 63.f, 64.f, 64.f, 63.f, 63.f, 64.f, 64.f, 65.f, 65.f, 66.f, 66.f, 65.f, 65.f, 66.f, 66.f, 65.f, 65.f, 66.f, 66.f, 67.f, 67.f, 68.f, 68.f, 67.f, 67.f, 68.f, 68.f, 67.f, 67.f, 68.f, 68.f, | ||||||
|             65.f, 65.f, 66.f, 66.f, 65.f, 65.f, 66.f, 66.f, 65.f, 65.f, 66.f, 66.f, 67.f, 67.f, 68.f, 68.f, 67.f, 67.f, 68.f, 68.f, 67.f, 67.f, 68.f, 68.f, 69.f, 69.f, 70.f, 70.f, 69.f, 69.f, 70.f, 70.f, 69.f, 69.f, 70.f, 70.f, 71.f, 71.f, 72.f, 72.f, 71.f, 71.f, 72.f, 72.f, 71.f, 71.f, 72.f, 72.f, 69.f, 69.f, 70.f, 70.f, 69.f, 69.f, 70.f, 70.f, 69.f, 69.f, 70.f, 70.f, 71.f, 71.f, 72.f, 72.f, 71.f, 71.f, 72.f, 72.f, 71.f, 71.f, 72.f, 72.f}); |             65.f, 65.f, 66.f, 66.f, 65.f, 65.f, 66.f, 66.f, 65.f, 65.f, 66.f, 66.f, 67.f, 67.f, 68.f, 68.f, 67.f, 67.f, 68.f, 68.f, 67.f, 67.f, 68.f, 68.f, 69.f, 69.f, 70.f, 70.f, 69.f, 69.f, 70.f, 70.f, 69.f, 69.f, 70.f, 70.f, 71.f, 71.f, 72.f, 72.f, 71.f, 71.f, 72.f, 72.f, 71.f, 71.f, 72.f, 72.f, 69.f, 69.f, 70.f, 70.f, 69.f, 69.f, 70.f, 70.f, 69.f, 69.f, 70.f, 70.f, 71.f, 71.f, 72.f, 72.f, 71.f, 71.f, 72.f, 72.f, 71.f, 71.f, 72.f, 72.f}); | ||||||
| 
 | 
 | ||||||
|     nd4j::ops::upsampling3d op; |     nd4j::ops::upsampling3d op; | ||||||
| @ -2412,13 +2444,13 @@ TEST_F(ConvolutionTests1, deconv2d_test1) { | |||||||
| 
 | 
 | ||||||
|     auto input    = NDArrayFactory::create<float>('c', {bS, iH, iW, iC}); |     auto input    = NDArrayFactory::create<float>('c', {bS, iH, iW, iC}); | ||||||
|     auto weights  = NDArrayFactory::create<float>('c', {kH, kW, oC, iC}); |     auto weights  = NDArrayFactory::create<float>('c', {kH, kW, oC, iC}); | ||||||
|     auto exp = NDArrayFactory::create<float>('c', {bS, oH, oW, oC}, {  2.75f, 7.75f, 12.75f, 17.75f, 22.75f, 30.5f, 40.5f, 50.5f, 60.5f, 70.5f, 30.5f, 40.5f, 50.5f, 60.5f, 70.5f, 27.75f, 32.75f, 37.75f, 42.75f, 47.75f,  |     auto exp = NDArrayFactory::create<float>('c', {bS, oH, oW, oC}, {  2.75f, 7.75f, 12.75f, 17.75f, 22.75f, 30.5f, 40.5f, 50.5f, 60.5f, 70.5f, 30.5f, 40.5f, 50.5f, 60.5f, 70.5f, 27.75f, 32.75f, 37.75f, 42.75f, 47.75f, | ||||||
|                                                   55.5f, 65.5f, 75.5f, 85.5f, 95.5f, 161.f, 181.f, 201.f, 221.f, 241.f, 161.f, 181.f, 201.f, 221.f, 241.f, 105.5f, 115.5f, 125.5f, 135.5f, 145.5f,  |                                                   55.5f, 65.5f, 75.5f, 85.5f, 95.5f, 161.f, 181.f, 201.f, 221.f, 241.f, 161.f, 181.f, 201.f, 221.f, 241.f, 105.5f, 115.5f, 125.5f, 135.5f, 145.5f, | ||||||
|                                                   55.5f, 65.5f, 75.5f, 85.5f, 95.5f, 161.f, 181.f, 201.f, 221.f, 241.f, 161.f, 181.f, 201.f, 221.f, 241.f, 105.5f, 115.5f, 125.5f, 135.5f, 145.5f,  |                                                   55.5f, 65.5f, 75.5f, 85.5f, 95.5f, 161.f, 181.f, 201.f, 221.f, 241.f, 161.f, 181.f, 201.f, 221.f, 241.f, 105.5f, 115.5f, 125.5f, 135.5f, 145.5f, | ||||||
|                                                   52.75f, 57.75f, 62.75f, 67.75f, 72.75f, 130.5f, 140.5f, 150.5f, 160.5f, 170.5f, 130.5f, 140.5f, 150.5f, 160.5f, 170.5f, 77.75f, 82.75f, 87.75f, 92.75f, 97.75f,  |                                                   52.75f, 57.75f, 62.75f, 67.75f, 72.75f, 130.5f, 140.5f, 150.5f, 160.5f, 170.5f, 130.5f, 140.5f, 150.5f, 160.5f, 170.5f, 77.75f, 82.75f, 87.75f, 92.75f, 97.75f, | ||||||
|                                                    2.75f, 7.75f, 12.75f, 17.75f, 22.75f, 30.5f, 40.5f, 50.5f, 60.5f, 70.5f, 30.5f, 40.5f, 50.5f, 60.5f, 70.5f, 27.75f, 32.75f, 37.75f, 42.75f, 47.75f,  |                                                    2.75f, 7.75f, 12.75f, 17.75f, 22.75f, 30.5f, 40.5f, 50.5f, 60.5f, 70.5f, 30.5f, 40.5f, 50.5f, 60.5f, 70.5f, 27.75f, 32.75f, 37.75f, 42.75f, 47.75f, | ||||||
|                                                   55.5f, 65.5f, 75.5f, 85.5f, 95.5f, 161.f, 181.f, 201.f, 221.f, 241.f, 161.f, 181.f, 201.f, 221.f, 241.f, 105.5f, 115.5f, 125.5f, 135.5f, 145.5f,  |                                                   55.5f, 65.5f, 75.5f, 85.5f, 95.5f, 161.f, 181.f, 201.f, 221.f, 241.f, 161.f, 181.f, 201.f, 221.f, 241.f, 105.5f, 115.5f, 125.5f, 135.5f, 145.5f, | ||||||
|                                                   55.5f, 65.5f, 75.5f, 85.5f, 95.5f, 161.f, 181.f, 201.f, 221.f, 241.f, 161.f, 181.f, 201.f, 221.f, 241.f, 105.5f, 115.5f, 125.5f, 135.5f, 145.5f,  |                                                   55.5f, 65.5f, 75.5f, 85.5f, 95.5f, 161.f, 181.f, 201.f, 221.f, 241.f, 161.f, 181.f, 201.f, 221.f, 241.f, 105.5f, 115.5f, 125.5f, 135.5f, 145.5f, | ||||||
|                                                   52.75f, 57.75f, 62.75f, 67.75f, 72.75f, 130.5f, 140.5f, 150.5f, 160.5f, 170.5f, 130.5f, 140.5f, 150.5f, 160.5f, 170.5f, 77.75f, 82.75f, 87.75f, 92.75f, 97.75f}); |                                                   52.75f, 57.75f, 62.75f, 67.75f, 72.75f, 130.5f, 140.5f, 150.5f, 160.5f, 170.5f, 130.5f, 140.5f, 150.5f, 160.5f, 170.5f, 77.75f, 82.75f, 87.75f, 92.75f, 97.75f}); | ||||||
|     input = 0.5; |     input = 0.5; | ||||||
|     weights.linspace(0.1, 0.1); |     weights.linspace(0.1, 0.1); | ||||||
| @ -2445,13 +2477,13 @@ TEST_F(ConvolutionTests1, deconv2d_test2) { | |||||||
| 
 | 
 | ||||||
|     auto input    = NDArrayFactory::create<float>('c', {bS, oH, oW, oC}); |     auto input    = NDArrayFactory::create<float>('c', {bS, oH, oW, oC}); | ||||||
|     auto weights  = NDArrayFactory::create<float>('c', {kH, kW, iC, oC}); |     auto weights  = NDArrayFactory::create<float>('c', {kH, kW, iC, oC}); | ||||||
|     auto exp = NDArrayFactory::create<float>('c', {bS, iH, iW, iC}, {2.75f,    7.75f,   12.75f,   17.75f,   22.75f,  30.5f,   40.5f,   50.5f,   60.5f,   70.5f,  30.5f,   40.5f,   50.5f,   60.5f,   70.5f,  30.5f,   40.5f,   50.5f,   60.5f,   70.5f,  |     auto exp = NDArrayFactory::create<float>('c', {bS, iH, iW, iC}, {2.75f,    7.75f,   12.75f,   17.75f,   22.75f,  30.5f,   40.5f,   50.5f,   60.5f,   70.5f,  30.5f,   40.5f,   50.5f,   60.5f,   70.5f,  30.5f,   40.5f,   50.5f,   60.5f,   70.5f, | ||||||
|                                                 55.5f,   65.5f,   75.5f,   85.5f,   95.5f, 161.f,  181.f,  201.f,  221.f,  241.f, 161.f,  181.f,  201.f,  221.f,  241.f, 161.f,  181.f,  201.f,  221.f,  241.f,  |                                                 55.5f,   65.5f,   75.5f,   85.5f,   95.5f, 161.f,  181.f,  201.f,  221.f,  241.f, 161.f,  181.f,  201.f,  221.f,  241.f, 161.f,  181.f,  201.f,  221.f,  241.f, | ||||||
|                                                 55.5f,   65.5f,   75.5f,   85.5f,   95.5f, 161.f,  181.f,  201.f,  221.f,  241.f, 161.f,  181.f,  201.f,  221.f,  241.f, 161.f,  181.f,  201.f,  221.f,  241.f,  |                                                 55.5f,   65.5f,   75.5f,   85.5f,   95.5f, 161.f,  181.f,  201.f,  221.f,  241.f, 161.f,  181.f,  201.f,  221.f,  241.f, 161.f,  181.f,  201.f,  221.f,  241.f, | ||||||
|                                                 55.5f,   65.5f,   75.5f,   85.5f,   95.5f, 161.f,  181.f,  201.f,  221.f,  241.f, 161.f,  181.f,  201.f,  221.f,  241.f, 161.f,  181.f,  201.f,  221.f,  241.f,  |                                                 55.5f,   65.5f,   75.5f,   85.5f,   95.5f, 161.f,  181.f,  201.f,  221.f,  241.f, 161.f,  181.f,  201.f,  221.f,  241.f, 161.f,  181.f,  201.f,  221.f,  241.f, | ||||||
|                                                  2.75f,    7.75f,   12.75f,   17.75f,   22.75f,  30.5f,   40.5f,   50.5f,   60.5f,   70.5f,  30.5f,   40.5f,   50.5f,   60.5f,   70.5f,  30.5f,   40.5f,   50.5f,   60.5f,   70.5f,  |                                                  2.75f,    7.75f,   12.75f,   17.75f,   22.75f,  30.5f,   40.5f,   50.5f,   60.5f,   70.5f,  30.5f,   40.5f,   50.5f,   60.5f,   70.5f,  30.5f,   40.5f,   50.5f,   60.5f,   70.5f, | ||||||
|                                                 55.5f,   65.5f,   75.5f,   85.5f,   95.5f, 161.f,  181.f,  201.f,  221.f,  241.f, 161.f,  181.f,  201.f,  221.f,  241.f, 161.f,  181.f,  201.f,  221.f,  241.f,  |                                                 55.5f,   65.5f,   75.5f,   85.5f,   95.5f, 161.f,  181.f,  201.f,  221.f,  241.f, 161.f,  181.f,  201.f,  221.f,  241.f, 161.f,  181.f,  201.f,  221.f,  241.f, | ||||||
|                                                 55.5f,   65.5f,   75.5f,   85.5f,   95.5f, 161.f,  181.f,  201.f,  221.f,  241.f, 161.f,  181.f,  201.f,  221.f,  241.f, 161.f,  181.f,  201.f,  221.f,  241.f,  |                                                 55.5f,   65.5f,   75.5f,   85.5f,   95.5f, 161.f,  181.f,  201.f,  221.f,  241.f, 161.f,  181.f,  201.f,  221.f,  241.f, 161.f,  181.f,  201.f,  221.f,  241.f, | ||||||
|                                                 55.5f,   65.5f,   75.5f,   85.5f,   95.5f, 161.f,  181.f,  201.f,  221.f,  241.f, 161.f,  181.f,  201.f,  221.f,  241.f, 161.f,  181.f,  201.f,  221.f,  241.f  }); |                                                 55.5f,   65.5f,   75.5f,   85.5f,   95.5f, 161.f,  181.f,  201.f,  221.f,  241.f, 161.f,  181.f,  201.f,  221.f,  241.f, 161.f,  181.f,  201.f,  221.f,  241.f  }); | ||||||
|     input = 0.5; |     input = 0.5; | ||||||
|     weights.linspace(0.1, 0.1); |     weights.linspace(0.1, 0.1); | ||||||
| @ -2673,13 +2705,13 @@ TYPED_TEST(TypedConvolutionTests1, deconv2d_tf_test1) { | |||||||
|     auto input    = NDArrayFactory::create<TypeParam>('c', {bS, oH, oW, oC}); |     auto input    = NDArrayFactory::create<TypeParam>('c', {bS, oH, oW, oC}); | ||||||
|     auto weights  = NDArrayFactory::create<TypeParam>('c', {kH, kW, iC, oC}); |     auto weights  = NDArrayFactory::create<TypeParam>('c', {kH, kW, iC, oC}); | ||||||
|     auto outShape = NDArrayFactory::create<TypeParam>('c', {4}, {static_cast<TypeParam>(bS), static_cast<TypeParam>(iH), static_cast<TypeParam>(iW), static_cast<TypeParam>(iC)}); |     auto outShape = NDArrayFactory::create<TypeParam>('c', {4}, {static_cast<TypeParam>(bS), static_cast<TypeParam>(iH), static_cast<TypeParam>(iW), static_cast<TypeParam>(iC)}); | ||||||
|     auto exp = NDArrayFactory::create<TypeParam>('c', {bS, iH, iW, iC}, {  2.75f,    7.75f,   12.75f,   17.75f,   22.75f,  30.5f,   40.5f,   50.5f,   60.5f,   70.5f,  30.5f,   40.5f,   50.5f,   60.5f,   70.5f,  27.75f,   32.75f,   37.75f,   42.75f,   47.75f,  |     auto exp = NDArrayFactory::create<TypeParam>('c', {bS, iH, iW, iC}, {  2.75f,    7.75f,   12.75f,   17.75f,   22.75f,  30.5f,   40.5f,   50.5f,   60.5f,   70.5f,  30.5f,   40.5f,   50.5f,   60.5f,   70.5f,  27.75f,   32.75f,   37.75f,   42.75f,   47.75f, | ||||||
|                                                   55.5f,   65.5f,   75.5f,   85.5f,   95.5f, 161.f,  181.f,  201.f,  221.f,  241.f, 161.f,  181.f,  201.f,  221.f,  241.f, 105.5f,  115.5f,  125.5f,  135.5f,  145.5f,  |                                                   55.5f,   65.5f,   75.5f,   85.5f,   95.5f, 161.f,  181.f,  201.f,  221.f,  241.f, 161.f,  181.f,  201.f,  221.f,  241.f, 105.5f,  115.5f,  125.5f,  135.5f,  145.5f, | ||||||
|                                                   55.5f,   65.5f,   75.5f,   85.5f,   95.5f, 161.f,  181.f,  201.f,  221.f,  241.f, 161.f,  181.f,  201.f,  221.f,  241.f, 105.5f,  115.5f,  125.5f,  135.5f,  145.5f,  |                                                   55.5f,   65.5f,   75.5f,   85.5f,   95.5f, 161.f,  181.f,  201.f,  221.f,  241.f, 161.f,  181.f,  201.f,  221.f,  241.f, 105.5f,  115.5f,  125.5f,  135.5f,  145.5f, | ||||||
|                                                   52.75f,   57.75f,   62.75f,   67.75f,   72.75f, 130.5f,  140.5f,  150.5f,  160.5f,  170.5f, 130.5f,  140.5f,  150.5f,  160.5f,  170.5f,  77.75f,   82.75f,   87.75f,   92.75f,   97.75f,  |                                                   52.75f,   57.75f,   62.75f,   67.75f,   72.75f, 130.5f,  140.5f,  150.5f,  160.5f,  170.5f, 130.5f,  140.5f,  150.5f,  160.5f,  170.5f,  77.75f,   82.75f,   87.75f,   92.75f,   97.75f, | ||||||
|                                                    2.75f,    7.75f,   12.75f,   17.75f,   22.75f,  30.5f,   40.5f,   50.5f,   60.5f,   70.5f,  30.5f,   40.5f,   50.5f,   60.5f,   70.5f,  27.75f,   32.75f,   37.75f,   42.75f,   47.75f,  |                                                    2.75f,    7.75f,   12.75f,   17.75f,   22.75f,  30.5f,   40.5f,   50.5f,   60.5f,   70.5f,  30.5f,   40.5f,   50.5f,   60.5f,   70.5f,  27.75f,   32.75f,   37.75f,   42.75f,   47.75f, | ||||||
|                                                   55.5f,   65.5f,   75.5f,   85.5f,   95.5f, 161.f,  181.f,  201.f,  221.f,  241.f, 161.f,  181.f,  201.f,  221.f,  241.f, 105.5f,  115.5f,  125.5f,  135.5f,  145.5f,  |                                                   55.5f,   65.5f,   75.5f,   85.5f,   95.5f, 161.f,  181.f,  201.f,  221.f,  241.f, 161.f,  181.f,  201.f,  221.f,  241.f, 105.5f,  115.5f,  125.5f,  135.5f,  145.5f, | ||||||
|                                                   55.5f,   65.5f,   75.5f,   85.5f,   95.5f, 161.f,  181.f,  201.f,  221.f,  241.f, 161.f,  181.f,  201.f,  221.f,  241.f, 105.5f,  115.5f,  125.5f,  135.5f,  145.5f,  |                                                   55.5f,   65.5f,   75.5f,   85.5f,   95.5f, 161.f,  181.f,  201.f,  221.f,  241.f, 161.f,  181.f,  201.f,  221.f,  241.f, 105.5f,  115.5f,  125.5f,  135.5f,  145.5f, | ||||||
|                                                   52.75f,   57.75f,   62.75f,   67.75f,   72.75f, 130.5f,  140.5f,  150.5f,  160.5f,  170.5f, 130.5f,  140.5f,  150.5f,  160.5f,  170.5f,  77.75f,   82.75f,   87.75f,   92.75f,   97.75f}); |                                                   52.75f,   57.75f,   62.75f,   67.75f,   72.75f, 130.5f,  140.5f,  150.5f,  160.5f,  170.5f, 130.5f,  140.5f,  150.5f,  160.5f,  170.5f,  77.75f,   82.75f,   87.75f,   92.75f,   97.75f}); | ||||||
|     input = 0.5; |     input = 0.5; | ||||||
|     weights.linspace(0.1, 0.1); |     weights.linspace(0.1, 0.1); | ||||||
|  | |||||||
| @ -428,10 +428,11 @@ TEST_F(DeclarableOpsTests13, CellContains_test_1) { | |||||||
| TEST_F(DeclarableOpsTests13, adjustHue_1) { | TEST_F(DeclarableOpsTests13, adjustHue_1) { | ||||||
| 
 | 
 | ||||||
|     NDArray input('c', {2,2,3}, {0,100,56, 17,220,5,  150,97,230, 255,2,13}, nd4j::DataType::FLOAT32); |     NDArray input('c', {2,2,3}, {0,100,56, 17,220,5,  150,97,230, 255,2,13}, nd4j::DataType::FLOAT32); | ||||||
|  |     NDArray factor = NDArrayFactory::create<float>(0.5); | ||||||
|     NDArray exp  ('c', {2,2,3}, {100,0,44, 208,5,220, 177,230,97,  2,255,244}, nd4j::DataType::FLOAT32); |     NDArray exp  ('c', {2,2,3}, {100,0,44, 208,5,220, 177,230,97,  2,255,244}, nd4j::DataType::FLOAT32); | ||||||
| 
 | 
 | ||||||
|     nd4j::ops::adjust_hue op; |     nd4j::ops::adjust_hue op; | ||||||
|     auto results = op.execute({&input}, {0.5}, {2}); |     auto results = op.execute({&input, &factor}, {}, {2}); | ||||||
| 
 | 
 | ||||||
|     ASSERT_EQ(ND4J_STATUS_OK, results->status()); |     ASSERT_EQ(ND4J_STATUS_OK, results->status()); | ||||||
| 
 | 
 | ||||||
| @ -525,10 +526,11 @@ TEST_F(DeclarableOpsTests13, adjustHue_5) { | |||||||
| TEST_F(DeclarableOpsTests13, adjustSaturation_1) { | TEST_F(DeclarableOpsTests13, adjustSaturation_1) { | ||||||
| 
 | 
 | ||||||
|     NDArray input('c', {2,2,3}, {0,100,56,  17,220,5,         150,97,230,    255,2,13}, nd4j::DataType::FLOAT32); |     NDArray input('c', {2,2,3}, {0,100,56,  17,220,5,         150,97,230,    255,2,13}, nd4j::DataType::FLOAT32); | ||||||
|  |     NDArray factor = NDArrayFactory::create<float>(0.5); | ||||||
|     NDArray exp  ('c', {2,2,3}, {50,100,78, 118.5,220,112.5,  190,163.5,230, 255,128.5,134}, nd4j::DataType::FLOAT32); |     NDArray exp  ('c', {2,2,3}, {50,100,78, 118.5,220,112.5,  190,163.5,230, 255,128.5,134}, nd4j::DataType::FLOAT32); | ||||||
| 
 | 
 | ||||||
|     nd4j::ops::adjust_saturation op; |     nd4j::ops::adjust_saturation op; | ||||||
|     auto results = op.execute({&input}, {0.5}, {2}); |     auto results = op.execute({&input, &factor}, {}, {2}); | ||||||
| 
 | 
 | ||||||
|     ASSERT_EQ(ND4J_STATUS_OK, results->status()); |     ASSERT_EQ(ND4J_STATUS_OK, results->status()); | ||||||
| 
 | 
 | ||||||
|  | |||||||
| @ -159,18 +159,19 @@ TEST_F(DeclarableOpsTests15, Test_standarize_bp_1) { | |||||||
| 
 | 
 | ||||||
| TEST_F(DeclarableOpsTests15, Test_AdjustContrast_1) { | TEST_F(DeclarableOpsTests15, Test_AdjustContrast_1) { | ||||||
|     auto x = NDArrayFactory::create<double>('c', {4,4,3}); |     auto x = NDArrayFactory::create<double>('c', {4,4,3}); | ||||||
|     auto e = NDArrayFactory::create<double>('c', {4,4,3}, { |     NDArray factor = NDArrayFactory::create<double>(2.); | ||||||
|         -21.5, -20.5, -19.5,  -15.5, -14.5, -13.5,  -9.5,  -8.5,  -7.5,  -3.5,  -2.5,  -1.5, |     auto e = NDArrayFactory::create<double>('c', {4,4,3}, {-21.5, -20.5, -19.5,  -15.5, -14.5, -13.5,  -9.5,  -8.5,  -7.5,  -3.5,  -2.5,  -1.5, | ||||||
|           2.5,   3.5,   4.5,    8.5,   9.5,  10.5,  14.5,  15.5,  16.5,  20.5,  21.5,  22.5, |                                      2.5,   3.5,   4.5,    8.5,   9.5,  10.5,  14.5,  15.5,  16.5,  20.5,  21.5,  22.5, | ||||||
|          26.5,  27.5,  28.5,   32.5,  33.5,  34.5,  38.5,  39.5,  40.5,  44.5,  45.5,  46.5, |                                     26.5,  27.5,  28.5,   32.5,  33.5,  34.5,  38.5,  39.5,  40.5,  44.5,  45.5,  46.5, | ||||||
|          50.5,  51.5,  52.5,   56.5,  57.5,  58.5,  62.5,  63.5,  64.5,  68.5,  69.5,  70.5 |                                     50.5,  51.5,  52.5,   56.5,  57.5,  58.5,  62.5,  63.5,  64.5,  68.5,  69.5,  70.5}); | ||||||
|     }); | 
 | ||||||
|  | 
 | ||||||
|     x.linspace(1.); |     x.linspace(1.); | ||||||
|     nd4j::ops::adjust_contrast op; |     nd4j::ops::adjust_contrast op; | ||||||
|     auto result = op.execute({&x}, {2.}, {}, {}); |     auto result = op.execute({&x, &factor}, {}, {}, {}); | ||||||
|     ASSERT_EQ(Status::OK(), result->status()); |     ASSERT_EQ(Status::OK(), result->status()); | ||||||
|     auto out = result->at(0); |     auto out = result->at(0); | ||||||
| //    out->printIndexedBuffer("Adjusted Constrast");
 | 
 | ||||||
|     ASSERT_TRUE(e.equalsTo(out)); |     ASSERT_TRUE(e.equalsTo(out)); | ||||||
|     delete result; |     delete result; | ||||||
| } | } | ||||||
|  | |||||||
| @ -1774,6 +1774,28 @@ TEST_F(DeclarableOpsTests3, betainc_test10) { | |||||||
|     delete results; |     delete results; | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
|  | ///////////////////////////////////////////////////////////////////
 | ||||||
|  | TEST_F(DeclarableOpsTests3, betainc_test11) { | ||||||
|  | 
 | ||||||
|  |     NDArray a('c', {4}, {0.7788f, 0.8012f, 0.7244f, 0.2309f}, nd4j::DataType::FLOAT32); | ||||||
|  |     NDArray b('c', {4}, {0.7717f, 0.9281f, 0.9846f, 0.4838f}, nd4j::DataType::FLOAT32); | ||||||
|  |     NDArray x('c', {4}, {0.9441f, 0.5957f, 0.8669f, 0.3502f}, nd4j::DataType::FLOAT32); | ||||||
|  | 
 | ||||||
|  |     NDArray expected('c', {4}, {0.912156, 0.634443, 0.898314, 0.624544}, nd4j::DataType::FLOAT32); | ||||||
|  | 
 | ||||||
|  |     nd4j::ops::betainc op; | ||||||
|  |     auto results = op.execute({&a, &b, &x}, {}, {}); | ||||||
|  | 
 | ||||||
|  |     ASSERT_EQ(ND4J_STATUS_OK, results->status()); | ||||||
|  | 
 | ||||||
|  |     auto *output = results->at(0); | ||||||
|  | 
 | ||||||
|  |     ASSERT_TRUE(expected.isSameShape(output)); | ||||||
|  |     ASSERT_TRUE(expected.equalsTo(output)); | ||||||
|  | 
 | ||||||
|  |     delete results; | ||||||
|  | } | ||||||
|  | 
 | ||||||
| ///////////////////////////////////////////////////////////////////
 | ///////////////////////////////////////////////////////////////////
 | ||||||
| TEST_F(DeclarableOpsTests3, zeta_test1) { | TEST_F(DeclarableOpsTests3, zeta_test1) { | ||||||
| 
 | 
 | ||||||
| @ -2092,8 +2114,26 @@ TEST_F(DeclarableOpsTests3, polygamma_test3) { | |||||||
|     x.linspace(10.); |     x.linspace(10.); | ||||||
| 
 | 
 | ||||||
|     auto expected= NDArrayFactory::create<double>('c', {3,3}, {1.05166336e-01,-9.04983497e-03, 1.31009323e-03,-2.44459433e-04, 5.31593880e-05,-1.28049888e-05, 3.31755364e-06,-9.07408791e-07, 2.58758130e-07}); |     auto expected= NDArrayFactory::create<double>('c', {3,3}, {1.05166336e-01,-9.04983497e-03, 1.31009323e-03,-2.44459433e-04, 5.31593880e-05,-1.28049888e-05, 3.31755364e-06,-9.07408791e-07, 2.58758130e-07}); | ||||||
|  |     nd4j::ops::polygamma op; | ||||||
|  |     auto results = op.execute({&n, &x}, {}, {}); | ||||||
| 
 | 
 | ||||||
|     //ASSERT_FALSE(true);
 |     ASSERT_EQ(ND4J_STATUS_OK, results->status()); | ||||||
|  | 
 | ||||||
|  |     auto output = results->at(0); | ||||||
|  | 
 | ||||||
|  |     ASSERT_TRUE(expected.isSameShape(output)); | ||||||
|  |     ASSERT_TRUE(expected.equalsTo(output)); | ||||||
|  | 
 | ||||||
|  |     delete results; | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | TEST_F(DeclarableOpsTests3, polygamma_test4) { | ||||||
|  | 
 | ||||||
|  |     NDArray n('c', {3,4}, {/*0.7788*/0, 0,1,2,3,4,5,6,7,8,9,10}, nd4j::DataType::DOUBLE); | ||||||
|  |     NDArray x('c', {3,4}, {0.7717,0.9281,0.9846,0.4838,0.6433,0.6041,0.6501,0.7612,0.7605,0.3948,0.9493,0.8600}, nd4j::DataType::DOUBLE); | ||||||
|  | 
 | ||||||
|  |     NDArray expected('c', {3,4}, {/*std::numeric_limits<double>::quiet_NaN()*/-1.031918,  -7.021327e-01,  1.682743e+00, -1.851378e+01,3.604167e+01, -3.008293e+02, | ||||||
|  |                                 1.596005e+03, -4.876665e+03,4.510025e+04, -1.730340e+08,  6.110257e+05, -1.907087e+07}, nd4j::DataType::DOUBLE); | ||||||
| 
 | 
 | ||||||
|     nd4j::ops::polygamma op; |     nd4j::ops::polygamma op; | ||||||
|     auto results = op.execute({&n, &x}, {}, {}); |     auto results = op.execute({&n, &x}, {}, {}); | ||||||
| @ -2108,6 +2148,26 @@ TEST_F(DeclarableOpsTests3, polygamma_test3) { | |||||||
|     delete results; |     delete results; | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
|  | TEST_F(DeclarableOpsTests3, digamma_1) { | ||||||
|  | 
 | ||||||
|  |     NDArray x('c', {18}, {-25, -24.99999, -21.5, -21.2, -5.5, -4.1, -2.1, -0.5, -0.3, 0., 0.2, 1, 1.5, 2.2, 5.2, 19., 21, 22.2}, nd4j::DataType::DOUBLE); | ||||||
|  | 
 | ||||||
|  |     NDArray expected('c', {18}, {std::numeric_limits<double>::infinity(), -99996.761229, 3.091129, 7.401432, 1.792911,11.196838,10.630354, 0.03649, 2.11331, | ||||||
|  |                                  std::numeric_limits<double>::infinity(),-5.28904,-0.577216, 0.03649, 0.544293, 1.549434,2.917892, 3.020524, 3.077401}, nd4j::DataType::DOUBLE); | ||||||
|  | 
 | ||||||
|  |     nd4j::ops::digamma op; | ||||||
|  |     auto results = op.execute({&x}, {}, {}); | ||||||
|  | 
 | ||||||
|  |     ASSERT_EQ(ND4J_STATUS_OK, results->status()); | ||||||
|  | 
 | ||||||
|  |     auto output = results->at(0); | ||||||
|  | 
 | ||||||
|  |     ASSERT_TRUE(expected.isSameShape(output)); | ||||||
|  |     ASSERT_TRUE(expected.equalsTo(output)); | ||||||
|  | 
 | ||||||
|  |     delete results; | ||||||
|  | } | ||||||
|  | 
 | ||||||
| ///////////////////////////////////////////////////////////////////
 | ///////////////////////////////////////////////////////////////////
 | ||||||
| TEST_F(DeclarableOpsTests3, svd_test1) { | TEST_F(DeclarableOpsTests3, svd_test1) { | ||||||
| 
 | 
 | ||||||
|  | |||||||
		Loading…
	
	
			
			x
			
			
		
	
		Reference in New Issue
	
	Block a user