cavis/libnd4j/tests_cpu/layers_tests/ConvolutionTests1.cpp

2438 lines
202 KiB
C++

/*******************************************************************************
* Copyright (c) 2015-2018 Skymind, Inc.
*
* This program and the accompanying materials are made available under the
* terms of the Apache License, Version 2.0 which is available at
* https://www.apache.org/licenses/LICENSE-2.0.
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* License for the specific language governing permissions and limitations
* under the License.
*
* SPDX-License-Identifier: Apache-2.0
******************************************************************************/
//
// @author raver119@gmail.com
//
#ifndef LIBND4J_CONVOLUTIONTESTS1_H
#define LIBND4J_CONVOLUTIONTESTS1_H
#include "testlayers.h"
#include <NDArray.h>
#include <Context.h>
#include <Node.h>
#include <graph/Variable.h>
#include <graph/VariableSpace.h>
#include <ops/declarable/CustomOperations.h>
#include <ops/declarable/helpers/convolutions.h>
#include <ops/declarable/helpers/col2im.h>
#include <PointersManager.h>
#ifdef HAVE_MKLDNN
#include <ops/declarable/platform/mkldnn/mkldnnUtils.h>
#endif
using namespace nd4j;
using namespace nd4j::graph;
class ConvolutionTests1 : public testing::Test {
public:
};
template <typename T>
class TypedConvolutionTests1 : public testing::Test {
public:
};
typedef ::testing::Types<double, float> TestingTypes;
TYPED_TEST_CASE(TypedConvolutionTests1, TestingTypes);
//////////////////////////////////////////////////////////////////////
TYPED_TEST(TypedConvolutionTests1, conv2d_1) {
int bS=1, iH=5,iW=4, iC=2,oC=3, kH=2,kW=2, sH=1,sW=1, pH=0,pW=0, dH=1,dW=1;
TypeParam _expB[]{664.0, 700.0, 736.0, 344.0, 808.0, 844.0, 880.0, 408.0, 952.0, 988.0, 1024.0, 472.0, 1096.0, 1132.0, 1168.0, 536.0, 466.0, 480.0, 494.0, 220.0, 1528.0, 1628.0, 1728.0, 856.0, 1928.0, 2028.0, 2128.0, 1048.0, 2328.0, 2428.0, 2528.0, 1240.0, 2728.0, 2828.0, 2928.0, 1432.0, 1346.0, 1392.0, 1438.0, 700.0, 2392.0, 2556.0, 2720.0, 1368.0, 3048.0, 3212.0, 3376.0, 1688.0, 3704.0, 3868.0, 4032.0, 2008.0, 4360.0, 4524.0, 4688.0, 2328.0, 2226.0, 2304.0, 2382.0, 1180.0};
Nd4jLong _expS[]{4, 1, 3, 5, 4, 60, 20, 4, 1, typeid(TypeParam) == typeid(float) ? 8192 : 16384, 1, 99};
auto input = NDArrayFactory::create_<TypeParam>('c', {bS, iC, iH, iW});
auto weights = NDArrayFactory::create_<TypeParam>('c', {oC, iC, kH, kW});
for (int e = 0; e < input->lengthOf(); e++)
input->p(e, e + 1);
for (int e = 0; e < weights->lengthOf(); e++)
weights->p(e, e + 1);
weights->permutei({2,3,1,0});
// weights->printShapeInfo("weights");
ArrayOptions::setDataType(_expS, input->dataType());
auto exp = new NDArray(_expB, _expS);
auto variableSpace = new VariableSpace();
variableSpace->putVariable(-1, input);
variableSpace->putVariable(-2, weights);
auto block = new Context(1, variableSpace, false); // not-in-place
block->fillInputs({-1, -2});
// 5,5 kernel
block->getIArguments()->push_back(kH);
block->getIArguments()->push_back(kW);
// 1,1 stride
block->getIArguments()->push_back(sH);
block->getIArguments()->push_back(sW);
// 0,0 padding
block->getIArguments()->push_back(pH);
block->getIArguments()->push_back(pW);
// 1,1 dilation
block->getIArguments()->push_back(dH);
block->getIArguments()->push_back(dW);
// same mode
block->getIArguments()->push_back(1);
// is NHWC
block->getIArguments()->push_back(0);
nd4j::ops::conv2d op;
Nd4jStatus status = op.execute(block);
ASSERT_EQ(ND4J_STATUS_OK, status);
auto res = variableSpace->getVariable(1)->getNDArray();
// checking output shape
ASSERT_EQ(1, res->sizeAt(0));
ASSERT_EQ(3, res->sizeAt(1));
ASSERT_EQ(5, res->sizeAt(2));
ASSERT_EQ(4, res->sizeAt(3));
// basically the same as above
ASSERT_TRUE(res->isSameShape(exp));
// just for visual validation
// exp->printIndexedBuffer("Expected");
// res->printIndexedBuffer("Actual ");
// res->printShapeInfo("Result shape");
// final check
ASSERT_TRUE(res->equalsTo(exp));
delete block;
delete variableSpace;
delete exp;
}
//////////////////////////////////////////////////////////////////////
TYPED_TEST(TypedConvolutionTests1, conv2d_2) {
auto input = NDArrayFactory::create<TypeParam>('c', {1, 1, 1, 4});
auto weights = NDArrayFactory::create<TypeParam>('c', {1, 1, 1, 4});
auto exp = NDArrayFactory::create<TypeParam>('c', {1, 4, 1, 4}, {2., 4., 6., 8., 2., 4., 6., 8., 2., 4., 6., 8., 2., 4., 6., 8.});
weights.assign(2.0);
input.linspace(1);
nd4j::ops::conv2d op;
auto result = op.execute({&input, &weights}, {}, {1, 1, 1, 1, 0, 0, 1, 1, 0, 0});
ASSERT_EQ(ND4J_STATUS_OK, result->status());
auto z = result->at(0);
ASSERT_TRUE(exp.isSameShape(z));
ASSERT_TRUE(exp.equalsTo(z));
delete result;
}
//////////////////////////////////////////////////////////////////////
TYPED_TEST(TypedConvolutionTests1, conv2d_3) {
int bS=2, iH=4,iW=3, iC=4,oC=3, kH=3,kW=2, sH=1,sW=1, pH=0,pW=0, dH=1,dW=1;
int oH=4,oW=3;
int paddingMode = 1; // 1-SAME, 0-VALID;
int dataFormat = 1; // 1-NHWC, 0-NCHW
auto input = NDArrayFactory::create<TypeParam>('c', {bS, iH, iW, iC});
auto weights = NDArrayFactory::create<TypeParam>('c', {kH, kW, iC, oC});
auto bias = NDArrayFactory::create<TypeParam>('c', {oC}, {1,2,3});
auto expOutput = NDArrayFactory::create<TypeParam>('c', {bS, oH, oW, oC},{ 152.f, 155.2f, 158.4f, 152.f, 155.2f, 158.4f, 66.4f, 68.f, 69.6f, 170.4f, 175.2f, 180.f, 170.4f, 175.2f, 180.f, 70.8f, 73.2f, 75.6f,
170.4f, 175.2f, 180.f, 170.4f, 175.2f, 180.f, 70.8f, 73.2f, 75.6f, 75.2f, 78.4f, 81.6f, 75.2f, 78.4f, 81.6f, 28.f, 29.6f, 31.2f,
152.f, 155.2f, 158.4f, 152.f, 155.2f, 158.4f, 66.4f, 68.f, 69.6f, 170.4f, 175.2f, 180.f, 170.4f, 175.2f, 180.f, 70.8f, 73.2f, 75.6f,
170.4f, 175.2f, 180.f, 170.4f, 175.2f, 180.f, 70.8f, 73.2f, 75.6f, 75.2f, 78.4f, 81.6f, 75.2f, 78.4f, 81.6f, 28.f, 29.6f, 31.2f});
input = 2.;
weights.linspace(0.1, 0.1);
nd4j::ops::conv2d op;
auto results = op.execute({&input, &weights}, {}, {kH,kW, sH,sW, pH,pW, dH,dW, paddingMode, dataFormat});
auto output = results->at(0);
ASSERT_EQ(Status::OK(), results->status());
ASSERT_TRUE(expOutput.isSameShape(output));
ASSERT_TRUE(expOutput.equalsTo(output));
delete results;
}
//////////////////////////////////////////////////////////////////////
TYPED_TEST(TypedConvolutionTests1, conv2d_4) {
int bS=2, iH=4,iW=3, iC=4,oC=3, kH=3,kW=2, sH=1,sW=1, pH=0,pW=0, dH=1,dW=1;
int oH=2,oW=2;
int paddingMode = 0; // 1-SAME, 0-VALID;
int dataFormat = 1; // 1-NHWC, 0-NCHW
auto input = NDArrayFactory::create<TypeParam>('c', {bS, iH, iW, iC});
auto weights = NDArrayFactory::create<TypeParam>('c', {kH, kW, iC, oC});
auto bias = NDArrayFactory::create<TypeParam>('c', {oC}, {1,2,3});
auto expOutput = NDArrayFactory::create<TypeParam>('c', {bS, oH, oW, oC},{ 170.4f,175.20001f,180.f,170.4f,175.20001f,180.f,170.4f,175.20001f,180.f,170.4f,175.20001f,180.f,170.4f,175.20001f,180.f,170.4f,175.20001f,180.f,170.4f,175.20001f,180.f,170.4f,175.20001f,180.f});
input = 2.;
weights.linspace(0.1, 0.1);
nd4j::ops::conv2d op;
auto results = op.execute({&input, &weights}, {}, {kH,kW, sH,sW, pH,pW, dH,dW, paddingMode, dataFormat});
auto output = results->at(0);
ASSERT_EQ(Status::OK(), results->status());
ASSERT_TRUE(expOutput.isSameShape(output));
ASSERT_TRUE(expOutput.equalsTo(output));
delete results;
}
//////////////////////////////////////////////////////////////////////
TYPED_TEST(TypedConvolutionTests1, conv2d_5) {
int bS=2, iH=4,iW=3, iC=4,oC=3, kH=3,kW=2, sH=1,sW=1, pH=0,pW=0, dH=1,dW=1;
int oH=2,oW=2;
int paddingMode = 0; // 1-SAME, 0-VALID;
int dataFormat = 0; // 1-NHWC, 0-NCHW
auto input = NDArrayFactory::create<TypeParam>('c', {bS, iC, iH, iW});
auto weights = NDArrayFactory::create<TypeParam>('c', {oC, iC, kH, kW});
auto bias = NDArrayFactory::create<TypeParam>('c', {oC}, {1,2,3});
auto expOutput = NDArrayFactory::create<TypeParam>('c', {bS, oC, oH, oW}, {61.f, 61.f, 61.f, 61.f, 177.2f, 177.2f, 177.2f, 177.2f, 293.4f, 293.4f, 293.4f, 293.4f, 61.f, 61.f, 61.f, 61.f, 177.2f, 177.2f, 177.2f, 177.2f, 293.4f, 293.4f, 293.4f, 293.4f});
input = 2.;
weights.linspace(0.1, 0.1);
weights.permutei({2,3,1,0});
nd4j::ops::conv2d op;
auto results = op.execute({&input, &weights, &bias}, {}, {kH,kW, sH,sW, pH,pW, dH,dW, paddingMode, dataFormat});
auto output = results->at(0);
// output->printIndexedBuffer();
ASSERT_EQ(Status::OK(), results->status());
ASSERT_TRUE(expOutput.isSameShape(output));
ASSERT_TRUE(expOutput.equalsTo(output));
delete results;
}
//////////////////////////////////////////////////////////////////////
TYPED_TEST(TypedConvolutionTests1, conv2d_6) {
auto input = NDArrayFactory::create<TypeParam>('c', {54, 1, 12, 12});
auto weights = NDArrayFactory::create<TypeParam>('c', {1, 2, 12, 2});
nd4j::ops::conv2d op;
auto result = op.execute({&input, &weights}, {}, {-1,-1, 1,1, 0,0, 1,1, 1,1});
ASSERT_EQ(Status::OK(), result->status());
delete result;
}
TYPED_TEST(TypedConvolutionTests1, TestAvgFF_TF) {
auto input = NDArrayFactory::create<TypeParam>('c', {4, 10, 10, 3}, {9.37125111f, 2.20166993f, 2.91434479f, 5.43639755f, -2.10573769f, 4.08528662f, 5.86908436f, -4.46203756f, 2.21057916f, 5.35849190f, 0.01394637f, 4.40566349f, 7.07982206f, -0.09633455f, 2.42429352f, 3.97301817f, -1.89553940f, 1.99690318f, 6.33141708f, 0.55401880f, 1.70707977f, 5.55204201f, -0.03513752f, 1.60011971f, 2.62700319f, -2.74582434f, 3.06697464f, 1.06277943f, -1.16075921f, -0.78095782f, 9.72352791f, -1.22686064f, 1.99644792f, 7.35571337f, 1.40607321f, 0.11390255f, 9.53334427f, 2.28303599f, -1.66728830f, 6.16678810f, -0.04532295f, -1.97708666f, 9.74906158f, 1.46223176f, -1.46734393f, 4.30761862f, -1.23790228f, 1.24823606f, 6.13938427f, -3.83689475f, -1.19625473f, 7.91535568f, 6.05868721f, -3.22946382f, 8.81633949f, -0.19967777f, 0.66053957f, 2.30919123f, 0.74543846f, -0.39347672f, 11.11058044f, 0.53720862f, 1.52645731f, 5.70012379f, -1.15213466f, 1.16451406f, 7.00526333f, 1.57362783f, -2.44384766f, 5.54213285f, -1.98828590f, -0.70483637f, 7.88281822f, -3.59875536f, 0.80745387f, 13.41578484f, -1.55507684f, -0.65855008f, 9.32583523f, -0.14544789f, 0.73436141f, 3.61176538f, -1.71268058f, -2.58490300f, 9.09280205f, -3.27405524f, -2.04569697f, 4.44761324f, -0.62955856f, -2.61917663f, 8.04890442f, 0.54579324f, 0.85929775f, 9.82259560f, -1.93825579f, 0.77703512f, 4.67090321f, -4.79267597f, -2.38906908f, 9.31265545f, 0.96026313f, -1.14109385f, 11.54231834f, -0.01417295f, -0.39500344f, 8.49191666f, 0.55300158f, 2.79490185f, 6.92466164f, 1.72254205f, 2.82222271f, 8.83112717f, 2.95033407f, 2.18054962f, 6.73509789f, -2.22272944f, 0.51127720f, -1.04563558f, 2.15747333f, -2.30959272f, 9.55441570f, 1.50396204f, 1.77370787f, 7.38146257f, -1.79076433f, 3.20961165f, 7.18864202f, 2.91217351f, 0.43018937f, 7.11078024f, -1.17386127f, -0.16817921f, 6.12327290f, -2.82205725f, 3.30696845f, 13.51291752f, -1.30856836f, -2.38332748f, 11.09487438f, -1.47190213f, -0.53050828f, 4.38285351f, -5.07309771f, 1.50714362f, 5.72274446f, -2.85825086f, -0.89673209f, 3.73791552f, -0.67708802f, -4.13149452f, -0.00671843f, -0.26566532f, 0.32961160f, 7.14501762f, -1.41608179f, -4.96590328f, 12.26205540f, -0.65158135f, -0.88641000f, 6.95777559f, -0.79058206f, -0.10260171f, 7.87169170f, 1.35921454f, 1.11759663f, 5.46187401f, -2.57214499f, 2.48484039f, 4.04043484f, -2.07137156f, -1.42709637f, 9.25487137f, -0.12605135f, -2.66949964f, 2.89412403f, 0.74451172f, -2.96250391f, 3.99258423f, 0.27084303f, 0.32213116f, 5.42332172f, -0.44414216f, 1.70881832f, 6.69346905f, 0.53058422f, -4.73146200f, 4.22051668f, 2.24834967f, 0.66996074f, 4.30173683f, 0.11849818f, -4.07520294f, 8.27318478f, -2.54398274f, -2.86705542f, 10.11775303f, -0.99382895f, 0.65881538f, 7.93556786f, -1.27934420f, -1.69343162f, 9.68042564f, -1.02609646f, -1.18189347f, 5.75370646f, -1.67888868f, -4.48871994f, 4.79537392f, -0.79212248f, -0.19855022f, 6.15060997f, -0.01081491f, 3.64454579f, 10.82562447f, 1.58859253f, -2.65847278f, 8.60093212f, -1.59196103f, 0.07635692f, 11.76175690f, -1.17453325f, 0.10122013f, 6.86458445f, -2.18891335f, -2.74004745f, 8.07066154f, 0.71818852f, -2.03035975f, 6.31053686f, 0.51509416f, 1.39789927f, 9.43515587f, 2.04256630f, 0.13985133f, 4.65010691f, 2.40911126f, -0.36255789f, -3.06867862f, -0.45225358f, -1.56778407f, 6.05917358f, -1.09891272f, 1.77184200f, 6.46248102f, 0.96042323f, -0.24346280f, 4.63436460f, -4.69907761f, 1.25187206f, 11.46173859f, -2.21917558f, 1.28007793f, 6.92173195f, 2.11268163f, -3.47389889f, 5.08722782f, -3.03950930f, -4.17154264f, 11.30568314f, 0.80361372f, 2.53214502f, 7.18707085f, -4.49114513f, 2.85449266f, 10.14906883f, -0.31974933f, -0.84472644f, -0.52459574f, 0.12921631f, -1.81390119f, 2.76170087f, 1.03982210f, 2.91744232f, -0.29048753f, 5.87453508f, -1.53684759f, 1.85800636f, -0.91404629f, 1.28954852f, 5.11354685f, -2.47475505f, -1.33179152f, 2.58552408f, 1.37316465f, -3.32339454f, 1.54122913f, 3.24953628f, -0.29758382f, 2.82391763f, -1.51142192f, -1.22699404f, 6.75745535f, 0.65452754f, -3.29385471f, 2.06008053f, 2.53172946f, -4.23532820f, -1.53909743f, -0.07010663f, -1.42173731f, 7.29031610f, -0.18448229f, 4.59496164f, 6.73027277f, 0.73441899f, 0.14426160f, 4.14915276f, -2.97010231f, 6.05851364f, 4.95218086f, -2.39145470f, 2.40494704f, 2.10288811f, 0.53503096f, 1.44511235f, 6.66344261f, -3.05803776f, 7.21418667f, 3.30303526f, -0.24163735f, 3.47409391f, 3.64520788f, 2.15189481f, -3.11243272f, 3.62310791f, 0.37379482f, 0.40865007f, -0.83132005f, -4.78246069f, 2.07030797f, 6.51765442f, 3.16178989f, 5.06180477f, 3.78434467f, -0.96689719f, 0.35965276f, 5.89967585f, 1.40294051f, 1.11952639f, 10.59778214f, 0.26739889f, -1.61297631f, 6.24801159f, -0.93914318f, -0.57812452f, 9.92604542f, -0.73025000f, -3.38530874f, 2.45646000f, -2.47949195f, 0.51638460f, 10.65636063f, 1.97816694f, -3.00407791f, 2.66914415f, -0.81951088f, -0.23316640f, 2.40737987f, -2.70007610f, 1.51531935f, 4.08860207f, -0.27552786f, -1.31721711f, 7.11568260f, -3.33498216f, -4.02545023f, 7.22675610f, -0.81690705f, -2.52689576f, 1.04016697f, -0.79291463f, -0.34875512f, 10.00498390f, -4.24167728f, 1.46162593f, 11.82569408f, -1.70359993f, -0.30161047f, 16.44085884f, -0.82253462f, -0.09435523f, 6.13080597f, -0.20259480f, 0.68308711f, 6.15663004f, -6.61776876f, 0.33295766f, 2.55449438f, -0.17819691f, -1.14892209f, 5.56776142f, 1.99279118f, 1.33035934f, 4.45823956f, 3.34916544f, -2.59905386f, 6.16164446f, -2.03881931f, -2.45273542f, 12.46793365f, -2.22743297f, 2.83738565f, 8.48628139f, -1.39347959f, -1.30867767f, 11.08041477f, -4.00363779f, 2.09183025f, 11.30395889f, -2.20504737f, 1.37426853f, 8.98735619f, 1.04676604f, -0.72757077f, 8.28050232f, -6.70741081f, -0.65798020f, 5.68592072f, -0.60760021f, 0.35854483f, 6.26852131f, 1.94100165f, 1.32112014f, 0.80987954f, -1.74617672f, -0.25434083f, 7.16045523f, 1.58884013f, -2.64847064f, 13.14820385f, 1.21393633f, -2.47258949f, 9.41650105f, -0.79384226f, 2.48954105f, 10.95629311f, 0.47723705f, 4.02126694f, 8.02593136f, -2.20726371f, -1.18794477f, 1.50836647f, 0.93118095f, -1.73513174f, 8.85493565f, -2.99670315f, -0.79055870f, 2.39473820f, 2.05046916f, -2.38055134f, 11.82299423f, 0.15609655f, 0.68744308f, 5.66401434f, -0.69281673f, 2.09855556f, 7.74626589f, -0.34283102f, 1.00542057f, 9.95838642f, 0.80161905f, 2.33455157f, 9.80057335f, -0.93561798f, 2.56991577f, 8.29711342f, 0.94213426f, 0.44209945f, 11.70259857f, 0.92710167f, 2.60957146f, 0.24971688f, -0.86529571f, 3.78628922f, 6.80884457f, -0.68178189f, 2.21103406f, 3.18895817f, 0.60283208f, -2.92716241f, 6.72060776f, -1.06625068f, 2.56543374f, 9.97404480f, 3.58080721f, -0.94936347f, 10.16736984f, -1.38464379f, 1.18191063f, 6.66179037f, -3.56115270f, 0.32329530f, 10.90870762f, 2.20638227f, 0.19653285f, 7.34650040f, -3.63859272f, -1.03027737f, 5.98829985f, -3.66606474f, -3.89746714f, 8.63469028f, 1.22569811f, 1.63240814f, 3.74385309f, 0.58243257f, -0.56981975f, 3.69260955f, 1.00979900f, -1.44030499f, 8.57058144f, -1.10648811f, 1.20474911f, 5.43133020f, -2.14822555f, -0.07928789f, 11.25825310f, 0.19645604f, -5.49546146f, 10.41917038f, -0.68178523f, -2.99639869f, 6.50054455f, 0.46488351f, -5.42328453f, 9.09500027f, -2.82107449f, 0.05601966f, 15.34610748f, -0.06820253f, 3.86699796f, 10.73316956f, -3.04795432f, -0.14702171f, 5.64813185f, 1.44028485f, -2.47596145f, 0.07280898f, -3.03187990f, -1.35183525f, 9.35835648f, 2.72966957f, 1.88199532f, 10.36187744f, -0.22834805f, -3.26738238f, 6.92025137f, -2.34061313f, 4.77379704f, 5.28559113f, -2.96323752f, -1.76186585f, 5.94436455f, 0.38647744f, -5.73869514f, 6.76849556f, 1.40892124f, -1.19068217f, 5.37919092f, -6.65328646f, 3.62782669f, 12.34744644f, 2.44762444f, -4.19242620f, 6.14906216f, 0.08121119f, 0.61355996f, 2.69666457f, -1.88962626f, -0.55314136f, 1.84937525f, 1.56048691f, 1.17460012f, 3.75674725f, 1.06198275f, -5.74625874f, 5.41645575f, -1.28946674f, -1.51689398f, 4.32400894f, -0.05222082f, -4.83948946f, 1.80747867f, 1.63144708f, -2.73887825f, 1.63975775f, -2.02163982f, -0.16210437f, 2.93518686f, 1.14427686f, -2.83246303f, 4.79283667f, 2.69697428f, -3.12678456f, -1.19225168f, -2.37022972f, -3.09429741f, 1.94225383f, -1.13747168f, -2.55048585f, 5.40242243f, 1.12777328f, 3.43713188f, 3.62658787f, -2.16878843f, 0.30164462f, 2.97407579f, -0.07275413f, -1.31149673f, 4.70066261f, -2.01323795f, 4.85255766f, 4.59128904f, 1.68084168f, 1.60336494f, 6.58138466f, -1.04759812f, 2.69906545f, 3.55769277f, -0.74327278f, 2.65819693f, 5.39528131f, 2.11248922f, -1.06446671f, 5.24546766f, -2.43146014f, 4.58907509f, 0.06521678f, -2.24503994f, 2.45722699f, 6.94863081f, 0.35258654f, 2.83396196f, 9.92525196f, -1.12225175f, -0.34365177f, 7.19116688f, -4.39813757f, 0.46517885f, 13.22028065f, -2.57483673f, -6.37226963f, 7.58046293f, -2.74600363f, 0.42231262f, 8.04881668f, 0.17289802f, -0.53447008f, 16.55157471f, -5.63614368f, 0.39288223f, 3.37079263f, 1.26484549f, -0.12820500f, 8.46440125f, -4.39304399f, 2.97676420f, 0.65650189f, 0.83158541f, -1.11556435f, 6.32885838f, -0.36087769f, 2.80724382f, 9.90292645f, 1.15936041f, 0.20947981f, 6.91249275f, -2.67404819f, 2.93782163f, 6.65656614f, -2.30828357f, 2.98214006f, 6.80611229f, -4.93821478f, -7.66555262f, 7.59763002f, -0.54159302f, 3.87403512f, 12.42607784f, 2.59284401f, -0.23375344f, 8.95293331f, -0.71807784f, 0.61873478f, 8.66713524f, 1.24289191f, -2.37835455f, 2.08071637f, -0.88315344f, -3.41891551f, 6.85245323f, 1.73007369f, 1.02169311f, 7.69170332f, -2.85411978f, 2.69790673f, 8.12906551f, -1.19351399f, -2.26442742f, 12.26104450f, -0.75579089f, -1.73274946f, 10.68729019f, 2.20655656f, -0.90522075f, 12.42165184f, -1.67929137f, 2.44851565f, 9.31565762f, -0.06645700f, 1.52762020f, 6.18427515f, -1.68882596f, 3.70261097f, 3.02252960f, -3.44125366f, -1.31575799f, 2.84617424f, -0.96849400f, -4.52356243f, 9.95027161f, 0.19966406f, -0.78874779f, 8.18595028f, -4.08300209f, 1.75126517f, 0.96418417f, -4.04913044f, -0.95200396f, 12.03637886f, -0.03041124f, 0.41642749f, 8.88267422f, -3.24985337f, -2.24919462f, 7.32566118f, 0.16964148f, -2.74123430f, 7.05264473f, -3.30191112f, 0.17163286f, 4.81851053f, -1.64463484f, -0.85933101f, 7.29276276f, 2.34066939f, -2.14860010f, 3.46148157f, -0.01782012f, 1.51504040f, 4.79304934f, 1.85281146f, -1.70663762f, 6.93470192f, -4.15440845f, -1.25983095f, 10.52491760f, 0.42930329f, -1.85146868f, 11.70042324f, -0.41704914f, 3.83796859f, 9.21148491f, -2.79719448f, 0.79470479f, 6.26926661f, -5.85230207f, 3.95105338f, 7.84790897f, -1.38680744f, -1.78099084f, 11.95235348f, -2.99841452f, -1.34507811f, 6.15714645f, -1.07552516f, -2.81228638f, 1.66234732f, -4.55166149f, -1.92601109f, 8.64634514f, -0.48158705f, 3.31595659f, 7.67371941f, 2.56964207f, 0.12107098f, 4.56467867f, -0.93541539f, 1.39432955f, 11.99714088f, 1.05353570f, -2.13099813f, 3.67617917f, 3.45895386f, 1.37365830f, 8.74344158f, -4.17585802f, 1.43908918f, 6.28764772f, 3.97346330f, -0.69144285f, 9.07983303f, -0.41635889f, -0.14965028f, 8.85469818f, 1.11306190f, 2.59440994f, 5.38982344f, -1.07948279f, 1.37252975f, 10.26984596f, -0.09318046f, 2.73104119f, 12.45902252f, -1.55446684f, -2.76124811f, 12.19395065f, -0.51846564f, 1.02764034f, 11.42673588f, -0.95940983f, -0.04781032f, 8.78379822f, -4.88957930f, 0.32534006f, 11.97696400f, -3.35108662f, 1.95104563f, 4.46915388f, -2.32061648f, 3.45230985f, 8.29983711f, 2.81034684f, -2.35529327f, 6.07801294f, -0.98105043f, -0.05359888f, 2.52291036f, -0.01986909f, -2.35321999f, 10.51954269f, 2.11145401f, 3.53506470f, 7.29093266f, 0.03721160f, -1.13496494f, 7.43886709f, -5.84201956f, 2.50796294f, 12.14647675f, 2.77490377f, -2.18896222f, 6.05641937f, 5.32617044f, 1.04221284f, 10.79106712f, -2.95749092f, -2.75414610f, 11.30037117f, -3.40654182f, -2.24673963f, 7.49126101f, 0.70811015f, -6.18003702f, 13.83951187f, -1.01204085f, 1.36298490f, -1.04451632f, 2.42435336f, -0.02346706f, -0.85528886f, 1.04731262f, 0.22192979f, 4.15708160f, 0.34933877f, 0.04814529f, 2.24107265f, 0.49676740f, -1.47752666f, 0.45040059f, -0.70471478f, -1.19759345f, 0.21711677f, 0.88461423f, -2.76830935f, 5.52066898f, 1.97664857f, -1.75381601f, 3.45877838f, 1.52617192f, -1.61350942f, 0.85337949f, 1.97610760f, -3.40310287f, 3.40319014f, -3.38691044f, -0.71319139f, 1.65463758f, -0.60680127f, -1.80700517f, 8.02592373f, 2.59627104f, 2.65895891f, 5.93043184f, -4.48425817f, 3.92670918f, 4.19496679f, -2.28286791f, 6.41634607f, 5.72330523f, 1.16269672f, -0.28753027f, 2.46342492f, 0.36693189f, 0.26712441f, 6.37652683f, -2.50139046f, 2.43923736f, 5.56310415f, 0.98065847f, 1.04267502f, 4.16403675f, -0.04966142f, 4.40897894f, 3.72905660f, -3.46129870f, 3.59962773f, 1.34830284f, -1.76661730f, 0.47943926f, 5.29946661f, -1.12711561f, 1.26970029f, 15.17655945f, -1.50971997f, 5.81345224f, 8.48562050f, -4.36049604f, 2.48144460f, 8.23780441f, -3.46030426f, -0.84656560f, 5.94946814f, 1.12747943f, -2.65683913f, 8.69085693f, 1.31309867f, -2.79958344f, 8.76840591f, -1.56444156f, 1.62710834f, 2.41177034f, -0.72804940f, 5.70619011f, 4.67169666f, -0.86167198f, -1.83803177f, 2.96346045f, 2.82692933f, -2.81557131f, 7.11113358f, -1.90071094f, 2.54244423f, 11.19284058f, -0.06298946f, -1.71517313f, 12.98388577f, 0.84510714f, 3.00816894f, 2.57200313f, 0.03899818f, -1.49330592f, 9.60099125f, -3.59513044f, -1.30045319f, 7.09241819f, -0.65233821f, -2.33627677f, 8.81366920f, 0.84154201f, 1.03312039f, 9.85289097f, 0.19351870f, 1.78496623f, 7.34631205f, -2.16530800f, -0.65016162f, 2.46842360f, 0.24016285f, -1.24308395f, 4.78175163f, -0.97682536f, 2.20942235f, 6.68382788f, 3.76786447f, -1.44454038f, 6.26453733f, -3.23575711f, -2.30137897f, 9.53092670f, -5.55222607f, 3.25999236f, 9.37559509f, 1.86339056f, -0.23551451f, 10.23400211f, 3.93031883f, -0.52629089f, 7.85724449f, -2.91549587f, 4.46612740f, 5.66530371f, -2.70820427f, 4.81359577f, 10.31247330f, 1.92230141f, 2.53931546f, 0.74986327f, 1.70303428f, 0.48063779f, 5.31099129f, -0.78976244f, 3.75864220f, 4.23051405f, 2.34042454f, -7.98193836f, 9.83987141f, -1.46722627f, 3.54497814f, 10.36455154f, -4.51249075f, 0.77715248f, 7.78694630f, -4.59989023f, -2.49585629f, 9.90296268f, 1.38535416f, 1.17441154f, 10.10452843f, -0.98628229f, 0.60194463f, 9.12639141f, -3.90754628f, 2.88526392f, 7.24123430f, -0.15283313f, -0.75728363f, -1.15116858f, -2.53791571f, 0.77229571f, 6.44114161f, 0.02646767f, 4.95463037f, 7.21066380f, 1.79384065f, 0.73250306f, 8.04447937f, 0.32576546f, -0.79447043f, 10.12717724f, 2.33392906f, 1.30716443f, 12.36073112f, -0.36694977f, -1.20438910f, 7.03105593f, 0.59557682f, 0.69267452f, 10.18113136f, 2.49944925f, -0.42229167f, 8.83143330f, -1.18805945f, -2.87509322f, 4.53596449f, 4.09732771f, -3.39088297f, -1.02536607f, 0.82119560f, -3.47302604f, 9.29991817f, 0.21001509f, 4.97036457f, 9.50018406f, 1.04420102f, 1.96560478f, 10.74769592f, -6.22709799f, 3.11690164f, 5.06759691f, -1.23724771f, -3.05831861f, 8.12925529f, -1.93435478f, -1.10151744f, 9.32263088f, -0.04249470f, -5.98547363f, 10.49398136f, 0.26400441f, -0.78915191f, 13.28219604f, 2.99276900f, 0.74853164f, 2.49364305f, -3.43529654f, 4.05278301f, 2.13498688f, -2.35444307f, -0.79900265f, 4.66968822f, -0.31095147f, 3.60674143f, 12.37222099f, -0.07855003f, -3.30292702f, 12.15215874f, 0.60886210f, 2.87075138f, 7.75271845f, 0.38044083f, 3.34402204f, 6.40583277f, -0.87888050f, 0.67438459f, 6.91080809f, 1.98332930f, -0.08303714f, 8.08630371f, -0.16772588f, -2.74058914f, 7.17253590f, -2.69122696f, 1.48173678f, 8.99470139f, -1.43302310f, -0.88651133f, 2.66944790f, -0.29186964f, 2.00838661f, 5.09587479f, -0.76676071f, -2.88322186f, 8.31110573f, -0.14550979f, -1.37726915f, 10.28355122f, -1.60575438f, -0.04118848f, 9.97510815f, 0.14440438f, -3.24632120f, 9.00034523f, 4.14319563f, -1.31023729f, 7.16950464f, -0.70428526f, 2.01559544f, 7.26155043f, 2.40816474f, 2.09847403f, 7.31264496f, -0.75401551f, 2.13392544f, 7.03648758f, 1.04036045f, -1.15636516f, 1.09634531f, -0.06340861f, -0.58107805f, -0.65623116f, 1.18972754f, -0.80717683f, 1.40118241f, -0.61932516f, -3.60596156f, 1.59904599f, -2.23774099f, -1.13721037f, 3.89620137f, -0.09115922f, -7.51356888f, 2.36975193f, -1.42520905f, -2.34173775f, 3.33830214f, -2.74016523f, -3.04115510f, 6.00119495f, -1.36084354f, -2.45065260f, 4.56992292f, -3.02825928f, -3.74182844f, 5.11069250f, -0.91531068f, -2.31385994f, 1.83399653f, 3.39370203f, -3.60886002f});
auto exp = NDArrayFactory::create<TypeParam>('c', {4, 4, 4, 3}, {7.97172260f, 0.06878620f, 2.27749538f, 7.29276514f, -0.14074677f, 0.65480286f, 5.70313978f, -0.06546132f, 0.35443667f, 3.70382833f, -0.84020567f, 0.63826996f, 8.60301399f, -0.38236514f, 1.55177069f, 7.37542057f, -0.99374938f, -0.29971302f, 8.84352493f, -0.67121059f, 0.43132120f, 4.78175592f, -1.25070143f, -1.91523600f, 6.03855371f, -0.00292124f, -1.11214364f, 7.90158176f, -0.57949901f, -0.96735370f, 7.81192017f, -0.53255427f, -0.48009714f, 3.16953635f, 0.08353355f, -1.54299748f, 3.74821687f, 1.69396687f, 0.72724354f, 5.42915201f, -1.13686812f, -0.71793109f, 5.78376389f, -0.72239977f, -0.60055625f, 2.53636408f, 0.56777251f, -2.07892323f, 6.08064651f, 0.68620735f, 2.54017019f, 5.65828180f, -0.68255502f, 1.47283304f, 6.10842514f, -0.39655915f, 0.28380761f, 1.96707797f, -1.98206317f, 0.94027776f, 4.71811438f, 0.32104525f, -0.92409706f, 8.34588146f, -1.05581069f, -0.55217457f, 9.58440876f, -0.96549922f, 0.45820439f, 5.65453672f, -2.50953507f, -0.71441835f, 8.03059578f, -0.21281289f, 0.92125505f, 9.26900673f, -0.35963219f, -0.70039093f, 8.59924412f, -1.22358346f, 0.81318003f, 3.85920119f, -0.01305223f, -1.09234154f, 6.33158875f, 1.28094780f, -1.48926139f, 4.94969177f, -0.77126902f, -1.97033751f, 5.64381838f, -0.16285487f, -1.31277227f, 2.39893222f, -1.32902908f, -1.39609122f, 6.47572327f, -0.45267010f, 1.55727172f, 6.70965624f, -1.68735468f, -0.05672536f, 7.25092363f, -0.64613032f, 0.67050058f, 3.60789680f, -2.05948973f, 2.22687531f, 8.15202713f, -0.70148355f, 1.28314006f, 8.14842319f, -1.88807654f, -1.04808438f, 8.45500565f, -0.76425624f, 0.94542569f, 4.56179953f, -0.28786001f, -2.04502511f, 8.46278095f, -0.31019822f, 0.07339200f, 9.34214592f, -0.61948007f, 0.52481830f, 8.32515621f, -1.52418160f, 0.49678251f, 5.11082315f, -1.09908783f, -0.52969611f, 5.27806664f, 0.88632923f, 0.66754371f, 4.75839233f, 0.48928693f, -0.68036932f, 6.56925392f, -0.02949905f, -2.99189186f, 4.46320581f, -0.64534980f, -0.29516968f, 8.60809517f, -1.13120568f, 3.41720533f, 5.84243155f, -1.24109328f, 0.89566326f, 5.99578333f, -0.42496428f, 2.07076764f, 3.17812920f, -0.81566459f, -0.14363396f, 6.55184317f, 0.39633346f, -0.43852386f, 8.70214558f, -2.24613595f, 0.30708700f, 8.73882294f, -0.53545928f, 1.54409575f, 4.49452257f, -0.16509305f, 0.19028664f, 8.24897003f, 0.44750381f, 2.15448594f, 8.97640514f, -0.77728152f, 0.57272542f, 9.03467560f, 0.47173575f, -1.10807717f, 3.30056310f, -0.43268481f, -0.41470885f, 3.53798294f, -0.08546703f, -2.16840744f, 6.18733406f, -0.17871059f, -2.59837723f, 5.94218683f, -1.02990067f, -0.49760687f, 3.76938033f, 0.86383581f, -1.91504073f});
nd4j::ops::avgpool2d op;
auto result = op.execute({&input}, {}, {3,3, 3,3, 0,0, 1,1,1, 0,1}, {});
ASSERT_EQ(Status::OK(), result->status());
auto z = result->at(0);
// z->printIndexedBuffer("z");
// exp.printIndexedBuffer("e");
ASSERT_TRUE(exp.isSameShape(z));
ASSERT_TRUE(exp.equalsTo(z));
delete result;
}
//////////////////////////////////////////////////////////////////////
TYPED_TEST(TypedConvolutionTests1, conv2d_7) {
int bS=1, iH=256,iW=256, iC=1,oC=1, kH=4,kW=3, sH=1,sW=1, pH=0,pW=0, dH=1,dW=1;
// int oH=256,oW=256;
int paddingMode = 1; // 1-SAME, 0-VALID;
int dataFormat = 0; // 1-NHWC, 0-NCHW
auto input = NDArrayFactory::create<TypeParam>('c', {bS, iC, iH, iW});
auto weights = NDArrayFactory::create<TypeParam>('c', {kH, kW, iC, oC});
input = 5.;
weights = 3.;
nd4j::ops::conv2d op;
auto results = op.execute({&input, &weights}, {}, {kH,kW, sH,sW, pH,pW, dH,dW, paddingMode, dataFormat});
auto* output = results->at(0);
ASSERT_EQ(Status::OK(), results->status());
delete results;
}
//////////////////////////////////////////////////////////////////////
TEST_F(ConvolutionTests1, sconv2d_1) {
float _expB[] = {10025.0f, 10350.0f, 10675.0f, 11000.0f, 11325.0f, 11650.0f, 13275.0f, 13600.0f, 13925.0f, 14250.0f, 14575.0f, 14900.0f, 16525.0f, 16850.0f, 17175.0f, 17500.0f, 17825.0f, 18150.0f, 19775.0f, 20100.0f, 20425.0f, 20750.0f, 21075.0f, 21400.0f, 23025.0f, 23350.0f, 23675.0f, 24000.0f, 24325.0f, 24650.0f, 26275.0f, 26600.0f, 26925.0f, 27250.0f, 27575.0f, 27900.0f, 38775.0f, 40350.0f, 41925.0f, 43500.0f, 45075.0f, 46650.0f, 54525.0f, 56100.0f, 57675.0f, 59250.0f, 60825.0f, 62400.0f, 70275.0f, 71850.0f, 73425.0f, 75000.0f, 76575.0f, 78150.0f, 86025.0f, 87600.0f, 89175.0f, 90750.0f, 92325.0f, 93900.0f, 101775.0f, 103350.0f, 104925.0f, 106500.0f, 108075.0f, 109650.0f, 117525.0f, 119100.0f, 120675.0f, 122250.0f, 123825.0f, 125400.0f, 67525.0f, 70350.0f, 73175.0f, 76000.0f, 78825.0f, 81650.0f, 95775.0f, 98600.0f, 101425.0f, 104250.0f, 107075.0f, 109900.0f, 124025.0f, 126850.0f, 129675.0f, 132500.0f, 135325.0f, 138150.0f, 152275.0f, 155100.0f, 157925.0f, 160750.0f, 163575.0f, 166400.0f, 180525.0f, 183350.0f, 186175.0f, 189000.0f, 191825.0f, 194650.0f, 208775.0f, 211600.0f, 214425.0f, 217250.0f, 220075.0f, 222900.0f, 119400.0f, 120350.0f, 121300.0f, 122250.0f, 123200.0f, 124150.0f, 128900.0f, 129850.0f, 130800.0f, 131750.0f, 132700.0f, 133650.0f, 138400.0f, 139350.0f, 140300.0f, 141250.0f, 142200.0f, 143150.0f, 147900.0f, 148850.0f, 149800.0f, 150750.0f, 151700.0f, 152650.0f, 157400.0f, 158350.0f, 159300.0f, 160250.0f, 161200.0f, 162150.0f, 166900.0f, 167850.0f, 168800.0f, 169750.0f, 170700.0f, 171650.0f, 273150.0f, 275350.0f, 277550.0f, 279750.0f, 281950.0f, 284150.0f, 295150.0f, 297350.0f, 299550.0f, 301750.0f, 303950.0f, 306150.0f, 317150.0f, 319350.0f, 321550.0f, 323750.0f, 325950.0f, 328150.0f, 339150.0f, 341350.0f, 343550.0f, 345750.0f, 347950.0f, 350150.0f, 361150.0f, 363350.0f, 365550.0f, 367750.0f, 369950.0f, 372150.0f, 383150.0f, 385350.0f, 387550.0f, 389750.0f, 391950.0f, 394150.0f, 426900.0f, 430350.0f, 433800.0f, 437250.0f, 440700.0f, 444150.0f, 461400.0f, 464850.0f, 468300.0f, 471750.0f, 475200.0f, 478650.0f, 495900.0f, 499350.0f, 502800.0f, 506250.0f, 509700.0f, 513150.0f, 530400.0f, 533850.0f, 537300.0f, 540750.0f, 544200.0f, 547650.0f, 564900.0f, 568350.0f, 571800.0f, 575250.0f, 578700.0f, 582150.0f, 599400.0f, 602850.0f, 606300.0f, 609750.0f, 613200.0f, 616650.0f, 75025.0f, 75350.0f, 75675.0f, 76000.0f, 76325.0f, 76650.0f, 78275.0f, 78600.0f, 78925.0f, 79250.0f, 79575.0f, 79900.0f, 81525.0f, 81850.0f, 82175.0f, 82500.0f, 82825.0f, 83150.0f, 84775.0f, 85100.0f, 85425.0f, 85750.0f, 86075.0f, 86400.0f, 88025.0f, 88350.0f, 88675.0f, 89000.0f, 89325.0f, 89650.0f, 91275.0f, 91600.0f, 91925.0f, 92250.0f, 92575.0f, 92900.0f, 353775.0f, 355350.0f, 356925.0f, 358500.0f, 360075.0f, 361650.0f, 369525.0f, 371100.0f, 372675.0f, 374250.0f, 375825.0f, 377400.0f, 385275.0f, 386850.0f, 388425.0f, 390000.0f, 391575.0f, 393150.0f, 401025.0f, 402600.0f, 404175.0f, 405750.0f, 407325.0f, 408900.0f, 416775.0f, 418350.0f, 419925.0f, 421500.0f, 423075.0f, 424650.0f, 432525.0f, 434100.0f, 435675.0f, 437250.0f, 438825.0f, 440400.0f, 632525.0f, 635350.0f, 638175.0f, 641000.0f, 643825.0f, 646650.0f, 660775.0f, 663600.0f, 666425.0f, 669250.0f, 672075.0f, 674900.0f, 689025.0f, 691850.0f, 694675.0f, 697500.0f, 700325.0f, 703150.0f, 717275.0f, 720100.0f, 722925.0f, 725750.0f, 728575.0f, 731400.0f, 745525.0f, 748350.0f, 751175.0f, 754000.0f, 756825.0f, 759650.0f, 773775.0f, 776600.0f, 779425.0f, 782250.0f, 785075.0f, 787900.0f, 309400.0f, 310350.0f, 311300.0f, 312250.0f, 313200.0f, 314150.0f, 318900.0f, 319850.0f, 320800.0f, 321750.0f, 322700.0f, 323650.0f, 328400.0f, 329350.0f, 330300.0f, 331250.0f, 332200.0f, 333150.0f, 337900.0f, 338850.0f, 339800.0f, 340750.0f, 341700.0f, 342650.0f, 347400.0f, 348350.0f, 349300.0f, 350250.0f, 351200.0f, 352150.0f, 356900.0f, 357850.0f, 358800.0f, 359750.0f, 360700.0f, 361650.0f, 713150.0f, 715350.0f, 717550.0f, 719750.0f, 721950.0f, 724150.0f, 735150.0f, 737350.0f, 739550.0f, 741750.0f, 743950.0f, 746150.0f, 757150.0f, 759350.0f, 761550.0f, 763750.0f, 765950.0f, 768150.0f, 779150.0f, 781350.0f, 783550.0f, 785750.0f, 787950.0f, 790150.0f, 801150.0f, 803350.0f, 805550.0f, 807750.0f, 809950.0f, 812150.0f, 823150.0f, 825350.0f, 827550.0f, 829750.0f, 831950.0f, 834150.0f, 1116900.0f, 1120350.0f, 1123800.0f, 1127250.0f, 1130700.0f, 1134150.0f, 1151400.0f, 1154850.0f, 1158300.0f, 1161750.0f, 1165200.0f, 1168650.0f, 1185900.0f, 1189350.0f, 1192800.0f, 1196250.0f, 1199700.0f, 1203150.0f, 1220400.0f, 1223850.0f, 1227300.0f, 1230750.0f, 1234200.0f, 1237650.0f, 1254900.0f, 1258350.0f, 1261800.0f, 1265250.0f, 1268700.0f, 1272150.0f, 1289400.0f, 1292850.0f, 1296300.0f, 1299750.0f, 1303200.0f, 1306650.0f,};
Nd4jLong _expS[] = {4, 2, 6, 6, 6, 144, 36, 6, 1, 8192, 1, 99};
NDArray exp(_expB, _expS);
int sY = 1;
int sX = 1;
int pY = 0;
int pX = 0;
int iC = 2;
int oC = 3;
int kY = 5;
int kX = 5;
int iY = 10;
int iX = 10;
int B = 2;
auto input = NDArrayFactory::create_<float>('c', {B, iC, iY, iX});
for (int e = 0; e < input->lengthOf(); e++)
input->p(e, e+1);
auto weights = NDArrayFactory::create_<float>('c', {oC, iC, kY, kX});
for (int e = 0; e < weights->lengthOf(); e++)
weights->p(e, e+1);
weights->permutei({2,3,1,0});
auto variableSpace = new VariableSpace();
variableSpace->putVariable(-1, input);
variableSpace->putVariable(-2, weights);
auto block = new Context(1, variableSpace, false);
block->fillInputs({-1, -2});
block->getIArguments()->push_back(kY);
block->getIArguments()->push_back(kX);
block->getIArguments()->push_back(sY);
block->getIArguments()->push_back(sX);
block->getIArguments()->push_back(pY);
block->getIArguments()->push_back(pX);
// dilation
block->getIArguments()->push_back(1);
block->getIArguments()->push_back(1);
// NOT same mode
block->getIArguments()->push_back(0);
nd4j::ops::sconv2d op;
Nd4jStatus status = op.execute(block);
ASSERT_EQ(ND4J_STATUS_OK, status);
auto output = variableSpace->getVariable(1)->getNDArray();
//exp.printShapeInfo("Expected shape");
//output->printShapeInfo("Result shape");
ASSERT_TRUE(exp.isSameShape(output));
//exp.printBuffer("Expctd buffer");
//output->printBuffer("Result buffer");
ASSERT_TRUE(exp.equalsTo(output));
delete block;
delete variableSpace;
}
//////////////////////////////////////////////////////////////////////
TYPED_TEST(TypedConvolutionTests1, sconv2d_2) {
TypeParam _expBFF[] = {108.9405008, 109.5920008, 110.2435008, 110.8950008, 111.5465008, 112.1980008, 115.4555008, 116.1070008, 116.7585008, 117.410000, 118.061500, 118.7130009, 121.9705009, 122.6220009, 123.2735009, 123.9250009, 124.5765009, 125.2280009, 128.4855009, 129.1370009, 129.7885009, 130.4400009, 131.09150, 131.74300, 135.0005010, 135.6520010, 136.3035010, 136.9550010, 137.6065010, 138.2580010, 141.5155010, 142.1670010, 142.8185010, 143.4700010, 144.1215010, 144.7730010, 248.9617514, 250.670751, 252.3797515, 254.0887515, 255.7977515, 257.5067515, 266.0517515, 267.7607515, 269.469751, 271.1787516, 272.8877516, 274.5967516, 283.1417516, 284.8507516, 286.5597516, 288.268751, 289.9777517, 291.6867517, 300.2317517, 301.9407517, 303.6497517, 305.3587517, 307.067751, 308.7767518, 317.3217518, 319.0307518, 320.7397518, 322.4487518, 324.157751, 325.866751, 334.4117519, 336.1207519, 337.8297519, 339.5387519, 341.2477519, 342.95675, 388.9829964, 391.7494964, 394.5159964, 397.2824964, 400.048996, 402.8154963, 416.647996, 419.4144962, 422.1809962, 424.9474962, 427.7139962, 430.4804962, 444.3129961, 447.0794961, 449.8459961, 452.6124960, 455.3789960, 458.1454960, 471.9779959, 474.7444959, 477.5109959, 480.2774959, 483.0439959, 485.8104958, 499.6429958, 502.4094957, 505.1759957, 507.9424957, 510.7089957, 513.4754957, 527.3079956, 530.0744956, 532.8409956, 535.607495, 538.3739955, 541.1404955, 529.0042487, 532.8282487, 536.6522487, 540.4762487, 544.3002487, 548.1242487, 567.2442487, 571.068248, 574.892248, 578.716248, 582.540248, 586.3642486, 605.4842486, 609.3082486, 613.1322486, 616.9562486, 620.7802486, 624.6042486, 643.7242486, 647.5482486, 651.3722486, 655.1962486, 659.0202486, 662.8442486, 681.9642486, 685.7882486, 689.6122486, 693.4362486, 697.2602486, 701.0842486, 720.2042486, 724.0282486, 727.852248, 731.676248, 735.500248, 739.324248, 669.0255044, 673.9070044, 678.7885044, 683.6700044, 688.5515044, 693.4330044, 717.8405044, 722.7220044, 727.6035044, 732.4850044, 737.3665044, 742.2480044, 766.6555043, 771.5370043, 776.4185043, 781.3000043, 786.1815043, 791.0630043, 815.4705043, 820.3520043, 825.2335043, 830.1150043, 834.9965043, 839.8780043, 864.2855042, 869.1670042, 874.0485042, 878.9300042, 883.8115042, 888.6930042, 913.1005042, 917.9820042, 922.8635042, 927.7450042, 932.6265042, 937.5080042, 809.0467424, 814.9857424, 820.9247424, 826.8637423, 832.8027423, 838.7417423, 868.4367421, 874.3757421, 880.3147420, 886.2537420, 892.1927420, 898.13174, 927.8267418, 933.7657418, 939.7047417, 945.6437417, 951.5827417, 957.5217416, 987.2167415, 993.155741, 999.0947414, 1005.0337414, 1010.972741, 1016.9117413, 1046.6067412, 1052.5457411, 1058.4847411, 1064.4237411, 1070.3627410, 1076.3017410, 1105.996740, 1111.9357408, 1117.8747408, 1123.8137408, 1129.7527407, 1135.6917407, 949.0679815, 956.0644814, 963.060981, 970.0574813, 977.0539812, 984.0504811, 1019.0329807, 1026.0294807, 1033.0259806, 1040.0224805, 1047.0189804, 1054.0154804, 1088.9979800, 1095.9944799, 1102.9909798, 1109.987479, 1116.9839797, 1123.9804796, 1158.9629792, 1165.9594791, 1172.9559791, 1179.9524790, 1186.9489789, 1193.9454788, 1228.9279785, 1235.9244784, 1242.9209783, 1249.9174782, 1256.913978, 1263.9104781, 1298.8929777, 1305.8894776, 1312.8859775, 1319.8824775, 1326.8789774, 1333.8754773, 1089.0892560, 1097.1432561, 1105.1972562, 1113.251256, 1121.3052563, 1129.3592564, 1169.6292568, 1177.6832568, 1185.7372569, 1193.7912570, 1201.845257, 1209.8992571, 1250.1692575, 1258.2232576, 1266.2772576, 1274.3312577, 1282.3852578, 1290.4392579, 1330.7092582, 1338.7632583, 1346.8172584, 1354.8712584, 1362.9252585, 1370.9792586, 1411.24925, 1419.3032590, 1427.3572591, 1435.4112592, 1443.465259, 1451.5192593, 1491.7892597, 1499.8432598, 1507.8972598, 1515.9512599, 1524.0052600, 1532.059260, 1229.1105073, 1238.2220073, 1247.3335073, 1256.4450073, 1265.5565073, 1274.668007, 1320.2255074, 1329.3370074, 1338.4485074, 1347.5600075, 1356.6715075, 1365.7830075, 1411.340507, 1420.4520076, 1429.5635076, 1438.6750076, 1447.7865076, 1456.8980076, 1502.4555077, 1511.5670077, 1520.6785077, 1529.7900077, 1538.9015077, 1548.013007, 1593.5705078, 1602.6820078, 1611.793507, 1620.9050079, 1630.0165079, 1639.1280079, 1684.6855080, 1693.7970080, 1702.9085080, 1712.0200080, 1721.1315080, 1730.2430080, 1369.1317613, 1379.3007614, 1389.4697614, 1399.6387615, 1409.8077615, 1419.976761, 1470.8217618, 1480.9907618, 1491.159761, 1501.3287619, 1511.4977619, 1521.6667620, 1572.5117622, 1582.6807622, 1592.8497623, 1603.0187623, 1613.1877624, 1623.3567624, 1674.2017626, 1684.3707627, 1694.5397627, 1704.7087628, 1714.8777628, 1725.046762, 1775.8917631, 1786.0607631, 1796.229763, 1806.3987632, 1816.5677632, 1826.7367633, 1877.5817635, 1887.7507635, 1897.9197636, 1908.0887636, 1918.2577637, 1928.4267637, 304.3905022, 305.0420022, 305.6935022, 306.3450022, 306.9965022, 307.6480022, 310.9055022, 311.5570022, 312.208502, 312.860002, 313.5115023, 314.1630023, 317.4205023, 318.0720023, 318.7235023, 319.3750023, 320.0265023, 320.6780023, 323.9355023, 324.5870023, 325.2385023, 325.8900023, 326.541502, 327.193002, 330.4505024, 331.1020024, 331.7535024, 332.4050024, 333.0565024, 333.7080024, 336.9655024, 337.6170024, 338.2685024, 338.9200024, 339.5715024, 340.223002, 761.6617542, 763.3707542, 765.0797542, 766.7887542, 768.4977542, 770.206754, 778.7517543, 780.4607543, 782.1697543, 783.8787543, 785.5877543, 787.2967543, 795.8417544, 797.5507544, 799.2597544, 800.9687544, 802.6777544, 804.3867544, 812.9317545, 814.6407545, 816.3497545, 818.0587545, 819.7677545, 821.4767545, 830.0217546, 831.7307546, 833.4397546, 835.1487546, 836.8577546, 838.5667546, 847.1117547, 848.8207547, 850.5297547, 852.2387547, 853.9477547, 855.6567547, 1218.9329915, 1221.6994915, 1224.4659915, 1227.232491, 1229.9989914, 1232.7654914, 1246.5979913, 1249.3644913, 1252.1309913, 1254.8974913, 1257.6639913, 1260.430491, 1274.2629912, 1277.029491, 1279.7959911, 1282.5624911, 1285.3289911, 1288.0954911, 1301.9279910, 1304.6944910, 1307.4609910, 1310.22749, 1312.9939909, 1315.7604909, 1329.5929908, 1332.3594908, 1335.1259908, 1337.8924908, 1340.6589908, 1343.4254908, 1357.2579907, 1360.0244907, 1362.7909906, 1365.5574906, 1368.3239906, 1371.0904906, 1676.2042479, 1680.0282479, 1683.8522479, 1687.6762479, 1691.5002479, 1695.3242479, 1714.4442479, 1718.2682479, 1722.0922479, 1725.9162479, 1729.7402479, 1733.5642479, 1752.6842479, 1756.5082479, 1760.3322479, 1764.1562479, 1767.9802479, 1771.8042479, 1790.9242479, 1794.7482479, 1798.5722479, 1802.3962479, 1806.2202479, 1810.044247, 1829.1642478, 1832.9882478, 1836.8122478, 1840.6362478, 1844.4602478, 1848.2842478, 1867.4042478, 1871.2282478, 1875.0522478, 1878.8762478, 1882.7002478, 1886.5242478, 2133.4755029, 2138.3570029, 2143.2385029, 2148.1200029, 2153.0015029, 2157.8830029, 2182.2905028, 2187.1720028, 2192.0535028, 2196.9350028, 2201.8165028, 2206.6980028, 2231.1055028, 2235.9870028, 2240.8685028, 2245.7500028, 2250.6315028, 2255.5130028, 2279.9205027, 2284.8020027, 2289.6835027, 2294.5650027, 2299.4465027, 2304.3280027, 2328.7355027, 2333.6170027, 2338.4985027, 2343.3800027, 2348.2615027, 2353.1430027, 2377.5505026, 2382.4320026, 2387.3135026, 2392.1950026, 2397.0765026, 2401.9580026, 2590.7467330, 2596.6857330, 2602.6247329, 2608.5637329, 2614.5027329, 2620.441732, 2650.1367327, 2656.0757327, 2662.0147326, 2667.9537326, 2673.8927326, 2679.8317325, 2709.5267324, 2715.465732, 2721.4047323, 2727.3437323, 2733.282732, 2739.2217322, 2768.9167321, 2774.8557320, 2780.7947320, 2786.7337320, 2792.6727319, 2798.6117319, 2828.306731, 2834.2457317, 2840.1847317, 2846.1237317, 2852.0627316, 2858.0017316, 2887.6967314, 2893.6357314, 2899.5747314, 2905.5137313, 2911.4527313, 2917.3917313, 3048.0179587, 3055.0144586, 3062.0109585, 3069.0074584, 3076.0039584, 3083.0004583, 3117.9829579, 3124.9794578, 3131.9759578, 3138.9724577, 3145.9689576, 3152.9654575, 3187.947957, 3194.9444571, 3201.9409570, 3208.9374569, 3215.933956, 3222.9304568, 3257.9129564, 3264.9094563, 3271.9059562, 3278.9024562, 3285.8989561, 3292.8954560, 3327.8779556, 3334.874455, 3341.8709555, 3348.8674554, 3355.8639553, 3362.860455, 3397.8429549, 3404.8394548, 3411.8359547, 3418.8324546, 3425.8289546, 3432.8254545, 3505.28927, 3513.3432780, 3521.3972781, 3529.4512782, 3537.5052782, 3545.5592783, 3585.8292787, 3593.8832788, 3601.9372788, 3609.9912789, 3618.0452790, 3626.099279, 3666.3692794, 3674.4232795, 3682.4772796, 3690.5312796, 3698.5852797, 3706.6392798, 3746.9092801, 3754.9632802, 3763.0172803, 3771.0712804, 3779.1252804, 3787.1792805, 3827.4492809, 3835.50328, 3843.5572810, 3851.6112811, 3859.6652812, 3867.7192812, 3907.9892816, 3916.0432817, 3924.097281, 3932.1512818, 3940.2052819, 3948.2592820, 3962.5605113, 3971.6720113, 3980.783511, 3989.8950114, 3999.0065114, 4008.1180114, 4053.6755115, 4062.7870115, 4071.8985115, 4081.0100115, 4090.1215115, 4099.2330115, 4144.7905116, 4153.9020116, 4163.0135116, 4172.1250116, 4181.236511, 4190.3480117, 4235.9055117, 4245.0170117, 4254.128511, 4263.2400118, 4272.3515118, 4281.4630118, 4327.0205119, 4336.1320119, 4345.2435119, 4354.3550119, 4363.4665119, 4372.5780119, 4418.1355120, 4427.2470120, 4436.3585120, 4445.4700120, 4454.581512, 4463.6930121, 4419.8317743, 4430.0007744, 4440.1697744, 4450.338774, 4460.5077745, 4470.6767745, 4521.521774, 4531.6907748, 4541.8597748, 4552.0287749, 4562.1977749, 4572.3667750, 4623.2117752, 4633.3807752, 4643.5497753, 4653.7187753, 4663.8877754, 4674.0567754, 4724.9017756, 4735.0707757, 4745.2397757, 4755.4087757, 4765.5777758, 4775.7467758, 4826.591776, 4836.7607761, 4846.9297761, 4857.0987762, 4867.2677762, 4877.4367763, 4928.2817765, 4938.4507765, 4948.6197766, 4958.7887766, 4968.957776, 4979.12677675};
Nd4jLong _expSFF[] = {4, 2, 10, 6, 6, 360, 36, 6, 1, typeid(TypeParam) == typeid(float) ? 8192 : 16384, 1, 99,};
NDArray expFF(_expBFF, _expSFF);
auto input = NDArrayFactory::create<TypeParam>('c', {2, 3, 10, 10});
auto weightsD = NDArrayFactory::create<TypeParam>('c', {5, 3, 5, 5});
auto weightsP = NDArrayFactory::create<TypeParam>('c', {10, 15, 1, 1});
input.linspace(1);
weightsD.linspace(1);
weightsP.linspace(1);
weightsD.permutei({2,3,1,0});
weightsP.permutei({2,3,1,0});
input.applyScalar(scalar::Divide, 100.0);
weightsD.applyScalar(scalar::Divide, 100.0);
weightsP.applyScalar(scalar::Divide, 100.0);
nd4j::ops::sconv2d op;
auto resultFF = op.execute({&input, &weightsD, &weightsP}, {}, {5, 5, 1, 1, 0, 0, 1, 1, 0, 0}, {});
auto z = resultFF->at(0);
//z->printShapeInfo("FF shape");
ASSERT_TRUE(z->isSameShape(&expFF));
//expFF.printBuffer("e");
//z->printBuffer("z");
ASSERT_TRUE(z->equalsTo(&expFF, 1e-3));
delete resultFF;
}
TYPED_TEST(TypedConvolutionTests1, sconv2d_3) {
auto input = NDArrayFactory::create<TypeParam>('c', {3, 3, 8, 8});
auto weightsD = NDArrayFactory::create<TypeParam>('c', {1, 3, 1, 1});
auto weightsP = NDArrayFactory::create<TypeParam>('c', {2, 3, 1, 1});
auto bias = NDArrayFactory::create<TypeParam>('c', {2});
auto output = NDArrayFactory::create<TypeParam>('c', {3, 2, 8, 8});
output.assign(0.0);
input.linspace(1);
weightsD.linspace(1);
weightsP.linspace(1);
bias.linspace(1);
weightsD.permutei({2,3,1,0});
weightsP.permutei({2,3,1,0});
auto expOutput = NDArrayFactory::create<TypeParam>('c', {3, 2, 8, 8});
nd4j::ops::sconv2d op;
Nd4jStatus status = op.execute({&input, &weightsD, &weightsP, &bias}, {&output}, {}, {1, 1, 1, 1, 0, 0, 1, 1, 0}, {});
auto result = op.execute({&input, &weightsD, &weightsP, &bias}, {}, {1, 1, 1, 1, 0, 0, 1, 1, 0}, {});
auto z = result->at(0);
//printf("\n");
//output.printBuffer("output");
//z->printBuffer("z");
//ASSERT_TRUE(expOutput.isSameShape(z));
delete result;
}
TYPED_TEST(TypedConvolutionTests1, deconv2D_FF_NoBias_1) {
Nd4jLong _expS[] = {4, 2, 3, 8, 8, 192, 64, 8, 1, typeid(TypeParam) == typeid(float) ? 8192 : 16384, 1, 99};
TypeParam _expB[] = {6276.0, 12831.0, 19668.0, 26790.0, 27012.0, 20703.0, 14100.0, 7200.0, 13719.0, 28023.0, 42918.0, 58410.0, 58902.0, 45105.0, 30693.0, 15660.0, 22389.0, 45696.0, 69930.0, 95100.0, 95910.0, 73386.0, 49899.0, 25440.0, 32346.0, 65970.0, 100884.0, 137100.0, 138276.0, 105726.0, 71838.0, 36600.0, 33726.0, 68790.0, 105204.0, 142980.0, 144156.0, 110226.0, 74898.0, 38160.0, 27555.0, 56154.0, 85806.0, 116520.0, 117474.0, 89748.0, 60933.0, 31020.0, 19917.0, 40557.0, 61926.0, 84030.0, 84714.0, 64671.0, 43875.0, 22320.0, 10752.0, 21879.0, 33384.0, 45270.0, 45636.0, 34815.0, 23604.0, 12000.0, 7551.0, 15456.0, 23718.0, 32340.0, 32562.0, 24978.0, 17025.0, 8700.0, 16569.0, 33873.0, 51918.0, 70710.0, 71202.0, 54555.0, 37143.0, 18960.0, 27114.0, 55371.0, 84780.0, 115350.0, 116160.0, 88911.0, 60474.0, 30840.0, 39246.0, 80070.0, 122484.0, 166500.0, 167676.0, 128226.0, 87138.0, 44400.0, 40626.0, 82890.0, 126804.0, 172380.0, 173556.0, 132726.0, 90198.0, 45960.0, 33180.0, 67629.0, 103356.0, 140370.0, 141324.0, 107973.0, 73308.0, 37320.0, 23967.0, 48807.0, 74526.0, 101130.0, 101814.0, 77721.0, 52725.0, 26820.0, 12927.0, 26304.0, 40134.0, 54420.0, 54786.0, 41790.0, 28329.0, 14400.0, 8826.0, 18081.0, 27768.0, 37890.0, 38112.0, 29253.0, 19950.0, 10200.0, 19419.0, 39723.0, 60918.0, 83010.0, 83502.0, 64005.0, 43593.0, 22260.0, 31839.0, 65046.0, 99630.0, 135600.0, 136410.0, 104436.0, 71049.0, 36240.0, 46146.0, 94170.0, 144084.0, 195900.0, 197076.0, 150726.0, 102438.0, 52200.0, 47526.0, 96990.0, 148404.0, 201780.0, 202956.0, 155226.0, 105498.0, 53760.0, 38805.0, 79104.0, 120906.0, 164220.0, 165174.0, 126198.0, 85683.0, 43620.0, 28017.0, 57057.0, 87126.0, 118230.0, 118914.0, 90771.0, 61575.0, 31320.0, 15102.0, 30729.0, 46884.0, 63570.0, 63936.0, 48765.0, 33054.0, 16800.0, 17220.0, 34863.0, 52932.0, 71430.0, 72228.0, 54831.0, 36996.0, 18720.0, 36327.0, 73527.0, 111606.0, 150570.0, 152214.0, 115521.0, 77925.0, 39420.0, 57381.0, 116112.0, 176202.0, 237660.0, 240198.0, 182250.0, 122907.0, 62160.0, 80442.0, 162738.0, 246900.0, 332940.0, 336420.0, 255198.0, 172062.0, 87000.0, 84702.0, 171318.0, 259860.0, 350340.0, 353820.0, 268338.0, 180882.0, 91440.0, 66867.0, 135210.0, 205038.0, 276360.0, 279042.0, 211572.0, 142581.0, 72060.0, 46845.0, 94701.0, 143574.0, 193470.0, 195306.0, 148047.0, 99747.0, 50400.0, 24576.0, 49671.0, 75288.0, 101430.0, 102372.0, 77583.0, 52260.0, 26400.0, 22095.0, 44688.0, 67782.0, 91380.0, 92178.0, 69906.0, 47121.0, 23820.0, 46377.0, 93777.0, 142206.0, 191670.0, 193314.0, 146571.0, 98775.0, 49920.0, 72906.0, 147387.0, 223452.0, 301110.0, 303648.0, 230175.0, 155082.0, 78360.0, 101742.0, 205638.0, 311700.0, 419940.0, 423420.0, 320898.0, 216162.0, 109200.0, 106002.0, 214218.0, 324660.0, 437340.0, 440820.0, 334038.0, 224982.0, 113640.0, 83292.0, 168285.0, 254988.0, 343410.0, 346092.0, 262197.0, 176556.0, 89160.0, 58095.0, 117351.0, 177774.0, 239370.0, 241206.0, 182697.0, 122997.0, 62100.0, 30351.0, 61296.0, 92838.0, 124980.0, 125922.0, 95358.0, 64185.0, 32400.0, 26970.0, 54513.0, 82632.0, 111330.0, 112128.0, 84981.0, 57246.0, 28920.0, 56427.0, 114027.0, 172806.0, 232770.0, 234414.0, 177621.0, 119625.0, 60420.0, 88431.0, 178662.0, 270702.0, 364560.0, 367098.0, 278100.0, 187257.0, 94560.0, 123042.0, 248538.0, 376500.0, 506940.0, 510420.0, 386598.0, 260262.0, 131400.0, 127302.0, 257118.0, 389460.0, 524340.0, 527820.0, 399738.0, 269082.0, 135840.0, 99717.0, 201360.0, 304938.0, 410460.0, 413142.0, 312822.0, 210531.0, 106260.0, 69345.0, 140001.0, 211974.0, 285270.0, 287106.0, 217347.0, 146247.0, 73800.0, 36126.0, 72921.0, 110388.0, 148530.0, 149472.0, 113133.0, 76110.0, 38400.0,};
NDArray exp(_expB, _expS);
auto input = NDArrayFactory::create_<TypeParam>('c', {2, 3, 4, 4});
auto weights = NDArrayFactory::create_<TypeParam>('c', {3, 3, 5, 5});
input->linspace(1);
weights->linspace(1);
weights->permutei({2,3,1,0});
auto variableSpace = new VariableSpace();
variableSpace->putVariable(-1, input);
variableSpace->putVariable(-2, weights);
auto block = new Context(1, variableSpace, false);
block->fillInputs({-1, -2});
block->getIArguments()->push_back(5);
block->getIArguments()->push_back(5);
block->getIArguments()->push_back(1);
block->getIArguments()->push_back(1);
block->getIArguments()->push_back(0);
block->getIArguments()->push_back(0);
// dilation
block->getIArguments()->push_back(1);
block->getIArguments()->push_back(1);
// NOT same mode
block->getIArguments()->push_back(0);
block->getIArguments()->push_back(0);
nd4j::ops::deconv2d op;
Nd4jStatus status = op.execute(block);
ASSERT_EQ(ND4J_STATUS_OK, status);
auto output = variableSpace->getVariable(1)->getNDArray();
ASSERT_TRUE(exp.isSameShape(output));
// exp.printBuffer("Expctd buffer");
//output->printBuffer("Result buffer");
ASSERT_TRUE(exp.equalsTo(output));
delete variableSpace;
delete block;
}
TYPED_TEST(TypedConvolutionTests1, conv2D_BP_Bias_1) {
TypeParam _expWGradB[] = {9312.0, 12580.0, 9528.0, 13168.0, 17712.0, 13360.0, 9960.0, 13348.0, 10032.0, 13344.0, 18148.0, 13848.0, 19312.0, 26160.0, 19888.0, 15144.0, 20452.0, 15504.0};
Nd4jLong _expWGradS[] = {4, 2, 1, 3, 3, 9, 9, 3, 1, typeid(TypeParam) == typeid(float) ? 8192 : 16384, 1, 99};
NDArray expWGrad(_expWGradB, _expWGradS);
expWGrad.permutei({2,3,1,0});
TypeParam _expBGradB[] = {784.0, 1296.0};
Nd4jLong _expBGradS[] = {2, 2, 1, 1, 1, typeid(TypeParam) == typeid(float) ? 8192 : 16384, 1, 99};
NDArray expBGrad(_expBGradB, _expBGradS);
auto input = NDArrayFactory::create<TypeParam>('c', {2, 1, 4, 4});
auto weights = NDArrayFactory::create<TypeParam>('c', {2, 1, 3, 3});
auto bias = NDArrayFactory::create<TypeParam>('c', {2, 1});
auto epsilonNext = NDArrayFactory::create<TypeParam>('c', {2, 2, 4, 4});
TypeParam _expEpsB[] = {952.0, 1540.0, 1636.0, 1180.0, 1791.0, 2886.0, 3057.0, 2193.0, 2223.0, 3570.0, 3741.0, 2673.0, 1900.0, 3028.0, 3160.0, 2240.0, 2872.0, 4612.0, 4708.0, 3356.0, 5247.0, 8358.0, 8529.0, 6033.0, 5679.0, 9042.0, 9213.0, 6513.0, 4588.0, 7252.0, 7384.0, 5184.0};
NDArray expEps(_expEpsB, input.getShapeInfo());
input.linspace(1);
weights.linspace(1);
epsilonNext.linspace(1);
weights.permutei({2,3,1,0});
nd4j::ops::conv2d_bp op;
auto results = op.execute({&input, &weights, &bias, &epsilonNext}, {}, {3, 3, 1, 1, 0, 0, 1, 1, 1}, {});
ASSERT_TRUE(results->size() == 3);
auto epsilon = results->at(0);
auto gradW = results->at(1);
auto gradB = results->at(2);
ASSERT_TRUE(expWGrad.isSameShape(gradW));
//expWGrad.printBuffer("Expctd buffer");
// gradW->printBuffer("Result buffer");
ASSERT_TRUE(expWGrad.equalsTo(gradW));
ASSERT_TRUE(input.isSameShape(epsilon));
// expEps.printBuffer("Expctd buffer");
//epsilon->printBuffer("Result buffer");
ASSERT_TRUE(expEps.equalsTo(epsilon));
ASSERT_TRUE(expBGrad.isSameShape(gradB));
ASSERT_TRUE(expBGrad.equalsTo(gradB));
delete results;
}
TYPED_TEST(TypedConvolutionTests1, conv2D_BP_NoBias_1) {
TypeParam _expWGradB[] = {9312.0, 12580.0, 9528.0, 13168.0, 17712.0, 13360.0, 9960.0, 13348.0, 10032.0, 13344.0, 18148.0, 13848.0, 19312.0, 26160.0, 19888.0, 15144.0, 20452.0, 15504.0};
Nd4jLong _expWGradS[] = {4, 2, 1, 3, 3, 9, 9, 3, 1, typeid(TypeParam) == typeid(float) ? 8192 : 16384, 1, 99};
NDArray expWGrad(_expWGradB, _expWGradS);
expWGrad.permutei({2,3,1,0});
auto input = NDArrayFactory::create<TypeParam>('c', {2, 1, 4, 4});
auto weights = NDArrayFactory::create<TypeParam>('c', {2, 1, 3, 3});
auto epsilonNext = NDArrayFactory::create<TypeParam>('c', {2, 2, 4, 4});
TypeParam _expEpsB[] = {952.0, 1540.0, 1636.0, 1180.0, 1791.0, 2886.0, 3057.0, 2193.0, 2223.0, 3570.0, 3741.0, 2673.0, 1900.0, 3028.0, 3160.0, 2240.0, 2872.0, 4612.0, 4708.0, 3356.0, 5247.0, 8358.0, 8529.0, 6033.0, 5679.0, 9042.0, 9213.0, 6513.0, 4588.0, 7252.0, 7384.0, 5184.0};
NDArray expEps(_expEpsB, input.getShapeInfo());
input.linspace(1);
weights.linspace(1);
epsilonNext.linspace(1);
weights.permutei({2,3,1,0});
nd4j::ops::conv2d_bp op;
auto results = op.execute({&input, &weights, &epsilonNext}, {}, {3, 3, 1, 1, 0, 0, 1, 1, 1}, {});
ASSERT_TRUE(results->size() == 2);
auto epsilon = results->at(0);
auto gradW = results->at(1);
ASSERT_TRUE(expWGrad.isSameShape(gradW));
//expWGrad.printBuffer("Expctd buffer");
// gradW->printBuffer("Result buffer");
ASSERT_TRUE(expWGrad.equalsTo(gradW));
ASSERT_TRUE(input.isSameShape(epsilon));
// expEps.printBuffer("Expctd buffer");
//epsilon->printBuffer("Result buffer");
ASSERT_TRUE(expEps.equalsTo(epsilon));
delete results;
}
TYPED_TEST(TypedConvolutionTests1, sconv2d_conv2d_1) {
TypeParam _expBFF[] = {10025.0f, 10350.0f, 10675.0f, 11000.0f, 11325.0f, 11650.0f, 13275.0f, 13600.0f, 13925.0f, 14250.0f, 14575.0f, 14900.0f, 16525.0f, 16850.0f, 17175.0f, 17500.0f, 17825.0f, 18150.0f, 19775.0f, 20100.0f, 20425.0f, 20750.0f, 21075.0f, 21400.0f, 23025.0f, 23350.0f, 23675.0f, 24000.0f, 24325.0f, 24650.0f, 26275.0f, 26600.0f, 26925.0f, 27250.0f, 27575.0f, 27900.0f, 53150.0f, 55350.0f, 57550.0f, 59750.0f, 61950.0f, 64150.0f, 75150.0f, 77350.0f, 79550.0f, 81750.0f, 83950.0f, 86150.0f, 97150.0f, 99350.0f, 101550.0f, 103750.0f, 105950.0f, 108150.0f, 119150.0f, 121350.0f, 123550.0f, 125750.0f, 127950.0f, 130150.0f, 141150.0f, 143350.0f, 145550.0f, 147750.0f, 149950.0f, 152150.0f, 163150.0f, 165350.0f, 167550.0f, 169750.0f, 171950.0f, 174150.0f, 119400.0f, 120350.0f, 121300.0f, 122250.0f, 123200.0f, 124150.0f, 128900.0f, 129850.0f, 130800.0f, 131750.0f, 132700.0f, 133650.0f, 138400.0f, 139350.0f, 140300.0f, 141250.0f, 142200.0f, 143150.0f, 147900.0f, 148850.0f, 149800.0f, 150750.0f, 151700.0f, 152650.0f, 157400.0f, 158350.0f, 159300.0f, 160250.0f, 161200.0f, 162150.0f, 166900.0f, 167850.0f, 168800.0f, 169750.0f, 170700.0f, 171650.0f, 350025.0f, 352850.0f, 355675.0f, 358500.0f, 361325.0f, 364150.0f, 378275.0f, 381100.0f, 383925.0f, 386750.0f, 389575.0f, 392400.0f, 406525.0f, 409350.0f, 412175.0f, 415000.0f, 417825.0f, 420650.0f, 434775.0f, 437600.0f, 440425.0f, 443250.0f, 446075.0f, 448900.0f, 463025.0f, 465850.0f, 468675.0f, 471500.0f, 474325.0f, 477150.0f, 491275.0f, 494100.0f, 496925.0f, 499750.0f, 502575.0f, 505400.0f, 353775.0f, 355350.0f, 356925.0f, 358500.0f, 360075.0f, 361650.0f, 369525.0f, 371100.0f, 372675.0f, 374250.0f, 375825.0f, 377400.0f, 385275.0f, 386850.0f, 388425.0f, 390000.0f, 391575.0f, 393150.0f, 401025.0f, 402600.0f, 404175.0f, 405750.0f, 407325.0f, 408900.0f, 416775.0f, 418350.0f, 419925.0f, 421500.0f, 423075.0f, 424650.0f, 432525.0f, 434100.0f, 435675.0f, 437250.0f, 438825.0f, 440400.0f, 771900.0f, 775350.0f, 778800.0f, 782250.0f, 785700.0f, 789150.0f, 806400.0f, 809850.0f, 813300.0f, 816750.0f, 820200.0f, 823650.0f, 840900.0f, 844350.0f, 847800.0f, 851250.0f, 854700.0f, 858150.0f, 875400.0f, 878850.0f, 882300.0f, 885750.0f, 889200.0f, 892650.0f, 909900.0f, 913350.0f, 916800.0f, 920250.0f, 923700.0f, 927150.0f, 944400.0f, 947850.0f, 951300.0f, 954750.0f, 958200.0f, 961650.0f, 107525.0f, 107850.0f, 108175.0f, 108500.0f, 108825.0f, 109150.0f, 110775.0f, 111100.0f, 111425.0f, 111750.0f, 112075.0f, 112400.0f, 114025.0f, 114350.0f, 114675.0f, 115000.0f, 115325.0f, 115650.0f, 117275.0f, 117600.0f, 117925.0f, 118250.0f, 118575.0f, 118900.0f, 120525.0f, 120850.0f, 121175.0f, 121500.0f, 121825.0f, 122150.0f, 123775.0f, 124100.0f, 124425.0f, 124750.0f, 125075.0f, 125400.0f, 713150.0f, 715350.0f, 717550.0f, 719750.0f, 721950.0f, 724150.0f, 735150.0f, 737350.0f, 739550.0f, 741750.0f, 743950.0f, 746150.0f, 757150.0f, 759350.0f, 761550.0f, 763750.0f, 765950.0f, 768150.0f, 779150.0f, 781350.0f, 783550.0f, 785750.0f, 787950.0f, 790150.0f, 801150.0f, 803350.0f, 805550.0f, 807750.0f, 809950.0f, 812150.0f, 823150.0f, 825350.0f, 827550.0f, 829750.0f, 831950.0f, 834150.0f, 404400.0f, 405350.0f, 406300.0f, 407250.0f, 408200.0f, 409150.0f, 413900.0f, 414850.0f, 415800.0f, 416750.0f, 417700.0f, 418650.0f, 423400.0f, 424350.0f, 425300.0f, 426250.0f, 427200.0f, 428150.0f, 432900.0f, 433850.0f, 434800.0f, 435750.0f, 436700.0f, 437650.0f, 442400.0f, 443350.0f, 444300.0f, 445250.0f, 446200.0f, 447150.0f, 451900.0f, 452850.0f, 453800.0f, 454750.0f, 455700.0f, 456650.0f, 1197525.0f, 1200350.0f, 1203175.0f, 1206000.0f, 1208825.0f, 1211650.0f, 1225775.0f, 1228600.0f, 1231425.0f, 1234250.0f, 1237075.0f, 1239900.0f, 1254025.0f, 1256850.0f, 1259675.0f, 1262500.0f, 1265325.0f, 1268150.0f, 1282275.0f, 1285100.0f, 1287925.0f, 1290750.0f, 1293575.0f, 1296400.0f, 1310525.0f, 1313350.0f, 1316175.0f, 1319000.0f, 1321825.0f, 1324650.0f, 1338775.0f, 1341600.0f, 1344425.0f, 1347250.0f, 1350075.0f, 1352900.0f, 826275.0f, 827850.0f, 829425.0f, 831000.0f, 832575.0f, 834150.0f, 842025.0f, 843600.0f, 845175.0f, 846750.0f, 848325.0f, 849900.0f, 857775.0f, 859350.0f, 860925.0f, 862500.0f, 864075.0f, 865650.0f, 873525.0f, 875100.0f, 876675.0f, 878250.0f, 879825.0f, 881400.0f, 889275.0f, 890850.0f, 892425.0f, 894000.0f, 895575.0f, 897150.0f, 905025.0f, 906600.0f, 908175.0f, 909750.0f, 911325.0f, 912900.0f, 1806900.0f, 1810350.0f, 1813800.0f, 1817250.0f, 1820700.0f, 1824150.0f, 1841400.0f, 1844850.0f, 1848300.0f, 1851750.0f, 1855200.0f, 1858650.0f, 1875900.0f, 1879350.0f, 1882800.0f, 1886250.0f, 1889700.0f, 1893150.0f, 1910400.0f, 1913850.0f, 1917300.0f, 1920750.0f, 1924200.0f, 1927650.0f, 1944900.0f, 1948350.0f, 1951800.0f, 1955250.0f, 1958700.0f, 1962150.0f, 1979400.0f, 1982850.0f, 1986300.0f, 1989750.0f, 1993200.0f, 1996650.};
Nd4jLong _expSFF[] = {4, 2, 6, 6, 6, 216, 36, 6, 1, typeid(TypeParam) == typeid(float) ? 8192 : 16384, 1, 99,};
NDArray expFF(_expBFF, _expSFF);
TypeParam _exp2BFF[] = {827.4900282f, 832.2350283f, 836.9800284f, 841.725028f, 846.4700287f, 851.2150288f, 874.9400293f, 879.6850294f, 884.4300295f, 889.1750296f, 893.9200297f, 898.665029f, 922.3900304f, 927.1350305f, 931.8800306f, 936.6250307f, 941.3700308f, 946.1150309f, 969.8400315f, 974.5850316f, 979.3300317f, 984.0750318f, 988.8200319f, 993.5650320f, 1017.2900326f, 1022.0350327f, 1026.7800328f, 1031.5250329f, 1036.2700330f, 1041.0150331f, 1064.7400337f, 1069.4850338f, 1074.2300339f, 1078.9750340f, 1083.7200341f, 1088.4650342f, 1822.4550553f, 1833.995055f, 1845.5350558f, 1857.075056f, 1868.6150563f, 1880.1550566f, 1937.8550578f, 1949.3950581f, 1960.9350583f, 1972.4750586f, 1984.015058f, 1995.5550591f, 2053.2550604f, 2064.7950606f, 2076.3350609f, 2087.8750611f, 2099.4150614f, 2110.955061f, 2168.6550629f, 2180.1950632f, 2191.7350634f, 2203.2750637f, 2214.8150639f, 2226.3550642f, 2284.0550655f, 2295.5950657f, 2307.1350660f, 2318.6750662f, 2330.2150665f, 2341.7550667f, 2399.4550680f, 2410.9950683f, 2422.5350685f, 2434.0750688f, 2445.6150690f, 2457.1550693f, 2817.419968f, 2835.7549686f, 2854.0899683f, 2872.4249680f, 2890.7599677f, 2909.0949674f, 3000.7699660f, 3019.104965f, 3037.4399655f, 3055.7749652f, 3074.1099649f, 3092.4449646f, 3184.1199632f, 3202.4549629f, 3220.789962f, 3239.1249624f, 3257.4599621f, 3275.7949618f, 3367.4699604f, 3385.8049601f, 3404.1399598f, 3422.474959f, 3440.8099593f, 3459.1449590f, 3550.8199576f, 3569.1549573f, 3587.4899570f, 3605.8249567f, 3624.1599565f, 3642.4949562f, 3734.1699548f, 3752.5049545f, 3770.8399542f, 3789.1749539f, 3807.5099536f, 3825.8449534f, 3812.385098f, 3837.5150988f, 3862.6450994f, 3887.7751000f, 3912.9051006f, 3938.0351012f, 4063.6851041f, 4088.8151047f, 4113.9451053f, 4139.0751059f, 4164.2051065f, 4189.3351071f, 4314.9851100f, 4340.1151106f, 4365.2451112f, 4390.3751118f, 4415.5051124f, 4440.6351130f, 4566.2851159f, 4591.4151165f, 4616.5451171f, 4641.6751177f, 4666.805118f, 4691.9351188f, 4817.5851218f, 4842.7151224f, 4867.8451230f, 4892.975123f, 4918.1051241f, 4943.2351247f, 5068.8851277f, 5094.0151283f, 5119.1451288f, 5144.2751294f, 5169.4051300f, 5194.5351306f, 4807.3499803f, 4839.2749801f, 4871.1999799f, 4903.1249797f, 4935.0499795f, 4966.9749793f, 5126.5999784f, 5158.5249782f, 5190.4499780f, 5222.3749778f, 5254.2999777f, 5286.2249775f, 5445.8499765f, 5477.774976f, 5509.6999762f, 5541.6249760f, 5573.5499758f, 5605.4749756f, 5765.0999747f, 5797.0249745f, 5828.9499743f, 5860.8749741f, 5892.7999739f, 5924.724973f, 6084.3499728f, 6116.2749726f, 6148.1999724f, 6180.1249723f, 6212.0499721f, 6243.9749719f, 6403.59997f, 6435.5249708f, 6467.4499706f, 6499.3749704f, 6531.2999702f, 6563.2249700f, 5802.3150007f, 5841.0350006f, 5879.7550005f, 5918.4750004f, 5957.195000f, 5995.9150003f, 6189.5149999f, 6228.2349998f, 6266.9549997f, 6305.6749996f, 6344.3949995f, 6383.114999f, 6576.7149990f, 6615.4349990f, 6654.1549989f, 6692.8749988f, 6731.5949987f, 6770.3149986f, 6963.9149982f, 7002.6349981f, 7041.3549981f, 7080.0749980f, 7118.7949979f, 7157.5149978f, 7351.1149974f, 7389.8349973f, 7428.5549972f, 7467.2749972f, 7505.9949971f, 7544.7149970f, 7738.3149966f, 7777.0349965f, 7815.7549964f, 7854.4749963f, 7893.1949963f, 7931.9149962f, 6797.2799488f, 6842.794948f, 6888.3099489f, 6933.8249490f, 6979.3399491f, 7024.8549492f, 7252.4299497f, 7297.9449498f, 7343.4599499f, 7388.9749500f, 7434.489950f, 7480.0049501f, 7707.5799506f, 7753.0949507f, 7798.6099508f, 7844.1249509f, 7889.6399510f, 7935.1549511f, 8162.7299515f, 8208.2449516f, 8253.7599517f, 8299.2749518f, 8344.7899519f, 8390.3049520f, 8617.8799525f, 8663.394952f, 8708.9099526f, 8754.4249527f, 8799.9399528f, 8845.4549529f, 9073.0299534f, 9118.5449535f, 9164.0599536f, 9209.5749537f, 9255.089953f, 9300.604953f, 7792.2451647f, 7844.5551655f, 7896.8651663f, 7949.1751671f, 8001.4851679f, 8053.7951686f, 8315.3451725f, 8367.6551733f, 8419.9651741f, 8472.2751749f, 8524.585175f, 8576.8951764f, 8838.4451803f, 8890.7551811f, 8943.0651819f, 8995.3751827f, 9047.6851834f, 9099.9951842f, 9361.5451881f, 9413.8551889f, 9466.1651897f, 9518.475190f, 9570.7851912f, 9623.0951920f, 9884.6451959f, 9936.9551967f, 9989.2651975f, 10041.5751982f, 10093.8851990f, 10146.1951998f, 10407.7452037f, 10460.0552045f, 10512.3652053f, 10564.6752060f, 10616.9852068f, 10669.2952076f, 8787.210074f, 8846.3150748f, 8905.4200750f, 8964.5250752f, 9023.6300755f, 9082.7350757f, 9378.2600768f, 9437.3650770f, 9496.4700773f, 9555.5750775f, 9614.6800777f, 9673.7850779f, 9969.3100791f, 10028.4150793f, 10087.5200795f, 10146.625079f, 10205.7300800f, 10264.8350802f, 10560.3600813f, 10619.465081f, 10678.5700818f, 10737.6750820f, 10796.7800822f, 10855.8850825f, 11151.4100836f, 11210.5150838f, 11269.6200840f, 11328.7250843f, 11387.8300845f, 11446.9350847f, 11742.4600858f, 11801.5650861f, 11860.6700863f, 11919.7750865f, 11978.880086f, 12037.9850870f, 9782.1750935f, 9848.0750935f, 9913.9750934f, 9979.8750934f, 10045.7750934f, 10111.6750933f, 10441.1750931f, 10507.0750931f, 10572.9750931f, 10638.8750930f, 10704.7750930f, 10770.6750930f, 11100.1750928f, 11166.0750927f, 11231.9750927f, 11297.8750927f, 11363.7750926f, 11429.6750926f, 11759.1750924f, 11825.0750924f, 11890.9750923f, 11956.8750923f, 12022.7750923f, 12088.6750922f, 12418.175092f, 12484.0750920f, 12549.9750920f, 12615.8750919f, 12681.7750919f, 12747.6750919f, 13077.1750917f, 13143.0750916f, 13208.9750916f, 13274.8750916f, 13340.7750915f, 13406.6750915f, 2250.990060f, 2255.7350610f, 2260.4800611f, 2265.2250612f, 2269.9700613f, 2274.7150614f, 2298.4400619f, 2303.185062f, 2307.9300622f, 2312.6750623f, 2317.4200624f, 2322.1650625f, 2345.8900630f, 2350.6350631f, 2355.380063f, 2360.1250634f, 2364.8700635f, 2369.6150636f, 2393.3400641f, 2398.0850642f, 2402.8300643f, 2407.5750644f, 2412.320064f, 2417.0650647f, 2440.7900652f, 2445.5350653f, 2450.2800654f, 2455.0250655f, 2459.7700656f, 2464.515065f, 2488.2400663f, 2492.9850664f, 2497.7300665f, 2502.4750666f, 2507.2200667f, 2511.9650668f, 5284.4551315f, 5295.9951318f, 5307.535132f, 5319.0751323f, 5330.6151326f, 5342.1551328f, 5399.8551341f, 5411.3951343f, 5422.9351346f, 5434.475134f, 5446.0151351f, 5457.5551354f, 5515.2551366f, 5526.7951369f, 5538.3351371f, 5549.8751374f, 5561.4151376f, 5572.9551379f, 5630.6551392f, 5642.1951394f, 5653.7351397f, 5665.2751399f, 5676.8151402f, 5688.3551404f, 5746.0551417f, 5757.5951420f, 5769.1351422f, 5780.6751425f, 5792.2151427f, 5803.7551430f, 5861.455144f, 5872.9951445f, 5884.5351448f, 5896.0751450f, 5907.6151453f, 5919.1551455f, 8317.919884f, 8336.2548841f, 8354.5898838f, 8372.9248835f, 8391.2598832f, 8409.59488f, 8501.2698815f, 8519.6048813f, 8537.9398810f, 8556.2748807f, 8574.6098804f, 8592.9448801f, 8684.6198787f, 8702.9548784f, 8721.2898782f, 8739.6248779f, 8757.9598776f, 8776.2948773f, 8867.9698759f, 8886.3048756f, 8904.6398753f, 8922.9748751f, 8941.3098748f, 8959.6448745f, 9051.3198731f, 9069.6548728f, 9087.9898725f, 9106.3248722f, 9124.6598720f, 9142.9948717f, 9234.6698703f, 9253.0048700f, 9271.3398697f, 9289.6748694f, 9308.0098691f, 9326.3448689f, 11351.3852747f, 11376.5152753f, 11401.6452759f, 11426.7752765f, 11451.9052771f, 11477.0352777f, 11602.6852806f, 11627.8152812f, 11652.9452818f, 11678.0752824f, 11703.2052830f, 11728.335283f, 11853.9852865f, 11879.1152871f, 11904.2452877f, 11929.3752883f, 11954.505288f, 11979.6352894f, 12105.2852924f, 12130.4152930f, 12155.545293f, 12180.6752941f, 12205.8052947f, 12230.9352953f, 12356.5852983f, 12381.715298f, 12406.8452994f, 12431.9753000f, 12457.1053006f, 12482.2353012f, 12607.8853041f, 12633.0153047f, 12658.1453053f, 12683.2753059f, 12708.4053065f, 12733.5353071f, 14384.8499244f, 14416.7749242f, 14448.6999240f, 14480.6249238f, 14512.549923f, 14544.4749235f, 14704.0999225f, 14736.024922f, 14767.9499222f, 14799.8749220f, 14831.7999218f, 14863.7249216f, 15023.3499207f, 15055.2749205f, 15087.1999203f, 15119.1249201f, 15151.0499199f, 15182.9749197f, 15342.5999188f, 15374.5249186f, 15406.4499184f, 15438.374918f, 15470.2999181f, 15502.2249179f, 15661.84991f, 15693.7749168f, 15725.6999166f, 15757.6249164f, 15789.5499162f, 15821.4749160f, 15981.0999151f, 16013.0249149f, 16044.9499147f, 16076.8749145f, 16108.7999143f, 16140.7249142f, 17418.314976f, 17457.0349761f, 17495.7549760f, 17534.4749759f, 17573.1949758f, 17611.9149757f, 17805.5149753f, 17844.234975f, 17882.9549752f, 17921.6749751f, 17960.3949750f, 17999.1149749f, 18192.7149745f, 18231.4349744f, 18270.154974f, 18308.8749743f, 18347.5949742f, 18386.3149741f, 18579.9149737f, 18618.6349736f, 18657.3549735f, 18696.074973f, 18734.7949734f, 18773.5149733f, 18967.1149729f, 19005.8349728f, 19044.5549727f, 19083.2749726f, 19121.994972f, 19160.7149725f, 19354.3149721f, 19393.0349720f, 19431.7549719f, 19470.4749718f, 19509.1949717f, 19547.914971f, 20451.7799765f, 20497.2949766f, 20542.8099767f, 20588.3249768f, 20633.8399769f, 20679.3549770f, 20906.929977f, 20952.4449775f, 20997.9599776f, 21043.4749777f, 21088.9899778f, 21134.5049779f, 21362.0799784f, 21407.5949785f, 21453.1099786f, 21498.624978f, 21544.139978f, 21589.6549788f, 21817.2299793f, 21862.7449794f, 21908.2599795f, 21953.7749796f, 21999.2899797f, 22044.8049798f, 22272.3799802f, 22317.8949803f, 22363.4099804f, 22408.9249805f, 22454.4399806f, 22499.9549807f, 22727.529981f, 22773.044981f, 22818.5599813f, 22864.0749814f, 22909.5899815f, 22955.1049816f, 23485.2453985f, 23537.555399f, 23589.8654000f, 23642.1754008f, 23694.4854016f, 23746.7954024f, 24008.3454063f, 24060.655407f, 24112.9654078f, 24165.2754086f, 24217.5854094f, 24269.8954102f, 24531.4454141f, 24583.7554148f, 24636.0654156f, 24688.3754164f, 24740.6854172f, 24792.99541f, 25054.545421f, 25106.8554226f, 25159.1654234f, 25211.4754242f, 25263.7854250f, 25316.0954257f, 25577.6454296f, 25629.9554304f, 25682.2654312f, 25734.5754320f, 25786.8854328f, 25839.1954335f, 26100.7454374f, 26153.0554382f, 26205.3654390f, 26257.6754398f, 26309.985440f, 26362.2954413f, 26518.7101423f, 26577.8151425f, 26636.920142f, 26696.0251430f, 26755.1301432f, 26814.2351434f, 27109.7601446f, 27168.8651448f, 27227.9701450f, 27287.0751452f, 27346.1801455f, 27405.2851457f, 27700.8101468f, 27759.9151470f, 27819.0201473f, 27878.1251475f, 27937.2301477f, 27996.33514f, 28291.8601491f, 28350.9651493f, 28410.0701495f, 28469.175149f, 28528.2801500f, 28587.3851502f, 28882.9101513f, 28942.0151516f, 29001.1201518f, 29060.2251520f, 29119.3301522f, 29178.4351525f, 29473.9601536f, 29533.0651538f, 29592.1701540f, 29651.2751543f, 29710.3801545f, 29769.4851547f, 29552.1750826f, 29618.0750825f, 29683.9750825f, 29749.8750825f, 29815.7750824f, 29881.6750824f, 30211.1750822f, 30277.0750822f, 30342.9750821f, 30408.8750821f, 30474.7750821f, 30540.6750820f, 30870.175081f, 30936.0750818f, 31001.9750818f, 31067.8750817f, 31133.7750817f, 31199.6750817f, 31529.1750815f, 31595.075081f, 31660.9750814f, 31726.8750814f, 31792.7750813f, 31858.6750813f, 32188.1750811f, 32254.0750811f, 32319.975081f, 32385.8750810f, 32451.7750810f, 32517.6750809f, 32847.1750808f, 32913.0750807f, 32978.9750807f, 33044.875080f, 33110.7750806f, 33176.67508062};
Nd4jLong _exp2SFF[] = {4, 2, 10, 6, 6, 360, 36, 6, 1, typeid(TypeParam) == typeid(float) ? 8192 : 16384, 1, 99,};
NDArray exp2FF(_exp2BFF, _exp2SFF);
auto input = NDArrayFactory::create<TypeParam>('c', {2, 3, 10, 10});
auto weightsD = NDArrayFactory::create<TypeParam>('c', {2, 3, 5, 5});
auto weightsP = NDArrayFactory::create<TypeParam>('c', {10, 6, 1, 1});
input.linspace(1);
weightsD.linspace(1);
weightsP.linspace(1);
weightsD.permutei({2,3,1,0});
weightsP.permutei({2,3,1,0});
weightsP.applyScalar(scalar::Divide, 10000.0);
nd4j::ops::sconv2d op;
auto resultFF = op.execute({&input, &weightsD}, {}, {5, 5, 1, 1, 0, 0, 1, 1, 0}, {});
auto z = resultFF->at(0);
ASSERT_TRUE(z->isSameShape(&expFF));
ASSERT_TRUE(z->equalsTo(&expFF, 1));
nd4j::ops::conv2d op2d;
// weightsP.printShapeInfo();
auto result2D = op2d.execute({z, &weightsP}, {}, {1, 1, 1, 1, 0, 0, 1, 1, 0, 0}, {});
auto z2d = result2D->at(0);
ASSERT_TRUE(z2d->isSameShape(&exp2FF));
ASSERT_TRUE(z2d->equalsTo(&exp2FF));
delete resultFF;
delete result2D;
}
TEST_F(ConvolutionTests1, Test_im2col_col2im_1) {
int kY = 5;
int kX = 5;
int sY = 1;
int sX = 1;
int pY = 0;
int pX = 0;
int dY = 1;
int dX = 1;
int inY = 28;
int inX = 28;
int channels = 3;
bool isSameMode = true;
auto x = NDArrayFactory::create<double>('c', {2, channels, inY, inX});
x.linspace(1);
int oY, oX;
x.syncToDevice();
//ASSERT_TRUE(x.isActualOnDeviceSide());
ASSERT_TRUE(x.isActualOnHostSide());
//x.printBuffer("x", 64);
nd4j::ops::ConvolutionUtils::calcOutSizePool2D(oY, oX, kY, kX, sY, sX, pY, pX, dY, dX, inY, inX, isSameMode);
if (isSameMode)
nd4j::ops::ConvolutionUtils::calcPadding2D(pY, pX, oY, oX, inY, inX, kY, kX, sY, sX, dY, dX);
auto im2col0 = NDArrayFactory::create<double>('c', {2, channels, kY, kX, oY, oX});
ExtraArguments args({(double) kY, (double) kX, (double) sY, (double) sX, (double) pY, (double) pX, (double) dY, (double) dX, isSameMode ? (double) 1 : (double) 0, (double)0.0, (double) 0.});
x.applyTransform(transform::Im2col, &im2col0, &args);
nd4j::ops::im2col op;
auto result2col = op.execute({&x}, {}, {kY, kX, sY, sX, pY, pX, dY, dX, isSameMode ? 1 : 0});
auto im2col1 = result2col->at(0);
//im2col0.printBuffer("transformed");
//im2col1->printBuffer("customized", 64);
ASSERT_TRUE(im2col1->isSameShape(&im2col0));
ASSERT_TRUE(im2col1->equalsTo(&im2col0));
ExtraArguments args2({ (double) sY, (double) sX, (double) pY, (double) pX, (double) inY, (double) inX, (double) dY, (double) dX, isSameMode ? (double) 1 : (double) 0});
auto col2im0 = NDArrayFactory::create<double>('c', {2, channels, inY, inX});
im2col0.applyTransform(transform::Col2Im, &col2im0, &args2);
nd4j::ops::col2im op2im;
auto result2im = op2im.execute({im2col1}, {}, {sY, sX, pY, pX, inY, inX, dY, dX, isSameMode ? 1 : 0});
auto col2im1 = result2im->at(0);
ASSERT_TRUE(col2im1->isSameShape(&col2im0));
ASSERT_TRUE(col2im1->equalsTo(&col2im0));
delete result2col;
delete result2im;
}
TEST_F(ConvolutionTests1, Test_im2col_col2im_2) {
int kY = 5;
int kX = 5;
int sY = 1;
int sX = 1;
int pY = 0;
int pX = 0;
int dY = 1;
int dX = 1;
int inY = 28;
int inX = 28;
int channels = 3;
bool isSameMode = true;
auto x = NDArrayFactory::create<double>('c', {2, channels, inY, inX});
x.linspace(1);
int oY, oX;
nd4j::ops::ConvolutionUtils::calcOutSizePool2D(oY, oX, kY, kX, sY, sX, pY, pX, dY, dX, inY, inX, isSameMode);
if (isSameMode)
nd4j::ops::ConvolutionUtils::calcPadding2D(pY, pX, oY, oX, inY, inX, kY, kX, sY, sX, dY, dX);
auto im2col0 = NDArrayFactory::create<double>('c', {2, channels, oY, oX, kY, kX});
im2col0.permutei({0, 1, 4, 5, 2, 3});
ExtraArguments args2col({(double) kY, (double) kX, (double) sY, (double) sX, (double) pY, (double) pX, (double) dY, (double) dX, isSameMode ? (double) 1 : (double) 0, (double)0.0, (double) 0.});
x.applyTransform(transform::Im2col, &im2col0, &args2col);
nd4j::ops::im2col op;
auto result2col = op.execute({&x}, {}, {kY, kX, sY, sX, pY, pX, dY, dX, isSameMode ? 1 : 0});
auto im2col1 = result2col->at(0);
ASSERT_TRUE(im2col1->isSameShape(&im2col0));
ASSERT_TRUE(im2col1->equalsTo(&im2col0));
ExtraArguments args2im({ (double) sY, (double) sX, (double) pY, (double) pX, (double) inY, (double) inX, (double) dY, (double) dX, isSameMode ? (double) 1 : (double) 0});
auto col2im0 = NDArrayFactory::create<double>('c', {2, channels, inY, inX});
im2col0.applyTransform(transform::Col2Im, &col2im0, &args2im);
nd4j::ops::col2im op2im;
auto result2im = op2im.execute({im2col1}, {}, {sY, sX, pY, pX, inY, inX, dY, dX, isSameMode ? 1 : 0});
auto col2im1 = result2im->at(0);
ASSERT_TRUE(col2im1->isSameShape(&col2im0));
ASSERT_TRUE(col2im1->equalsTo(&col2im0));
delete result2col;
delete result2im;
}
TEST_F(ConvolutionTests1, Test_im2col_col2im_3) {
int kY = 5;
int kX = 5;
int sY = 1;
int sX = 1;
int pY = 0;
int pX = 0;
int dY = 1;
int dX = 1;
int inY = 28;
int inX = 28;
int channels = 3;
bool isSameMode = true;
auto x = NDArrayFactory::create<double>('c', {2, channels, inY, inX});
x.linspace(1);
int oY, oX;
nd4j::ops::ConvolutionUtils::calcOutSizePool2D(oY, oX, kY, kX, sY, sX, pY, pX, dY, dX, inY, inX, isSameMode);
if (isSameMode)
nd4j::ops::ConvolutionUtils::calcPadding2D(pY, pX, oY, oX, inY, inX, kY, kX, sY, sX, dY, dX);
auto im2col0 = NDArrayFactory::create<double>('c', {2, channels, oY, oX, kY, kX});
im2col0.permutei({0, 1, 4, 5, 2, 3});
auto im2col1 = NDArrayFactory::create<double>('c', {2, channels, oY, oX, kY, kX});
im2col1.permutei({0, 1, 4, 5, 2, 3});
ExtraArguments args2col({(double) kY, (double) kX, (double) sY, (double) sX, (double) pY, (double) pX, (double) dY, (double) dX, isSameMode ? (double) 1 : (double) 0, (double)0.0, (double) 0.});
x.applyTransform(transform::Im2col, &im2col0, &args2col);
nd4j::ops::im2col op;
auto status = op.execute({&x}, {&im2col1}, {}, {kY, kX, sY, sX, pY, pX, dY, dX, isSameMode ? 1 : 0}, {});
ASSERT_EQ(Status::OK(), status);
ASSERT_TRUE(im2col1.isSameShape(&im2col0));
ASSERT_TRUE(im2col1.equalsTo(&im2col0));
ExtraArguments args2im({ (double) sY, (double) sX, (double) pY, (double) pX, (double) inY, (double) inX, (double) dY, (double) dX, isSameMode ? (double) 1 : (double) 0});
auto col2im0 = NDArrayFactory::create<double>('c', {2, channels, inY, inX});
im2col0.applyTransform(transform::Col2Im, &col2im0, &args2im);
nd4j::ops::col2im op2im;
auto result2im = op2im.execute({&im2col1}, {}, {sY, sX, pY, pX, inY, inX, dY, dX, isSameMode ? 1 : 0});
auto col2im1 = result2im->at(0);
ASSERT_TRUE(col2im1->isSameShape(&col2im0));
ASSERT_TRUE(col2im1->equalsTo(&col2im0));
delete result2im;
}
TEST_F(ConvolutionTests1, TestDeconv_bp_1) {
double _expb[] = { 35.f, 38.f, 41.f, 44.f, 47.f, 50.f, 53.f, 56.f, 59.f, 62.f, 65.f, 68.f, 71.f, 74.f, 77.f, 80.f, 71.f, 78.f, 85.f, 92.f, 99.f, 106.f, 113.f, 120.f, 127.f, 134.f, 141.f, 148.f, 155.f, 162.f, 169.f, 176.f, 107.f, 118.f, 129.f, 140.f, 151.f, 162.f, 173.f, 184.f, 195.f, 206.f, 217.f, 228.f, 239.f, 250.f, 261.f, 272.f, 131.f, 134.f, 137.f, 140.f, 143.f, 146.f, 149.f, 152.f, 155.f, 158.f, 161.f, 164.f, 167.f, 170.f, 173.f, 176.f, 295.f, 302.f, 309.f, 316.f, 323.f, 330.f, 337.f, 344.f, 351.f, 358.f, 365.f, 372.f, 379.f, 386.f, 393.f, 400.f, 459.f, 470.f, 481.f, 492.f, 503.f, 514.f, 525.f, 536.f, 547.f, 558.f, 569.f, 580.f, 591.f, 602.f, 613.f, 624.f, 227.f, 230.f, 233.f, 236.f, 239.f, 242.f, 245.f, 248.f, 251.f, 254.f, 257.f, 260.f, 263.f, 266.f, 269.f, 272.f, 519.f, 526.f, 533.f, 540.f, 547.f, 554.f, 561.f, 568.f, 575.f, 582.f, 589.f, 596.f, 603.f, 610.f, 617.f, 624.f, 811.f, 822.f, 833.f, 844.f, 855.f, 866.f, 877.f, 888.f, 899.f, 910.f, 921.f, 932.f, 943.f, 954.f, 965.f, 976.f};
std::shared_ptr<DataBuffer> pBuffer1 = std::make_shared<DataBuffer>(_expb, sizeof(_expb), nd4j::DataType::DOUBLE, false);
NDArray expEpsilon(pBuffer1, 'c', {3, 3, 4, 4});
double _expwb[] = { 160008.f, 203400.f, 191112.f, 246792.f, 222216.f, 290184.f};
std::shared_ptr<DataBuffer> pBuffer2 = std::make_shared<DataBuffer>(_expwb, sizeof(_expwb), nd4j::DataType::DOUBLE, false);
NDArray expGradW(pBuffer2, 'c', {3, 2, 1, 1});
expGradW.permutei({2,3,1,0});
double _expbb[] = {1944.f, 2712.f};
std::shared_ptr<DataBuffer> pBuffer3 = std::make_shared<DataBuffer>(_expbb, sizeof(_expbb), nd4j::DataType::DOUBLE, false);
NDArray expGradB(pBuffer3, 'c', {1, 2});
auto input = NDArrayFactory::create<double>('c', {3, 3, 4, 4});
auto bias = NDArrayFactory::create<double>('c', {1, 2});
auto weights = NDArrayFactory::create<double>('c',{3, 2, 1, 1});
auto epsilon = NDArrayFactory::create<double>('c', {3, 2, 4, 4});
/*
Input shape (3, 3, 4, 4)
Weights shape (3, 2, 1, 1)
Epsilon shape (3, 2, 4, 4)
*/
input.linspace(1);
weights.linspace(1);
bias.linspace(1);
epsilon.linspace(1);
weights.permutei({2,3,1,0});
nd4j::ops::deconv2d_bp op;
auto result = op.execute({&input, &weights, &bias, &epsilon}, {}, {1, 1, 1, 1, 0, 0, 1, 1, 1, 0});
ASSERT_EQ(ND4J_STATUS_OK, result->status());
auto expNext = result->at(0);
ASSERT_TRUE(expEpsilon.isSameShape(expNext));
ASSERT_TRUE(expEpsilon.equalsTo(expNext));
auto gradW = result->at(1);
ASSERT_TRUE(expGradW.isSameShape(gradW));
ASSERT_TRUE(expGradW.equalsTo(gradW));
auto gradB = result->at(2);
ASSERT_TRUE(expGradB.isSameShape(gradB));
ASSERT_TRUE(expGradB.equalsTo(gradB));
delete result;
}
TEST_F(ConvolutionTests1, TestDeconv_bp_2) {
/*
Input shape:
[3, 3, 14, 14]
Output shape:
[3, 2, 15, 15]
Weights shape:
[3, 2, 2, 2]
Bias shape:
[1, 2]
weight shape:
[3, 2, 2, 2]
weight grad shape:
[3, 2, 2, 2]
bias grad shape:
[2]
input epsilon shape:
[3, 2, 15, 15]
output epsilon shape:
[3, 3, 14, 14]
*/
/*
auto input('c', {3, 3, 14, 14});
auto bias('c', {2});
auto weights('c',{3, 2, 2, 2});
auto epsilon('c', {3, 2, 15, 15});
input.linspace(1);
weights.linspace(1);
bias.linspace(1);
epsilon.linspace(1);
nd4j::ops::deconv2d_bp<double> op;
auto result = op.execute({&input, &weights, &bias, &epsilon}, {}, {2, 2, 1, 1, 0, 0, 2, 2, 0});
ASSERT_EQ(ND4J_STATUS_OK, result->status());
delete result;*/
}
TEST_F(ConvolutionTests1, TestDeconv_ff_2) {
NDArray exp('c', {3, 2, 4, 4}, {218., 227., 236., 245., 254., 263., 272., 281., 290., 299., 308., 317., 326., 335., 344., 353., 270., 282., 294., 306., 318., 330., 342., 354., 366., 378., 390., 402., 414., 426., 438., 450., 650., 659., 668., 677., 686., 695., 704., 713., 722., 731., 740., 749., 758., 767., 776., 785., 846., 858., 870., 882., 894., 906., 918., 930., 942., 954., 966., 978., 990., 1002., 1014., 1026., 1082., 1091., 1100., 1109., 1118., 1127., 1136., 1145., 1154., 1163., 1172., 1181., 1190., 1199., 1208., 1217., 1422., 1434., 1446., 1458., 1470., 1482., 1494., 1506., 1518., 1530., 1542., 1554., 1566., 1578., 1590., 1602.});
auto input = NDArrayFactory::create<double>('c', {3, 3, 4, 4});
auto weights = NDArrayFactory::create<double>('c',{3, 2, 1, 1});
auto bias = NDArrayFactory::create<double>('c', {2});
input.linspace(1);
weights.linspace(1);
bias.linspace(1);
weights.permutei({2,3,1,0});
nd4j::ops::deconv2d op;
auto result = op.execute({&input, &weights, &bias}, {}, {1, 1, 1, 1, 0, 0, 1, 1, 1, 0});
ASSERT_EQ(ND4J_STATUS_OK, result->status());
auto output = result->at(0);
ASSERT_TRUE(exp.isSameShape(output));
ASSERT_TRUE(exp.equalsTo(output));
delete result;
}
TYPED_TEST(TypedConvolutionTests1, Test_Conv1D_ff_1) {
auto input = NDArrayFactory::create<TypeParam>('c', {2, 2, 6});
auto weights = NDArrayFactory::create<TypeParam>('c', {2, 2, 3}, {1,5,9,3,7,11,2,6,10,4,8,12});
auto bias = NDArrayFactory::create<TypeParam>('c', {3});
auto expFF = NDArrayFactory::create<TypeParam>('c', {2, 3, 5}, {59.0, 69.0, 79.0, 89.0, 99.0, 132.0, 158.0, 184.0, 210.0, 236.0, 205.0, 247.0, 289.0, 331.0, 373.0, 179.0, 189.0, 199.0, 209.0, 219.0, 444.0, 470.0, 496.0, 522.0, 548.0, 709.0, 751.0, 793.0, 835.0, 877.0});
auto expEps = NDArrayFactory::create<TypeParam>('c', {2, 2, 6}, {130.0, 293.0, 326.0, 359.0, 392.0, 220.0, 166.0, 371.0, 416.0, 461.0, 506.0, 280.0, 355.0, 788.0, 821.0, 854.0, 887.0, 490.0, 481.0, 1046.0, 1091.0, 1136.0, 1181.0, 640.0});
auto expGW = NDArrayFactory::create<TypeParam>('c', {3, 2, 2}, {1415.0, 1520.0, 2045.0, 2150.0, 1865.0, 2020.0, 2795.0, 2950.0, 2315.0, 2520.0, 3545.0, 3750.0});
auto expGB = NDArrayFactory::create<TypeParam>('c', {3}, {105.0, 155.0, 205.0});
expGW.permutei({2,1,0});
input.linspace(1);
bias.linspace(1);
nd4j::ops::conv1d op;
auto result_FF = op.execute({&input, &weights, &bias}, {}, {2, 1, 0, 0});
ASSERT_EQ(ND4J_STATUS_OK, result_FF->status());
auto z = result_FF->at(0);
ASSERT_TRUE(expFF.isSameShape(z));
ASSERT_TRUE(expFF.equalsTo(z));
nd4j::ops::conv1d_bp op_bp;
auto epsilonNxt = z->dup();
epsilonNxt->linspace(1);
auto result_BP = op_bp.execute({&input, &weights, &bias, epsilonNxt}, {}, {2, 1, 0, 0});
ASSERT_EQ(ND4J_STATUS_OK, result_BP->status());
auto eps = result_BP->at(0);
auto gradW = result_BP->at(1);
auto gradB = result_BP->at(2);
ASSERT_TRUE(expEps.isSameShape(eps));
ASSERT_TRUE(expGW.isSameShape(gradW));
ASSERT_TRUE(expGB.isSameShape(gradB));
ASSERT_TRUE(expEps.equalsTo(eps));
ASSERT_TRUE(expGW.equalsTo(gradW));
ASSERT_TRUE(expGB.equalsTo(gradB));
delete result_FF;
delete result_BP;
delete epsilonNxt;
}
TYPED_TEST(TypedConvolutionTests1, Test_Conv1D_ff_2) {
auto input = NDArrayFactory::create<TypeParam>('c', {2, 2, 6});
auto weights = NDArrayFactory::create<TypeParam>('c', {2, 2, 3}, {1,5,9,3,7,11,2,6,10,4,8,12});
input.linspace(1);
nd4j::ops::conv1d op;
auto result = op.execute({&input, &weights}, {}, {2, 1, 0, 1});
ASSERT_EQ(ND4J_STATUS_OK, result->status());
auto z = result->at(0);
delete result;
}
TEST_F(ConvolutionTests1, Test_Dilation2D_1) {
auto input = NDArrayFactory::create<double>('c', {2, 6, 6, 3});
auto weights = NDArrayFactory::create<double>('c', {3, 2, 3});
auto exp = NDArrayFactory::create<double>('c', {2, 3, 3, 3}, {77, 79, 81, 83, 85, 87, 80, 82, 84, 113, 115, 117, 119, 121, 123, 116, 118, 120, 107, 109, 111, 113, 115, 117, 110, 112, 114, 185, 187, 189, 191, 193, 195, 188, 190, 192, 221, 223, 225, 227, 229, 231, 224, 226, 228, 215, 217, 219, 221, 223, 225, 218, 220, 222,});
input.linspace(1);
weights.linspace(1);
nd4j::ops::dilation2d op;
auto result = op.execute({&input, &weights}, {}, {1, 1,2,2,1, 1,2,2,1});
ASSERT_EQ(Status::OK(), result->status());
auto z = result->at(0);
ASSERT_TRUE(exp.isSameShape(z));
ASSERT_TRUE(exp.equalsTo(z));
delete result;
}
TEST_F(ConvolutionTests1, Test_Dilation2D_2) {
auto input = NDArrayFactory::create<double>('c', {2, 6, 6, 3});
auto weights = NDArrayFactory::create<double>('c', {3, 2, 3});
auto exp = NDArrayFactory::create<double>('c', {2, 1, 2, 3}, {95, 97, 99, 101, 103, 105, 203, 205, 207, 209, 211, 213});
input.linspace(1);
weights.linspace(1);
nd4j::ops::dilation2d op;
auto result = op.execute({&input, &weights}, {}, {0, 1,2,2,1, 1,2,2,1});
ASSERT_EQ(Status::OK(), result->status());
auto z = result->at(0);
ASSERT_TRUE(exp.isSameShape(z));
ASSERT_TRUE(exp.equalsTo(z));
delete result;
}
//////////////////////////////////////////////////////////////////////
TYPED_TEST(TypedConvolutionTests1, conv2d_bp_test1) {
int bS=2, iH=4,iW=3, iC=4,oC=3, kH=3,kW=2, sH=1,sW=1, pH=0,pW=0, dH=1,dW=1;
int oH=4,oW=3;
int paddingMode = 1; // 1-SAME, 0-VALID;
int dataFormat = 1; // 1-NHWC, 0-NCHW
auto input = NDArrayFactory::create<TypeParam>('c', {bS, iH, iW, iC});
auto weights = NDArrayFactory::create<TypeParam>('c', {kH, kW, iC, oC});
auto bias = NDArrayFactory::create<TypeParam>('c', {oC}, {1,2,3});
auto gradO = NDArrayFactory::create<TypeParam>('c', {bS, oH, oW, oC});
auto expGradI = NDArrayFactory::create<TypeParam>('c', {bS, iH, iW, iC},{ 0.226f, 0.343f, 0.46f, 0.577f, 1.172f, 1.46f, 1.748f, 2.036f, 1.892f, 2.288f, 2.684f, 3.08f, 1.284f, 1.581f, 1.878f, 2.175f, 4.458f, 5.133f, 5.808f, 6.483f, 6.186f, 7.023f, 7.86f, 8.697f,
3.39f, 3.93f, 4.47f, 5.01f, 9.642f, 10.803f, 11.964f, 13.125f,11.37f, 12.693f, 14.016f, 15.339f, 5.266f, 5.707f, 6.148f, 6.589f,12.98f, 13.916f, 14.852f, 15.788f,14.564f, 15.608f, 16.652f, 17.696f,
3.25f, 4.015f, 4.78f, 5.545f, 9.812f, 11.396f, 12.98f, 14.564f,10.532f, 12.224f, 13.916f, 15.608f, 9.708f, 10.977f, 12.246f, 13.515f,25.194f, 27.813f, 30.432f, 33.051f,26.922f, 29.703f, 32.484f, 35.265f,
11.814f, 13.326f, 14.838f, 16.35f,30.378f, 33.483f, 36.588f, 39.693f,32.106f, 35.373f, 38.64f, 41.907f,13.474f, 14.563f, 15.652f, 16.741f,31.988f, 34.22f, 36.452f, 38.684f,33.572f, 35.912f, 38.252f, 40.592f});
auto expGradW = NDArrayFactory::create<TypeParam>('c', {kH, kW, iC, oC},{14.4f, 14.76f, 15.12f,14.4f, 14.76f, 15.12f,14.4f, 14.76f, 15.12f,14.4f, 14.76f, 15.12f, 9.24f, 9.48f, 9.72f, 9.24f, 9.48f, 9.72f, 9.24f, 9.48f, 9.72f, 9.24f, 9.48f, 9.72f,
17.04f, 17.52f, 18.f,17.04f, 17.52f, 18.f, 17.04f, 17.52f, 18.f, 17.04f, 17.52f, 18.f,10.88f, 11.2f, 11.52f,10.88f, 11.2f, 11.52f,10.88f, 11.2f, 11.52f,10.88f, 11.2f, 11.52f,
11.16f, 11.52f, 11.88f,11.16f, 11.52f, 11.88f,11.16f, 11.52f, 11.88f,11.16f, 11.52f, 11.88f, 7.08f, 7.32f, 7.56f, 7.08f, 7.32f, 7.56f, 7.08f, 7.32f, 7.56f, 7.08f, 7.32f, 7.56f});
// auto expGradB('c', {oC},{});
input = 2.;
weights.linspace(0.1, 0.1);
gradO.linspace(0.01, 0.01);
nd4j::ops::conv2d_bp op;
auto results = op.execute({&input, &weights, &bias, &gradO}, {}, {kH,kW, sH,sW, pH,pW, dH,dW, paddingMode, dataFormat});
auto gradI = results->at(0);
auto gradW = results->at(1);
ASSERT_EQ(Status::OK(), results->status());
ASSERT_TRUE(expGradI.isSameShape(gradI));
ASSERT_TRUE(expGradI.equalsTo(gradI));
ASSERT_TRUE(expGradW.isSameShape(gradW));
ASSERT_TRUE(expGradW.equalsTo(gradW));
delete results;
}
//////////////////////////////////////////////////////////////////////
TYPED_TEST(TypedConvolutionTests1, conv2d_bp_test2) {
int bS=2, iH=4,iW=3, iC=4,oC=3, kH=3,kW=2, sH=1,sW=1, pH=0,pW=0, dH=1,dW=1;
int oH=2,oW=2;
int paddingMode = 0; // 1-SAME, 0-VALID;
int dataFormat = 1; // 1-NHWC, 0-NCHW
auto input = NDArrayFactory::create<TypeParam>('c', {bS, iH, iW, iC});
auto weights = NDArrayFactory::create<TypeParam>('c', {kH, kW, iC, oC});
auto bias = NDArrayFactory::create<TypeParam>('c', {oC}, {1,2,3});
auto gradO = NDArrayFactory::create<TypeParam>('c', {bS, oH, oW, oC});
auto expGradI = NDArrayFactory::create<TypeParam>('c', {bS, iH, iW, iC},{ 0.014f, 0.032f, 0.05f, 0.068f,0.118f,0.181f, 0.244f, 0.307f,0.212f,0.257f, 0.302f, 0.347f,0.208f,0.298f, 0.388f, 0.478f,1.028f,1.262f, 1.496f, 1.73f,1.036f,1.18f, 1.324f, 1.468f,
0.928f,1.018f, 1.108f, 1.198f,2.9f,3.134f, 3.368f, 3.602f,2.188f,2.332f, 2.476f, 2.62f, 1.202f,1.274f, 1.346f, 1.418f,3.142f,3.313f, 3.484f, 3.655f,2.048f,2.147f, 2.246f, 2.345f,
0.086f,0.212f, 0.338f, 0.464f,0.694f,0.973f, 1.252f, 1.531f,0.716f,0.869f, 1.022f, 1.175f,1.216f,1.522f, 1.828f, 2.134f,3.908f,4.574f, 5.24f, 5.906f,2.908f,3.268f, 3.628f, 3.988f,
3.664f,3.97f, 4.276f, 4.582f,9.236f,9.902f,10.568f,11.234f,5.788f,6.148f, 6.508f, 6.868f,3.002f,3.182f, 3.362f, 3.542f,7.174f,7.561f, 7.948f, 8.335f,4.28f,4.487f, 4.694f, 4.901f});
auto expGradW = NDArrayFactory::create<TypeParam>('c', {kH, kW, iC, oC},{1.84f, 2.f, 2.16f,1.84f, 2.f, 2.16f,1.84f, 2.f, 2.16f,1.84f, 2.f, 2.16f,1.84f, 2.f, 2.16f,1.84f, 2.f, 2.16f,1.84f, 2.f, 2.16f,1.84f, 2.f, 2.16f,
1.84f, 2.f, 2.16f,1.84f, 2.f, 2.16f,1.84f, 2.f, 2.16f,1.84f, 2.f, 2.16f,1.84f, 2.f, 2.16f,1.84f, 2.f, 2.16f,1.84f, 2.f, 2.16f,1.84f, 2.f, 2.16f,
1.84f, 2.f, 2.16f,1.84f, 2.f, 2.16f,1.84f, 2.f, 2.16f,1.84f, 2.f, 2.16f,1.84f, 2.f, 2.16f,1.84f, 2.f, 2.16f,1.84f, 2.f, 2.16f,1.84f, 2.f, 2.16f});
// auto expGradB('c', {oC},{});
input = 2.;
weights.linspace(0.1, 0.1);
gradO.linspace(0.01, 0.01);
nd4j::ops::conv2d_bp op;
auto results = op.execute({&input, &weights, &bias, &gradO}, {}, {kH,kW, sH,sW, pH,pW, dH,dW, paddingMode, dataFormat});
auto gradI = results->at(0);
auto gradW = results->at(1);
ASSERT_EQ(Status::OK(), results->status());
ASSERT_TRUE(expGradI.isSameShape(gradI));
ASSERT_TRUE(expGradI.equalsTo(gradI));
ASSERT_TRUE(expGradW.isSameShape(gradW));
ASSERT_TRUE(expGradW.equalsTo(gradW));
delete results;
}
//////////////////////////////////////////////////////////////////////
TYPED_TEST(TypedConvolutionTests1, conv2d_bp_test3) {
int bS=2, iH=4,iW=3, iC=4,oC=3, kH=3,kW=2, sH=1,sW=1, pH=0,pW=0, dH=1,dW=1;
int oH=2,oW=2;
int paddingMode = 0; // 1-SAME, 0-VALID;
int dataFormat = 0; // 1-NHWC, 0-NCHW
auto input = NDArrayFactory::create<TypeParam>('c', {bS, iC, iH, iW});
auto weights = NDArrayFactory::create<TypeParam>('c', {oC, iC, kH, kW});
auto bias = NDArrayFactory::create<TypeParam>('c', {oC}, {1,2,3});
auto gradO = NDArrayFactory::create<TypeParam>('c', {bS, oC, oH, oW});
auto expGradI = NDArrayFactory::create<TypeParam>('c', {bS, iC, iH, iW},{ 0.567, 1.224,0.66 ,1.314, 2.82 ,1.512,1.386, 2.976,1.596,0.801, 1.71 ,0.912,0.657, 1.422,0.768,1.53 , 3.288,1.764,1.602, 3.444,1.848,0.927, 1.98 ,1.056,
0.747, 1.62 ,0.876,1.746, 3.756,2.016,1.818, 3.912,2.1 ,1.053, 2.25 ,1.2 ,0.837, 1.818,0.984,1.962, 4.224,2.268,2.034, 4.38 ,2.352,1.179, 2.52 ,1.344,
1.467, 3.06 ,1.596,3.186, 6.636,3.456,3.402, 7.08 ,3.684,1.845, 3.834,1.992,1.773, 3.69 ,1.92 ,3.834, 7.968,4.14 ,4.05 , 8.412,4.368,2.187, 4.536,2.352,
2.079, 4.32 ,2.244,4.482, 9.3 ,4.824,4.698, 9.744,5.052,2.529, 5.238,2.712,2.385, 4.95 ,2.568,5.13 ,10.632,5.508,5.346,11.076,5.736,2.871, 5.94 ,3.072});
auto expGradW = NDArrayFactory::create<TypeParam>('c', {oC, iC, kH, kW},{1.3600e+00,1.3600e+00,1.3600e+00,1.3600e+00,1.3600e+00,1.3600e+00,1.3600e+00,1.3600e+00,1.3600e+00,1.3600e+00,1.3600e+00,1.3600e+00,1.3600e+00,1.3600e+00,1.3600e+00,1.3600e+00,1.3600e+00,1.3600e+00,
1.3600e+00,1.3600e+00,1.3600e+00,1.3600e+00,1.3600e+00,1.3600e+00,2.0000e+00,2.0000e+00,2.0000e+00,2.0000e+00,2.0000e+00,2.0000e+00,2.0000e+00,2.0000e+00,2.0000e+00,2.0000e+00,2.0000e+00,2.0000e+00,
2.0000e+00,2.0000e+00,2.0000e+00,2.0000e+00,2.0000e+00,2.0000e+00,2.0000e+00,2.0000e+00,2.0000e+00,2.0000e+00,2.0000e+00,2.0000e+00,2.6400e+00,2.6400e+00,2.6400e+00,2.6400e+00,2.6400e+00,2.6400e+00,
2.6400e+00,2.6400e+00,2.6400e+00,2.6400e+00,2.6400e+00,2.6400e+00,2.6400e+00,2.6400e+00,2.6400e+00,2.6400e+00,2.6400e+00,2.6400e+00,2.6400e+00,2.6400e+00,2.6400e+00,2.6400e+00,2.6400e+00,2.6400e+00});
auto expGradB = NDArrayFactory::create<TypeParam>('c', {oC},{0.68, 1., 1.32});
input = 2.;
weights.linspace(0.1, 0.1);
gradO.linspace(0.01, 0.01);
weights.permutei({2,3,1,0});
expGradW.permutei({2,3,1,0});
nd4j::ops::conv2d_bp op;
auto results = op.execute({&input, &weights, &bias, &gradO}, {}, {kH,kW, sH,sW, pH,pW, dH,dW, paddingMode, dataFormat});
auto gradI = results->at(0);
auto gradW = results->at(1);
auto gradB = results->at(2);
ASSERT_EQ(Status::OK(), results->status());
ASSERT_TRUE(expGradI.isSameShape(gradI));
ASSERT_TRUE(expGradI.equalsTo(gradI));
ASSERT_TRUE(expGradW.isSameShape(gradW));
ASSERT_TRUE(expGradW.equalsTo(gradW));
ASSERT_TRUE(expGradB.isSameShape(gradB));
ASSERT_TRUE(expGradB.equalsTo(gradB));
delete results;
}
//////////////////////////////////////////////////////////////////////
TEST_F(ConvolutionTests1, conv2d_bp_4) {
int bS=1, iH=7,iW=1, iC=2,oC=3, kH=2,kW=1, sH=1,sW=1, pH=0,pW=0, dH=1,dW=1;
int oH=7,oW=1;
int paddingMode = 1; // 1-SAME, 0-VALID;
int dataFormat = 0; // 1-NHWC, 0-NCHW
NDArray input('c', {bS, iC, iH, iW}, nd4j::DataType::FLOAT32);
NDArray weights('c', {kH, kW, iC, oC}, nd4j::DataType::FLOAT32);
NDArray bias('c', {oC}, {1,2,3}, nd4j::DataType::FLOAT32);
NDArray gradO('c', {bS, oC, oH, oW}, nd4j::DataType::FLOAT32);
NDArray gradI('c', {bS, iC, iH, iW}, nd4j::DataType::FLOAT32);
NDArray gradW('c', {kH, kW, iC, oC}, nd4j::DataType::FLOAT32);
NDArray gradB('c', {oC}, nd4j::DataType::FLOAT32);
input = 2.;
weights.linspace(0.1, 0.1);
gradO.linspace(0.01, 0.01);
nd4j::ops::conv2d_bp op;
auto status = op.execute({&input, &weights, &bias, &gradO}, {&gradI, &gradW, &gradB}, {}, {kH,kW, sH,sW, pH,pW, dH,dW, paddingMode, dataFormat}, {});
ASSERT_EQ(Status::OK(), status);
}
////////////////////////////////////////////////////////////////////
TYPED_TEST(TypedConvolutionTests1, conv3d_bp_test1) {
int bS=2, iD=3,iH=4,iW=3, iC=4,oC=3, kD=2,kH=3,kW=2, sD=1,sH=1,sW=1, pD=0,pH=0,pW=0, dD=1,dH=1,dW=1;
int oD=3,oH=4,oW=3;
int paddingMode = 1; // 1-SAME, 0-VALID;
int dataFormat = 1; // 1-NDHWC, 0-NCDHW
auto input = NDArrayFactory::create<TypeParam>('c', {bS, iD, iH, iW, iC});
auto weights = NDArrayFactory::create<TypeParam>('c', {kD, kH, kW, iC, oC});
auto bias = NDArrayFactory::create<TypeParam>('c', {oC}, {1,2,3});
auto gradO = NDArrayFactory::create<TypeParam>('c', {bS, oD, oH, oW, oC});
auto expGradI = NDArrayFactory::create<TypeParam>('c', {bS, iD, iH, iW, iC},{0.226, 0.343, 0.46 , 0.577, 1.172, 1.46 , 1.748, 2.036, 1.892, 2.288, 2.684, 3.08 , 1.284, 1.581, 1.878, 2.175, 4.458, 5.133, 5.808, 6.483, 6.186, 7.023, 7.86 , 8.697, 3.39 , 3.93 , 4.47 , 5.01 , 9.642, 10.803, 11.964, 13.125, 11.37 , 12.693, 14.016, 15.339,
5.266, 5.707, 6.148, 6.589, 12.98 , 13.916, 14.852, 15.788, 14.564, 15.608, 16.652, 17.696, 6.284, 7.166, 8.048, 8.93 , 17.896, 19.768, 21.64 , 23.512, 21.928, 24.016, 26.104, 28.192, 18.12 , 19.686, 21.252, 22.818, 45.852, 49.146, 52.44 , 55.734, 53.196, 56.814, 60.432, 64.05 ,
28.164, 30.216, 32.268, 34.32 , 67.884, 72.15 , 76.416, 80.682, 75.228, 79.818, 84.408, 88.998, 29.324, 30.854, 32.384, 33.914, 67.432, 70.6 , 73.768, 76.936, 73.192, 76.576, 79.96 , 83.344, 27.884, 30.062, 32.24 , 34.418, 66.28 , 70.744, 75.208, 79.672, 70.312, 74.992, 79.672, 84.352,
58.296, 61.806, 65.316, 68.826,133.98 , 141.162, 148.344, 155.526,141.324, 148.83 , 156.336, 163.842, 68.34 , 72.336, 76.332, 80.328,156.012, 164.166, 172.32 , 180.474,163.356, 171.834, 180.312, 188.79 , 61.292, 64.118, 66.944, 69.77 ,136.552, 142.312, 148.072, 153.832,142.312, 148.288, 154.264, 160.24 ,
9.298, 11.359, 13.42 , 15.481, 27.092, 31.268, 35.444, 39.62 , 27.812, 32.096, 36.38 , 40.664, 26.556, 29.769, 32.982, 36.195, 66.666, 73.173, 79.68 , 86.187, 68.394, 75.063, 81.732, 88.401, 28.662, 32.118, 35.574, 39.03 , 71.85 , 78.843, 85.836, 92.829, 73.578, 80.733, 87.888, 95.043,
29.89 , 32.275, 34.66 , 37.045, 70.004, 74.828, 79.652, 84.476, 71.588, 76.52 , 81.452, 86.384, 71.084, 75.854, 80.624, 85.394,163.048, 172.696, 182.344, 191.992,167.08 , 176.944, 186.808, 196.672,138.648, 146.046, 153.444, 160.842,310.236, 325.194, 340.152, 355.11 ,317.58 , 332.862, 348.144, 363.426,
148.692, 156.576, 164.46 , 172.344,332.268, 348.198, 364.128, 380.058,339.612, 355.866, 372.12 , 388.374,125.228, 130.646, 136.064, 141.482,274.792, 285.736, 296.68 , 307.624,280.552, 291.712, 302.872, 314.032, 92.684, 98.75 , 104.816, 110.882,211.432, 223.672, 235.912, 248.152,215.464, 227.92 , 240.376, 252.832,
178.824, 188.166, 197.508, 206.85 ,398.364, 417.21 , 436.056, 454.902,405.708, 424.878, 444.048, 463.218,188.868, 198.696, 208.524, 218.352,420.396, 440.214, 460.032, 479.85 ,427.74 , 447.882, 468.024, 488.166,157.196, 163.91 , 170.624, 177.338,343.912, 357.448, 370.984, 384.52 ,349.672, 363.424, 377.176, 390.928});
auto expGradW = NDArrayFactory::create<TypeParam>('c', {kD, kH, kW, iC, oC},{120.96, 122.04, 123.12,120.96, 122.04, 123.12,120.96, 122.04, 123.12,120.96, 122.04, 123.12, 79.56, 80.28, 81. , 79.56, 80.28, 81. , 79.56, 80.28, 81. , 79.56, 80.28, 81. ,
154.8 , 156.24, 157.68,154.8 , 156.24, 157.68,154.8 , 156.24, 157.68,154.8 , 156.24, 157.68,101.76, 102.72, 103.68,101.76, 102.72, 103.68,101.76, 102.72, 103.68,101.76, 102.72, 103.68,
111.24, 112.32, 113.4 ,111.24, 112.32, 113.4 ,111.24, 112.32, 113.4 ,111.24, 112.32, 113.4 , 73.08, 73.8 , 74.52, 73.08, 73.8 , 74.52, 73.08, 73.8 , 74.52, 73.08, 73.8 , 74.52,
67.68, 68.4 , 69.12, 67.68, 68.4 , 69.12, 67.68, 68.4 , 69.12, 67.68, 68.4 , 69.12, 44.4 , 44.88, 45.36, 44.4 , 44.88, 45.36, 44.4 , 44.88, 45.36, 44.4 , 44.88, 45.36,
85.92, 86.88, 87.84, 85.92, 86.88, 87.84, 85.92, 86.88, 87.84, 85.92, 86.88, 87.84, 56.32, 56.96, 57.6 , 56.32, 56.96, 57.6 , 56.32, 56.96, 57.6 , 56.32, 56.96, 57.6 ,
61.2 , 61.92, 62.64, 61.2 , 61.92, 62.64, 61.2 , 61.92, 62.64, 61.2 , 61.92, 62.64, 40.08, 40.56, 41.04, 40.08, 40.56, 41.04, 40.08, 40.56, 41.04, 40.08, 40.56, 41.04});
// auto expGradB('c', {oC},{});
input = 2.;
weights.linspace(0.1, 0.1);
gradO.linspace(0.01, 0.01);
nd4j::ops::conv3dnew_bp op;
auto results = op.execute({&input, &weights, &bias, &gradO}, {}, {kD,kH,kW, sD,sH,sW, pD,pH,pW, dD,dH,dW, paddingMode, dataFormat});
auto* gradI = results->at(0);
auto* gradW = results->at(1);
ASSERT_EQ(Status::OK(), results->status());
ASSERT_TRUE(expGradI.isSameShape(gradI));
ASSERT_TRUE(expGradI.equalsTo(gradI));
ASSERT_TRUE(expGradW.isSameShape(gradW));
ASSERT_TRUE(expGradW.equalsTo(gradW));
delete results;
}
////////////////////////////////////////////////////////////////////
TYPED_TEST(TypedConvolutionTests1, conv3d_bp_test2) {
int bS=2, iD=3,iH=4,iW=3, iC=4,oC=3, kD=2,kH=3,kW=2, sD=1,sH=1,sW=1, pD=0,pH=0,pW=0, dD=1,dH=1,dW=1;
int oD=2,oH=2,oW=2;
int paddingMode = 0; // 1-SAME, 0-VALID;
int dataFormat = 1; // 1-NDHWC, 0-NCDHW
auto input = NDArrayFactory::create<TypeParam>('c', {bS, iD, iH, iW, iC});
auto weights = NDArrayFactory::create<TypeParam>('c', {kD, kH, kW, iC, oC});
auto bias = NDArrayFactory::create<TypeParam>('c', {oC}, {1,2,3});
auto gradO = NDArrayFactory::create<TypeParam>('c', {bS, oD, oH, oW, oC});
auto expGradI = NDArrayFactory::create<TypeParam>('c', {bS, iD, iH, iW, iC},{ 0.014, 0.032, 0.05 , 0.068, 0.118, 0.181, 0.244, 0.307, 0.212, 0.257, 0.302, 0.347, 0.208, 0.298, 0.388, 0.478, 1.028, 1.262, 1.496, 1.73 , 1.036, 1.18 , 1.324, 1.468, 0.928, 1.018, 1.108, 1.198, 2.9 , 3.134, 3.368, 3.602, 2.188, 2.332, 2.476, 2.62 ,
1.202, 1.274, 1.346, 1.418, 3.142, 3.313, 3.484, 3.655, 2.048, 2.147, 2.246, 2.345, 0.532, 0.676, 0.82 , 0.964, 2.324, 2.666, 3.008, 3.35 , 2.008, 2.206, 2.404, 2.602, 3.584, 3.98 , 4.376, 4.772,10.552,11.452,12.352,13.252, 7.4 , 7.904, 8.408, 8.912,
6.752, 7.148, 7.544, 7.94 ,17.752,18.652,19.552,20.452,11.432,11.936,12.44 ,12.944, 5.932, 6.184, 6.436, 6.688,14.42 ,14.978,15.536,16.094, 8.704, 9.01 , 9.316, 9.622, 3.11 , 3.236, 3.362, 3.488, 7.39 , 7.669, 7.948, 8.227, 4.388, 4.541, 4.694, 4.847,
8.56 , 8.866, 9.172, 9.478,19.892,20.558,21.224,21.89 ,11.548,11.908,12.268,12.628,11.008,11.314,11.62 ,11.926,25.22 ,25.886,26.552,27.218,14.428,14.788,15.148,15.508, 7.322, 7.502, 7.682, 7.862,16.462,16.849,17.236,17.623, 9.248, 9.455, 9.662, 9.869,
0.158, 0.392, 0.626, 0.86 , 1.27 , 1.765, 2.26 , 2.755, 1.22 , 1.481, 1.742, 2.003, 2.224, 2.746, 3.268, 3.79 , 6.788, 7.886, 8.984,10.082, 4.78 , 5.356, 5.932, 6.508, 6.4 , 6.922, 7.444, 7.966,15.572,16.67 ,17.768,18.866, 9.388, 9.964,10.54 ,11.116,
4.802, 5.09 , 5.378, 5.666,11.206,11.809,12.412,13.015, 6.512, 6.827, 7.142, 7.457, 6.004, 6.58 , 7.156, 7.732,14.996,16.202,17.408,18.614, 9.208, 9.838,10.468,11.098,17.984,19.244,20.504,21.764,42.808,45.436,48.064,50.692,25.256,26.624,27.992,29.36 ,
28.064,29.324,30.584,31.844,63.832,66.46 ,69.088,71.716,36.2 ,37.568,38.936,40.304,18.316,19. ,19.684,20.368,40.916,42.338,43.76 ,45.182,22.816,23.554,24.292,25.03 , 8.438, 8.78 , 9.122, 9.464,18.91 ,19.621,20.332,21.043,10.58 ,10.949,11.318,11.687,
20.944,21.682,22.42 ,23.158,46.388,47.918,49.448,50.978,25.66 ,26.452,27.244,28.036,26.848,27.586,28.324,29.062,58.628,60.158,61.688,63.218,31.996,32.788,33.58 ,34.372,16.106,16.502,16.898,17.294,34.894,35.713,36.532,37.351,18.896,19.319,19.742,20.165});
auto expGradW = NDArrayFactory::create<TypeParam>('c', {kD, kH, kW, iC, oC},{7.52, 7.84, 8.16,7.52, 7.84, 8.16,7.52, 7.84, 8.16,7.52, 7.84, 8.16,7.52, 7.84, 8.16,7.52, 7.84, 8.16,7.52, 7.84, 8.16,7.52, 7.84, 8.16,7.52, 7.84, 8.16,7.52, 7.84, 8.16,7.52, 7.84, 8.16,7.52, 7.84, 8.16,
7.52, 7.84, 8.16,7.52, 7.84, 8.16,7.52, 7.84, 8.16,7.52, 7.84, 8.16,7.52, 7.84, 8.16,7.52, 7.84, 8.16,7.52, 7.84, 8.16,7.52, 7.84, 8.16,7.52, 7.84, 8.16,7.52, 7.84, 8.16,7.52, 7.84, 8.16,7.52, 7.84, 8.16,
7.52, 7.84, 8.16,7.52, 7.84, 8.16,7.52, 7.84, 8.16,7.52, 7.84, 8.16,7.52, 7.84, 8.16,7.52, 7.84, 8.16,7.52, 7.84, 8.16,7.52, 7.84, 8.16,7.52, 7.84, 8.16,7.52, 7.84, 8.16,7.52, 7.84, 8.16,7.52, 7.84, 8.16,
7.52, 7.84, 8.16,7.52, 7.84, 8.16,7.52, 7.84, 8.16,7.52, 7.84, 8.16,7.52, 7.84, 8.16,7.52, 7.84, 8.16,7.52, 7.84, 8.16,7.52, 7.84, 8.16,7.52, 7.84, 8.16,7.52, 7.84, 8.16,7.52, 7.84, 8.16,7.52, 7.84, 8.16});
// auto expGradB('c', {oC},{});
input = 2.;
weights.linspace(0.1, 0.1);
gradO.linspace(0.01, 0.01);
nd4j::ops::conv3dnew_bp op;
auto results = op.execute({&input, &weights, &bias, &gradO}, {}, {kD,kH,kW, sD,sH,sW, pD,pH,pW, dD,dH,dW, paddingMode, dataFormat});
auto gradI = results->at(0);
auto gradW = results->at(1);
ASSERT_EQ(Status::OK(), results->status());
ASSERT_TRUE(expGradI.isSameShape(gradI));
ASSERT_TRUE(expGradI.equalsTo(gradI));
ASSERT_TRUE(expGradW.isSameShape(gradW));
ASSERT_TRUE(expGradW.equalsTo(gradW));
delete results;
}
////////////////////////////////////////////////////////////////////
TYPED_TEST(TypedConvolutionTests1, conv3d_bp_test3) {
int bS=2, iD=3,iH=4,iW=3, iC=4,oC=3, kD=2,kH=3,kW=2, sD=1,sH=1,sW=1, pD=0,pH=0,pW=0, dD=1,dH=1,dW=1;
int oD=2,oH=2,oW=2;
int paddingMode = 0; // 1-SAME, 0-VALID;
int dataFormat = 0; // 1-NDHWC, 0-NCDHW
auto input = NDArrayFactory::create<TypeParam>('c', {bS, iC, iD, iH, iW});
auto weights = NDArrayFactory::create<TypeParam>('c', {oC, iC, kD, kH, kW});
auto bias = NDArrayFactory::create<TypeParam>('c', {oC}, {1,2,3});
auto gradO = NDArrayFactory::create<TypeParam>('c', {bS, oC, oD, oH, oW});
auto expGradI = NDArrayFactory::create<TypeParam>('c', {bS, iC, iD, iH, iW},{2.091, 4.356, 2.268, 4.53 , 9.42 , 4.896, 4.65 , 9.672, 5.028, 2.517, 5.226, 2.712, 4.932,10.242, 5.316,10.62 ,22.02 ,11.412,10.908,22.62 ,11.724, 5.868,12.15 , 6.288, 2.913, 6.03 , 3.12 , 6.234,12.888, 6.66 , 6.402,13.236, 6.84 , 3.423, 7.068, 3.648,
2.415, 5.04 , 2.628, 5.25 ,10.932, 5.688, 5.37 ,11.184, 5.82 , 2.913, 6.054, 3.144, 5.724,11.898, 6.18 ,12.348,25.62 ,13.284,12.636,26.22 ,13.596, 6.804,14.094, 7.296, 3.381, 7.002, 3.624, 7.242,14.976, 7.74 , 7.41 ,15.324, 7.92 , 3.963, 8.184, 4.224,
2.739, 5.724, 2.988, 5.97 ,12.444, 6.48 , 6.09 ,12.696, 6.612, 3.309, 6.882, 3.576, 6.516,13.554, 7.044,14.076,29.22 ,15.156,14.364,29.82 ,15.468, 7.74 ,16.038, 8.304, 3.849, 7.974, 4.128, 8.25 ,17.064, 8.82 , 8.418,17.412, 9. , 4.503, 9.3 , 4.8 ,
3.063, 6.408, 3.348, 6.69 ,13.956, 7.272, 6.81 ,14.208, 7.404, 3.705, 7.71 , 4.008, 7.308,15.21 , 7.908,15.804,32.82 ,17.028,16.092,33.42 ,17.34 , 8.676,17.982, 9.312, 4.317, 8.946, 4.632, 9.258,19.152, 9.9 , 9.426,19.5 ,10.08 , 5.043,10.416, 5.376,
5.619,11.484, 5.868,11.73 ,23.964,12.24 ,12.138,24.792,12.66 , 6.333,12.93 , 6.6 ,12.42 ,25.362,12.948,25.884,52.836,26.964,26.748,54.588,27.852,13.932,28.422,14.496, 6.873,14.022, 7.152,14.298,29.16 ,14.868,14.754,30.084,15.336, 7.671,15.636, 7.968,
6.807,13.896, 7.092,14.178,28.932,14.76 ,14.586,29.76 ,15.18 , 7.593,15.486, 7.896,14.94 ,30.474,15.54 ,31.068,63.348,32.292,31.932,65.1 ,33.18 ,16.596,33.822,17.232, 8.205,16.722, 8.52 ,17.034,34.704,17.676,17.49 ,35.628,18.144, 9.075,18.48 , 9.408,
7.995,16.308, 8.316,16.626,33.9 ,17.28 ,17.034,34.728,17.7 , 8.853,18.042, 9.192,17.46 ,35.586,18.132,36.252,73.86 ,37.62 ,37.116,75.612,38.508,19.26 ,39.222,19.968, 9.537,19.422, 9.888,19.77 ,40.248,20.484,20.226,41.172,20.952,10.479,21.324,10.848,
9.183,18.72 , 9.54 ,19.074,38.868,19.8 ,19.482,39.696,20.22 ,10.113,20.598,10.488,19.98 ,40.698,20.724,41.436,84.372,42.948,42.3 ,86.124,43.836,21.924,44.622,22.704,10.869,22.122,11.256,22.506,45.792,23.292,22.962,46.716,23.76 ,11.883,24.168,12.288});
auto expGradW = NDArrayFactory::create<TypeParam>('c', {oC, iC, kD, kH, kW},{5.28, 5.28, 5.28, 5.28, 5.28, 5.28, 5.28, 5.28, 5.28, 5.28, 5.28, 5.28, 5.28, 5.28, 5.28, 5.28, 5.28, 5.28, 5.28, 5.28, 5.28, 5.28, 5.28, 5.28,
5.28, 5.28, 5.28, 5.28, 5.28, 5.28, 5.28, 5.28, 5.28, 5.28, 5.28, 5.28, 5.28, 5.28, 5.28, 5.28, 5.28, 5.28, 5.28, 5.28, 5.28, 5.28, 5.28, 5.28,
7.84, 7.84, 7.84, 7.84, 7.84, 7.84, 7.84, 7.84, 7.84, 7.84, 7.84, 7.84, 7.84, 7.84, 7.84, 7.84, 7.84, 7.84, 7.84, 7.84, 7.84, 7.84, 7.84, 7.84,
7.84, 7.84, 7.84, 7.84, 7.84, 7.84, 7.84, 7.84, 7.84, 7.84, 7.84, 7.84, 7.84, 7.84, 7.84, 7.84, 7.84, 7.84, 7.84, 7.84, 7.84, 7.84, 7.84, 7.84,
10.4, 10.4, 10.4, 10.4, 10.4, 10.4, 10.4, 10.4, 10.4, 10.4, 10.4, 10.4, 10.4, 10.4, 10.4, 10.4, 10.4, 10.4, 10.4, 10.4, 10.4, 10.4, 10.4, 10.4,
10.4, 10.4, 10.4, 10.4, 10.4, 10.4, 10.4, 10.4, 10.4, 10.4, 10.4, 10.4, 10.4, 10.4, 10.4, 10.4, 10.4, 10.4, 10.4, 10.4, 10.4, 10.4, 10.4, 10.4});
auto expGradB = NDArrayFactory::create<TypeParam>('c', {oC},{2.64, 3.92, 5.2 });
input = 2.;
weights.linspace(0.1, 0.1);
gradO.linspace(0.01, 0.01);
weights.permutei({2, 3, 4, 1, 0});
expGradW.permutei({2, 3, 4, 1, 0});
nd4j::ops::conv3dnew_bp op;
auto results = op.execute({&input, &weights, &bias, &gradO}, {}, {kD,kH,kW, sD,sH,sW, pD,pH,pW, dD,dH,dW, paddingMode, dataFormat});
auto* gradI = results->at(0);
auto* gradW = results->at(1);
auto* gradB = results->at(2);
ASSERT_EQ(Status::OK(), results->status());
ASSERT_TRUE(expGradI.isSameShape(gradI));
ASSERT_TRUE(expGradI.equalsTo(gradI));
ASSERT_TRUE(expGradW.isSameShape(gradW));
ASSERT_TRUE(expGradW.equalsTo(gradW));
ASSERT_TRUE(expGradB.isSameShape(gradB));
ASSERT_TRUE(expGradB.equalsTo(gradB));
delete results;
}
//////////////////////////////////////////////////////////////////////
TYPED_TEST(TypedConvolutionTests1, depthwise_conv2d_1) {
int bS=2, iH=4,iW=3, iC=2,mC=2, kH=3,kW=2, sH=1,sW=1, pH=0,pW=0, dH=1,dW=1;
int oC=iC*mC;
int oH=4,oW=3;
int paddingMode = 1; // 1-SAME, 0-VALID;
int dataFormat = 1; // 1-NHWC, 0-NCHW
auto input = NDArrayFactory::create<TypeParam>('c', {bS, iH, iW, iC});
auto weights = NDArrayFactory::create<TypeParam>('c', {kH, kW, iC, mC});
auto expOutput = NDArrayFactory::create<TypeParam>('c', {bS, oH, oW, oC},{12. , 12.8, 13.6, 14.4,12. , 12.8, 13.6, 14.4, 5.2, 5.6, 6. , 6.4,13.2, 14.4, 15.6, 16.8,13.2, 14.4, 15.6, 16.8, 5.4, 6. , 6.6, 7.2,
13.2, 14.4, 15.6, 16.8,13.2, 14.4, 15.6, 16.8, 5.4, 6. , 6.6, 7.2, 5.6, 6.4, 7.2, 8. , 5.6, 6.4, 7.2, 8. , 2. , 2.4, 2.8, 3.2,
12. , 12.8, 13.6, 14.4,12. , 12.8, 13.6, 14.4, 5.2, 5.6, 6. , 6.4,13.2, 14.4, 15.6, 16.8,13.2, 14.4, 15.6, 16.8, 5.4, 6. , 6.6, 7.2,
13.2, 14.4, 15.6, 16.8,13.2, 14.4, 15.6, 16.8, 5.4, 6. , 6.6, 7.2, 5.6, 6.4, 7.2, 8. , 5.6, 6.4, 7.2, 8. , 2. , 2.4, 2.8, 3.2});
input = 2.;
weights.linspace(0.1, 0.1);
nd4j::ops::depthwise_conv2d op;
auto results = op.execute({&input, &weights}, {}, {kH,kW, sH,sW, pH,pW, dH,dW, paddingMode, dataFormat});
auto* output = results->at(0);
ASSERT_EQ(Status::OK(), results->status());
ASSERT_TRUE(expOutput.isSameShape(output));
ASSERT_TRUE(expOutput.equalsTo(output));
delete results;
}
//////////////////////////////////////////////////////////////////////
TEST_F(ConvolutionTests1, depthwise_conv2d_2) {
int bS=2, iH=4,iW=3, iC=2,mC=2, kH=3,kW=2, sH=1,sW=1, pH=0,pW=0, dH=1,dW=1;
int oC=iC*mC;
int oH=2,oW=2;
int paddingMode = 0; // 1-SAME, 0-VALID;
int dataFormat = 1; // 1-NHWC, 0-NCHW
auto input = NDArrayFactory::create<double>('c', {bS, iH, iW, iC});
auto weights = NDArrayFactory::create<double>('c', {kH, kW, iC, mC});
auto expOutput = NDArrayFactory::create<double>('c', {bS, oH, oW, oC},{13.2, 14.4, 15.6, 16.8,13.2, 14.4, 15.6, 16.8,13.2, 14.4, 15.6, 16.8,13.2, 14.4, 15.6, 16.8,
13.2, 14.4, 15.6, 16.8,13.2, 14.4, 15.6, 16.8,13.2, 14.4, 15.6, 16.8,13.2, 14.4, 15.6, 16.8});
input = 2.;
weights.linspace(0.1, 0.1);
nd4j::ops::depthwise_conv2d op;
auto results = op.execute({&input, &weights}, {}, {kH,kW, sH,sW, pH,pW, dH,dW, paddingMode, dataFormat});
auto* output = results->at(0);
ASSERT_EQ(Status::OK(), results->status());
ASSERT_TRUE(expOutput.isSameShape(output));
ASSERT_TRUE(expOutput.equalsTo(output));
delete results;
}
//////////////////////////////////////////////////////////////////////
TEST_F(ConvolutionTests1, depthwise_conv2d_3) {
int bS=2, iH=4,iW=3, iC=2,mC=2, kH=3,kW=2, sH=1,sW=1, pH=0,pW=0, dH=1,dW=1;
int oC=iC*mC;
int oH=2,oW=2;
int paddingMode = 0; // 1-SAME, 0-VALID;
int dataFormat = 0; // 1-NHWC, 0-NCHW
auto input = NDArrayFactory::create<double>('c', {bS, iC, iH, iW});
auto weights = NDArrayFactory::create<double>('c', {mC, iC, kH, kW});
auto biases = NDArrayFactory::create<double>('c', {iC*mC}, {1,2,3,4});
auto expOutput = NDArrayFactory::create<double>('c', {bS, oC, oH, oW},{5.2, 5.2, 5.2, 5.2,20.6,20.6,20.6,20.6,14.4,14.4,14.4,14.4,29.8,29.8,29.8,29.8, 5.2, 5.2, 5.2, 5.2,20.6,20.6,20.6,20.6,14.4,14.4,14.4,14.4,29.8,29.8,29.8,29.8});
input = 2.;
weights.linspace(0.1, 0.1);
weights.permutei({2,3,1,0});
nd4j::ops::depthwise_conv2d op;
auto results = op.execute({&input, &weights, &biases}, {}, {kH,kW, sH,sW, pH,pW, dH,dW, paddingMode, dataFormat});
auto* output = results->at(0);
// output->printIndexedBuffer();
ASSERT_EQ(Status::OK(), results->status());
ASSERT_TRUE(expOutput.isSameShape(output));
ASSERT_TRUE(expOutput.equalsTo(output));
delete results;
}
//////////////////////////////////////////////////////////////////////
TEST_F(ConvolutionTests1, depthwise_conv2d_4) {
int bS=1, iH=111,iW=111, iC=32,mC=1, kH=7,kW=7, sH=2,sW=2, pH=0,pW=0, dH=1,dW=1;
int oC=iC*mC;
int oH=56,oW=56;
int paddingMode = 1; // 1-SAME, 0-VALID;
int dataFormat = 1; // 1-NHWC, 0-NCHW
const float unique = -1000000;
NDArray input('c', {bS, iH, iW, iC}, nd4j::DataType::FLOAT32);
NDArray weights('c', {kH, kW, iC, mC}, nd4j::DataType::FLOAT32);
NDArray output('c', {bS, oH, oW, oC}, nd4j::DataType::FLOAT32);
input.linspace(0.1, 0.0001);
weights = 0.5;
output = unique;
nd4j::ops::depthwise_conv2d op;
Nd4jStatus status = op.execute({&input, &weights}, {&output} , {}, {kH,kW, sH,sW, pH,pW, dH,dW, paddingMode, dataFormat}, {});
ASSERT_EQ(Status::OK(), status);
for(Nd4jLong i=output.lengthOf()/1.5; i < output.lengthOf(); ++i)
ASSERT_EQ(output.e<float>(i) != unique, true);
}
//////////////////////////////////////////////////////////////////////
TEST_F(ConvolutionTests1, depthwise_conv2d_5) {
int bS=1, iH=3,iW=3, iC=2,mC=1, kH=2,kW=2, sH=1,sW=1, pH=0,pW=0, dH=1,dW=1;
int oC=iC*mC;
int oH=3,oW=3;
int paddingMode = 1; // 1-SAME, 0-VALID;
int dataFormat = 1; // 1-NHWC, 0-NCHW
auto input = NDArrayFactory::create<double>('c', {bS, iH, iW, iC});
auto weights = NDArrayFactory::create<double>('c', {kH, kW, iC, mC});
auto expOutput = NDArrayFactory::create<double>('c', {bS, oH, oW, oC}, {20., 24.,28., 32.,16., 18.,44., 48.,52., 56.,28., 30.,28., 30.,32., 34.,17., 18.});
input.linspace(1.);
weights = 1.;
nd4j::ops::depthwise_conv2d op;
auto results = op.execute({&input, &weights}, {}, {kH,kW, sH,sW, pH,pW, dH,dW, paddingMode, dataFormat});
auto output = results->at(0);
// output->printIndexedBuffer();
ASSERT_EQ(Status::OK(), results->status());
ASSERT_TRUE(expOutput.isSameShape(output));
ASSERT_TRUE(expOutput.equalsTo(output));
delete results;
}
//////////////////////////////////////////////////////////////////////
TEST_F(ConvolutionTests1, depthwise_conv2d_6) {
int bS=1, iH=3,iW=3, iC=2,mC=1, kH=2,kW=2, sH=1,sW=1, pH=0,pW=0, dH=1,dW=1;
int oC=iC*mC;
int oH=3,oW=3;
int paddingMode = 1; // 1-SAME, 0-VALID;
int dataFormat = 1; // 1-NHWC, 0-NCHW
NDArray input('c', {bS, iH, iW, iC}, nd4j::DataType::DOUBLE);
NDArray weights('c', {kH, kW, iC, mC}, nd4j::DataType::DOUBLE);
NDArray expOutput('c', {bS, oH, oW, oC}, {20., 24.,28., 32.,16., 18.,44., 48.,52., 56.,28., 30.,28., 30.,32., 34.,17., 18.});
input.linspace(1.);
weights = 1.;
nd4j::ops::depthwise_conv2d op;
ResultSet* results = op.execute({&input, &weights}, {}, {kH,kW, sH,sW, pH,pW, dH,dW, paddingMode, dataFormat});
NDArray* output = results->at(0);
// output.printIndexedBuffer();
ASSERT_EQ(Status::OK(), results->status());
ASSERT_TRUE(expOutput.isSameShape(output));
ASSERT_TRUE(expOutput.equalsTo(output));
delete results;
}
//////////////////////////////////////////////////////////////////////
TEST_F(ConvolutionTests1, depthwise_conv2d_bp_test1) {
int bS=2, iH=4,iW=3, iC=2,mC=2, kH=3,kW=2, sH=1,sW=1, pH=0,pW=0, dH=1,dW=1;
int oH=4,oW=3;
int oC=iC*mC;
int paddingMode = 1; // 1-SAME, 0-VALID;
int dataFormat = 1; // 1-NHWC, 0-NCHW
auto input = NDArrayFactory::create<double>('c', {bS, iH, iW, iC});
auto weights = NDArrayFactory::create<double>('c', {kH, kW, iC, mC});
auto bias = NDArrayFactory::create<double>('c', {oC}, {1,2,3,4});
auto gradO = NDArrayFactory::create<double>('c', {bS, oH, oW, oC});
auto expGradI = NDArrayFactory::create<double>('c', {bS, iH, iW, iC},{0.07 , 0.19 , 0.348, 0.652, 0.588, 0.956, 0.387, 0.687, 1.326, 2.022, 1.878, 2.67 , 1.071, 1.515, 2.982, 3.966, 3.534, 4.614, 1.606, 1.982, 3.932, 4.748, 4.428, 5.308,
1.126, 1.63 , 3.228, 4.3 , 3.468, 4.604, 3.123, 3.999, 7.95 , 9.798, 8.502, 10.446, 3.807, 4.827, 9.606, 11.742,10.158, 12.39 , 4.198, 4.958, 9.884, 11.468,10.38 , 12.028});
auto expGradW = NDArrayFactory::create<double>('c', {kH, kW, iC, mC},{19.08, 19.44,19.8 , 20.16,12.24, 12.48,12.72, 12.96,22.56, 23.04,23.52, 24. ,14.4 , 14.72,15.04, 15.36,14.76, 15.12,15.48, 15.84, 9.36, 9.6 , 9.84, 10.08});
input = 2.;
weights.linspace(0.1, 0.1);
gradO.linspace(0.01, 0.01);
nd4j::ops::depthwise_conv2d_bp op;
auto results = op.execute({&input, &weights, &bias, &gradO}, {}, {kH,kW, sH,sW, pH,pW, dH,dW, paddingMode, dataFormat});
auto* gradI = results->at(0);
auto* gradW = results->at(1);
ASSERT_EQ(Status::OK(), results->status());
ASSERT_TRUE(expGradI.isSameShape(gradI));
ASSERT_TRUE(expGradI.equalsTo(gradI));
ASSERT_TRUE(expGradW.isSameShape(gradW));
ASSERT_TRUE(expGradW.equalsTo(gradW));
delete results;
}
//////////////////////////////////////////////////////////////////////
TEST_F(ConvolutionTests1, depthwise_conv2d_bp_test2) {
int bS=2, iH=4,iW=3, iC=2,mC=2, kH=3,kW=2, sH=1,sW=1, pH=0,pW=0, dH=1,dW=1;
int oH=2,oW=2;
int oC=iC*mC;
int paddingMode = 0; // 1-SAME, 0-VALID;
int dataFormat = 1; // 1-NHWC, 0-NCHW
auto input = NDArrayFactory::create<double>('c', {bS, iH, iW, iC});
auto weights = NDArrayFactory::create<double>('c', {kH, kW, iC, mC});
auto bias = NDArrayFactory::create<double>('c', {oC}, {1,2,3,4});
auto gradO = NDArrayFactory::create<double>('c', {bS, oH, oW, oC});
auto expGradI = NDArrayFactory::create<double>('c', {bS, iH, iW, iC},{0.005, 0.025,0.034, 0.106,0.061, 0.113,0.058, 0.162,0.292, 0.564,0.298, 0.466,0.234, 0.402,0.772, 1.172,0.602, 0.834,0.333, 0.449,0.882, 1.146,0.581, 0.729,
0.053, 0.137,0.258, 0.458,0.237, 0.353,0.41 , 0.642,1.252, 1.78 ,0.906, 1.202,1.098, 1.394,2.756, 3.412,1.722, 2.082,0.893, 1.073,2.13 , 2.522,1.269, 1.481});
auto expGradW = NDArrayFactory::create<double>('c', {kH, kW, iC, mC},{2.4 , 2.56,2.72, 2.88,2.4 , 2.56,2.72, 2.88,2.4 , 2.56,2.72, 2.88,2.4 , 2.56,2.72, 2.88,2.4 , 2.56,2.72, 2.88,2.4 , 2.56,2.72, 2.88});
input = 2.;
weights.linspace(0.1, 0.1);
gradO.linspace(0.01, 0.01);
nd4j::ops::depthwise_conv2d_bp op;
auto results = op.execute({&input, &weights, &bias, &gradO}, {}, {kH,kW, sH,sW, pH,pW, dH,dW, paddingMode, dataFormat});
auto* gradI = results->at(0);
auto* gradW = results->at(1);
ASSERT_EQ(Status::OK(), results->status());
ASSERT_TRUE(expGradI.isSameShape(gradI));
ASSERT_TRUE(expGradI.equalsTo(gradI));
ASSERT_TRUE(expGradW.isSameShape(gradW));
ASSERT_TRUE(expGradW.equalsTo(gradW));
delete results;
}
//////////////////////////////////////////////////////////////////////
TYPED_TEST(TypedConvolutionTests1, conv3d_test1) {
int bS=2, iD=3,iH=4,iW=3, iC=4,oC=3, kD=2,kH=3,kW=2, sD=1,sH=1,sW=1, pD=0,pH=0,pW=0, dD=1,dH=1,dW=1;
int paddingMode = 1; // 1-SAME, 0-VALID;
int dataFormat = 1; // 1-NDHWC, 0-NCDHW
auto input = NDArrayFactory::create<TypeParam>('c', {bS, iD, iH, iW, iC});
auto weights = NDArrayFactory::create<TypeParam>('c', {kD, kH, kW, iC, oC});
auto expected = NDArrayFactory::create<TypeParam>('c', {2, 3, 4, 3, 3}, {64.,64.,64.,64.,64.,64.,32.,32.,32.,96.,96.,96.,96.,96.,96.,48.,48.,48.,96.,96.,96.,96.,96.,96.,48.,48.,48.,
64.,64.,64.,64.,64.,64.,32.,32.,32.,64.,64.,64.,64.,64.,64.,32.,32.,32.,96.,96.,96.,96.,96.,96.,48.,48.,48.,
96.,96.,96.,96.,96.,96.,48.,48.,48.,64.,64.,64.,64.,64.,64.,32.,32.,32.,32.,32.,32.,32.,32.,32.,16.,16.,16.,
48.,48.,48.,48.,48.,48.,24.,24.,24.,48.,48.,48.,48.,48.,48.,24.,24.,24.,32.,32.,32.,32.,32.,32.,16.,16.,16.,
64.,64.,64.,64.,64.,64.,32.,32.,32.,96.,96.,96.,96.,96.,96.,48.,48.,48.,96.,96.,96.,96.,96.,96.,48.,48.,48.,
64.,64.,64.,64.,64.,64.,32.,32.,32.,64.,64.,64.,64.,64.,64.,32.,32.,32.,96.,96.,96.,96.,96.,96.,48.,48.,48.,
96.,96.,96.,96.,96.,96.,48.,48.,48.,64.,64.,64.,64.,64.,64.,32.,32.,32.,32.,32.,32.,32.,32.,32.,16.,16.,16.,
48.,48.,48.,48.,48.,48.,24.,24.,24.,48.,48.,48.,48.,48.,48.,24.,24.,24.,32.,32.,32.,32.,32.,32.,16.,16.,16.});
input = 2.;
weights = 1.;
nd4j::ops::conv3dnew op;
auto results = op.execute({&input, &weights}, {}, {kD,kH,kW, sD,sH,sW, pD,pH,pW, dD,dH,dW, paddingMode, dataFormat});
auto* output = results->at(0);
ASSERT_EQ(Status::OK(), results->status());
ASSERT_TRUE(expected.isSameShape(output));
ASSERT_TRUE(expected.equalsTo(output));
delete results;
}
//////////////////////////////////////////////////////////////////////
TYPED_TEST(TypedConvolutionTests1, conv3d_test2) {
int bS=2, iD=3,iH=4,iW=3, iC=4,oC=3, kD=2,kH=3,kW=2, sD=1,sH=1,sW=1, pD=0,pH=0,pW=0, dD=1,dH=1,dW=1;
int paddingMode = 1; // 1-SAME, 0-VALID;
int dataFormat = 1; // 1-NDHWC, 0-NCDHW
auto input = NDArrayFactory::create<TypeParam>('c', {bS, iD, iH, iW, iC});
auto weights = NDArrayFactory::create<TypeParam>('c', {kD, kH, kW, iC, oC});
auto expected = NDArrayFactory::create<TypeParam>('c', {2, 3, 4, 3, 3}, {534.4,540.8,547.2,534.4,540.8,547.2,248. ,251.2,254.4,686.4,696. ,705.6,686.4,696. ,705.6,314.4,319.2,324. ,686.4,696. ,705.6,686.4,696. ,705.6,314.4,319.2,324. ,
380.8,387.2,393.6,380.8,387.2,393.6,171.2,174.4,177.6,534.4,540.8,547.2,534.4,540.8,547.2,248. ,251.2,254.4,686.4,696. ,705.6,686.4,696. ,705.6,314.4,319.2,324. ,
686.4,696. ,705.6,686.4,696. ,705.6,314.4,319.2,324. ,380.8,387.2,393.6,380.8,387.2,393.6,171.2,174.4,177.6,152. ,155.2,158.4,152. ,155.2,158.4, 66.4, 68. , 69.6,
170.4,175.2,180. ,170.4,175.2,180. , 70.8, 73.2, 75.6,170.4,175.2,180. ,170.4,175.2,180. , 70.8, 73.2, 75.6, 75.2, 78.4, 81.6, 75.2, 78.4, 81.6, 28. , 29.6, 31.2,
534.4,540.8,547.2,534.4,540.8,547.2,248. ,251.2,254.4,686.4,696. ,705.6,686.4,696. ,705.6,314.4,319.2,324. ,686.4,696. ,705.6,686.4,696. ,705.6,314.4,319.2,324. ,
380.8,387.2,393.6,380.8,387.2,393.6,171.2,174.4,177.6,534.4,540.8,547.2,534.4,540.8,547.2,248. ,251.2,254.4,686.4,696. ,705.6,686.4,696. ,705.6,314.4,319.2,324. ,
686.4,696. ,705.6,686.4,696. ,705.6,314.4,319.2,324. ,380.8,387.2,393.6,380.8,387.2,393.6,171.2,174.4,177.6,152. ,155.2,158.4,152. ,155.2,158.4, 66.4, 68. , 69.6,
170.4,175.2,180. ,170.4,175.2,180. , 70.8, 73.2, 75.6,170.4,175.2,180. ,170.4,175.2,180. , 70.8, 73.2, 75.6, 75.2, 78.4, 81.6, 75.2, 78.4, 81.6, 28. , 29.6, 31.2});
input = 2.;
weights.linspace(0.1, 0.1);
nd4j::ops::conv3dnew op;
auto results = op.execute({&input, &weights}, {}, {kD,kH,kW, sD,sH,sW, pD,pH,pW, dD,dH,dW, paddingMode, dataFormat});
auto* output = results->at(0);
ASSERT_EQ(Status::OK(), results->status());
ASSERT_TRUE(expected.isSameShape(output));
ASSERT_TRUE(expected.equalsTo(output));
delete results;
}
//////////////////////////////////////////////////////////////////////
TYPED_TEST(TypedConvolutionTests1, conv3d_test3) {
int bS=2, iD=3,iH=4,iW=3, iC=4,oC=3, kD=2,kH=3,kW=2, sD=1,sH=1,sW=1, pD=0,pH=0,pW=0, dD=1,dH=1,dW=1;
int paddingMode = 0; // 1-SAME, 0-VALID;
int dataFormat = 1; // 1-NDHWC, 0-NCDHW
auto input = NDArrayFactory::create<TypeParam>('c', {bS, iD, iH, iW, iC});
auto weights = NDArrayFactory::create<TypeParam>('c', {kD, kH, kW, iC, oC});
auto expected = NDArrayFactory::create<TypeParam>('c', {2, 2, 2, 2, 3}, {686.4,696.,705.6,686.4,696.,705.6,686.4,696.,705.6,686.4,696.,705.6,
686.4,696.,705.6,686.4,696.,705.6,686.4,696.,705.6,686.4,696.,705.6,
686.4,696.,705.6,686.4,696.,705.6,686.4,696.,705.6,686.4,696.,705.6,
686.4,696.,705.6,686.4,696.,705.6,686.4,696.,705.6,686.4,696.,705.6});
input = 2.;
weights.linspace(0.1, 0.1);
nd4j::ops::conv3dnew op;
auto results = op.execute({&input, &weights}, {}, {kD,kH,kW, sD,sH,sW, pD,pH,pW, dD,dH,dW, paddingMode, dataFormat});
auto* output = results->at(0);
ASSERT_EQ(Status::OK(), results->status());
ASSERT_TRUE(expected.isSameShape(output));
ASSERT_TRUE(expected.equalsTo(output));
delete results;
}
//////////////////////////////////////////////////////////////////////
TYPED_TEST(TypedConvolutionTests1, conv3d_test4) {
int bS=2, iD=3,iH=4,iW=3, iC=4,oC=3, kD=2,kH=3,kW=2, sD=1,sH=1,sW=1, pD=0,pH=0,pW=0, dD=1,dH=1,dW=1;
int paddingMode = 0; // 1-SAME, 0-VALID;
int dataFormat = 0; // 1-NDHWC, 0-NCDHW
auto input = NDArrayFactory::create<TypeParam>('c', {bS, iC, iD, iH, iW});
auto weights = NDArrayFactory::create<TypeParam>('c', {kD, kH, kW, iC, oC});
auto expected = NDArrayFactory::create<TypeParam>('c', {2, 3, 2, 2, 2});
input = 2.;
weights = 0.5;
expected = 48.;
nd4j::ops::conv3dnew op;
auto results = op.execute({&input, &weights}, {}, {kD,kH,kW, sD,sH,sW, pD,pH,pW, dD,dH,dW, paddingMode, dataFormat});
auto* output = results->at(0);
ASSERT_EQ(Status::OK(), results->status());
ASSERT_TRUE(expected.isSameShape(output));
ASSERT_TRUE(expected.equalsTo(output));
delete results;
}
////////////////////////////////////////////////////////////////////
TYPED_TEST(TypedConvolutionTests1, conv3d_test5) {
int bS=2, iD=3,iH=4,iW=3, iC=4,oC=3, kD=2,kH=3,kW=2, sD=1,sH=1,sW=1, pD=0,pH=0,pW=0, dD=1,dH=1,dW=1;
int paddingMode = 0; // 1-SAME, 0-VALID;
int dataFormat = 0; // 1-NDHWC, 0-NCDHW
auto input = NDArrayFactory::create<TypeParam>('c', {bS, iC, iD, iH, iW});
auto weights = NDArrayFactory::create<TypeParam>('c', {kD, kH, kW, iC, oC});
auto bias = NDArrayFactory::create<TypeParam>('c', {oC});
auto expected = NDArrayFactory::create<TypeParam>('c', {2, 3, 2, 2, 2});
input = 2.;
weights = 0.5;
expected = 49.;
bias = 1.;
nd4j::ops::conv3dnew op;
auto results = op.execute({&input, &weights, &bias}, {}, {kD,kH,kW, sD,sH,sW, pD,pH,pW, dD,dH,dW, paddingMode, dataFormat});
auto* output = results->at(0);
// output->printIndexedBuffer();
ASSERT_EQ(Status::OK(), results->status());
ASSERT_TRUE(expected.isSameShape(output));
ASSERT_TRUE(expected.equalsTo(output));
delete results;
}
////////////////////////////////////////////////////////////////////
TYPED_TEST(TypedConvolutionTests1, conv3d_test6) {
int bS=2, iD=3,iH=4,iW=3, iC=4,oC=3, kD=2,kH=3,kW=2, sD=1,sH=1,sW=1, pD=0,pH=0,pW=0, dD=1,dH=1,dW=1;
int paddingMode = 0; // 1-SAME, 0-VALID;
int dataFormat = 0; // 1-NDHWC, 0-NCDHW
auto input = NDArrayFactory::create<TypeParam>('c', {bS, iC, iD, iH, iW});
auto weights = NDArrayFactory::create<TypeParam>('c', {kD, kH, kW, iC, oC});
auto bias = NDArrayFactory::create<TypeParam>('c', {oC},{1,2,3});
auto expected = NDArrayFactory::create<TypeParam>('c', {2, 3, 2, 2, 2},{49., 49.,49., 49., 49., 49.,49., 49., 50., 50.,50., 50., 50., 50.,50., 50.,
51., 51.,51., 51., 51., 51.,51., 51., 49., 49.,49., 49., 49., 49.,49., 49.,
50., 50.,50., 50., 50., 50.,50., 50., 51., 51.,51., 51., 51., 51.,51., 51.});
input = 2.;
weights = 0.5;
nd4j::ops::conv3dnew op;
auto results = op.execute({&input, &weights, &bias}, {}, {kD,kH,kW, sD,sH,sW, pD,pH,pW, dD,dH,dW, paddingMode, dataFormat});
auto* output = results->at(0);
// output->printIndexedBuffer();
ASSERT_EQ(Status::OK(), results->status());
ASSERT_TRUE(expected.isSameShape(output));
ASSERT_TRUE(expected.equalsTo(output));
delete results;
}
////////////////////////////////////////////////////////////////////
TYPED_TEST(TypedConvolutionTests1, conv3d_test7) {
int bS=2, iD=3,iH=4,iW=3, iC=4,oC=3, kD=2,kH=3,kW=2, sD=1,sH=1,sW=1, pD=0,pH=0,pW=0, dD=1,dH=1,dW=1;
int paddingMode = 0; // 1-SAME, 0-VALID;
int dataFormat = 0; // 1-NDHWC, 0-NCDHW
auto input = NDArrayFactory::create<TypeParam>('c', {bS, iC, iD, iH, iW});
auto weights = NDArrayFactory::create<TypeParam>('c', {oC, iC, kD, kH, kW});
auto bias = NDArrayFactory::create<TypeParam>('c', {oC},{1,2,3});
auto expected = NDArrayFactory::create<TypeParam>('c', {2, 3, 2, 2, 2},{236.2, 236.2, 236.2, 236.2, 236.2, 236.2, 236.2, 236.2, 698. , 698. , 698. , 698. ,
698. , 698. , 698. , 698. ,1159.8,1159.8,1159.8,1159.8,1159.8,1159.8,1159.8,1159.8,
236.2, 236.2, 236.2, 236.2, 236.2, 236.2, 236.2, 236.2, 698. , 698. , 698. , 698. ,
698. , 698. , 698. , 698. ,1159.8,1159.8,1159.8,1159.8,1159.8,1159.8,1159.8,1159.8});
input = 2.;
weights.linspace(0.1, 0.1);
weights.permutei({2, 3, 4, 1, 0});
nd4j::ops::conv3dnew op;
auto results = op.execute({&input, &weights, &bias}, {}, {kD,kH,kW, sD,sH,sW, pD,pH,pW, dD,dH,dW, paddingMode, dataFormat});
auto* output = results->at(0);
// output->printIndexedBuffer();
ASSERT_EQ(Status::OK(), results->status());
ASSERT_TRUE(expected.isSameShape(output));
ASSERT_TRUE(expected.equalsTo(output));
delete results;
}
////////////////////////////////////////////////////////////////////
TYPED_TEST(TypedConvolutionTests1, conv3d_test8) {
int bS=2, iD=3,iH=4,iW=3, iC=4,oC=3, kD=2,kH=3,kW=2, sD=1,sH=1,sW=1, pD=0,pH=0,pW=0, dD=1,dH=1,dW=1;
int paddingMode = 0; // 1-SAME, 0-VALID;
int dataFormat = 0; // 1-NDHWC, 0-NCDHW
auto input = NDArrayFactory::create<TypeParam>('c', {bS, iC, iD, iH, iW});
auto weights = NDArrayFactory::create<TypeParam>('c', {oC, iC, kD, kH, kW});
auto expected = NDArrayFactory::create<TypeParam>('c', {2, 3, 2, 2, 2},{235.2, 235.2, 235.2, 235.2, 235.2, 235.2, 235.2, 235.2, 696. , 696. , 696. , 696. , 696. , 696. , 696. , 696. ,
1156.8,1156.8,1156.8,1156.8,1156.8,1156.8,1156.8,1156.8, 235.2, 235.2, 235.2, 235.2, 235.2, 235.2, 235.2, 235.2,
696. , 696. , 696. , 696. , 696. , 696. , 696. , 696. ,1156.8,1156.8,1156.8,1156.8,1156.8,1156.8,1156.8,1156.8});
input = 2.;
weights.linspace(0.1, 0.1);
weights.permutei({2, 3, 4, 1, 0});
nd4j::ops::conv3dnew op;
auto results = op.execute({&input, &weights}, {}, {kD,kH,kW, sD,sH,sW, pD,pH,pW, dD,dH,dW, paddingMode, dataFormat});
auto* output = results->at(0);
ASSERT_EQ(Status::OK(), results->status());
ASSERT_TRUE(expected.isSameShape(output));
ASSERT_TRUE(expected.equalsTo(output));
delete results;
}
//////////////////////////////////////////////////////////////////////
TYPED_TEST(TypedConvolutionTests1, conv3d_test9) {
auto x = NDArrayFactory::create<TypeParam>('c', {4, 2, 28, 28, 3});
auto y = NDArrayFactory::create<TypeParam>('c', {2, 5, 5, 3, 4});
auto e = NDArrayFactory::create<TypeParam>('c', {4, 1, 7, 10, 4});
nd4j::ops::conv3dnew op;
auto result = op.execute({&x, &y}, {}, {2,5,5, 5,4,3, 0,0,0, 1,1,1, 1,1});
ASSERT_EQ(Status::OK(), result->status());
auto z = result->at(0);
ASSERT_TRUE(e.isSameShape(z));
delete result;
}
TYPED_TEST(TypedConvolutionTests1, conv3d_test10) {
auto x = NDArrayFactory::create<TypeParam>('c', {4, 2, 28, 28, 3});
auto w = NDArrayFactory::create<TypeParam>('c', {2, 5, 5, 3, 4});
auto exp = NDArrayFactory::create<TypeParam>('c', {4, 1, 7, 10, 4});
nd4j::ops::conv3dnew op;
auto result = op.execute({&x, &w}, {}, {2,5,5, 5,4,3, 0,0,0, 1,1,1, 1,1});
ASSERT_EQ(Status::OK(), result->status());
ShapeList shapeList({x.shapeInfo(), w.shapeInfo()});
ContextPrototype proto;
Context ctx(1);
ctx.getIArguments()->push_back(2);
ctx.getIArguments()->push_back(5);
ctx.getIArguments()->push_back(5);
ctx.getIArguments()->push_back(5);
ctx.getIArguments()->push_back(4);
ctx.getIArguments()->push_back(3);
ctx.getIArguments()->push_back(0);
ctx.getIArguments()->push_back(0);
ctx.getIArguments()->push_back(0);
ctx.getIArguments()->push_back(1);
ctx.getIArguments()->push_back(1);
ctx.getIArguments()->push_back(1);
ctx.getIArguments()->push_back(0);
ctx.getIArguments()->push_back(1); // previous variant was "ctx.getIArguments()->push_back(0)" and this caused fail
auto shapes = op.calculateOutputShape(&shapeList, ctx);
ASSERT_EQ(1, shapes->size());
auto s = shapes->at(0);
auto z = result->at(0);
// z->printShapeInfo("z shape");
ASSERT_TRUE(exp.isSameShape(z));
delete result;
shapes->destroy();
delete shapes;
}
//////////////////////////////////////////////////////////////////////
TYPED_TEST(TypedConvolutionTests1, pointwise_conv2d_test1) {
int bS=2, iH=4,iW=3, iC=4,oC=3;
int dataFormat = 1; // 1-NHWC, 0-NCHW
auto input = NDArrayFactory::create<TypeParam>('c', {bS, iH, iW, iC});
auto weights = NDArrayFactory::create<TypeParam>('c', {1, 1, iC, oC});
auto bias = NDArrayFactory::create<TypeParam>('c', {oC});
auto expOutput = NDArrayFactory::create<TypeParam>('c', {bS, iH, iW, oC},{ 5.4, 6.2, 7.0, 5.4, 6.2, 7.0, 5.4, 6.2, 7.0, 5.4, 6.2, 7.0, 5.4, 6.2, 7.0, 5.4, 6.2, 7.0, 5.4, 6.2,
7.0, 5.4, 6.2, 7.0, 5.4, 6.2, 7.0, 5.4, 6.2, 7.0, 5.4, 6.2, 7.0, 5.4, 6.2, 7.0, 5.4, 6.2, 7.0, 5.4,
6.2, 7.0, 5.4, 6.2, 7.0, 5.4, 6.2, 7.0, 5.4, 6.2, 7.0, 5.4, 6.2, 7.0, 5.4, 6.2, 7.0, 5.4, 6.2, 7.0,
5.4, 6.2, 7.0, 5.4, 6.2, 7.0, 5.4, 6.2, 7.0, 5.4, 6.2, 7.0});
input = 2.;
weights.linspace(0.1, 0.1);
bias = 1.;
nd4j::ops::pointwise_conv2d op;
auto results = op.execute({&input, &weights, &bias}, {}, {dataFormat});
auto* output = results->at(0);
ASSERT_EQ(Status::OK(), results->status());
ASSERT_TRUE(expOutput.isSameShape(output));
ASSERT_TRUE(expOutput.equalsTo(output));
delete results;
}
//////////////////////////////////////////////////////////////////////
TYPED_TEST(TypedConvolutionTests1, conv3d_test11) {
int bS=1, iD=2,iH=2,iW=2, iC=1,oC=1, kD=2,kH=2,kW=2, sD=1,sH=1,sW=1, pD=0,pH=0,pW=0, dD=1,dH=1,dW=1;
int paddingMode = 1; // 1-SAME, 0-VALID;
int dataFormat = 0; // 1-NDHWC, 0-NCDHW
auto input = NDArrayFactory::create<TypeParam>('c', {bS, iC, iD, iH, iW});
auto weights = NDArrayFactory::create<TypeParam>('c', {kD, kH, kW, iC, oC});
input = 2.;
weights = 1.;
nd4j::ops::conv3dnew op;
auto results = op.execute({&input, &weights}, {}, {kD,kH,kW, sD,sH,sW, pD,pH,pW, dD,dH,dW, paddingMode, dataFormat});
auto* output = results->at(0);
ASSERT_EQ(Status::OK(), results->status());
delete results;
}
//////////////////////////////////////////////////////////////////////
TYPED_TEST(TypedConvolutionTests1, conv3d_test12) {
int bS=5, iD=4,iH=14,iW=14, iC=1,oC=1, kD=2,kH=2,kW=2, sD=1,sH=1,sW=1, pD=0,pH=0,pW=0, dD=1,dH=1,dW=1;
int oD=3,oH=13,oW=13;
int paddingMode = 0; // 1-SAME, 0-VALID;
int dataFormat = 0; // 1-NDHWC, 0-NCDHW
auto input = NDArrayFactory::create<TypeParam>('c', {bS, iC, iD, iH, iW});
auto weights = NDArrayFactory::create<TypeParam>('c', {kD, kH, kW, iC, oC});
auto expected = NDArrayFactory::create<TypeParam>('c', {bS, oC, oD, oH, oW});
input = 2.;
weights = 1.;
nd4j::ops::conv3dnew op;
auto results = op.execute({&input, &weights}, {}, {kD,kH,kW, sD,sH,sW, pD,pH,pW, dD,dH,dW, paddingMode, dataFormat});
auto* output = results->at(0);
ASSERT_EQ(Status::OK(), results->status());
ASSERT_TRUE(output->isSameShape(&expected));
delete results;
}
//////////////////////////////////////////////////////////////////////
TEST_F(ConvolutionTests1, vol2col_test1) {
int bS=2, iD=2,iH=3,iW=2, iC=3,oC=2, kD=2,kH=3,kW=2, sD=1,sH=1,sW=1, pD=0,pH=0,pW=0, dD=1,dH=1,dW=1;
int oD=2,oH=3,oW=2;
NDArray volume('c', {bS, iC, iD, iH, iW}, nd4j::DataType::FLOAT32);
NDArray columns('c', {bS, iC, kD, kH, kW, oD, oH, oW}, nd4j::DataType::FLOAT32);
columns = -1.;
volume.linspace(1);
NDArray columnsExpected('c', {bS, iC, kD, kH, kW, oD, oH, oW}, {1., 2., 3., 4., 5., 6., 7., 8., 9., 10., 11., 12., 2., 0., 4., 0., 6.,0., 8., 0., 10., 0., 12., 0., 3., 4., 5., 6., 0., 0., 9., 10., 11., 12., 0., 0., 4., 0., 6., 0., 0., 0., 10., 0., 12., 0., 0., 0., 5., 6.,
0., 0., 0., 0., 11., 12., 0., 0., 0., 0., 6., 0., 0., 0., 0., 0., 12., 0., 0., 0., 0., 0., 7., 8., 9., 10., 11., 12., 0., 0., 0., 0., 0., 0., 8., 0., 10., 0., 12., 0., 0., 0., 0., 0., 0., 0., 9., 10., 11., 12., 0., 0., 0., 0., 0., 0., 0., 0., 10., 0., 12., 0., 0., 0., 0., 0.,
0., 0., 0., 0., 11., 12., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 12., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 13., 14., 15., 16., 17.,18., 19., 20., 21., 22., 23., 24., 14., 0., 16., 0., 18., 0., 20., 0., 22., 0., 24., 0., 15., 16., 17., 18., 0., 0., 21., 22., 23., 24., 0.,
0., 16., 0., 18., 0., 0., 0., 22., 0., 24., 0., 0., 0., 17., 18., 0., 0., 0., 0., 23., 24., 0., 0., 0., 0., 18., 0., 0., 0., 0., 0., 24., 0., 0., 0., 0., 0., 19., 20., 21., 22., 23., 24., 0., 0., 0., 0., 0., 0., 20., 0., 22., 0., 24., 0., 0., 0., 0., 0., 0., 0., 21., 22., 23.,
24., 0., 0., 0., 0., 0., 0., 0., 0., 22., 0., 24., 0., 0., 0., 0., 0., 0., 0., 0., 0., 23., 24., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 24.,0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 25., 26., 27., 28., 29., 30., 31., 32., 33., 34., 35., 36., 26., 0., 28., 0., 30., 0., 32., 0.,
34., 0., 36., 0., 27., 28., 29., 30., 0., 0., 33., 34., 35., 36., 0., 0., 28., 0., 30., 0., 0., 0., 34., 0., 36., 0., 0., 0., 29., 30., 0., 0., 0., 0., 35., 36., 0., 0., 0., 0., 30., 0., 0., 0., 0., 0., 36., 0., 0., 0., 0., 0., 31., 32., 33., 34., 35., 36., 0., 0., 0., 0., 0.,
0., 32., 0., 34., 0., 36., 0., 0., 0., 0., 0., 0., 0., 33., 34., 35., 36., 0., 0., 0., 0., 0., 0., 0., 0., 34., 0., 36., 0., 0., 0., 0., 0., 0., 0., 0., 0., 35., 36., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 36., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 37., 38., 39., 40.,
41., 42., 43., 44., 45., 46., 47., 48., 38., 0., 40., 0., 42., 0., 44., 0., 46., 0., 48., 0., 39., 40., 41., 42., 0., 0., 45., 46., 47., 48., 0., 0., 40., 0., 42., 0., 0., 0., 46., 0., 48., 0., 0., 0., 41., 42., 0., 0., 0., 0., 47., 48., 0., 0., 0., 0., 42., 0., 0., 0., 0.,
0., 48., 0., 0., 0., 0., 0., 43., 44., 45., 46., 47., 48., 0., 0., 0., 0., 0., 0., 44., 0., 46., 0., 48., 0., 0., 0., 0., 0., 0., 0., 45., 46., 47., 48., 0., 0., 0., 0., 0., 0., 0., 0., 46., 0., 48., 0., 0., 0., 0., 0., 0., 0., 0., 0., 47., 48., 0., 0., 0., 0., 0., 0., 0., 0.,
0., 0., 48., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 49., 50., 51., 52., 53., 54., 55., 56., 57., 58., 59., 60., 50., 0., 52., 0., 54.,0., 56., 0., 58., 0., 60., 0., 51., 52., 53., 54., 0., 0., 57., 58., 59., 60., 0., 0., 52., 0., 54., 0., 0., 0., 58., 0., 60., 0., 0., 0.,
53., 54., 0., 0., 0., 0., 59., 60., 0., 0., 0., 0., 54., 0., 0., 0., 0., 0., 60., 0., 0., 0., 0., 0., 55., 56., 57., 58., 59., 60., 0., 0.,0., 0., 0., 0., 56., 0., 58., 0., 60., 0., 0., 0., 0., 0., 0., 0., 57., 58., 59., 60., 0., 0., 0., 0., 0., 0., 0., 0., 58., 0., 60., 0.,
0., 0., 0., 0., 0., 0., 0., 0., 59., 60., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 60., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 61., 62., 63., 64., 65., 66., 67., 68., 69., 70., 71., 72., 62., 0., 64., 0., 66., 0., 68., 0., 70., 0., 72., 0., 63., 64., 65., 66., 0., 0., 69.,
70., 71., 72., 0., 0., 64., 0., 66., 0., 0., 0., 70., 0., 72., 0., 0., 0., 65., 66., 0., 0., 0., 0., 71., 72., 0., 0., 0., 0., 66., 0., 0., 0., 0., 0., 72., 0., 0., 0., 0., 0., 67., 68., 69., 70., 71., 72., 0., 0., 0., 0., 0., 0., 68., 0., 70., 0., 72., 0., 0., 0., 0., 0., 0.,
0., 69., 70., 71., 72., 0., 0., 0., 0., 0., 0., 0., 0., 70., 0., 72., 0., 0., 0., 0., 0., 0., 0., 0., 0., 71., 72., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 72., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.}, nd4j::DataType::FLOAT32);
graph::Context context(1);
nd4j::ops::ConvolutionUtils::vol2col(context, volume, columns, sD, sH, sW, pD, pH, pW, dD, dH, dW);
// columns.printBuffer();
ASSERT_TRUE(columns.equalsTo(columnsExpected));
}
//////////////////////////////////////////////////////////////////////
TEST_F(ConvolutionTests1, vol2col_test2) {
int bS=2, iD=2,iH=3,iW=2, iC=3,oC=2, kD=2,kH=3,kW=2, sD=1,sH=1,sW=1, pD=0,pH=0,pW=0, dD=1,dH=1,dW=1;
int oD=2,oH=3,oW=2;
auto volume = NDArrayFactory::create<float>('c', {iD, bS, iH, iC, iW});
volume.permutei({1, 3, 0, 2, 4});
volume.linspace(1);
auto columns = NDArrayFactory::create<float>('c', {kD, iC, kH, oW, kW, bS, oD, oH});
columns.permutei({5, 1, 0, 2, 4, 6, 7, 3});
columns = -1.;
auto columnsExpected = NDArrayFactory::create<float>('c', {bS, iC, kD, kH, kW, oD, oH, oW}, {1., 2., 3., 4., 5., 6., 7., 8., 9.,
10., 11., 12., 2., 0., 4., 0., 6., 0., 8., 0., 10., 0., 12., 0., 3., 4., 5., 6., 0., 0., 9., 10., 11., 12., 0., 0., 4., 0., 6., 0., 0.,0., 10., 0., 12., 0., 0., 0., 5., 6., 0., 0., 0., 0., 11., 12., 0., 0., 0., 0., 6., 0., 0., 0., 0., 0., 12., 0., 0., 0., 0., 0., 7., 8.,
9., 10., 11., 12., 0., 0., 0., 0., 0., 0., 8., 0., 10., 0., 12., 0., 0., 0., 0., 0., 0., 0., 9., 10., 11., 12., 0., 0., 0., 0., 0., 0., 0., 0., 10., 0., 12., 0., 0., 0., 0., 0., 0., 0., 0., 0., 11., 12., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 12., 0., 0., 0., 0., 0.,
0., 0., 0., 0., 0., 0., 13., 14., 15., 16., 17., 18., 19., 20., 21., 22., 23., 24., 14., 0., 16., 0., 18., 0., 20., 0., 22., 0., 24., 0., 15., 16., 17., 18., 0., 0., 21., 22., 23., 24., 0., 0., 16., 0., 18., 0., 0., 0., 22., 0., 24., 0., 0., 0., 17., 18., 0., 0., 0., 0.,
23., 24., 0., 0., 0., 0., 18., 0., 0., 0., 0., 0., 24., 0., 0., 0., 0., 0., 19., 20., 21., 22., 23., 24., 0., 0., 0., 0., 0., 0., 20., 0., 22., 0., 24., 0., 0., 0., 0., 0., 0., 0., 21., 22., 23., 24., 0., 0., 0., 0., 0., 0., 0., 0., 22., 0., 24., 0., 0., 0., 0., 0., 0.,
0., 0., 0., 23., 24., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 24., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 25., 26., 27., 28., 29., 30., 31., 32., 33., 34., 35., 36., 26., 0., 28., 0., 30., 0., 32., 0., 34., 0., 36., 0., 27., 28., 29., 30., 0., 0., 33., 34., 35., 36.,
0., 0., 28., 0., 30., 0., 0., 0., 34., 0., 36., 0., 0., 0., 29., 30., 0., 0., 0., 0., 35., 36., 0., 0., 0., 0., 30., 0., 0., 0., 0., 0., 36., 0., 0., 0., 0., 0., 31., 32., 33., 34., 35., 36., 0., 0., 0., 0., 0., 0., 32., 0., 34., 0., 36., 0., 0., 0., 0., 0., 0., 0., 33.,
34., 35., 36., 0., 0., 0., 0., 0., 0., 0., 0., 34., 0., 36., 0., 0., 0., 0., 0., 0., 0., 0., 0., 35., 36., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 36., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 37., 38., 39., 40., 41., 42., 43., 44., 45., 46., 47., 48., 38., 0., 40.,
0., 42., 0., 44., 0., 46., 0., 48., 0., 39., 40., 41., 42., 0., 0., 45., 46., 47., 48., 0., 0., 40., 0., 42., 0., 0., 0., 46., 0., 48., 0., 0., 0., 41., 42., 0., 0., 0., 0., 47., 48., 0., 0., 0., 0., 42., 0., 0., 0., 0., 0., 48., 0., 0., 0., 0., 0., 43., 44., 45., 46., 47.,
48., 0., 0., 0., 0., 0., 0., 44., 0., 46., 0., 48., 0., 0., 0., 0., 0., 0., 0., 45., 46., 47., 48., 0., 0., 0., 0., 0., 0., 0., 0., 46., 0., 48., 0., 0., 0., 0., 0., 0., 0., 0., 0., 47., 48., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 48., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
0., 0., 49., 50., 51., 52., 53., 54., 55., 56., 57., 58., 59., 60., 50., 0., 52., 0., 54., 0., 56., 0., 58., 0., 60., 0., 51., 52., 53., 54., 0., 0., 57., 58., 59., 60., 0., 0., 52., 0., 54., 0., 0., 0., 58., 0., 60., 0., 0., 0., 53., 54., 0., 0., 0., 0., 59., 60., 0., 0.,
0., 0., 54., 0., 0., 0., 0., 0., 60., 0., 0., 0., 0., 0., 55., 56., 57., 58., 59., 60., 0., 0., 0., 0., 0., 0., 56., 0., 58., 0., 60., 0., 0., 0., 0., 0., 0., 0., 57., 58., 59., 60., 0., 0., 0., 0., 0., 0., 0., 0., 58., 0., 60., 0., 0., 0., 0., 0., 0., 0., 0., 0., 59., 60.,
0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 60., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 61., 62., 63., 64., 65., 66., 67., 68., 69., 70., 71., 72., 62., 0., 64., 0., 66., 0., 68., 0., 70., 0., 72., 0., 63., 64., 65., 66., 0., 0., 69., 70., 71., 72., 0., 0., 64., 0., 66.,
0., 0., 0., 70., 0., 72., 0., 0., 0., 65., 66., 0., 0., 0., 0., 71., 72., 0., 0., 0., 0., 66., 0., 0., 0., 0., 0., 72., 0., 0., 0., 0., 0., 67., 68., 69., 70., 71., 72., 0., 0., 0., 0., 0., 0., 68., 0., 70., 0., 72., 0., 0., 0., 0., 0., 0., 0., 69., 70., 71., 72., 0., 0.,
0., 0., 0., 0., 0., 0., 70., 0., 72., 0., 0., 0., 0., 0., 0., 0., 0., 0., 71., 72., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 72., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.});
graph::Context context(1);
nd4j::ops::ConvolutionUtils::vol2col(context, volume, columns, sD, sH, sW, pD, pH, pW, dD, dH, dW);
// columns.printBuffer();
ASSERT_TRUE(columns.equalsTo(columnsExpected));
}
//////////////////////////////////////////////////////////////////////
TEST_F(ConvolutionTests1, col2im_test1) {
int bS=2, iH=2,iW=2, iC=2, kH=2,kW=2, sD=1,sH=1,sW=1, pD=0,pH=0,pW=0, dD=1,dH=1,dW=1;
int oH=2,oW=2;
auto image = NDArrayFactory::create<float>('c', {bS, iC, iH, iW});
image = -2.;
auto columns = NDArrayFactory::create<float>('c', {bS, iC, kH, kW, oH, oW});
columns.linspace(1);
auto imageExpected = NDArrayFactory::create<float>('c', {bS, iC, iH, iW}, {1., 7., 12., 34., 17., 39., 44., 98., 33., 71., 76., 162., 49., 103., 108., 226.});
LaunchContext ctx;
nd4j::ops::helpers::col2im(ctx, columns, image, sH, sW, pH, pW, iH, iW, dH, dW);
ASSERT_TRUE(image.equalsTo(imageExpected));
}
//////////////////////////////////////////////////////////////////////
TEST_F(ConvolutionTests1, upsampling2d_test1) {
const int bS=3, iH=2,iW=2, iC=3;
const int factorH=2, factorW=3;
const int isNCHW = 0; // data format, default is NCHW
auto input = NDArrayFactory::create<float>('c', {bS, iH, iW, iC});
input.linspace(1);
auto expOutput = NDArrayFactory::create<float>('c', {bS, iH*factorH, iW*factorW, iC}, {1., 2., 3., 1., 2., 3., 1., 2., 3., 4., 5., 6., 4., 5., 6., 4., 5., 6., 1., 2., 3., 1., 2., 3., 1., 2., 3., 4., 5., 6., 4., 5., 6., 4., 5., 6.,
7., 8., 9., 7., 8., 9., 7., 8., 9.,10., 11., 12.,10., 11., 12.,10., 11., 12., 7., 8., 9., 7., 8., 9., 7., 8., 9.,10., 11., 12.,10., 11., 12.,10., 11., 12.,
13., 14., 15.,13., 14., 15.,13., 14., 15.,16., 17., 18.,16., 17., 18.,16., 17., 18.,13., 14., 15.,13., 14., 15.,13., 14., 15.,16., 17., 18.,16., 17., 18.,16., 17., 18.,
19., 20., 21.,19., 20., 21.,19., 20., 21.,22., 23., 24.,22., 23., 24.,22., 23., 24.,19., 20., 21.,19., 20., 21.,19., 20., 21.,22., 23., 24.,22., 23., 24.,22., 23., 24.,
25., 26., 27.,25., 26., 27.,25., 26., 27.,28., 29., 30.,28., 29., 30.,28., 29., 30.,25., 26., 27.,25., 26., 27.,25., 26., 27.,28., 29., 30.,28., 29., 30.,28., 29., 30.,
31., 32., 33.,31., 32., 33.,31., 32., 33.,34., 35., 36.,34., 35., 36.,34., 35., 36.,31., 32., 33.,31., 32., 33.,31., 32., 33.,34., 35., 36.,34., 35., 36.,34., 35., 36.});
nd4j::ops::upsampling2d op;
auto results = op.execute({&input}, {}, {factorH, factorW, isNCHW});
auto* output = results->at(0);
ASSERT_EQ(Status::OK(), results->status());
ASSERT_TRUE(expOutput.isSameShape(output));
ASSERT_TRUE(expOutput.equalsTo(output));
delete results;
}
//////////////////////////////////////////////////////////////////////
TEST_F(ConvolutionTests1, upsampling2d_test2) {
const int bS=3, iH=2,iW=2, iC=3;
const int factorH=2, factorW=3;
const int isNCHW = 1; // data format, default is NCHW
auto input = NDArrayFactory::create<float>('c', {bS, iC, iH, iW});
input.linspace(1);
auto expOutput = NDArrayFactory::create<float>('c', {bS, iC, iH*factorH, iW*factorW}, {1., 1., 1., 2., 2., 2., 1., 1., 1., 2., 2., 2., 3., 3., 3., 4., 4., 4., 3., 3., 3., 4., 4., 4.,
5., 5., 5., 6., 6., 6., 5., 5., 5., 6., 6., 6., 7., 7., 7., 8., 8., 8., 7., 7., 7., 8., 8., 8., 9., 9., 9., 10., 10., 10., 9., 9., 9., 10., 10., 10.,11., 11., 11., 12., 12., 12.,11., 11., 11., 12., 12., 12.,
13., 13., 13., 14., 14., 14.,13., 13., 13., 14., 14., 14.,15., 15., 15., 16., 16., 16.,15., 15., 15., 16., 16., 16.,17., 17., 17., 18., 18., 18.,17., 17., 17., 18., 18., 18.,19., 19., 19., 20., 20., 20.,19., 19., 19., 20., 20., 20.,
21., 21., 21., 22., 22., 22.,21., 21., 21., 22., 22., 22.,23., 23., 23., 24., 24., 24.,23., 23., 23., 24., 24., 24.,25., 25., 25., 26., 26., 26.,25., 25., 25., 26., 26., 26.,27., 27., 27., 28., 28., 28.,27., 27., 27., 28., 28., 28.,
29., 29., 29., 30., 30., 30.,29., 29., 29., 30., 30., 30.,31., 31., 31., 32., 32., 32.,31., 31., 31., 32., 32., 32.,
33., 33., 33., 34., 34., 34.,33., 33., 33., 34., 34., 34.,35., 35., 35., 36., 36., 36.,35., 35., 35., 36., 36., 36.});
nd4j::ops::upsampling2d op;
auto results = op.execute({&input}, {}, {factorH, factorW, isNCHW});
auto* output = results->at(0);
ASSERT_EQ(Status::OK(), results->status());
ASSERT_TRUE(expOutput.isSameShape(output));
ASSERT_TRUE(expOutput.equalsTo(output));
delete results;
}
//////////////////////////////////////////////////////////////////////
TEST_F(ConvolutionTests1, upsampling3d_test1) {
const int bS=3, iD=2,iH=2,iW=2, iC=3;
const int factorD=2,factorH=3,factorW=2;
const int isNCDHW = 0; // data format, default is NCHW
auto input = NDArrayFactory::create<float>('c', {bS, iD, iH, iW, iC});
input.linspace(1);
auto expOutput = NDArrayFactory::create<float>('c', {bS, iD*factorD, iH*factorH, iW*factorW, iC}, {1., 2., 3., 1., 2., 3., 4., 5., 6., 4., 5., 6., 1., 2., 3., 1., 2., 3., 4., 5., 6., 4., 5., 6., 1., 2., 3., 1., 2., 3., 4., 5., 6., 4., 5., 6., 7., 8., 9., 7., 8., 9.,10., 11., 12.,10., 11., 12., 7., 8., 9., 7., 8., 9.,10., 11., 12.,10., 11., 12.,
7., 8., 9., 7., 8., 9.,10., 11., 12.,10., 11., 12., 1., 2., 3., 1., 2., 3., 4., 5., 6., 4., 5., 6., 1., 2., 3., 1., 2., 3., 4., 5., 6., 4., 5., 6., 1., 2., 3., 1., 2., 3., 4., 5., 6., 4., 5., 6., 7., 8., 9., 7., 8., 9.,10., 11., 12.,10., 11., 12.,
7., 8., 9., 7., 8., 9.,10., 11., 12.,10., 11., 12., 7., 8., 9., 7., 8., 9.,10., 11., 12.,10., 11., 12.,13., 14., 15.,13., 14., 15.,16., 17., 18.,16., 17., 18.,13., 14., 15.,13., 14., 15.,16., 17., 18.,16., 17., 18.,13., 14., 15.,13., 14., 15.,16., 17., 18.,16., 17., 18.,
19., 20., 21.,19., 20., 21.,22., 23., 24.,22., 23., 24.,19., 20., 21.,19., 20., 21.,22., 23., 24.,22., 23., 24.,19., 20., 21.,19., 20., 21.,22., 23., 24.,22., 23., 24.,13., 14., 15.,13., 14., 15.,16., 17., 18.,16., 17., 18.,13., 14., 15.,13., 14., 15.,16., 17., 18.,16., 17., 18.,
13., 14., 15.,13., 14., 15.,16., 17., 18.,16., 17., 18.,19., 20., 21.,19., 20., 21.,22., 23., 24.,22., 23., 24.,19., 20., 21.,19., 20., 21.,22., 23., 24.,22., 23., 24.,19., 20., 21.,19., 20., 21.,22., 23., 24.,22., 23., 24.,25., 26., 27.,25., 26., 27.,28., 29., 30.,28., 29., 30.,
25., 26., 27.,25., 26., 27.,28., 29., 30.,28., 29., 30.,25., 26., 27.,25., 26., 27.,28., 29., 30.,28., 29., 30.,31., 32., 33.,31., 32., 33.,34., 35., 36.,34., 35., 36.,31., 32., 33.,31., 32., 33.,34., 35., 36.,34., 35., 36.,31., 32., 33.,31., 32., 33.,34., 35., 36.,34., 35., 36.,
25., 26., 27.,25., 26., 27.,28., 29., 30.,28., 29., 30.,25., 26., 27.,25., 26., 27.,28., 29., 30.,28., 29., 30.,25., 26., 27.,25., 26., 27.,28., 29., 30.,28., 29., 30.,31., 32., 33.,31., 32., 33.,34., 35., 36.,34., 35., 36.,31., 32., 33.,31., 32., 33.,34., 35., 36.,34., 35., 36.,
31., 32., 33.,31., 32., 33.,34., 35., 36.,34., 35., 36.,37., 38., 39.,37., 38., 39.,40., 41., 42.,40., 41., 42.,37., 38., 39.,37., 38., 39.,40., 41., 42.,40., 41., 42.,37., 38., 39.,37., 38., 39.,40., 41., 42.,40., 41., 42.,43., 44., 45.,43., 44., 45.,46., 47., 48.,46., 47., 48.,
43., 44., 45.,43., 44., 45.,46., 47., 48.,46., 47., 48.,43., 44., 45.,43., 44., 45.,46., 47., 48.,46., 47., 48.,37., 38., 39.,37., 38., 39.,40., 41., 42.,40., 41., 42.,37., 38., 39.,37., 38., 39.,40., 41., 42.,40., 41., 42.,37., 38., 39.,37., 38., 39.,40., 41., 42.,40., 41., 42.,
43., 44., 45.,43., 44., 45.,46., 47., 48.,46., 47., 48.,43., 44., 45.,43., 44., 45.,46., 47., 48.,46., 47., 48.,43., 44., 45.,43., 44., 45.,46., 47., 48.,46., 47., 48.,49., 50., 51.,49., 50., 51.,52., 53., 54.,52., 53., 54.,49., 50., 51.,49., 50., 51.,52., 53., 54.,52., 53., 54.,
49., 50., 51.,49., 50., 51.,52., 53., 54.,52., 53., 54.,55., 56., 57.,55., 56., 57.,58., 59., 60.,58., 59., 60.,55., 56., 57.,55., 56., 57.,58., 59., 60.,58., 59., 60.,55., 56., 57.,55., 56., 57.,58., 59., 60.,58., 59., 60.,49., 50., 51.,49., 50., 51.,52., 53., 54.,52., 53., 54.,
49., 50., 51.,49., 50., 51.,52., 53., 54.,52., 53., 54.,49., 50., 51.,49., 50., 51.,52., 53., 54.,52., 53., 54.,55., 56., 57.,55., 56., 57.,58., 59., 60.,58., 59., 60.,55., 56., 57.,55., 56., 57.,58., 59., 60.,58., 59., 60.,55., 56., 57.,55., 56., 57.,58., 59., 60.,58., 59., 60.,
61., 62., 63.,61., 62., 63.,64., 65., 66.,64., 65., 66.,61., 62., 63.,61., 62., 63.,64., 65., 66.,64., 65., 66.,61., 62., 63.,61., 62., 63.,64., 65., 66.,64., 65., 66.,67., 68., 69.,67., 68., 69.,70., 71., 72.,70., 71., 72.,67., 68., 69.,67., 68., 69.,70., 71., 72.,70., 71., 72.,
67., 68., 69.,67., 68., 69.,70., 71., 72.,70., 71., 72.,61., 62., 63.,61., 62., 63.,64., 65., 66.,64., 65., 66.,61., 62., 63.,61., 62., 63.,64., 65., 66.,64., 65., 66.,61., 62., 63.,61., 62., 63.,64., 65., 66.,64., 65., 66.,67., 68., 69.,67., 68., 69.,70., 71., 72.,70., 71., 72.,
67., 68., 69.,67., 68., 69.,70., 71., 72.,70., 71., 72.,67., 68., 69.,67., 68., 69.,70., 71., 72.,70., 71., 72.});
nd4j::ops::upsampling3d op;
auto results = op.execute({&input}, {}, {factorD, factorH, factorW, isNCDHW});
auto* output = results->at(0);
ASSERT_EQ(Status::OK(), results->status());
ASSERT_TRUE(expOutput.isSameShape(output));
ASSERT_TRUE(expOutput.equalsTo(output));
delete results;
}
//////////////////////////////////////////////////////////////////////
TEST_F(ConvolutionTests1, upsampling3d_test2) {
const int bS=3, iD=2,iH=2,iW=2, iC=3;
const int factorD=2,factorH=3,factorW=2;
const int isNCDHW = 1; // data format, default is NCHW
auto input = NDArrayFactory::create<float>('c', {bS, iC, iD, iH, iW});
input.linspace(1);
auto expOutput = NDArrayFactory::create<float>('c', {bS, iC, iD*factorD, iH*factorH, iW*factorW}, { 1., 1., 2., 2., 1., 1., 2., 2., 1., 1., 2., 2., 3., 3., 4., 4., 3., 3., 4., 4., 3., 3., 4., 4., 1., 1., 2., 2., 1., 1., 2., 2., 1., 1., 2., 2., 3., 3., 4., 4., 3., 3., 4., 4., 3., 3., 4., 4., 5., 5., 6., 6., 5., 5., 6., 6., 5., 5., 6., 6., 7., 7., 8., 8., 7., 7., 8., 8., 7., 7., 8., 8.,
5., 5., 6., 6., 5., 5., 6., 6., 5., 5., 6., 6., 7., 7., 8., 8., 7., 7., 8., 8., 7., 7., 8., 8., 9., 9., 10., 10., 9., 9., 10., 10., 9., 9., 10., 10.,11., 11., 12., 12.,11., 11., 12., 12.,11., 11., 12., 12., 9., 9., 10., 10., 9., 9., 10., 10., 9., 9., 10., 10.,11., 11., 12., 12.,11., 11., 12., 12.,11., 11., 12., 12.,
13., 13., 14., 14.,13., 13., 14., 14.,13., 13., 14., 14.,15., 15., 16., 16.,15., 15., 16., 16.,15., 15., 16., 16.,13., 13., 14., 14.,13., 13., 14., 14.,13., 13., 14., 14.,15., 15., 16., 16.,15., 15., 16., 16.,15., 15., 16., 16.,17., 17., 18., 18.,17., 17., 18., 18.,17., 17., 18., 18.,19., 19., 20., 20.,19., 19., 20., 20.,19., 19., 20., 20.,
17., 17., 18., 18.,17., 17., 18., 18.,17., 17., 18., 18.,19., 19., 20., 20.,19., 19., 20., 20.,19., 19., 20., 20.,21., 21., 22., 22.,21., 21., 22., 22.,21., 21., 22., 22.,23., 23., 24., 24.,23., 23., 24., 24.,23., 23., 24., 24.,21., 21., 22., 22.,21., 21., 22., 22.,21., 21., 22., 22.,23., 23., 24., 24.,23., 23., 24., 24.,23., 23., 24., 24.,
25., 25., 26., 26.,25., 25., 26., 26.,25., 25., 26., 26.,27., 27., 28., 28.,27., 27., 28., 28.,27., 27., 28., 28.,25., 25., 26., 26.,25., 25., 26., 26.,25., 25., 26., 26.,27., 27., 28., 28.,27., 27., 28., 28.,27., 27., 28., 28.,29., 29., 30., 30.,29., 29., 30., 30.,29., 29., 30., 30.,31., 31., 32., 32.,31., 31., 32., 32.,31., 31., 32., 32.,
29., 29., 30., 30.,29., 29., 30., 30.,29., 29., 30., 30.,31., 31., 32., 32.,31., 31., 32., 32.,31., 31., 32., 32.,33., 33., 34., 34.,33., 33., 34., 34.,33., 33., 34., 34.,35., 35., 36., 36.,35., 35., 36., 36.,35., 35., 36., 36.,33., 33., 34., 34.,33., 33., 34., 34.,33., 33., 34., 34.,35., 35., 36., 36.,35., 35., 36., 36.,35., 35., 36., 36.,
37., 37., 38., 38.,37., 37., 38., 38.,37., 37., 38., 38.,39., 39., 40., 40.,39., 39., 40., 40.,39., 39., 40., 40.,37., 37., 38., 38.,37., 37., 38., 38.,37., 37., 38., 38.,39., 39., 40., 40.,39., 39., 40., 40.,39., 39., 40., 40.,41., 41., 42., 42.,41., 41., 42., 42.,41., 41., 42., 42.,43., 43., 44., 44.,43., 43., 44., 44.,43., 43., 44., 44.,
41., 41., 42., 42.,41., 41., 42., 42.,41., 41., 42., 42.,43., 43., 44., 44.,43., 43., 44., 44.,43., 43., 44., 44.,45., 45., 46., 46.,45., 45., 46., 46.,45., 45., 46., 46.,47., 47., 48., 48.,47., 47., 48., 48.,47., 47., 48., 48.,45., 45., 46., 46.,45., 45., 46., 46.,45., 45., 46., 46.,47., 47., 48., 48.,47., 47., 48., 48.,47., 47., 48., 48.,
49., 49., 50., 50.,49., 49., 50., 50.,49., 49., 50., 50.,51., 51., 52., 52.,51., 51., 52., 52.,51., 51., 52., 52.,49., 49., 50., 50.,49., 49., 50., 50.,49., 49., 50., 50.,51., 51., 52., 52.,51., 51., 52., 52.,51., 51., 52., 52.,53., 53., 54., 54.,53., 53., 54., 54.,53., 53., 54., 54.,55., 55., 56., 56.,55., 55., 56., 56.,55., 55., 56., 56.,
53., 53., 54., 54.,53., 53., 54., 54.,53., 53., 54., 54.,55., 55., 56., 56.,55., 55., 56., 56.,55., 55., 56., 56.,57., 57., 58., 58.,57., 57., 58., 58.,57., 57., 58., 58.,59., 59., 60., 60.,59., 59., 60., 60.,59., 59., 60., 60.,57., 57., 58., 58.,57., 57., 58., 58.,57., 57., 58., 58.,59., 59., 60., 60.,59., 59., 60., 60.,59., 59., 60., 60.,
61., 61., 62., 62.,61., 61., 62., 62.,61., 61., 62., 62.,63., 63., 64., 64.,63., 63., 64., 64.,63., 63., 64., 64.,61., 61., 62., 62.,61., 61., 62., 62.,61., 61., 62., 62.,63., 63., 64., 64.,63., 63., 64., 64.,63., 63., 64., 64.,65., 65., 66., 66.,65., 65., 66., 66.,65., 65., 66., 66.,67., 67., 68., 68.,67., 67., 68., 68.,67., 67., 68., 68.,
65., 65., 66., 66.,65., 65., 66., 66.,65., 65., 66., 66.,67., 67., 68., 68.,67., 67., 68., 68.,67., 67., 68., 68.,69., 69., 70., 70.,69., 69., 70., 70.,69., 69., 70., 70.,71., 71., 72., 72.,71., 71., 72., 72.,71., 71., 72., 72.,69., 69., 70., 70.,69., 69., 70., 70.,69., 69., 70., 70.,71., 71., 72., 72.,71., 71., 72., 72.,71., 71., 72., 72.});
nd4j::ops::upsampling3d op;
auto results = op.execute({&input}, {}, {factorD, factorH, factorW, isNCDHW});
auto* output = results->at(0);
ASSERT_EQ(Status::OK(), results->status());
ASSERT_TRUE(expOutput.isSameShape(output));
ASSERT_TRUE(expOutput.equalsTo(output));
delete results;
}
//////////////////////////////////////////////////////////////////////
TEST_F(ConvolutionTests1, upsampling3d_bp_test1) {
const int bS=1, iD=2,iH=2,iW=2, iC=1;
const int factorD=2, factorH=2, factorW=2;
const int isNCDHW = 1; // data format, default is NCHW
auto input = NDArrayFactory::create<float>('c', {bS, iC, iD, iH, iW});
auto gradO = NDArrayFactory::create<float>('c', {bS, iC, iD*factorD, iH*factorH, iW*factorW});
gradO = 1.;
auto expGradI = NDArrayFactory::create<float>('c', {bS, iC, iD, iH, iW});
expGradI = 8.;
nd4j::ops::upsampling3d_bp op;
auto results = op.execute({&input, &gradO}, {}, {isNCDHW});
auto* gradI = results->at(0);
ASSERT_EQ(Status::OK(), results->status());
ASSERT_TRUE(expGradI.isSameShape(gradI));
ASSERT_TRUE(expGradI.equalsTo(gradI));
delete results;
}
TYPED_TEST(TypedConvolutionTests1, conv2D_input_BP_test1) {
auto inputShape = NDArrayFactory::create<TypeParam>('c', {4}, {2, 1, 4, 4});
auto weights = NDArrayFactory::create<TypeParam>('c', {2, 1, 3, 3});
auto epsilonNext = NDArrayFactory::create<TypeParam>('c', {2, 2, 4, 4});
auto shapeArr = NDArrayFactory::create<TypeParam>('c', {2, 1, 4, 4});
TypeParam _expEpsB[] = {952.0, 1540.0, 1636.0, 1180.0, 1791.0, 2886.0, 3057.0, 2193.0, 2223.0, 3570.0, 3741.0, 2673.0, 1900.0, 3028.0, 3160.0, 2240.0, 2872.0, 4612.0, 4708.0, 3356.0, 5247.0, 8358.0, 8529.0, 6033.0, 5679.0, 9042.0, 9213.0, 6513.0, 4588.0, 7252.0, 7384.0, 5184.0};
NDArray expEps(_expEpsB, shapeArr.getShapeInfo());
weights.linspace(1);
epsilonNext.linspace(1);
weights.permutei({2,3,1,0});
nd4j::ops::conv2d_input_bp op;
auto results = op.execute({&inputShape, &weights, &epsilonNext}, {}, {3, 3, 1, 1, 0, 0, 1, 1, 1});
ASSERT_TRUE(results->size() == 1);
auto epsilon = results->at(0);
ASSERT_TRUE(shapeArr.isSameShape(epsilon));
ASSERT_TRUE(expEps.equalsTo(epsilon));
delete results;
}
//////////////////////////////////////////////////////////////////////
TEST_F(ConvolutionTests1, upsampling3d_bp_test3) {
const int bS=1, iD=3,iH=3,iW=3, iC=2;
const int factorD=2, factorH=2, factorW=2;
const int isNCDHW = 1; // data format, default is NCHW
NDArray input('c', {bS, iC, iD, iH, iW}, nd4j::DataType::FLOAT32);
NDArray gradO('c', {bS, iC, iD*factorD, iH*factorH, iW*factorW}, {0.6793504, 0.35508695, 0.84278935, 0.20031333, 0.7014987, 0.31069338,
0.44793984, 0.93800974, 0.32667395, 0.15187258, 0.38331753, 0.78212297, 0.1988072, 0.7985636, 0.1632634, 0.14696825, 0.26089668,
0.13505761, 0.7562093, 0.27545404, 0.36908787, 0.09282647, 0.83649176, 0.26841334, 0.09506222, 0.31279507, 0.13591796, 0.5175439,
0.32870287, 0.061735712, 0.39643127, 0.248016, 0.5489592, 0.115046196, 0.8143622, 0.7215636, 0.40449402, 0.29908907, 0.4038839,
0.9883108, 0.022296403, 0.927782, 0.3184157, 0.0685462, 0.28453344, 0.23272, 0.35214192, 0.058909304, 0.7112212, 0.6744568, 0.19694561,
0.6994972, 0.0743224, 0.42042503, 0.5842631, 0.14957358, 0.44640633, 0.72307247, 0.06448108, 0.48307765, 0.8759956, 0.5698191, 0.4458631,
0.5277549, 0.016646361, 0.753678, 0.14063567, 0.7541292, 0.16193217, 0.7750374, 0.3326449, 0.11739397, 0.017710684, 0.60847557, 0.52515227,
0.9171938, 0.84989065, 0.5894228, 0.85227835, 0.39063585, 0.88968325, 0.6694452, 0.698873, 0.96147966, 0.15740126, 0.15736352, 0.49352047,
0.5699365, 0.12683152, 0.11572781, 0.7863682, 0.737939, 0.49007934, 0.6084143, 0.9564999, 0.3900982, 0.14730452, 0.8506447, 0.49765033,
0.07186628, 0.08214969, 0.035314173, 0.7320408, 0.36993408, 0.8406658, 0.27389422, 0.43179566, 0.13323106, 0.19297548, 0.24689731, 0.38641843,
0.51154125, 0.19903564, 0.1416313, 0.69769853, 0.25363067, 0.78221816, 0.9300991, 0.3355119, 0.5588076, 0.6643576, 0.018850708, 0.63755876,
0.2904297, 0.43490165, 0.84251267, 0.46609768, 0.38139546, 0.52318525, 0.9901826, 0.9257676, 0.6434591, 0.016828254, 0.9187561, 0.22897908,
0.0063138064, 0.66597503, 0.19036093, 0.59552056, 0.69888055, 0.22146936, 0.9124342, 0.8708221, 0.7273687, 0.52397245, 0.66288394, 0.2188415,
0.3354802, 0.03566524, 0.5101009, 0.5017283, 0.75122046, 0.1884508, 0.7407126, 0.6253045, 0.47145858, 0.5369367, 0.19884548, 0.99008304,
0.08256686, 0.91884845, 0.02360027, 0.98895234, 0.3751719, 0.91783875, 0.4338776, 0.6783008, 0.6667967, 0.46720362, 0.7508773, 0.52304846,
0.76631916, 0.4187526, 0.7653719, 0.5159193, 0.42730415, 0.49462363, 0.2731735, 0.8862948, 0.043214794, 0.3197591, 0.040378205, 0.5427239,
0.9228089, 0.045940384, 0.70047987, 0.8419288, 0.53966296, 0.009444186, 0.038044546, 0.03158029, 0.43485752, 0.9204235, 0.5478789, 0.8290083,
0.11868837, 0.0229866, 0.6639305, 0.8757367, 0.8279557, 0.76270294, 0.43242732, 0.4713431, 0.2569212, 0.30575937, 0.44395888, 0.99384075,
0.6127142, 0.44844577, 0.6347944, 0.098358564, 0.34233716, 0.9329664, 0.65776783, 0.108565055, 0.2052629, 0.46441218, 0.041791342, 0.89369565,
0.7000381, 0.2106213, 0.51152664, 0.44200692, 0.8293282, 0.20901772, 0.6387249, 0.8016979, 0.11178707, 0.109545894, 0.19654618, 0.060582615,
0.08239174, 0.64630795, 0.32862368, 0.60225064, 0.8328141, 0.5484566, 0.8120276, 0.38822946, 0.6742381, 0.34913155, 0.42887798, 0.45344824,
0.73956585, 0.9714739, 0.42937812, 0.45185348, 0.84535813, 0.046436775, 0.8802151, 0.8676222, 0.42625394, 0.4985318, 0.42399272, 0.122144565,
0.0060101906, 0.47253844, 0.18123977, 0.86316174, 0.5863874, 0.3852012, 0.9785553, 0.0054711984, 0.88500834, 0.020897374, 0.27467912, 0.3852802,
0.0766939, 0.94622654, 0.38687763, 0.3308602, 0.7770494, 0.9052543, 0.22258204, 0.42207044, 0.18050623, 0.21057767, 0.012561422, 0.7977821,
0.61251044, 0.7203693, 0.6028265, 0.6036933, 0.1446382, 0.6712341, 0.76634467, 0.4854034, 0.26634562, 0.76523924, 0.16348523, 0.2663676,
0.96846986, 0.8273284, 0.10700377, 0.7600526, 0.6771002, 0.47963092, 0.21264452, 0.56934077, 0.5514792, 0.85725874, 0.99090636, 0.54562527,
0.93597686, 0.21142527, 0.4628326, 0.35011524, 0.31464386, 0.31164807, 0.65928996, 0.94418925, 0.39666295, 0.9496393, 0.103756346, 0.482158,
0.49171793, 0.4108867, 0.22594318, 0.97093135, 0.5974685, 0.34632966, 0.54835194, 0.10499302, 0.9767778, 0.55008715, 0.54379046, 0.3583731,
0.33369112, 0.04279039, 0.24939054, 0.23943715, 0.06775989, 0.7750291, 0.24329625, 0.4327169, 0.86916673, 0.80322117, 0.049972698, 0.47177452,
0.37419558, 0.15303156, 0.121425234, 0.75884604, 0.8191354, 0.48554084, 0.053899214, 0.7858246, 0.39219773, 0.77579063, 0.34507045, 0.46070176,
0.14496958, 0.47706795, 0.50678796, 0.64902323, 0.3277943, 0.0017530271, 0.6536156, 0.8582253, 0.95703506, 0.9963951, 0.8239163, 0.305142,
0.012419582, 0.9498972, 0.1595827, 0.47947606, 0.5071124, 0.78227425, 0.2066719, 0.5217094, 0.7841406, 0.5260441, 0.49798164, 0.10975622,
0.8633349, 0.76298475, 0.14295428, 0.6131504, 0.43794408, 0.50339264, 0.4504877, 0.19235311, 0.6678411, 0.80769485, 0.67495126, 0.96461457,
0.10535406, 0.66438645, 0.4372345, 0.93851465, 0.8635335, 0.3405871, 0.45652762, 0.3636232, 0.52931345, 0.20154329, 0.07698499, 0.6125804,
0.3583082, 0.3894796, 0.32601944, 0.5237369, 0.66683626, 0.08541841, 0.4815708, 0.11897489, 0.97555137, 0.3602705, 0.9620871, 0.6361821,
0.71167386, 0.5134439, 0.57761437, 0.58598644, 0.39387667, 0.6966405, 0.46841687, 0.85788506, 0.9957087, 0.051309288, 0.24846801, 0.55938333,
0.10230542, 0.9370694, 0.57527155, 0.54656035, 0.28896323, 0.51303476, 0.8865, 0.38641605, 0.9836358}, nd4j::DataType::FLOAT32);
NDArray expGradI('c', {bS, iC, iD, iH, iW}, {3.510932, 3.4310975, 3.538762, 4.148549, 2.8380678, 2.5431657, 3.3928843, 3.228055, 3.1467278,
3.2603023, 5.611751, 4.334653, 3.3697734, 4.603307, 4.4357986, 4.32991, 3.0532732, 3.1370173, 4.181534, 2.9965065, 2.8553872, 5.2719016,
4.5671935, 3.7027276, 3.3517184, 5.2544537, 3.5107024, 4.1496124, 3.9333878, 3.1798909, 3.1446428, 3.0932689, 3.9730802, 3.0466917,
4.9675374, 4.769673, 3.766952, 3.6375027, 3.6492167, 4.9440994, 3.8379507, 3.467589, 4.719474, 3.1295977, 4.5177174, 4.2760015, 2.8443856,
4.225355, 4.377341, 4.4398847, 4.710785, 4.4199953, 3.928307, 4.8769503}, nd4j::DataType::FLOAT32);
nd4j::ops::upsampling3d_bp op;
auto results = op.execute({&input, &gradO}, {}, {isNCDHW});
auto* gradI = results->at(0);
ASSERT_EQ(Status::OK(), results->status());
ASSERT_TRUE(expGradI.isSameShape(gradI));
ASSERT_TRUE(expGradI.equalsTo(gradI));
delete results;
}
//////////////////////////////////////////////////////////////////////
TEST_F(ConvolutionTests1, deconv2d_test1) {
int bS=2, iH=4,iW=4, iC=5,oC=10, kH=2,kW=2, sH=1,sW=1, pH=0,pW=0, dH=1,dW=1;
int oH=3,oW=3;
int paddingMode = 0; // 1-SAME, 0-VALID;
int dataFormat = 1; // 1-NHWC, 0-NCHW
auto input = NDArrayFactory::create<double>('c', {bS, oH, oW, oC});
auto weights = NDArrayFactory::create<double>('c', {kH, kW, iC, oC});
auto exp = NDArrayFactory::create<double>('c', {bS, iH, iW, iC}, { 2.75, 7.75, 12.75, 17.75, 22.75, 30.5 , 40.5 , 50.5 , 60.5 , 70.5 , 30.5 , 40.5 , 50.5 , 60.5 , 70.5 , 27.75, 32.75, 37.75, 42.75, 47.75,
55.5 , 65.5 , 75.5 , 85.5 , 95.5 ,161. , 181. , 201. , 221. , 241. ,161. , 181. , 201. , 221. , 241. ,105.5 , 115.5 , 125.5 , 135.5 , 145.5 ,
55.5 , 65.5 , 75.5 , 85.5 , 95.5 ,161. , 181. , 201. , 221. , 241. ,161. , 181. , 201. , 221. , 241. ,105.5 , 115.5 , 125.5 , 135.5 , 145.5 ,
52.75, 57.75, 62.75, 67.75, 72.75,130.5 , 140.5 , 150.5 , 160.5 , 170.5 ,130.5 , 140.5 , 150.5 , 160.5 , 170.5 , 77.75, 82.75, 87.75, 92.75, 97.75,
2.75, 7.75, 12.75, 17.75, 22.75, 30.5 , 40.5 , 50.5 , 60.5 , 70.5 , 30.5 , 40.5 , 50.5 , 60.5 , 70.5 , 27.75, 32.75, 37.75, 42.75, 47.75,
55.5 , 65.5 , 75.5 , 85.5 , 95.5 ,161. , 181. , 201. , 221. , 241. ,161. , 181. , 201. , 221. , 241. ,105.5 , 115.5 , 125.5 , 135.5 , 145.5 ,
55.5 , 65.5 , 75.5 , 85.5 , 95.5 ,161. , 181. , 201. , 221. , 241. ,161. , 181. , 201. , 221. , 241. ,105.5 , 115.5 , 125.5 , 135.5 , 145.5 ,
52.75, 57.75, 62.75, 67.75, 72.75,130.5 , 140.5 , 150.5 , 160.5 , 170.5 ,130.5 , 140.5 , 150.5 , 160.5 , 170.5 , 77.75, 82.75, 87.75, 92.75, 97.75});
input = 0.5;
weights.linspace(0.1, 0.1);
nd4j::ops::deconv2d op;
auto results = op.execute({&input, &weights}, {}, {kH,kW, sH,sW, pH,pW, dH,dW, paddingMode, dataFormat});
ASSERT_EQ(Status::OK(), results->status());
auto output = results->at(0);
ASSERT_TRUE(exp.isSameShape(output));
ASSERT_TRUE(exp.equalsTo(output));
delete results;
}
//////////////////////////////////////////////////////////////////////
TEST_F(ConvolutionTests1, deconv2d_test2) {
int bS=2, iH=4,iW=4, iC=5,oC=10, kH=2,kW=2, sH=1,sW=1, pH=0,pW=0, dH=1,dW=1;
int oH=4,oW=4;
int paddingMode = 1; // 1-SAME, 0-VALID;
int dataFormat = 1; // 1-NHWC, 0-NCHW
auto input = NDArrayFactory::create<double>('c', {bS, oH, oW, oC});
auto weights = NDArrayFactory::create<double>('c', {kH, kW, iC, oC});
auto exp = NDArrayFactory::create<double>('c', {bS, iH, iW, iC}, {2.75, 7.75, 12.75, 17.75, 22.75, 30.5 , 40.5 , 50.5 , 60.5 , 70.5 , 30.5 , 40.5 , 50.5 , 60.5 , 70.5 , 30.5 , 40.5 , 50.5 , 60.5 , 70.5 ,
55.5 , 65.5 , 75.5 , 85.5 , 95.5 ,161. , 181. , 201. , 221. , 241. ,161. , 181. , 201. , 221. , 241. ,161. , 181. , 201. , 221. , 241. ,
55.5 , 65.5 , 75.5 , 85.5 , 95.5 ,161. , 181. , 201. , 221. , 241. ,161. , 181. , 201. , 221. , 241. ,161. , 181. , 201. , 221. , 241. ,
55.5 , 65.5 , 75.5 , 85.5 , 95.5 ,161. , 181. , 201. , 221. , 241. ,161. , 181. , 201. , 221. , 241. ,161. , 181. , 201. , 221. , 241. ,
2.75, 7.75, 12.75, 17.75, 22.75, 30.5 , 40.5 , 50.5 , 60.5 , 70.5 , 30.5 , 40.5 , 50.5 , 60.5 , 70.5 , 30.5 , 40.5 , 50.5 , 60.5 , 70.5 ,
55.5 , 65.5 , 75.5 , 85.5 , 95.5 ,161. , 181. , 201. , 221. , 241. ,161. , 181. , 201. , 221. , 241. ,161. , 181. , 201. , 221. , 241. ,
55.5 , 65.5 , 75.5 , 85.5 , 95.5 ,161. , 181. , 201. , 221. , 241. ,161. , 181. , 201. , 221. , 241. ,161. , 181. , 201. , 221. , 241. ,
55.5 , 65.5 , 75.5 , 85.5 , 95.5 ,161. , 181. , 201. , 221. , 241. ,161. , 181. , 201. , 221. , 241. ,161. , 181. , 201. , 221. , 241. });
input = 0.5;
weights.linspace(0.1, 0.1);
nd4j::ops::deconv2d op;
auto results = op.execute({&input, &weights}, {}, {kH,kW, sH,sW, pH,pW, dH,dW, paddingMode, dataFormat});
auto output = results->at(0);
ASSERT_EQ(Status::OK(), results->status());
ASSERT_TRUE(exp.isSameShape(output));
ASSERT_TRUE(exp.equalsTo(output));
delete results;
}
//////////////////////////////////////////////////////////////////////
TYPED_TEST(TypedConvolutionTests1, deconv2d_tf_test1) {
int bS=2, iH=4,iW=4, iC=5,oC=10, kH=2,kW=2, sH=1,sW=1, pH=0,pW=0, dH=1,dW=1;
int oH=3,oW=3;
int paddingMode = 0; // 1-SAME, 0-VALID;
int dataFormat = 1; // 1-NHWC, 0-NCHW
auto input = NDArrayFactory::create<TypeParam>('c', {bS, oH, oW, oC});
auto weights = NDArrayFactory::create<TypeParam>('c', {kH, kW, iC, oC});
auto outShape = NDArrayFactory::create<TypeParam>('c', {4}, {static_cast<TypeParam>(bS), static_cast<TypeParam>(iH), static_cast<TypeParam>(iW), static_cast<TypeParam>(iC)});
auto exp = NDArrayFactory::create<TypeParam>('c', {bS, iH, iW, iC}, { 2.75, 7.75, 12.75, 17.75, 22.75, 30.5 , 40.5 , 50.5 , 60.5 , 70.5 , 30.5 , 40.5 , 50.5 , 60.5 , 70.5 , 27.75, 32.75, 37.75, 42.75, 47.75,
55.5 , 65.5 , 75.5 , 85.5 , 95.5 ,161. , 181. , 201. , 221. , 241. ,161. , 181. , 201. , 221. , 241. ,105.5 , 115.5 , 125.5 , 135.5 , 145.5 ,
55.5 , 65.5 , 75.5 , 85.5 , 95.5 ,161. , 181. , 201. , 221. , 241. ,161. , 181. , 201. , 221. , 241. ,105.5 , 115.5 , 125.5 , 135.5 , 145.5 ,
52.75, 57.75, 62.75, 67.75, 72.75,130.5 , 140.5 , 150.5 , 160.5 , 170.5 ,130.5 , 140.5 , 150.5 , 160.5 , 170.5 , 77.75, 82.75, 87.75, 92.75, 97.75,
2.75, 7.75, 12.75, 17.75, 22.75, 30.5 , 40.5 , 50.5 , 60.5 , 70.5 , 30.5 , 40.5 , 50.5 , 60.5 , 70.5 , 27.75, 32.75, 37.75, 42.75, 47.75,
55.5 , 65.5 , 75.5 , 85.5 , 95.5 ,161. , 181. , 201. , 221. , 241. ,161. , 181. , 201. , 221. , 241. ,105.5 , 115.5 , 125.5 , 135.5 , 145.5 ,
55.5 , 65.5 , 75.5 , 85.5 , 95.5 ,161. , 181. , 201. , 221. , 241. ,161. , 181. , 201. , 221. , 241. ,105.5 , 115.5 , 125.5 , 135.5 , 145.5 ,
52.75, 57.75, 62.75, 67.75, 72.75,130.5 , 140.5 , 150.5 , 160.5 , 170.5 ,130.5 , 140.5 , 150.5 , 160.5 , 170.5 , 77.75, 82.75, 87.75, 92.75, 97.75});
input = 0.5;
weights.linspace(0.1, 0.1);
nd4j::ops::deconv2d_tf op;
auto results = op.execute({&outShape, &weights, &input}, {}, {kH,kW, sH,sW, pH,pW, dH,dW, paddingMode, dataFormat});
auto output = results->at(0);
ASSERT_EQ(Status::OK(), results->status());
ASSERT_TRUE(exp.isSameShape(output));
ASSERT_TRUE(exp.equalsTo(output));
delete results;
}
#endif //LIBND4J_CONVOLUTIONTESTS1_H