/*******************************************************************************
 * Copyright (c) 2015-2018 Skymind, Inc.
 *
 * This program and the accompanying materials are made available under the
 * terms of the Apache License, Version 2.0 which is available at
 * https://www.apache.org/licenses/LICENSE-2.0.
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
 * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
 * License for the specific language governing permissions and limitations
 * under the License.
 *
 * SPDX-License-Identifier: Apache-2.0
 ******************************************************************************/

//
// @author raver119@gmail.com
//

#ifndef LIBND4J_CONVOLUTIONTESTS1_H
#define LIBND4J_CONVOLUTIONTESTS1_H

#include "testlayers.h"
#include <NDArray.h>
#include <Context.h>
#include <Node.h>
#include <graph/Variable.h>
#include <graph/VariableSpace.h>
#include <ops/declarable/CustomOperations.h>
#include <ops/declarable/helpers/convolutions.h>
#include <ops/declarable/helpers/col2im.h>
#include <PointersManager.h>

#ifdef HAVE_MKLDNN
#include <ops/declarable/platform/mkldnn/mkldnnUtils.h>
#endif

using namespace nd4j;
using namespace nd4j::graph;

class ConvolutionTests1 : public testing::Test {
public:

};

template <typename T>
class TypedConvolutionTests1 : public testing::Test {
public:

};

typedef ::testing::Types<double, float> TestingTypes;
TYPED_TEST_CASE(TypedConvolutionTests1, TestingTypes);

//////////////////////////////////////////////////////////////////////
TYPED_TEST(TypedConvolutionTests1, conv2d_1) {

    int bS=1, iH=5,iW=4,  iC=2,oC=3,  kH=2,kW=2,  sH=1,sW=1,  pH=0,pW=0,  dH=1,dW=1;
    TypeParam _expB[]{664.0, 700.0, 736.0, 344.0, 808.0, 844.0, 880.0, 408.0, 952.0, 988.0, 1024.0, 472.0, 1096.0, 1132.0, 1168.0, 536.0, 466.0, 480.0, 494.0, 220.0, 1528.0, 1628.0, 1728.0, 856.0, 1928.0, 2028.0, 2128.0, 1048.0, 2328.0, 2428.0, 2528.0, 1240.0, 2728.0, 2828.0, 2928.0, 1432.0, 1346.0, 1392.0, 1438.0, 700.0, 2392.0, 2556.0, 2720.0, 1368.0, 3048.0, 3212.0, 3376.0, 1688.0, 3704.0, 3868.0, 4032.0, 2008.0, 4360.0, 4524.0, 4688.0, 2328.0, 2226.0, 2304.0, 2382.0, 1180.0};
    Nd4jLong _expS[]{4, 1, 3, 5, 4, 60, 20, 4, 1, typeid(TypeParam) == typeid(float) ? 8192 : 16384, 1, 99};
    auto input = NDArrayFactory::create_<TypeParam>('c', {bS, iC, iH, iW});
    auto weights = NDArrayFactory::create_<TypeParam>('c', {oC, iC, kH, kW});
    for (int e = 0; e < input->lengthOf(); e++)
        input->p(e, e + 1);

    for (int e = 0; e < weights->lengthOf(); e++)
        weights->p(e, e + 1);
    weights->permutei({2,3,1,0});

    // weights->printShapeInfo("weights");

    ArrayOptions::setDataType(_expS, input->dataType());
    auto exp = new NDArray(_expB, _expS);

    auto variableSpace = new VariableSpace();
    variableSpace->putVariable(-1, input);
    variableSpace->putVariable(-2, weights);

    auto block = new Context(1, variableSpace, false);  // not-in-place
    block->fillInputs({-1, -2});
    // 5,5 kernel
    block->getIArguments()->push_back(kH);
    block->getIArguments()->push_back(kW);

    // 1,1 stride
    block->getIArguments()->push_back(sH);
    block->getIArguments()->push_back(sW);

    // 0,0 padding
    block->getIArguments()->push_back(pH);
    block->getIArguments()->push_back(pW);

    // 1,1 dilation
    block->getIArguments()->push_back(dH);
    block->getIArguments()->push_back(dW);

    // same mode
    block->getIArguments()->push_back(1);

    // is NHWC
    block->getIArguments()->push_back(0);

    nd4j::ops::conv2d op;

    Nd4jStatus status = op.execute(block);
    ASSERT_EQ(ND4J_STATUS_OK, status);

    auto res = variableSpace->getVariable(1)->getNDArray();


    // checking output shape
    ASSERT_EQ(1, res->sizeAt(0));
    ASSERT_EQ(3, res->sizeAt(1));
    ASSERT_EQ(5, res->sizeAt(2));
    ASSERT_EQ(4, res->sizeAt(3));

    // basically the same as above
    ASSERT_TRUE(res->isSameShape(exp));
    // just for visual validation
    // exp->printIndexedBuffer("Expected");
    // res->printIndexedBuffer("Actual  ");
    // res->printShapeInfo("Result shape");
    // final check
    ASSERT_TRUE(res->equalsTo(exp));

    delete block;
    delete variableSpace;
    delete exp;
}

//////////////////////////////////////////////////////////////////////
TYPED_TEST(TypedConvolutionTests1, conv2d_2) {
    auto input = NDArrayFactory::create<TypeParam>('c', {1, 1, 1, 4});
    auto weights = NDArrayFactory::create<TypeParam>('c', {1, 1, 1, 4});
    auto exp = NDArrayFactory::create<TypeParam>('c', {1, 4, 1, 4}, {2., 4., 6., 8., 2., 4., 6., 8., 2., 4., 6., 8., 2., 4., 6., 8.});

    weights.assign(2.0);
    input.linspace(1);

    nd4j::ops::conv2d op;
    auto result = op.execute({&input, &weights}, {}, {1, 1, 1, 1, 0, 0, 1, 1, 0, 0});
    ASSERT_EQ(ND4J_STATUS_OK, result->status());

    auto z = result->at(0);

    ASSERT_TRUE(exp.isSameShape(z));
    ASSERT_TRUE(exp.equalsTo(z));

    delete result;
}

//////////////////////////////////////////////////////////////////////
TYPED_TEST(TypedConvolutionTests1, conv2d_3) {

    int bS=2, iH=4,iW=3,  iC=4,oC=3,  kH=3,kW=2,  sH=1,sW=1,  pH=0,pW=0,  dH=1,dW=1;
    int       oH=4,oW=3;
    int paddingMode = 1;             // 1-SAME, 0-VALID;
    int dataFormat  = 1;             // 1-NHWC, 0-NCHW

    auto input = NDArrayFactory::create<TypeParam>('c', {bS, iH, iW, iC});
    auto weights = NDArrayFactory::create<TypeParam>('c', {kH, kW, iC, oC});
    auto bias = NDArrayFactory::create<TypeParam>('c', {oC}, {1,2,3});


    auto expOutput = NDArrayFactory::create<TypeParam>('c', {bS, oH, oW, oC},{ 152.f,  155.2f,  158.4f, 152.f,  155.2f,  158.4f,  66.4f,   68.f,   69.6f, 170.4f,  175.2f,  180.f, 170.4f,  175.2f,  180.f,  70.8f,   73.2f,   75.6f,
                                                      170.4f,  175.2f,  180.f, 170.4f,  175.2f,  180.f,  70.8f,   73.2f,   75.6f,  75.2f,   78.4f,   81.6f,  75.2f,   78.4f,   81.6f,  28.f,   29.6f,   31.2f,
                                                      152.f,  155.2f,  158.4f, 152.f,  155.2f,  158.4f,  66.4f,   68.f,   69.6f, 170.4f,  175.2f,  180.f, 170.4f,  175.2f,  180.f,  70.8f,   73.2f,   75.6f,
                                                      170.4f,  175.2f,  180.f, 170.4f,  175.2f,  180.f,  70.8f,   73.2f,   75.6f,  75.2f,   78.4f,   81.6f,  75.2f,   78.4f,   81.6f,  28.f,   29.6f,   31.2f});
    input = 2.;
    weights.linspace(0.1, 0.1);

    nd4j::ops::conv2d op;
    auto results = op.execute({&input, &weights}, {}, {kH,kW,  sH,sW,  pH,pW,  dH,dW, paddingMode, dataFormat});
    auto output = results->at(0);

    ASSERT_EQ(Status::OK(), results->status());

    ASSERT_TRUE(expOutput.isSameShape(output));
    ASSERT_TRUE(expOutput.equalsTo(output));

    delete results;
}

//////////////////////////////////////////////////////////////////////
TYPED_TEST(TypedConvolutionTests1, conv2d_4) {

    int bS=2, iH=4,iW=3,  iC=4,oC=3,  kH=3,kW=2,  sH=1,sW=1,  pH=0,pW=0,  dH=1,dW=1;
    int       oH=2,oW=2;
    int paddingMode = 0;             // 1-SAME, 0-VALID;
    int dataFormat  = 1;             // 1-NHWC, 0-NCHW

    auto input = NDArrayFactory::create<TypeParam>('c', {bS, iH, iW, iC});
    auto weights = NDArrayFactory::create<TypeParam>('c', {kH, kW, iC, oC});
    auto bias = NDArrayFactory::create<TypeParam>('c', {oC}, {1,2,3});

    auto expOutput = NDArrayFactory::create<TypeParam>('c', {bS, oH, oW, oC},{ 170.4f,175.20001f,180.f,170.4f,175.20001f,180.f,170.4f,175.20001f,180.f,170.4f,175.20001f,180.f,170.4f,175.20001f,180.f,170.4f,175.20001f,180.f,170.4f,175.20001f,180.f,170.4f,175.20001f,180.f});

    input = 2.;
    weights.linspace(0.1, 0.1);

    nd4j::ops::conv2d op;
    auto results = op.execute({&input, &weights}, {}, {kH,kW,  sH,sW,  pH,pW,  dH,dW, paddingMode, dataFormat});
    auto output = results->at(0);

    ASSERT_EQ(Status::OK(), results->status());

    ASSERT_TRUE(expOutput.isSameShape(output));
    ASSERT_TRUE(expOutput.equalsTo(output));

    delete results;
}

//////////////////////////////////////////////////////////////////////
TYPED_TEST(TypedConvolutionTests1, conv2d_5) {

    int bS=2, iH=4,iW=3,  iC=4,oC=3,  kH=3,kW=2,  sH=1,sW=1,  pH=0,pW=0,  dH=1,dW=1;
    int       oH=2,oW=2;
    int paddingMode = 0;             // 1-SAME, 0-VALID;
    int dataFormat  = 0;             // 1-NHWC, 0-NCHW

    auto input    = NDArrayFactory::create<TypeParam>('c', {bS, iC, iH, iW});
    auto weights  = NDArrayFactory::create<TypeParam>('c', {oC, iC, kH, kW});
    auto bias     = NDArrayFactory::create<TypeParam>('c', {oC}, {1,2,3});

    auto expOutput = NDArrayFactory::create<TypeParam>('c', {bS, oC, oH, oW}, {61.f,   61.f,  61.f,   61.f, 177.2f,  177.2f, 177.2f,  177.2f, 293.4f,  293.4f, 293.4f,  293.4f,  61.f,   61.f,  61.f,   61.f, 177.2f,  177.2f, 177.2f,  177.2f, 293.4f,  293.4f, 293.4f,  293.4f});

    input = 2.;
    weights.linspace(0.1, 0.1);
    weights.permutei({2,3,1,0});

    nd4j::ops::conv2d op;
    auto results = op.execute({&input, &weights, &bias}, {}, {kH,kW,  sH,sW,  pH,pW,  dH,dW, paddingMode, dataFormat});
    auto output = results->at(0);

    // output->printIndexedBuffer();

    ASSERT_EQ(Status::OK(), results->status());

    ASSERT_TRUE(expOutput.isSameShape(output));
    ASSERT_TRUE(expOutput.equalsTo(output));

    delete results;
}

//////////////////////////////////////////////////////////////////////
TYPED_TEST(TypedConvolutionTests1, conv2d_6) {
    auto input = NDArrayFactory::create<TypeParam>('c', {54, 1, 12, 12});
    auto weights = NDArrayFactory::create<TypeParam>('c', {1, 2, 12, 2});

    nd4j::ops::conv2d op;
    auto result = op.execute({&input, &weights}, {}, {-1,-1,  1,1,  0,0,  1,1,  1,1});
    ASSERT_EQ(Status::OK(), result->status());

    delete result;
}

TYPED_TEST(TypedConvolutionTests1, TestAvgFF_TF) {
    auto input = NDArrayFactory::create<TypeParam>('c', {4, 10, 10, 3}, {9.37125111f, 2.20166993f,             2.91434479f,             5.43639755f,             -2.10573769f,             4.08528662f,             5.86908436f,             -4.46203756f,             2.21057916f,             5.35849190f,             0.01394637f,             4.40566349f,             7.07982206f,             -0.09633455f,             2.42429352f,             3.97301817f,             -1.89553940f,             1.99690318f,             6.33141708f,             0.55401880f,             1.70707977f,             5.55204201f,             -0.03513752f,             1.60011971f,             2.62700319f,             -2.74582434f,             3.06697464f,             1.06277943f,             -1.16075921f,             -0.78095782f,             9.72352791f,             -1.22686064f,             1.99644792f,             7.35571337f,             1.40607321f,             0.11390255f,             9.53334427f,             2.28303599f,             -1.66728830f,             6.16678810f,             -0.04532295f,             -1.97708666f,             9.74906158f,             1.46223176f,             -1.46734393f,             4.30761862f,             -1.23790228f,             1.24823606f,             6.13938427f,             -3.83689475f,             -1.19625473f,             7.91535568f,             6.05868721f,             -3.22946382f,             8.81633949f,             -0.19967777f,             0.66053957f,             2.30919123f,             0.74543846f,             -0.39347672f,             11.11058044f,             0.53720862f,             1.52645731f,             5.70012379f,             -1.15213466f,             1.16451406f,             7.00526333f,             1.57362783f,             -2.44384766f,             5.54213285f,             -1.98828590f,             -0.70483637f,             7.88281822f,             -3.59875536f,             0.80745387f,             13.41578484f,             -1.55507684f,             -0.65855008f,             9.32583523f,             -0.14544789f,             0.73436141f,             3.61176538f,             -1.71268058f,             -2.58490300f,             9.09280205f,             -3.27405524f,             -2.04569697f,             4.44761324f,             -0.62955856f,             -2.61917663f,             8.04890442f,             0.54579324f,             0.85929775f,             9.82259560f,             -1.93825579f,             0.77703512f,             4.67090321f,             -4.79267597f,             -2.38906908f,             9.31265545f,             0.96026313f,             -1.14109385f,             11.54231834f,             -0.01417295f,             -0.39500344f,             8.49191666f,             0.55300158f,             2.79490185f,             6.92466164f,             1.72254205f,             2.82222271f,             8.83112717f,             2.95033407f,             2.18054962f,             6.73509789f,             -2.22272944f,             0.51127720f,             -1.04563558f,             2.15747333f,             -2.30959272f,             9.55441570f,             1.50396204f,             1.77370787f,             7.38146257f,             -1.79076433f,             3.20961165f,             7.18864202f,             2.91217351f,             0.43018937f,             7.11078024f,             -1.17386127f,             -0.16817921f,             6.12327290f,             -2.82205725f,             3.30696845f,             13.51291752f,             -1.30856836f,             -2.38332748f,             11.09487438f,             -1.47190213f,             -0.53050828f,             4.38285351f,             -5.07309771f,             1.50714362f,             5.72274446f,             -2.85825086f,             -0.89673209f,             3.73791552f,             -0.67708802f,             -4.13149452f,             -0.00671843f,             -0.26566532f,             0.32961160f,             7.14501762f,             -1.41608179f,             -4.96590328f,             12.26205540f,             -0.65158135f,             -0.88641000f,             6.95777559f,             -0.79058206f,             -0.10260171f,             7.87169170f,             1.35921454f,             1.11759663f,             5.46187401f,             -2.57214499f,             2.48484039f,             4.04043484f,             -2.07137156f,             -1.42709637f,             9.25487137f,             -0.12605135f,             -2.66949964f,             2.89412403f,             0.74451172f,             -2.96250391f,             3.99258423f,             0.27084303f,             0.32213116f,             5.42332172f,             -0.44414216f,             1.70881832f,             6.69346905f,             0.53058422f,             -4.73146200f,             4.22051668f,             2.24834967f,             0.66996074f,             4.30173683f,             0.11849818f,             -4.07520294f,             8.27318478f,             -2.54398274f,             -2.86705542f,             10.11775303f,             -0.99382895f,             0.65881538f,             7.93556786f,             -1.27934420f,             -1.69343162f,             9.68042564f,             -1.02609646f,             -1.18189347f,             5.75370646f,             -1.67888868f,             -4.48871994f,             4.79537392f,             -0.79212248f,             -0.19855022f,             6.15060997f,             -0.01081491f,             3.64454579f,             10.82562447f,             1.58859253f,             -2.65847278f,             8.60093212f,             -1.59196103f,             0.07635692f,             11.76175690f,             -1.17453325f,             0.10122013f,             6.86458445f,             -2.18891335f,             -2.74004745f,             8.07066154f,             0.71818852f,             -2.03035975f,             6.31053686f,             0.51509416f,             1.39789927f,             9.43515587f,             2.04256630f,             0.13985133f,             4.65010691f,             2.40911126f,             -0.36255789f,             -3.06867862f,             -0.45225358f,             -1.56778407f,             6.05917358f,             -1.09891272f,             1.77184200f,             6.46248102f,             0.96042323f,             -0.24346280f,             4.63436460f,             -4.69907761f,             1.25187206f,             11.46173859f,             -2.21917558f,             1.28007793f,             6.92173195f,             2.11268163f,             -3.47389889f,             5.08722782f,             -3.03950930f,             -4.17154264f,             11.30568314f,             0.80361372f,             2.53214502f,             7.18707085f,             -4.49114513f,             2.85449266f,             10.14906883f,             -0.31974933f,             -0.84472644f,             -0.52459574f,             0.12921631f,             -1.81390119f,             2.76170087f,             1.03982210f,             2.91744232f,             -0.29048753f,             5.87453508f,             -1.53684759f,             1.85800636f,             -0.91404629f,             1.28954852f,             5.11354685f,             -2.47475505f,             -1.33179152f,             2.58552408f,             1.37316465f,             -3.32339454f,             1.54122913f,             3.24953628f,             -0.29758382f,             2.82391763f,             -1.51142192f,             -1.22699404f,             6.75745535f,             0.65452754f,             -3.29385471f,             2.06008053f,             2.53172946f,             -4.23532820f,             -1.53909743f,             -0.07010663f,             -1.42173731f,             7.29031610f,             -0.18448229f,             4.59496164f,             6.73027277f,             0.73441899f,             0.14426160f,             4.14915276f,             -2.97010231f,             6.05851364f,             4.95218086f,             -2.39145470f,             2.40494704f,             2.10288811f,             0.53503096f,             1.44511235f,             6.66344261f,             -3.05803776f,             7.21418667f,             3.30303526f,             -0.24163735f,             3.47409391f,             3.64520788f,             2.15189481f,             -3.11243272f,             3.62310791f,             0.37379482f,             0.40865007f,             -0.83132005f,             -4.78246069f,             2.07030797f,             6.51765442f,             3.16178989f,             5.06180477f,             3.78434467f,             -0.96689719f,             0.35965276f,             5.89967585f,             1.40294051f,             1.11952639f,             10.59778214f,             0.26739889f,             -1.61297631f,             6.24801159f,             -0.93914318f,             -0.57812452f,             9.92604542f,             -0.73025000f,             -3.38530874f,             2.45646000f,             -2.47949195f,             0.51638460f,             10.65636063f,             1.97816694f,             -3.00407791f,             2.66914415f,             -0.81951088f,             -0.23316640f,             2.40737987f,             -2.70007610f,             1.51531935f,             4.08860207f,             -0.27552786f,             -1.31721711f,             7.11568260f,             -3.33498216f,             -4.02545023f,             7.22675610f,             -0.81690705f,             -2.52689576f,             1.04016697f,             -0.79291463f,             -0.34875512f,             10.00498390f,             -4.24167728f,             1.46162593f,             11.82569408f,             -1.70359993f,             -0.30161047f,             16.44085884f,             -0.82253462f,             -0.09435523f,             6.13080597f,             -0.20259480f,             0.68308711f,             6.15663004f,             -6.61776876f,             0.33295766f,             2.55449438f,             -0.17819691f,             -1.14892209f,             5.56776142f,             1.99279118f,             1.33035934f,             4.45823956f,             3.34916544f,             -2.59905386f,             6.16164446f,             -2.03881931f,             -2.45273542f,             12.46793365f,             -2.22743297f,             2.83738565f,             8.48628139f,             -1.39347959f,             -1.30867767f,             11.08041477f,             -4.00363779f,             2.09183025f,             11.30395889f,             -2.20504737f,             1.37426853f,             8.98735619f,             1.04676604f,             -0.72757077f,             8.28050232f,             -6.70741081f,             -0.65798020f,             5.68592072f,             -0.60760021f,             0.35854483f,             6.26852131f,             1.94100165f,             1.32112014f,             0.80987954f,             -1.74617672f,             -0.25434083f,             7.16045523f,             1.58884013f,             -2.64847064f,             13.14820385f,             1.21393633f,             -2.47258949f,             9.41650105f,             -0.79384226f,             2.48954105f,             10.95629311f,             0.47723705f,             4.02126694f,             8.02593136f,             -2.20726371f,             -1.18794477f,             1.50836647f,             0.93118095f,             -1.73513174f,             8.85493565f,             -2.99670315f,             -0.79055870f,             2.39473820f,             2.05046916f,             -2.38055134f,             11.82299423f,             0.15609655f,             0.68744308f,             5.66401434f,             -0.69281673f,             2.09855556f,             7.74626589f,             -0.34283102f,             1.00542057f,             9.95838642f,             0.80161905f,             2.33455157f,             9.80057335f,             -0.93561798f,             2.56991577f,             8.29711342f,             0.94213426f,             0.44209945f,             11.70259857f,             0.92710167f,             2.60957146f,             0.24971688f,             -0.86529571f,             3.78628922f,             6.80884457f,             -0.68178189f,             2.21103406f,             3.18895817f,             0.60283208f,             -2.92716241f,             6.72060776f,             -1.06625068f,             2.56543374f,             9.97404480f,             3.58080721f,             -0.94936347f,             10.16736984f,             -1.38464379f,             1.18191063f,             6.66179037f,             -3.56115270f,             0.32329530f,             10.90870762f,             2.20638227f,             0.19653285f,             7.34650040f,             -3.63859272f,             -1.03027737f,             5.98829985f,             -3.66606474f,             -3.89746714f,             8.63469028f,             1.22569811f,             1.63240814f,             3.74385309f,             0.58243257f,             -0.56981975f,             3.69260955f,             1.00979900f,             -1.44030499f,             8.57058144f,             -1.10648811f,             1.20474911f,             5.43133020f,             -2.14822555f,             -0.07928789f,             11.25825310f,             0.19645604f,             -5.49546146f,             10.41917038f,             -0.68178523f,             -2.99639869f,             6.50054455f,             0.46488351f,             -5.42328453f,             9.09500027f,             -2.82107449f,             0.05601966f,             15.34610748f,             -0.06820253f,             3.86699796f,             10.73316956f,             -3.04795432f,             -0.14702171f,             5.64813185f,             1.44028485f,             -2.47596145f,             0.07280898f,             -3.03187990f,             -1.35183525f,             9.35835648f,             2.72966957f,             1.88199532f,             10.36187744f,             -0.22834805f,             -3.26738238f,             6.92025137f,             -2.34061313f,             4.77379704f,             5.28559113f,             -2.96323752f,             -1.76186585f,             5.94436455f,             0.38647744f,             -5.73869514f,             6.76849556f,             1.40892124f,             -1.19068217f,             5.37919092f,             -6.65328646f,             3.62782669f,             12.34744644f,             2.44762444f,             -4.19242620f,             6.14906216f,             0.08121119f,             0.61355996f,             2.69666457f,             -1.88962626f,             -0.55314136f,             1.84937525f,             1.56048691f,             1.17460012f,             3.75674725f,             1.06198275f,             -5.74625874f,             5.41645575f,             -1.28946674f,             -1.51689398f,             4.32400894f,             -0.05222082f,             -4.83948946f,             1.80747867f,             1.63144708f,             -2.73887825f,             1.63975775f,             -2.02163982f,             -0.16210437f,             2.93518686f,             1.14427686f,             -2.83246303f,             4.79283667f,             2.69697428f,             -3.12678456f,             -1.19225168f,             -2.37022972f,             -3.09429741f,             1.94225383f,             -1.13747168f,             -2.55048585f,             5.40242243f,             1.12777328f,             3.43713188f,             3.62658787f,             -2.16878843f,             0.30164462f,             2.97407579f,             -0.07275413f,             -1.31149673f,             4.70066261f,             -2.01323795f,             4.85255766f,             4.59128904f,             1.68084168f,             1.60336494f,             6.58138466f,             -1.04759812f,             2.69906545f,             3.55769277f,             -0.74327278f,             2.65819693f,             5.39528131f,             2.11248922f,             -1.06446671f,             5.24546766f,             -2.43146014f,             4.58907509f,             0.06521678f,             -2.24503994f,             2.45722699f,             6.94863081f,             0.35258654f,             2.83396196f,             9.92525196f,             -1.12225175f,             -0.34365177f,             7.19116688f,             -4.39813757f,             0.46517885f,             13.22028065f,             -2.57483673f,             -6.37226963f,             7.58046293f,             -2.74600363f,             0.42231262f,             8.04881668f,             0.17289802f,             -0.53447008f,             16.55157471f,             -5.63614368f,             0.39288223f,             3.37079263f,             1.26484549f,             -0.12820500f,             8.46440125f,             -4.39304399f,             2.97676420f,             0.65650189f,             0.83158541f,             -1.11556435f,             6.32885838f,             -0.36087769f,             2.80724382f,             9.90292645f,             1.15936041f,             0.20947981f,             6.91249275f,             -2.67404819f,             2.93782163f,             6.65656614f,             -2.30828357f,             2.98214006f,             6.80611229f,             -4.93821478f,             -7.66555262f,             7.59763002f,             -0.54159302f,             3.87403512f,             12.42607784f,             2.59284401f,             -0.23375344f,             8.95293331f,             -0.71807784f,             0.61873478f,             8.66713524f,             1.24289191f,             -2.37835455f,             2.08071637f,             -0.88315344f,             -3.41891551f,             6.85245323f,             1.73007369f,             1.02169311f,             7.69170332f,             -2.85411978f,             2.69790673f,             8.12906551f,             -1.19351399f,             -2.26442742f,             12.26104450f,             -0.75579089f,             -1.73274946f,             10.68729019f,             2.20655656f,             -0.90522075f,             12.42165184f,             -1.67929137f,             2.44851565f,             9.31565762f,             -0.06645700f,             1.52762020f,             6.18427515f,             -1.68882596f,             3.70261097f,             3.02252960f,             -3.44125366f,             -1.31575799f,             2.84617424f,             -0.96849400f,             -4.52356243f,             9.95027161f,             0.19966406f,             -0.78874779f,             8.18595028f,             -4.08300209f,             1.75126517f,             0.96418417f,             -4.04913044f,             -0.95200396f,             12.03637886f,             -0.03041124f,             0.41642749f,             8.88267422f,             -3.24985337f,             -2.24919462f,             7.32566118f,             0.16964148f,             -2.74123430f,             7.05264473f,             -3.30191112f,             0.17163286f,             4.81851053f,             -1.64463484f,             -0.85933101f,             7.29276276f,             2.34066939f,             -2.14860010f,             3.46148157f,             -0.01782012f,             1.51504040f,             4.79304934f,             1.85281146f,             -1.70663762f,             6.93470192f,             -4.15440845f,             -1.25983095f,             10.52491760f,             0.42930329f,             -1.85146868f,             11.70042324f,             -0.41704914f,             3.83796859f,             9.21148491f,             -2.79719448f,             0.79470479f,             6.26926661f,             -5.85230207f,             3.95105338f,             7.84790897f,             -1.38680744f,             -1.78099084f,             11.95235348f,             -2.99841452f,             -1.34507811f,             6.15714645f,             -1.07552516f,             -2.81228638f,             1.66234732f,             -4.55166149f,             -1.92601109f,             8.64634514f,             -0.48158705f,             3.31595659f,             7.67371941f,             2.56964207f,             0.12107098f,             4.56467867f,             -0.93541539f,             1.39432955f,             11.99714088f,             1.05353570f,             -2.13099813f,             3.67617917f,             3.45895386f,             1.37365830f,             8.74344158f,             -4.17585802f,             1.43908918f,             6.28764772f,             3.97346330f,             -0.69144285f,             9.07983303f,             -0.41635889f,             -0.14965028f,             8.85469818f,             1.11306190f,             2.59440994f,             5.38982344f,             -1.07948279f,             1.37252975f,             10.26984596f,             -0.09318046f,             2.73104119f,             12.45902252f,             -1.55446684f,             -2.76124811f,             12.19395065f,             -0.51846564f,             1.02764034f,             11.42673588f,             -0.95940983f,             -0.04781032f,             8.78379822f,             -4.88957930f,             0.32534006f,             11.97696400f,             -3.35108662f,             1.95104563f,             4.46915388f,             -2.32061648f,             3.45230985f,             8.29983711f,             2.81034684f,             -2.35529327f,             6.07801294f,             -0.98105043f,             -0.05359888f,             2.52291036f,             -0.01986909f,             -2.35321999f,             10.51954269f,             2.11145401f,             3.53506470f,             7.29093266f,             0.03721160f,             -1.13496494f,             7.43886709f,             -5.84201956f,             2.50796294f,             12.14647675f,             2.77490377f,             -2.18896222f,             6.05641937f,             5.32617044f,             1.04221284f,             10.79106712f,             -2.95749092f,             -2.75414610f,             11.30037117f,             -3.40654182f,             -2.24673963f,             7.49126101f,             0.70811015f,             -6.18003702f,             13.83951187f,             -1.01204085f,             1.36298490f,             -1.04451632f,             2.42435336f,             -0.02346706f,             -0.85528886f,             1.04731262f,             0.22192979f,             4.15708160f,             0.34933877f,             0.04814529f,             2.24107265f,             0.49676740f,             -1.47752666f,             0.45040059f,             -0.70471478f,             -1.19759345f,             0.21711677f,             0.88461423f,             -2.76830935f,             5.52066898f,             1.97664857f,             -1.75381601f,             3.45877838f,             1.52617192f,             -1.61350942f,             0.85337949f,             1.97610760f,             -3.40310287f,             3.40319014f,             -3.38691044f,             -0.71319139f,             1.65463758f,             -0.60680127f,             -1.80700517f,             8.02592373f,             2.59627104f,             2.65895891f,             5.93043184f,             -4.48425817f,             3.92670918f,             4.19496679f,             -2.28286791f,             6.41634607f,             5.72330523f,             1.16269672f,             -0.28753027f,             2.46342492f,             0.36693189f,             0.26712441f,             6.37652683f,             -2.50139046f,             2.43923736f,             5.56310415f,             0.98065847f,             1.04267502f,             4.16403675f,             -0.04966142f,             4.40897894f,             3.72905660f,             -3.46129870f,             3.59962773f,             1.34830284f,             -1.76661730f,             0.47943926f,             5.29946661f,             -1.12711561f,             1.26970029f,             15.17655945f,             -1.50971997f,             5.81345224f,             8.48562050f,             -4.36049604f,             2.48144460f,             8.23780441f,             -3.46030426f,             -0.84656560f,             5.94946814f,             1.12747943f,             -2.65683913f,             8.69085693f,             1.31309867f,             -2.79958344f,             8.76840591f,             -1.56444156f,             1.62710834f,             2.41177034f,             -0.72804940f,             5.70619011f,             4.67169666f,             -0.86167198f,             -1.83803177f,             2.96346045f,             2.82692933f,             -2.81557131f,             7.11113358f,             -1.90071094f,             2.54244423f,             11.19284058f,             -0.06298946f,             -1.71517313f,             12.98388577f,             0.84510714f,             3.00816894f,             2.57200313f,             0.03899818f,             -1.49330592f,             9.60099125f,             -3.59513044f,             -1.30045319f,             7.09241819f,             -0.65233821f,             -2.33627677f,             8.81366920f,             0.84154201f,             1.03312039f,             9.85289097f,             0.19351870f,             1.78496623f,             7.34631205f,             -2.16530800f,             -0.65016162f,             2.46842360f,             0.24016285f,             -1.24308395f,             4.78175163f,             -0.97682536f,             2.20942235f,             6.68382788f,             3.76786447f,             -1.44454038f,             6.26453733f,             -3.23575711f,             -2.30137897f,             9.53092670f,             -5.55222607f,             3.25999236f,             9.37559509f,             1.86339056f,             -0.23551451f,             10.23400211f,             3.93031883f,             -0.52629089f,             7.85724449f,             -2.91549587f,             4.46612740f,             5.66530371f,             -2.70820427f,             4.81359577f,             10.31247330f,             1.92230141f,             2.53931546f,             0.74986327f,             1.70303428f,             0.48063779f,             5.31099129f,             -0.78976244f,             3.75864220f,             4.23051405f,             2.34042454f,             -7.98193836f,             9.83987141f,             -1.46722627f,             3.54497814f,             10.36455154f,             -4.51249075f,             0.77715248f,             7.78694630f,             -4.59989023f,             -2.49585629f,             9.90296268f,             1.38535416f,             1.17441154f,             10.10452843f,             -0.98628229f,             0.60194463f,             9.12639141f,             -3.90754628f,             2.88526392f,             7.24123430f,             -0.15283313f,             -0.75728363f,             -1.15116858f,             -2.53791571f,             0.77229571f,             6.44114161f,             0.02646767f,             4.95463037f,             7.21066380f,             1.79384065f,             0.73250306f,             8.04447937f,             0.32576546f,             -0.79447043f,             10.12717724f,             2.33392906f,             1.30716443f,             12.36073112f,             -0.36694977f,             -1.20438910f,             7.03105593f,             0.59557682f,             0.69267452f,             10.18113136f,             2.49944925f,             -0.42229167f,             8.83143330f,             -1.18805945f,             -2.87509322f,             4.53596449f,             4.09732771f,             -3.39088297f,             -1.02536607f,             0.82119560f,             -3.47302604f,             9.29991817f,             0.21001509f,             4.97036457f,             9.50018406f,             1.04420102f,             1.96560478f,             10.74769592f,             -6.22709799f,             3.11690164f,             5.06759691f,             -1.23724771f,             -3.05831861f,             8.12925529f,             -1.93435478f,             -1.10151744f,             9.32263088f,             -0.04249470f,             -5.98547363f,             10.49398136f,             0.26400441f,             -0.78915191f,             13.28219604f,             2.99276900f,             0.74853164f,             2.49364305f,             -3.43529654f,             4.05278301f,             2.13498688f,             -2.35444307f,             -0.79900265f,             4.66968822f,             -0.31095147f,             3.60674143f,             12.37222099f,             -0.07855003f,             -3.30292702f,             12.15215874f,             0.60886210f,             2.87075138f,             7.75271845f,             0.38044083f,             3.34402204f,             6.40583277f,             -0.87888050f,             0.67438459f,             6.91080809f,             1.98332930f,             -0.08303714f,             8.08630371f,             -0.16772588f,             -2.74058914f,             7.17253590f,             -2.69122696f,             1.48173678f,             8.99470139f,             -1.43302310f,             -0.88651133f,             2.66944790f,             -0.29186964f,             2.00838661f,             5.09587479f,             -0.76676071f,             -2.88322186f,             8.31110573f,             -0.14550979f,             -1.37726915f,             10.28355122f,             -1.60575438f,             -0.04118848f,             9.97510815f,             0.14440438f,             -3.24632120f,             9.00034523f,             4.14319563f,             -1.31023729f,             7.16950464f,             -0.70428526f,             2.01559544f,             7.26155043f,             2.40816474f,             2.09847403f,             7.31264496f,             -0.75401551f,             2.13392544f,             7.03648758f,             1.04036045f,             -1.15636516f,             1.09634531f,             -0.06340861f,             -0.58107805f,             -0.65623116f,             1.18972754f,             -0.80717683f,             1.40118241f,             -0.61932516f,             -3.60596156f,             1.59904599f,             -2.23774099f,             -1.13721037f,             3.89620137f,             -0.09115922f,             -7.51356888f,             2.36975193f,             -1.42520905f,             -2.34173775f,             3.33830214f,             -2.74016523f,             -3.04115510f,             6.00119495f,             -1.36084354f,             -2.45065260f,             4.56992292f,             -3.02825928f,             -3.74182844f,             5.11069250f,             -0.91531068f,             -2.31385994f,             1.83399653f,             3.39370203f,             -3.60886002f});
    auto exp = NDArrayFactory::create<TypeParam>('c', {4, 4, 4, 3}, {7.97172260f,  0.06878620f,              2.27749538f,              7.29276514f,              -0.14074677f,              0.65480286f,              5.70313978f,              -0.06546132f,              0.35443667f,              3.70382833f,              -0.84020567f,              0.63826996f,              8.60301399f,              -0.38236514f,              1.55177069f,              7.37542057f,              -0.99374938f,              -0.29971302f,              8.84352493f,              -0.67121059f,              0.43132120f,              4.78175592f,              -1.25070143f,              -1.91523600f,              6.03855371f,              -0.00292124f,              -1.11214364f,              7.90158176f,              -0.57949901f,              -0.96735370f,              7.81192017f,              -0.53255427f,              -0.48009714f,              3.16953635f,              0.08353355f,              -1.54299748f,              3.74821687f,              1.69396687f,              0.72724354f,              5.42915201f,              -1.13686812f,              -0.71793109f,              5.78376389f,              -0.72239977f,              -0.60055625f,              2.53636408f,              0.56777251f,              -2.07892323f,              6.08064651f,              0.68620735f,              2.54017019f,              5.65828180f,              -0.68255502f,              1.47283304f,              6.10842514f,              -0.39655915f,              0.28380761f,              1.96707797f,              -1.98206317f,              0.94027776f,              4.71811438f,              0.32104525f,              -0.92409706f,              8.34588146f,              -1.05581069f,              -0.55217457f,              9.58440876f,              -0.96549922f,              0.45820439f,              5.65453672f,              -2.50953507f,              -0.71441835f,              8.03059578f,              -0.21281289f,              0.92125505f,              9.26900673f,              -0.35963219f,              -0.70039093f,              8.59924412f,              -1.22358346f,              0.81318003f,              3.85920119f,              -0.01305223f,              -1.09234154f,              6.33158875f,              1.28094780f,              -1.48926139f,              4.94969177f,              -0.77126902f,              -1.97033751f,              5.64381838f,              -0.16285487f,              -1.31277227f,              2.39893222f,              -1.32902908f,              -1.39609122f,              6.47572327f,              -0.45267010f,              1.55727172f,              6.70965624f,              -1.68735468f,              -0.05672536f,              7.25092363f,              -0.64613032f,              0.67050058f,              3.60789680f,              -2.05948973f,              2.22687531f,              8.15202713f,              -0.70148355f,              1.28314006f,              8.14842319f,              -1.88807654f,              -1.04808438f,              8.45500565f,              -0.76425624f,              0.94542569f,              4.56179953f,              -0.28786001f,              -2.04502511f,              8.46278095f,              -0.31019822f,              0.07339200f,              9.34214592f,              -0.61948007f,              0.52481830f,              8.32515621f,              -1.52418160f,              0.49678251f,              5.11082315f,              -1.09908783f,              -0.52969611f,              5.27806664f,              0.88632923f,              0.66754371f,              4.75839233f,              0.48928693f,              -0.68036932f,              6.56925392f,              -0.02949905f,              -2.99189186f,              4.46320581f,              -0.64534980f,              -0.29516968f,              8.60809517f,              -1.13120568f,              3.41720533f,              5.84243155f,              -1.24109328f,              0.89566326f,              5.99578333f,              -0.42496428f,              2.07076764f,              3.17812920f,              -0.81566459f,              -0.14363396f,              6.55184317f,              0.39633346f,              -0.43852386f,              8.70214558f,              -2.24613595f,              0.30708700f,              8.73882294f,              -0.53545928f,              1.54409575f,              4.49452257f,              -0.16509305f,              0.19028664f,              8.24897003f,              0.44750381f,              2.15448594f,              8.97640514f,              -0.77728152f,              0.57272542f,              9.03467560f,              0.47173575f,              -1.10807717f,              3.30056310f,              -0.43268481f,              -0.41470885f,              3.53798294f,              -0.08546703f,              -2.16840744f,              6.18733406f,              -0.17871059f,              -2.59837723f,              5.94218683f,              -1.02990067f,              -0.49760687f,              3.76938033f,              0.86383581f,              -1.91504073f});

    nd4j::ops::avgpool2d op;
    auto result = op.execute({&input}, {}, {3,3,  3,3,  0,0,  1,1,1,  0,1}, {});
    ASSERT_EQ(Status::OK(), result->status());

    auto z = result->at(0);

    // z->printIndexedBuffer("z");
    // exp.printIndexedBuffer("e");

    ASSERT_TRUE(exp.isSameShape(z));
    ASSERT_TRUE(exp.equalsTo(z));

    delete result;
}

//////////////////////////////////////////////////////////////////////
TYPED_TEST(TypedConvolutionTests1, conv2d_7) {

    int bS=1, iH=256,iW=256,  iC=1,oC=1,  kH=4,kW=3,  sH=1,sW=1,  pH=0,pW=0,  dH=1,dW=1;
    // int       oH=256,oW=256;
    int paddingMode = 1;             // 1-SAME, 0-VALID;
    int dataFormat  = 0;             // 1-NHWC, 0-NCHW

    auto input    = NDArrayFactory::create<TypeParam>('c', {bS, iC, iH, iW});
    auto weights  = NDArrayFactory::create<TypeParam>('c', {kH, kW, iC, oC});

    input = 5.;
    weights = 3.;

    nd4j::ops::conv2d op;
    auto results = op.execute({&input, &weights}, {}, {kH,kW,  sH,sW,  pH,pW,  dH,dW, paddingMode, dataFormat});
    auto* output = results->at(0);

    ASSERT_EQ(Status::OK(), results->status());

    delete results;
}

//////////////////////////////////////////////////////////////////////
TEST_F(ConvolutionTests1, sconv2d_1) {
    float _expB[] = {10025.0f,    10350.0f,    10675.0f,    11000.0f,    11325.0f,    11650.0f,    13275.0f,    13600.0f,    13925.0f,    14250.0f,    14575.0f,    14900.0f,    16525.0f,    16850.0f,    17175.0f,    17500.0f,    17825.0f,    18150.0f,    19775.0f,    20100.0f,    20425.0f,    20750.0f,    21075.0f,    21400.0f,    23025.0f,    23350.0f,    23675.0f,    24000.0f,    24325.0f,    24650.0f,    26275.0f,    26600.0f,    26925.0f,    27250.0f,    27575.0f,    27900.0f,    38775.0f,    40350.0f,    41925.0f,    43500.0f,    45075.0f,    46650.0f,    54525.0f,    56100.0f,    57675.0f,    59250.0f,    60825.0f,    62400.0f,    70275.0f,    71850.0f,    73425.0f,    75000.0f,    76575.0f,    78150.0f,    86025.0f,    87600.0f,    89175.0f,    90750.0f,    92325.0f,    93900.0f,   101775.0f,   103350.0f,   104925.0f,    106500.0f,   108075.0f,   109650.0f,   117525.0f,   119100.0f,   120675.0f,   122250.0f,    123825.0f,   125400.0f,    67525.0f,    70350.0f,    73175.0f,    76000.0f,    78825.0f,    81650.0f,    95775.0f,    98600.0f,   101425.0f,   104250.0f,   107075.0f,   109900.0f,    124025.0f,   126850.0f,   129675.0f,   132500.0f,   135325.0f,   138150.0f,   152275.0f,    155100.0f,   157925.0f,   160750.0f,   163575.0f,   166400.0f,   180525.0f,   183350.0f,    186175.0f,   189000.0f,   191825.0f,   194650.0f,   208775.0f,   211600.0f,   214425.0f,    217250.0f,   220075.0f,   222900.0f,   119400.0f,   120350.0f,   121300.0f,   122250.0f,    123200.0f,   124150.0f,   128900.0f,   129850.0f,   130800.0f,   131750.0f,   132700.0f,    133650.0f,   138400.0f,   139350.0f,   140300.0f,   141250.0f,   142200.0f,   143150.0f,    147900.0f,   148850.0f,   149800.0f,   150750.0f,   151700.0f,   152650.0f,   157400.0f,    158350.0f,   159300.0f,   160250.0f,   161200.0f,   162150.0f,   166900.0f,   167850.0f,    168800.0f,   169750.0f,   170700.0f,   171650.0f,   273150.0f,   275350.0f,   277550.0f,    279750.0f,   281950.0f,   284150.0f,   295150.0f,   297350.0f,   299550.0f,   301750.0f,    303950.0f,   306150.0f,   317150.0f,   319350.0f,   321550.0f,   323750.0f,   325950.0f,    328150.0f,   339150.0f,   341350.0f,   343550.0f,   345750.0f,   347950.0f,   350150.0f,    361150.0f,   363350.0f,   365550.0f,   367750.0f,   369950.0f,   372150.0f,   383150.0f,    385350.0f,   387550.0f,   389750.0f,   391950.0f,   394150.0f,   426900.0f,   430350.0f,    433800.0f,   437250.0f,   440700.0f,   444150.0f,   461400.0f,   464850.0f,   468300.0f,    471750.0f,   475200.0f,   478650.0f,   495900.0f,   499350.0f,   502800.0f,   506250.0f,    509700.0f,   513150.0f,   530400.0f,   533850.0f,   537300.0f,   540750.0f,   544200.0f,    547650.0f,   564900.0f,   568350.0f,   571800.0f,   575250.0f,   578700.0f,   582150.0f,    599400.0f,   602850.0f,   606300.0f,   609750.0f,   613200.0f,   616650.0f,    75025.0f,    75350.0f,    75675.0f,    76000.0f,    76325.0f,    76650.0f,    78275.0f,    78600.0f,    78925.0f,    79250.0f,    79575.0f,    79900.0f,    81525.0f,    81850.0f,    82175.0f,    82500.0f,    82825.0f,    83150.0f,    84775.0f,    85100.0f,    85425.0f,    85750.0f,    86075.0f,    86400.0f,    88025.0f,    88350.0f,    88675.0f,    89000.0f,    89325.0f,    89650.0f,    91275.0f,    91600.0f,    91925.0f,    92250.0f,    92575.0f,    92900.0f,    353775.0f,   355350.0f,   356925.0f,   358500.0f,   360075.0f,   361650.0f,   369525.0f,    371100.0f,   372675.0f,   374250.0f,   375825.0f,   377400.0f,   385275.0f,   386850.0f,    388425.0f,   390000.0f,   391575.0f,   393150.0f,   401025.0f,   402600.0f,   404175.0f,    405750.0f,   407325.0f,   408900.0f,   416775.0f,   418350.0f,   419925.0f,   421500.0f,    423075.0f,   424650.0f,   432525.0f,   434100.0f,   435675.0f,   437250.0f,   438825.0f,    440400.0f,   632525.0f,   635350.0f,   638175.0f,   641000.0f,   643825.0f,   646650.0f,    660775.0f,   663600.0f,   666425.0f,   669250.0f,   672075.0f,   674900.0f,   689025.0f,    691850.0f,   694675.0f,   697500.0f,   700325.0f,   703150.0f,   717275.0f,   720100.0f,    722925.0f,   725750.0f,   728575.0f,   731400.0f,   745525.0f,   748350.0f,   751175.0f,    754000.0f,   756825.0f,   759650.0f,   773775.0f,   776600.0f,   779425.0f,   782250.0f,    785075.0f,   787900.0f,   309400.0f,   310350.0f,   311300.0f,   312250.0f,   313200.0f,    314150.0f,   318900.0f,   319850.0f,   320800.0f,   321750.0f,   322700.0f,   323650.0f,    328400.0f,   329350.0f,   330300.0f,   331250.0f,   332200.0f,   333150.0f,   337900.0f,    338850.0f,   339800.0f,   340750.0f,   341700.0f,   342650.0f,   347400.0f,   348350.0f,    349300.0f,   350250.0f,   351200.0f,   352150.0f,   356900.0f,   357850.0f,   358800.0f,    359750.0f,   360700.0f,   361650.0f,   713150.0f,   715350.0f,   717550.0f,   719750.0f,    721950.0f,   724150.0f,   735150.0f,   737350.0f,   739550.0f,   741750.0f,   743950.0f,    746150.0f,   757150.0f,   759350.0f,   761550.0f,   763750.0f,   765950.0f,   768150.0f,    779150.0f,   781350.0f,   783550.0f,   785750.0f,   787950.0f,   790150.0f,   801150.0f,    803350.0f,   805550.0f,   807750.0f,   809950.0f,   812150.0f,   823150.0f,   825350.0f,    827550.0f,   829750.0f,   831950.0f,   834150.0f,  1116900.0f,  1120350.0f,  1123800.0f,    1127250.0f,  1130700.0f,  1134150.0f,  1151400.0f,  1154850.0f,  1158300.0f,  1161750.0f,    1165200.0f,  1168650.0f,  1185900.0f,  1189350.0f,  1192800.0f,  1196250.0f,  1199700.0f,    1203150.0f,  1220400.0f,  1223850.0f,  1227300.0f,  1230750.0f,  1234200.0f,  1237650.0f,    1254900.0f,  1258350.0f,  1261800.0f,  1265250.0f,  1268700.0f,  1272150.0f,  1289400.0f,    1292850.0f,  1296300.0f,  1299750.0f,  1303200.0f,  1306650.0f,};
    Nd4jLong _expS[] = {4, 2, 6, 6, 6, 144, 36, 6, 1, 8192, 1, 99};
    NDArray exp(_expB, _expS);

    int sY = 1;
    int sX = 1;
    int pY = 0;
    int pX = 0;
    int iC = 2;
    int oC = 3;
    int kY = 5;
    int kX = 5;
    int iY = 10;
    int iX = 10;
    int B = 2;

    auto input = NDArrayFactory::create_<float>('c', {B, iC, iY, iX});
    for (int e = 0; e < input->lengthOf(); e++)
        input->p(e, e+1);

    auto weights = NDArrayFactory::create_<float>('c', {oC, iC, kY, kX});
    for (int e = 0; e < weights->lengthOf(); e++)
        weights->p(e, e+1);
    weights->permutei({2,3,1,0});

    auto variableSpace = new VariableSpace();
    variableSpace->putVariable(-1, input);
    variableSpace->putVariable(-2, weights);

    auto block = new Context(1, variableSpace, false);
    block->fillInputs({-1, -2});

    block->getIArguments()->push_back(kY);
    block->getIArguments()->push_back(kX);

    block->getIArguments()->push_back(sY);
    block->getIArguments()->push_back(sX);

    block->getIArguments()->push_back(pY);
    block->getIArguments()->push_back(pX);

    // dilation
    block->getIArguments()->push_back(1);
    block->getIArguments()->push_back(1);

    // NOT same mode
    block->getIArguments()->push_back(0);

    nd4j::ops::sconv2d op;

    Nd4jStatus status = op.execute(block);

    ASSERT_EQ(ND4J_STATUS_OK, status);
    auto output = variableSpace->getVariable(1)->getNDArray();

    //exp.printShapeInfo("Expected shape");
    //output->printShapeInfo("Result shape");
    ASSERT_TRUE(exp.isSameShape(output));

        //exp.printBuffer("Expctd buffer");
    //output->printBuffer("Result buffer");
    ASSERT_TRUE(exp.equalsTo(output));

    delete block;
    delete variableSpace;
}

//////////////////////////////////////////////////////////////////////
TYPED_TEST(TypedConvolutionTests1, sconv2d_2) {
    TypeParam _expBFF[] = {108.9405008,  109.5920008,  110.2435008,  110.8950008,   111.5465008,  112.1980008,  115.4555008,  116.1070008,   116.7585008,  117.410000,   118.061500,   118.7130009,   121.9705009,  122.6220009,  123.2735009,  123.9250009,   124.5765009,  125.2280009,  128.4855009,  129.1370009,   129.7885009,  130.4400009,  131.09150,    131.74300,   135.0005010,  135.6520010,  136.3035010,  136.9550010,   137.6065010,  138.2580010,  141.5155010,  142.1670010,   142.8185010,  143.4700010,  144.1215010,  144.7730010,   248.9617514,  250.670751,   252.3797515,  254.0887515,   255.7977515,  257.5067515,  266.0517515,  267.7607515,   269.469751,   271.1787516,  272.8877516,  274.5967516,   283.1417516,  284.8507516,  286.5597516,  288.268751,   289.9777517,  291.6867517,  300.2317517,  301.9407517,   303.6497517,  305.3587517,  307.067751,   308.7767518,   317.3217518,  319.0307518,  320.7397518,  322.4487518,   324.157751,   325.866751,   334.4117519,  336.1207519,   337.8297519,  339.5387519,  341.2477519,  342.95675,   388.9829964,  391.7494964,  394.5159964,  397.2824964,   400.048996,   402.8154963,  416.647996,   419.4144962,   422.1809962,  424.9474962,  427.7139962,  430.4804962,   444.3129961,  447.0794961,  449.8459961,  452.6124960,   455.3789960,  458.1454960,  471.9779959,  474.7444959,   477.5109959,  480.2774959,  483.0439959,  485.8104958,   499.6429958,  502.4094957,  505.1759957,  507.9424957,   510.7089957,  513.4754957,  527.3079956,  530.0744956,   532.8409956,  535.607495,   538.3739955,  541.1404955,   529.0042487,  532.8282487,  536.6522487,  540.4762487,   544.3002487,  548.1242487,  567.2442487,  571.068248,   574.892248,   578.716248,   582.540248,   586.3642486,   605.4842486,  609.3082486,  613.1322486,  616.9562486,   620.7802486,  624.6042486,  643.7242486,  647.5482486,   651.3722486,  655.1962486,  659.0202486,  662.8442486,   681.9642486,  685.7882486,  689.6122486,  693.4362486,   697.2602486,  701.0842486,  720.2042486,  724.0282486,   727.852248,   731.676248,   735.500248,   739.324248,   669.0255044,  673.9070044,  678.7885044,  683.6700044,   688.5515044,  693.4330044,  717.8405044,  722.7220044,   727.6035044,  732.4850044,  737.3665044,  742.2480044,   766.6555043,  771.5370043,  776.4185043,  781.3000043,   786.1815043,  791.0630043,  815.4705043,  820.3520043,   825.2335043,  830.1150043,  834.9965043,  839.8780043,   864.2855042,  869.1670042,  874.0485042,  878.9300042,   883.8115042,  888.6930042,  913.1005042,  917.9820042,   922.8635042,  927.7450042,  932.6265042,  937.5080042,   809.0467424,  814.9857424,  820.9247424,  826.8637423,   832.8027423,  838.7417423,  868.4367421,  874.3757421,   880.3147420,  886.2537420,  892.1927420,  898.13174,   927.8267418,  933.7657418,  939.7047417,  945.6437417,   951.5827417,  957.5217416,  987.2167415,  993.155741,   999.0947414, 1005.0337414, 1010.972741,  1016.9117413,   1046.6067412, 1052.5457411, 1058.4847411, 1064.4237411,   1070.3627410, 1076.3017410, 1105.996740,  1111.9357408,   1117.8747408, 1123.8137408, 1129.7527407, 1135.6917407,   949.0679815,  956.0644814,  963.060981,   970.0574813,   977.0539812,  984.0504811, 1019.0329807, 1026.0294807,   1033.0259806, 1040.0224805, 1047.0189804, 1054.0154804,   1088.9979800, 1095.9944799, 1102.9909798, 1109.987479,   1116.9839797, 1123.9804796, 1158.9629792, 1165.9594791,   1172.9559791, 1179.9524790, 1186.9489789, 1193.9454788,   1228.9279785, 1235.9244784, 1242.9209783, 1249.9174782,   1256.913978,  1263.9104781, 1298.8929777, 1305.8894776,   1312.8859775, 1319.8824775, 1326.8789774, 1333.8754773,   1089.0892560, 1097.1432561, 1105.1972562, 1113.251256,   1121.3052563, 1129.3592564, 1169.6292568, 1177.6832568,   1185.7372569, 1193.7912570, 1201.845257,  1209.8992571,   1250.1692575, 1258.2232576, 1266.2772576, 1274.3312577,   1282.3852578, 1290.4392579, 1330.7092582, 1338.7632583,   1346.8172584, 1354.8712584, 1362.9252585, 1370.9792586,   1411.24925,   1419.3032590, 1427.3572591, 1435.4112592,   1443.465259,  1451.5192593, 1491.7892597, 1499.8432598,   1507.8972598, 1515.9512599, 1524.0052600, 1532.059260,   1229.1105073, 1238.2220073, 1247.3335073, 1256.4450073,   1265.5565073, 1274.668007,  1320.2255074, 1329.3370074,   1338.4485074, 1347.5600075, 1356.6715075, 1365.7830075,   1411.340507,  1420.4520076, 1429.5635076, 1438.6750076,   1447.7865076, 1456.8980076, 1502.4555077, 1511.5670077,   1520.6785077, 1529.7900077, 1538.9015077, 1548.013007,   1593.5705078, 1602.6820078, 1611.793507,  1620.9050079,   1630.0165079, 1639.1280079, 1684.6855080, 1693.7970080,   1702.9085080, 1712.0200080, 1721.1315080, 1730.2430080,   1369.1317613, 1379.3007614, 1389.4697614, 1399.6387615,   1409.8077615, 1419.976761,  1470.8217618, 1480.9907618,   1491.159761,  1501.3287619, 1511.4977619, 1521.6667620,   1572.5117622, 1582.6807622, 1592.8497623, 1603.0187623,   1613.1877624, 1623.3567624, 1674.2017626, 1684.3707627,   1694.5397627, 1704.7087628, 1714.8777628, 1725.046762,   1775.8917631, 1786.0607631, 1796.229763,  1806.3987632,   1816.5677632, 1826.7367633, 1877.5817635, 1887.7507635,   1897.9197636, 1908.0887636, 1918.2577637, 1928.4267637,   304.3905022,  305.0420022,  305.6935022,  306.3450022,   306.9965022,  307.6480022,  310.9055022,  311.5570022,   312.208502,   312.860002,   313.5115023,  314.1630023,   317.4205023,  318.0720023,  318.7235023,  319.3750023,   320.0265023,  320.6780023,  323.9355023,  324.5870023,   325.2385023,  325.8900023,  326.541502,   327.193002,   330.4505024,  331.1020024,  331.7535024,  332.4050024,   333.0565024,  333.7080024,  336.9655024,  337.6170024,   338.2685024,  338.9200024,  339.5715024,  340.223002,   761.6617542,  763.3707542,  765.0797542,  766.7887542,   768.4977542,  770.206754,   778.7517543,  780.4607543,   782.1697543,  783.8787543,  785.5877543,  787.2967543,   795.8417544,  797.5507544,  799.2597544,  800.9687544,   802.6777544,  804.3867544,  812.9317545,  814.6407545,   816.3497545,  818.0587545,  819.7677545,  821.4767545,   830.0217546,  831.7307546,  833.4397546,  835.1487546,   836.8577546,  838.5667546,  847.1117547,  848.8207547,   850.5297547,  852.2387547,  853.9477547,  855.6567547,   1218.9329915, 1221.6994915, 1224.4659915, 1227.232491,   1229.9989914, 1232.7654914, 1246.5979913, 1249.3644913,   1252.1309913, 1254.8974913, 1257.6639913, 1260.430491,   1274.2629912, 1277.029491,  1279.7959911, 1282.5624911,   1285.3289911, 1288.0954911, 1301.9279910, 1304.6944910,   1307.4609910, 1310.22749,   1312.9939909, 1315.7604909,   1329.5929908, 1332.3594908, 1335.1259908, 1337.8924908,   1340.6589908, 1343.4254908, 1357.2579907, 1360.0244907,   1362.7909906, 1365.5574906, 1368.3239906, 1371.0904906,   1676.2042479, 1680.0282479, 1683.8522479, 1687.6762479,   1691.5002479, 1695.3242479, 1714.4442479, 1718.2682479,   1722.0922479, 1725.9162479, 1729.7402479, 1733.5642479,   1752.6842479, 1756.5082479, 1760.3322479, 1764.1562479,   1767.9802479, 1771.8042479, 1790.9242479, 1794.7482479,   1798.5722479, 1802.3962479, 1806.2202479, 1810.044247,   1829.1642478, 1832.9882478, 1836.8122478, 1840.6362478,   1844.4602478, 1848.2842478, 1867.4042478, 1871.2282478,   1875.0522478, 1878.8762478, 1882.7002478, 1886.5242478,   2133.4755029, 2138.3570029, 2143.2385029, 2148.1200029,   2153.0015029, 2157.8830029, 2182.2905028, 2187.1720028,   2192.0535028, 2196.9350028, 2201.8165028, 2206.6980028,   2231.1055028, 2235.9870028, 2240.8685028, 2245.7500028,   2250.6315028, 2255.5130028, 2279.9205027, 2284.8020027,   2289.6835027, 2294.5650027, 2299.4465027, 2304.3280027,   2328.7355027, 2333.6170027, 2338.4985027, 2343.3800027,   2348.2615027, 2353.1430027, 2377.5505026, 2382.4320026,   2387.3135026, 2392.1950026, 2397.0765026, 2401.9580026,   2590.7467330, 2596.6857330, 2602.6247329, 2608.5637329,   2614.5027329, 2620.441732,  2650.1367327, 2656.0757327,   2662.0147326, 2667.9537326, 2673.8927326, 2679.8317325,   2709.5267324, 2715.465732,  2721.4047323, 2727.3437323,   2733.282732,  2739.2217322, 2768.9167321, 2774.8557320,   2780.7947320, 2786.7337320, 2792.6727319, 2798.6117319,   2828.306731,  2834.2457317, 2840.1847317, 2846.1237317,   2852.0627316, 2858.0017316, 2887.6967314, 2893.6357314,   2899.5747314, 2905.5137313, 2911.4527313, 2917.3917313,   3048.0179587, 3055.0144586, 3062.0109585, 3069.0074584,   3076.0039584, 3083.0004583, 3117.9829579, 3124.9794578,   3131.9759578, 3138.9724577, 3145.9689576, 3152.9654575,   3187.947957,  3194.9444571, 3201.9409570, 3208.9374569,   3215.933956,  3222.9304568, 3257.9129564, 3264.9094563,   3271.9059562, 3278.9024562, 3285.8989561, 3292.8954560,   3327.8779556, 3334.874455,  3341.8709555, 3348.8674554,   3355.8639553, 3362.860455,  3397.8429549, 3404.8394548,   3411.8359547, 3418.8324546, 3425.8289546, 3432.8254545,   3505.28927,   3513.3432780, 3521.3972781, 3529.4512782,   3537.5052782, 3545.5592783, 3585.8292787, 3593.8832788,   3601.9372788, 3609.9912789, 3618.0452790, 3626.099279,   3666.3692794, 3674.4232795, 3682.4772796, 3690.5312796,   3698.5852797, 3706.6392798, 3746.9092801, 3754.9632802,   3763.0172803, 3771.0712804, 3779.1252804, 3787.1792805,   3827.4492809, 3835.50328,   3843.5572810, 3851.6112811,   3859.6652812, 3867.7192812, 3907.9892816, 3916.0432817,   3924.097281,  3932.1512818, 3940.2052819, 3948.2592820,   3962.5605113, 3971.6720113, 3980.783511,  3989.8950114,   3999.0065114, 4008.1180114, 4053.6755115, 4062.7870115,   4071.8985115, 4081.0100115, 4090.1215115, 4099.2330115,   4144.7905116, 4153.9020116, 4163.0135116, 4172.1250116,   4181.236511,  4190.3480117, 4235.9055117, 4245.0170117,   4254.128511,  4263.2400118, 4272.3515118, 4281.4630118,   4327.0205119, 4336.1320119, 4345.2435119, 4354.3550119,   4363.4665119, 4372.5780119, 4418.1355120, 4427.2470120,   4436.3585120, 4445.4700120, 4454.581512,  4463.6930121,   4419.8317743, 4430.0007744, 4440.1697744, 4450.338774,   4460.5077745, 4470.6767745, 4521.521774,  4531.6907748,   4541.8597748, 4552.0287749, 4562.1977749, 4572.3667750,   4623.2117752, 4633.3807752, 4643.5497753, 4653.7187753,   4663.8877754, 4674.0567754, 4724.9017756, 4735.0707757,   4745.2397757, 4755.4087757, 4765.5777758, 4775.7467758,   4826.591776,  4836.7607761, 4846.9297761, 4857.0987762,   4867.2677762, 4877.4367763, 4928.2817765, 4938.4507765,   4948.6197766, 4958.7887766, 4968.957776,  4979.12677675};
    Nd4jLong _expSFF[] = {4, 2, 10, 6, 6, 360, 36, 6, 1, typeid(TypeParam) == typeid(float) ? 8192 : 16384, 1, 99,};
    NDArray expFF(_expBFF, _expSFF);

    auto input = NDArrayFactory::create<TypeParam>('c', {2, 3, 10, 10});
    auto weightsD = NDArrayFactory::create<TypeParam>('c', {5, 3, 5, 5});
    auto weightsP = NDArrayFactory::create<TypeParam>('c', {10, 15, 1, 1});

    input.linspace(1);
    weightsD.linspace(1);
    weightsP.linspace(1);
    weightsD.permutei({2,3,1,0});
    weightsP.permutei({2,3,1,0});

    input.applyScalar(scalar::Divide, 100.0);
    weightsD.applyScalar(scalar::Divide, 100.0);
    weightsP.applyScalar(scalar::Divide, 100.0);

    nd4j::ops::sconv2d op;

    auto resultFF = op.execute({&input, &weightsD, &weightsP}, {}, {5, 5, 1, 1, 0, 0, 1, 1, 0, 0}, {});

    auto z = resultFF->at(0);
    //z->printShapeInfo("FF shape");


    ASSERT_TRUE(z->isSameShape(&expFF));

    //expFF.printBuffer("e");
    //z->printBuffer("z");
    ASSERT_TRUE(z->equalsTo(&expFF, 1e-3));

    delete resultFF;
}

TYPED_TEST(TypedConvolutionTests1, sconv2d_3) {
    auto input = NDArrayFactory::create<TypeParam>('c', {3, 3, 8, 8});
    auto weightsD = NDArrayFactory::create<TypeParam>('c', {1, 3, 1, 1});
    auto weightsP = NDArrayFactory::create<TypeParam>('c', {2, 3, 1, 1});
    auto bias = NDArrayFactory::create<TypeParam>('c', {2});
    auto output = NDArrayFactory::create<TypeParam>('c', {3, 2, 8, 8});
    output.assign(0.0);

    input.linspace(1);
    weightsD.linspace(1);
    weightsP.linspace(1);
    bias.linspace(1);
    weightsD.permutei({2,3,1,0});
    weightsP.permutei({2,3,1,0});

    auto expOutput = NDArrayFactory::create<TypeParam>('c', {3, 2, 8, 8});

    nd4j::ops::sconv2d op;
    Nd4jStatus status = op.execute({&input, &weightsD, &weightsP, &bias}, {&output}, {},  {1, 1, 1, 1, 0, 0, 1, 1, 0}, {});
    auto result = op.execute({&input, &weightsD, &weightsP, &bias}, {},  {1, 1, 1, 1, 0, 0, 1, 1, 0}, {});

    auto z = result->at(0);

    //printf("\n");
    //output.printBuffer("output");
    //z->printBuffer("z");


    //ASSERT_TRUE(expOutput.isSameShape(z));

    delete result;
}

TYPED_TEST(TypedConvolutionTests1, deconv2D_FF_NoBias_1) {
    Nd4jLong _expS[] = {4, 2, 3, 8, 8, 192, 64, 8, 1, typeid(TypeParam) == typeid(float) ? 8192 : 16384, 1, 99};
    TypeParam _expB[] = {6276.0,   12831.0,   19668.0,   26790.0,   27012.0,   20703.0,   14100.0,    7200.0,    13719.0,   28023.0,   42918.0,   58410.0,   58902.0,   45105.0,   30693.0,   15660.0,    22389.0,   45696.0,   69930.0,   95100.0,   95910.0,   73386.0,   49899.0,   25440.0,    32346.0,   65970.0,  100884.0,  137100.0,  138276.0,  105726.0,   71838.0,   36600.0,    33726.0,   68790.0,  105204.0,  142980.0,  144156.0,  110226.0,   74898.0,   38160.0,    27555.0,   56154.0,   85806.0,  116520.0,  117474.0,   89748.0,   60933.0,   31020.0,    19917.0,   40557.0,   61926.0,   84030.0,   84714.0,   64671.0,   43875.0,   22320.0,    10752.0,   21879.0,   33384.0,   45270.0,   45636.0,   34815.0,   23604.0,   12000.0,    7551.0,   15456.0,   23718.0,   32340.0,   32562.0,   24978.0,   17025.0,    8700.0,    16569.0,   33873.0,   51918.0,   70710.0,   71202.0,   54555.0,   37143.0,   18960.0,    27114.0,   55371.0,   84780.0,  115350.0,  116160.0,   88911.0,   60474.0,   30840.0,    39246.0,   80070.0,  122484.0,  166500.0,  167676.0,  128226.0,   87138.0,   44400.0,    40626.0,   82890.0,  126804.0,  172380.0,  173556.0,  132726.0,   90198.0,   45960.0,    33180.0,   67629.0,  103356.0,  140370.0,  141324.0,  107973.0,   73308.0,   37320.0,    23967.0,   48807.0,   74526.0,  101130.0,  101814.0,   77721.0,   52725.0,   26820.0,    12927.0,   26304.0,   40134.0,   54420.0,   54786.0,   41790.0,   28329.0,   14400.0,    8826.0,   18081.0,   27768.0,   37890.0,   38112.0,   29253.0,   19950.0,   10200.0,    19419.0,   39723.0,   60918.0,   83010.0,   83502.0,   64005.0,   43593.0,   22260.0,    31839.0,   65046.0,   99630.0,  135600.0,  136410.0,  104436.0,   71049.0,   36240.0,    46146.0,   94170.0,  144084.0,  195900.0,  197076.0,  150726.0,  102438.0,   52200.0,    47526.0,   96990.0,  148404.0,  201780.0,  202956.0,  155226.0,  105498.0,   53760.0,    38805.0,   79104.0,  120906.0,  164220.0,  165174.0,  126198.0,   85683.0,   43620.0,    28017.0,   57057.0,   87126.0,  118230.0,  118914.0,   90771.0,   61575.0,   31320.0,    15102.0,   30729.0,   46884.0,   63570.0,   63936.0,   48765.0,   33054.0,   16800.0,    17220.0,   34863.0,   52932.0,   71430.0,   72228.0,   54831.0,   36996.0,   18720.0,    36327.0,   73527.0,  111606.0,  150570.0,  152214.0,  115521.0,   77925.0,   39420.0,    57381.0,  116112.0,  176202.0,  237660.0,  240198.0,  182250.0,  122907.0,   62160.0,    80442.0,  162738.0,  246900.0,  332940.0,  336420.0,  255198.0,  172062.0,   87000.0,    84702.0,  171318.0,  259860.0,  350340.0,  353820.0,  268338.0,  180882.0,   91440.0,    66867.0,  135210.0,  205038.0,  276360.0,  279042.0,  211572.0,  142581.0,   72060.0,    46845.0,   94701.0,  143574.0,  193470.0,  195306.0,  148047.0,   99747.0,   50400.0,    24576.0,   49671.0,   75288.0,  101430.0,  102372.0,   77583.0,   52260.0,   26400.0,    22095.0,   44688.0,   67782.0,   91380.0,   92178.0,   69906.0,   47121.0,   23820.0,    46377.0,   93777.0,  142206.0,  191670.0,  193314.0,  146571.0,   98775.0,   49920.0,    72906.0,  147387.0,  223452.0,  301110.0,  303648.0,  230175.0,  155082.0,   78360.0,    101742.0,  205638.0,  311700.0,  419940.0,  423420.0,  320898.0,  216162.0,  109200.0,    106002.0,  214218.0,  324660.0,  437340.0,  440820.0,  334038.0,  224982.0,  113640.0,    83292.0,  168285.0,  254988.0,  343410.0,  346092.0,  262197.0,  176556.0,   89160.0,    58095.0,  117351.0,  177774.0,  239370.0,  241206.0,  182697.0,  122997.0,   62100.0,    30351.0,   61296.0,   92838.0,  124980.0,  125922.0,   95358.0,   64185.0,   32400.0,    26970.0,   54513.0,   82632.0,  111330.0,  112128.0,   84981.0,   57246.0,   28920.0,    56427.0,  114027.0,  172806.0,  232770.0,  234414.0,  177621.0,  119625.0,   60420.0,    88431.0,  178662.0,  270702.0,  364560.0,  367098.0,  278100.0,  187257.0,   94560.0,    123042.0,  248538.0,  376500.0,  506940.0,  510420.0,  386598.0,  260262.0,  131400.0,    127302.0,  257118.0,  389460.0,  524340.0,  527820.0,  399738.0,  269082.0,  135840.0,    99717.0,  201360.0,  304938.0,  410460.0,  413142.0,  312822.0,  210531.0,  106260.0,    69345.0,  140001.0,  211974.0,  285270.0,  287106.0,  217347.0,  146247.0,   73800.0,    36126.0,   72921.0,  110388.0,  148530.0,  149472.0,  113133.0,   76110.0,   38400.0,};
    NDArray exp(_expB, _expS);

    auto input = NDArrayFactory::create_<TypeParam>('c', {2, 3, 4, 4});
    auto weights = NDArrayFactory::create_<TypeParam>('c', {3, 3, 5, 5});

    input->linspace(1);
    weights->linspace(1);
    weights->permutei({2,3,1,0});

    auto variableSpace = new VariableSpace();
    variableSpace->putVariable(-1, input);
    variableSpace->putVariable(-2, weights);

    auto block = new Context(1, variableSpace, false);
    block->fillInputs({-1, -2});

    block->getIArguments()->push_back(5);
    block->getIArguments()->push_back(5);

    block->getIArguments()->push_back(1);
    block->getIArguments()->push_back(1);

    block->getIArguments()->push_back(0);
    block->getIArguments()->push_back(0);

    // dilation
    block->getIArguments()->push_back(1);
    block->getIArguments()->push_back(1);

    // NOT same mode
    block->getIArguments()->push_back(0);

    block->getIArguments()->push_back(0);

    nd4j::ops::deconv2d op;

    Nd4jStatus status = op.execute(block);

    ASSERT_EQ(ND4J_STATUS_OK, status);

    auto output = variableSpace->getVariable(1)->getNDArray();

    ASSERT_TRUE(exp.isSameShape(output));

    //    exp.printBuffer("Expctd buffer");
    //output->printBuffer("Result buffer");
    ASSERT_TRUE(exp.equalsTo(output));

    delete variableSpace;
    delete block;
}

TYPED_TEST(TypedConvolutionTests1, conv2D_BP_Bias_1) {
    TypeParam _expWGradB[] = {9312.0, 12580.0, 9528.0, 13168.0, 17712.0, 13360.0, 9960.0, 13348.0, 10032.0, 13344.0, 18148.0, 13848.0, 19312.0, 26160.0, 19888.0, 15144.0, 20452.0, 15504.0};
    Nd4jLong _expWGradS[] = {4, 2, 1, 3, 3, 9, 9, 3, 1, typeid(TypeParam) == typeid(float) ? 8192 : 16384, 1, 99};
    NDArray expWGrad(_expWGradB, _expWGradS);
    expWGrad.permutei({2,3,1,0});

    TypeParam _expBGradB[] = {784.0, 1296.0};
    Nd4jLong _expBGradS[] = {2, 2, 1, 1, 1, typeid(TypeParam) == typeid(float) ? 8192 : 16384, 1, 99};

    NDArray expBGrad(_expBGradB, _expBGradS);

    auto input = NDArrayFactory::create<TypeParam>('c', {2, 1, 4, 4});
    auto weights = NDArrayFactory::create<TypeParam>('c', {2, 1, 3, 3});
    auto bias = NDArrayFactory::create<TypeParam>('c', {2, 1});
    auto epsilonNext = NDArrayFactory::create<TypeParam>('c', {2, 2, 4, 4});


    TypeParam _expEpsB[] = {952.0, 1540.0, 1636.0, 1180.0, 1791.0, 2886.0, 3057.0, 2193.0, 2223.0, 3570.0, 3741.0, 2673.0, 1900.0, 3028.0, 3160.0, 2240.0, 2872.0, 4612.0, 4708.0, 3356.0, 5247.0, 8358.0, 8529.0, 6033.0, 5679.0, 9042.0, 9213.0, 6513.0, 4588.0, 7252.0, 7384.0, 5184.0};
    NDArray expEps(_expEpsB, input.getShapeInfo());

    input.linspace(1);
    weights.linspace(1);
    epsilonNext.linspace(1);
    weights.permutei({2,3,1,0});

    nd4j::ops::conv2d_bp op;

    auto results = op.execute({&input, &weights, &bias, &epsilonNext}, {},  {3, 3, 1, 1, 0, 0, 1, 1, 1}, {});

    ASSERT_TRUE(results->size() == 3);

    auto epsilon = results->at(0);
    auto gradW = results->at(1);
    auto gradB = results->at(2);

    ASSERT_TRUE(expWGrad.isSameShape(gradW));

    //expWGrad.printBuffer("Expctd buffer");
    //  gradW->printBuffer("Result buffer");
    ASSERT_TRUE(expWGrad.equalsTo(gradW));


    ASSERT_TRUE(input.isSameShape(epsilon));

    //  expEps.printBuffer("Expctd buffer");
    //epsilon->printBuffer("Result buffer");
    ASSERT_TRUE(expEps.equalsTo(epsilon));

    ASSERT_TRUE(expBGrad.isSameShape(gradB));

    ASSERT_TRUE(expBGrad.equalsTo(gradB));

    delete results;
}


TYPED_TEST(TypedConvolutionTests1, conv2D_BP_NoBias_1) {
    TypeParam _expWGradB[] = {9312.0, 12580.0, 9528.0, 13168.0, 17712.0, 13360.0, 9960.0, 13348.0, 10032.0, 13344.0, 18148.0, 13848.0, 19312.0, 26160.0, 19888.0, 15144.0, 20452.0, 15504.0};
    Nd4jLong _expWGradS[] = {4, 2, 1, 3, 3, 9, 9, 3, 1, typeid(TypeParam) == typeid(float) ? 8192 : 16384, 1, 99};
    NDArray expWGrad(_expWGradB, _expWGradS);
    expWGrad.permutei({2,3,1,0});

    auto input = NDArrayFactory::create<TypeParam>('c', {2, 1, 4, 4});
    auto weights = NDArrayFactory::create<TypeParam>('c', {2, 1, 3, 3});
    auto epsilonNext = NDArrayFactory::create<TypeParam>('c', {2, 2, 4, 4});


    TypeParam _expEpsB[] = {952.0, 1540.0, 1636.0, 1180.0, 1791.0, 2886.0, 3057.0, 2193.0, 2223.0, 3570.0, 3741.0, 2673.0, 1900.0, 3028.0, 3160.0, 2240.0, 2872.0, 4612.0, 4708.0, 3356.0, 5247.0, 8358.0, 8529.0, 6033.0, 5679.0, 9042.0, 9213.0, 6513.0, 4588.0, 7252.0, 7384.0, 5184.0};
    NDArray expEps(_expEpsB, input.getShapeInfo());

    input.linspace(1);
    weights.linspace(1);
    epsilonNext.linspace(1);
    weights.permutei({2,3,1,0});

    nd4j::ops::conv2d_bp op;

    auto results = op.execute({&input, &weights, &epsilonNext}, {},  {3, 3, 1, 1, 0, 0, 1, 1, 1}, {});

    ASSERT_TRUE(results->size() == 2);

    auto epsilon = results->at(0);
    auto gradW = results->at(1);

    ASSERT_TRUE(expWGrad.isSameShape(gradW));

    //expWGrad.printBuffer("Expctd buffer");
    //  gradW->printBuffer("Result buffer");
    ASSERT_TRUE(expWGrad.equalsTo(gradW));


    ASSERT_TRUE(input.isSameShape(epsilon));

    //  expEps.printBuffer("Expctd buffer");
    //epsilon->printBuffer("Result buffer");
    ASSERT_TRUE(expEps.equalsTo(epsilon));

    delete results;
}

TYPED_TEST(TypedConvolutionTests1, sconv2d_conv2d_1) {
    TypeParam _expBFF[] = {10025.0f,   10350.0f,   10675.0f,   11000.0f,   11325.0f,   11650.0f,   13275.0f,   13600.0f,   13925.0f,   14250.0f,   14575.0f,   14900.0f,   16525.0f,   16850.0f,   17175.0f,   17500.0f,   17825.0f,   18150.0f,   19775.0f,   20100.0f,   20425.0f,   20750.0f,   21075.0f,   21400.0f,   23025.0f,   23350.0f,   23675.0f,   24000.0f,   24325.0f,   24650.0f,   26275.0f,   26600.0f,   26925.0f,   27250.0f,   27575.0f,   27900.0f,   53150.0f,   55350.0f,   57550.0f,   59750.0f,   61950.0f,   64150.0f,   75150.0f,   77350.0f,   79550.0f,   81750.0f,   83950.0f,   86150.0f,   97150.0f,   99350.0f,  101550.0f,  103750.0f,  105950.0f,  108150.0f,   119150.0f,  121350.0f,  123550.0f,  125750.0f,  127950.0f,  130150.0f,   141150.0f,  143350.0f,  145550.0f,  147750.0f,  149950.0f,  152150.0f,   163150.0f,  165350.0f,  167550.0f,  169750.0f,  171950.0f,  174150.0f,   119400.0f,  120350.0f,  121300.0f,  122250.0f,  123200.0f,  124150.0f,   128900.0f,  129850.0f,  130800.0f,  131750.0f,  132700.0f,  133650.0f,   138400.0f,  139350.0f,  140300.0f,  141250.0f,  142200.0f,  143150.0f,   147900.0f,  148850.0f,  149800.0f,  150750.0f,  151700.0f,  152650.0f,   157400.0f,  158350.0f,  159300.0f,  160250.0f,  161200.0f,  162150.0f,   166900.0f,  167850.0f,  168800.0f,  169750.0f,  170700.0f,  171650.0f,   350025.0f,  352850.0f,  355675.0f,  358500.0f,  361325.0f,  364150.0f,   378275.0f,  381100.0f,  383925.0f,  386750.0f,  389575.0f,  392400.0f,   406525.0f,  409350.0f,  412175.0f,  415000.0f,  417825.0f,  420650.0f,   434775.0f,  437600.0f,  440425.0f,  443250.0f,  446075.0f,  448900.0f,   463025.0f,  465850.0f,  468675.0f,  471500.0f,  474325.0f,  477150.0f,   491275.0f,  494100.0f,  496925.0f,  499750.0f,  502575.0f,  505400.0f,   353775.0f,  355350.0f,  356925.0f,  358500.0f,  360075.0f,  361650.0f,   369525.0f,  371100.0f,  372675.0f,  374250.0f,  375825.0f,  377400.0f,   385275.0f,  386850.0f,  388425.0f,  390000.0f,  391575.0f,  393150.0f,   401025.0f,  402600.0f,  404175.0f,  405750.0f,  407325.0f,  408900.0f,   416775.0f,  418350.0f,  419925.0f,  421500.0f,  423075.0f,  424650.0f,   432525.0f,  434100.0f,  435675.0f,  437250.0f,  438825.0f,  440400.0f,   771900.0f,  775350.0f,  778800.0f,  782250.0f,  785700.0f,  789150.0f,   806400.0f,  809850.0f,  813300.0f,  816750.0f,  820200.0f,  823650.0f,   840900.0f,  844350.0f,  847800.0f,  851250.0f,  854700.0f,  858150.0f,   875400.0f,  878850.0f,  882300.0f,  885750.0f,  889200.0f,  892650.0f,   909900.0f,  913350.0f,  916800.0f,  920250.0f,  923700.0f,  927150.0f,   944400.0f,  947850.0f,  951300.0f,  954750.0f,  958200.0f,  961650.0f,   107525.0f,  107850.0f,  108175.0f,  108500.0f,  108825.0f,  109150.0f,   110775.0f,  111100.0f,  111425.0f,  111750.0f,  112075.0f,  112400.0f,   114025.0f,  114350.0f,  114675.0f,  115000.0f,  115325.0f,  115650.0f,   117275.0f,  117600.0f,  117925.0f,  118250.0f,  118575.0f,  118900.0f,   120525.0f,  120850.0f,  121175.0f,  121500.0f,  121825.0f,  122150.0f,   123775.0f,  124100.0f,  124425.0f,  124750.0f,  125075.0f,  125400.0f,   713150.0f,  715350.0f,  717550.0f,  719750.0f,  721950.0f,  724150.0f,   735150.0f,  737350.0f,  739550.0f,  741750.0f,  743950.0f,  746150.0f,   757150.0f,  759350.0f,  761550.0f,  763750.0f,  765950.0f,  768150.0f,   779150.0f,  781350.0f,  783550.0f,  785750.0f,  787950.0f,  790150.0f,   801150.0f,  803350.0f,  805550.0f,  807750.0f,  809950.0f,  812150.0f,   823150.0f,  825350.0f,  827550.0f,  829750.0f,  831950.0f,  834150.0f,   404400.0f,  405350.0f,  406300.0f,  407250.0f,  408200.0f,  409150.0f,   413900.0f,  414850.0f,  415800.0f,  416750.0f,  417700.0f,  418650.0f,   423400.0f,  424350.0f,  425300.0f,  426250.0f,  427200.0f,  428150.0f,   432900.0f,  433850.0f,  434800.0f,  435750.0f,  436700.0f,  437650.0f,   442400.0f,  443350.0f,  444300.0f,  445250.0f,  446200.0f,  447150.0f,   451900.0f,  452850.0f,  453800.0f,  454750.0f,  455700.0f,  456650.0f,   1197525.0f, 1200350.0f, 1203175.0f, 1206000.0f, 1208825.0f, 1211650.0f,   1225775.0f, 1228600.0f, 1231425.0f, 1234250.0f, 1237075.0f, 1239900.0f,   1254025.0f, 1256850.0f, 1259675.0f, 1262500.0f, 1265325.0f, 1268150.0f,   1282275.0f, 1285100.0f, 1287925.0f, 1290750.0f, 1293575.0f, 1296400.0f,   1310525.0f, 1313350.0f, 1316175.0f, 1319000.0f, 1321825.0f, 1324650.0f,   1338775.0f, 1341600.0f, 1344425.0f, 1347250.0f, 1350075.0f, 1352900.0f,   826275.0f,  827850.0f,  829425.0f,  831000.0f,  832575.0f,  834150.0f,   842025.0f,  843600.0f,  845175.0f,  846750.0f,  848325.0f,  849900.0f,   857775.0f,  859350.0f,  860925.0f,  862500.0f,  864075.0f,  865650.0f,   873525.0f,  875100.0f,  876675.0f,  878250.0f,  879825.0f,  881400.0f,   889275.0f,  890850.0f,  892425.0f,  894000.0f,  895575.0f,  897150.0f,   905025.0f,  906600.0f,  908175.0f,  909750.0f,  911325.0f,  912900.0f,   1806900.0f, 1810350.0f, 1813800.0f, 1817250.0f, 1820700.0f, 1824150.0f,   1841400.0f, 1844850.0f, 1848300.0f, 1851750.0f, 1855200.0f, 1858650.0f,   1875900.0f, 1879350.0f, 1882800.0f, 1886250.0f, 1889700.0f, 1893150.0f,   1910400.0f, 1913850.0f, 1917300.0f, 1920750.0f, 1924200.0f, 1927650.0f,   1944900.0f, 1948350.0f, 1951800.0f, 1955250.0f, 1958700.0f, 1962150.0f,   1979400.0f, 1982850.0f, 1986300.0f, 1989750.0f, 1993200.0f, 1996650.};
    Nd4jLong _expSFF[] = {4, 2, 6, 6, 6, 216, 36, 6, 1, typeid(TypeParam) == typeid(float) ? 8192 : 16384, 1, 99,};
    NDArray expFF(_expBFF, _expSFF);
    TypeParam _exp2BFF[] = {827.4900282f,   832.2350283f,   836.9800284f,   841.725028f,   846.4700287f,   851.2150288f,   874.9400293f,   879.6850294f,   884.4300295f,   889.1750296f,   893.9200297f,   898.665029f,   922.3900304f,   927.1350305f,   931.8800306f,   936.6250307f,   941.3700308f,   946.1150309f,   969.8400315f,   974.5850316f,   979.3300317f,   984.0750318f,   988.8200319f,   993.5650320f,   1017.2900326f,  1022.0350327f,  1026.7800328f,  1031.5250329f,   1036.2700330f,  1041.0150331f,  1064.7400337f,  1069.4850338f,   1074.2300339f,  1078.9750340f,  1083.7200341f,  1088.4650342f,   1822.4550553f,  1833.995055f,   1845.5350558f,  1857.075056f,   1868.6150563f,  1880.1550566f,  1937.8550578f,  1949.3950581f,   1960.9350583f,  1972.4750586f,  1984.015058f,   1995.5550591f,   2053.2550604f,  2064.7950606f,  2076.3350609f,  2087.8750611f,   2099.4150614f,  2110.955061f,   2168.6550629f,  2180.1950632f,   2191.7350634f,  2203.2750637f,  2214.8150639f,  2226.3550642f,   2284.0550655f,  2295.5950657f,  2307.1350660f,  2318.6750662f,   2330.2150665f,  2341.7550667f,  2399.4550680f,  2410.9950683f,   2422.5350685f,  2434.0750688f,  2445.6150690f,  2457.1550693f,   2817.419968f,   2835.7549686f,  2854.0899683f,  2872.4249680f,   2890.7599677f,  2909.0949674f,  3000.7699660f,  3019.104965f,   3037.4399655f,  3055.7749652f,  3074.1099649f,  3092.4449646f,   3184.1199632f,  3202.4549629f,  3220.789962f,   3239.1249624f,   3257.4599621f,  3275.7949618f,  3367.4699604f,  3385.8049601f,   3404.1399598f,  3422.474959f,   3440.8099593f,  3459.1449590f,   3550.8199576f,  3569.1549573f,  3587.4899570f,  3605.8249567f,   3624.1599565f,  3642.4949562f,  3734.1699548f,  3752.5049545f,   3770.8399542f,  3789.1749539f,  3807.5099536f,  3825.8449534f,   3812.385098f,   3837.5150988f,  3862.6450994f,  3887.7751000f,   3912.9051006f,  3938.0351012f,  4063.6851041f,  4088.8151047f,   4113.9451053f,  4139.0751059f,  4164.2051065f,  4189.3351071f,   4314.9851100f,  4340.1151106f,  4365.2451112f,  4390.3751118f,   4415.5051124f,  4440.6351130f,  4566.2851159f,  4591.4151165f,   4616.5451171f,  4641.6751177f,  4666.805118f,   4691.9351188f,   4817.5851218f,  4842.7151224f,  4867.8451230f,  4892.975123f,   4918.1051241f,  4943.2351247f,  5068.8851277f,  5094.0151283f,   5119.1451288f,  5144.2751294f,  5169.4051300f,  5194.5351306f,   4807.3499803f,  4839.2749801f,  4871.1999799f,  4903.1249797f,   4935.0499795f,  4966.9749793f,  5126.5999784f,  5158.5249782f,   5190.4499780f,  5222.3749778f,  5254.2999777f,  5286.2249775f,   5445.8499765f,  5477.774976f,   5509.6999762f,  5541.6249760f,   5573.5499758f,  5605.4749756f,  5765.0999747f,  5797.0249745f,   5828.9499743f,  5860.8749741f,  5892.7999739f,  5924.724973f,   6084.3499728f,  6116.2749726f,  6148.1999724f,  6180.1249723f,   6212.0499721f,  6243.9749719f,  6403.59997f,    6435.5249708f,   6467.4499706f,  6499.3749704f,  6531.2999702f,  6563.2249700f,   5802.3150007f,  5841.0350006f,  5879.7550005f,  5918.4750004f,   5957.195000f,   5995.9150003f,  6189.5149999f,  6228.2349998f,   6266.9549997f,  6305.6749996f,  6344.3949995f,  6383.114999f,   6576.7149990f,  6615.4349990f,  6654.1549989f,  6692.8749988f,   6731.5949987f,  6770.3149986f,  6963.9149982f,  7002.6349981f,   7041.3549981f,  7080.0749980f,  7118.7949979f,  7157.5149978f,   7351.1149974f,  7389.8349973f,  7428.5549972f,  7467.2749972f,   7505.9949971f,  7544.7149970f,  7738.3149966f,  7777.0349965f,   7815.7549964f,  7854.4749963f,  7893.1949963f,  7931.9149962f,   6797.2799488f,  6842.794948f,   6888.3099489f,  6933.8249490f,   6979.3399491f,  7024.8549492f,  7252.4299497f,  7297.9449498f,   7343.4599499f,  7388.9749500f,  7434.489950f,   7480.0049501f,   7707.5799506f,  7753.0949507f,  7798.6099508f,  7844.1249509f,   7889.6399510f,  7935.1549511f,  8162.7299515f,  8208.2449516f,   8253.7599517f,  8299.2749518f,  8344.7899519f,  8390.3049520f,   8617.8799525f,  8663.394952f,   8708.9099526f,  8754.4249527f,   8799.9399528f,  8845.4549529f,  9073.0299534f,  9118.5449535f,   9164.0599536f,  9209.5749537f,  9255.089953f,   9300.604953f,   7792.2451647f,  7844.5551655f,  7896.8651663f,  7949.1751671f,   8001.4851679f,  8053.7951686f,  8315.3451725f,  8367.6551733f,   8419.9651741f,  8472.2751749f,  8524.585175f,   8576.8951764f,   8838.4451803f,  8890.7551811f,  8943.0651819f,  8995.3751827f,   9047.6851834f,  9099.9951842f,  9361.5451881f,  9413.8551889f,   9466.1651897f,  9518.475190f,   9570.7851912f,  9623.0951920f,   9884.6451959f,  9936.9551967f,  9989.2651975f, 10041.5751982f,   10093.8851990f, 10146.1951998f, 10407.7452037f, 10460.0552045f,   10512.3652053f, 10564.6752060f, 10616.9852068f, 10669.2952076f,   8787.210074f,   8846.3150748f,  8905.4200750f,  8964.5250752f,   9023.6300755f,  9082.7350757f,  9378.2600768f,  9437.3650770f,   9496.4700773f,  9555.5750775f,  9614.6800777f,  9673.7850779f,   9969.3100791f, 10028.4150793f, 10087.5200795f, 10146.625079f,   10205.7300800f, 10264.8350802f, 10560.3600813f, 10619.465081f,   10678.5700818f, 10737.6750820f, 10796.7800822f, 10855.8850825f,   11151.4100836f, 11210.5150838f, 11269.6200840f, 11328.7250843f,   11387.8300845f, 11446.9350847f, 11742.4600858f, 11801.5650861f,   11860.6700863f, 11919.7750865f, 11978.880086f,  12037.9850870f,   9782.1750935f,  9848.0750935f,  9913.9750934f,  9979.8750934f,   10045.7750934f, 10111.6750933f, 10441.1750931f, 10507.0750931f,   10572.9750931f, 10638.8750930f, 10704.7750930f, 10770.6750930f,   11100.1750928f, 11166.0750927f, 11231.9750927f, 11297.8750927f,   11363.7750926f, 11429.6750926f, 11759.1750924f, 11825.0750924f,   11890.9750923f, 11956.8750923f, 12022.7750923f, 12088.6750922f,   12418.175092f,  12484.0750920f, 12549.9750920f, 12615.8750919f,   12681.7750919f, 12747.6750919f, 13077.1750917f, 13143.0750916f,   13208.9750916f, 13274.8750916f, 13340.7750915f, 13406.6750915f,   2250.990060f,   2255.7350610f,  2260.4800611f,  2265.2250612f,   2269.9700613f,  2274.7150614f,  2298.4400619f,  2303.185062f,   2307.9300622f,  2312.6750623f,  2317.4200624f,  2322.1650625f,   2345.8900630f,  2350.6350631f,  2355.380063f,   2360.1250634f,   2364.8700635f,  2369.6150636f,  2393.3400641f,  2398.0850642f,   2402.8300643f,  2407.5750644f,  2412.320064f,   2417.0650647f,   2440.7900652f,  2445.5350653f,  2450.2800654f,  2455.0250655f,   2459.7700656f,  2464.515065f,   2488.2400663f,  2492.9850664f,   2497.7300665f,  2502.4750666f,  2507.2200667f,  2511.9650668f,   5284.4551315f,  5295.9951318f,  5307.535132f,   5319.0751323f,   5330.6151326f,  5342.1551328f,  5399.8551341f,  5411.3951343f,   5422.9351346f,  5434.475134f,   5446.0151351f,  5457.5551354f,   5515.2551366f,  5526.7951369f,  5538.3351371f,  5549.8751374f,   5561.4151376f,  5572.9551379f,  5630.6551392f,  5642.1951394f,   5653.7351397f,  5665.2751399f,  5676.8151402f,  5688.3551404f,   5746.0551417f,  5757.5951420f,  5769.1351422f,  5780.6751425f,   5792.2151427f,  5803.7551430f,  5861.455144f,   5872.9951445f,   5884.5351448f,  5896.0751450f,  5907.6151453f,  5919.1551455f,   8317.919884f,   8336.2548841f,  8354.5898838f,  8372.9248835f,   8391.2598832f,  8409.59488f,    8501.2698815f,  8519.6048813f,   8537.9398810f,  8556.2748807f,  8574.6098804f,  8592.9448801f,   8684.6198787f,  8702.9548784f,  8721.2898782f,  8739.6248779f,   8757.9598776f,  8776.2948773f,  8867.9698759f,  8886.3048756f,   8904.6398753f,  8922.9748751f,  8941.3098748f,  8959.6448745f,   9051.3198731f,  9069.6548728f,  9087.9898725f,  9106.3248722f,   9124.6598720f,  9142.9948717f,  9234.6698703f,  9253.0048700f,   9271.3398697f,  9289.6748694f,  9308.0098691f,  9326.3448689f,   11351.3852747f, 11376.5152753f, 11401.6452759f, 11426.7752765f,   11451.9052771f, 11477.0352777f, 11602.6852806f, 11627.8152812f,   11652.9452818f, 11678.0752824f, 11703.2052830f, 11728.335283f,   11853.9852865f, 11879.1152871f, 11904.2452877f, 11929.3752883f,   11954.505288f,  11979.6352894f, 12105.2852924f, 12130.4152930f,   12155.545293f,  12180.6752941f, 12205.8052947f, 12230.9352953f,   12356.5852983f, 12381.715298f,  12406.8452994f, 12431.9753000f,   12457.1053006f, 12482.2353012f, 12607.8853041f, 12633.0153047f,   12658.1453053f, 12683.2753059f, 12708.4053065f, 12733.5353071f,   14384.8499244f, 14416.7749242f, 14448.6999240f, 14480.6249238f,   14512.549923f,  14544.4749235f, 14704.0999225f, 14736.024922f,   14767.9499222f, 14799.8749220f, 14831.7999218f, 14863.7249216f,   15023.3499207f, 15055.2749205f, 15087.1999203f, 15119.1249201f,   15151.0499199f, 15182.9749197f, 15342.5999188f, 15374.5249186f,   15406.4499184f, 15438.374918f,  15470.2999181f, 15502.2249179f,   15661.84991f,   15693.7749168f, 15725.6999166f, 15757.6249164f,   15789.5499162f, 15821.4749160f, 15981.0999151f, 16013.0249149f,   16044.9499147f, 16076.8749145f, 16108.7999143f, 16140.7249142f,   17418.314976f,  17457.0349761f, 17495.7549760f, 17534.4749759f,   17573.1949758f, 17611.9149757f, 17805.5149753f, 17844.234975f,   17882.9549752f, 17921.6749751f, 17960.3949750f, 17999.1149749f,   18192.7149745f, 18231.4349744f, 18270.154974f,  18308.8749743f,   18347.5949742f, 18386.3149741f, 18579.9149737f, 18618.6349736f,   18657.3549735f, 18696.074973f,  18734.7949734f, 18773.5149733f,   18967.1149729f, 19005.8349728f, 19044.5549727f, 19083.2749726f,   19121.994972f,  19160.7149725f, 19354.3149721f, 19393.0349720f,   19431.7549719f, 19470.4749718f, 19509.1949717f, 19547.914971f,   20451.7799765f, 20497.2949766f, 20542.8099767f, 20588.3249768f,   20633.8399769f, 20679.3549770f, 20906.929977f,  20952.4449775f,   20997.9599776f, 21043.4749777f, 21088.9899778f, 21134.5049779f,   21362.0799784f, 21407.5949785f, 21453.1099786f, 21498.624978f,   21544.139978f,  21589.6549788f, 21817.2299793f, 21862.7449794f,   21908.2599795f, 21953.7749796f, 21999.2899797f, 22044.8049798f,   22272.3799802f, 22317.8949803f, 22363.4099804f, 22408.9249805f,   22454.4399806f, 22499.9549807f, 22727.529981f,  22773.044981f,   22818.5599813f, 22864.0749814f, 22909.5899815f, 22955.1049816f,   23485.2453985f, 23537.555399f,  23589.8654000f, 23642.1754008f,   23694.4854016f, 23746.7954024f, 24008.3454063f, 24060.655407f,   24112.9654078f, 24165.2754086f, 24217.5854094f, 24269.8954102f,   24531.4454141f, 24583.7554148f, 24636.0654156f, 24688.3754164f,   24740.6854172f, 24792.99541f,   25054.545421f,  25106.8554226f,   25159.1654234f, 25211.4754242f, 25263.7854250f, 25316.0954257f,   25577.6454296f, 25629.9554304f, 25682.2654312f, 25734.5754320f,   25786.8854328f, 25839.1954335f, 26100.7454374f, 26153.0554382f,   26205.3654390f, 26257.6754398f, 26309.985440f,  26362.2954413f,   26518.7101423f, 26577.8151425f, 26636.920142f,  26696.0251430f,   26755.1301432f, 26814.2351434f, 27109.7601446f, 27168.8651448f,   27227.9701450f, 27287.0751452f, 27346.1801455f, 27405.2851457f,   27700.8101468f, 27759.9151470f, 27819.0201473f, 27878.1251475f,   27937.2301477f, 27996.33514f,   28291.8601491f, 28350.9651493f,   28410.0701495f, 28469.175149f,  28528.2801500f, 28587.3851502f,   28882.9101513f, 28942.0151516f, 29001.1201518f, 29060.2251520f,   29119.3301522f, 29178.4351525f, 29473.9601536f, 29533.0651538f,   29592.1701540f, 29651.2751543f, 29710.3801545f, 29769.4851547f,   29552.1750826f, 29618.0750825f, 29683.9750825f, 29749.8750825f,   29815.7750824f, 29881.6750824f, 30211.1750822f, 30277.0750822f,   30342.9750821f, 30408.8750821f, 30474.7750821f, 30540.6750820f,   30870.175081f,  30936.0750818f, 31001.9750818f, 31067.8750817f,   31133.7750817f, 31199.6750817f, 31529.1750815f, 31595.075081f,   31660.9750814f, 31726.8750814f, 31792.7750813f, 31858.6750813f,   32188.1750811f, 32254.0750811f, 32319.975081f,  32385.8750810f,   32451.7750810f, 32517.6750809f, 32847.1750808f, 32913.0750807f,   32978.9750807f, 33044.875080f,  33110.7750806f, 33176.67508062};
    Nd4jLong _exp2SFF[] = {4, 2, 10, 6, 6, 360, 36, 6, 1, typeid(TypeParam) == typeid(float) ? 8192 : 16384, 1, 99,};
    NDArray exp2FF(_exp2BFF, _exp2SFF);

    auto input = NDArrayFactory::create<TypeParam>('c', {2, 3, 10, 10});
    auto weightsD = NDArrayFactory::create<TypeParam>('c', {2, 3, 5, 5});
    auto weightsP = NDArrayFactory::create<TypeParam>('c', {10, 6, 1, 1});


    input.linspace(1);
    weightsD.linspace(1);
    weightsP.linspace(1);
    weightsD.permutei({2,3,1,0});
    weightsP.permutei({2,3,1,0});

    weightsP.applyScalar(scalar::Divide, 10000.0);

    nd4j::ops::sconv2d op;
    auto resultFF = op.execute({&input, &weightsD}, {}, {5, 5, 1, 1, 0, 0, 1, 1, 0}, {});

    auto z = resultFF->at(0);

    ASSERT_TRUE(z->isSameShape(&expFF));
    ASSERT_TRUE(z->equalsTo(&expFF, 1));


    nd4j::ops::conv2d op2d;
    // weightsP.printShapeInfo();
    auto result2D = op2d.execute({z, &weightsP}, {}, {1, 1, 1, 1, 0, 0, 1, 1, 0, 0}, {});

    auto z2d = result2D->at(0);

    ASSERT_TRUE(z2d->isSameShape(&exp2FF));
    ASSERT_TRUE(z2d->equalsTo(&exp2FF));

    delete resultFF;
    delete result2D;
}



TEST_F(ConvolutionTests1, Test_im2col_col2im_1) {
    int kY = 5;
    int kX = 5;
    int sY = 1;
    int sX = 1;
    int pY = 0;
    int pX = 0;
    int dY = 1;
    int dX = 1;
    int inY = 28;
    int inX = 28;
    int channels = 3;

    bool isSameMode = true;

    auto x = NDArrayFactory::create<double>('c', {2, channels, inY, inX});
    x.linspace(1);

    int oY, oX;
    x.syncToDevice();
    //ASSERT_TRUE(x.isActualOnDeviceSide());
    ASSERT_TRUE(x.isActualOnHostSide());
    //x.printBuffer("x", 64);

    nd4j::ops::ConvolutionUtils::calcOutSizePool2D(oY, oX, kY, kX, sY, sX, pY, pX, dY, dX, inY, inX, isSameMode);

    if (isSameMode)
        nd4j::ops::ConvolutionUtils::calcPadding2D(pY, pX, oY, oX, inY, inX, kY, kX, sY, sX, dY, dX);

    auto im2col0 = NDArrayFactory::create<double>('c', {2, channels, kY, kX, oY, oX});

    ExtraArguments args({(double) kY, (double) kX, (double) sY, (double) sX, (double) pY, (double) pX, (double) dY, (double) dX, isSameMode ? (double) 1 : (double) 0, (double)0.0, (double) 0.});
    x.applyTransform(transform::Im2col, &im2col0, &args);

    nd4j::ops::im2col op;
    auto result2col = op.execute({&x}, {}, {kY, kX, sY, sX, pY, pX, dY, dX, isSameMode ? 1 : 0});

    auto im2col1 = result2col->at(0);

    //im2col0.printBuffer("transformed");
    //im2col1->printBuffer("customized", 64);

    ASSERT_TRUE(im2col1->isSameShape(&im2col0));
    ASSERT_TRUE(im2col1->equalsTo(&im2col0));


    ExtraArguments args2({ (double) sY, (double) sX, (double) pY, (double) pX, (double) inY, (double) inX, (double) dY, (double) dX, isSameMode ? (double) 1 : (double) 0});
    auto col2im0 = NDArrayFactory::create<double>('c', {2, channels, inY, inX});
    im2col0.applyTransform(transform::Col2Im, &col2im0, &args2);

    nd4j::ops::col2im op2im;
    auto result2im = op2im.execute({im2col1}, {}, {sY, sX, pY, pX, inY, inX, dY, dX, isSameMode ? 1 : 0});
    auto col2im1 = result2im->at(0);

    ASSERT_TRUE(col2im1->isSameShape(&col2im0));
    ASSERT_TRUE(col2im1->equalsTo(&col2im0));

    delete result2col;
    delete result2im;
}


TEST_F(ConvolutionTests1, Test_im2col_col2im_2) {
    int kY = 5;
    int kX = 5;
    int sY = 1;
    int sX = 1;
    int pY = 0;
    int pX = 0;
    int dY = 1;
    int dX = 1;
    int inY = 28;
    int inX = 28;
    int channels = 3;

    bool isSameMode = true;

    auto x = NDArrayFactory::create<double>('c', {2, channels, inY, inX});
    x.linspace(1);

    int oY, oX;

    nd4j::ops::ConvolutionUtils::calcOutSizePool2D(oY, oX, kY, kX, sY, sX, pY, pX, dY, dX, inY, inX, isSameMode);

    if (isSameMode)
        nd4j::ops::ConvolutionUtils::calcPadding2D(pY, pX, oY, oX, inY, inX, kY, kX, sY, sX, dY, dX);

    auto im2col0 = NDArrayFactory::create<double>('c', {2, channels, oY, oX, kY, kX});
    im2col0.permutei({0, 1, 4, 5, 2, 3});

    ExtraArguments args2col({(double) kY, (double) kX, (double) sY, (double) sX, (double) pY, (double) pX, (double) dY, (double) dX, isSameMode ? (double) 1 : (double) 0, (double)0.0, (double) 0.});
    x.applyTransform(transform::Im2col, &im2col0, &args2col);

    nd4j::ops::im2col op;
    auto result2col = op.execute({&x}, {}, {kY, kX, sY, sX, pY, pX, dY, dX, isSameMode ? 1 : 0});

    auto im2col1 = result2col->at(0);

    ASSERT_TRUE(im2col1->isSameShape(&im2col0));
    ASSERT_TRUE(im2col1->equalsTo(&im2col0));


    ExtraArguments args2im({ (double) sY, (double) sX, (double) pY, (double) pX, (double) inY, (double) inX, (double) dY, (double) dX, isSameMode ? (double) 1 : (double) 0});
    auto col2im0 = NDArrayFactory::create<double>('c', {2, channels, inY, inX});
    im2col0.applyTransform(transform::Col2Im, &col2im0, &args2im);

    nd4j::ops::col2im op2im;
    auto result2im = op2im.execute({im2col1}, {}, {sY, sX, pY, pX, inY, inX, dY, dX, isSameMode ? 1 : 0});
    auto col2im1 = result2im->at(0);

    ASSERT_TRUE(col2im1->isSameShape(&col2im0));
    ASSERT_TRUE(col2im1->equalsTo(&col2im0));

    delete result2col;
    delete result2im;
}

TEST_F(ConvolutionTests1, Test_im2col_col2im_3) {
    int kY = 5;
    int kX = 5;
    int sY = 1;
    int sX = 1;
    int pY = 0;
    int pX = 0;
    int dY = 1;
    int dX = 1;
    int inY = 28;
    int inX = 28;
    int channels = 3;

    bool isSameMode = true;

    auto x = NDArrayFactory::create<double>('c', {2, channels, inY, inX});
    x.linspace(1);

    int oY, oX;

    nd4j::ops::ConvolutionUtils::calcOutSizePool2D(oY, oX, kY, kX, sY, sX, pY, pX, dY, dX, inY, inX, isSameMode);

    if (isSameMode)
        nd4j::ops::ConvolutionUtils::calcPadding2D(pY, pX, oY, oX, inY, inX, kY, kX, sY, sX, dY, dX);

    auto im2col0 = NDArrayFactory::create<double>('c', {2, channels, oY, oX, kY, kX});
    im2col0.permutei({0, 1, 4, 5, 2, 3});

    auto im2col1 = NDArrayFactory::create<double>('c', {2, channels, oY, oX, kY, kX});
    im2col1.permutei({0, 1, 4, 5, 2, 3});

    ExtraArguments args2col({(double) kY, (double) kX, (double) sY, (double) sX, (double) pY, (double) pX, (double) dY, (double) dX, isSameMode ? (double) 1 : (double) 0, (double)0.0, (double) 0.});
    x.applyTransform(transform::Im2col, &im2col0, &args2col);

    nd4j::ops::im2col op;
    auto status = op.execute({&x}, {&im2col1}, {}, {kY, kX, sY, sX, pY, pX, dY, dX, isSameMode ? 1 : 0}, {});
    ASSERT_EQ(Status::OK(), status);

    ASSERT_TRUE(im2col1.isSameShape(&im2col0));
    ASSERT_TRUE(im2col1.equalsTo(&im2col0));


    ExtraArguments args2im({ (double) sY, (double) sX, (double) pY, (double) pX, (double) inY, (double) inX, (double) dY, (double) dX, isSameMode ? (double) 1 : (double) 0});
    auto col2im0 = NDArrayFactory::create<double>('c', {2, channels, inY, inX});
    im2col0.applyTransform(transform::Col2Im, &col2im0, &args2im);

    nd4j::ops::col2im op2im;
    auto result2im = op2im.execute({&im2col1}, {}, {sY, sX, pY, pX, inY, inX, dY, dX, isSameMode ? 1 : 0});
    auto col2im1 = result2im->at(0);

    ASSERT_TRUE(col2im1->isSameShape(&col2im0));
    ASSERT_TRUE(col2im1->equalsTo(&col2im0));

    delete result2im;
}


TEST_F(ConvolutionTests1, TestDeconv_bp_1) {


    double _expb[] = {  35.f,   38.f,   41.f,   44.f,   47.f,   50.f,   53.f,   56.f,   59.f,   62.f,   65.f,    68.f,   71.f,   74.f,   77.f,   80.f,   71.f,   78.f,   85.f,   92.f,   99.f,  106.f,    113.f,  120.f,  127.f,  134.f,  141.f,  148.f,  155.f,  162.f,  169.f,  176.f,  107.f,    118.f,  129.f,  140.f,  151.f,  162.f,  173.f,  184.f,  195.f,  206.f,  217.f,  228.f,    239.f,  250.f,  261.f,  272.f,  131.f,  134.f,  137.f,  140.f,  143.f,  146.f,  149.f,    152.f,  155.f,  158.f,  161.f,  164.f,  167.f,  170.f,  173.f,  176.f,  295.f,  302.f,    309.f,  316.f,  323.f,  330.f,  337.f,  344.f,  351.f,  358.f,  365.f,  372.f,  379.f,    386.f,  393.f,  400.f,  459.f,  470.f,  481.f,  492.f,  503.f,  514.f,  525.f,  536.f,    547.f,  558.f,  569.f,  580.f,  591.f,  602.f,  613.f,  624.f,  227.f,  230.f,  233.f,    236.f,  239.f,  242.f,  245.f,  248.f,  251.f,  254.f,  257.f,  260.f,  263.f,  266.f,    269.f,  272.f,  519.f,  526.f,  533.f,  540.f,  547.f,  554.f,  561.f,  568.f,  575.f,    582.f,  589.f,  596.f,  603.f,  610.f,  617.f,  624.f,  811.f,  822.f,  833.f,  844.f,    855.f,  866.f,  877.f,  888.f,  899.f,  910.f,  921.f,  932.f,  943.f,  954.f,  965.f,    976.f};
    std::shared_ptr<DataBuffer> pBuffer1 = std::make_shared<DataBuffer>(_expb, sizeof(_expb), nd4j::DataType::DOUBLE, false);
    NDArray expEpsilon(pBuffer1, 'c', {3, 3, 4, 4});

    double _expwb[] = { 160008.f,  203400.f,  191112.f,  246792.f,  222216.f,  290184.f};
    std::shared_ptr<DataBuffer> pBuffer2 = std::make_shared<DataBuffer>(_expwb, sizeof(_expwb), nd4j::DataType::DOUBLE, false);
    NDArray expGradW(pBuffer2, 'c', {3, 2, 1, 1});
    expGradW.permutei({2,3,1,0});

    double _expbb[] = {1944.f,  2712.f};
    std::shared_ptr<DataBuffer> pBuffer3 = std::make_shared<DataBuffer>(_expbb, sizeof(_expbb), nd4j::DataType::DOUBLE, false);
    NDArray expGradB(pBuffer3, 'c', {1, 2});

    auto input = NDArrayFactory::create<double>('c', {3, 3, 4, 4});
    auto bias = NDArrayFactory::create<double>('c', {1, 2});
    auto weights = NDArrayFactory::create<double>('c',{3, 2, 1, 1});
    auto epsilon = NDArrayFactory::create<double>('c', {3, 2, 4, 4});

    /*
        Input shape (3, 3, 4, 4)
        Weights shape (3, 2, 1, 1)
        Epsilon shape (3, 2, 4, 4)
     */

    input.linspace(1);
    weights.linspace(1);
    bias.linspace(1);
    epsilon.linspace(1);
    weights.permutei({2,3,1,0});

    nd4j::ops::deconv2d_bp op;

    auto result = op.execute({&input, &weights, &bias, &epsilon}, {}, {1, 1, 1, 1, 0, 0, 1, 1, 1, 0});

    ASSERT_EQ(ND4J_STATUS_OK, result->status());

    auto expNext = result->at(0);

    ASSERT_TRUE(expEpsilon.isSameShape(expNext));
    ASSERT_TRUE(expEpsilon.equalsTo(expNext));

    auto gradW = result->at(1);

    ASSERT_TRUE(expGradW.isSameShape(gradW));
    ASSERT_TRUE(expGradW.equalsTo(gradW));

    auto gradB = result->at(2);

    ASSERT_TRUE(expGradB.isSameShape(gradB));
    ASSERT_TRUE(expGradB.equalsTo(gradB));

    delete result;
}

TEST_F(ConvolutionTests1, TestDeconv_bp_2) {
    /*
     Input shape:
    [3, 3, 14, 14]
    Output shape:
    [3, 2, 15, 15]
    Weights shape:
    [3, 2, 2, 2]
    Bias shape:
    [1, 2]
    weight shape:
    [3, 2, 2, 2]
    weight grad shape:
    [3, 2, 2, 2]
    bias grad shape:
    [2]
    input epsilon shape:
    [3, 2, 15, 15]
    output epsilon shape:
    [3, 3, 14, 14]
     */
    /*
    auto input('c', {3, 3, 14, 14});
    auto bias('c', {2});
    auto weights('c',{3, 2, 2, 2});
    auto epsilon('c', {3, 2, 15, 15});


    input.linspace(1);
    weights.linspace(1);
    bias.linspace(1);
    epsilon.linspace(1);

    nd4j::ops::deconv2d_bp<double> op;

    auto result = op.execute({&input, &weights, &bias, &epsilon}, {}, {2, 2, 1, 1, 0, 0, 2, 2, 0});
    ASSERT_EQ(ND4J_STATUS_OK, result->status());


    delete result;*/
}

TEST_F(ConvolutionTests1, TestDeconv_ff_2) {

    NDArray exp('c', {3, 2, 4, 4}, {218., 227., 236., 245., 254., 263., 272., 281., 290., 299.,  308., 317., 326., 335., 344., 353., 270., 282., 294., 306.,  318., 330., 342., 354., 366., 378., 390., 402., 414., 426.,  438., 450., 650., 659., 668., 677., 686., 695., 704., 713.,  722., 731., 740., 749., 758., 767., 776., 785., 846., 858.,  870., 882., 894., 906., 918., 930., 942., 954., 966., 978.,  990., 1002., 1014., 1026., 1082., 1091., 1100., 1109., 1118., 1127.,  1136., 1145., 1154., 1163., 1172., 1181., 1190., 1199., 1208., 1217.,  1422., 1434., 1446., 1458., 1470., 1482., 1494., 1506., 1518., 1530.,  1542., 1554., 1566., 1578., 1590., 1602.});

    auto input = NDArrayFactory::create<double>('c', {3, 3, 4, 4});
    auto weights = NDArrayFactory::create<double>('c',{3, 2, 1, 1});
    auto bias = NDArrayFactory::create<double>('c', {2});

    input.linspace(1);
    weights.linspace(1);
    bias.linspace(1);
    weights.permutei({2,3,1,0});

    nd4j::ops::deconv2d op;

    auto result = op.execute({&input, &weights, &bias}, {}, {1, 1, 1, 1, 0, 0, 1, 1, 1, 0});

    ASSERT_EQ(ND4J_STATUS_OK, result->status());

    auto output = result->at(0);

    ASSERT_TRUE(exp.isSameShape(output));
    ASSERT_TRUE(exp.equalsTo(output));

    delete result;
}

TYPED_TEST(TypedConvolutionTests1, Test_Conv1D_ff_1) {
    auto input = NDArrayFactory::create<TypeParam>('c', {2, 2, 6});
    auto weights = NDArrayFactory::create<TypeParam>('c', {2, 2, 3}, {1,5,9,3,7,11,2,6,10,4,8,12});
    auto bias = NDArrayFactory::create<TypeParam>('c', {3});
    auto expFF = NDArrayFactory::create<TypeParam>('c', {2, 3, 5}, {59.0, 69.0, 79.0, 89.0, 99.0, 132.0, 158.0, 184.0, 210.0, 236.0, 205.0, 247.0, 289.0, 331.0, 373.0, 179.0, 189.0, 199.0, 209.0, 219.0, 444.0, 470.0, 496.0, 522.0, 548.0, 709.0, 751.0, 793.0, 835.0, 877.0});
    auto expEps = NDArrayFactory::create<TypeParam>('c', {2, 2, 6}, {130.0, 293.0, 326.0, 359.0, 392.0, 220.0, 166.0, 371.0, 416.0, 461.0, 506.0, 280.0, 355.0, 788.0, 821.0, 854.0, 887.0, 490.0, 481.0, 1046.0, 1091.0, 1136.0, 1181.0, 640.0});
    auto expGW = NDArrayFactory::create<TypeParam>('c', {3, 2, 2}, {1415.0, 1520.0, 2045.0, 2150.0, 1865.0, 2020.0, 2795.0, 2950.0, 2315.0, 2520.0, 3545.0, 3750.0});
    auto expGB = NDArrayFactory::create<TypeParam>('c', {3}, {105.0, 155.0, 205.0});

    expGW.permutei({2,1,0});
    input.linspace(1);
    bias.linspace(1);

    nd4j::ops::conv1d op;
    auto result_FF = op.execute({&input, &weights, &bias}, {}, {2, 1, 0, 0});

    ASSERT_EQ(ND4J_STATUS_OK, result_FF->status());

    auto z = result_FF->at(0);

    ASSERT_TRUE(expFF.isSameShape(z));
    ASSERT_TRUE(expFF.equalsTo(z));

    nd4j::ops::conv1d_bp op_bp;

    auto epsilonNxt = z->dup();
    epsilonNxt->linspace(1);

    auto result_BP = op_bp.execute({&input, &weights, &bias, epsilonNxt}, {}, {2, 1, 0, 0});
    ASSERT_EQ(ND4J_STATUS_OK, result_BP->status());

    auto eps = result_BP->at(0);
    auto gradW = result_BP->at(1);
    auto gradB = result_BP->at(2);

    ASSERT_TRUE(expEps.isSameShape(eps));
    ASSERT_TRUE(expGW.isSameShape(gradW));
    ASSERT_TRUE(expGB.isSameShape(gradB));

    ASSERT_TRUE(expEps.equalsTo(eps));
    ASSERT_TRUE(expGW.equalsTo(gradW));
    ASSERT_TRUE(expGB.equalsTo(gradB));

    delete result_FF;
    delete result_BP;
    delete epsilonNxt;
}


TYPED_TEST(TypedConvolutionTests1, Test_Conv1D_ff_2) {
    auto input = NDArrayFactory::create<TypeParam>('c', {2, 2, 6});
    auto weights = NDArrayFactory::create<TypeParam>('c', {2, 2, 3}, {1,5,9,3,7,11,2,6,10,4,8,12});

    input.linspace(1);

    nd4j::ops::conv1d op;
    auto result = op.execute({&input, &weights}, {}, {2, 1, 0, 1});

    ASSERT_EQ(ND4J_STATUS_OK, result->status());

    auto z = result->at(0);

    delete result;
}

TEST_F(ConvolutionTests1, Test_Dilation2D_1) {
    auto input = NDArrayFactory::create<double>('c', {2, 6, 6, 3});
    auto weights = NDArrayFactory::create<double>('c', {3, 2, 3});
    auto exp = NDArrayFactory::create<double>('c', {2, 3, 3, 3}, {77,   79,   81,   83,   85,   87,   80,   82,   84,  113,  115,  117, 119,  121,  123,  116,  118,  120,  107,  109,  111,  113,  115,  117, 110,  112,  114,  185,  187,  189,  191,  193,  195,  188,  190,  192, 221,  223,  225,  227,  229,  231,  224,  226,  228,  215,  217,  219, 221,  223,  225,  218,  220,  222,});

    input.linspace(1);
    weights.linspace(1);

    nd4j::ops::dilation2d op;
    auto result = op.execute({&input, &weights}, {}, {1, 1,2,2,1, 1,2,2,1});
    ASSERT_EQ(Status::OK(), result->status());

    auto z = result->at(0);

    ASSERT_TRUE(exp.isSameShape(z));
    ASSERT_TRUE(exp.equalsTo(z));

    delete result;
}

TEST_F(ConvolutionTests1, Test_Dilation2D_2) {
    auto input = NDArrayFactory::create<double>('c', {2, 6, 6, 3});
    auto weights = NDArrayFactory::create<double>('c', {3, 2, 3});
    auto exp = NDArrayFactory::create<double>('c', {2, 1, 2, 3}, {95, 97, 99, 101, 103, 105, 203, 205, 207, 209, 211, 213});

    input.linspace(1);
    weights.linspace(1);

    nd4j::ops::dilation2d op;
    auto result = op.execute({&input, &weights}, {}, {0, 1,2,2,1, 1,2,2,1});
    ASSERT_EQ(Status::OK(), result->status());

    auto z = result->at(0);

    ASSERT_TRUE(exp.isSameShape(z));
    ASSERT_TRUE(exp.equalsTo(z));

    delete result;
}

//////////////////////////////////////////////////////////////////////
TYPED_TEST(TypedConvolutionTests1, conv2d_bp_test1) {

    int bS=2, iH=4,iW=3,  iC=4,oC=3,  kH=3,kW=2,  sH=1,sW=1,  pH=0,pW=0,  dH=1,dW=1;
    int       oH=4,oW=3;
    int paddingMode = 1;             // 1-SAME, 0-VALID;
    int dataFormat  = 1;             // 1-NHWC, 0-NCHW

    auto input = NDArrayFactory::create<TypeParam>('c', {bS, iH, iW, iC});
    auto weights = NDArrayFactory::create<TypeParam>('c', {kH, kW, iC, oC});
    auto bias = NDArrayFactory::create<TypeParam>('c', {oC}, {1,2,3});
    auto gradO = NDArrayFactory::create<TypeParam>('c', {bS, oH, oW, oC});

    auto expGradI = NDArrayFactory::create<TypeParam>('c', {bS, iH, iW, iC},{ 0.226f,  0.343f,  0.46f,  0.577f, 1.172f,  1.46f,  1.748f,  2.036f, 1.892f,  2.288f,  2.684f,  3.08f, 1.284f,  1.581f,  1.878f,  2.175f, 4.458f,  5.133f,  5.808f,  6.483f, 6.186f,  7.023f,  7.86f,  8.697f,
                                                     3.39f,  3.93f,  4.47f,  5.01f, 9.642f, 10.803f, 11.964f, 13.125f,11.37f, 12.693f, 14.016f, 15.339f, 5.266f,  5.707f,  6.148f,  6.589f,12.98f, 13.916f, 14.852f, 15.788f,14.564f, 15.608f, 16.652f, 17.696f,
                                                     3.25f,  4.015f,  4.78f,  5.545f, 9.812f, 11.396f, 12.98f, 14.564f,10.532f, 12.224f, 13.916f, 15.608f, 9.708f, 10.977f, 12.246f, 13.515f,25.194f, 27.813f, 30.432f, 33.051f,26.922f, 29.703f, 32.484f, 35.265f,
                                                    11.814f, 13.326f, 14.838f, 16.35f,30.378f, 33.483f, 36.588f, 39.693f,32.106f, 35.373f, 38.64f, 41.907f,13.474f, 14.563f, 15.652f, 16.741f,31.988f, 34.22f, 36.452f, 38.684f,33.572f, 35.912f, 38.252f, 40.592f});

    auto expGradW = NDArrayFactory::create<TypeParam>('c', {kH, kW, iC, oC},{14.4f, 14.76f, 15.12f,14.4f, 14.76f, 15.12f,14.4f, 14.76f, 15.12f,14.4f, 14.76f, 15.12f, 9.24f,  9.48f,  9.72f, 9.24f,  9.48f,  9.72f, 9.24f,  9.48f,  9.72f, 9.24f,  9.48f,  9.72f,
                                                    17.04f, 17.52f, 18.f,17.04f, 17.52f, 18.f, 17.04f, 17.52f, 18.f, 17.04f, 17.52f, 18.f,10.88f, 11.2f, 11.52f,10.88f, 11.2f, 11.52f,10.88f, 11.2f, 11.52f,10.88f, 11.2f, 11.52f,
                                                    11.16f, 11.52f, 11.88f,11.16f, 11.52f, 11.88f,11.16f, 11.52f, 11.88f,11.16f, 11.52f, 11.88f, 7.08f,  7.32f,  7.56f, 7.08f,  7.32f,  7.56f, 7.08f,  7.32f,  7.56f, 7.08f,  7.32f,  7.56f});
    // auto expGradB('c', {oC},{});

    input = 2.;
    weights.linspace(0.1, 0.1);
    gradO.linspace(0.01, 0.01);

    nd4j::ops::conv2d_bp op;
    auto results = op.execute({&input, &weights, &bias, &gradO}, {}, {kH,kW,  sH,sW,  pH,pW,  dH,dW, paddingMode, dataFormat});
    auto gradI = results->at(0);
    auto gradW = results->at(1);

    ASSERT_EQ(Status::OK(), results->status());

    ASSERT_TRUE(expGradI.isSameShape(gradI));
    ASSERT_TRUE(expGradI.equalsTo(gradI));

    ASSERT_TRUE(expGradW.isSameShape(gradW));
    ASSERT_TRUE(expGradW.equalsTo(gradW));

    delete results;
}

//////////////////////////////////////////////////////////////////////
TYPED_TEST(TypedConvolutionTests1, conv2d_bp_test2) {

    int bS=2, iH=4,iW=3,  iC=4,oC=3,  kH=3,kW=2,  sH=1,sW=1,  pH=0,pW=0,  dH=1,dW=1;
    int       oH=2,oW=2;
    int paddingMode = 0;             // 1-SAME, 0-VALID;
    int dataFormat  = 1;             // 1-NHWC, 0-NCHW

    auto input = NDArrayFactory::create<TypeParam>('c', {bS, iH, iW, iC});
    auto weights = NDArrayFactory::create<TypeParam>('c', {kH, kW, iC, oC});
    auto bias = NDArrayFactory::create<TypeParam>('c', {oC}, {1,2,3});
    auto gradO = NDArrayFactory::create<TypeParam>('c', {bS, oH, oW, oC});

    auto expGradI = NDArrayFactory::create<TypeParam>('c', {bS, iH, iW, iC},{ 0.014f, 0.032f, 0.05f, 0.068f,0.118f,0.181f, 0.244f, 0.307f,0.212f,0.257f, 0.302f, 0.347f,0.208f,0.298f, 0.388f, 0.478f,1.028f,1.262f, 1.496f, 1.73f,1.036f,1.18f, 1.324f, 1.468f,
                                                     0.928f,1.018f, 1.108f, 1.198f,2.9f,3.134f, 3.368f, 3.602f,2.188f,2.332f, 2.476f, 2.62f, 1.202f,1.274f, 1.346f, 1.418f,3.142f,3.313f, 3.484f, 3.655f,2.048f,2.147f, 2.246f, 2.345f,
                                                     0.086f,0.212f, 0.338f, 0.464f,0.694f,0.973f, 1.252f, 1.531f,0.716f,0.869f, 1.022f, 1.175f,1.216f,1.522f, 1.828f, 2.134f,3.908f,4.574f, 5.24f, 5.906f,2.908f,3.268f, 3.628f, 3.988f,
                                                     3.664f,3.97f, 4.276f, 4.582f,9.236f,9.902f,10.568f,11.234f,5.788f,6.148f, 6.508f, 6.868f,3.002f,3.182f, 3.362f, 3.542f,7.174f,7.561f, 7.948f, 8.335f,4.28f,4.487f, 4.694f, 4.901f});

    auto expGradW = NDArrayFactory::create<TypeParam>('c', {kH, kW, iC, oC},{1.84f, 2.f, 2.16f,1.84f, 2.f, 2.16f,1.84f, 2.f, 2.16f,1.84f, 2.f, 2.16f,1.84f, 2.f, 2.16f,1.84f, 2.f, 2.16f,1.84f, 2.f, 2.16f,1.84f, 2.f, 2.16f,
                                                    1.84f, 2.f, 2.16f,1.84f, 2.f, 2.16f,1.84f, 2.f, 2.16f,1.84f, 2.f, 2.16f,1.84f, 2.f, 2.16f,1.84f, 2.f, 2.16f,1.84f, 2.f, 2.16f,1.84f, 2.f, 2.16f,
                                                    1.84f, 2.f, 2.16f,1.84f, 2.f, 2.16f,1.84f, 2.f, 2.16f,1.84f, 2.f, 2.16f,1.84f, 2.f, 2.16f,1.84f, 2.f, 2.16f,1.84f, 2.f, 2.16f,1.84f, 2.f, 2.16f});
    // auto expGradB('c', {oC},{});

    input = 2.;
    weights.linspace(0.1, 0.1);
    gradO.linspace(0.01, 0.01);

    nd4j::ops::conv2d_bp op;
    auto results = op.execute({&input, &weights, &bias, &gradO}, {}, {kH,kW,  sH,sW,  pH,pW,  dH,dW, paddingMode, dataFormat});
    auto gradI = results->at(0);
    auto gradW = results->at(1);

    ASSERT_EQ(Status::OK(), results->status());

    ASSERT_TRUE(expGradI.isSameShape(gradI));
    ASSERT_TRUE(expGradI.equalsTo(gradI));

    ASSERT_TRUE(expGradW.isSameShape(gradW));
    ASSERT_TRUE(expGradW.equalsTo(gradW));

    delete results;
}

//////////////////////////////////////////////////////////////////////
TYPED_TEST(TypedConvolutionTests1, conv2d_bp_test3) {

    int bS=2, iH=4,iW=3,  iC=4,oC=3,  kH=3,kW=2,  sH=1,sW=1,  pH=0,pW=0,  dH=1,dW=1;
    int       oH=2,oW=2;
    int paddingMode = 0;             // 1-SAME, 0-VALID;
    int dataFormat  = 0;             // 1-NHWC, 0-NCHW

    auto input = NDArrayFactory::create<TypeParam>('c', {bS, iC, iH, iW});
    auto weights = NDArrayFactory::create<TypeParam>('c', {oC, iC, kH, kW});
    auto bias = NDArrayFactory::create<TypeParam>('c', {oC}, {1,2,3});
    auto gradO = NDArrayFactory::create<TypeParam>('c', {bS, oC, oH, oW});

    auto expGradI = NDArrayFactory::create<TypeParam>('c', {bS, iC, iH, iW},{ 0.567, 1.224,0.66 ,1.314, 2.82 ,1.512,1.386, 2.976,1.596,0.801, 1.71 ,0.912,0.657, 1.422,0.768,1.53 , 3.288,1.764,1.602, 3.444,1.848,0.927, 1.98 ,1.056,
                                                     0.747, 1.62 ,0.876,1.746, 3.756,2.016,1.818, 3.912,2.1  ,1.053, 2.25 ,1.2  ,0.837, 1.818,0.984,1.962, 4.224,2.268,2.034, 4.38 ,2.352,1.179, 2.52 ,1.344,
                                                     1.467, 3.06 ,1.596,3.186, 6.636,3.456,3.402, 7.08 ,3.684,1.845, 3.834,1.992,1.773, 3.69 ,1.92 ,3.834, 7.968,4.14 ,4.05 , 8.412,4.368,2.187, 4.536,2.352,
                                                     2.079, 4.32 ,2.244,4.482, 9.3  ,4.824,4.698, 9.744,5.052,2.529, 5.238,2.712,2.385, 4.95 ,2.568,5.13 ,10.632,5.508,5.346,11.076,5.736,2.871, 5.94 ,3.072});

    auto expGradW = NDArrayFactory::create<TypeParam>('c', {oC, iC, kH, kW},{1.3600e+00,1.3600e+00,1.3600e+00,1.3600e+00,1.3600e+00,1.3600e+00,1.3600e+00,1.3600e+00,1.3600e+00,1.3600e+00,1.3600e+00,1.3600e+00,1.3600e+00,1.3600e+00,1.3600e+00,1.3600e+00,1.3600e+00,1.3600e+00,
                                                    1.3600e+00,1.3600e+00,1.3600e+00,1.3600e+00,1.3600e+00,1.3600e+00,2.0000e+00,2.0000e+00,2.0000e+00,2.0000e+00,2.0000e+00,2.0000e+00,2.0000e+00,2.0000e+00,2.0000e+00,2.0000e+00,2.0000e+00,2.0000e+00,
                                                    2.0000e+00,2.0000e+00,2.0000e+00,2.0000e+00,2.0000e+00,2.0000e+00,2.0000e+00,2.0000e+00,2.0000e+00,2.0000e+00,2.0000e+00,2.0000e+00,2.6400e+00,2.6400e+00,2.6400e+00,2.6400e+00,2.6400e+00,2.6400e+00,
                                                    2.6400e+00,2.6400e+00,2.6400e+00,2.6400e+00,2.6400e+00,2.6400e+00,2.6400e+00,2.6400e+00,2.6400e+00,2.6400e+00,2.6400e+00,2.6400e+00,2.6400e+00,2.6400e+00,2.6400e+00,2.6400e+00,2.6400e+00,2.6400e+00});
    auto expGradB = NDArrayFactory::create<TypeParam>('c', {oC},{0.68, 1., 1.32});

    input = 2.;
    weights.linspace(0.1, 0.1);
    gradO.linspace(0.01, 0.01);
    weights.permutei({2,3,1,0});
    expGradW.permutei({2,3,1,0});

    nd4j::ops::conv2d_bp op;
    auto results = op.execute({&input, &weights, &bias, &gradO}, {}, {kH,kW,  sH,sW,  pH,pW,  dH,dW, paddingMode, dataFormat});
    auto gradI = results->at(0);
    auto gradW = results->at(1);
    auto gradB = results->at(2);

    ASSERT_EQ(Status::OK(), results->status());

    ASSERT_TRUE(expGradI.isSameShape(gradI));
    ASSERT_TRUE(expGradI.equalsTo(gradI));

    ASSERT_TRUE(expGradW.isSameShape(gradW));
    ASSERT_TRUE(expGradW.equalsTo(gradW));

    ASSERT_TRUE(expGradB.isSameShape(gradB));
    ASSERT_TRUE(expGradB.equalsTo(gradB));

    delete results;
}

//////////////////////////////////////////////////////////////////////
TEST_F(ConvolutionTests1, conv2d_bp_4) {

    int bS=1, iH=7,iW=1,  iC=2,oC=3,  kH=2,kW=1,  sH=1,sW=1,  pH=0,pW=0,  dH=1,dW=1;
    int       oH=7,oW=1;
    int paddingMode = 1;             // 1-SAME, 0-VALID;
    int dataFormat  = 0;             // 1-NHWC, 0-NCHW

    NDArray input('c', {bS, iC, iH, iW}, nd4j::DataType::FLOAT32);
    NDArray weights('c', {kH, kW, iC, oC}, nd4j::DataType::FLOAT32);
    NDArray bias('c', {oC}, {1,2,3}, nd4j::DataType::FLOAT32);
    NDArray gradO('c', {bS, oC, oH, oW}, nd4j::DataType::FLOAT32);

    NDArray gradI('c', {bS, iC, iH, iW}, nd4j::DataType::FLOAT32);
    NDArray gradW('c', {kH, kW, iC, oC}, nd4j::DataType::FLOAT32);
    NDArray gradB('c', {oC}, nd4j::DataType::FLOAT32);


    input = 2.;
    weights.linspace(0.1, 0.1);
    gradO.linspace(0.01, 0.01);

    nd4j::ops::conv2d_bp op;
    auto status = op.execute({&input, &weights, &bias, &gradO}, {&gradI, &gradW, &gradB}, {}, {kH,kW,  sH,sW,  pH,pW,  dH,dW, paddingMode, dataFormat}, {});

    ASSERT_EQ(Status::OK(), status);
}

////////////////////////////////////////////////////////////////////
TYPED_TEST(TypedConvolutionTests1, conv3d_bp_test1) {

    int bS=2, iD=3,iH=4,iW=3,  iC=4,oC=3,  kD=2,kH=3,kW=2,  sD=1,sH=1,sW=1,  pD=0,pH=0,pW=0,  dD=1,dH=1,dW=1;
    int       oD=3,oH=4,oW=3;
    int paddingMode = 1;             // 1-SAME,  0-VALID;
    int dataFormat  = 1;             // 1-NDHWC, 0-NCDHW

    auto input    = NDArrayFactory::create<TypeParam>('c', {bS, iD, iH, iW, iC});
    auto weights  = NDArrayFactory::create<TypeParam>('c', {kD, kH, kW, iC, oC});
    auto bias     = NDArrayFactory::create<TypeParam>('c', {oC}, {1,2,3});
    auto gradO    = NDArrayFactory::create<TypeParam>('c', {bS, oD, oH, oW, oC});

    auto expGradI = NDArrayFactory::create<TypeParam>('c', {bS, iD, iH, iW, iC},{0.226,   0.343,   0.46 ,   0.577,  1.172,   1.46 ,   1.748,   2.036,  1.892,   2.288,   2.684,   3.08 ,  1.284,   1.581,   1.878,   2.175,  4.458,   5.133,   5.808,   6.483,  6.186,   7.023,   7.86 ,   8.697,  3.39 ,   3.93 ,   4.47 ,   5.01 ,  9.642,  10.803,  11.964,  13.125, 11.37 ,  12.693,  14.016,  15.339,
                                                        5.266,   5.707,   6.148,   6.589, 12.98 ,  13.916,  14.852,  15.788, 14.564,  15.608,  16.652,  17.696,  6.284,   7.166,   8.048,   8.93 , 17.896,  19.768,  21.64 ,  23.512, 21.928,  24.016,  26.104,  28.192, 18.12 ,  19.686,  21.252,  22.818, 45.852,  49.146,  52.44 ,  55.734, 53.196,  56.814,  60.432,  64.05 ,
                                                       28.164,  30.216,  32.268,  34.32 , 67.884,  72.15 ,  76.416,  80.682, 75.228,  79.818,  84.408,  88.998, 29.324,  30.854,  32.384,  33.914, 67.432,  70.6  ,  73.768,  76.936, 73.192,  76.576,  79.96 ,  83.344, 27.884,  30.062,  32.24 ,  34.418, 66.28 ,  70.744,  75.208,  79.672, 70.312,  74.992,  79.672,  84.352,
                                                       58.296,  61.806,  65.316,  68.826,133.98 , 141.162, 148.344, 155.526,141.324, 148.83 , 156.336, 163.842, 68.34 ,  72.336,  76.332,  80.328,156.012, 164.166, 172.32 , 180.474,163.356, 171.834, 180.312, 188.79 , 61.292,  64.118,  66.944,  69.77 ,136.552, 142.312, 148.072, 153.832,142.312, 148.288, 154.264, 160.24 ,
                                                        9.298,  11.359,  13.42 ,  15.481, 27.092,  31.268,  35.444,  39.62 , 27.812,  32.096,  36.38 ,  40.664, 26.556,  29.769,  32.982,  36.195, 66.666,  73.173,  79.68 ,  86.187, 68.394,  75.063,  81.732,  88.401, 28.662,  32.118,  35.574,  39.03 , 71.85 ,  78.843,  85.836,  92.829, 73.578,  80.733,  87.888,  95.043,
                                                       29.89 ,  32.275,  34.66 ,  37.045, 70.004,  74.828,  79.652,  84.476, 71.588,  76.52 ,  81.452,  86.384, 71.084,  75.854,  80.624,  85.394,163.048, 172.696, 182.344, 191.992,167.08 , 176.944, 186.808, 196.672,138.648, 146.046, 153.444, 160.842,310.236, 325.194, 340.152, 355.11 ,317.58 , 332.862, 348.144, 363.426,
                                                      148.692, 156.576, 164.46 , 172.344,332.268, 348.198, 364.128, 380.058,339.612, 355.866, 372.12 , 388.374,125.228, 130.646, 136.064, 141.482,274.792, 285.736, 296.68 , 307.624,280.552, 291.712, 302.872, 314.032, 92.684,  98.75 , 104.816, 110.882,211.432, 223.672, 235.912, 248.152,215.464, 227.92 , 240.376, 252.832,
                                                      178.824, 188.166, 197.508, 206.85 ,398.364, 417.21 , 436.056, 454.902,405.708, 424.878, 444.048, 463.218,188.868, 198.696, 208.524, 218.352,420.396, 440.214, 460.032, 479.85 ,427.74 , 447.882, 468.024, 488.166,157.196, 163.91 , 170.624, 177.338,343.912, 357.448, 370.984, 384.52 ,349.672, 363.424, 377.176, 390.928});

    auto expGradW = NDArrayFactory::create<TypeParam>('c', {kD, kH, kW, iC, oC},{120.96, 122.04, 123.12,120.96, 122.04, 123.12,120.96, 122.04, 123.12,120.96, 122.04, 123.12, 79.56,  80.28,  81.  , 79.56,  80.28,  81.  , 79.56,  80.28,  81.  , 79.56,  80.28,  81.  ,
                                                        154.8 , 156.24, 157.68,154.8 , 156.24, 157.68,154.8 , 156.24, 157.68,154.8 , 156.24, 157.68,101.76, 102.72, 103.68,101.76, 102.72, 103.68,101.76, 102.72, 103.68,101.76, 102.72, 103.68,
                                                        111.24, 112.32, 113.4 ,111.24, 112.32, 113.4 ,111.24, 112.32, 113.4 ,111.24, 112.32, 113.4 , 73.08,  73.8 ,  74.52, 73.08,  73.8 ,  74.52, 73.08,  73.8 ,  74.52, 73.08,  73.8 ,  74.52,
                                                         67.68,  68.4 ,  69.12, 67.68,  68.4 ,  69.12, 67.68,  68.4 ,  69.12, 67.68,  68.4 ,  69.12, 44.4 ,  44.88,  45.36, 44.4 ,  44.88,  45.36, 44.4 ,  44.88,  45.36, 44.4 ,  44.88,  45.36,
                                                         85.92,  86.88,  87.84, 85.92,  86.88,  87.84, 85.92,  86.88,  87.84, 85.92,  86.88,  87.84, 56.32,  56.96,  57.6 , 56.32,  56.96,  57.6 , 56.32,  56.96,  57.6 , 56.32,  56.96,  57.6 ,
                                                         61.2 ,  61.92,  62.64, 61.2 ,  61.92,  62.64, 61.2 ,  61.92,  62.64, 61.2 ,  61.92,  62.64, 40.08,  40.56,  41.04, 40.08,  40.56,  41.04, 40.08,  40.56,  41.04, 40.08,  40.56,  41.04});
    // auto expGradB('c', {oC},{});

    input = 2.;
    weights.linspace(0.1, 0.1);
    gradO.linspace(0.01, 0.01);

    nd4j::ops::conv3dnew_bp op;
    auto results = op.execute({&input, &weights, &bias, &gradO}, {}, {kD,kH,kW,  sD,sH,sW,  pD,pH,pW,  dD,dH,dW, paddingMode, dataFormat});
    auto* gradI = results->at(0);
    auto* gradW = results->at(1);

    ASSERT_EQ(Status::OK(), results->status());
    ASSERT_TRUE(expGradI.isSameShape(gradI));
    ASSERT_TRUE(expGradI.equalsTo(gradI));

    ASSERT_TRUE(expGradW.isSameShape(gradW));
    ASSERT_TRUE(expGradW.equalsTo(gradW));

    delete results;
}


////////////////////////////////////////////////////////////////////
TYPED_TEST(TypedConvolutionTests1, conv3d_bp_test2) {

    int bS=2, iD=3,iH=4,iW=3,  iC=4,oC=3,  kD=2,kH=3,kW=2,  sD=1,sH=1,sW=1,  pD=0,pH=0,pW=0,  dD=1,dH=1,dW=1;
    int       oD=2,oH=2,oW=2;
    int paddingMode = 0;             // 1-SAME,  0-VALID;
    int dataFormat  = 1;             // 1-NDHWC, 0-NCDHW

    auto input    = NDArrayFactory::create<TypeParam>('c', {bS, iD, iH, iW, iC});
    auto weights  = NDArrayFactory::create<TypeParam>('c', {kD, kH, kW, iC, oC});
    auto bias     = NDArrayFactory::create<TypeParam>('c', {oC}, {1,2,3});
    auto gradO    = NDArrayFactory::create<TypeParam>('c', {bS, oD, oH, oW, oC});

    auto expGradI = NDArrayFactory::create<TypeParam>('c', {bS, iD, iH, iW, iC},{ 0.014, 0.032, 0.05 , 0.068, 0.118, 0.181, 0.244, 0.307, 0.212, 0.257, 0.302, 0.347, 0.208, 0.298, 0.388, 0.478, 1.028, 1.262, 1.496, 1.73 , 1.036, 1.18 , 1.324, 1.468, 0.928, 1.018, 1.108, 1.198, 2.9  , 3.134, 3.368, 3.602, 2.188, 2.332, 2.476, 2.62 ,
                                                         1.202, 1.274, 1.346, 1.418, 3.142, 3.313, 3.484, 3.655, 2.048, 2.147, 2.246, 2.345, 0.532, 0.676, 0.82 , 0.964, 2.324, 2.666, 3.008, 3.35 , 2.008, 2.206, 2.404, 2.602, 3.584, 3.98 , 4.376, 4.772,10.552,11.452,12.352,13.252, 7.4  , 7.904, 8.408, 8.912,
                                                         6.752, 7.148, 7.544, 7.94 ,17.752,18.652,19.552,20.452,11.432,11.936,12.44 ,12.944, 5.932, 6.184, 6.436, 6.688,14.42 ,14.978,15.536,16.094, 8.704, 9.01 , 9.316, 9.622, 3.11 , 3.236, 3.362, 3.488, 7.39 , 7.669, 7.948, 8.227, 4.388, 4.541, 4.694, 4.847,
                                                         8.56 , 8.866, 9.172, 9.478,19.892,20.558,21.224,21.89 ,11.548,11.908,12.268,12.628,11.008,11.314,11.62 ,11.926,25.22 ,25.886,26.552,27.218,14.428,14.788,15.148,15.508, 7.322, 7.502, 7.682, 7.862,16.462,16.849,17.236,17.623, 9.248, 9.455, 9.662, 9.869,
                                                         0.158, 0.392, 0.626, 0.86 , 1.27 , 1.765, 2.26 , 2.755, 1.22 , 1.481, 1.742, 2.003, 2.224, 2.746, 3.268, 3.79 , 6.788, 7.886, 8.984,10.082, 4.78 , 5.356, 5.932, 6.508, 6.4  , 6.922, 7.444, 7.966,15.572,16.67 ,17.768,18.866, 9.388, 9.964,10.54 ,11.116,
                                                         4.802, 5.09 , 5.378, 5.666,11.206,11.809,12.412,13.015, 6.512, 6.827, 7.142, 7.457, 6.004, 6.58 , 7.156, 7.732,14.996,16.202,17.408,18.614, 9.208, 9.838,10.468,11.098,17.984,19.244,20.504,21.764,42.808,45.436,48.064,50.692,25.256,26.624,27.992,29.36 ,
                                                        28.064,29.324,30.584,31.844,63.832,66.46 ,69.088,71.716,36.2  ,37.568,38.936,40.304,18.316,19.   ,19.684,20.368,40.916,42.338,43.76 ,45.182,22.816,23.554,24.292,25.03 , 8.438, 8.78 , 9.122, 9.464,18.91 ,19.621,20.332,21.043,10.58 ,10.949,11.318,11.687,
                                                        20.944,21.682,22.42 ,23.158,46.388,47.918,49.448,50.978,25.66 ,26.452,27.244,28.036,26.848,27.586,28.324,29.062,58.628,60.158,61.688,63.218,31.996,32.788,33.58 ,34.372,16.106,16.502,16.898,17.294,34.894,35.713,36.532,37.351,18.896,19.319,19.742,20.165});

    auto expGradW = NDArrayFactory::create<TypeParam>('c', {kD, kH, kW, iC, oC},{7.52, 7.84, 8.16,7.52, 7.84, 8.16,7.52, 7.84, 8.16,7.52, 7.84, 8.16,7.52, 7.84, 8.16,7.52, 7.84, 8.16,7.52, 7.84, 8.16,7.52, 7.84, 8.16,7.52, 7.84, 8.16,7.52, 7.84, 8.16,7.52, 7.84, 8.16,7.52, 7.84, 8.16,
                                                        7.52, 7.84, 8.16,7.52, 7.84, 8.16,7.52, 7.84, 8.16,7.52, 7.84, 8.16,7.52, 7.84, 8.16,7.52, 7.84, 8.16,7.52, 7.84, 8.16,7.52, 7.84, 8.16,7.52, 7.84, 8.16,7.52, 7.84, 8.16,7.52, 7.84, 8.16,7.52, 7.84, 8.16,
                                                        7.52, 7.84, 8.16,7.52, 7.84, 8.16,7.52, 7.84, 8.16,7.52, 7.84, 8.16,7.52, 7.84, 8.16,7.52, 7.84, 8.16,7.52, 7.84, 8.16,7.52, 7.84, 8.16,7.52, 7.84, 8.16,7.52, 7.84, 8.16,7.52, 7.84, 8.16,7.52, 7.84, 8.16,
                                                        7.52, 7.84, 8.16,7.52, 7.84, 8.16,7.52, 7.84, 8.16,7.52, 7.84, 8.16,7.52, 7.84, 8.16,7.52, 7.84, 8.16,7.52, 7.84, 8.16,7.52, 7.84, 8.16,7.52, 7.84, 8.16,7.52, 7.84, 8.16,7.52, 7.84, 8.16,7.52, 7.84, 8.16});
    // auto expGradB('c', {oC},{});

    input = 2.;
    weights.linspace(0.1, 0.1);
    gradO.linspace(0.01, 0.01);

    nd4j::ops::conv3dnew_bp op;
    auto results = op.execute({&input, &weights, &bias, &gradO}, {}, {kD,kH,kW,  sD,sH,sW,  pD,pH,pW,  dD,dH,dW, paddingMode, dataFormat});
    auto gradI = results->at(0);
    auto gradW = results->at(1);

    ASSERT_EQ(Status::OK(), results->status());
    ASSERT_TRUE(expGradI.isSameShape(gradI));
    ASSERT_TRUE(expGradI.equalsTo(gradI));

    ASSERT_TRUE(expGradW.isSameShape(gradW));
    ASSERT_TRUE(expGradW.equalsTo(gradW));

    delete results;
}


////////////////////////////////////////////////////////////////////
TYPED_TEST(TypedConvolutionTests1, conv3d_bp_test3) {

    int bS=2, iD=3,iH=4,iW=3,  iC=4,oC=3,  kD=2,kH=3,kW=2,  sD=1,sH=1,sW=1,  pD=0,pH=0,pW=0,  dD=1,dH=1,dW=1;
    int       oD=2,oH=2,oW=2;
    int paddingMode = 0;             // 1-SAME,  0-VALID;
    int dataFormat  = 0;             // 1-NDHWC, 0-NCDHW

    auto input    = NDArrayFactory::create<TypeParam>('c', {bS, iC, iD, iH, iW});
    auto weights  = NDArrayFactory::create<TypeParam>('c', {oC, iC, kD, kH, kW});
    auto bias     = NDArrayFactory::create<TypeParam>('c', {oC}, {1,2,3});
    auto gradO    = NDArrayFactory::create<TypeParam>('c', {bS, oC, oD, oH, oW});

    auto expGradI = NDArrayFactory::create<TypeParam>('c', {bS, iC, iD, iH, iW},{2.091, 4.356, 2.268, 4.53 , 9.42 , 4.896, 4.65 , 9.672, 5.028, 2.517, 5.226, 2.712, 4.932,10.242, 5.316,10.62 ,22.02 ,11.412,10.908,22.62 ,11.724, 5.868,12.15 , 6.288, 2.913, 6.03 , 3.12 , 6.234,12.888, 6.66 , 6.402,13.236, 6.84 , 3.423, 7.068, 3.648,
                                                        2.415, 5.04 , 2.628, 5.25 ,10.932, 5.688, 5.37 ,11.184, 5.82 , 2.913, 6.054, 3.144, 5.724,11.898, 6.18 ,12.348,25.62 ,13.284,12.636,26.22 ,13.596, 6.804,14.094, 7.296, 3.381, 7.002, 3.624, 7.242,14.976, 7.74 , 7.41 ,15.324, 7.92 , 3.963, 8.184, 4.224,
                                                        2.739, 5.724, 2.988, 5.97 ,12.444, 6.48 , 6.09 ,12.696, 6.612, 3.309, 6.882, 3.576, 6.516,13.554, 7.044,14.076,29.22 ,15.156,14.364,29.82 ,15.468, 7.74 ,16.038, 8.304, 3.849, 7.974, 4.128, 8.25 ,17.064, 8.82 , 8.418,17.412, 9.   , 4.503, 9.3  , 4.8  ,
                                                        3.063, 6.408, 3.348, 6.69 ,13.956, 7.272, 6.81 ,14.208, 7.404, 3.705, 7.71 , 4.008, 7.308,15.21 , 7.908,15.804,32.82 ,17.028,16.092,33.42 ,17.34 , 8.676,17.982, 9.312, 4.317, 8.946, 4.632, 9.258,19.152, 9.9  , 9.426,19.5  ,10.08 , 5.043,10.416, 5.376,
                                                        5.619,11.484, 5.868,11.73 ,23.964,12.24 ,12.138,24.792,12.66 , 6.333,12.93 , 6.6  ,12.42 ,25.362,12.948,25.884,52.836,26.964,26.748,54.588,27.852,13.932,28.422,14.496, 6.873,14.022, 7.152,14.298,29.16 ,14.868,14.754,30.084,15.336, 7.671,15.636, 7.968,
                                                        6.807,13.896, 7.092,14.178,28.932,14.76 ,14.586,29.76 ,15.18 , 7.593,15.486, 7.896,14.94 ,30.474,15.54 ,31.068,63.348,32.292,31.932,65.1  ,33.18 ,16.596,33.822,17.232, 8.205,16.722, 8.52 ,17.034,34.704,17.676,17.49 ,35.628,18.144, 9.075,18.48 , 9.408,
                                                        7.995,16.308, 8.316,16.626,33.9  ,17.28 ,17.034,34.728,17.7  , 8.853,18.042, 9.192,17.46 ,35.586,18.132,36.252,73.86 ,37.62 ,37.116,75.612,38.508,19.26 ,39.222,19.968, 9.537,19.422, 9.888,19.77 ,40.248,20.484,20.226,41.172,20.952,10.479,21.324,10.848,
                                                        9.183,18.72 , 9.54 ,19.074,38.868,19.8  ,19.482,39.696,20.22 ,10.113,20.598,10.488,19.98 ,40.698,20.724,41.436,84.372,42.948,42.3  ,86.124,43.836,21.924,44.622,22.704,10.869,22.122,11.256,22.506,45.792,23.292,22.962,46.716,23.76 ,11.883,24.168,12.288});

    auto expGradW = NDArrayFactory::create<TypeParam>('c', {oC, iC, kD, kH, kW},{5.28, 5.28, 5.28, 5.28, 5.28, 5.28, 5.28, 5.28, 5.28, 5.28, 5.28, 5.28, 5.28, 5.28, 5.28, 5.28, 5.28, 5.28, 5.28, 5.28, 5.28, 5.28, 5.28, 5.28,
                                                        5.28, 5.28, 5.28, 5.28, 5.28, 5.28, 5.28, 5.28, 5.28, 5.28, 5.28, 5.28, 5.28, 5.28, 5.28, 5.28, 5.28, 5.28, 5.28, 5.28, 5.28, 5.28, 5.28, 5.28,
                                                        7.84, 7.84, 7.84, 7.84, 7.84, 7.84, 7.84, 7.84, 7.84, 7.84, 7.84, 7.84, 7.84, 7.84, 7.84, 7.84, 7.84, 7.84, 7.84, 7.84, 7.84, 7.84, 7.84, 7.84,
                                                        7.84, 7.84, 7.84, 7.84, 7.84, 7.84, 7.84, 7.84, 7.84, 7.84, 7.84, 7.84, 7.84, 7.84, 7.84, 7.84, 7.84, 7.84, 7.84, 7.84, 7.84, 7.84, 7.84, 7.84,
                                                        10.4, 10.4, 10.4, 10.4, 10.4, 10.4, 10.4, 10.4, 10.4, 10.4, 10.4, 10.4, 10.4, 10.4, 10.4, 10.4, 10.4, 10.4, 10.4, 10.4, 10.4, 10.4, 10.4, 10.4,
                                                        10.4, 10.4, 10.4, 10.4, 10.4, 10.4, 10.4, 10.4, 10.4, 10.4, 10.4, 10.4, 10.4, 10.4, 10.4, 10.4, 10.4, 10.4, 10.4, 10.4, 10.4, 10.4, 10.4, 10.4});

    auto expGradB = NDArrayFactory::create<TypeParam>('c', {oC},{2.64, 3.92, 5.2 });

    input = 2.;
    weights.linspace(0.1, 0.1);
    gradO.linspace(0.01, 0.01);
    weights.permutei({2, 3, 4, 1, 0});
    expGradW.permutei({2, 3, 4, 1, 0});

    nd4j::ops::conv3dnew_bp op;
    auto results = op.execute({&input, &weights, &bias, &gradO}, {}, {kD,kH,kW,  sD,sH,sW,  pD,pH,pW,  dD,dH,dW, paddingMode, dataFormat});
    auto* gradI = results->at(0);
    auto* gradW = results->at(1);
    auto* gradB = results->at(2);

    ASSERT_EQ(Status::OK(), results->status());

    ASSERT_TRUE(expGradI.isSameShape(gradI));
    ASSERT_TRUE(expGradI.equalsTo(gradI));

    ASSERT_TRUE(expGradW.isSameShape(gradW));
    ASSERT_TRUE(expGradW.equalsTo(gradW));

    ASSERT_TRUE(expGradB.isSameShape(gradB));
    ASSERT_TRUE(expGradB.equalsTo(gradB));

    delete results;
}

//////////////////////////////////////////////////////////////////////
TYPED_TEST(TypedConvolutionTests1, depthwise_conv2d_1) {

    int bS=2, iH=4,iW=3,  iC=2,mC=2,  kH=3,kW=2,  sH=1,sW=1,  pH=0,pW=0,  dH=1,dW=1;
    int       oC=iC*mC;
    int       oH=4,oW=3;
    int paddingMode = 1;             // 1-SAME, 0-VALID;
    int dataFormat  = 1;             // 1-NHWC, 0-NCHW

    auto input    = NDArrayFactory::create<TypeParam>('c', {bS, iH, iW, iC});
    auto weights  = NDArrayFactory::create<TypeParam>('c', {kH, kW, iC, mC});


    auto expOutput = NDArrayFactory::create<TypeParam>('c', {bS, oH, oW, oC},{12. , 12.8, 13.6, 14.4,12. , 12.8, 13.6, 14.4, 5.2,  5.6,  6. ,  6.4,13.2, 14.4, 15.6, 16.8,13.2, 14.4, 15.6, 16.8, 5.4,  6. ,  6.6,  7.2,
                                                     13.2, 14.4, 15.6, 16.8,13.2, 14.4, 15.6, 16.8, 5.4,  6. ,  6.6,  7.2, 5.6,  6.4,  7.2,  8. , 5.6,  6.4,  7.2,  8. , 2. ,  2.4,  2.8,  3.2,
                                                     12. , 12.8, 13.6, 14.4,12. , 12.8, 13.6, 14.4, 5.2,  5.6,  6. ,  6.4,13.2, 14.4, 15.6, 16.8,13.2, 14.4, 15.6, 16.8, 5.4,  6. ,  6.6,  7.2,
                                                     13.2, 14.4, 15.6, 16.8,13.2, 14.4, 15.6, 16.8, 5.4,  6. ,  6.6,  7.2, 5.6,  6.4,  7.2,  8. , 5.6,  6.4,  7.2,  8. , 2. ,  2.4,  2.8,  3.2});
    input = 2.;
    weights.linspace(0.1, 0.1);

    nd4j::ops::depthwise_conv2d op;
    auto results = op.execute({&input, &weights}, {}, {kH,kW,  sH,sW,  pH,pW,  dH,dW, paddingMode, dataFormat});
    auto* output = results->at(0);

    ASSERT_EQ(Status::OK(), results->status());

    ASSERT_TRUE(expOutput.isSameShape(output));
    ASSERT_TRUE(expOutput.equalsTo(output));

    delete results;
}

//////////////////////////////////////////////////////////////////////
TEST_F(ConvolutionTests1, depthwise_conv2d_2) {

    int bS=2, iH=4,iW=3,  iC=2,mC=2,  kH=3,kW=2,  sH=1,sW=1,  pH=0,pW=0,  dH=1,dW=1;
    int       oC=iC*mC;
    int       oH=2,oW=2;
    int paddingMode = 0;             // 1-SAME, 0-VALID;
    int dataFormat  = 1;             // 1-NHWC, 0-NCHW

    auto input    = NDArrayFactory::create<double>('c', {bS, iH, iW, iC});
    auto weights  = NDArrayFactory::create<double>('c', {kH, kW, iC, mC});


    auto expOutput = NDArrayFactory::create<double>('c', {bS, oH, oW, oC},{13.2, 14.4, 15.6, 16.8,13.2, 14.4, 15.6, 16.8,13.2, 14.4, 15.6, 16.8,13.2, 14.4, 15.6, 16.8,
                                                     13.2, 14.4, 15.6, 16.8,13.2, 14.4, 15.6, 16.8,13.2, 14.4, 15.6, 16.8,13.2, 14.4, 15.6, 16.8});
    input = 2.;
    weights.linspace(0.1, 0.1);

    nd4j::ops::depthwise_conv2d op;
    auto results = op.execute({&input, &weights}, {}, {kH,kW,  sH,sW,  pH,pW,  dH,dW, paddingMode, dataFormat});
    auto* output = results->at(0);

    ASSERT_EQ(Status::OK(), results->status());

    ASSERT_TRUE(expOutput.isSameShape(output));
    ASSERT_TRUE(expOutput.equalsTo(output));

    delete results;
}


//////////////////////////////////////////////////////////////////////
TEST_F(ConvolutionTests1, depthwise_conv2d_3) {

    int bS=2, iH=4,iW=3,  iC=2,mC=2,  kH=3,kW=2,  sH=1,sW=1,  pH=0,pW=0,  dH=1,dW=1;
    int       oC=iC*mC;
    int       oH=2,oW=2;
    int paddingMode = 0;             // 1-SAME, 0-VALID;
    int dataFormat  = 0;             // 1-NHWC, 0-NCHW

    auto input    = NDArrayFactory::create<double>('c', {bS, iC, iH, iW});
    auto weights  = NDArrayFactory::create<double>('c', {mC, iC, kH, kW});
    auto biases   = NDArrayFactory::create<double>('c', {iC*mC}, {1,2,3,4});


    auto expOutput = NDArrayFactory::create<double>('c', {bS, oC, oH, oW},{5.2, 5.2, 5.2, 5.2,20.6,20.6,20.6,20.6,14.4,14.4,14.4,14.4,29.8,29.8,29.8,29.8, 5.2, 5.2, 5.2, 5.2,20.6,20.6,20.6,20.6,14.4,14.4,14.4,14.4,29.8,29.8,29.8,29.8});
    input = 2.;
    weights.linspace(0.1, 0.1);
    weights.permutei({2,3,1,0});

    nd4j::ops::depthwise_conv2d op;
    auto results = op.execute({&input, &weights, &biases}, {}, {kH,kW,  sH,sW,  pH,pW,  dH,dW, paddingMode, dataFormat});
    auto* output = results->at(0);

    // output->printIndexedBuffer();

    ASSERT_EQ(Status::OK(), results->status());

    ASSERT_TRUE(expOutput.isSameShape(output));
    ASSERT_TRUE(expOutput.equalsTo(output));

    delete results;
}

//////////////////////////////////////////////////////////////////////
TEST_F(ConvolutionTests1, depthwise_conv2d_4) {

    int bS=1, iH=111,iW=111,  iC=32,mC=1,  kH=7,kW=7,  sH=2,sW=2,  pH=0,pW=0,  dH=1,dW=1;
    int       oC=iC*mC;
    int       oH=56,oW=56;

    int paddingMode = 1;             // 1-SAME, 0-VALID;
    int dataFormat  = 1;             // 1-NHWC, 0-NCHW

    const float unique = -1000000;

    NDArray input('c', {bS, iH, iW, iC}, nd4j::DataType::FLOAT32);
    NDArray weights('c', {kH, kW, iC, mC}, nd4j::DataType::FLOAT32);
    NDArray output('c', {bS, oH, oW, oC}, nd4j::DataType::FLOAT32);
    input.linspace(0.1, 0.0001);
    weights = 0.5;
    output = unique;

    nd4j::ops::depthwise_conv2d op;
    Nd4jStatus status = op.execute({&input, &weights}, {&output} , {}, {kH,kW,  sH,sW,  pH,pW,  dH,dW, paddingMode, dataFormat}, {});

    ASSERT_EQ(Status::OK(), status);

    for(Nd4jLong i=output.lengthOf()/1.5; i < output.lengthOf(); ++i)
        ASSERT_EQ(output.e<float>(i) != unique, true);
}

//////////////////////////////////////////////////////////////////////
TEST_F(ConvolutionTests1, depthwise_conv2d_5) {

    int bS=1, iH=3,iW=3,  iC=2,mC=1,  kH=2,kW=2,  sH=1,sW=1,  pH=0,pW=0,  dH=1,dW=1;
    int       oC=iC*mC;
    int       oH=3,oW=3;
    int paddingMode = 1;             // 1-SAME, 0-VALID;
    int dataFormat  = 1;             // 1-NHWC, 0-NCHW

    auto input   = NDArrayFactory::create<double>('c', {bS, iH, iW, iC});
    auto weights = NDArrayFactory::create<double>('c', {kH, kW, iC, mC});


    auto expOutput = NDArrayFactory::create<double>('c', {bS, oH, oW, oC}, {20., 24.,28., 32.,16., 18.,44., 48.,52., 56.,28., 30.,28., 30.,32., 34.,17., 18.});
    input.linspace(1.);
    weights = 1.;

    nd4j::ops::depthwise_conv2d op;
    auto results = op.execute({&input, &weights}, {}, {kH,kW,  sH,sW,  pH,pW,  dH,dW, paddingMode, dataFormat});
    auto output = results->at(0);
    // output->printIndexedBuffer();

    ASSERT_EQ(Status::OK(), results->status());

    ASSERT_TRUE(expOutput.isSameShape(output));
    ASSERT_TRUE(expOutput.equalsTo(output));

    delete results;
}

//////////////////////////////////////////////////////////////////////
TEST_F(ConvolutionTests1, depthwise_conv2d_6) {

    int bS=1, iH=3,iW=3,  iC=2,mC=1,  kH=2,kW=2,  sH=1,sW=1,  pH=0,pW=0,  dH=1,dW=1;
    int       oC=iC*mC;
    int       oH=3,oW=3;
    int paddingMode = 1;             // 1-SAME, 0-VALID;
    int dataFormat  = 1;             // 1-NHWC, 0-NCHW

    NDArray input('c', {bS, iH, iW, iC}, nd4j::DataType::DOUBLE);
    NDArray weights('c', {kH, kW, iC, mC}, nd4j::DataType::DOUBLE);

    NDArray expOutput('c', {bS, oH, oW, oC}, {20., 24.,28., 32.,16., 18.,44., 48.,52., 56.,28., 30.,28., 30.,32., 34.,17., 18.});
    input.linspace(1.);
    weights = 1.;

    nd4j::ops::depthwise_conv2d op;
    ResultSet* results = op.execute({&input, &weights}, {}, {kH,kW,  sH,sW,  pH,pW,  dH,dW, paddingMode, dataFormat});
    NDArray* output = results->at(0);
    // output.printIndexedBuffer();

    ASSERT_EQ(Status::OK(), results->status());

    ASSERT_TRUE(expOutput.isSameShape(output));
    ASSERT_TRUE(expOutput.equalsTo(output));

    delete results;
}

//////////////////////////////////////////////////////////////////////
TEST_F(ConvolutionTests1, depthwise_conv2d_bp_test1) {

    int bS=2, iH=4,iW=3,  iC=2,mC=2,  kH=3,kW=2,  sH=1,sW=1,  pH=0,pW=0,  dH=1,dW=1;
    int       oH=4,oW=3;
    int       oC=iC*mC;
    int paddingMode = 1;             // 1-SAME, 0-VALID;
    int dataFormat  = 1;             // 1-NHWC, 0-NCHW

    auto input    = NDArrayFactory::create<double>('c', {bS, iH, iW, iC});
    auto weights  = NDArrayFactory::create<double>('c', {kH, kW, iC, mC});
    auto bias     = NDArrayFactory::create<double>('c', {oC}, {1,2,3,4});
    auto gradO    = NDArrayFactory::create<double>('c', {bS, oH, oW, oC});

    auto expGradI = NDArrayFactory::create<double>('c', {bS, iH, iW, iC},{0.07 ,  0.19 , 0.348,  0.652, 0.588,  0.956, 0.387,  0.687, 1.326,  2.022, 1.878,  2.67 , 1.071,  1.515, 2.982,  3.966, 3.534,  4.614, 1.606,  1.982, 3.932,  4.748, 4.428,  5.308,
                                                    1.126,  1.63 , 3.228,  4.3  , 3.468,  4.604, 3.123,  3.999, 7.95 ,  9.798, 8.502, 10.446, 3.807,  4.827, 9.606, 11.742,10.158, 12.39 , 4.198,  4.958, 9.884, 11.468,10.38 , 12.028});

    auto expGradW = NDArrayFactory::create<double>('c', {kH, kW, iC, mC},{19.08, 19.44,19.8 , 20.16,12.24, 12.48,12.72, 12.96,22.56, 23.04,23.52, 24. ,14.4 , 14.72,15.04, 15.36,14.76, 15.12,15.48, 15.84, 9.36,  9.6 , 9.84, 10.08});

    input = 2.;
    weights.linspace(0.1, 0.1);
    gradO.linspace(0.01, 0.01);

    nd4j::ops::depthwise_conv2d_bp op;
    auto results = op.execute({&input, &weights, &bias, &gradO}, {}, {kH,kW,  sH,sW,  pH,pW,  dH,dW, paddingMode, dataFormat});
    auto* gradI = results->at(0);
    auto* gradW = results->at(1);

    ASSERT_EQ(Status::OK(), results->status());

    ASSERT_TRUE(expGradI.isSameShape(gradI));
    ASSERT_TRUE(expGradI.equalsTo(gradI));

    ASSERT_TRUE(expGradW.isSameShape(gradW));
    ASSERT_TRUE(expGradW.equalsTo(gradW));

    delete results;
}

//////////////////////////////////////////////////////////////////////
TEST_F(ConvolutionTests1, depthwise_conv2d_bp_test2) {

    int bS=2, iH=4,iW=3,  iC=2,mC=2,  kH=3,kW=2,  sH=1,sW=1,  pH=0,pW=0,  dH=1,dW=1;
    int       oH=2,oW=2;
    int       oC=iC*mC;
    int paddingMode = 0;             // 1-SAME, 0-VALID;
    int dataFormat  = 1;             // 1-NHWC, 0-NCHW

    auto input    = NDArrayFactory::create<double>('c', {bS, iH, iW, iC});
    auto weights  = NDArrayFactory::create<double>('c', {kH, kW, iC, mC});
    auto bias     = NDArrayFactory::create<double>('c', {oC}, {1,2,3,4});
    auto gradO    = NDArrayFactory::create<double>('c', {bS, oH, oW, oC});

    auto expGradI = NDArrayFactory::create<double>('c', {bS, iH, iW, iC},{0.005, 0.025,0.034, 0.106,0.061, 0.113,0.058, 0.162,0.292, 0.564,0.298, 0.466,0.234, 0.402,0.772, 1.172,0.602, 0.834,0.333, 0.449,0.882, 1.146,0.581, 0.729,
                                                    0.053, 0.137,0.258, 0.458,0.237, 0.353,0.41 , 0.642,1.252, 1.78 ,0.906, 1.202,1.098, 1.394,2.756, 3.412,1.722, 2.082,0.893, 1.073,2.13 , 2.522,1.269, 1.481});
    auto expGradW = NDArrayFactory::create<double>('c', {kH, kW, iC, mC},{2.4 , 2.56,2.72, 2.88,2.4 , 2.56,2.72, 2.88,2.4 , 2.56,2.72, 2.88,2.4 , 2.56,2.72, 2.88,2.4 , 2.56,2.72, 2.88,2.4 , 2.56,2.72, 2.88});

    input = 2.;
    weights.linspace(0.1, 0.1);
    gradO.linspace(0.01, 0.01);

    nd4j::ops::depthwise_conv2d_bp op;
    auto results = op.execute({&input, &weights, &bias, &gradO}, {}, {kH,kW,  sH,sW,  pH,pW,  dH,dW, paddingMode, dataFormat});
    auto* gradI = results->at(0);
    auto* gradW = results->at(1);

    ASSERT_EQ(Status::OK(), results->status());

    ASSERT_TRUE(expGradI.isSameShape(gradI));
    ASSERT_TRUE(expGradI.equalsTo(gradI));

    ASSERT_TRUE(expGradW.isSameShape(gradW));
    ASSERT_TRUE(expGradW.equalsTo(gradW));

    delete results;
}

//////////////////////////////////////////////////////////////////////
TYPED_TEST(TypedConvolutionTests1, conv3d_test1) {

    int bS=2, iD=3,iH=4,iW=3,  iC=4,oC=3,  kD=2,kH=3,kW=2,  sD=1,sH=1,sW=1,  pD=0,pH=0,pW=0,  dD=1,dH=1,dW=1;
    int paddingMode = 1;             // 1-SAME,  0-VALID;
    int dataFormat  = 1;             // 1-NDHWC, 0-NCDHW

    auto input    = NDArrayFactory::create<TypeParam>('c', {bS, iD, iH, iW, iC});
    auto weights  = NDArrayFactory::create<TypeParam>('c', {kD, kH, kW, iC, oC});
    auto expected = NDArrayFactory::create<TypeParam>('c', {2, 3, 4, 3, 3}, {64.,64.,64.,64.,64.,64.,32.,32.,32.,96.,96.,96.,96.,96.,96.,48.,48.,48.,96.,96.,96.,96.,96.,96.,48.,48.,48.,
                                                   64.,64.,64.,64.,64.,64.,32.,32.,32.,64.,64.,64.,64.,64.,64.,32.,32.,32.,96.,96.,96.,96.,96.,96.,48.,48.,48.,
                                                   96.,96.,96.,96.,96.,96.,48.,48.,48.,64.,64.,64.,64.,64.,64.,32.,32.,32.,32.,32.,32.,32.,32.,32.,16.,16.,16.,
                                                   48.,48.,48.,48.,48.,48.,24.,24.,24.,48.,48.,48.,48.,48.,48.,24.,24.,24.,32.,32.,32.,32.,32.,32.,16.,16.,16.,
                                                   64.,64.,64.,64.,64.,64.,32.,32.,32.,96.,96.,96.,96.,96.,96.,48.,48.,48.,96.,96.,96.,96.,96.,96.,48.,48.,48.,
                                                   64.,64.,64.,64.,64.,64.,32.,32.,32.,64.,64.,64.,64.,64.,64.,32.,32.,32.,96.,96.,96.,96.,96.,96.,48.,48.,48.,
                                                   96.,96.,96.,96.,96.,96.,48.,48.,48.,64.,64.,64.,64.,64.,64.,32.,32.,32.,32.,32.,32.,32.,32.,32.,16.,16.,16.,
                                                   48.,48.,48.,48.,48.,48.,24.,24.,24.,48.,48.,48.,48.,48.,48.,24.,24.,24.,32.,32.,32.,32.,32.,32.,16.,16.,16.});
    input = 2.;
    weights = 1.;

    nd4j::ops::conv3dnew op;
    auto results = op.execute({&input, &weights}, {}, {kD,kH,kW,  sD,sH,sW,  pD,pH,pW,  dD,dH,dW, paddingMode, dataFormat});
    auto* output = results->at(0);

    ASSERT_EQ(Status::OK(), results->status());
    ASSERT_TRUE(expected.isSameShape(output));
    ASSERT_TRUE(expected.equalsTo(output));

    delete results;
}


//////////////////////////////////////////////////////////////////////
TYPED_TEST(TypedConvolutionTests1, conv3d_test2) {

    int bS=2, iD=3,iH=4,iW=3,  iC=4,oC=3,  kD=2,kH=3,kW=2,  sD=1,sH=1,sW=1,  pD=0,pH=0,pW=0,  dD=1,dH=1,dW=1;
    int paddingMode = 1;             // 1-SAME,  0-VALID;
    int dataFormat  = 1;             // 1-NDHWC, 0-NCDHW

    auto input    = NDArrayFactory::create<TypeParam>('c', {bS, iD, iH, iW, iC});
    auto weights  = NDArrayFactory::create<TypeParam>('c', {kD, kH, kW, iC, oC});
    auto expected = NDArrayFactory::create<TypeParam>('c', {2, 3, 4, 3, 3}, {534.4,540.8,547.2,534.4,540.8,547.2,248. ,251.2,254.4,686.4,696. ,705.6,686.4,696. ,705.6,314.4,319.2,324. ,686.4,696. ,705.6,686.4,696. ,705.6,314.4,319.2,324. ,
                                                   380.8,387.2,393.6,380.8,387.2,393.6,171.2,174.4,177.6,534.4,540.8,547.2,534.4,540.8,547.2,248. ,251.2,254.4,686.4,696. ,705.6,686.4,696. ,705.6,314.4,319.2,324. ,
                                                   686.4,696. ,705.6,686.4,696. ,705.6,314.4,319.2,324. ,380.8,387.2,393.6,380.8,387.2,393.6,171.2,174.4,177.6,152. ,155.2,158.4,152. ,155.2,158.4, 66.4, 68. , 69.6,
                                                   170.4,175.2,180. ,170.4,175.2,180. , 70.8, 73.2, 75.6,170.4,175.2,180. ,170.4,175.2,180. , 70.8, 73.2, 75.6, 75.2, 78.4, 81.6, 75.2, 78.4, 81.6, 28. , 29.6, 31.2,
                                                   534.4,540.8,547.2,534.4,540.8,547.2,248. ,251.2,254.4,686.4,696. ,705.6,686.4,696. ,705.6,314.4,319.2,324. ,686.4,696. ,705.6,686.4,696. ,705.6,314.4,319.2,324. ,
                                                   380.8,387.2,393.6,380.8,387.2,393.6,171.2,174.4,177.6,534.4,540.8,547.2,534.4,540.8,547.2,248. ,251.2,254.4,686.4,696. ,705.6,686.4,696. ,705.6,314.4,319.2,324. ,
                                                   686.4,696. ,705.6,686.4,696. ,705.6,314.4,319.2,324. ,380.8,387.2,393.6,380.8,387.2,393.6,171.2,174.4,177.6,152. ,155.2,158.4,152. ,155.2,158.4, 66.4, 68. , 69.6,
                                                   170.4,175.2,180. ,170.4,175.2,180. , 70.8, 73.2, 75.6,170.4,175.2,180. ,170.4,175.2,180. , 70.8, 73.2, 75.6, 75.2, 78.4, 81.6, 75.2, 78.4, 81.6, 28. , 29.6, 31.2});
    input = 2.;
    weights.linspace(0.1, 0.1);

    nd4j::ops::conv3dnew op;
    auto results = op.execute({&input, &weights}, {}, {kD,kH,kW,  sD,sH,sW,  pD,pH,pW,  dD,dH,dW, paddingMode, dataFormat});
    auto* output = results->at(0);

    ASSERT_EQ(Status::OK(), results->status());
    ASSERT_TRUE(expected.isSameShape(output));
    ASSERT_TRUE(expected.equalsTo(output));

    delete results;
}

//////////////////////////////////////////////////////////////////////
TYPED_TEST(TypedConvolutionTests1, conv3d_test3) {

    int bS=2, iD=3,iH=4,iW=3,  iC=4,oC=3,  kD=2,kH=3,kW=2,  sD=1,sH=1,sW=1,  pD=0,pH=0,pW=0,  dD=1,dH=1,dW=1;
    int paddingMode = 0;             // 1-SAME,  0-VALID;
    int dataFormat  = 1;             // 1-NDHWC, 0-NCDHW

    auto input    = NDArrayFactory::create<TypeParam>('c', {bS, iD, iH, iW, iC});
    auto weights  = NDArrayFactory::create<TypeParam>('c', {kD, kH, kW, iC, oC});
    auto expected = NDArrayFactory::create<TypeParam>('c', {2, 2, 2, 2, 3},  {686.4,696.,705.6,686.4,696.,705.6,686.4,696.,705.6,686.4,696.,705.6,
                                                    686.4,696.,705.6,686.4,696.,705.6,686.4,696.,705.6,686.4,696.,705.6,
                                                    686.4,696.,705.6,686.4,696.,705.6,686.4,696.,705.6,686.4,696.,705.6,
                                                    686.4,696.,705.6,686.4,696.,705.6,686.4,696.,705.6,686.4,696.,705.6});
    input = 2.;
    weights.linspace(0.1, 0.1);

    nd4j::ops::conv3dnew op;
    auto results = op.execute({&input, &weights}, {}, {kD,kH,kW,  sD,sH,sW,  pD,pH,pW,  dD,dH,dW, paddingMode, dataFormat});
    auto* output = results->at(0);

    ASSERT_EQ(Status::OK(), results->status());
    ASSERT_TRUE(expected.isSameShape(output));
    ASSERT_TRUE(expected.equalsTo(output));

    delete results;
}


//////////////////////////////////////////////////////////////////////
TYPED_TEST(TypedConvolutionTests1, conv3d_test4) {

    int bS=2, iD=3,iH=4,iW=3,  iC=4,oC=3,  kD=2,kH=3,kW=2,  sD=1,sH=1,sW=1,  pD=0,pH=0,pW=0,  dD=1,dH=1,dW=1;
    int paddingMode = 0;             // 1-SAME,  0-VALID;
    int dataFormat =  0;             // 1-NDHWC, 0-NCDHW

    auto input    = NDArrayFactory::create<TypeParam>('c', {bS, iC, iD, iH, iW});
    auto weights  = NDArrayFactory::create<TypeParam>('c', {kD, kH, kW, iC, oC});
    auto expected = NDArrayFactory::create<TypeParam>('c', {2, 3, 2, 2, 2});
    input = 2.;
    weights = 0.5;
    expected = 48.;

    nd4j::ops::conv3dnew op;
    auto results = op.execute({&input, &weights}, {}, {kD,kH,kW,  sD,sH,sW,  pD,pH,pW,  dD,dH,dW, paddingMode, dataFormat});
    auto* output = results->at(0);

    ASSERT_EQ(Status::OK(), results->status());
    ASSERT_TRUE(expected.isSameShape(output));
    ASSERT_TRUE(expected.equalsTo(output));

    delete results;
}

////////////////////////////////////////////////////////////////////
TYPED_TEST(TypedConvolutionTests1, conv3d_test5) {

    int bS=2, iD=3,iH=4,iW=3,  iC=4,oC=3,  kD=2,kH=3,kW=2,  sD=1,sH=1,sW=1,  pD=0,pH=0,pW=0,  dD=1,dH=1,dW=1;
    int paddingMode = 0;             // 1-SAME,  0-VALID;
    int dataFormat  = 0;             // 1-NDHWC, 0-NCDHW

    auto input    = NDArrayFactory::create<TypeParam>('c', {bS, iC, iD, iH, iW});
    auto weights  = NDArrayFactory::create<TypeParam>('c', {kD, kH, kW, iC, oC});
    auto bias     = NDArrayFactory::create<TypeParam>('c', {oC});
    auto expected = NDArrayFactory::create<TypeParam>('c', {2, 3, 2, 2, 2});

    input = 2.;
    weights = 0.5;
    expected = 49.;
    bias = 1.;

    nd4j::ops::conv3dnew op;
    auto results = op.execute({&input, &weights, &bias}, {}, {kD,kH,kW,  sD,sH,sW,  pD,pH,pW,  dD,dH,dW, paddingMode, dataFormat});
    auto* output = results->at(0);

    // output->printIndexedBuffer();

    ASSERT_EQ(Status::OK(), results->status());
    ASSERT_TRUE(expected.isSameShape(output));
    ASSERT_TRUE(expected.equalsTo(output));

    delete results;
}

////////////////////////////////////////////////////////////////////
TYPED_TEST(TypedConvolutionTests1, conv3d_test6) {

    int bS=2, iD=3,iH=4,iW=3,  iC=4,oC=3,  kD=2,kH=3,kW=2,  sD=1,sH=1,sW=1,  pD=0,pH=0,pW=0,  dD=1,dH=1,dW=1;
    int paddingMode = 0;             // 1-SAME,  0-VALID;
    int dataFormat  = 0;             // 1-NDHWC, 0-NCDHW

    auto input    = NDArrayFactory::create<TypeParam>('c', {bS, iC, iD, iH, iW});
    auto weights  = NDArrayFactory::create<TypeParam>('c', {kD, kH, kW, iC, oC});
    auto bias     = NDArrayFactory::create<TypeParam>('c', {oC},{1,2,3});
    auto expected = NDArrayFactory::create<TypeParam>('c', {2, 3, 2, 2, 2},{49., 49.,49., 49., 49., 49.,49., 49., 50., 50.,50., 50., 50., 50.,50., 50.,
                                                  51., 51.,51., 51., 51., 51.,51., 51., 49., 49.,49., 49., 49., 49.,49., 49.,
                                                  50., 50.,50., 50., 50., 50.,50., 50., 51., 51.,51., 51., 51., 51.,51., 51.});
    input = 2.;
    weights = 0.5;

    nd4j::ops::conv3dnew op;
    auto results = op.execute({&input, &weights, &bias}, {}, {kD,kH,kW,  sD,sH,sW,  pD,pH,pW,  dD,dH,dW, paddingMode, dataFormat});
    auto* output = results->at(0);

    // output->printIndexedBuffer();

    ASSERT_EQ(Status::OK(), results->status());
    ASSERT_TRUE(expected.isSameShape(output));
    ASSERT_TRUE(expected.equalsTo(output));

    delete results;
}

////////////////////////////////////////////////////////////////////
TYPED_TEST(TypedConvolutionTests1, conv3d_test7) {

    int bS=2, iD=3,iH=4,iW=3,  iC=4,oC=3,  kD=2,kH=3,kW=2,  sD=1,sH=1,sW=1,  pD=0,pH=0,pW=0,  dD=1,dH=1,dW=1;
    int paddingMode = 0;             // 1-SAME,  0-VALID;
    int dataFormat  = 0;             // 1-NDHWC, 0-NCDHW

    auto input    = NDArrayFactory::create<TypeParam>('c', {bS, iC, iD, iH, iW});
    auto weights  = NDArrayFactory::create<TypeParam>('c', {oC, iC, kD, kH, kW});
    auto bias     = NDArrayFactory::create<TypeParam>('c', {oC},{1,2,3});
    auto expected = NDArrayFactory::create<TypeParam>('c', {2, 3, 2, 2, 2},{236.2, 236.2, 236.2, 236.2, 236.2, 236.2, 236.2, 236.2, 698. , 698. , 698. , 698. ,
                                                  698. , 698. , 698. , 698. ,1159.8,1159.8,1159.8,1159.8,1159.8,1159.8,1159.8,1159.8,
                                                  236.2, 236.2, 236.2, 236.2, 236.2, 236.2, 236.2, 236.2, 698. , 698. , 698. , 698. ,
                                                  698. , 698. , 698. , 698. ,1159.8,1159.8,1159.8,1159.8,1159.8,1159.8,1159.8,1159.8});
    input = 2.;
    weights.linspace(0.1, 0.1);
    weights.permutei({2, 3, 4, 1, 0});

    nd4j::ops::conv3dnew op;
    auto results = op.execute({&input, &weights, &bias}, {}, {kD,kH,kW,  sD,sH,sW,  pD,pH,pW,  dD,dH,dW, paddingMode, dataFormat});
    auto* output = results->at(0);

    // output->printIndexedBuffer();

    ASSERT_EQ(Status::OK(), results->status());
    ASSERT_TRUE(expected.isSameShape(output));
    ASSERT_TRUE(expected.equalsTo(output));

    delete results;
}

////////////////////////////////////////////////////////////////////
TYPED_TEST(TypedConvolutionTests1, conv3d_test8) {

    int bS=2, iD=3,iH=4,iW=3,  iC=4,oC=3,  kD=2,kH=3,kW=2,  sD=1,sH=1,sW=1,  pD=0,pH=0,pW=0,  dD=1,dH=1,dW=1;
    int paddingMode = 0;             // 1-SAME,  0-VALID;
    int dataFormat  = 0;             // 1-NDHWC, 0-NCDHW

    auto input    = NDArrayFactory::create<TypeParam>('c', {bS, iC, iD, iH, iW});
    auto weights  = NDArrayFactory::create<TypeParam>('c', {oC, iC, kD, kH, kW});
    auto expected = NDArrayFactory::create<TypeParam>('c', {2, 3, 2, 2, 2},{235.2, 235.2, 235.2, 235.2, 235.2, 235.2, 235.2, 235.2, 696. , 696. , 696. , 696. , 696. , 696. , 696. , 696. ,
                                                  1156.8,1156.8,1156.8,1156.8,1156.8,1156.8,1156.8,1156.8, 235.2, 235.2, 235.2, 235.2, 235.2, 235.2, 235.2, 235.2,
                                                  696. , 696. , 696. , 696. , 696. , 696. , 696. , 696. ,1156.8,1156.8,1156.8,1156.8,1156.8,1156.8,1156.8,1156.8});
    input = 2.;
    weights.linspace(0.1, 0.1);
    weights.permutei({2, 3, 4, 1, 0});

    nd4j::ops::conv3dnew op;
    auto results = op.execute({&input, &weights}, {}, {kD,kH,kW,  sD,sH,sW,  pD,pH,pW,  dD,dH,dW, paddingMode, dataFormat});
    auto* output = results->at(0);

    ASSERT_EQ(Status::OK(), results->status());
    ASSERT_TRUE(expected.isSameShape(output));
    ASSERT_TRUE(expected.equalsTo(output));

    delete results;
}

//////////////////////////////////////////////////////////////////////
TYPED_TEST(TypedConvolutionTests1, conv3d_test9) {
    auto x = NDArrayFactory::create<TypeParam>('c', {4, 2, 28, 28, 3});
    auto y = NDArrayFactory::create<TypeParam>('c', {2, 5, 5, 3, 4});
    auto e = NDArrayFactory::create<TypeParam>('c', {4, 1, 7, 10, 4});

    nd4j::ops::conv3dnew op;
    auto result = op.execute({&x, &y}, {}, {2,5,5, 5,4,3, 0,0,0, 1,1,1, 1,1});
    ASSERT_EQ(Status::OK(), result->status());

    auto z = result->at(0);

    ASSERT_TRUE(e.isSameShape(z));

    delete result;
}

TYPED_TEST(TypedConvolutionTests1, conv3d_test10) {
    auto x = NDArrayFactory::create<TypeParam>('c', {4, 2, 28, 28, 3});
    auto w = NDArrayFactory::create<TypeParam>('c', {2, 5, 5, 3, 4});
    auto exp = NDArrayFactory::create<TypeParam>('c', {4, 1, 7, 10, 4});

    nd4j::ops::conv3dnew op;
    auto result = op.execute({&x, &w}, {}, {2,5,5, 5,4,3, 0,0,0, 1,1,1, 1,1});
    ASSERT_EQ(Status::OK(), result->status());

    ShapeList shapeList({x.shapeInfo(), w.shapeInfo()});
    ContextPrototype proto;
    Context ctx(1);
    ctx.getIArguments()->push_back(2);
    ctx.getIArguments()->push_back(5);
    ctx.getIArguments()->push_back(5);

    ctx.getIArguments()->push_back(5);
    ctx.getIArguments()->push_back(4);
    ctx.getIArguments()->push_back(3);

    ctx.getIArguments()->push_back(0);
    ctx.getIArguments()->push_back(0);
    ctx.getIArguments()->push_back(0);

    ctx.getIArguments()->push_back(1);
    ctx.getIArguments()->push_back(1);
    ctx.getIArguments()->push_back(1);

    ctx.getIArguments()->push_back(0);
    ctx.getIArguments()->push_back(1);  // previous variant was "ctx.getIArguments()->push_back(0)" and this caused fail

    auto shapes = op.calculateOutputShape(&shapeList, ctx);
    ASSERT_EQ(1, shapes->size());

    auto s = shapes->at(0);

    auto z = result->at(0);
    // z->printShapeInfo("z shape");

    ASSERT_TRUE(exp.isSameShape(z));

    delete result;

    shapes->destroy();
    delete shapes;
}

//////////////////////////////////////////////////////////////////////
TYPED_TEST(TypedConvolutionTests1, pointwise_conv2d_test1) {

    int bS=2, iH=4,iW=3,  iC=4,oC=3;

    int dataFormat = 1;           // 1-NHWC, 0-NCHW

    auto input    = NDArrayFactory::create<TypeParam>('c', {bS, iH, iW, iC});
    auto weights  = NDArrayFactory::create<TypeParam>('c', {1,   1, iC, oC});
    auto bias     = NDArrayFactory::create<TypeParam>('c', {oC});


    auto expOutput = NDArrayFactory::create<TypeParam>('c', {bS, iH, iW, oC},{ 5.4, 6.2, 7.0, 5.4, 6.2, 7.0, 5.4, 6.2, 7.0, 5.4, 6.2, 7.0, 5.4, 6.2, 7.0, 5.4, 6.2, 7.0, 5.4, 6.2,
                                                      7.0, 5.4, 6.2, 7.0, 5.4, 6.2, 7.0, 5.4, 6.2, 7.0, 5.4, 6.2, 7.0, 5.4, 6.2, 7.0, 5.4, 6.2, 7.0, 5.4,
                                                      6.2, 7.0, 5.4, 6.2, 7.0, 5.4, 6.2, 7.0, 5.4, 6.2, 7.0, 5.4, 6.2, 7.0, 5.4, 6.2, 7.0, 5.4, 6.2, 7.0,
                                                      5.4, 6.2, 7.0, 5.4, 6.2, 7.0, 5.4, 6.2, 7.0, 5.4, 6.2, 7.0});
    input = 2.;
    weights.linspace(0.1, 0.1);
    bias = 1.;

    nd4j::ops::pointwise_conv2d op;
    auto results = op.execute({&input, &weights, &bias}, {}, {dataFormat});
    auto* output = results->at(0);

    ASSERT_EQ(Status::OK(), results->status());
    ASSERT_TRUE(expOutput.isSameShape(output));
    ASSERT_TRUE(expOutput.equalsTo(output));

    delete results;
}

//////////////////////////////////////////////////////////////////////
TYPED_TEST(TypedConvolutionTests1, conv3d_test11) {

    int bS=1, iD=2,iH=2,iW=2,  iC=1,oC=1,  kD=2,kH=2,kW=2,  sD=1,sH=1,sW=1,  pD=0,pH=0,pW=0,  dD=1,dH=1,dW=1;
    int paddingMode = 1;             // 1-SAME,  0-VALID;
    int dataFormat  = 0;             // 1-NDHWC, 0-NCDHW

    auto input    = NDArrayFactory::create<TypeParam>('c', {bS, iC, iD, iH, iW});
    auto weights  = NDArrayFactory::create<TypeParam>('c', {kD, kH, kW, iC, oC});

    input = 2.;
    weights = 1.;

    nd4j::ops::conv3dnew op;
    auto results = op.execute({&input, &weights}, {}, {kD,kH,kW,  sD,sH,sW,  pD,pH,pW,  dD,dH,dW, paddingMode, dataFormat});
    auto* output = results->at(0);

    ASSERT_EQ(Status::OK(), results->status());

    delete results;
}

//////////////////////////////////////////////////////////////////////
TYPED_TEST(TypedConvolutionTests1, conv3d_test12) {

    int bS=5, iD=4,iH=14,iW=14,  iC=1,oC=1,  kD=2,kH=2,kW=2,  sD=1,sH=1,sW=1,  pD=0,pH=0,pW=0,  dD=1,dH=1,dW=1;
    int       oD=3,oH=13,oW=13;
    int paddingMode = 0;             // 1-SAME,  0-VALID;
    int dataFormat  = 0;             // 1-NDHWC, 0-NCDHW

    auto input    = NDArrayFactory::create<TypeParam>('c', {bS, iC, iD, iH, iW});
    auto weights  = NDArrayFactory::create<TypeParam>('c', {kD, kH, kW, iC, oC});
    auto expected = NDArrayFactory::create<TypeParam>('c', {bS, oC, oD, oH, oW});

    input = 2.;
    weights = 1.;

    nd4j::ops::conv3dnew op;
    auto results = op.execute({&input, &weights}, {}, {kD,kH,kW,  sD,sH,sW,  pD,pH,pW,  dD,dH,dW, paddingMode, dataFormat});
    auto* output = results->at(0);

    ASSERT_EQ(Status::OK(), results->status());
    ASSERT_TRUE(output->isSameShape(&expected));

    delete results;
}

//////////////////////////////////////////////////////////////////////
TEST_F(ConvolutionTests1, vol2col_test1) {

    int bS=2, iD=2,iH=3,iW=2,  iC=3,oC=2,  kD=2,kH=3,kW=2,  sD=1,sH=1,sW=1,  pD=0,pH=0,pW=0,  dD=1,dH=1,dW=1;
    int       oD=2,oH=3,oW=2;

    NDArray volume('c', {bS, iC, iD, iH, iW}, nd4j::DataType::FLOAT32);
    NDArray columns('c', {bS, iC, kD, kH, kW, oD, oH, oW}, nd4j::DataType::FLOAT32);

    columns = -1.;
    volume.linspace(1);

    NDArray columnsExpected('c', {bS, iC, kD, kH, kW, oD, oH, oW}, {1., 2., 3., 4., 5., 6., 7., 8., 9., 10., 11., 12., 2., 0., 4., 0., 6.,0., 8., 0., 10., 0., 12., 0., 3., 4., 5., 6., 0., 0., 9., 10., 11., 12., 0., 0., 4., 0., 6., 0., 0., 0., 10., 0., 12., 0., 0., 0., 5., 6.,
0., 0., 0., 0., 11., 12., 0., 0., 0., 0., 6., 0., 0., 0., 0., 0., 12., 0., 0., 0., 0., 0., 7., 8., 9., 10., 11., 12., 0., 0., 0., 0., 0., 0., 8., 0., 10., 0., 12., 0., 0., 0., 0., 0., 0., 0., 9., 10., 11., 12., 0., 0., 0., 0., 0., 0., 0., 0., 10., 0., 12., 0., 0., 0., 0., 0.,
0., 0., 0., 0., 11., 12., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 12., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 13., 14., 15., 16., 17.,18., 19., 20., 21., 22., 23., 24., 14., 0., 16., 0., 18., 0., 20., 0., 22., 0., 24., 0., 15., 16., 17., 18., 0., 0., 21., 22., 23., 24., 0.,
0., 16., 0., 18., 0., 0., 0., 22., 0., 24., 0., 0., 0., 17., 18., 0., 0., 0., 0., 23., 24., 0., 0., 0., 0., 18., 0., 0., 0., 0., 0., 24., 0., 0., 0., 0., 0., 19., 20., 21., 22., 23., 24., 0., 0., 0., 0., 0., 0., 20., 0., 22., 0., 24., 0., 0., 0., 0., 0., 0., 0., 21., 22., 23.,
24., 0., 0., 0., 0., 0., 0., 0., 0., 22., 0., 24., 0., 0., 0., 0., 0., 0., 0., 0., 0., 23., 24., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 24.,0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 25., 26., 27., 28., 29., 30., 31., 32., 33., 34., 35., 36., 26., 0., 28., 0., 30., 0., 32., 0.,
34., 0., 36., 0., 27., 28., 29., 30., 0., 0., 33., 34., 35., 36., 0., 0., 28., 0., 30., 0., 0., 0., 34., 0., 36., 0., 0., 0., 29., 30., 0., 0., 0., 0., 35., 36., 0., 0., 0., 0., 30., 0., 0., 0., 0., 0., 36., 0., 0., 0., 0., 0., 31., 32., 33., 34., 35., 36., 0., 0., 0., 0., 0.,
0., 32., 0., 34., 0., 36., 0., 0., 0., 0., 0., 0., 0., 33., 34., 35., 36., 0., 0., 0., 0., 0., 0., 0., 0., 34., 0., 36., 0., 0., 0., 0., 0., 0., 0., 0., 0., 35., 36., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 36., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 37., 38., 39., 40.,
41., 42., 43., 44., 45., 46., 47., 48., 38., 0., 40., 0., 42., 0., 44., 0., 46., 0., 48., 0., 39., 40., 41., 42., 0., 0., 45., 46., 47., 48., 0., 0., 40., 0., 42., 0., 0., 0., 46., 0., 48., 0., 0., 0., 41., 42., 0., 0., 0., 0., 47., 48., 0., 0., 0., 0., 42., 0., 0., 0., 0.,
0., 48., 0., 0., 0., 0., 0., 43., 44., 45., 46., 47., 48., 0., 0., 0., 0., 0., 0., 44., 0., 46., 0., 48., 0., 0., 0., 0., 0., 0., 0., 45., 46., 47., 48., 0., 0., 0., 0., 0., 0., 0., 0., 46., 0., 48., 0., 0., 0., 0., 0., 0., 0., 0., 0., 47., 48., 0., 0., 0., 0., 0., 0., 0., 0.,
0., 0., 48., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 49., 50., 51., 52., 53., 54., 55., 56., 57., 58., 59., 60., 50., 0., 52., 0., 54.,0., 56., 0., 58., 0., 60., 0., 51., 52., 53., 54., 0., 0., 57., 58., 59., 60., 0., 0., 52., 0., 54., 0., 0., 0., 58., 0., 60., 0., 0., 0.,
53., 54., 0., 0., 0., 0., 59., 60., 0., 0., 0., 0., 54., 0., 0., 0., 0., 0., 60., 0., 0., 0., 0., 0., 55., 56., 57., 58., 59., 60., 0., 0.,0., 0., 0., 0., 56., 0., 58., 0., 60., 0., 0., 0., 0., 0., 0., 0., 57., 58., 59., 60., 0., 0., 0., 0., 0., 0., 0., 0., 58., 0., 60., 0.,
0., 0., 0., 0., 0., 0., 0., 0., 59., 60., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 60., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 61., 62., 63., 64., 65., 66., 67., 68., 69., 70., 71., 72., 62., 0., 64., 0., 66., 0., 68., 0., 70., 0., 72., 0., 63., 64., 65., 66., 0., 0., 69.,
70., 71., 72., 0., 0., 64., 0., 66., 0., 0., 0., 70., 0., 72., 0., 0., 0., 65., 66., 0., 0., 0., 0., 71., 72., 0., 0., 0., 0., 66., 0., 0., 0., 0., 0., 72., 0., 0., 0., 0., 0., 67., 68., 69., 70., 71., 72., 0., 0., 0., 0., 0., 0., 68., 0., 70., 0., 72., 0., 0., 0., 0., 0., 0.,
0., 69., 70., 71., 72., 0., 0., 0., 0., 0., 0., 0., 0., 70., 0., 72., 0., 0., 0., 0., 0., 0., 0., 0., 0., 71., 72., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 72., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.}, nd4j::DataType::FLOAT32);

    graph::Context context(1);
    nd4j::ops::ConvolutionUtils::vol2col(context, volume, columns, sD, sH, sW, pD, pH, pW, dD, dH, dW);
    // columns.printBuffer();

    ASSERT_TRUE(columns.equalsTo(columnsExpected));
}

//////////////////////////////////////////////////////////////////////
TEST_F(ConvolutionTests1, vol2col_test2) {

    int bS=2, iD=2,iH=3,iW=2,  iC=3,oC=2,  kD=2,kH=3,kW=2,  sD=1,sH=1,sW=1,  pD=0,pH=0,pW=0,  dD=1,dH=1,dW=1;
    int       oD=2,oH=3,oW=2;

    auto volume  = NDArrayFactory::create<float>('c', {iD, bS, iH, iC, iW});
    volume.permutei({1, 3, 0, 2, 4});
    volume.linspace(1);

    auto columns = NDArrayFactory::create<float>('c', {kD, iC, kH, oW, kW, bS, oD, oH});
    columns.permutei({5, 1, 0, 2, 4, 6, 7, 3});
    columns = -1.;
    auto columnsExpected = NDArrayFactory::create<float>('c', {bS, iC, kD, kH, kW, oD, oH, oW}, {1., 2., 3., 4., 5., 6., 7., 8., 9.,
10., 11., 12., 2., 0., 4., 0., 6., 0., 8., 0., 10., 0., 12., 0., 3., 4., 5., 6., 0., 0., 9., 10., 11., 12., 0., 0., 4., 0., 6., 0., 0.,0., 10., 0., 12., 0., 0., 0., 5., 6., 0., 0., 0., 0., 11., 12., 0., 0., 0., 0., 6., 0., 0., 0., 0., 0., 12., 0., 0., 0., 0., 0., 7., 8.,
9., 10., 11., 12., 0., 0., 0., 0., 0., 0., 8., 0., 10., 0., 12., 0., 0., 0., 0., 0., 0., 0., 9., 10., 11., 12., 0., 0., 0., 0., 0., 0., 0., 0., 10., 0., 12., 0., 0., 0., 0., 0., 0., 0., 0., 0., 11., 12., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 12., 0., 0., 0., 0., 0.,
0., 0., 0., 0., 0., 0., 13., 14., 15., 16., 17., 18., 19., 20., 21., 22., 23., 24., 14., 0., 16., 0., 18., 0., 20., 0., 22., 0., 24., 0., 15., 16., 17., 18., 0., 0., 21., 22., 23., 24., 0., 0., 16., 0., 18., 0., 0., 0., 22., 0., 24., 0., 0., 0., 17., 18., 0., 0., 0., 0.,
23., 24., 0., 0., 0., 0., 18., 0., 0., 0., 0., 0., 24., 0., 0., 0., 0., 0., 19., 20., 21., 22., 23., 24., 0., 0., 0., 0., 0., 0., 20., 0., 22., 0., 24., 0., 0., 0., 0., 0., 0., 0., 21., 22., 23., 24., 0., 0., 0., 0., 0., 0., 0., 0., 22., 0., 24., 0., 0., 0., 0., 0., 0.,
0., 0., 0., 23., 24., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 24., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 25., 26., 27., 28., 29., 30., 31., 32., 33., 34., 35., 36., 26., 0., 28., 0., 30., 0., 32., 0., 34., 0., 36., 0., 27., 28., 29., 30., 0., 0., 33., 34., 35., 36.,
0., 0., 28., 0., 30., 0., 0., 0., 34., 0., 36., 0., 0., 0., 29., 30., 0., 0., 0., 0., 35., 36., 0., 0., 0., 0., 30., 0., 0., 0., 0., 0., 36., 0., 0., 0., 0., 0., 31., 32., 33., 34., 35., 36., 0., 0., 0., 0., 0., 0., 32., 0., 34., 0., 36., 0., 0., 0., 0., 0., 0., 0., 33.,
34., 35., 36., 0., 0., 0., 0., 0., 0., 0., 0., 34., 0., 36., 0., 0., 0., 0., 0., 0., 0., 0., 0., 35., 36., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 36., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 37., 38., 39., 40., 41., 42., 43., 44., 45., 46., 47., 48., 38., 0., 40.,
0., 42., 0., 44., 0., 46., 0., 48., 0., 39., 40., 41., 42., 0., 0., 45., 46., 47., 48., 0., 0., 40., 0., 42., 0., 0., 0., 46., 0., 48., 0., 0., 0., 41., 42., 0., 0., 0., 0., 47., 48., 0., 0., 0., 0., 42., 0., 0., 0., 0., 0., 48., 0., 0., 0., 0., 0., 43., 44., 45., 46., 47.,
48., 0., 0., 0., 0., 0., 0., 44., 0., 46., 0., 48., 0., 0., 0., 0., 0., 0., 0., 45., 46., 47., 48., 0., 0., 0., 0., 0., 0., 0., 0., 46., 0., 48., 0., 0., 0., 0., 0., 0., 0., 0., 0., 47., 48., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 48., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
0., 0., 49., 50., 51., 52., 53., 54., 55., 56., 57., 58., 59., 60., 50., 0., 52., 0., 54., 0., 56., 0., 58., 0., 60., 0., 51., 52., 53., 54., 0., 0., 57., 58., 59., 60., 0., 0., 52., 0., 54., 0., 0., 0., 58., 0., 60., 0., 0., 0., 53., 54., 0., 0., 0., 0., 59., 60., 0., 0.,
0., 0., 54., 0., 0., 0., 0., 0., 60., 0., 0., 0., 0., 0., 55., 56., 57., 58., 59., 60., 0., 0., 0., 0., 0., 0., 56., 0., 58., 0., 60., 0., 0., 0., 0., 0., 0., 0., 57., 58., 59., 60., 0., 0., 0., 0., 0., 0., 0., 0., 58., 0., 60., 0., 0., 0., 0., 0., 0., 0., 0., 0., 59., 60.,
0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 60., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 61., 62., 63., 64., 65., 66., 67., 68., 69., 70., 71., 72., 62., 0., 64., 0., 66., 0., 68., 0., 70., 0., 72., 0., 63., 64., 65., 66., 0., 0., 69., 70., 71., 72., 0., 0., 64., 0., 66.,
0., 0., 0., 70., 0., 72., 0., 0., 0., 65., 66., 0., 0., 0., 0., 71., 72., 0., 0., 0., 0., 66., 0., 0., 0., 0., 0., 72., 0., 0., 0., 0., 0., 67., 68., 69., 70., 71., 72., 0., 0., 0., 0., 0., 0., 68., 0., 70., 0., 72., 0., 0., 0., 0., 0., 0., 0., 69., 70., 71., 72., 0., 0.,
0., 0., 0., 0., 0., 0., 70., 0., 72., 0., 0., 0., 0., 0., 0., 0., 0., 0., 71., 72., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 72., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.});

    graph::Context context(1);
    nd4j::ops::ConvolutionUtils::vol2col(context, volume, columns, sD, sH, sW, pD, pH, pW, dD, dH, dW);
        // columns.printBuffer();

    ASSERT_TRUE(columns.equalsTo(columnsExpected));
}

//////////////////////////////////////////////////////////////////////
TEST_F(ConvolutionTests1, col2im_test1) {

    int bS=2, iH=2,iW=2,  iC=2,   kH=2,kW=2,  sD=1,sH=1,sW=1,  pD=0,pH=0,pW=0,  dD=1,dH=1,dW=1;
    int       oH=2,oW=2;

    auto image  = NDArrayFactory::create<float>('c', {bS, iC, iH, iW});
    image = -2.;

    auto columns = NDArrayFactory::create<float>('c', {bS, iC, kH, kW, oH, oW});
    columns.linspace(1);

    auto imageExpected = NDArrayFactory::create<float>('c', {bS, iC, iH, iW}, {1., 7., 12., 34., 17., 39., 44., 98., 33., 71., 76., 162., 49., 103., 108., 226.});

    LaunchContext ctx;
    nd4j::ops::helpers::col2im(ctx, columns, image, sH, sW, pH, pW, iH, iW, dH, dW);

    ASSERT_TRUE(image.equalsTo(imageExpected));
}


//////////////////////////////////////////////////////////////////////
TEST_F(ConvolutionTests1, upsampling2d_test1) {

    const int bS=3,  iH=2,iW=2,  iC=3;
    const int factorH=2, factorW=3;
    const int isNCHW = 0;                    // data format, default is NCHW

    auto input  = NDArrayFactory::create<float>('c', {bS, iH, iW, iC});
    input.linspace(1);

    auto expOutput = NDArrayFactory::create<float>('c', {bS, iH*factorH, iW*factorW, iC}, {1.,  2.,  3., 1.,  2.,  3., 1.,  2.,  3., 4.,  5.,  6., 4.,  5.,  6., 4.,  5.,  6., 1.,  2.,  3., 1.,  2.,  3., 1.,  2.,  3., 4.,  5.,  6., 4.,  5.,  6., 4.,  5.,  6.,
                                         7.,  8.,  9., 7.,  8.,  9., 7.,  8.,  9.,10., 11., 12.,10., 11., 12.,10., 11., 12., 7.,  8.,  9., 7.,  8.,  9., 7.,  8.,  9.,10., 11., 12.,10., 11., 12.,10., 11., 12.,
                                        13., 14., 15.,13., 14., 15.,13., 14., 15.,16., 17., 18.,16., 17., 18.,16., 17., 18.,13., 14., 15.,13., 14., 15.,13., 14., 15.,16., 17., 18.,16., 17., 18.,16., 17., 18.,
                                        19., 20., 21.,19., 20., 21.,19., 20., 21.,22., 23., 24.,22., 23., 24.,22., 23., 24.,19., 20., 21.,19., 20., 21.,19., 20., 21.,22., 23., 24.,22., 23., 24.,22., 23., 24.,
                                        25., 26., 27.,25., 26., 27.,25., 26., 27.,28., 29., 30.,28., 29., 30.,28., 29., 30.,25., 26., 27.,25., 26., 27.,25., 26., 27.,28., 29., 30.,28., 29., 30.,28., 29., 30.,
                                        31., 32., 33.,31., 32., 33.,31., 32., 33.,34., 35., 36.,34., 35., 36.,34., 35., 36.,31., 32., 33.,31., 32., 33.,31., 32., 33.,34., 35., 36.,34., 35., 36.,34., 35., 36.});

    nd4j::ops::upsampling2d op;
    auto results = op.execute({&input}, {}, {factorH, factorW, isNCHW});
    auto* output = results->at(0);

    ASSERT_EQ(Status::OK(), results->status());
    ASSERT_TRUE(expOutput.isSameShape(output));
    ASSERT_TRUE(expOutput.equalsTo(output));

    delete results;
}

//////////////////////////////////////////////////////////////////////
TEST_F(ConvolutionTests1, upsampling2d_test2) {

    const int bS=3,  iH=2,iW=2,  iC=3;
    const int factorH=2, factorW=3;
    const int isNCHW = 1;                    // data format, default is NCHW

    auto input  = NDArrayFactory::create<float>('c', {bS, iC, iH, iW});
    input.linspace(1);

    auto expOutput = NDArrayFactory::create<float>('c', {bS, iC, iH*factorH, iW*factorW}, {1.,  1.,  1.,  2.,  2.,  2., 1.,  1.,  1.,  2.,  2.,  2., 3.,  3.,  3.,  4.,  4.,  4., 3.,  3.,  3.,  4.,  4.,  4.,
                                 5.,  5.,  5.,  6.,  6.,  6., 5.,  5.,  5.,  6.,  6.,  6., 7.,  7.,  7.,  8.,  8.,  8., 7.,  7.,  7.,  8.,  8.,  8., 9.,  9.,  9., 10., 10., 10., 9.,  9.,  9., 10., 10., 10.,11., 11., 11., 12., 12., 12.,11., 11., 11., 12., 12., 12.,
                                13., 13., 13., 14., 14., 14.,13., 13., 13., 14., 14., 14.,15., 15., 15., 16., 16., 16.,15., 15., 15., 16., 16., 16.,17., 17., 17., 18., 18., 18.,17., 17., 17., 18., 18., 18.,19., 19., 19., 20., 20., 20.,19., 19., 19., 20., 20., 20.,
                                21., 21., 21., 22., 22., 22.,21., 21., 21., 22., 22., 22.,23., 23., 23., 24., 24., 24.,23., 23., 23., 24., 24., 24.,25., 25., 25., 26., 26., 26.,25., 25., 25., 26., 26., 26.,27., 27., 27., 28., 28., 28.,27., 27., 27., 28., 28., 28.,
                                29., 29., 29., 30., 30., 30.,29., 29., 29., 30., 30., 30.,31., 31., 31., 32., 32., 32.,31., 31., 31., 32., 32., 32.,
                                33., 33., 33., 34., 34., 34.,33., 33., 33., 34., 34., 34.,35., 35., 35., 36., 36., 36.,35., 35., 35., 36., 36., 36.});

    nd4j::ops::upsampling2d op;
    auto results = op.execute({&input}, {}, {factorH, factorW, isNCHW});
    auto* output = results->at(0);

    ASSERT_EQ(Status::OK(), results->status());

    ASSERT_TRUE(expOutput.isSameShape(output));
    ASSERT_TRUE(expOutput.equalsTo(output));

    delete results;
}

//////////////////////////////////////////////////////////////////////
TEST_F(ConvolutionTests1, upsampling3d_test1) {

    const int bS=3,  iD=2,iH=2,iW=2,  iC=3;
    const int factorD=2,factorH=3,factorW=2;
    const int isNCDHW = 0;                    // data format, default is NCHW

    auto input  = NDArrayFactory::create<float>('c', {bS, iD, iH, iW, iC});
    input.linspace(1);

    auto expOutput = NDArrayFactory::create<float>('c', {bS, iD*factorD, iH*factorH, iW*factorW, iC}, {1.,  2.,  3., 1.,  2.,  3., 4.,  5.,  6., 4.,  5.,  6., 1.,  2.,  3., 1.,  2.,  3., 4.,  5.,  6., 4.,  5.,  6., 1.,  2.,  3., 1.,  2.,  3., 4.,  5.,  6., 4.,  5.,  6., 7.,  8.,  9., 7.,  8.,  9.,10., 11., 12.,10., 11., 12., 7.,  8.,  9., 7.,  8.,  9.,10., 11., 12.,10., 11., 12.,
             7.,  8.,  9., 7.,  8.,  9.,10., 11., 12.,10., 11., 12., 1.,  2.,  3., 1.,  2.,  3., 4.,  5.,  6., 4.,  5.,  6., 1.,  2.,  3., 1.,  2.,  3., 4.,  5.,  6., 4.,  5.,  6., 1.,  2.,  3., 1.,  2.,  3., 4.,  5.,  6., 4.,  5.,  6., 7.,  8.,  9., 7.,  8.,  9.,10., 11., 12.,10., 11., 12.,
             7.,  8.,  9., 7.,  8.,  9.,10., 11., 12.,10., 11., 12., 7.,  8.,  9., 7.,  8.,  9.,10., 11., 12.,10., 11., 12.,13., 14., 15.,13., 14., 15.,16., 17., 18.,16., 17., 18.,13., 14., 15.,13., 14., 15.,16., 17., 18.,16., 17., 18.,13., 14., 15.,13., 14., 15.,16., 17., 18.,16., 17., 18.,
            19., 20., 21.,19., 20., 21.,22., 23., 24.,22., 23., 24.,19., 20., 21.,19., 20., 21.,22., 23., 24.,22., 23., 24.,19., 20., 21.,19., 20., 21.,22., 23., 24.,22., 23., 24.,13., 14., 15.,13., 14., 15.,16., 17., 18.,16., 17., 18.,13., 14., 15.,13., 14., 15.,16., 17., 18.,16., 17., 18.,
            13., 14., 15.,13., 14., 15.,16., 17., 18.,16., 17., 18.,19., 20., 21.,19., 20., 21.,22., 23., 24.,22., 23., 24.,19., 20., 21.,19., 20., 21.,22., 23., 24.,22., 23., 24.,19., 20., 21.,19., 20., 21.,22., 23., 24.,22., 23., 24.,25., 26., 27.,25., 26., 27.,28., 29., 30.,28., 29., 30.,
            25., 26., 27.,25., 26., 27.,28., 29., 30.,28., 29., 30.,25., 26., 27.,25., 26., 27.,28., 29., 30.,28., 29., 30.,31., 32., 33.,31., 32., 33.,34., 35., 36.,34., 35., 36.,31., 32., 33.,31., 32., 33.,34., 35., 36.,34., 35., 36.,31., 32., 33.,31., 32., 33.,34., 35., 36.,34., 35., 36.,
            25., 26., 27.,25., 26., 27.,28., 29., 30.,28., 29., 30.,25., 26., 27.,25., 26., 27.,28., 29., 30.,28., 29., 30.,25., 26., 27.,25., 26., 27.,28., 29., 30.,28., 29., 30.,31., 32., 33.,31., 32., 33.,34., 35., 36.,34., 35., 36.,31., 32., 33.,31., 32., 33.,34., 35., 36.,34., 35., 36.,
            31., 32., 33.,31., 32., 33.,34., 35., 36.,34., 35., 36.,37., 38., 39.,37., 38., 39.,40., 41., 42.,40., 41., 42.,37., 38., 39.,37., 38., 39.,40., 41., 42.,40., 41., 42.,37., 38., 39.,37., 38., 39.,40., 41., 42.,40., 41., 42.,43., 44., 45.,43., 44., 45.,46., 47., 48.,46., 47., 48.,
            43., 44., 45.,43., 44., 45.,46., 47., 48.,46., 47., 48.,43., 44., 45.,43., 44., 45.,46., 47., 48.,46., 47., 48.,37., 38., 39.,37., 38., 39.,40., 41., 42.,40., 41., 42.,37., 38., 39.,37., 38., 39.,40., 41., 42.,40., 41., 42.,37., 38., 39.,37., 38., 39.,40., 41., 42.,40., 41., 42.,
            43., 44., 45.,43., 44., 45.,46., 47., 48.,46., 47., 48.,43., 44., 45.,43., 44., 45.,46., 47., 48.,46., 47., 48.,43., 44., 45.,43., 44., 45.,46., 47., 48.,46., 47., 48.,49., 50., 51.,49., 50., 51.,52., 53., 54.,52., 53., 54.,49., 50., 51.,49., 50., 51.,52., 53., 54.,52., 53., 54.,
            49., 50., 51.,49., 50., 51.,52., 53., 54.,52., 53., 54.,55., 56., 57.,55., 56., 57.,58., 59., 60.,58., 59., 60.,55., 56., 57.,55., 56., 57.,58., 59., 60.,58., 59., 60.,55., 56., 57.,55., 56., 57.,58., 59., 60.,58., 59., 60.,49., 50., 51.,49., 50., 51.,52., 53., 54.,52., 53., 54.,
            49., 50., 51.,49., 50., 51.,52., 53., 54.,52., 53., 54.,49., 50., 51.,49., 50., 51.,52., 53., 54.,52., 53., 54.,55., 56., 57.,55., 56., 57.,58., 59., 60.,58., 59., 60.,55., 56., 57.,55., 56., 57.,58., 59., 60.,58., 59., 60.,55., 56., 57.,55., 56., 57.,58., 59., 60.,58., 59., 60.,
            61., 62., 63.,61., 62., 63.,64., 65., 66.,64., 65., 66.,61., 62., 63.,61., 62., 63.,64., 65., 66.,64., 65., 66.,61., 62., 63.,61., 62., 63.,64., 65., 66.,64., 65., 66.,67., 68., 69.,67., 68., 69.,70., 71., 72.,70., 71., 72.,67., 68., 69.,67., 68., 69.,70., 71., 72.,70., 71., 72.,
            67., 68., 69.,67., 68., 69.,70., 71., 72.,70., 71., 72.,61., 62., 63.,61., 62., 63.,64., 65., 66.,64., 65., 66.,61., 62., 63.,61., 62., 63.,64., 65., 66.,64., 65., 66.,61., 62., 63.,61., 62., 63.,64., 65., 66.,64., 65., 66.,67., 68., 69.,67., 68., 69.,70., 71., 72.,70., 71., 72.,
            67., 68., 69.,67., 68., 69.,70., 71., 72.,70., 71., 72.,67., 68., 69.,67., 68., 69.,70., 71., 72.,70., 71., 72.});

    nd4j::ops::upsampling3d op;
    auto results = op.execute({&input}, {}, {factorD, factorH, factorW, isNCDHW});
    auto* output = results->at(0);

    ASSERT_EQ(Status::OK(), results->status());
    ASSERT_TRUE(expOutput.isSameShape(output));
    ASSERT_TRUE(expOutput.equalsTo(output));

    delete results;
}

//////////////////////////////////////////////////////////////////////
TEST_F(ConvolutionTests1, upsampling3d_test2) {

    const int bS=3,  iD=2,iH=2,iW=2,  iC=3;
    const int factorD=2,factorH=3,factorW=2;
    const int isNCDHW = 1;                    // data format, default is NCHW

    auto input = NDArrayFactory::create<float>('c', {bS, iC, iD, iH, iW});
    input.linspace(1);

    auto expOutput = NDArrayFactory::create<float>('c', {bS, iC, iD*factorD, iH*factorH, iW*factorW}, { 1.,  1.,  2.,  2., 1.,  1.,  2.,  2., 1.,  1.,  2.,  2., 3.,  3.,  4.,  4., 3.,  3.,  4.,  4., 3.,  3.,  4.,  4., 1.,  1.,  2.,  2., 1.,  1.,  2.,  2., 1.,  1.,  2.,  2., 3.,  3.,  4.,  4., 3.,  3.,  4.,  4., 3.,  3.,  4.,  4., 5.,  5.,  6.,  6., 5.,  5.,  6.,  6., 5.,  5.,  6.,  6., 7.,  7.,  8.,  8., 7.,  7.,  8.,  8., 7.,  7.,  8.,  8.,
             5.,  5.,  6.,  6., 5.,  5.,  6.,  6., 5.,  5.,  6.,  6., 7.,  7.,  8.,  8., 7.,  7.,  8.,  8., 7.,  7.,  8.,  8., 9.,  9., 10., 10., 9.,  9., 10., 10., 9.,  9., 10., 10.,11., 11., 12., 12.,11., 11., 12., 12.,11., 11., 12., 12., 9.,  9., 10., 10., 9.,  9., 10., 10., 9.,  9., 10., 10.,11., 11., 12., 12.,11., 11., 12., 12.,11., 11., 12., 12.,
            13., 13., 14., 14.,13., 13., 14., 14.,13., 13., 14., 14.,15., 15., 16., 16.,15., 15., 16., 16.,15., 15., 16., 16.,13., 13., 14., 14.,13., 13., 14., 14.,13., 13., 14., 14.,15., 15., 16., 16.,15., 15., 16., 16.,15., 15., 16., 16.,17., 17., 18., 18.,17., 17., 18., 18.,17., 17., 18., 18.,19., 19., 20., 20.,19., 19., 20., 20.,19., 19., 20., 20.,
            17., 17., 18., 18.,17., 17., 18., 18.,17., 17., 18., 18.,19., 19., 20., 20.,19., 19., 20., 20.,19., 19., 20., 20.,21., 21., 22., 22.,21., 21., 22., 22.,21., 21., 22., 22.,23., 23., 24., 24.,23., 23., 24., 24.,23., 23., 24., 24.,21., 21., 22., 22.,21., 21., 22., 22.,21., 21., 22., 22.,23., 23., 24., 24.,23., 23., 24., 24.,23., 23., 24., 24.,
            25., 25., 26., 26.,25., 25., 26., 26.,25., 25., 26., 26.,27., 27., 28., 28.,27., 27., 28., 28.,27., 27., 28., 28.,25., 25., 26., 26.,25., 25., 26., 26.,25., 25., 26., 26.,27., 27., 28., 28.,27., 27., 28., 28.,27., 27., 28., 28.,29., 29., 30., 30.,29., 29., 30., 30.,29., 29., 30., 30.,31., 31., 32., 32.,31., 31., 32., 32.,31., 31., 32., 32.,
            29., 29., 30., 30.,29., 29., 30., 30.,29., 29., 30., 30.,31., 31., 32., 32.,31., 31., 32., 32.,31., 31., 32., 32.,33., 33., 34., 34.,33., 33., 34., 34.,33., 33., 34., 34.,35., 35., 36., 36.,35., 35., 36., 36.,35., 35., 36., 36.,33., 33., 34., 34.,33., 33., 34., 34.,33., 33., 34., 34.,35., 35., 36., 36.,35., 35., 36., 36.,35., 35., 36., 36.,
            37., 37., 38., 38.,37., 37., 38., 38.,37., 37., 38., 38.,39., 39., 40., 40.,39., 39., 40., 40.,39., 39., 40., 40.,37., 37., 38., 38.,37., 37., 38., 38.,37., 37., 38., 38.,39., 39., 40., 40.,39., 39., 40., 40.,39., 39., 40., 40.,41., 41., 42., 42.,41., 41., 42., 42.,41., 41., 42., 42.,43., 43., 44., 44.,43., 43., 44., 44.,43., 43., 44., 44.,
            41., 41., 42., 42.,41., 41., 42., 42.,41., 41., 42., 42.,43., 43., 44., 44.,43., 43., 44., 44.,43., 43., 44., 44.,45., 45., 46., 46.,45., 45., 46., 46.,45., 45., 46., 46.,47., 47., 48., 48.,47., 47., 48., 48.,47., 47., 48., 48.,45., 45., 46., 46.,45., 45., 46., 46.,45., 45., 46., 46.,47., 47., 48., 48.,47., 47., 48., 48.,47., 47., 48., 48.,
            49., 49., 50., 50.,49., 49., 50., 50.,49., 49., 50., 50.,51., 51., 52., 52.,51., 51., 52., 52.,51., 51., 52., 52.,49., 49., 50., 50.,49., 49., 50., 50.,49., 49., 50., 50.,51., 51., 52., 52.,51., 51., 52., 52.,51., 51., 52., 52.,53., 53., 54., 54.,53., 53., 54., 54.,53., 53., 54., 54.,55., 55., 56., 56.,55., 55., 56., 56.,55., 55., 56., 56.,
            53., 53., 54., 54.,53., 53., 54., 54.,53., 53., 54., 54.,55., 55., 56., 56.,55., 55., 56., 56.,55., 55., 56., 56.,57., 57., 58., 58.,57., 57., 58., 58.,57., 57., 58., 58.,59., 59., 60., 60.,59., 59., 60., 60.,59., 59., 60., 60.,57., 57., 58., 58.,57., 57., 58., 58.,57., 57., 58., 58.,59., 59., 60., 60.,59., 59., 60., 60.,59., 59., 60., 60.,
            61., 61., 62., 62.,61., 61., 62., 62.,61., 61., 62., 62.,63., 63., 64., 64.,63., 63., 64., 64.,63., 63., 64., 64.,61., 61., 62., 62.,61., 61., 62., 62.,61., 61., 62., 62.,63., 63., 64., 64.,63., 63., 64., 64.,63., 63., 64., 64.,65., 65., 66., 66.,65., 65., 66., 66.,65., 65., 66., 66.,67., 67., 68., 68.,67., 67., 68., 68.,67., 67., 68., 68.,
            65., 65., 66., 66.,65., 65., 66., 66.,65., 65., 66., 66.,67., 67., 68., 68.,67., 67., 68., 68.,67., 67., 68., 68.,69., 69., 70., 70.,69., 69., 70., 70.,69., 69., 70., 70.,71., 71., 72., 72.,71., 71., 72., 72.,71., 71., 72., 72.,69., 69., 70., 70.,69., 69., 70., 70.,69., 69., 70., 70.,71., 71., 72., 72.,71., 71., 72., 72.,71., 71., 72., 72.});

    nd4j::ops::upsampling3d op;
    auto results = op.execute({&input}, {}, {factorD, factorH, factorW, isNCDHW});
    auto* output = results->at(0);

    ASSERT_EQ(Status::OK(), results->status());

    ASSERT_TRUE(expOutput.isSameShape(output));
    ASSERT_TRUE(expOutput.equalsTo(output));

    delete results;
}


//////////////////////////////////////////////////////////////////////
TEST_F(ConvolutionTests1, upsampling3d_bp_test1) {

    const int bS=1,  iD=2,iH=2,iW=2,  iC=1;
    const int factorD=2, factorH=2, factorW=2;
    const int isNCDHW = 1;                    // data format, default is NCHW

    auto input  = NDArrayFactory::create<float>('c', {bS, iC, iD, iH, iW});
    auto gradO  = NDArrayFactory::create<float>('c', {bS, iC, iD*factorD, iH*factorH, iW*factorW});
    gradO = 1.;

    auto expGradI = NDArrayFactory::create<float>('c', {bS, iC, iD, iH, iW});
    expGradI = 8.;

    nd4j::ops::upsampling3d_bp op;
    auto results = op.execute({&input, &gradO}, {}, {isNCDHW});
    auto* gradI = results->at(0);

    ASSERT_EQ(Status::OK(), results->status());
    ASSERT_TRUE(expGradI.isSameShape(gradI));
    ASSERT_TRUE(expGradI.equalsTo(gradI));

    delete results;
}

TYPED_TEST(TypedConvolutionTests1, conv2D_input_BP_test1) {

    auto inputShape = NDArrayFactory::create<TypeParam>('c', {4}, {2, 1, 4, 4});
    auto weights = NDArrayFactory::create<TypeParam>('c', {2, 1, 3, 3});
    auto epsilonNext = NDArrayFactory::create<TypeParam>('c', {2, 2, 4, 4});
    auto shapeArr = NDArrayFactory::create<TypeParam>('c', {2, 1, 4, 4});


    TypeParam _expEpsB[] = {952.0, 1540.0, 1636.0, 1180.0, 1791.0, 2886.0, 3057.0, 2193.0, 2223.0, 3570.0, 3741.0, 2673.0, 1900.0, 3028.0, 3160.0, 2240.0, 2872.0, 4612.0, 4708.0, 3356.0, 5247.0, 8358.0, 8529.0, 6033.0, 5679.0, 9042.0, 9213.0, 6513.0, 4588.0, 7252.0, 7384.0, 5184.0};
    NDArray expEps(_expEpsB, shapeArr.getShapeInfo());

    weights.linspace(1);
    epsilonNext.linspace(1);
    weights.permutei({2,3,1,0});

    nd4j::ops::conv2d_input_bp op;

    auto results = op.execute({&inputShape, &weights, &epsilonNext}, {},  {3, 3, 1, 1, 0, 0, 1, 1, 1});

    ASSERT_TRUE(results->size() == 1);

    auto epsilon = results->at(0);

    ASSERT_TRUE(shapeArr.isSameShape(epsilon));
    ASSERT_TRUE(expEps.equalsTo(epsilon));

    delete results;
}

//////////////////////////////////////////////////////////////////////
TEST_F(ConvolutionTests1, upsampling3d_bp_test3) {

    const int bS=1,  iD=3,iH=3,iW=3,  iC=2;
    const int factorD=2, factorH=2, factorW=2;
    const int isNCDHW = 1;                    // data format, default is NCHW

    NDArray input('c', {bS, iC, iD, iH, iW}, nd4j::DataType::FLOAT32);
    NDArray gradO('c', {bS, iC, iD*factorD, iH*factorH, iW*factorW}, {0.6793504, 0.35508695, 0.84278935, 0.20031333, 0.7014987, 0.31069338,
        0.44793984, 0.93800974, 0.32667395, 0.15187258, 0.38331753, 0.78212297, 0.1988072, 0.7985636, 0.1632634, 0.14696825, 0.26089668,
        0.13505761, 0.7562093, 0.27545404, 0.36908787, 0.09282647, 0.83649176, 0.26841334, 0.09506222, 0.31279507, 0.13591796, 0.5175439,
        0.32870287, 0.061735712, 0.39643127, 0.248016, 0.5489592, 0.115046196, 0.8143622, 0.7215636, 0.40449402, 0.29908907, 0.4038839,
        0.9883108, 0.022296403, 0.927782, 0.3184157, 0.0685462, 0.28453344, 0.23272, 0.35214192, 0.058909304, 0.7112212, 0.6744568, 0.19694561,
        0.6994972, 0.0743224, 0.42042503, 0.5842631, 0.14957358, 0.44640633, 0.72307247, 0.06448108, 0.48307765, 0.8759956, 0.5698191, 0.4458631,
        0.5277549, 0.016646361, 0.753678, 0.14063567, 0.7541292, 0.16193217, 0.7750374, 0.3326449, 0.11739397, 0.017710684, 0.60847557, 0.52515227,
        0.9171938, 0.84989065, 0.5894228, 0.85227835, 0.39063585, 0.88968325, 0.6694452, 0.698873, 0.96147966, 0.15740126, 0.15736352, 0.49352047,
        0.5699365, 0.12683152, 0.11572781, 0.7863682, 0.737939, 0.49007934, 0.6084143, 0.9564999, 0.3900982, 0.14730452, 0.8506447, 0.49765033,
        0.07186628, 0.08214969, 0.035314173, 0.7320408, 0.36993408, 0.8406658, 0.27389422, 0.43179566, 0.13323106, 0.19297548, 0.24689731, 0.38641843,
        0.51154125, 0.19903564, 0.1416313, 0.69769853, 0.25363067, 0.78221816, 0.9300991, 0.3355119, 0.5588076, 0.6643576, 0.018850708, 0.63755876,
        0.2904297, 0.43490165, 0.84251267, 0.46609768, 0.38139546, 0.52318525, 0.9901826, 0.9257676, 0.6434591, 0.016828254, 0.9187561, 0.22897908,
        0.0063138064, 0.66597503, 0.19036093, 0.59552056, 0.69888055, 0.22146936, 0.9124342, 0.8708221, 0.7273687, 0.52397245, 0.66288394, 0.2188415,
        0.3354802, 0.03566524, 0.5101009, 0.5017283, 0.75122046, 0.1884508, 0.7407126, 0.6253045, 0.47145858, 0.5369367, 0.19884548, 0.99008304,
        0.08256686, 0.91884845, 0.02360027, 0.98895234, 0.3751719, 0.91783875, 0.4338776, 0.6783008, 0.6667967, 0.46720362, 0.7508773, 0.52304846,
        0.76631916, 0.4187526, 0.7653719, 0.5159193, 0.42730415, 0.49462363, 0.2731735, 0.8862948, 0.043214794, 0.3197591, 0.040378205, 0.5427239,
        0.9228089, 0.045940384, 0.70047987, 0.8419288, 0.53966296, 0.009444186, 0.038044546, 0.03158029, 0.43485752, 0.9204235, 0.5478789, 0.8290083,
        0.11868837, 0.0229866, 0.6639305, 0.8757367, 0.8279557, 0.76270294, 0.43242732, 0.4713431, 0.2569212, 0.30575937, 0.44395888, 0.99384075,
        0.6127142, 0.44844577, 0.6347944, 0.098358564, 0.34233716, 0.9329664, 0.65776783, 0.108565055, 0.2052629, 0.46441218, 0.041791342, 0.89369565,
        0.7000381, 0.2106213, 0.51152664, 0.44200692, 0.8293282, 0.20901772, 0.6387249, 0.8016979, 0.11178707, 0.109545894, 0.19654618, 0.060582615,
        0.08239174, 0.64630795, 0.32862368, 0.60225064, 0.8328141, 0.5484566, 0.8120276, 0.38822946, 0.6742381, 0.34913155, 0.42887798, 0.45344824,
        0.73956585, 0.9714739, 0.42937812, 0.45185348, 0.84535813, 0.046436775, 0.8802151, 0.8676222, 0.42625394, 0.4985318, 0.42399272, 0.122144565,
        0.0060101906, 0.47253844, 0.18123977, 0.86316174, 0.5863874, 0.3852012, 0.9785553, 0.0054711984, 0.88500834, 0.020897374, 0.27467912, 0.3852802,
        0.0766939, 0.94622654, 0.38687763, 0.3308602, 0.7770494, 0.9052543, 0.22258204, 0.42207044, 0.18050623, 0.21057767, 0.012561422, 0.7977821,
        0.61251044, 0.7203693, 0.6028265, 0.6036933, 0.1446382, 0.6712341, 0.76634467, 0.4854034, 0.26634562, 0.76523924, 0.16348523, 0.2663676,
        0.96846986, 0.8273284, 0.10700377, 0.7600526, 0.6771002, 0.47963092, 0.21264452, 0.56934077, 0.5514792, 0.85725874, 0.99090636, 0.54562527,
        0.93597686, 0.21142527, 0.4628326, 0.35011524, 0.31464386, 0.31164807, 0.65928996, 0.94418925, 0.39666295, 0.9496393, 0.103756346, 0.482158,
        0.49171793, 0.4108867, 0.22594318, 0.97093135, 0.5974685, 0.34632966, 0.54835194, 0.10499302, 0.9767778, 0.55008715, 0.54379046, 0.3583731,
        0.33369112, 0.04279039, 0.24939054, 0.23943715, 0.06775989, 0.7750291, 0.24329625, 0.4327169, 0.86916673, 0.80322117, 0.049972698, 0.47177452,
        0.37419558, 0.15303156, 0.121425234, 0.75884604, 0.8191354, 0.48554084, 0.053899214, 0.7858246, 0.39219773, 0.77579063, 0.34507045, 0.46070176,
        0.14496958, 0.47706795, 0.50678796, 0.64902323, 0.3277943, 0.0017530271, 0.6536156, 0.8582253, 0.95703506, 0.9963951, 0.8239163, 0.305142,
        0.012419582, 0.9498972, 0.1595827, 0.47947606, 0.5071124, 0.78227425, 0.2066719, 0.5217094, 0.7841406, 0.5260441, 0.49798164, 0.10975622,
        0.8633349, 0.76298475, 0.14295428, 0.6131504, 0.43794408, 0.50339264, 0.4504877, 0.19235311, 0.6678411, 0.80769485, 0.67495126, 0.96461457,
        0.10535406, 0.66438645, 0.4372345, 0.93851465, 0.8635335, 0.3405871, 0.45652762, 0.3636232, 0.52931345, 0.20154329, 0.07698499, 0.6125804,
        0.3583082, 0.3894796, 0.32601944, 0.5237369, 0.66683626, 0.08541841, 0.4815708, 0.11897489, 0.97555137, 0.3602705, 0.9620871, 0.6361821,
        0.71167386, 0.5134439, 0.57761437, 0.58598644, 0.39387667, 0.6966405, 0.46841687, 0.85788506, 0.9957087, 0.051309288, 0.24846801, 0.55938333,
        0.10230542, 0.9370694, 0.57527155, 0.54656035, 0.28896323, 0.51303476, 0.8865, 0.38641605, 0.9836358}, nd4j::DataType::FLOAT32);

    NDArray expGradI('c', {bS, iC, iD, iH, iW}, {3.510932, 3.4310975, 3.538762, 4.148549, 2.8380678, 2.5431657, 3.3928843, 3.228055, 3.1467278,
        3.2603023, 5.611751, 4.334653, 3.3697734, 4.603307, 4.4357986, 4.32991, 3.0532732, 3.1370173, 4.181534, 2.9965065, 2.8553872, 5.2719016,
        4.5671935, 3.7027276, 3.3517184, 5.2544537, 3.5107024, 4.1496124, 3.9333878, 3.1798909, 3.1446428, 3.0932689, 3.9730802, 3.0466917,
        4.9675374, 4.769673, 3.766952, 3.6375027, 3.6492167, 4.9440994, 3.8379507, 3.467589, 4.719474, 3.1295977, 4.5177174, 4.2760015, 2.8443856,
        4.225355, 4.377341, 4.4398847, 4.710785, 4.4199953, 3.928307, 4.8769503}, nd4j::DataType::FLOAT32);

    nd4j::ops::upsampling3d_bp op;
    auto results = op.execute({&input, &gradO}, {}, {isNCDHW});
    auto* gradI = results->at(0);

    ASSERT_EQ(Status::OK(), results->status());
    ASSERT_TRUE(expGradI.isSameShape(gradI));
    ASSERT_TRUE(expGradI.equalsTo(gradI));

    delete results;
}


//////////////////////////////////////////////////////////////////////
TEST_F(ConvolutionTests1, deconv2d_test1) {

    int bS=2, iH=4,iW=4,  iC=5,oC=10,  kH=2,kW=2,  sH=1,sW=1,  pH=0,pW=0,  dH=1,dW=1;
    int       oH=3,oW=3;
    int paddingMode = 0;             // 1-SAME, 0-VALID;
    int dataFormat  = 1;             // 1-NHWC, 0-NCHW

    auto input    = NDArrayFactory::create<double>('c', {bS, oH, oW, oC});
    auto weights  = NDArrayFactory::create<double>('c', {kH, kW, iC, oC});
    auto exp = NDArrayFactory::create<double>('c', {bS, iH, iW, iC}, {  2.75,   7.75,  12.75,  17.75,  22.75, 30.5 ,  40.5 ,  50.5 ,  60.5 ,  70.5 , 30.5 ,  40.5 ,  50.5 ,  60.5 ,  70.5 , 27.75,  32.75,  37.75,  42.75,  47.75,
                                                  55.5 ,  65.5 ,  75.5 ,  85.5 ,  95.5 ,161.  , 181.  , 201.  , 221.  , 241.  ,161.  , 181.  , 201.  , 221.  , 241.  ,105.5 , 115.5 , 125.5 , 135.5 , 145.5 ,
                                                  55.5 ,  65.5 ,  75.5 ,  85.5 ,  95.5 ,161.  , 181.  , 201.  , 221.  , 241.  ,161.  , 181.  , 201.  , 221.  , 241.  ,105.5 , 115.5 , 125.5 , 135.5 , 145.5 ,
                                                  52.75,  57.75,  62.75,  67.75,  72.75,130.5 , 140.5 , 150.5 , 160.5 , 170.5 ,130.5 , 140.5 , 150.5 , 160.5 , 170.5 , 77.75,  82.75,  87.75,  92.75,  97.75,
                                                   2.75,   7.75,  12.75,  17.75,  22.75, 30.5 ,  40.5 ,  50.5 ,  60.5 ,  70.5 , 30.5 ,  40.5 ,  50.5 ,  60.5 ,  70.5 , 27.75,  32.75,  37.75,  42.75,  47.75,
                                                  55.5 ,  65.5 ,  75.5 ,  85.5 ,  95.5 ,161.  , 181.  , 201.  , 221.  , 241.  ,161.  , 181.  , 201.  , 221.  , 241.  ,105.5 , 115.5 , 125.5 , 135.5 , 145.5 ,
                                                  55.5 ,  65.5 ,  75.5 ,  85.5 ,  95.5 ,161.  , 181.  , 201.  , 221.  , 241.  ,161.  , 181.  , 201.  , 221.  , 241.  ,105.5 , 115.5 , 125.5 , 135.5 , 145.5 ,
                                                  52.75,  57.75,  62.75,  67.75,  72.75,130.5 , 140.5 , 150.5 , 160.5 , 170.5 ,130.5 , 140.5 , 150.5 , 160.5 , 170.5 , 77.75,  82.75,  87.75,  92.75,  97.75});
    input = 0.5;
    weights.linspace(0.1, 0.1);

    nd4j::ops::deconv2d op;
    auto results = op.execute({&input, &weights}, {}, {kH,kW,  sH,sW,  pH,pW,  dH,dW, paddingMode, dataFormat});
    ASSERT_EQ(Status::OK(), results->status());

    auto output = results->at(0);

    ASSERT_TRUE(exp.isSameShape(output));
    ASSERT_TRUE(exp.equalsTo(output));

    delete results;
}

//////////////////////////////////////////////////////////////////////
TEST_F(ConvolutionTests1, deconv2d_test2) {

    int bS=2, iH=4,iW=4,  iC=5,oC=10,  kH=2,kW=2,  sH=1,sW=1,  pH=0,pW=0,  dH=1,dW=1;
    int       oH=4,oW=4;
    int paddingMode = 1;             // 1-SAME, 0-VALID;
    int dataFormat  = 1;             // 1-NHWC, 0-NCHW

    auto input    = NDArrayFactory::create<double>('c', {bS, oH, oW, oC});
    auto weights  = NDArrayFactory::create<double>('c', {kH, kW, iC, oC});
    auto exp = NDArrayFactory::create<double>('c', {bS, iH, iW, iC}, {2.75,   7.75,  12.75,  17.75,  22.75, 30.5 ,  40.5 ,  50.5 ,  60.5 ,  70.5 , 30.5 ,  40.5 ,  50.5 ,  60.5 ,  70.5 , 30.5 ,  40.5 ,  50.5 ,  60.5 ,  70.5 ,
                                                55.5 ,  65.5 ,  75.5 ,  85.5 ,  95.5 ,161.  , 181.  , 201.  , 221.  , 241.  ,161.  , 181.  , 201.  , 221.  , 241.  ,161.  , 181.  , 201.  , 221.  , 241.  ,
                                                55.5 ,  65.5 ,  75.5 ,  85.5 ,  95.5 ,161.  , 181.  , 201.  , 221.  , 241.  ,161.  , 181.  , 201.  , 221.  , 241.  ,161.  , 181.  , 201.  , 221.  , 241.  ,
                                                55.5 ,  65.5 ,  75.5 ,  85.5 ,  95.5 ,161.  , 181.  , 201.  , 221.  , 241.  ,161.  , 181.  , 201.  , 221.  , 241.  ,161.  , 181.  , 201.  , 221.  , 241.  ,
                                                 2.75,   7.75,  12.75,  17.75,  22.75, 30.5 ,  40.5 ,  50.5 ,  60.5 ,  70.5 , 30.5 ,  40.5 ,  50.5 ,  60.5 ,  70.5 , 30.5 ,  40.5 ,  50.5 ,  60.5 ,  70.5 ,
                                                55.5 ,  65.5 ,  75.5 ,  85.5 ,  95.5 ,161.  , 181.  , 201.  , 221.  , 241.  ,161.  , 181.  , 201.  , 221.  , 241.  ,161.  , 181.  , 201.  , 221.  , 241.  ,
                                                55.5 ,  65.5 ,  75.5 ,  85.5 ,  95.5 ,161.  , 181.  , 201.  , 221.  , 241.  ,161.  , 181.  , 201.  , 221.  , 241.  ,161.  , 181.  , 201.  , 221.  , 241.  ,
                                                55.5 ,  65.5 ,  75.5 ,  85.5 ,  95.5 ,161.  , 181.  , 201.  , 221.  , 241.  ,161.  , 181.  , 201.  , 221.  , 241.  ,161.  , 181.  , 201.  , 221.  , 241.  });
    input = 0.5;
    weights.linspace(0.1, 0.1);

    nd4j::ops::deconv2d op;
    auto results = op.execute({&input, &weights}, {}, {kH,kW,  sH,sW,  pH,pW,  dH,dW, paddingMode, dataFormat});
    auto output = results->at(0);

    ASSERT_EQ(Status::OK(), results->status());
    ASSERT_TRUE(exp.isSameShape(output));
    ASSERT_TRUE(exp.equalsTo(output));

    delete results;
}

//////////////////////////////////////////////////////////////////////
TYPED_TEST(TypedConvolutionTests1, deconv2d_tf_test1) {

    int bS=2, iH=4,iW=4,  iC=5,oC=10,  kH=2,kW=2,  sH=1,sW=1,  pH=0,pW=0,  dH=1,dW=1;
    int       oH=3,oW=3;
    int paddingMode = 0;             // 1-SAME, 0-VALID;
    int dataFormat  = 1;             // 1-NHWC, 0-NCHW

    auto input    = NDArrayFactory::create<TypeParam>('c', {bS, oH, oW, oC});
    auto weights  = NDArrayFactory::create<TypeParam>('c', {kH, kW, iC, oC});
    auto outShape = NDArrayFactory::create<TypeParam>('c', {4}, {static_cast<TypeParam>(bS), static_cast<TypeParam>(iH), static_cast<TypeParam>(iW), static_cast<TypeParam>(iC)});
    auto exp = NDArrayFactory::create<TypeParam>('c', {bS, iH, iW, iC}, {  2.75,   7.75,  12.75,  17.75,  22.75, 30.5 ,  40.5 ,  50.5 ,  60.5 ,  70.5 , 30.5 ,  40.5 ,  50.5 ,  60.5 ,  70.5 , 27.75,  32.75,  37.75,  42.75,  47.75,
                                                  55.5 ,  65.5 ,  75.5 ,  85.5 ,  95.5 ,161.  , 181.  , 201.  , 221.  , 241.  ,161.  , 181.  , 201.  , 221.  , 241.  ,105.5 , 115.5 , 125.5 , 135.5 , 145.5 ,
                                                  55.5 ,  65.5 ,  75.5 ,  85.5 ,  95.5 ,161.  , 181.  , 201.  , 221.  , 241.  ,161.  , 181.  , 201.  , 221.  , 241.  ,105.5 , 115.5 , 125.5 , 135.5 , 145.5 ,
                                                  52.75,  57.75,  62.75,  67.75,  72.75,130.5 , 140.5 , 150.5 , 160.5 , 170.5 ,130.5 , 140.5 , 150.5 , 160.5 , 170.5 , 77.75,  82.75,  87.75,  92.75,  97.75,
                                                   2.75,   7.75,  12.75,  17.75,  22.75, 30.5 ,  40.5 ,  50.5 ,  60.5 ,  70.5 , 30.5 ,  40.5 ,  50.5 ,  60.5 ,  70.5 , 27.75,  32.75,  37.75,  42.75,  47.75,
                                                  55.5 ,  65.5 ,  75.5 ,  85.5 ,  95.5 ,161.  , 181.  , 201.  , 221.  , 241.  ,161.  , 181.  , 201.  , 221.  , 241.  ,105.5 , 115.5 , 125.5 , 135.5 , 145.5 ,
                                                  55.5 ,  65.5 ,  75.5 ,  85.5 ,  95.5 ,161.  , 181.  , 201.  , 221.  , 241.  ,161.  , 181.  , 201.  , 221.  , 241.  ,105.5 , 115.5 , 125.5 , 135.5 , 145.5 ,
                                                  52.75,  57.75,  62.75,  67.75,  72.75,130.5 , 140.5 , 150.5 , 160.5 , 170.5 ,130.5 , 140.5 , 150.5 , 160.5 , 170.5 , 77.75,  82.75,  87.75,  92.75,  97.75});
    input = 0.5;
    weights.linspace(0.1, 0.1);

    nd4j::ops::deconv2d_tf op;
    auto results = op.execute({&outShape, &weights, &input}, {}, {kH,kW,  sH,sW,  pH,pW,  dH,dW, paddingMode, dataFormat});
    auto output = results->at(0);

    ASSERT_EQ(Status::OK(), results->status());
    ASSERT_TRUE(exp.isSameShape(output));
    ASSERT_TRUE(exp.equalsTo(output));

    delete results;
}


#endif //LIBND4J_CONVOLUTIONTESTS1_H