cavis/libnd4j/tests_cpu/layers_tests/ConvolutionTests2.cpp

370 lines
97 KiB
C++
Raw Normal View History

2019-06-06 14:21:15 +02:00
/*******************************************************************************
* Copyright (c) 2015-2018 Skymind, Inc.
*
* This program and the accompanying materials are made available under the
* terms of the Apache License, Version 2.0 which is available at
* https://www.apache.org/licenses/LICENSE-2.0.
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* License for the specific language governing permissions and limitations
* under the License.
*
* SPDX-License-Identifier: Apache-2.0
******************************************************************************/
//
// @author raver119@gmail.com
// @author Yurii Shyrma (iuriish@yahoo.com), created 02.04.2019
//
#ifndef LIBND4J_CONVOLUTIONTESTS2_H
#define LIBND4J_CONVOLUTIONTESTS2_H
#include "testlayers.h"
#include <NDArray.h>
#include <Context.h>
#include <Node.h>
#include <graph/Variable.h>
#include <graph/VariableSpace.h>
#include <ops/declarable/CustomOperations.h>
#include <ops/declarable/helpers/convolutions.h>
#include <ops/declarable/helpers/col2im.h>
#include <PointersManager.h>
using namespace nd4j;
using namespace nd4j::graph;
class ConvolutionTests2 : public testing::Test {
public:
};
//////////////////////////////////////////////////////////////////////
TEST_F(ConvolutionTests2, im2col_1) {
int bS=2, iH=4,iW=3, iC=4, kH=3,kW=2, sH=1,sW=1, pH=0,pW=0, dH=1,dW=1;
int oH = (iH - (kH + (kH-1)*(dH-1)) + 2*pH)/sH + 1; // VALID
int oW = (iW - (kW + (kW-1)*(dW-1)) + 2*pW)/sW + 1; // VALID
int paddingMode = 0; // 1-SAME, 0-VALID;
NDArray image('c', {bS, iC, iH, iW}, nd4j::DataType::DOUBLE);
NDArray expected('c', {bS, iC, kH, kW, oH, oW}, {1, 2, 4, 5, 2, 3, 5, 6, 4, 5, 7, 8, 5, 6, 8, 9, 7, 8, 10, 11, 8, 9, 11, 12, 13, 14, 16, 17, 14,
15, 17, 18, 16, 17, 19, 20, 17, 18, 20, 21, 19, 20, 22, 23, 20, 21, 23, 24, 25, 26, 28, 29, 26, 27, 29, 30,
28, 29, 31, 32, 29, 30, 32, 33, 31, 32, 34, 35, 32, 33, 35, 36, 37, 38, 40, 41, 38, 39, 41, 42, 40, 41, 43,
44, 41, 42, 44, 45, 43, 44, 46, 47, 44, 45, 47, 48, 49, 50, 52, 53, 50, 51, 53, 54, 52, 53, 55, 56, 53, 54,
56, 57, 55, 56, 58, 59, 56, 57, 59, 60, 61, 62, 64, 65, 62, 63, 65, 66, 64, 65, 67, 68, 65, 66, 68, 69, 67,
68, 70, 71, 68, 69, 71, 72, 73, 74, 76, 77, 74, 75, 77, 78, 76, 77, 79, 80, 77, 78, 80, 81, 79, 80, 82, 83,
80, 81, 83, 84, 85, 86, 88, 89, 86, 87, 89, 90, 88, 89, 91, 92, 89, 90, 92, 93, 91, 92, 94, 95, 92, 93, 95, 96});
image.linspace(1, 1);
nd4j::ops::im2col op;
auto results = op.execute({&image}, {}, {kH,kW, sH,sW, pH,pW, dH,dW, paddingMode});
auto column = results->at(0);
// column->printIndexedBuffer();
ASSERT_EQ(Status::OK(), results->status());
ASSERT_TRUE(expected.isSameShape(column));
ASSERT_TRUE(expected.equalsTo(column));
delete results;
}
template <typename T>
class TypedConvolutionTests2 : public testing::Test {
public:
};
typedef ::testing::Types<double, float> TestingTypes;
TYPED_TEST_CASE(TypedConvolutionTests2, TestingTypes);
//////////////////////////////////////////////////////////////////////
TYPED_TEST(TypedConvolutionTests2, deconv2d_tf_test2) {
int bS=2, iH=4,iW=4, iC=5,oC=10, kH=2,kW=2, sH=1,sW=1, pH=0,pW=0, dH=1,dW=1;
int oH=4,oW=4;
int paddingMode = 1; // 1-SAME, 0-VALID;
int dataFormat = 1; // 1-NHWC, 0-NCHW
auto input = NDArrayFactory::create<TypeParam>('c', {bS, oH, oW, oC});
auto weights = NDArrayFactory::create<TypeParam>('c', {kH, kW, iC, oC});
auto outShape = NDArrayFactory::create<TypeParam>('c', {4}, {static_cast<TypeParam>(bS), static_cast<TypeParam>(iH), static_cast<TypeParam>(iW), static_cast<TypeParam>(iC)});
auto exp = NDArrayFactory::create<TypeParam>('c', {bS, iH, iW, iC}, {2.75, 7.75, 12.75, 17.75, 22.75, 30.5 , 40.5 , 50.5 , 60.5 , 70.5 , 30.5 , 40.5 , 50.5 , 60.5 , 70.5 , 30.5 , 40.5 , 50.5 , 60.5 , 70.5 ,
55.5 , 65.5 , 75.5 , 85.5 , 95.5 ,161. , 181. , 201. , 221. , 241. ,161. , 181. , 201. , 221. , 241. ,161. , 181. , 201. , 221. , 241. ,
55.5 , 65.5 , 75.5 , 85.5 , 95.5 ,161. , 181. , 201. , 221. , 241. ,161. , 181. , 201. , 221. , 241. ,161. , 181. , 201. , 221. , 241. ,
55.5 , 65.5 , 75.5 , 85.5 , 95.5 ,161. , 181. , 201. , 221. , 241. ,161. , 181. , 201. , 221. , 241. ,161. , 181. , 201. , 221. , 241. ,
2.75, 7.75, 12.75, 17.75, 22.75, 30.5 , 40.5 , 50.5 , 60.5 , 70.5 , 30.5 , 40.5 , 50.5 , 60.5 , 70.5 , 30.5 , 40.5 , 50.5 , 60.5 , 70.5 ,
55.5 , 65.5 , 75.5 , 85.5 , 95.5 ,161. , 181. , 201. , 221. , 241. ,161. , 181. , 201. , 221. , 241. ,161. , 181. , 201. , 221. , 241. ,
55.5 , 65.5 , 75.5 , 85.5 , 95.5 ,161. , 181. , 201. , 221. , 241. ,161. , 181. , 201. , 221. , 241. ,161. , 181. , 201. , 221. , 241. ,
55.5 , 65.5 , 75.5 , 85.5 , 95.5 ,161. , 181. , 201. , 221. , 241. ,161. , 181. , 201. , 221. , 241. ,161. , 181. , 201. , 221. , 241. });
input = 0.5;
weights.linspace(0.1, 0.1);
nd4j::ops::deconv2d_tf op;
auto results = op.execute({&outShape, &weights, &input}, {}, {kH,kW, sH,sW, pH,pW, dH,dW, paddingMode, dataFormat});
auto output = results->at(0);
ASSERT_EQ(Status::OK(), results->status());
ASSERT_TRUE(exp.isSameShape(output));
ASSERT_TRUE(exp.equalsTo(output));
delete results;
}
//////////////////////////////////////////////////////////////////////
TYPED_TEST(TypedConvolutionTests2, Test_DeConv2D_TF_1) {
auto input0 = NDArrayFactory::create<TypeParam>('c', {4}, {12.f, 5.f, 5.f, 32.f});
auto input1 = NDArrayFactory::create<TypeParam>('c', {2, 2, 32, 16});
auto input2 = NDArrayFactory::create<TypeParam>('c', {12, 4, 4, 16});
auto exp = NDArrayFactory::create<TypeParam>('c', {12, 5, 5, 32});
nd4j::ops::deconv2d_tf op;
auto result = op.execute({&input0, &input1, &input2}, {}, {2, 2, 1, 1, 0, 0, 1, 1, 0, 1});
ASSERT_EQ(Status::OK(), result->status());
ASSERT_EQ(exp, *result->at(0));
delete result;
}
//////////////////////////////////////////////////////////////////////
TYPED_TEST(TypedConvolutionTests2, Test_DeConv2D_TF_2) {
auto input0 = NDArrayFactory::create<TypeParam>('c', {4}, {3, 8, 8, 16});
auto input1 = NDArrayFactory::create<TypeParam>('c', {7, 7, 16, 5}, {1.05293429,-0.89349967,0.31027254,1.22991478,-0.62926656,0.56918693,-1.60992694,1.10167944,-0.80843484,0.07521993,-1.15994942,0.76016301,-0.40056285,-1.16872537,-0.91384381,-0.36700436,1.82389200,-1.18200207,0.51612782,-0.92479187,-0.09307563,-0.55122334,1.23532486,-1.11124146,-0.05812126,0.68159896,0.69125599,-0.77127314,-0.10874277,0.86469102,-1.31614351,0.33354419,-1.71750402,0.17197680,-1.03965557,1.10570908,-1.19115615,1.05115080,0.18277600,1.08820546,-0.72191417,-0.10999311,1.56521320,-0.35433730,-1.11799145,0.34499285,0.64998639,-1.64371550,0.92592359,-0.47659501,0.49101439,-0.15613313,1.47486567,0.43576995,2.19538260,-0.83567709,-1.21846950,0.80400819,1.14637423,-1.01503456,-0.61992753,-0.47378838,0.86503726,0.27147385,0.37073180,-0.19951358,0.79167330,-0.33982825,0.18631981,-1.54715073,0.39967480,0.95067030,1.12508667,-0.86676019,-1.10341156,2.33141375,1.10972047,0.71407092,1.70640314,1.80666339,0.59465605,-0.39653218,-2.61163163,-1.15013492,-1.19908321,0.41783467,-0.22730024,0.31425011,-0.58562893,-0.10131568,-0.85047537,-2.59974790,1.22072542,-2.08812046,-0.19363593,-1.27664304,-0.02703438,1.08477545,-0.65506506,0.46040919,-0.13715318,-0.74945593,-0.69006950,-1.29617655,-0.15865716,1.38956285,0.90216327,-1.31185400,-0.15067385,-0.63093358,-0.05895613,0.26545224,0.29332840,0.42852548,0.72409540,0.12879130,1.43038857,0.68647617,2.19654775,0.51878077,-0.03769343,0.52877223,-0.21733910,1.13710785,-0.59003806,1.54624867,-0.64997369,-1.03239334,0.19708300,0.68658423,0.71048903,-1.55250466,-1.38636279,0.32385820,0.81226677,0.19209047,-0.23002781,-0.63631231,1.02101684,0.65428704,-0.17206922,1.09488952,1.03022420,-0.95567745,-0.07595373,-1.48606372,2.57174873,-1.75366247,1.12913883,0.97053039,-0.28552356,0.56511772,-0.79568213,0.07561764,-1.02085686,1.05770981,-1.25715709,0.42046708,-2.57390857,0.96947151,1.05215812,0.65624017,-1.29019403,0.64157075,-0.40509227,-0.65354455,0.42348680,-1.34107757,0.05931387,-0.54337227,0.95460182,1.59319806,-0.44433126,-0.33717924,0.79566282,0.50112695,-0.22244534,1.76904583,-0.89817202,1.82985342,0.17671813,0.80720717,1.32469308,0.39417782,-0.23720963,0.96796370,-1.02348757,-0.86615551,-1.58120525,-0.37634999,0.00905940,0.01880967,1.75771821,-0.64372772,0.36687651,0.15854552,-0.67599791,0.53726906,-1.20158446,-1.78549063,0.96476388,-0.66158366,-0.41681561,-0.97541636,2.35928202,0.32130197,1.06886065,1.38736427,-0.73718959,0.11215294,2.12865782,-0.37927702,0.55621815,-1.10108411,-0.02032263,0.29595461,1.58737493,1.24001300,-0.66748160,0.80729002,-0.10575818,-1.03175950,1.80755460,0.10825710,2.20666361,1.33633149,1.39290452,0.45211342,-0.07837920,2.08304930,-0.28387162,-0.70775616,0.43626297,0.53556961,0.06201901,-0.59255266,-0.11854446,2.10024118,0.37638292,-0.56178707,-0.25220188,-1.23731256,-1.30002999,0.34283713,0.30502397,-1.09233856,1.12430644,0.52273953,-0.68507338,-0.69913578,0.88440478,-0.76959240,1.07093310,-0.34802195,0.35683727,-0.76079178,-1.92807376,0.84499562,1.39131641,0.44825050,0.34567752,0.44607711,-1.00986362,-0.50038189,-0.09060892,-2.55645394,0.56416476,-0.83058155,-0.65931624,-0.73649710,0.59814465,-0.86736494,-0.32200798,-1.28087902,-0.76818323,0.86848933,-0.98678392,-1.30813944,-0.20255326,0.26557815,-0.31090519,-1.46331608,-0.62782109,0.59034890,1.63147473,-0.17727259,-0.37636510,1.27368402,0.19096918,-0.29936951,-1.99038267,0.54831523,0.48849005,-2.55680346,-0.63126534,1.21715927,1.22841084,-0.67416084,0.02927168,-0.36693662,0.63204330,0.13721083,0.28742912,0.19470036,0.74873924,-1.47602463,0.86264688,-0.23730527,-0.99978864,-1.17048764,-0.34996086,1.43019187,0.26224539,0.60689932,-0.75002515,-0.79823422,-1.37300086,-0.19951135,-0.12150808,-0.75272322,0.23755015,0.31270382,1.66539109,-1.04104745,0.79540199,-0.54042423,-0.54150617,0.43871084,0.24163951,-0.24517761,-0.66178995,-1.13064528,-0.84426326,0.56437236,0.09088907,-0.82823074,0.81753862,-1.74096012,-1.80599844,-0.60943592,1.36094582,-1.47762752,0.15931177,1.05569172,0.36751524,0.06497604,0.13536447,-1.57156146,0.22783801,-0.9691010
auto input2 = NDArrayFactory::create<TypeParam>('c', {3, 4, 4, 5}, {0.98114507,0.96400015,0.58669623,0.60073098,0.75425418,0.44258752,0.76373084,0.96593234,0.34067846,0.57962620,0.77517051,0.97472977,0.79237527,0.68690428,0.21719366,0.79959206,0.84814187,0.22496814,0.08646965,0.31110474,0.79813162,0.19661444,0.57760099,0.72138960,0.15244268,0.87687051,0.11130344,0.01087698,0.34817841,0.54992017,0.23443850,0.31725614,0.59755220,0.20364695,0.00531392,0.23403114,0.07442912,0.83707647,0.89291743,0.09044587,0.69041462,0.29904183,0.61904680,0.85306847,0.34467042,0.95839152,0.54517124,0.29640937,0.94855959,0.95970016,0.94045145,0.95510301,0.34666505,0.34717010,0.69245678,0.71669175,0.59043738,0.64924132,0.06033522,0.60185199,0.04690073,0.59241154,0.40229547,0.23002481,0.45161195,0.73743778,0.93209113,0.37294358,0.50177744,0.15072501,0.26146917,0.05252146,0.04758931,0.76448288,0.85149045,0.08840467,0.07692576,0.33180160,0.27241259,0.74834620,0.56453640,0.23057286,0.68429752,0.11961551,0.39045977,0.44356094,0.77018807,0.07984410,0.47926806,0.26165759,0.18606064,0.89972877,0.17962874,0.47273120,0.64641705,0.61890443,0.58730015,0.25937832,0.35231561,0.10243882,0.17459193,0.95906995,0.09227025,0.30003223,0.41601210,0.38269713,0.84799751,0.59295173,0.76277990,0.68910424,0.37672606,0.40675461,0.94346058,0.91438505,0.84728183,0.64367667,0.74899979,0.60570691,0.16417363,0.68852426,0.85486889,0.22585792,0.86953176,0.07465519,0.93096301,0.38008822,0.38752587,0.44004038,0.13170612,0.94541045,0.89349973,0.69245307,0.94978877,0.98776658,0.79445884,0.30607409,0.58264961,0.37980538,0.41810784,0.48903038,0.51615888,0.57682794,0.82481897,0.78341080,0.48446465,0.17447931,0.71125424,0.30263851,0.70675352,0.03215584,0.92381065,0.22343694,0.08851149,0.91402490,0.70074717,0.30912192,0.37723206,0.97579397,0.23554587,0.95939133,0.41565709,0.01741416,0.58362787,0.22106662,0.89065537,0.31900249,0.41280911,0.67947610,0.04545590,0.15352812,0.85412524,0.84933222,0.80000225,0.93147073,0.70094105,0.69269875,0.95282194,0.65913582,0.79186874,0.59855248,0.39707430,0.95126239,0.15618217,0.33446689,0.98123758,0.84770758,0.98081012,0.54427413,0.18728519,0.89792955,0.53360126,0.72812986,0.13307744,0.51217443,0.66708084,0.29416915,0.31298995,0.39155037,0.29288291,0.87063305,0.61759154,0.73723332,0.37167635,0.82122716,0.22937430,0.76570536,0.47911792,0.02826214,0.94277323,0.59945469,0.19042060,0.68173155,0.82771295,0.95649538,0.40833101,0.90838542,0.55245881,0.49011012,0.36773444,0.34513527,0.42050683,0.16113964,0.30969388,0.27174174,0.12117655,0.35270175,0.81967867,0.63723136,0.84309389,0.71822576,0.84883484,0.32306117,0.08176457,0.56175486,0.34892198,0.09306929,0.85437582,0.13925577,0.48629188,0.29923539});
auto exp = NDArrayFactory::create<TypeParam>('c', {3, 8, 8, 16}, {5.98743296,-2.83037376,-0.87943113,1.41339970,1.32433391,-1.20299149,-0.02893090,2.05326009,1.19417048,5.58212376,3.28139353,1.19237995,-1.09431255,-2.55264497,3.11014652,6.81296825,-2.09029293,-4.32068443,-0.52808392,-1.97968531,-0.18673831,0.84605980,4.55825520,2.71503139,0.15210046,0.85310984,-3.82062817,2.76470995,3.69004202,-1.45017099,-2.59361267,-1.35094655,7.24145126,-5.25432396,0.19920218,-4.30596399,1.35318923,-3.88142037,3.67493343,2.25931478,2.87630725,1.66349852,6.21347952,0.94105923,-1.61742055,-2.35699606,0.12850338,1.79141688,-2.09535933,-6.35418081,-0.06303531,-4.38615131,0.48237842,0.26528549,3.38231516,3.76315165,-0.40254810,-0.23716694,-6.13381910,-0.41950428,-0.89680839,-1.46491277,-1.98541689,-0.99357355,5.58237648,-2.38937521,-0.00872564,-2.37138414,4.91117287,-4.51916361,0.97943687,2.91052818,-2.50362611,1.70252812,5.04137802,3.57108784,-1.87532270,-3.66677809,-2.38861251,5.55765152,-7.27571774,-1.68887305,-0.72266489,-4.42809057,-0.92118186,1.02381468,4.44284725,5.17150497,-0.42438728,2.02693963,-1.36484981,-1.47912180,0.26649538,-0.02091765,-2.86906910,-3.03046989,1.35122132,-3.21707630,2.21112418,0.24121630,3.96940088,-7.66105747,2.76352382,-0.99061489,-2.16720009,-1.63170409,1.12701774,-1.02415371,-0.90435314,-1.51372027,-0.76884907,0.39066136,-0.89562428,-2.03204703,1.28074932,-2.14551091,-2.36843777,0.46580017,0.75451565,-0.00336730,-1.06597757,3.27195978,-0.41307712,-0.10376054,-1.34102952,-2.22901654,2.31929803,1.40851438,-2.23774385,0.20417206,-1.12153268,-0.13188094,-3.96649432,2.10269976,0.49845099,6.18937683,-0.51783508,-0.48048639,-1.92970264,3.16670656,1.13355756,-0.07890664,1.31536257,-0.43924797,-0.04562932,-0.87974954,0.75411212,-2.39745235,-3.97132111,0.37202546,-2.40399146,-1.50796390,-3.08302689,0.23075986,-0.94316757,1.34948587,0.58591264,2.18529797,7.97652435,2.32798409,-4.09404373,0.89634895,0.77697754,-0.65091681,-7.05506849,5.86194515,2.51394033,4.69959354,0.20835471,3.18049693,-1.29682434,3.70832396,-0.48123091,-1.67904007,-1.35418940,1.58435583,-1.13851106,-1.19225955,0.59713769,-5.80462933,-7.45143986,-1.08658695,1.03244078,-1.75307107,-7.07100582,3.85825157,1.62127817,2.32572675,0.56171900,-0.80591971,3.98835945,0.15742642,-2.97832179,0.13821673,-0.72556758,-0.84936106,-7.28444147,3.94134307,0.80779338,7.47784615,8.23335075,4.80595016,-4.89574575,4.03362942,-6.67522192,-4.55204487,2.12511182,-2.70781207,-1.57226098,-3.08408356,-0.30812448,-5.32870674,-5.13238287,0.49605465,-0.55042171,0.46324944,-3.83545256,-0.12562510,-0.20978995,-0.13068712,-1.92144060,-1.68787408,5.45581436,-0.79583496,-2.38866687,-3.90546346,-0.47028148,-0.14319679,-3.37016582,2.00905991,-1.21345615,1.81376505,7.73004007,0.74310112,-4.64536428,3.78111577,-9.05182457,-0.10674095,1.53476238,0.63345337,-0.40907967,-1.44729769,-1.87145400,-2.46623540,1.07472968,0.77390999,-3.93438888,4.49174690,-0.96686655,1.92278123,0.30049133,-0.02388665,-1.99777114,-3.23885751,5.87784004,2.13776040,3.56758308,-3.37774134,-3.67526293,1.63700044,-1.69959962,-0.99112594,6.03103638,1.67399430,-1.28699589,7.16759014,12.63490295,3.62937450,-4.75982571,2.17861104,-2.03065681,4.30207729,-0.46797156,-2.96022511,-6.02702332,3.09229851,-1.39771092,-0.03471333,3.22175527,5.63565636,1.78195477,-0.63545251,-3.99497652,1.46043062,4.60050488,-2.96651959,-2.03159475,-1.52386189,-0.15129802,-3.90390921,-0.63852370,0.79210538,2.35288715,-5.55609035,5.36427498,-0.60248077,-0.26181316,5.04884720,8.53192806,5.05080223,-6.56371737,1.52260923,-7.13623667,6.49414349,2.33445597,-4.11490965,-6.44347477,-0.47079402,-0.63467920,2.60399365,1.05958164,3.66901422,-1.05657935,1.88611507,-6.37475634,2.01480770,3.36020517,-5.11001921,-0.46132171,2.16525555,4.21938848,-2.08346295,2.86168146,1.26987600,6.76066971,-7.84916353,4.11700916,0.47985530,-4.60113716,7.42062473,6.37472820,4.37820530,-7.12197018,0.01357239,-7.90392113,8.32131577,-0.87593079,-0.16994858,-5.86345863,-0.20697471,-1.37845206,1.63819647,1.59720242,-0.74357712,-1.88725603,-1.98357940,-8.57950306,-4.10104513,3.57231
nd4j::ops::deconv2d_tf op;
auto result = op.execute({&input0, &input1, &input2}, {}, {7,7, 2,2, 0,0, 1,1, 1,1});
ASSERT_EQ(Status::OK(), result->status());
auto z = result->at(0);
ASSERT_TRUE(exp.isSameShape(z));
ASSERT_TRUE(exp.equalsTo(z));
delete result;
}
//////////////////////////////////////////////////////////////////////
TYPED_TEST(TypedConvolutionTests2, Test_Conv2D_TF_1) {
auto input = NDArrayFactory::create<TypeParam>('c', {54, 1, 12, 12});
auto weights = NDArrayFactory::create<TypeParam>('c', {1, 2, 12, 2});
nd4j::ops::conv2d op;
auto result = op.execute({&input, &weights}, {}, {-1,-1, 1,1, 0,0, 1,1, 1,1});
ASSERT_EQ(Status::OK(), result->status());
delete result;
}
//////////////////////////////////////////////////////////////////////
TEST_F(ConvolutionTests2, Test_Dilation2D_Again_1) {
auto x = NDArrayFactory::create<double>('c', {4, 128, 128, 4});
auto w = NDArrayFactory::create<double>('c', {4, 5, 4});
auto exp = NDArrayFactory::create<double>('c', {4, 64, 43, 4});
nd4j::ops::dilation2d op;
auto result = op.execute({&x, &w}, {}, {1, 1,5,7,1, 1,2,3,1});
ASSERT_EQ(Status::OK(), result->status());
auto z = result->at(0);
ASSERT_TRUE(exp.isSameShape(z));
delete result;
}
//////////////////////////////////////////////////////////////////////
TEST_F(ConvolutionTests2, Test_Dilation2D_Again_2) {
auto x = NDArrayFactory::create<double>('c', {4, 26, 19, 4});
auto w = NDArrayFactory::create<double>('c', {11, 7, 4});
nd4j::ops::dilation2d op;
auto result = op.execute({&x, &w}, {}, {0, 1,2,3,1, 1,3,2,1});
ASSERT_EQ(Status::OK(), result->status());
delete result;
}
//////////////////////////////////////////////////////////////////////
TYPED_TEST(TypedConvolutionTests2, sconv2d_bp_test2) {
int bS=3, iH=16,iW=16, iC=3,mC=3, kH=1,kW=1, sH=1,sW=1, pH=0,pW=0, dH=2,dW=2;
int oH=16,oW=16;
int oC=2;
int paddingMode = 0; // 1-SAME, 0-VALID;
int dataFormat = 0; // 1-NHWC, 0-NCHW
NDArray input('c', {bS, iC, iH, iW}, typeid(TypeParam) == typeid(float) ? nd4j::DataType::FLOAT32 : nd4j::DataType::DOUBLE);
NDArray gradO('c', {bS, oC, oH, oW}, typeid(TypeParam) == typeid(float) ? nd4j::DataType::FLOAT32 : nd4j::DataType::DOUBLE);
NDArray weightsDepth('c', {kH, kW, iC, mC}, typeid(TypeParam) == typeid(float) ? nd4j::DataType::FLOAT32 : nd4j::DataType::DOUBLE);
NDArray weightsPoint('f', {1, 1, iC*mC, oC}, typeid(TypeParam) == typeid(float) ? nd4j::DataType::FLOAT32 : nd4j::DataType::DOUBLE);
NDArray bias('c', {1,oC}, {0.5, 0.5}, typeid(TypeParam) == typeid(float) ? nd4j::DataType::FLOAT32 : nd4j::DataType::DOUBLE);
NDArray gradI(&input);
NDArray gradWD(&weightsDepth);
NDArray gradWP(&weightsPoint);
NDArray gradB(&bias);
input = 2.;
weightsDepth.linspace(0.1, 0.1);
weightsPoint.linspace(0.15, 0.1);
gradO.linspace(0.01, 0.01);
nd4j::ops::sconv2d_bp op;
Nd4jStatus status = op.execute({&input, &gradO, &weightsDepth, & weightsPoint, &bias},
{&gradI, &gradWD, &gradWP, &gradB},
{}, {kH,kW, sH,sW, pH,pW, dH,dW, paddingMode, dataFormat}, {});
ASSERT_EQ(Status::OK(), status);
NDArray expGradI = gradI;
NDArray expGradWD = gradWD;
NDArray expGradWP = gradWP;
NDArray expGradB = gradB;
for( int i=0; i<10; i++ ) {
Nd4jStatus status = op.execute({&input, &gradO, &weightsDepth, & weightsPoint, &bias},
{&gradI, &gradWD, &gradWP, &gradB},
{}, {kH,kW, sH,sW, pH,pW, dH,dW, paddingMode, dataFormat}, {});
ASSERT_EQ(Status::OK(), status);
ASSERT_TRUE(expGradI.equalsTo(gradI));
ASSERT_TRUE(expGradWD.equalsTo(gradWD));
ASSERT_TRUE(expGradWP.equalsTo(gradWP));
ASSERT_TRUE(expGradB.equalsTo(expGradB));
}
}
//////////////////////////////////////////////////////////////////////
TEST_F(ConvolutionTests2, im2col_bp_1) {
int bS=3, iH=12,iW=12, iC=6,oC=3, kH=2,kW=2, sH=1,sW=1, pH=0,pW=0, dH=1,dW=1;
int oH=12,oW=12;
// [bS, iC, kH, kW, oH, oW] is de-convoluted to [bS, iC, iH, iW]
NDArray input('c', {bS, iC, iH, iW}, nd4j::DataType::DOUBLE);
NDArray gradO('c', {bS, iC, kH, kW, oH, oW}, nd4j::DataType::DOUBLE);
NDArray gradI('c', {bS, iC, iH, iW}, nd4j::DataType::DOUBLE); // output
nd4j::ops::im2col_bp op;
Nd4jStatus status = op.execute({&input, &gradO}, {&gradI}, {}, {kH, kW, sH, sW, pH, pW, dH, dW, 1}, {});
ASSERT_EQ(ND4J_STATUS_OK, status);
}
//////////////////////////////////////////////////////////////////////
TEST_F(ConvolutionTests2, depthwise_conv2d_test5) {
int bS=1, iH=111,iW=111, iC=32,mC=1, kH=7,kW=7, sH=2,sW=2, pH=0,pW=0, dH=1,dW=1;
int oC=iC*mC;
int oH=56,oW=56;
int paddingMode = 1; // 1-SAME, 0-VALID;
int dataFormat = 1; // 1-NHWC, 0-NCHW
const float unique = -1000000;
NDArray input('c', {bS, iH, iW, iC}, nd4j::DataType::FLOAT32);
NDArray weights('c', {kH, kW, iC, mC}, nd4j::DataType::FLOAT32);
NDArray output('c', {bS, oH, oW, oC}, nd4j::DataType::FLOAT32);
input.linspace(0.1, 0.0001);
weights = 0.5;
output = unique;
nd4j::ops::depthwise_conv2d op;
Nd4jStatus status = op.execute({&input, &weights}, {&output} , {}, {kH,kW, sH,sW, pH,pW, dH,dW, paddingMode, dataFormat}, {});
ASSERT_EQ(Status::OK(), status);
for(Nd4jLong i=output.lengthOf()/1.5; i < output.lengthOf(); ++i)
ASSERT_EQ(output.e<float>(i) != unique, true);
}
//////////////////////////////////////////////////////////////////////
TEST_F(ConvolutionTests2, conv2d_bp_4) {
int bS=1, iH=7,iW=1, iC=2,oC=3, kH=2,kW=1, sH=1,sW=1, pH=0,pW=0, dH=1,dW=1;
int oH=7,oW=1;
int paddingMode = 1; // 1-SAME, 0-VALID;
int dataFormat = 0; // 1-NHWC, 0-NCHW
NDArray input('c', {bS, iC, iH, iW}, nd4j::DataType::FLOAT32);
NDArray weights('c', {kH, kW, iC, oC}, nd4j::DataType::FLOAT32);
NDArray bias('c', {oC}, {1,2,3}, nd4j::DataType::FLOAT32);
NDArray gradO('c', {bS, oC, oH, oW}, nd4j::DataType::FLOAT32);
NDArray gradI('c', {bS, iC, iH, iW}, nd4j::DataType::FLOAT32);
NDArray gradW('c', {kH, kW, iC, oC}, nd4j::DataType::FLOAT32);
NDArray gradB('c', {oC}, nd4j::DataType::FLOAT32);
input = 2.;
weights.linspace(0.1, 0.1);
gradO.linspace(0.01, 0.01);
nd4j::ops::conv2d_bp op;
auto status = op.execute({&input, &weights, &bias, &gradO}, {&gradI, &gradW, &gradB}, {}, {kH,kW, sH,sW, pH,pW, dH,dW, paddingMode, dataFormat}, {});
ASSERT_EQ(Status::OK(), status);
}
//////////////////////////////////////////////////////////////////////
TEST_F(ConvolutionTests2, sconv2d_bp_2) {
int bS=1, iH=8,iW=8, iC=3,mC=3, kH=1,kW=1, sH=1,sW=1, pH=0,pW=0, dH=1,dW=1;
int oH=8,oW=8;
int oC=2; // iC*mC if weightsPoint = nullptr
int paddingMode = 0; // 1-SAME, 0-VALID;
int dataFormat = 0; // 1-NHWC, 0-NCHW
auto input = NDArrayFactory::create<double>('c', {bS, iC, iH, iW});
auto gradO = NDArrayFactory::create<double>('c', {bS, oC, oH, oW});
auto weightsDepth = NDArrayFactory::create<double>('c', {kH, kW, iC, mC});
auto weightsPoint = NDArrayFactory::create<double>('c', {1, 1, iC*mC, oC});
auto bias = NDArrayFactory::create<double>('c', {1,oC}, {1,2});
auto gradI = NDArrayFactory::create<double>('c', {bS, iC, iH, iW});
auto gradWD = NDArrayFactory::create<double>('f', {kH, kW, iC, mC});
auto gradWP = NDArrayFactory::create<double>('c', {1, 1, iC*mC, oC});
auto gradB = NDArrayFactory::create<double>('c', {1,oC}, {1,2});
input = 2.;
weightsDepth.linspace(0.1, 0.1);
weightsDepth.linspace(-0.5, 0.1);
gradO.linspace(0.01, 0.01);
nd4j::ops::sconv2d_bp op;
auto status = op.execute({&input, &gradO, &weightsDepth, &weightsPoint, &bias}, {&gradI, &gradWD, &gradWP, &gradB}, {}, {kH,kW, sH,sW, pH,pW, dH,dW, paddingMode, dataFormat}, {});
ASSERT_EQ(Status::OK(), status);
}
// @Test
// public void testSconv2dbp(){
// DynamicCustomOp op = DynamicCustomOp.builder("sconv2d_bp")
// .addInputs(Nd4j.create(DataType.DOUBLE, 1,3,8,8), // input
// Nd4j.create(DataType.DOUBLE, 1, 2, 8, 8), // gradO
// Nd4j.create(DataType.DOUBLE, 1, 1, 3, 3), // weightsDepth
// Nd4j.create(DataType.DOUBLE, 1, 1, 9, 2), // weightsPoint
// Nd4j.create(DataType.DOUBLE, 1, 2))
// .addOutputs(Nd4j.create(DataType.DOUBLE, 1, 3, 8, 8),
// Nd4j.create(DataType.DOUBLE, new long[]{1, 1, 3, 3}, 'f'),
// Nd4j.create(DataType.DOUBLE, 1, 1, 9, 2),
// Nd4j.create(DataType.DOUBLE, 1, 2))
// .addIntegerArguments(1,1, 1,1, 0,0, 1,1, 0)
// .build();
// Nd4j.exec(op);
// }
#endif //LIBND4J_CONVOLUTIONTESTS2_H