2859 lines
343 KiB
C++
2859 lines
343 KiB
C++
/* ******************************************************************************
|
|
*
|
|
*
|
|
* This program and the accompanying materials are made available under the
|
|
* terms of the Apache License, Version 2.0 which is available at
|
|
* https://www.apache.org/licenses/LICENSE-2.0.
|
|
*
|
|
* See the NOTICE file distributed with this work for additional
|
|
* information regarding copyright ownership.
|
|
* Unless required by applicable law or agreed to in writing, software
|
|
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
|
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
|
* License for the specific language governing permissions and limitations
|
|
* under the License.
|
|
*
|
|
* SPDX-License-Identifier: Apache-2.0
|
|
******************************************************************************/
|
|
|
|
//
|
|
// @author raver119@gmail.com
|
|
// @author Yurii Shyrma (iuriish@yahoo.com), created 02.04.2019
|
|
//
|
|
|
|
#ifndef LIBND4J_CONVOLUTIONTESTS2_H
|
|
#define LIBND4J_CONVOLUTIONTESTS2_H
|
|
|
|
#include "testlayers.h"
|
|
#include <array/NDArray.h>
|
|
#include <graph/Context.h>
|
|
#include <graph/Node.h>
|
|
#include <graph/Variable.h>
|
|
#include <graph/VariableSpace.h>
|
|
#include <ops/declarable/CustomOperations.h>
|
|
#include <ops/declarable/helpers/convolutions.h>
|
|
#include <ops/declarable/helpers/col2im.h>
|
|
#include <helpers/PointersManager.h>
|
|
#include <helpers/GradCheck.h>
|
|
|
|
using namespace sd;
|
|
using namespace sd::graph;
|
|
|
|
class ConvolutionTests2 : public testing::Test {
|
|
public:
|
|
|
|
const int bS = 2; // batch size
|
|
const int iD = 1; // input depth (number of picture channels, for example rgb=3)
|
|
const int iH = 28; // picture height in pixels
|
|
const int iW = 28; // picture width in pixels
|
|
const int oD = 3; // output depth (= N for dense layer)
|
|
const int kH = 5; // kernel height in pixels
|
|
const int kW = 5; // kernel width in pixels
|
|
const int sH = 1; // stride step in horizontal direction
|
|
const int sW = 1; // stride step in vertical direction
|
|
const int pH = 0; // padding height
|
|
const int pW = 0; // padding width
|
|
const int dH = 2; // dilation height
|
|
const int dW = 2; // dilation width
|
|
const int oH = (iH - kH - (kH-1)*(dH-1) + 2*pH)/sH + 1; // output height
|
|
const int oW = (iW - kW - (kW-1)*(dW-1) + 2*pW)/sW + 1; // output width
|
|
|
|
};
|
|
|
|
//////////////////////////////////////////////////////////////////////
|
|
TEST_F(ConvolutionTests2, im2col_1) {
|
|
|
|
int bS=2, iH=4,iW=3, iC=4, kH=3,kW=2, sH=1,sW=1, pH=0,pW=0, dH=1,dW=1;
|
|
int oH = (iH - (kH + (kH-1)*(dH-1)) + 2*pH)/sH + 1; // VALID
|
|
int oW = (iW - (kW + (kW-1)*(dW-1)) + 2*pW)/sW + 1; // VALID
|
|
|
|
int paddingMode = 0; // 1-SAME, 0-VALID;
|
|
|
|
NDArray image('c', {bS, iC, iH, iW}, sd::DataType::DOUBLE);
|
|
NDArray expected('c', {bS, iC, kH, kW, oH, oW}, {1, 2, 4, 5, 2, 3, 5, 6, 4, 5, 7, 8, 5, 6, 8, 9, 7, 8, 10, 11, 8, 9, 11, 12, 13, 14, 16, 17, 14,
|
|
15, 17, 18, 16, 17, 19, 20, 17, 18, 20, 21, 19, 20, 22, 23, 20, 21, 23, 24, 25, 26, 28, 29, 26, 27, 29, 30,
|
|
28, 29, 31, 32, 29, 30, 32, 33, 31, 32, 34, 35, 32, 33, 35, 36, 37, 38, 40, 41, 38, 39, 41, 42, 40, 41, 43,
|
|
44, 41, 42, 44, 45, 43, 44, 46, 47, 44, 45, 47, 48, 49, 50, 52, 53, 50, 51, 53, 54, 52, 53, 55, 56, 53, 54,
|
|
56, 57, 55, 56, 58, 59, 56, 57, 59, 60, 61, 62, 64, 65, 62, 63, 65, 66, 64, 65, 67, 68, 65, 66, 68, 69, 67,
|
|
68, 70, 71, 68, 69, 71, 72, 73, 74, 76, 77, 74, 75, 77, 78, 76, 77, 79, 80, 77, 78, 80, 81, 79, 80, 82, 83,
|
|
80, 81, 83, 84, 85, 86, 88, 89, 86, 87, 89, 90, 88, 89, 91, 92, 89, 90, 92, 93, 91, 92, 94, 95, 92, 93, 95, 96});
|
|
|
|
image.linspace(1, 1);
|
|
|
|
sd::ops::im2col op;
|
|
auto results = op.evaluate({&image}, {kH,kW, sH,sW, pH,pW, dH,dW, paddingMode});
|
|
auto column = results.at(0);
|
|
|
|
ASSERT_EQ(Status::OK(), results.status());
|
|
|
|
ASSERT_TRUE(expected.isSameShape(column));
|
|
ASSERT_TRUE(expected.equalsTo(column));
|
|
|
|
}
|
|
|
|
template <typename T>
|
|
class TypedConvolutionTests2 : public testing::Test {
|
|
public:
|
|
|
|
};
|
|
|
|
typedef ::testing::Types<double, float> TestingTypes;
|
|
TYPED_TEST_CASE(TypedConvolutionTests2, TestingTypes);
|
|
|
|
//////////////////////////////////////////////////////////////////////
|
|
TYPED_TEST(TypedConvolutionTests2, deconv2d_tf_test2) {
|
|
|
|
int bS=2, iH=4,iW=4, iC=5,oC=10, kH=2,kW=2, sH=1,sW=1, pH=0,pW=0, dH=1,dW=1;
|
|
int oH=4,oW=4;
|
|
int paddingMode = 1; // 1-SAME, 0-VALID;
|
|
int dataFormat = 1; // 1-NHWC, 0-NCHW
|
|
|
|
auto input = NDArrayFactory::create<TypeParam>('c', {bS, oH, oW, oC});
|
|
auto weights = NDArrayFactory::create<TypeParam>('c', {kH, kW, iC, oC});
|
|
auto outShape = NDArrayFactory::create<TypeParam>('c', {4}, {static_cast<TypeParam>(bS), static_cast<TypeParam>(iH), static_cast<TypeParam>(iW), static_cast<TypeParam>(iC)});
|
|
auto exp = NDArrayFactory::create<TypeParam>('c', {bS, iH, iW, iC}, {2.75f, 7.75f, 12.75f, 17.75f, 22.75f, 30.5f, 40.5f, 50.5f, 60.5f, 70.5f, 30.5f, 40.5f, 50.5f, 60.5f, 70.5f, 30.5f, 40.5f, 50.5f, 60.5f, 70.5f,
|
|
55.5f, 65.5f, 75.5f, 85.5f, 95.5f, 161.f, 181.f, 201.f, 221.f, 241.f, 161.f, 181.f, 201.f, 221.f, 241.f, 161.f, 181.f, 201.f, 221.f, 241.f,
|
|
55.5f, 65.5f, 75.5f, 85.5f, 95.5f, 161.f, 181.f, 201.f, 221.f, 241.f, 161.f, 181.f, 201.f, 221.f, 241.f, 161.f, 181.f, 201.f, 221.f, 241.f,
|
|
55.5f, 65.5f, 75.5f, 85.5f, 95.5f, 161.f, 181.f, 201.f, 221.f, 241.f, 161.f, 181.f, 201.f, 221.f, 241.f, 161.f, 181.f, 201.f, 221.f, 241.f,
|
|
2.75f, 7.75f, 12.75f, 17.75f, 22.75f, 30.5f, 40.5f, 50.5f, 60.5f, 70.5f, 30.5f, 40.5f, 50.5f, 60.5f, 70.5f, 30.5f, 40.5f, 50.5f, 60.5f, 70.5f,
|
|
55.5f, 65.5f, 75.5f, 85.5f, 95.5f, 161.f, 181.f, 201.f, 221.f, 241.f, 161.f, 181.f, 201.f, 221.f, 241.f, 161.f, 181.f, 201.f, 221.f, 241.f,
|
|
55.5f, 65.5f, 75.5f, 85.5f, 95.5f, 161.f, 181.f, 201.f, 221.f, 241.f, 161.f, 181.f, 201.f, 221.f, 241.f, 161.f, 181.f, 201.f, 221.f, 241.f,
|
|
55.5f, 65.5f, 75.5f, 85.5f, 95.5f, 161.f, 181.f, 201.f, 221.f, 241.f, 161.f, 181.f, 201.f, 221.f, 241.f, 161.f, 181.f, 201.f, 221.f, 241.f});
|
|
input = 0.5;
|
|
weights.linspace(0.1, 0.1);
|
|
|
|
sd::ops::deconv2d_tf op;
|
|
auto results = op.evaluate({&outShape, &weights, &input}, {}, {kH,kW, sH,sW, pH,pW, dH,dW, paddingMode, dataFormat});
|
|
auto output = results.at(0);
|
|
|
|
ASSERT_EQ(Status::OK(), results.status());
|
|
ASSERT_TRUE(exp.isSameShape(output));
|
|
ASSERT_TRUE(exp.equalsTo(output));
|
|
|
|
}
|
|
|
|
//////////////////////////////////////////////////////////////////////
|
|
TYPED_TEST(TypedConvolutionTests2, Test_DeConv2D_TF_1) {
|
|
auto input0 = NDArrayFactory::create<TypeParam>('c', {4}, {12.f, 5.f, 5.f, 32.f});
|
|
auto input1 = NDArrayFactory::create<TypeParam>('c', {2, 2, 32, 16});
|
|
auto input2 = NDArrayFactory::create<TypeParam>('c', {12, 4, 4, 16});
|
|
auto exp = NDArrayFactory::create<TypeParam>('c', {12, 5, 5, 32});
|
|
|
|
sd::ops::deconv2d_tf op;
|
|
auto result = op.evaluate({&input0, &input1, &input2}, {}, {2, 2, 1, 1, 0, 0, 1, 1, 0, 1});
|
|
ASSERT_EQ(Status::OK(), result.status());
|
|
|
|
ASSERT_EQ(exp, *result.at(0));
|
|
|
|
}
|
|
|
|
//////////////////////////////////////////////////////////////////////
|
|
TYPED_TEST(TypedConvolutionTests2, Test_DeConv2D_TF_2) {
|
|
auto input0 = NDArrayFactory::create<TypeParam>('c', {4}, {3.f, 8.f, 8.f, 16.f});
|
|
|
|
auto input1 = NDArrayFactory::create<TypeParam>('c', {7, 7, 16, 5}, {1.05293429f, -0.89349967f, 0.31027254f, 1.22991478f, -0.62926656f, 0.56918693f,
|
|
-1.60992694f, 1.10167944f, -0.80843484f, 0.07521993f, -1.15994942f, 0.76016301f, -0.40056285f, -1.16872537f, -0.91384381f, -0.36700436f, 1.82389200f, -1.18200207f, 0.51612782f, -0.92479187f, -0.09307563f, -0.55122334f, 1.23532486f, -1.11124146f, -0.05812126f, 0.68159896f, 0.69125599f, -0.77127314f, -0.10874277f, 0.86469102f,
|
|
-1.31614351f, 0.33354419f, -1.71750402f, 0.17197680f, -1.03965557f, 1.10570908f, -1.19115615f, 1.05115080f, 0.18277600f, 1.08820546f, -0.72191417f, -0.10999311f, 1.56521320f, -0.35433730f, -1.11799145f, 0.34499285f, 0.64998639f, -1.64371550f, 0.92592359f, -0.47659501f, 0.49101439f, -0.15613313f, 1.47486567f, 0.43576995f,
|
|
2.19538260f, -0.83567709f, -1.21846950f, 0.80400819f, 1.14637423f, -1.01503456f, -0.61992753f, -0.47378838f, 0.86503726f, 0.27147385f, 0.37073180f, -0.19951358f, 0.79167330f, -0.33982825f, 0.18631981f, -1.54715073f, 0.39967480f, 0.95067030f, 1.12508667f, -0.86676019f, -1.10341156f, 2.33141375f, 1.10972047f, 0.71407092f,
|
|
1.70640314f, 1.80666339f, 0.59465605f, -0.39653218f, -2.61163163f, -1.15013492f, -1.19908321f, 0.41783467f, -0.22730024f, 0.31425011f, -0.58562893f, -0.10131568f, -0.85047537f, -2.59974790f, 1.22072542f, -2.08812046f, -0.19363593f, -1.27664304f, -0.02703438f, 1.08477545f, -0.65506506f, 0.46040919f, -0.13715318f,
|
|
-0.74945593f, -0.69006950f, -1.29617655f, -0.15865716f, 1.38956285f, 0.90216327f, -1.31185400f, -0.15067385f, -0.63093358f, -0.05895613f, 0.26545224f, 0.29332840f, 0.42852548f, 0.72409540f, 0.12879130f, 1.43038857f, 0.68647617f, 2.19654775f, 0.51878077f, -0.03769343f, 0.52877223f, -0.21733910f, 1.13710785f, -0.59003806f,
|
|
1.54624867f, -0.64997369f, -1.03239334f, 0.19708300f, 0.68658423f, 0.71048903f, -1.55250466f, -1.38636279f, 0.32385820f, 0.81226677f, 0.19209047f, -0.23002781f, -0.63631231f, 1.02101684f, 0.65428704f, -0.17206922f, 1.09488952f, 1.03022420f, -0.95567745f, -0.07595373f, -1.48606372f, 2.57174873f, -1.75366247f, 1.12913883f,
|
|
0.97053039f, -0.28552356f, 0.56511772f, -0.79568213f, 0.07561764f, -1.02085686f, 1.05770981f, -1.25715709f, 0.42046708f, -2.57390857f, 0.96947151f, 1.05215812f, 0.65624017f, -1.29019403f, 0.64157075f, -0.40509227f, -0.65354455f, 0.42348680f, -1.34107757f, 0.05931387f, -0.54337227f, 0.95460182f, 1.59319806f, -0.44433126f,
|
|
-0.33717924f, 0.79566282f, 0.50112695f, -0.22244534f, 1.76904583f, -0.89817202f, 1.82985342f, 0.17671813f, 0.80720717f, 1.32469308f, 0.39417782f, -0.23720963f, 0.96796370f, -1.02348757f, -0.86615551f, -1.58120525f, -0.37634999f, 0.00905940f, 0.01880967f, 1.75771821f, -0.64372772f, 0.36687651f, 0.15854552f, -0.67599791f,
|
|
0.53726906f, -1.20158446f, -1.78549063f, 0.96476388f, -0.66158366f, -0.41681561f, -0.97541636f, 2.35928202f, 0.32130197f, 1.06886065f, 1.38736427f, -0.73718959f, 0.11215294f, 2.12865782f, -0.37927702f, 0.55621815f, -1.10108411f, -0.02032263f, 0.29595461f, 1.58737493f, 1.24001300f, -0.66748160f, 0.80729002f, -0.10575818f,
|
|
-1.03175950f, 1.80755460f, 0.10825710f, 2.20666361f, 1.33633149f, 1.39290452f, 0.45211342f, -0.07837920f, 2.08304930f, -0.28387162f, -0.70775616f, 0.43626297f, 0.53556961f, 0.06201901f, -0.59255266f, -0.11854446f, 2.10024118f, 0.37638292f, -0.56178707f, -0.25220188f, -1.23731256f, -1.30002999f, 0.34283713f, 0.30502397f,
|
|
-1.09233856f, 1.12430644f, 0.52273953f, -0.68507338f, -0.69913578f, 0.88440478f, -0.76959240f, 1.07093310f, -0.34802195f, 0.35683727f, -0.76079178f, -1.92807376f, 0.84499562f, 1.39131641f, 0.44825050f, 0.34567752f, 0.44607711f, -1.00986362f, -0.50038189f, -0.09060892f, -2.55645394f, 0.56416476f, -0.83058155f, -0.65931624f,
|
|
-0.73649710f, 0.59814465f, -0.86736494f, -0.32200798f, -1.28087902f, -0.76818323f, 0.86848933f, -0.98678392f, -1.30813944f, -0.20255326f, 0.26557815f, -0.31090519f, -1.46331608f, -0.62782109f, 0.59034890f, 1.63147473f, -0.17727259f, -0.37636510f, 1.27368402f, 0.19096918f, -0.29936951f, -1.99038267f, 0.54831523f, 0.48849005f, -2.55680346f, -0.63126534f, 1.21715927f, 1.22841084f, -0.67416084f, 0.02927168f, -0.36693662f, 0.63204330f, 0.13721083f, 0.28742912f, 0.19470036f, 0.74873924f, -1.47602463f, 0.86264688f, -0.23730527f, -0.99978864f, -1.17048764f, -0.34996086f, 1.43019187f, 0.26224539f, 0.60689932f, -0.75002515f, -0.79823422f, -1.37300086f, -0.19951135f, -0.12150808f, -0.75272322f, 0.23755015f, 0.31270382f, 1.66539109f, -1.04104745f, 0.79540199f, -0.54042423f, -0.54150617f, 0.43871084f, 0.24163951f, -0.24517761f, -0.66178995f, -1.13064528f, -0.84426326f, 0.56437236f, 0.09088907f, -0.82823074f, 0.81753862f, -1.74096012f, -1.80599844f, -0.60943592f, 1.36094582f, -1.47762752f, 0.15931177f, 1.05569172f, 0.36751524f, 0.06497604f, 0.13536447f, -1.57156146f, 0.22783801f, -0.96910107f, -1.24294984f, -1.47147155f, -1.04790676f, 0.64629447f, -0.32266054f, -0.55675793f, -0.95612079f, -0.23005411f, -0.75229394f, 0.03050950f, -1.72484553f, -2.06055546f, 0.19892083f, -0.13597751f, 0.65180075f, 0.27096850f, 0.08977254f, 0.57564765f, -0.43227410f, 0.09541437f, -0.00358280f, 0.65680492f, 0.04006556f, 0.57160908f, 0.43821687f, 1.96118212f, 0.42602235f, -0.36731303f, 0.67200917f, -0.56667900f, 0.44014785f, 0.06970236f, -1.34415269f, -1.13301528f, -0.08848868f, 0.35615012f, -0.06426942f, -0.81406075f, 0.94097465f, -0.54560357f, -0.65877116f, -1.29646838f, -1.13109028f, -1.64186084f, -2.12723470f, 1.86027610f, 1.22621441f, 0.26098135f, -0.05608099f, 0.21143445f, -0.87244326f, 0.79408187f, 1.24279130f, 0.14458629f, 0.25532281f, -1.24023473f, 2.42278886f, 0.00405578f, -1.00119174f, 1.19856644f, -1.37395728f, -0.16656208f, 0.46858498f, -0.00678801f, -0.34960639f, 0.16614936f, 2.41560221f, -0.53880709f, 0.91618651f, -1.77009308f, 0.32911557f, 0.30216452f, 0.02881077f, 0.77705866f, 0.27061903f, -0.07440855f, -1.14010465f, 1.25383139f, -1.58615100f, 1.04185510f, 0.15140508f, -0.88059032f, -0.33872122f, -0.42526904f, 2.17365575f, 0.29308075f, -2.24234557f, -1.03164542f, -0.09263755f, 0.08050421f, -0.74946511f, -0.64589006f, -1.13416314f, -0.64989561f, 0.16502371f, -0.33831969f, 0.22832428f, -0.08389475f, -0.28009200f, 1.34536922f, -0.19075738f, 0.36238208f, 0.83690089f, 0.26144615f, 0.04457319f, -2.55585861f, -0.01807522f, 1.68334866f, -0.05795629f, -0.21315987f, -1.84039557f, 0.06512877f, -1.77318645f, -0.27637982f, 0.20439345f, 0.67558700f, -0.77179354f, -0.17902173f, 0.70381826f, -0.40395790f, -0.96492916f, 0.84138173f, 2.43879008f, -0.32297835f, -1.74370265f, -0.10330839f, -1.07465363f, 1.85030377f, -0.59153467f, 0.99667048f, -0.56753993f, 0.57383025f, -1.90630126f, 1.24299097f, 0.22797665f, 0.30468231f, -0.07360230f, 1.64654350f, 0.57195550f, 0.03227921f, 1.11005175f, 0.00088721f, 1.19266295f, 0.61323351f, 0.13754399f, 0.59900171f, -0.75831634f, 1.11500823f, 0.99747783f, -1.36923385f, 1.26563418f, 0.01253266f, 0.35483193f, 1.95143735f, -2.02703261f, -1.38265920f, -0.02404256f, 2.02788448f, -0.75144875f, -0.58445263f, 0.26129767f, 0.60691077f, -1.84661067f, 0.65872228f, -0.58298993f, 0.33067298f, -0.09431327f, 0.43333948f, -1.52616286f, -0.25961858f, -1.65459549f, -0.72950101f, -0.89906919f, -0.80081612f, -1.32189929f, -1.36574399f, -0.35809481f, 0.36385000f, 0.31480747f, -0.35797358f, -1.04066050f, 0.07971872f, -0.21176252f, -0.76559299f, -0.10352154f, 0.29248312f, -1.75030553f, 0.68219930f, 0.56189102f, -1.11212170f, 0.06501702f, -0.07131009f, 1.23410738f, 0.29311740f, -1.02052307f, 1.40220940f, -1.00995779f, 0.57955760f, 0.22640309f, 0.74853230f, -0.02586563f, -0.33427954f, 1.70311153f, -0.53405988f, 0.90975094f, -0.46450076f, 0.19904344f, 0.28559047f, 0.23167793f, -0.69065529f, -0.17176504f, -0.29301846f, -0.85477978f, -0.00267053f, -0.28529504f, -0.64201307f, 1.03479636f, 1.03805065f, 0.83270210f, -0.09405448f, 2.50615931f, 0.62019676f, 0.31354564f, -1.51599669f, 0.42848015f, 0.66263914f, 0.74651009f, -1.13042867f, -0.58933645f, -0.35146511f, 0.06223279f, 0.28065836f, 0.66506970f, 0.16942430f, -0.23316263f, -0.87481076f, 1.21992230f, 1.48536301f, -0.79667616f, -0.75519305f, 1.40999961f, -0.42802793f, -0.20252463f, 0.30573779f, -0.23319976f, 1.77525878f, -1.80704832f, 2.71519923f, -0.67500192f, 0.12268137f, -0.13014549f, -0.07479453f, -1.51065743f, 1.04198146f, 0.96205556f, -2.00525570f, -0.37911776f, 0.89329720f, -0.39495832f, -0.03683375f, -0.90928614f, -1.56263304f, 0.45038295f, -2.62184358f, -0.45686841f, -0.52536523f, 1.05351484f, 0.89982438f, -0.63724512f, 3.21004057f, -0.08608918f, 1.55209303f, 0.62688643f, -0.59702635f, 1.85774517f, 0.38172096f, -1.25640929f, -2.59278178f, 0.85050315f, -1.10080361f, -1.26422560f, -1.80045366f, -0.34494889f, 0.68448657f, 1.25671864f, -1.26594126f, 0.32244179f, -0.51956522f, -0.56212711f, -0.95574015f, 0.71973872f, 0.46736258f, -0.11772985f, -1.52736545f, 0.19571695f, 0.73147154f, 0.87724912f, -0.26265728f, -2.60267401f, 0.19263546f, 0.18320183f, 0.11485019f, -0.82999659f, 0.13582672f, -0.08040185f, 0.28152901f, -0.51421624f, -2.32467175f, 0.19923948f, 0.64616692f, 0.29718629f, 0.32785949f, -0.62266952f, -0.98174316f, 1.23276305f, 0.58563638f, 1.28528512f, -2.13718534f, 0.28842899f, 0.12676710f, -1.72105229f, 0.15053287f, 2.19496536f, 1.28683448f, -0.96318281f, 0.17043279f, -0.05245409f, -0.38710704f, -0.30441490f, -0.08249986f, 0.28423953f, 0.72963721f, -1.49658203f, 0.99077344f, -0.78913772f, -1.12661564f, -1.26294816f, 0.16517465f, 0.10124251f, -0.77198768f, -0.16342169f, 0.08615876f, 0.49711797f, -0.66083062f, 0.76648003f, 1.04756033f, 1.46122825f, -0.42798752f, -2.29203916f, 0.30444992f, 0.58697921f, 1.22166932f, 0.09022947f, -0.03920181f, 0.10444995f, 0.10361757f, 1.18224072f, -0.76641631f, 0.90802073f, 1.41639423f, 1.55682337f, 1.28101575f, -0.35396016f, 1.11443567f, 1.18218529f, -0.06048089f, 0.85024464f, -1.01789165f, -0.69154263f, 0.06663221f, 0.68429029f, 0.12560424f, 0.37915874f, -0.66829866f, -0.64524972f, -0.05568011f, 0.12230454f, -0.35041061f, 0.62027830f, -0.16739209f, -0.72145337f, 0.46263054f, -1.67837834f, 0.69413221f, -0.57243419f, 0.37638462f, -0.21446526f, -0.89821470f, 0.60078722f, -1.06706369f, -1.26132309f, 0.35714921f, 2.39221811f, -0.09376130f, 0.30760849f, 0.59180892f, 0.55815399f, -0.32628775f, 1.28890121f, -2.53237987f, -0.98241091f, 1.10520673f, -1.74751687f, -0.90837651f, -0.25220659f, -0.56625104f, -0.30691949f, 0.16058689f, 0.44309673f, -1.09874964f, -0.76747823f, -0.33679363f, -0.02535496f, 0.00990100f, 1.35318136f, -0.70140815f, 0.50937581f, 0.55386209f, -1.21721983f, 0.71376961f, -0.18079315f, -0.11077732f, 0.09292522f, -0.57235324f, 0.62748206f, 0.42587611f, 0.64860481f, -1.10635614f, 1.66414368f, 0.47505483f, 1.48602211f, -0.59611166f, -0.41932896f, -0.96542233f, -0.41756630f, -1.02963889f, -0.70070386f, 1.65803933f, 0.20138647f, 0.05895034f, -1.46152759f, -0.37278318f, 1.05535650f, 0.34437978f, -1.13257408f, 0.17635690f, 0.09386671f, 0.37079874f, 1.47695887f, -1.58420062f, -0.26100200f, 0.44847637f, 0.88847303f, -0.13877590f, -0.64620668f, -0.38019657f, 1.01608157f, 0.13357787f, 0.05137976f, 0.93498152f, -0.62226880f, 0.80461699f, -0.71682596f, -0.88756353f, 0.40933055f, -1.52167451f, 0.79756850f, -0.17307425f, 0.62368619f, -0.22466940f, -1.72802913f, 0.59047443f, -0.58020931f, 0.09096476f, -0.07317388f, 0.44522321f, -0.64880705f, 0.15684015f, 0.08708375f, -0.41556796f, 1.11579072f, -0.81733495f, 0.11643656f, -0.73995101f, 0.93685871f, 1.57971406f, 0.67606360f, 0.70509088f, -0.25283816f, -0.00010609f, -0.61884147f, -0.86409342f, 0.95383751f, -0.05895388f, -1.45261180f, 0.45166013f, -1.01434863f, 0.18496066f, 1.06517637f, 1.81127059f, 0.89470667f, -0.13232610f, 0.46958798f, 0.13884509f, 0.57117194f, 0.29575035f, -0.97884250f, 0.83291447f, -0.59255791f, -0.04354135f, -0.19431923f, 0.30071029f, -0.95421529f, 0.76359886f, -0.47799742f, 0.68254346f, 1.19368529f, -0.48935115f, 0.30357337f, -0.50225669f, -0.23370270f, 1.96702433f, 1.46558523f, 2.68482018f, 0.41622332f, 0.73697484f, 1.43430734f, 0.15387188f, 0.20875402f, -2.49335337f, -1.39674246f, -0.22125854f, -0.00424605f, 0.91416460f, 0.33384630f, 0.44703746f, 0.25610185f, 0.38966551f, -0.01784045f, 1.66148460f, 0.36005461f, 0.95716912f, -0.18246566f, -0.15480693f, 0.38775176f, -0.56969136f, -0.29644895f, -1.04565966f, -1.00455630f, 0.30897698f, -1.46885884f, 0.03657720f, -0.49302089f, 1.34134722f, 0.01673754f, 1.22725964f, 0.55256772f, 0.63803208f, -0.29041430f, 1.11455286f, 0.76329172f, 0.27073982f, 0.77173829f, -1.79884446f, -0.11889492f, -1.92040312f, -0.46382675f, 0.20078070f, -0.98889589f, 1.46711135f, -1.68280172f, -0.52852470f, 0.66245162f, 0.29575166f, 1.34826505f, -0.22362417f, -0.14345661f, -2.34815073f, 1.26572001f, 0.66505629f, 1.01141500f, 1.08030057f, 0.17036134f, 0.00168786f, -0.37282917f, 0.69206375f, 1.07367527f, -0.49708191f, 1.49504781f, 0.58224988f, 0.96593714f, -1.07661915f, 0.25202179f, 0.25531644f, 0.42357162f, -0.31236249f, 0.48383278f, -0.06361829f, 0.24131298f, -0.95695931f, -0.12589653f, 0.36134180f, 3.20266032f, -0.40879184f, -0.66985190f, 1.51674330f, 0.34072638f, 1.15076303f, -0.40199137f, 0.46223637f, -0.48608047f, 0.99119538f, -0.22506073f, 0.30968750f, 0.64210880f, 0.54640514f, 0.18607031f, 1.26293361f, -0.77960914f, 0.79572529f, 1.01936150f, 2.27160740f, -1.48034489f, 0.74466604f, 0.14863680f, 0.31102443f, -1.15673816f, -0.38609681f, -2.65026069f, -0.45524642f, -0.74022961f, 2.74991131f, 0.00103815f, -3.03303242f, -0.41556966f, -0.87103498f, 0.78306234f, -0.88195556f, -0.77297026f, 1.21203196f, -1.09754920f, -0.03556008f, -0.31546223f, 0.72954375f, 0.25251788f, 0.11378583f, 0.50921023f, 0.30301905f, -1.60631680f, 0.27152416f, 1.17342317f, -0.70891970f, -0.08392961f, 0.92137378f, -0.10568139f, -0.31653777f, -0.28878728f, 1.22166574f, 1.12693942f, -0.21325994f, 0.94010323f, 1.21796405f, -0.68866694f, 2.30724216f, 0.28141466f, 0.83481526f, -0.04885862f, 0.01675143f, 1.04355800f, -0.81050140f, 1.51300573f, 0.53429186f, -0.56439877f, 0.38572624f, -0.05620475f, 0.67644542f, 0.72528905f, 0.05937041f, -1.06315899f, -0.51393986f, 0.46937627f, -0.34699562f, -0.64765716f, -1.45512629f, 0.47739139f, -0.88228017f, -2.00791359f, 1.29929042f, 0.05482405f, -0.66725296f, -0.54735124f, 0.09972951f, 0.76675093f, 0.98748523f, 0.08900899f, -0.78854066f, 1.47970486f, -0.61667502f, 0.45625573f, -0.21766303f, -0.46250847f, -0.07130960f, 0.64414692f, 0.12784545f, 0.26393634f, 1.07720757f, -1.23938286f, 0.62483376f, -0.55001754f, -0.05358591f, 0.07322436f, 1.12003291f, -1.00830650f, -0.20486419f, 0.76664752f, 0.28850746f, -0.04464776f, -0.40146068f, 0.73262817f, -1.12827921f, -0.19989438f, -1.15999687f, 1.37973154f, 0.78881019f, -0.34762639f, 1.22088552f, -1.64088547f, 0.63218033f, 0.45736769f, 0.05502866f, 2.22683382f, -1.78935897f, -1.49635041f, 0.83450896f, 1.67770112f, 1.33909333f, 1.51158953f, 0.28595078f, -0.08593627f, 0.45812801f, -0.15193029f, 1.14770603f, -0.88920450f, -1.96352005f, -1.49894583f, 0.49629962f, 1.59872091f, 0.00903497f, 2.15563583f, 2.25149560f, -2.01200557f, 2.56229877f, -1.38850498f, 0.73552012f, -0.39378855f, 0.52616280f, -0.03685786f, 0.87403935f, 0.12163408f, 0.74297994f, -0.30697080f, 0.38139752f, 0.49113834f, -0.95485127f, -0.99908817f, 0.71716321f, 0.04000283f, -2.09645271f, 1.38789880f, 1.37198520f, 0.82493287f, 0.17114936f, 0.53696346f, -0.19516060f, -0.50377476f, -0.91730285f, -0.70113552f, -0.02406530f, 0.84943396f, -0.17428185f, -1.09140801f, -0.68156958f, 1.70756388f, -1.00399911f, 0.03023832f, -0.39023280f, -1.89737976f, 1.14469039f, -0.58337289f, -0.60037899f, -1.17490256f, -1.56342828f, 0.48714057f, 0.62266618f, -0.15967095f, 1.32789338f, -1.25700688f, -0.55633998f, -0.83128709f, -0.49346271f, 1.59561753f, -0.24675299f, 0.38012561f, 0.91796309f, -0.38522810f, -0.65509188f, 0.94100451f, -0.57324487f, 2.19070768f, 1.24058700f, -0.75978851f, -0.40460554f, 0.79189235f, 0.70192885f, 1.93569362f, -0.03070199f, 0.77010989f, 0.58794290f, 0.51087004f, 0.22892070f, 0.35007235f, 1.56023848f, -0.67453802f, -0.18485607f, 0.64349502f, -0.31489357f, -1.95834625f, 0.06560058f, 2.30394220f, 1.18194163f, -0.88034087f, -1.05000436f, -1.05471325f, -0.98481798f, 0.49904808f, 0.16438948f, -1.10297823f, -1.39736509f, 0.01306054f, -1.85160267f, -0.87292641f, -0.15418227f, 0.43412164f, 1.16518164f, 0.06273691f, 0.24659210f, -0.08267246f, 1.28885782f, 0.73575675f, -0.01019809f, -0.08753663f, -0.61827368f, -0.40863234f, 2.12599611f, -0.53620332f, 0.53789747f, -0.66386080f, -1.70461988f, 0.86608189f, -1.11151052f, 0.14120635f, 1.18858743f, -0.31760478f, -0.73533046f, 0.20978074f, -0.84074509f, 0.16523147f, -1.03362834f, 0.59721231f, 0.21318658f, 0.23671274f, 1.75115061f, 0.25363782f, -1.32541454f, 1.13056135f, 0.24652456f, 0.60381413f, 0.21478581f, 0.75044096f, -0.63125616f, -1.69889998f, -0.02116571f, 1.46165359f, 1.03068244f, 0.63693464f, 0.67795700f, 1.20033514f, -1.39205134f, -0.61743122f, 0.56549704f, 0.65182322f, -0.74250507f, -1.61939359f, 1.14054918f, -0.45725963f, 1.74519682f, -0.66251940f, -0.94811529f, -1.60865819f, -0.59968346f, 0.86309159f, -1.91936195f, -1.02646923f, -1.50352538f, 0.58292735f, 0.05320299f, 1.53582895f, 0.01069612f, 0.15226212f, -0.71840125f, -1.36896348f, 2.14600968f, 0.96626586f, -0.52014917f, 0.41001406f, 0.59478027f, 0.15282436f, 0.27790198f, 0.76614654f, -0.38971323f, -0.01839927f, -1.57882118f, 0.61391610f, -0.62133092f, -0.03968323f, -0.88467252f, -1.24041140f, 2.07306671f, -0.41776338f, 0.14537935f, -0.91069067f, 1.67362070f, 4.72630215f, -0.07395106f, 0.46280116f, -0.40843824f, 0.70683080f, -0.27510864f, -0.63465804f, -0.83630908f, -0.44419941f, 0.60405648f, -0.65039170f, -1.02413189f, 1.05983019f, 1.73366308f, 0.73343736f, -0.00895882f, -1.00826013f, 0.17323074f, 0.73995626f, 0.24128854f, 0.94510227f, 0.25557515f, 0.02244723f, -0.95197725f, -0.16297856f, -0.38497585f, 1.17993331f, 1.20282137f, -1.31491220f, 0.44229278f, -0.24349044f, -0.01230415f, 1.37944865f, 0.48554277f, -0.54510897f, -0.10793537f, 0.41121426f, -0.12889031f, 0.26434359f, 1.27966082f, 0.64518744f, -0.15577169f, -0.99864733f, -0.61746484f, 2.01614976f, 1.56254935f, 1.86473298f, -0.54662132f, -0.22047071f, -0.06118120f, 0.84799510f, 0.17009684f, -1.30523121f, 0.64000309f, 0.36299205f, -0.59620583f, 1.36372304f, -0.05389515f, -0.93849313f, 0.98043185f, -0.39373067f, -0.84898937f, 1.32077873f, 1.05988657f, -1.35339200f, 0.23259017f, 0.63816410f, -0.80297333f, 0.60017115f, 1.25715804f, 1.18894124f, -0.62473553f, 1.05611980f, 0.02335166f, 1.07509828f, 0.25873449f, -1.68341100f, 0.54547334f, 0.79288185f, -0.93678916f, 0.19202201f, -1.48575914f, 1.08649087f, 0.50851744f, -0.45758674f, -0.39734635f, 0.35637981f, -1.63079453f, -0.75910008f, 0.92640859f, -0.55599529f, -0.40276715f, 0.31307653f, 0.39907026f, -1.18830419f, 0.71051043f, 0.14157933f, -0.39581308f, -1.64361024f, -0.06161860f, -0.25312796f, 1.10018682f, 0.56500763f, 0.80385065f, 0.35395023f, 0.81813669f, 0.27644628f, 0.65563256f, 1.73197234f, 0.68178749f, 0.76769936f, 0.44597456f, 0.67761195f, 0.67635447f, -0.32315412f, 0.19330767f, -0.25557944f, 1.91693723f, 0.38335562f, 0.07107610f, -0.57384586f, 0.79184365f, 1.87835479f, 0.60902315f, -0.94220877f, 0.79479855f, -0.25656971f, 0.08739131f, 0.53384244f, 1.22159266f, -0.39152125f, -1.46373534f, -0.02458516f, 1.62825716f, -1.26112676f, 0.19967082f, -0.71114451f, 0.27929229f, 0.65001321f, -0.11868202f, -0.55587751f, 0.78069001f, 0.57969242f, -0.60274386f, 0.31650013f, 0.90339553f, 0.09453616f, -0.37119162f, -1.00320566f, 0.33299938f, -0.48636708f, 0.26342997f, -0.91914523f, 0.28682709f, -1.24780893f, -1.59254742f, 0.97176319f, 0.14744301f, -0.53056234f, -1.73221612f, -0.67645556f, 0.98705006f, 0.79895812f, -2.04333115f, -0.60132772f, -0.91653955f, -0.28094748f, 0.47943443f, 0.38157779f, -0.67648011f, 1.09093642f, 1.66012859f, -0.29358891f, -1.26773024f, 0.36747769f, -1.10141146f, 0.82383633f, -0.89772314f, -0.47145563f, 0.63939518f, -0.64430422f, -0.48889321f, -0.37680882f, -1.06962025f, -1.28689516f, 1.28365147f, 0.61859220f, -0.84676331f, 1.38404000f, 1.21053445f, -0.14871351f, 1.06349385f, 1.45878971f, -0.47362664f, 1.40707004f, 1.25224137f, 0.87364739f, 0.92858213f, 0.00157326f, 1.45661485f, -0.27318576f, 0.15482858f, -1.07058907f, -0.06903186f, -0.74147576f, -1.64111829f, -0.67226541f, -1.13458407f, 1.28511488f, -0.41041154f, 2.09085560f, 0.45243183f, -0.67437285f, 0.84960121f, -1.49300814f, -0.42961186f, -2.35021853f, 0.57255560f, -0.73903763f, 1.37607956f, -2.44575167f, 1.25105727f, 1.38575912f, -1.16299784f, -0.13719854f, -1.11507034f, 0.35796806f, -0.64511567f, -0.87903833f, 0.32833642f, -0.87696886f, 0.02714214f, 0.30224666f, -0.69118696f, -1.23500824f, 0.76678628f, -3.20508122f, -0.24704689f, 0.49019828f, -1.20862615f, -0.03778638f, -0.07273687f, -0.11517122f, -1.75857520f, -1.64188445f, 1.21574795f, 0.57325113f, 1.14370298f, -1.07824504f, 1.70653832f, -0.03700557f, -0.47645858f, 0.11065386f, -1.03143036f, -2.18094873f, -0.94403434f, -0.09335683f, -0.44817665f, 1.39707148f, -1.21947956f, 0.56575936f, -0.69612634f, -1.12361753f, -0.17105591f, 1.15422392f, 0.02840637f, 0.09469353f, -0.52859986f, -2.08487725f, 1.28789508f, -0.03740775f, 0.61196613f, 1.23405397f, 1.56595814f, -0.65800631f, 2.02985072f, -0.69446486f, -0.88443804f, -0.23448054f, -0.43628734f, -0.45888957f, -0.21943338f, 1.78258693f, 1.75214970f, 0.71804136f, 0.49782532f, 0.37886053f, -1.59176385f, -1.74758542f, -0.02820176f, 0.75398153f, 1.00119829f, 0.80881971f, -0.53365272f, -0.22720885f, 0.37476870f, 0.01005529f, -1.23421800f, -0.13431595f, -1.01843679f, 1.87386346f, -1.68539488f, -1.04942071f, -0.77322137f, 0.53964764f, 0.29278332f, -0.58299130f, -1.56022692f, -0.79441273f, 0.49289709f, 0.44112054f, 1.07305002f, 0.54899335f, 1.13781393f, 0.77809113f, 0.81795985f, 0.16576190f, 0.32552773f, -0.20250474f, 1.46543837f, 0.12731771f, 0.21013761f, -1.34241438f, 0.44267517f, 0.93246883f, 0.08808212f, 0.92653406f, -1.21083558f, 0.17247954f, -0.70557106f, 0.04630012f, 0.48834828f, 0.89634645f, 0.46683592f, -0.29553145f, 0.46363977f, -0.48971879f, -0.88603491f, -0.12333342f, 0.37073737f, 0.92061806f, 0.54675460f, -0.14716248f, 0.75578392f, -0.98173791f, -1.15983224f, -0.58713156f, 0.07950903f, -0.59016788f, 0.41622928f, -0.32474482f, 0.42086437f, 0.23061797f, 0.62596649f, -0.22615278f, -2.14721417f, 1.01685894f, -0.25976995f, 0.00739352f, -1.31597066f, 0.39005190f, -1.09549701f, 1.68375242f, 0.43331525f, -0.37124026f, 0.22255214f, 0.59654880f, -0.73840386f, -1.20048976f, 0.12226126f, 0.12997478f, 1.04826224f, 0.03894836f, -0.36289826f, 1.14466560f, -1.18198848f, -0.03713558f, 0.67677927f, -0.42329931f, -0.89409167f, -0.77874780f, 0.58438253f, -0.35176343f, -1.53329861f, -0.02995299f, -0.40145162f, -1.51052392f, 0.09194464f, -1.13275242f, -0.61983156f, -0.40004560f, -0.19893464f, 0.22134103f, -0.03903082f, 1.14894116f, -0.03476744f, 0.22520730f, -0.55851930f, 0.76650429f, -0.57863152f, -1.34161711f, -0.31498179f, -1.19411755f, 1.70044947f, -0.17428267f, -0.35983825f, -0.42613637f, 0.58165723f, -0.77866900f, -1.59727287f, -0.61723864f, 1.51078022f, 0.32971445f, -0.86441469f, 0.60552609f, 0.00208178f, -0.47096625f, -1.10479307f, -1.21652532f, -0.08211990f, -1.43739200f, -1.31684434f, 0.43312529f, -0.76822090f, 1.88128507f, -0.02179282f, 1.04971325f, -1.55004108f, 1.25337446f, 0.11203052f, -1.16048300f, 1.59467411f, -1.29469275f, 1.14019871f, 1.20021439f, 1.84098923f, 0.05004879f, 0.73529941f, 2.05272865f, -0.13080600f, -0.08436690f, -1.17919350f, -0.66256678f, -0.36727047f, 0.73840511f, 1.22293818f, -0.00206342f, -0.29839504f, -0.00618613f, 1.04213119f, 1.21176076f, -0.62886089f, -0.02589060f, 0.96009409f, -0.64478731f, -1.16516542f, 0.57528079f, 1.04294407f, -0.09774588f, 0.45935291f, 1.03263175f, 1.00633478f, -1.82209253f, -0.18035053f, -0.28302726f, -0.83813244f, 0.57593471f, -0.03807700f, 1.60498738f, 0.16530658f, -1.43083501f, 2.10824299f, 0.30279446f, -0.03961089f, -0.38900724f, 1.31272805f, -0.56575215f, 0.57970244f, -0.48305038f, 1.34114623f, 0.21859215f, 0.66399640f, -1.52087069f, -1.30717897f, 0.14394683f, 0.97648209f, -0.71372712f, -1.22574198f, -0.27702177f, 0.04041927f, 0.02442212f, 2.19617033f, -0.48566443f, 0.81463927f, 0.20383844f, 1.17562282f, -0.33829874f, -0.42141283f, -0.96415234f, -2.39141965f, -1.04285860f, -0.23004992f, 0.41186509f, 0.03811268f, 0.36818987f, -0.71099734f, -0.56749570f, 0.18486284f, -0.44530040f, 2.14008284f, -0.27467576f, 1.70690107f, -1.40462613f, 0.24697532f, -1.31629777f, -2.20674944f, -0.67868507f, -1.15767133f, -0.64391804f, -1.79037917f, 0.58749497f, -1.58303332f, -0.69021022f, 1.64376318f, -0.95393223f, 1.98415601f, -0.10991055f, 0.02474386f, 0.23683345f, -0.63420391f, -0.57991928f, 0.83028817f, -0.40033704f, 0.19212338f, 0.74640590f, 1.10264432f, -1.65286255f, 0.92683482f, -1.42252541f, -0.74605089f, 2.14535880f, 0.12971123f, -0.47971717f, 1.67546797f, 0.42268261f, 0.22648531f, -0.42369929f, 0.77403021f, -1.31818616f, -0.67143595f, -0.04311426f, 1.64128351f, 0.34776631f, -0.39353722f, -0.42765084f, 0.16170517f, -0.54488391f, -0.38428506f, 0.42097485f, -0.55982012f, -1.74543798f, 1.53704774f, 0.43562424f, -0.30395737f, 0.31846946f, 0.39205357f, 0.57386035f, -1.11912560f, -1.39164317f, -1.04337609f, 0.31629622f, 1.51927638f, 0.88745505f, -0.40445471f, 0.25783861f, 1.88646257f, 0.36509129f, -1.13266826f, -0.45394278f, -0.48400903f, -1.22332740f, 0.38626808f, -1.10049105f, 0.84138852f, 1.27863181f, 0.53942156f, -0.67743856f, -0.03896645f, 1.70393491f, 0.60997570f, 0.43368068f, -0.13338457f, -0.18920666f, -0.29583672f, -1.40738738f, 1.03876019f, 1.71253765f, 2.12821221f, -0.96092403f, 0.93841934f, -0.79030478f, 1.36427641f, -1.39196694f, 0.08514920f, 0.16223004f, 0.71259701f, 0.20150672f, 0.25068361f, -0.99952722f, 1.80129099f, -1.28586197f, -0.64957166f, -0.94813949f, -0.40161121f, 0.31977695f, 0.54932386f, -0.67757767f, 1.88086259f, 0.92337233f, -1.64887333f, 0.44333732f, -0.19468001f, 0.12977587f, 0.21171951f, 0.27679422f, 0.49134475f, -1.44429457f, 1.25617445f, 0.39978400f, 0.99869555f, -1.61617446f, 1.61177349f, 0.70243025f, -0.95748568f, -0.61795151f, -0.77302909f, 0.72967088f, 0.81964350f, -0.71813750f, 0.90140164f, -1.45950246f, -0.79972702f, 0.40875742f, 0.00152073f, -1.74491429f, 1.53776145f, 0.75769204f, -0.22075878f, -0.58385569f, 2.18884754f, 0.33597681f, -1.66265559f, 1.03805876f, -1.55245185f, -0.03582226f, -1.94542754f, -0.76081425f, -0.50471377f, 1.35763168f, -0.39631784f, -0.17134467f, -0.82220149f, -0.41021580f, -0.00940776f, -0.80176353f, -0.19816744f, 1.22061026f, -0.14486519f, -0.71727395f, -0.65721530f, 0.47020102f, -0.70403302f, -0.94795334f, 1.79884899f, 0.07779162f, -1.50615680f, 0.04140327f, -0.22001404f, 0.63735324f, 0.79237640f, -2.25412822f, -0.52519119f, -0.87280381f, -0.07100742f, -0.94734806f, -0.12286110f, -0.13623615f, -0.42595413f, 0.17547913f, -0.81707209f, 0.36855817f, -1.68186557f, 0.19312963f, -0.66249490f, -0.98283452f, -0.33314428f, 0.40918943f, 0.88268638f, -0.05390308f, -0.22440539f, -0.15879378f, -0.34859571f, -0.01013108f, -0.30005428f, -1.19408464f, 0.21789688f, -1.07769871f, 0.81475031f, -0.69555300f, 2.35201311f, -0.40362412f, 0.93497628f, 1.13343573f, 0.92343372f, 0.26987928f, 0.46123627f, 0.22577702f, 1.26289701f, -0.45956740f, 0.55994868f, -0.58410591f, 0.13304594f, -0.25806463f, 0.49044946f, -0.82065403f, -3.06672239f, -0.27774641f, 0.68504512f, -0.21386372f, 1.11427057f, -0.73201770f, 0.51655543f, 1.77261138f, 0.72081727f, 0.11116749f, 0.16637769f, -0.74987584f, 0.66579849f, -0.75808716f, 0.20678560f, -0.67698354f, -0.82141948f, 0.61008269f, 0.66520184f, 0.44894725f, 0.73015076f, -1.52517414f, 0.11714164f, 1.90452611f, -1.30355322f, 0.12144456f, 1.18547559f, -0.07349755f, -2.28061509f, 0.83522540f, 0.78438890f, 2.19334102f, 0.90305614f, -0.59345531f, 0.77925014f, 1.32338643f, 0.14068902f, 1.19032264f, 0.20666829f, -0.76595837f, 0.74967057f, 2.86965609f, 0.55690205f, -1.72530472f, -0.83317834f, -0.85842621f, -0.29678273f, 1.80955839f, -0.70496303f, 1.19106734f, -0.92985237f, -1.00617313f, -0.56049556f, -0.29382578f, -2.04022193f, -1.95356870f, -0.42553005f, -0.33369407f, 1.02115977f, -1.45769477f, -0.67720300f, 0.53819913f, 1.57643425f, -0.47015440f, -1.47861958f, -0.00545934f, -0.97836047f, 0.42680529f, 1.56110144f, -1.49487829f, -0.65198445f, 0.22720462f, 1.83036661f, -0.47099793f, -0.09915133f, 0.14923312f, -1.16313052f, 0.67798084f, -1.63665557f, -0.38220280f, 0.01719763f, 0.30041245f, 0.43148938f, -0.44021657f, -1.25734651f, 0.02465564f, -1.00845659f, -0.28574651f, 0.01367745f, 0.77253437f, -0.99399441f, 0.61445391f, 0.18343423f, -0.50997210f, 0.41359940f, 0.77279282f, 0.83511519f, 0.27929801f, 0.70800692f, -0.20278299f, 1.57884383f, 0.22650529f, 0.43347472f, 0.74003208f, -0.71401161f, -0.69829476f, -1.56766701f, -0.99254119f, 1.27301061f, 2.73726511f, 0.66089469f, -1.95778012f, -1.24642098f, -0.63579029f, -1.63168180f, -0.66980726f, 0.81933254f, 0.61866677f, 1.40594471f, 0.05158535f, 0.00196500f, -0.24592508f, -0.50780547f, -0.83905292f, -0.10748957f, 0.04490763f, 0.27769178f, -0.23227681f, 0.82108080f, 0.03562285f, 0.95483875f, -1.49897683f, 0.67809856f, 0.35497451f, -0.44021592f, -1.67361462f, -0.88895375f, 1.44293678f, -0.85046643f, -0.46437624f, -1.87252641f, 0.26775804f, -0.24535774f, 0.73365933f, 0.52253938f, 0.27947086f, -0.58796054f, 0.59045380f, 1.93476331f, -0.46775359f, 0.25238225f, -1.26601815f, -0.13324316f, -0.71454948f, -0.21610366f, -1.49586582f, 1.04903507f, 0.22208478f, 0.25512528f, -0.46157327f, -0.41319233f, -0.63846964f, -0.25100923f, 0.81277549f, -0.26959971f, 0.88737756f, 1.24578953f, -0.91121447f, -1.05756927f, 0.44390878f, 0.16672316f, -1.22941923f, 0.89547867f, -1.50212002f, -1.69620168f, 0.53339505f, -0.23656729f, -1.69879091f, 0.01510374f, 0.08315694f, -0.73196459f, -1.60263407f, -1.07601058f, -0.76389569f, -1.65307498f, -0.61484390f, -0.43546933f, 0.71318507f, -0.16273083f, 0.64122051f, -0.15406294f, 1.17673671f, -0.91240519f, 0.71091145f, 2.40497613f, 1.26343656f, 0.71469337f, 0.20705548f, 0.81776261f, 0.36253929f, -1.92106628f, -0.09300470f, -0.36648872f, 1.27732766f, -0.39180157f, -0.61186749f, -1.03455031f, -0.25079829f, -0.61479062f, -1.07094336f, 0.82218504f, 0.89934880f, 0.41308978f, -0.59968555f, 0.37682834f, -1.77388155f, 0.00294951f, -0.66145372f, -0.50789726f, -0.85123241f, -0.89909405f, -1.89454281f, -0.56692821f, 1.52272677f, -0.11961794f, 0.27843913f, -0.60582250f, 1.01871169f, -0.36098275f, -0.12242325f, -0.67375034f, -0.11204147f, -2.62773919f, -0.95901299f, 0.14040214f, 1.32364666f, -1.35099924f, -0.11077739f, -0.79319423f, 0.75949597f, -0.25485823f, -0.90959758f, -0.42373934f, -1.29850340f, 0.85699379f, -1.11882365f, 0.63470817f, 0.49696380f, -0.07983235f, -0.23903450f, -0.22618714f, -0.12117998f, -0.09442677f, 1.55589819f, -0.11996678f, -1.72700179f, 0.54683149f, -0.40804827f, -0.50099218f, 0.34596699f, -1.81841791f, 0.06385052f, 0.84428120f, 0.69901514f, 1.94559097f, 0.43251973f, 0.16794942f, 1.82829034f, 1.70959795f, 0.36130908f, -0.94608402f, -0.53498030f, 0.47781768f, -0.24203247f, 1.25065851f, 0.51788396f, -2.09381890f, 0.72973937f, 0.03281829f, 0.58632666f, 1.85737121f, -0.49569523f, 0.45921183f, 1.87173629f, 0.22803484f, 1.66433418f, -1.05872321f, -1.13663685f, 0.12397861f, -0.65112090f, 0.98152941f, 0.83739656f, -0.18783289f, 1.84249437f, -0.90706986f, -0.80824369f, -1.23854923f, -0.86488134f, -1.02627063f, 0.10976455f, -0.61403006f, 1.27554715f, 0.14653525f, -0.03953953f, -0.08512071f, -1.30043304f, -0.02566035f, 0.12054887f, 0.00282162f, 0.48921332f, -1.74398839f, 1.44554436f, -1.35854721f, 0.69256759f, 0.34101671f, 2.50045252f, 0.49121150f, -0.27115449f, 0.93974596f, 0.26258010f, 0.27151433f, -0.87214381f, -0.92580765f, -1.03269923f, 0.20615758f, -0.37822601f, 0.58983004f, 0.16426525f, 0.68218285f, 1.98158526f, 0.47492698f, 0.54224718f, 1.28722692f, -1.76915324f, -1.11240053f, 0.77428484f, 0.27184650f, 2.22473478f, -0.05574624f, 0.39976570f, -0.43911108f, 0.52805597f, 0.17340177f, 1.36057591f, -0.35004014f, 1.72787797f, 0.68357420f, 1.25532615f, -0.56752264f, 0.51840127f, -0.21237844f, -0.58821255f, -0.85278064f, 1.90179110f, -0.67447448f, -0.36831430f, -0.22930753f, 0.98231596f, -0.07011599f, -0.08560387f, 0.05998110f, -0.02481356f, -0.57335132f, -0.44288307f, -0.24468307f, 0.53321087f, 1.19609559f, 0.10664973f, 0.24379487f, 0.93687552f, 0.93615580f, 1.74319768f, -0.68310338f, 1.32163060f, 0.61918712f, -0.76501870f, -0.54549301f, 1.74077415f, -0.69977754f, -0.66880983f, -1.15981388f, 0.81571609f, 0.53788543f, 0.47898352f, -0.02484704f, -1.64646924f, -0.69822907f, 0.27020717f, 0.05027051f, 1.75149667f, 0.01548872f, 0.32615909f, 2.55151844f, -1.29172051f, -0.36133784f, 0.98637396f, 0.14009331f, -0.50038946f, -0.92230296f, 0.17307127f, 1.05361068f, -1.46784890f, 2.38960409f, 1.19413340f, -1.33349669f, 1.59141159f, -0.71811068f, 1.22429430f, 1.26947939f, 1.08177102f, -1.18138707f, -0.72775704f, 0.17282635f, -0.40554270f, -0.40341887f, 0.46564049f, -1.02069795f, -0.07653128f, -0.13979210f, -0.31195050f, -1.72042310f, 1.37131393f, 0.63849634f, 0.75561279f, 1.81152904f, 0.26686314f, 1.32796574f, 0.56100166f, 0.70058894f, -0.88962644f, -0.04360984f, -0.88249093f, 0.24311203f, 0.50410056f, -2.22567797f, 0.94520348f, -2.12467694f, 0.47282359f, -0.71379906f, -0.09857135f, 0.62374717f, 1.37182784f, 0.73380554f, 0.59745449f, 2.80427694f, 0.67253572f, 1.65335357f, 1.69891667f, 1.34585941f, -0.79989213f, 1.44980943f, -0.52013642f, -0.46971673f, -1.50070012f, -0.25687039f, -0.56916732f, 0.71065760f, -1.31996286f, 0.96031237f, 0.13929774f, 1.49679291f, -0.05966444f, -0.58674580f, -0.08278833f, -0.93390942f, 0.42415768f, -1.77889526f, 0.75336021f, -0.72699982f, -0.82880586f, 0.63955617f, 0.42771208f, -0.42366457f, -0.91581815f, 0.94750947f, 0.43123913f, -0.99053741f, 0.70470595f, -1.16662264f, 1.14847183f, -0.83885664f, 0.46714026f, -2.27748466f, -1.23656678f, 0.14695056f, -0.33159894f, -0.52553117f, -0.04391259f, -0.29630372f, 0.25949728f, 0.96991086f, -0.37714824f, -0.28251833f, 0.16106486f, 1.38844633f, -0.18713553f, -1.30708838f, 0.48490265f, 0.29553881f, -0.45505449f, 0.83341682f, 0.87346369f, -0.63516861f, 0.66063565f, 0.93892503f, -2.73996735f, -0.81515318f, -0.91458052f, 0.00978268f, 0.43472794f, -0.08090764f, 1.37249672f, 0.76722521f, -1.19154143f, 0.22046764f, 0.34916410f, 0.51383299f, -0.56379753f, -2.49949312f, -0.74207872f, -0.68400806f, -0.09663232f, -0.07199454f, -1.05562651f, -0.75028551f, -0.87253797f, 0.69039482f, 0.45923674f, -1.27515161f, -0.04555376f, -1.41501272f, -0.83773375f, -0.74807298f, 1.36646152f, 0.06317432f, -1.32559633f, 1.89092779f, 1.24883330f, -1.03608561f, 1.08677161f, -0.99629849f, -0.69947034f, -0.85716367f, -0.07947286f, -0.25485426f, -0.19732477f, 1.64581251f, 1.04618108f, 1.87186897f, -0.18198362f, -0.83807969f, 0.70462501f, -3.18930101f, 0.74610996f, -0.60935193f, -0.49383929f, -2.88986492f, 0.51707613f, 1.04620326f, 1.09837818f, -1.19840038f, -0.10391295f, -0.20789115f, -1.51052022f, -0.31087330f, 0.22411564f, -1.30506921f, -1.52000105f, -1.51593041f, 1.04321992f, 0.97611690f, 0.90424490f, 1.83324766f, -0.08682299f, 0.47035542f, 1.70865905f, -0.31108001f, 0.04115159f, -1.36352801f, -0.90797836f, 0.32128647f, 0.66191489f, 0.08681208f, 0.14993365f, 0.47110486f, -0.31522670f, -0.38906571f, -0.08876022f, -0.13106902f, 2.25685239f, -0.62211353f, -1.68553007f, -0.23707703f, 0.69236159f, -0.46686995f, -0.27520603f, 0.26619941f, 1.48525345f, 1.61278927f, 0.49452963f, 1.20846486f, -1.11853909f, -0.30010033f, -0.75471467f, -1.69959772f, -0.52042168f, -0.43881389f, -1.45240712f, 1.02122891f, 1.73639011f, -0.03813924f, -0.22239220f, 0.15797073f, -0.64418089f, -0.60228932f, -0.83248150f, -0.02042520f, 0.38137484f, 0.86056453f, 0.06410559f, -0.62785137f, -0.49916875f, -2.53796315f, -0.79168582f, -0.69197005f, -0.77175534f, -0.28669405f, -0.79764080f, 0.97218460f, -0.10351621f, -0.52759898f, 1.02840185f, 1.16363287f, 0.08351815f, -0.61088538f, 0.59944046f, 1.54409397f, -1.39842033f, 0.27917057f, -0.27146137f, 1.46310735f, 0.03626106f, 0.15038440f, -0.07894899f, -1.42527366f, 1.69641745f, 1.48384345f, -0.43328866f, -0.54252565f, -0.94416499f, 1.54436302f, -0.81367069f, -1.67925239f, -0.17525831f, 0.27891046f, -0.69066733f, 0.89911050f, 0.11606655f, 0.67450327f, 0.41538724f, 0.90886223f, 1.19786549f, 0.85810721f, 1.32862210f, -0.83469814f, -1.09682298f, 0.88092703f, -0.97478902f, -0.11664717f, -0.07929394f, -0.69581884f, -0.16928329f, -0.70731819f, -0.40485084f, -0.28954300f, 0.52882415f, 0.38769314f, -1.38704026f, 1.15099049f, -0.43566978f, 0.34459323f, 0.49520254f, 1.11130333f, 0.28783718f, -0.53783375f, -1.63577271f, 1.02222812f, 0.86302060f, 0.48346213f, 0.46627176f, -1.30133855f, -1.48477137f, 0.31219670f, -1.21498191f, 0.89838904f, 0.87186617f, -0.39968935f, 0.34930915f, -0.32909471f, -1.39364409f, 2.13006306f, 0.33270469f, 0.00215986f, 0.97776711f, 0.24908836f, 1.56164885f, 0.45157790f, -1.55970144f, 0.27677536f, 0.07662498f, -0.08262251f, -0.17658773f, 0.65820259f, 2.01052690f, -1.71946216f, 0.84686053f, -1.23594892f, 1.40792072f, -1.47772563f, -0.36132276f, -0.50405115f, 0.09009213f, 0.81659186f, 1.85574234f, -0.64974433f, 0.63352364f, 1.01766217f, -1.54804432f, -0.42570522f, -0.24763709f, 0.72822112f, -0.93733686f, 0.68087620f, -1.40644944f, 0.48672482f, 0.09725539f, -0.64416331f, -0.95747960f, 0.36771363f, 0.39155054f, -0.71790671f, -2.17222738f, -0.08655047f, -0.97842115f, -0.22991380f, 0.52029115f, -1.42072022f, 0.29576331f, 0.32391560f, -1.00823236f, 1.67909145f, 1.16841447f, -0.32307062f, 0.15756166f, -0.97590631f, -0.39429301f, -0.03583352f, 0.17554663f, 0.57961231f, -0.46873134f, -0.23343173f, -0.85060924f, 1.71745574f, -0.04658702f, 0.63088381f, -0.67581934f, -1.53171062f, -1.58800113f, -1.17987096f, -1.16737640f, -0.87544650f, -1.17138922f, 0.38979119f, -2.39369726f, -1.34747124f, 0.58450359f, 0.87791806f, -0.04459394f, 0.97995293f, -0.10354915f, 0.65324986f, -0.17833626f, -0.85849386f, -0.42063358f, 0.19708554f, 0.10255250f, -0.59539181f, 0.86194044f, 1.68610668f, 0.55275291f, -0.43127069f, -0.04218780f, -0.08466262f, 0.31236625f, -0.92824298f, -0.09879152f, 0.32358822f, 1.04045570f, 0.35617545f, 0.09059231f, 1.19069445f, 1.96978688f, 0.63561743f, 0.15030998f, -0.29879019f, 0.22774190f, -1.01608860f, 1.03605175f, 0.47804731f, -0.30450734f, -0.61382371f, 0.45390254f, -1.93547988f, 2.01267338f, 0.52447683f, 0.18379784f, 1.11913633f, -1.24273467f, 0.15803322f, 1.72184098f, -0.79349059f, 0.10258614f, -1.53445125f, 0.02630571f, 0.81649125f, 0.91089755f, -1.12968338f, 1.04016411f, 0.28999722f, 0.74863863f, -0.61388236f, 0.01665530f, 1.43592548f, 0.68138391f, 0.11963340f, -1.26123953f, 1.36340797f, 0.25696915f, -0.58877039f, 1.42209792f, 0.55563360f, -1.33329606f, 1.84695840f, 0.88433737f, 1.04359078f, 0.18906727f, -0.03448994f, 1.17944050f, 0.86783957f, 0.44934425f, -0.77892244f, -1.76232874f, -1.01689589f, 0.78943914f, 0.92141974f, -1.00187087f, -0.13809921f, -0.90222073f, 1.10094714f, -0.13657950f, -0.44349849f, -1.61441302f, 1.05724919f, 1.50337231f, -0.05785890f, -0.76958144f, -0.51498759f, 0.69227600f, -0.37975949f, 1.31949317f, 0.82049531f, 0.32868597f, -0.31557772f, -0.75534385f, 1.27303052f, 0.43453619f, 0.11296938f, 1.18182182f, 2.23387384f, -0.86412978f, -0.01599468f, -0.70869064f, -0.09221385f, -1.23729551f, 0.79490280f, 0.03522846f, -0.95069039f, -1.73461652f, 0.72329187f, 1.40385795f, -0.11585230f, -0.78033113f, 0.07491048f, -1.12873089f, 0.18476245f, 0.57568848f, -0.28792691f, 1.35411644f, -0.76956165f, 0.29571572f, 1.03178787f, -0.38780826f, 0.31680650f, 0.69368076f, -1.23856580f, -0.49848995f, 0.14766994f, 1.02625990f, 3.03858209f, -0.51030380f, 0.96796870f, 1.35078156f, -1.07729447f, 0.84322494f, 0.54886484f, 1.31453705f, -0.45792100f, 0.31196272f, -0.15701357f, 0.83586836f, -0.74952888f, -1.17432022f, -0.31002575f, -1.02149463f, -0.36117774f, -1.22079086f, 0.03532525f, 0.00555908f, -0.45891216f, 0.29636297f, -0.68272704f, 0.41257843f, 0.37988129f, 0.01747893f, 0.82739186f, 1.52292180f, -0.79456621f, 2.20275712f, 2.13212132f, -0.81393015f, -1.15712392f, 0.22488308f, 0.62776327f, -0.85444915f, 0.44017896f, 0.05863331f, -0.83198178f, 0.93063420f, -0.16121253f, 0.12382501f, -0.37826315f, 0.93118382f, 0.19507533f, -0.58595538f, 1.46994352f, 0.13170272f, -0.70031989f, -0.12820166f, 0.30487457f, 0.84148771f, -0.68807501f, 0.21187615f, -0.67030680f, -1.79136002f, 0.70810199f, -1.20959783f, -0.08468831f, -0.06317700f, 1.35527098f, -0.47018668f, -0.91693246f, 0.14818805f, -0.05405350f, 1.16875637f, -0.17363262f, -1.61833882f, -0.32934523f, -0.38346377f, -0.62702698f, 0.34135151f, 0.48015586f, -0.65263331f, -0.04689486f, 0.01156854f, 0.37580970f, -0.16174591f, 0.59627324f, 0.24351901f, -0.87983090f, 1.57049024f, 1.25836349f, -0.41464049f, -0.62279183f, 0.09693756f, -0.23850618f, -0.49007827f, 0.22298151f, 0.10914832f, -0.35192192f, -1.27221346f, 1.10203624f, -0.86399704f, -0.47319838f, -0.77105570f, -1.68624854f, 0.81198281f, 0.82534081f, 0.75654501f, 1.47631240f, -0.61000234f, -0.58933264f, 0.54822850f, -1.22829592f, 0.11107657f, 0.56449169f, 1.50693524f, -0.59280968f, -0.64286685f, -0.20120731f, 0.27184448f, 1.55500400f, -0.48919386f, 1.04044867f, -0.87048137f, -0.40569979f, 0.21908638f, -0.51829034f, -1.48748124f, 0.02990401f, 1.83462536f, 0.29885170f, 1.32370698f, -1.30129600f, 2.43271399f, 0.22967771f, -1.13014007f, 0.95529765f, -0.83325785f, 0.43633386f, 0.85774118f, 0.78160155f, 0.58583075f, 1.18906367f, -1.54354560f, -0.68320692f, 0.01900371f, -0.79777133f, 0.12851712f, 1.10176420f, 0.79418170f, -1.41154039f, 0.36929929f, 1.12176800f, 1.23849642f, -0.89377707f, 1.01390159f, -0.50889206f, -1.12554002f, 0.17932732f, 0.48949540f, -0.54235244f, -0.28146735f, -1.39125514f, 0.13309635f, -1.12864995f, -1.29901242f, -0.04266220f, -1.98028529f, -1.34869373f, 0.00038156f, -0.92473024f, 1.48010647f, -0.02754467f, -0.26030368f, 0.93083733f, 0.27946711f, 0.64052200f, -0.04220961f, 1.25002527f, -1.07923257f, 0.19048618f, 0.08900311f, -0.40813437f, -0.73068553f, 0.52122378f, 0.68990833f, -0.38749605f, -1.09269309f, -1.63480806f, 1.01789618f, -0.61596102f, 0.81049860f, 1.30838764f, -1.49213874f, -0.77916288f, -0.72660202f, -0.92013240f, -1.61726642f, -0.11527207f, 0.35143322f, -1.11646879f, -1.45525432f, -0.82892823f, 0.15512508f, 1.01891017f, 1.40162635f, 1.02494884f, 0.33882582f, -0.78747398f, -0.26009330f, -0.38519114f, 0.79247451f, 0.02065756f, -0.48030257f, 1.01167107f, -1.74057114f, -0.84549171f, -0.15337363f, -1.92544484f, 1.01270044f, 0.00762185f, -0.16405612f, 1.61778915f, 0.93316060f, -0.68960994f, -1.13214970f, -0.94695878f, -0.28418848f, 0.17102109f, -0.08787476f, -1.83799696f, -0.13761258f, -0.18652774f, 1.46456254f, 0.34169790f, -0.40697145f, 1.49663997f, -0.99555492f, -0.67775637f, -0.51951116f, 1.35157657f, -0.27099034f, -0.46987835f, 2.28101230f, 0.59104478f, 0.75010139f, 1.01472175f, 0.25741309f, -0.56074983f, 1.12267506f, 0.35336846f, 0.61733276f, -1.63976014f, -0.17700450f, -0.25093642f, -0.75599891f, 2.10956192f, 0.95155340f, 0.72049862f, 0.50492924f, 0.62067389f, 2.08688402f, -0.73604703f, 0.63383341f, -0.53528428f, -2.11538506f, -0.98173052f, 0.59560484f, -0.26205051f, -0.91948050f, 0.00593397f, -0.11734286f, -1.41261208f, -0.83611172f, -0.27682739f, -0.20619918f, -0.36557615f, 0.77194935f, 1.67695415f, -1.39265156f, 0.04892010f, -0.37773246f, 0.16124558f, -0.18348448f, -1.38248885f, 0.58459854f, 0.65064198f, 1.11349559f, 0.36708066f, -0.15471332f, 0.14208725f, -2.06860566f, 0.29629150f, 0.93084633f, -0.47215626f, 0.60208917f, 0.95415461f, 1.03390312f, -0.03639749f, -0.23988228f, 1.27037442f, 0.95133096f, 0.33187470f, -0.34527761f, 0.22134073f, 1.01799667f, -0.81475645f, -1.18869019f, 0.23314142f, 0.25180560f, -1.23762786f, 1.25283313f, 0.16980635f, 0.40740708f, 0.59256923f, 0.16274920f, -0.69713289f, -0.16444311f, -2.41602516f, 0.37952334f, -0.05604568f, -0.23772651f, 0.20581599f, -0.54303211f, 1.71877348f, 0.83602583f, -0.32586128f, 0.73609394f, -1.73640239f, 0.07249248f, 0.31248692f, 1.77627432f, 0.97660398f, -0.42095289f, -0.18750280f, -0.84246057f, 0.29762223f, 1.87054563f, -1.46980762f, -0.45306337f, 1.52366042f, 1.39061129f, -0.04980387f, -0.55382830f, -0.96987218f, -0.06910808f, -0.41276473f, -0.83891344f, -0.92597574f, 0.60252470f, 0.21938549f, -0.04451685f, -1.00330937f, -0.36955237f, -1.52876902f, 0.27296364f, -1.96721256f, 0.05291027f, -0.91540521f, 0.48990685f, -1.99560380f, -0.68551093f, -0.14532298f, -1.56881595f, -0.08319287f, 0.31003201f, -1.42829597f, -0.61810297f, -0.03581250f, 0.77747720f, 1.25297558f, -1.36239243f, -1.13274276f, -0.35045877f, -2.34157228f, 0.04515179f, -0.83044821f, 1.81353962f, -1.36855912f, 0.39704823f, 0.16665934f, -0.16654585f, 1.17806077f, 1.00086153f, -1.25474250f, -1.46876431f, 1.18021631f, -0.32257929f, 2.12062597f, 0.86819613f, -1.18048275f, -1.69747460f, -0.74092305f, 0.05086798f, 1.15339577f, 1.32972670f, 0.27247882f, 0.98499072f, 2.35597157f, 0.30179837f, -0.66633248f, 0.13794266f, -0.22753908f, -0.22868259f, -1.81792033f, 0.50151759f, -0.79408127f, -1.05343878f, 0.45727381f, 0.84800923f, -1.73605800f, -0.02032863f, 1.82778001f, 1.41025102f, -0.81715560f, 0.25888795f, -0.25075480f, 0.66256499f, 0.11993053f, 1.81336939f, -0.06345166f, -1.49658346f, 0.07531686f, 0.96972889f, 0.87405980f, 0.75830793f, -0.13497087f, -2.45855975f, -0.65984958f, 0.93919373f, -0.97305542f, 0.73477978f, 1.04337513f, -1.22712576f, -0.46385625f, -1.20876372f, -0.82760453f, 0.01455977f, -1.05089867f, -0.02801843f, 0.60899758f, -0.82052249f, -1.48932517f, -0.98073828f, -0.19311285f, -0.25602359f, 0.50351876f, -1.24557400f, -0.82138073f, -1.45966852f, 0.44991320f, -0.75550151f, -0.98550314f, -1.21418869f, -1.15771639f, -1.72192061f, -0.39616469f, -0.55566746f, -1.31880891f, -0.08843257f, 1.00422776f, 0.35846478f, 0.46060917f, 0.77326930f, 1.60129988f, -1.85124147f, -0.30582917f, 1.30227256f, 1.81890345f, -0.44084981f, 0.25315762f, 0.70259613f, -0.94882858f, 1.97040296f, 0.71473581f, -0.68193883f, -0.36290962f, 1.16348684f, 0.15418798f, 1.07806778f, 0.40554729f, 0.10280909f, -1.06474805f, 0.64398485f, -0.63568884f, -0.06108581f, -1.03290677f, 1.02834034f, 1.15284693f, 0.14046004f, 1.86630619f, 0.46804786f, -0.68397558f, 1.60733378f, -1.64890087f, -1.03819239f, -1.19212389f, -0.78382361f, 0.03925850f, 1.52259934f, 0.09540676f, -0.21220762f, 0.55955195f, -0.39845437f, -2.14541650f, 0.49337825f, -0.68574250f, 0.74040270f, 0.50783634f, -1.60461199f, -1.26806450f, -0.12652303f, -0.83992827f, -0.15524681f, 0.40098447f, 0.23392735f, -0.23262636f, 0.06525709f, -0.35994548f, -1.08432877f, -0.21395946f, -0.78357452f, -0.57157278f, 0.71407390f, 0.86596155f, -1.13723528f, 0.13460183f, -1.20881450f, 0.71018457f, 0.68943661f, -0.70428050f, 0.64600736f, 0.01990297f, -0.10575775f, -0.80263519f, 0.10618331f, 0.08865548f, 1.51651669f, 0.60851854f, 1.15161908f, 1.04919207f, 1.18359745f, -0.04352076f, -0.83643389f, -0.07922365f, 0.10597949f, -1.34984851f, -1.91319740f, 0.71585363f, -2.10845160f, 0.64385056f, -0.54551518f, -1.02039802f, -1.62510490f, 1.65401149f, -0.42711899f, 0.07970079f, -0.21404363f, 0.30498922f, 1.07942021f, 0.63995659f, -1.82114816f, 0.56396323f, 1.07084870f, -2.00350380f, 0.53339815f, 0.18500003f, 1.15034151f, -0.21436051f, -0.99986565f, -0.58812016f, -0.07247020f, 0.78910017f, 0.48839527f, 0.98795873f, 0.10357288f, -0.05604928f, 0.38977858f, 0.73745090f, 1.40838420f, 0.25967824f, 0.23588051f, -0.03451392f, 1.04897523f, -1.77121758f, 2.35625434f, -0.67086869f, -0.84005541f, -0.85940343f, -1.04449213f, -0.65917015f, -0.78713167f, -0.95910054f, 0.38597879f, -0.31879017f, -0.86260867f, -1.08593106f, 0.02802678f, 0.99484950f, -0.55113328f, 2.60936737f, -0.03388772f, -0.47583574f, -0.14021793f, 0.99019170f, -1.22431207f, 0.78734446f, -1.77037835f, 0.15018673f, 0.36423206f, 1.36447549f, -1.61007094f, 0.51875496f, -1.60788095f, -1.73557448f, -0.41414359f, -0.93710536f, 0.38715765f, 0.04243837f, -1.59682858f, -1.10728157f, 1.88292623f, -1.01428258f, 0.01074958f, -1.88169158f, -0.31616244f, 0.45334938f, 1.12449574f, -1.16699445f, -1.59505820f, 0.04126552f, -0.89016622f, 0.45838884f, 0.71463561f, 0.14563711f, 0.30694655f, 0.67193079f, 0.61429602f, 1.00201404f, -0.49295208f, 0.05997690f, 0.99491668f, -0.73801446f, -1.17185295f, 0.94778723f, 0.36106884f, -0.43561545f, 0.04102699f, 0.52626407f, 0.08442099f, -1.57626402f, 1.56855237f, -1.65396678f, 1.74014664f, -0.38219589f, 0.39305371f, -0.31705827f, -1.15742850f, 0.11669596f, 0.54043210f, -0.52270615f, -0.13375773f, 0.68094701f, -1.84134769f, -1.49383473f, 0.14632171f, -0.54607725f, -1.20867658f, -1.28439069f, -1.81734920f, 1.54257309f, 0.78347659f, -0.24049839f, 1.69973648f, 0.99825776f, 0.99971974f, -0.26055810f, 0.34143049f, -0.44862366f, 0.11253342f, -0.60932243f, 0.70383030f, -1.87318194f, 0.21953633f, 0.82791799f, 1.64545465f, -0.42693698f, -0.64897031f, -0.97996652f, -1.06616282f, 0.52939081f, -0.12541170f, -0.57480675f, 0.73600835f, 0.35711968f, -0.03528263f, 0.79997194f, 0.55742902f, -0.28909785f, 0.64331138f, -1.79893720f, 1.01572442f, 0.27111965f, -0.51778597f, 0.12906317f, 0.76148927f, 1.51315522f, 0.41101140f, 0.38008851f, 0.66759896f, -0.13804778f, 0.64854795f, 1.73474562f, 0.75999504f, -0.73411214f, -0.05406699f, 1.35664344f, -0.25298578f, -0.12696666f, -0.42628938f, 0.61129904f, 1.55259824f, -0.05820796f, -0.38598019f, -0.87325627f, -0.55066222f, -1.24557889f, -0.26509118f, -0.32103062f, 1.14031804f, -0.75985742f, 0.70659167f, -1.15016067f, 1.24906838f, 0.90396994f, -0.16241251f, 0.43682271f, -1.42695689f, 0.47134697f, -1.66143429f, 0.08698819f, -1.00775325f, -2.24129725f, -1.04226267f, -0.98537570f, -0.89938259f, -1.80710697f, -1.22866321f, 0.78125423f, 1.55150509f, 0.46235040f, 0.18444096f, 0.19313288f, -2.20686269f, -0.40341458f, 0.50321484f, 0.47339424f, -0.81383848f, -0.21972439f, 0.66612029f, 0.60239881f, 1.20443010f, 0.70015103f, 0.30632916f, 0.01489905f, 0.68129027f, -0.89645082f, -2.68969011f, -0.96684915f, 1.66421318f, 0.74333072f, -0.78321886f, 1.60063362f, -1.27524030f, -1.95856726f, 0.47504124f, 0.15398432f, -0.20796098f, -0.13449343f, 0.93458968f, 1.60390890f, 0.21798505f, -0.27035928f, -1.23248971f, -1.25361061f, 1.34666133f, 1.07233441f, 0.88799530f, -1.23687923f, -0.40781614f, -0.11916534f, -0.88050151f, -0.66422415f, -2.61471510f, 0.78276747f, 2.42323995f, -1.70715427f, 0.71550035f, -0.60298312f, 0.70491880f, 0.46175584f, 0.80827898f, -0.45108104f, -0.98219043f, -1.72823501f, 1.73190725f, 0.53906441f, -1.50445580f, -0.59250867f, -0.07239901f, 0.44743437f, -0.13740127f, 1.69935930f, -1.00480616f, -0.58191377f, 0.39853972f, -0.60960841f, -0.45473522f, -0.76396072f, -0.31872150f, 1.74509728f, -0.59950751f, 0.89810580f, -0.81400329f, 1.14280319f, 1.11165059f, -1.31295311f, -1.60784578f, -0.87506992f, -1.13461006f, -2.09486437f, -0.16449419f, -0.37728927f, 0.47595578f, -0.55342919f, -0.17574213f, 2.21499181f, 1.14331865f, -0.14938518f, 0.18935619f, -0.33802557f, 0.52538890f, 0.82673949f, 1.16562462f, 1.24713838f, 0.98890215f, -0.64991701f, 1.49886703f, 1.97769642f, 0.08059916f, -1.60925281f, -1.23822486f, -1.40829837f, 0.51331180f, -0.29928651f, -1.04348791f, -0.39911583f, 0.69380492f, 1.54516888f, 1.22791195f, 2.25008130f, 1.33348894f, -0.21775827f, -0.71937007f, 0.54982573f, 1.70691478f, 0.32459491f, -0.57187974f, -0.21614684f, 1.08274269f, 0.41384646f, 0.24497485f, -1.43703413f, 0.89616930f, 0.82032162f, -0.24598582f, 0.84271127f, -0.81894702f, -0.01828136f, 1.70397091f, 0.39505738f, -0.51221430f, -0.87979966f, 0.10795479f, 0.45194778f, -0.76008922f, 1.23394477f, -0.56798172f, 1.06459570f, -0.44333413f, -2.40399075f, -0.37267187f, 1.42946172f, 0.95734519f, 1.86127949f, -0.15217264f, 1.68742633f, 1.97638428f, -0.44211119f, -0.98393327f, -0.54173928f, -1.72017395f, 0.74697793f, -1.77827263f, -1.92299354f, -0.17189410f, -0.48633271f, -2.21230388f, -0.45906609f, -0.53493047f, 0.37253976f, -0.56951141f, 0.07728028f, 0.03530006f, -1.18123293f, 1.94158125f, -1.55930352f, 0.69334733f, -1.95163214f, -0.95800400f, -0.01804711f, -0.56747472f, -0.99099451f, -1.52853060f, -0.98279524f, -1.67307866f, 0.96121490f, 0.35654056f, 1.74034202f, -1.44633865f, -0.27781928f, 1.79457986f, -0.41029963f, -0.76871634f, 0.36555341f, -0.77664107f, 0.19535238f, -0.76185411f, -0.19828433f, -0.88820636f, 0.63885397f, 0.11346363f, -2.50265074f, 0.16319332f, -1.01288569f, 1.86605489f, 0.89761645f, 1.11795115f, -0.00714116f, -0.89034635f, -0.76447034f, -0.18822117f, -0.48340848f, -0.99788517f, 1.02172959f, -0.39395007f, 0.72566581f, -0.81438208f, -0.71715081f, 0.96243578f, -1.36424279f, -1.13870537f, 1.17602491f, 0.16320205f, 0.71959788f, 1.66669416f, 0.55690295f, -0.28912008f, -1.19219172f, 0.23308393f, -0.37963116f, 0.45347008f, -0.42606446f, 1.30938649f, 1.25128853f, 0.57649273f, 0.34440875f, -0.23893952f, -1.06604803f, 0.31336102f, 0.75727910f, 0.46772480f, -0.37650385f, -0.06036821f, 1.03686309f, 0.46158856f, -1.81028461f, 1.43393028f, 0.85494965f, -2.34685564f, -0.17571987f, -0.45592231f, -1.31190526f, 1.73194158f, -0.11856517f, 0.07041293f, 0.25689471f, -0.56000596f, 2.06649089f, 0.38954756f, 1.36627376f, 0.13905638f, 0.77370811f, 0.43944249f, -0.08798827f, 0.07245751f, -1.30234015f, 0.29710820f, 0.74389762f, 0.11971968f, -0.07381748f, 1.32652700f, 1.34079397f});
|
|
|
|
auto input2 = NDArrayFactory::create<TypeParam>('c', {3, 4, 4, 5}, {0.98114507f, 0.96400015f, 0.58669623f, 0.60073098f, 0.75425418f, 0.44258752f, 0.76373084f, 0.96593234f, 0.34067846f, 0.57962620f, 0.77517051f, 0.97472977f, 0.79237527f, 0.68690428f, 0.21719366f, 0.79959206f, 0.84814187f, 0.22496814f, 0.08646965f, 0.31110474f, 0.79813162f, 0.19661444f, 0.57760099f, 0.72138960f, 0.15244268f, 0.87687051f, 0.11130344f, 0.01087698f, 0.34817841f, 0.54992017f, 0.23443850f, 0.31725614f, 0.59755220f, 0.20364695f, 0.00531392f, 0.23403114f, 0.07442912f, 0.83707647f, 0.89291743f, 0.09044587f, 0.69041462f, 0.29904183f, 0.61904680f, 0.85306847f, 0.34467042f, 0.95839152f, 0.54517124f, 0.29640937f, 0.94855959f, 0.95970016f, 0.94045145f, 0.95510301f, 0.34666505f, 0.34717010f, 0.69245678f, 0.71669175f, 0.59043738f, 0.64924132f, 0.06033522f, 0.60185199f, 0.04690073f, 0.59241154f, 0.40229547f, 0.23002481f, 0.45161195f, 0.73743778f, 0.93209113f, 0.37294358f, 0.50177744f, 0.15072501f, 0.26146917f, 0.05252146f, 0.04758931f, 0.76448288f, 0.85149045f, 0.08840467f, 0.07692576f, 0.33180160f, 0.27241259f, 0.74834620f, 0.56453640f, 0.23057286f, 0.68429752f, 0.11961551f, 0.39045977f, 0.44356094f, 0.77018807f, 0.07984410f, 0.47926806f, 0.26165759f, 0.18606064f, 0.89972877f, 0.17962874f, 0.47273120f, 0.64641705f, 0.61890443f, 0.58730015f, 0.25937832f, 0.35231561f, 0.10243882f, 0.17459193f, 0.95906995f, 0.09227025f, 0.30003223f, 0.41601210f, 0.38269713f, 0.84799751f, 0.59295173f, 0.76277990f, 0.68910424f, 0.37672606f, 0.40675461f, 0.94346058f, 0.91438505f, 0.84728183f, 0.64367667f, 0.74899979f, 0.60570691f, 0.16417363f, 0.68852426f, 0.85486889f, 0.22585792f, 0.86953176f, 0.07465519f, 0.93096301f, 0.38008822f, 0.38752587f, 0.44004038f, 0.13170612f, 0.94541045f, 0.89349973f, 0.69245307f, 0.94978877f, 0.98776658f, 0.79445884f, 0.30607409f, 0.58264961f, 0.37980538f, 0.41810784f, 0.48903038f, 0.51615888f, 0.57682794f, 0.82481897f, 0.78341080f, 0.48446465f, 0.17447931f, 0.71125424f, 0.30263851f, 0.70675352f, 0.03215584f, 0.92381065f, 0.22343694f, 0.08851149f, 0.91402490f, 0.70074717f, 0.30912192f, 0.37723206f, 0.97579397f, 0.23554587f, 0.95939133f, 0.41565709f, 0.01741416f, 0.58362787f, 0.22106662f, 0.89065537f, 0.31900249f, 0.41280911f, 0.67947610f, 0.04545590f, 0.15352812f, 0.85412524f, 0.84933222f, 0.80000225f, 0.93147073f, 0.70094105f, 0.69269875f, 0.95282194f, 0.65913582f, 0.79186874f, 0.59855248f, 0.39707430f, 0.95126239f, 0.15618217f, 0.33446689f, 0.98123758f, 0.84770758f, 0.98081012f, 0.54427413f, 0.18728519f, 0.89792955f, 0.53360126f, 0.72812986f, 0.13307744f, 0.51217443f, 0.66708084f, 0.29416915f, 0.31298995f, 0.39155037f, 0.29288291f, 0.87063305f, 0.61759154f, 0.73723332f, 0.37167635f, 0.82122716f, 0.22937430f, 0.76570536f, 0.47911792f, 0.02826214f, 0.94277323f, 0.59945469f, 0.19042060f, 0.68173155f, 0.82771295f, 0.95649538f, 0.40833101f, 0.90838542f, 0.55245881f, 0.49011012f, 0.36773444f, 0.34513527f, 0.42050683f, 0.16113964f, 0.30969388f, 0.27174174f, 0.12117655f, 0.35270175f, 0.81967867f, 0.63723136f, 0.84309389f, 0.71822576f, 0.84883484f, 0.32306117f, 0.08176457f, 0.56175486f, 0.34892198f, 0.09306929f, 0.85437582f, 0.13925577f, 0.48629188f, 0.29923539f});
|
|
auto exp = NDArrayFactory::create<TypeParam>('c', {3, 8, 8, 16}, {5.98743296f, -2.83037376f, -0.87943113f, 1.41339970f, 1.32433391f, -1.20299149f, -0.02893090f, 2.05326009f, 1.19417048f, 5.58212376f, 3.28139353f, 1.19237995f, -1.09431255f, -2.55264497f, 3.11014652f, 6.81296825f, -2.09029293f, -4.32068443f, -0.52808392f, -1.97968531f, -0.18673831f, 0.84605980f, 4.55825520f, 2.71503139f, 0.15210046f, 0.85310984f, -3.82062817f, 2.76470995f, 3.69004202f, -1.45017099f, -2.59361267f, -1.35094655f, 7.24145126f, -5.25432396f, 0.19920218f, -4.30596399f, 1.35318923f, -3.88142037f, 3.67493343f, 2.25931478f, 2.87630725f, 1.66349852f, 6.21347952f, 0.94105923f, -1.61742055f, -2.35699606f, 0.12850338f, 1.79141688f, -2.09535933f, -6.35418081f, -0.06303531f, -4.38615131f, 0.48237842f, 0.26528549f, 3.38231516f, 3.76315165f, -0.40254810f, -0.23716694f, -6.13381910f, -0.41950428f, -0.89680839f, -1.46491277f, -1.98541689f, -0.99357355f, 5.58237648f, -2.38937521f, -0.00872564f, -2.37138414f, 4.91117287f, -4.51916361f, 0.97943687f, 2.91052818f, -2.50362611f, 1.70252812f, 5.04137802f, 3.57108784f, -1.87532270f, -3.66677809f, -2.38861251f, 5.55765152f, -7.27571774f, -1.68887305f, -0.72266489f, -4.42809057f, -0.92118186f, 1.02381468f, 4.44284725f, 5.17150497f, -0.42438728f, 2.02693963f, -1.36484981f, -1.47912180f, 0.26649538f, -0.02091765f, -2.86906910f, -3.03046989f, 1.35122132f, -3.21707630f, 2.21112418f, 0.24121630f, 3.96940088f, -7.66105747f, 2.76352382f, -0.99061489f, -2.16720009f, -1.63170409f, 1.12701774f, -1.02415371f, -0.90435314f, -1.51372027f, -0.76884907f, 0.39066136f, -0.89562428f, -2.03204703f, 1.28074932f, -2.14551091f, -2.36843777f, 0.46580017f, 0.75451565f, -0.00336730f, -1.06597757f, 3.27195978f, -0.41307712f, -0.10376054f, -1.34102952f, -2.22901654f, 2.31929803f, 1.40851438f, -2.23774385f, 0.20417206f, -1.12153268f, -0.13188094f, -3.96649432f, 2.10269976f, 0.49845099f, 6.18937683f, -0.51783508f, -0.48048639f, -1.92970264f, 3.16670656f, 1.13355756f, -0.07890664f, 1.31536257f, -0.43924797f, -0.04562932f, -0.87974954f, 0.75411212f, -2.39745235f, -3.97132111f, 0.37202546f, -2.40399146f, -1.50796390f, -3.08302689f, 0.23075986f, -0.94316757f, 1.34948587f, 0.58591264f, 2.18529797f, 7.97652435f, 2.32798409f, -4.09404373f, 0.89634895f, 0.77697754f, -0.65091681f, -7.05506849f, 5.86194515f, 2.51394033f, 4.69959354f, 0.20835471f, 3.18049693f, -1.29682434f, 3.70832396f, -0.48123091f, -1.67904007f, -1.35418940f, 1.58435583f, -1.13851106f, -1.19225955f, 0.59713769f, -5.80462933f, -7.45143986f, -1.08658695f, 1.03244078f, -1.75307107f, -7.07100582f, 3.85825157f, 1.62127817f, 2.32572675f, 0.56171900f, -0.80591971f, 3.98835945f, 0.15742642f, -2.97832179f, 0.13821673f, -0.72556758f, -0.84936106f, -7.28444147f, 3.94134307f, 0.80779338f, 7.47784615f, 8.23335075f, 4.80595016f, -4.89574575f, 4.03362942f, -6.67522192f, -4.55204487f, 2.12511182f, -2.70781207f, -1.57226098f, -3.08408356f, -0.30812448f, -5.32870674f, -5.13238287f, 0.49605465f, -0.55042171f, 0.46324944f, -3.83545256f, -0.12562510f, -0.20978995f, -0.13068712f, -1.92144060f, -1.68787408f, 5.45581436f, -0.79583496f, -2.38866687f, -3.90546346f, -0.47028148f, -0.14319679f, -3.37016582f, 2.00905991f, -1.21345615f, 1.81376505f, 7.73004007f, 0.74310112f, -4.64536428f, 3.78111577f, -9.05182457f, -0.10674095f, 1.53476238f, 0.63345337f, -0.40907967f, -1.44729769f, -1.87145400f, -2.46623540f, 1.07472968f, 0.77390999f, -3.93438888f, 4.49174690f, -0.96686655f, 1.92278123f, 0.30049133f, -0.02388665f, -1.99777114f, -3.23885751f, 5.87784004f, 2.13776040f, 3.56758308f, -3.37774134f, -3.67526293f, 1.63700044f, -1.69959962f, -0.99112594f, 6.03103638f, 1.67399430f, -1.28699589f, 7.16759014f, 12.63490295f, 3.62937450f, -4.75982571f, 2.17861104f, -2.03065681f, 4.30207729f, -0.46797156f, -2.96022511f, -6.02702332f, 3.09229851f, -1.39771092f, -0.03471333f, 3.22175527f, 5.63565636f, 1.78195477f, -0.63545251f, -3.99497652f, 1.46043062f, 4.60050488f, -2.96651959f, -2.03159475f, -1.52386189f, -0.15129802f, -3.90390921f, -0.63852370f, 0.79210538f, 2.35288715f, -5.55609035f, 5.36427498f, -0.60248077f, -0.26181316f, 5.04884720f, 8.53192806f, 5.05080223f, -6.56371737f, 1.52260923f, -7.13623667f, 6.49414349f, 2.33445597f, -4.11490965f, -6.44347477f, -0.47079402f, -0.63467920f, 2.60399365f, 1.05958164f, 3.66901422f, -1.05657935f, 1.88611507f, -6.37475634f, 2.01480770f, 3.36020517f, -5.11001921f, -0.46132171f, 2.16525555f, 4.21938848f, -2.08346295f, 2.86168146f, 1.26987600f, 6.76066971f, -7.84916353f, 4.11700916f, 0.47985530f, -4.60113716f, 7.42062473f, 6.37472820f, 4.37820530f, -7.12197018f, 0.01357239f, -7.90392113f, 8.32131577f, -0.87593079f, -0.16994858f, -5.86345863f, -0.20697471f, -1.37845206f, 1.63819647f, 1.59720242f, -0.74357712f, -1.88725603f, -1.98357940f, -8.57950306f, -4.10104513f, 3.57231879f, -2.89855957f, -0.11263305f, 2.78033924f, 1.53078973f, -2.93089223f, 0.73189604f, 3.20563078f, 3.92601013f, -5.21916151f, 0.89163935f, -0.42978728f, -6.70888853f, 4.56477976f, 1.20105875f, 3.83393812f, -6.27205181f, 4.05993128f, -7.35513067f, 1.60660768f, -1.21052051f, 1.58191252f, -1.37899971f, -1.20117283f, 2.93301678f, 1.06302834f, 1.38993621f, -1.66884089f, -3.34452581f, 1.04498529f, -4.10412455f, -4.03310585f, 1.61513603f, -1.09388447f, 2.11451387f, -0.94192362f, -0.23287666f, 5.88265705f, -0.83010495f, -2.15317154f, -0.60276151f, -1.49265075f, 3.93397975f, 5.45194483f, 1.45161700f, -2.57401872f, -5.59288931f, 4.29170895f, 1.87151814f, 0.08362055f, -0.28767288f, 1.17675185f, 0.85266006f, 1.30549634f, -5.60830832f, 0.19398519f, -0.83982587f, 1.75940764f, -5.46077394f, 1.64495635f, 0.17102760f, -0.54459631f, -2.21975255f, -0.37443402f, -2.08474159f, 1.85959935f, 11.19680309f, -0.18611598f, -2.59765387f, 3.06330776f, -1.52183700f, -4.88415241f, -0.75097847f, 2.58201051f, 7.40885210f, 3.58994508f, 1.62457407f, 3.12514591f, -4.36833286f, 1.39830995f, 3.61003447f, -0.63837433f, -3.62661815f, 3.78898096f, 2.92802262f, 5.87374496f, -4.38554621f, -2.53411579f, -2.87311554f, -1.31391978f, -4.26736879f, 3.45099425f, 1.58769250f, 1.73341393f, -1.08842182f, 2.27120280f, -1.78938174f, -2.29940319f, 7.07046986f, 0.51426595f, -6.22928905f, 5.28968811f, 2.31827855f, -4.20915890f, -1.27249205f, 5.92120600f, 3.19458675f, 7.09252501f, 3.96577907f, 6.41484213f, -4.66009521f, 10.00181389f, 0.51108456f, -4.62243366f, -5.18351841f, 2.12961674f, 5.10694027f, 7.29412317f, 0.15912467f, -3.38902974f, -4.01918602f, -2.17383957f, 0.13118666f, 0.27872476f, -0.92317247f, 3.51440644f, 1.84171486f, 1.03378081f, 1.30569839f, -2.09583759f, 9.03952980f, -0.55187917f, -2.04549074f, 1.08294606f, -2.65263700f, -2.93977118f, 1.88909876f, 0.96043622f, 1.76579499f, 3.14314699f, 5.86394691f, 7.36944389f, -7.04524136f, 6.68673229f, -5.52591467f, -2.19745898f, -4.32036924f, 0.52971321f, 2.26268244f, 6.91575766f, -0.94590527f, -3.98923349f, -0.12266219f, 0.24294075f, -1.07783222f, 1.87989080f, -3.57109427f, 1.61553633f, 0.42486978f, 0.75852054f, -6.19481468f, -3.80570698f, 2.39946675f, -1.93851781f, -5.42234039f, -6.34092760f, -2.52374983f, -1.85044456f, 3.92693520f, 0.40042299f, 4.69742584f, 5.40483189f, -1.02398944f, 8.89605045f, 0.64680403f, 0.89943957f, 0.76993859f, -1.88244629f, 1.90714884f, 3.10836840f, -0.17064989f, 0.84892416f, -6.94988108f, 1.92141032f, -1.36458397f, 6.39284658f, 0.45201308f, 2.58823442f, 6.33375788f, -4.76916075f, -8.45738983f, -0.48962492f, 2.40652561f, 4.56602001f, -3.34420681f, 1.86862195f, -7.01420689f, -6.94657421f, -2.47419310f, -4.61693668f, -0.18822384f, -0.36949772f, 2.01374269f, 4.11018658f, -5.11564064f, 8.04294395f, 2.88567662f, -2.87645102f, -1.23238611f, -5.91409397f, -0.62205851f, 1.38689423f, -0.01120412f, 5.25955677f, -1.98474956f, -3.72012186f, 3.00445986f, 4.99141550f, 2.97457719f, 2.70827627f, 6.04544449f, -0.20756161f, -10.87035751f, 0.80454814f, 0.33568168f, -2.48132324f, -2.84452009f, 2.63126230f, -3.99351716f, -7.39294338f, 3.62798953f, -8.65815926f, 2.65992808f, -6.98126554f, 3.09881067f, 0.67735767f, -1.15946686f, 5.63180256f, -0.17694545f, -8.59651184f, 3.75297594f, -2.35913754f, -0.20330384f, 5.49958467f, 1.00861740f, 1.42849684f, 0.00062013f, -0.11073381f, 2.15207863f, 4.07368469f, 1.14344299f, -1.27953362f, 6.64699316f, -0.73672432f, -8.55606937f, -0.19439441f, -4.14319754f, -4.69964647f, -5.86446047f, 2.87106085f, -3.42714882f, -5.00668287f, 6.22464132f, -7.72335291f, 4.05667686f, -5.72637177f, 6.35073948f, -1.29593158f, 0.00813985f, 3.63368607f, -1.05764008f, -7.88486052f, 3.73919106f, 1.41835213f, -1.04935634f, 0.65119827f, 0.03547254f, 1.88996327f, 1.58701086f, -0.56215239f, -0.80187100f, 4.55604362f, -0.67249978f, 1.41084409f, 7.86281586f, -2.38301182f, -8.50535774f, -3.82098866f, -2.40856767f, -5.33439016f, -3.34747362f, 2.69389009f, -1.64118791f, 4.52447939f, 0.04468334f, -1.48768258f, -0.69848812f, -0.71123981f, 3.66259432f, 6.10314512f, 1.37305343f, -0.62758982f, -2.99383426f, 4.20510864f, 1.48497128f, -0.08954811f, 2.43872309f, -0.59880185f, 0.37431365f, 2.45458341f, -3.28401661f, -1.94629693f, -1.93975246f, -0.26385683f, -0.45814323f, -0.18108580f, -3.74811840f, -0.29739976f, -2.24116230f, -0.28150487f, -2.24421668f, 3.46930790f, 8.35415077f, 0.05562943f, -2.81079793f, 1.10388446f, -2.82245207f, -2.98102283f, -1.08132946f, 1.19089699f, 8.00183105f, 6.35385323f, 3.72591257f, 4.59467506f, -5.74890900f, 4.42238331f, -3.36533451f, 0.18350232f, 3.05606651f, 1.18788099f, 2.87450886f, 0.27472210f, -2.80111074f, -0.66314960f, -1.96376896f, 0.75167024f, -4.72056293f, 1.10629988f, -5.00775242f, 1.48246133f, -3.91681528f, -1.86573625f, -6.17714882f, -0.67820001f, 5.69730282f, 1.04399037f, -4.93794823f, 3.09619617f, 2.18692017f, -5.54232264f, -3.10046840f, -0.68972743f, 2.81824327f, 3.04334164f, 6.13203907f, 4.14081764f, 1.02573645f, 5.71970081f, -6.01574707f, -2.07346702f, 0.99554527f, 1.69641590f, 0.66776669f, -0.80132431f, -2.03513098f, -3.42513680f, -0.06704485f, -1.87195873f, -5.42428589f, -0.20748445f, -1.52408111f, 0.97084987f, -0.48799962f, -0.45379883f, -0.26652339f, -1.20720732f, 3.94169855f, -3.18480229f, -1.87440264f, -1.18028760f, 0.52011997f, -2.13437462f, -4.52583313f, 1.69722807f, -0.89371562f, 3.37972403f, 6.38838720f, 6.98663378f, -4.05421400f, 6.89512825f, -5.09085655f, -2.16257906f, -3.33272719f, -3.01246452f, 0.37613097f, 1.80455804f, -0.36456174f, -5.32273912f, -1.29978943f, -0.53685790f, -2.12896323f, 2.55506587f, -2.57999182f, 3.40891910f, 1.36033249f, 0.83864629f, -2.88629293f, -7.36048365f, 5.61314154f, 1.32668555f, -2.58041072f, -3.71943092f, 1.60647738f, -2.74816346f, 2.47269106f, 0.85507953f, 8.39183426f, 3.42624784f, -0.01519036f, 5.68412066f, 2.51771593f, 1.03045523f, -2.08733034f, -2.44337177f, 0.81668580f, 1.30275154f, 2.99679208f, -2.91957355f, -1.71337795f, 3.34979844f, 1.51825011f, 5.20375061f, 2.27888370f, 1.38787699f, 4.23474550f, -4.05878592f, -4.85074377f, -0.22794735f, 4.64402294f, 1.24391258f, -2.04935098f, 1.26285601f, -7.51862240f, 0.62138438f, -1.95792389f, -0.96587181f, 0.85141110f, 0.79354531f, 7.93766356f, 6.07677746f, 2.05947518f, 6.55480623f, 1.44032848f, -0.70615625f, -0.07896036f, -5.08359432f, -0.01047915f, -1.89632201f, 2.57555676f, 3.83779287f, 0.42850614f, 1.80754125f, -0.06942326f, 6.35997963f, 6.06101418f, -0.97032297f, 5.71477222f, -6.06671238f, -3.46607208f, -4.98306370f, 2.84659123f, -2.11025190f, -0.04609144f, 5.26831341f, -9.56940651f, -3.67193556f, -1.71143103f, -1.35221267f, -4.26226807f, -6.89146233f, 8.21761799f, 5.69823503f, 2.28137946f, 1.88911343f, -1.44562483f, -1.60295713f, -0.52568185f, -3.31892347f, -2.81997776f, 0.35287106f, 2.98202395f, -1.39432132f, -2.70001364f, -4.14169264f, 3.50194883f, 4.12610435f, 5.52755260f, 2.65859175f, 3.61353087f, -0.83027136f, -5.10652542f, -4.48625374f, 2.06585884f, -2.76383352f, -0.64300913f, 8.19686604f, 0.96106279f, 2.45952058f, 2.47275925f, -1.03288829f, -0.64897656f, -3.77937531f, 4.27940083f, 2.58320260f, -0.57665241f, 1.87247813f, -3.81604433f, -0.24543774f, -1.62118483f, -0.73075479f, -0.48533297f, 2.05016756f, 0.45561486f, 0.03316188f, 0.77791005f, -1.56283605f, 2.36616826f, 5.58082104f, -1.30925488f, -1.06329608f, 2.17189479f, -3.43008828f, -4.71520567f, -2.56184673f, 0.17508316f, -3.25817418f, -0.41749167f, 0.18119079f, -0.73181152f, 3.99792433f, -3.08002281f, -0.99143314f, -1.83520067f, 1.18565679f, 2.98040128f, 5.67814350f, 2.35128760f, 1.41600966f, 4.02718067f, -0.08193968f, 0.64636409f, 1.35931289f, 2.37125754f, 1.75978124f, 3.90977740f, 1.50662971f, -2.84089065f, 1.29824126f, -3.38730979f, -1.61005294f, 0.58292413f, -0.03019404f, -1.57986510f, -0.56102908f, -3.03128719f, 0.51644313f, -2.01147819f, 0.98400700f, 3.00028515f, 0.74579155f, -3.37098312f, 0.93339360f, -1.29018497f, -2.14695001f, 1.30411184f, 0.71501279f, 7.47793055f, 4.06516457f, 3.50772929f, 3.52762985f, 0.55643129f, 0.32272506f, -4.30955982f, 2.49414706f, 2.07820845f, -0.34377906f, 4.39805031f, 2.77561307f, -3.91292810f, 2.43981409f, 0.18861845f, -2.76658440f, -4.97148752f, 3.25273705f, -0.08929539f, 0.19818619f, -5.83767605f, -0.97381884f, -5.68745661f, -5.42433214f, 3.98769903f, -0.40394354f, -1.83387578f, -0.80109525f, 1.47454357f, -3.14899540f, 0.80130816f, -2.26348829f, 4.06121159f, 6.13077354f, 5.31226397f, 2.94966197f, -3.65217376f, -1.08136678f, -7.14119816f, -0.85269439f, -0.70365787f, -0.81598872f, 3.62807679f, 3.08123684f, -7.82739496f, 4.07951784f, -0.14204243f, -0.66969109f, -5.07225513f, 2.88492823f, 0.47202343f, 0.72683257f, -6.84280777f, 0.41807127f, -5.09785986f, -3.74514675f, 2.03936672f, -1.06096244f, -1.52409148f, -0.97046643f, 2.27491093f, -1.55597985f, -1.29215479f, -0.79737484f, -0.01979581f, 7.65407991f, 5.54527044f, 4.04147148f, -2.64274883f, -1.89246953f, -3.89547634f, -1.06029689f, -2.85982800f, -1.41247237f, 1.55836034f, 3.38194537f, -2.97655582f, 0.87510300f, 1.26282072f, -1.77029657f, -3.57144690f, -4.19456863f, 0.53179169f, -1.42221975f, -3.09144497f, -0.84294832f, -5.02758694f, -2.68011904f, 0.89156240f, -0.34783912f, 4.64484835f, -2.34453487f, -1.28573155f, 0.09990287f, 0.01828218f, -1.79960847f, -1.06579173f, 1.08763921f, 0.43687880f, 3.24747229f, 3.83097172f, 1.07253766f, -1.33810723f, 0.76530832f, 1.58660865f, 5.60743904f, -3.54124737f, -0.89264417f, -3.83942485f, -1.03707337f, -1.61659896f, 1.65349591f, 1.72698796f, 4.96013832f, 0.78927267f, -0.35563886f, -3.48121166f, 3.79677629f, 2.59023166f, 2.74940348f, -2.17589283f, -5.91757107f, 2.43766379f, -4.15906048f, -1.74731481f, -2.49113035f, -0.57349741f, -4.04455185f, -1.46939647f, 2.21418452f, 0.09153593f, 2.23016739f, 7.91880608f, 4.04464149f, 0.07706618f, -2.41892862f, -2.19280314f, 7.61760712f, -5.89153862f, 0.33551922f, -1.70855618f, -0.30561331f, -0.14341974f, -2.48878574f, 1.31269515f, 3.45388412f, -0.02453184f, -0.12132037f, -4.27916241f, 1.25179088f, 4.09455204f, -1.83801770f, -1.86743176f, -4.02864933f, 3.44515228f, -4.39244986f, -0.56988084f, -1.69426417f, 2.18254852f, -4.78135824f, 1.73193693f, -2.27968478f, -1.49523509f, 2.51696730f, 4.03677559f, -2.03679037f, 1.32167840f, -2.22570705f, -2.74843621f, 6.29655170f, -3.67230225f, -1.86765468f, -0.14842367f, -1.21552539f, -0.92038238f, -0.51692355f, 1.08433771f, -0.01929832f, 0.15660909f, 2.31432915f, -3.86507082f, -0.69797570f, 0.13505173f, -1.50951028f, -0.69980979f, -1.51297045f, 3.63725281f, 0.13388813f, 2.73131752f, -0.96528149f, 4.92000961f, -5.92699385f, 1.69444644f, -1.17121375f, -2.33710480f, 1.35302818f, 1.39608085f, 1.68293881f, 0.94960749f, 1.89011908f, -4.08865070f, 0.13722643f, -1.62849212f, -0.19044125f, 1.37906075f, -3.92504406f, -1.45033538f, -0.42085981f, 3.38237071f, -3.06508875f, -1.39420545f, 1.13067436f, 0.92206454f, 0.49917889f, -2.74508023f, -2.19221997f, 1.77914095f, 0.10854459f, -2.62178278f, 2.35042715f, -0.15322030f, -0.67014873f, -1.75627899f, 2.64074945f, 2.76339936f, 2.67275214f, -0.62736398f, 0.58251178f, -4.64895678f, 5.50419283f, 2.53566456f, -2.44196153f, -0.07845879f, -2.80389643f, -0.64810950f, -0.05813205f, 1.67155504f, -2.69673729f, -1.72486305f, -0.53888649f, 1.86805439f, -1.37128329f, -5.37923479f, -2.08133769f, 0.58187997f, -1.39498150f, 0.21874082f, 4.33726025f, 6.29673958f, 0.72312093f, -3.32683516f, 1.73482585f, -0.00766110f, -2.63785434f, -0.13511759f, 4.07195950f, 0.94139838f, 3.15717316f, 1.53720927f, 1.87664819f, -2.33655119f, 6.18176556f, -2.73912525f, -2.45279956f, 2.20392370f, -0.56854641f, 0.98915887f, -2.64472580f, 2.40633702f, -4.93327999f, -1.28942823f, 0.98247659f, 1.31774998f, 0.07669818f, -5.91169453f, -0.43135011f, 1.27404964f, -0.59787154f, -0.22716975f, 0.74409103f, 10.27316475f, -2.29192710f, -2.19403267f, 3.78925133f, 3.19553399f, -4.42490482f, -0.80781460f, 2.16568565f, -2.54165983f, 2.54885101f, 4.18779039f, 1.73079813f, -1.48891807f, 11.60153770f, -0.98686743f, -2.88813901f, 2.32898521f, -0.36101711f, 2.34522438f, 0.29057693f, 1.39800644f, -4.31848240f, -3.21217132f, 0.11740226f, -1.21613467f, 0.57248503f, -4.44853830f, 1.54665899f, 3.14459944f, 1.76809108f, 0.26693153f, 0.86913753f, 9.47121620f, -2.07677889f, 2.08578467f, 1.30181742f, 1.58683562f, -3.52757788f, -1.32763624f, 0.79821301f, -2.19358301f, 1.17707348f, 6.01983643f, 4.11209440f, -2.04209709f, 7.00413418f, -1.84904683f, -1.32542288f, -0.01298118f, 0.70377320f, 0.27815005f, 2.07879829f, -0.71606725f, -4.94399881f, -2.11898828f, -0.39051518f, -2.21034360f, 3.05337906f, -1.56889665f, 1.97065282f, 2.61320901f, -0.34063196f, -0.57001418f, -2.13183641f, 3.48879004f, -0.12067288f, 0.48568326f, -1.81424558f, 2.28868723f, 1.44802380f, 1.25918829f, -1.76415455f, 5.35742331f, 3.50682044f, 4.71371317f, 5.89110756f, 8.51241302f, 4.07391453f, -0.05887252f, -0.18202400f, 2.27119660f, 6.78274727f, -2.87470293f, -5.14336634f, 0.76443815f, 2.04625130f, -0.43199503f, -1.01353514f, 2.42951298f, 2.35641170f, 0.32345510f, -4.04195738f, -4.77967072f, 0.26564783f, 6.11455107f, -2.53868008f, -3.11839914f, -1.04203856f, 5.17195654f, -4.15338612f, -3.84149241f, 0.48130888f, 3.09706950f, -4.18423653f, 5.26233864f, 3.55831861f, 3.75122595f, 8.14969349f, 6.80038738f, 4.68907356f, -1.40135396f, -3.19287133f, -3.15895939f, 8.77363205f, -4.48793411f, -3.80537176f, -2.40145254f, -2.74341679f, -2.02862644f, 5.33402443f, 9.25365734f, 2.50246119f, 0.32847846f, -1.50564361f, -4.26163197f, -1.40994716f, 2.50708485f, 0.44500345f, -0.62516934f, 4.09846306f, 5.29355669f, -4.02224922f, 0.73442125f, 0.46648952f, 0.67028689f, -6.30715466f, 6.56297970f, 3.80854273f, -5.19078207f, 4.98839283f, 7.59161472f, 0.46010983f, -2.10227895f, 0.29324162f, -2.67019558f, 4.57838106f, -3.02338457f, -3.08647728f, -2.00112700f, -3.81710315f, -0.08346784f, 1.69288683f, 5.68807268f, 3.29351830f, 0.54618967f, 1.83540761f, -5.38810253f, 0.51326782f, 4.40081882f, -4.03805828f, 0.49482727f, -1.36024392f, 2.91845679f, -2.00959015f, 2.47489738f, -1.43354976f, 1.92024410f, -6.55897284f, 1.79488957f, -0.89570928f, -6.13094234f, -0.45504010f, 2.35239482f, 1.29039919f, -4.78849840f, -1.52545333f, -6.50420475f, 2.99257326f, -0.55620033f, 0.26807702f, -2.52090979f, -4.59419632f, 0.57965040f, 2.19423151f, 2.04760551f, -0.57048106f, -2.20812702f, -0.04777686f, 1.38053393f, -2.71448946f, -1.06219673f, -3.62008905f, 1.85719645f, 1.28355026f, -2.76315832f, 1.65295160f, -4.01645803f, -3.10454416f, -0.65713316f, 1.22384977f, -0.70416176f, 4.45064926f, 1.31602776f, 2.06907344f, 2.48872757f, 4.25775290f, 3.50504255f, -0.68262041f, 1.29799378f, -1.01969171f, 2.98593879f, 0.12607655f, 0.37219539f, -0.84196299f, -3.80019331f, -1.82315290f, -0.38489276f, -1.45200360f, -4.00882292f, 0.61042011f, -0.16738498f, 1.33787775f, -2.26938057f, 1.03656030f, 8.89089870f, -1.60370600f, -5.38691807f, 5.72182989f, 2.72854710f, -6.18535757f, -3.13408709f, 2.79175353f, 5.18425512f, 9.46434212f, 2.40110517f, 1.11330092f, -3.57366538f, 4.80967665f, 0.40691876f, -3.65484858f, 0.92398167f, 2.53852940f, 3.17747331f, 2.14199781f, -1.69107199f, -1.91864693f, -3.18452644f, -2.42408276f, -2.14332366f, -1.35526609f, -4.50732136f, 0.58234072f, -1.81547785f, 0.57311213f, 1.10584176f, -0.97226644f, 11.73174381f, -2.00559855f, -1.81175601f, 2.33131361f, 0.49264961f, -0.42245382f, -1.37528467f, 1.55768061f, 0.21152198f, 13.08896351f, 10.33674145f, 5.77929306f, -6.19886398f, 5.67007637f, -6.61288071f, -2.58029866f, -4.05192375f, 1.77221894f, 0.29821560f, 5.23508501f, -5.09560966f, -0.97536200f, -5.17957878f, 1.02876794f, -4.52072096f, 2.22126532f, -4.81708670f, 0.44538212f, -2.30738068f, 3.15900373f, -4.99227905f, 0.82632786f, 9.65415478f, -0.63819492f, -3.25479436f, -0.13276935f, 0.21337092f, -2.22116399f, -3.04922724f, 0.65568435f, -0.10706246f, 4.58047390f, 7.80782652f, 5.49080181f, -3.97114491f, 6.43327618f, -6.54772758f, -2.10962629f, -0.79831678f, -0.08316499f, 2.48658133f, 4.14070511f, -0.59806836f, -4.58636141f, -0.31166920f, 0.31757897f, -3.92562199f, 0.65357721f, 0.55871534f, 1.71843934f, 1.62395024f, 0.00695819f, -4.56716251f, -3.76420808f, 4.24979544f, -0.86128616f, 0.23126510f, -6.32968998f, 1.83346081f, 3.81335950f, 2.98407745f, -1.80454743f, 6.61764765f, -1.39372075f, -0.86780751f, 7.24317265f, 2.24205112f, 1.05702817f, 0.55431479f, -1.54557061f, 3.36389136f, 4.70898724f, 1.11327887f, -3.78462076f, -3.63381767f, 2.86510396f, 0.74203897f, 0.81488025f, 3.54250598f, 3.24824381f, 3.19000244f, -0.58995843f, -7.05670738f, 3.18306041f, 3.95191574f, 0.81820154f, -1.91068232f, -2.05426741f, -1.05589008f, -3.18377590f, -1.86278260f, -8.80374908f, 0.93416154f, -4.60517359f, 8.38999462f, 5.26356745f, -8.89992714f, 8.95298958f, 4.22590351f, 1.00351548f, -6.90151119f, -8.07641125f, -4.82450199f, 8.02293015f, 4.11661243f, 0.95457208f, -7.07843113f, -4.30524826f, 5.02697992f, 5.21011686f, 0.80132771f, 3.23420191f, 3.82452774f, -2.13171721f, -7.88879967f, 1.31062031f, 1.90848613f, -3.51572514f, -3.75684500f, 3.62577081f, -5.76075602f, -2.79389215f, 0.32598805f, -4.28981733f, 4.21048594f, -3.84532523f, 3.19815183f, -0.40756655f, -2.19974327f, 6.25655174f, 3.42396951f, -1.88986623f, -1.92803884f, -2.97344875f, -0.09756154f, 5.24342251f, -0.72513700f, 1.06113195f, -1.30720282f, 4.69107103f, 0.58984971f, 2.33985567f, 1.46385121f, 3.16576266f, 6.77769995f, -5.92685127f, -12.61141014f, -2.83663774f, 4.90253258f, -6.32688522f, -3.00096869f, 2.38634992f, -7.21459866f, -5.89208746f, 2.84085894f, -1.21792030f, 6.70161343f, -4.00450230f, 5.29881001f, -1.45574808f, 0.77542424f, 1.38336325f, -0.21572059f, -3.38088870f, 2.33249640f, 0.68824625f, -3.68440270f, 0.33481622f, -0.39239681f, 0.14560902f, 1.61039007f, -3.11967754f, 2.49372435f, 2.68783092f, -1.17559779f, 0.95257235f, 4.35451412f, -0.56818569f, -7.32110357f, -7.58534050f, -2.10573673f, -3.34446383f, -0.32183546f, -0.78525496f, -1.76974547f, 5.19060802f, -2.11319876f, -3.41755080f, -0.36864156f, 1.32680905f, 0.45004874f, 6.17223930f, -1.60707474f, 0.46096295f, -3.88852644f, 1.84729624f, -0.03412050f, 0.99224162f, -2.05553341f, 3.47793245f, -0.06305170f, 0.51314175f, -2.91650558f, -1.78121483f, -2.85465693f, 0.24649808f, -2.70376635f, 0.42334458f, -1.13862336f, -0.98409218f, -0.96593523f, 2.22128963f, 0.53402066f, 3.33979344f, 8.57430458f, 2.34217858f, -2.40062976f, 5.81624222f, 1.13290989f, -5.06850052f, -4.72865725f, 1.82859278f, 6.78569555f, 8.56885242f, 2.76462936f, 0.33891773f, -2.81092787f, 0.79498398f, -2.27208567f, 1.55182552f, 2.17166376f, 6.12517643f, 3.56859684f, 0.27685475f, -1.38408327f, -1.03533340f, -3.46618199f, 0.79240030f, -3.89390516f, -0.55852515f, -1.16367757f, -0.07008934f, -2.20105195f, 3.81210446f, -0.66834474f, 0.43603873f, 10.92334938f, 2.48571420f, -6.34997845f, 4.23135757f, 0.45045292f, -4.13489866f, -3.92324209f, 1.88537407f, 2.57159734f, 9.90973091f, 4.37453461f, 7.34546280f, -2.51120615f, 11.12575245f, -3.23452854f, -2.49947500f, 1.39819741f, -3.78950691f, 2.40617585f, 5.10036278f, -3.55743456f, -6.42888737f, -2.51929998f, -1.90880990f, -1.81618094f, 1.60946512f, -4.09737110f, 1.96408439f, -1.90115595f, 2.44444203f, -2.31254292f, -4.01332951f, 8.65541840f, -0.58626485f, -4.02226830f, 0.43893200f, -3.78272748f, -5.46277428f, 0.01306701f, 0.61185312f, 0.24469066f, 1.30214953f, 5.87789631f, 8.75197792f, -5.31634712f, 3.43556309f, -5.90755081f, 0.54375106f, -2.48162293f, -3.51843548f, 2.55853295f, 5.06387186f, -2.09662485f, -3.00377345f, -3.21781397f, -0.14537808f, -4.65453672f, 1.92747557f, 0.41553855f, 4.09379959f, 0.83387995f, 1.50868511f, -6.54959488f, -8.38881016f, 5.50689125f, -2.88616610f, -1.21597648f, -0.23817590f, 1.50816703f, -2.26873541f, 2.29862142f, -1.61143053f, 5.97371244f, 4.71440220f, -0.20635787f, 8.85926723f, 0.56064367f, -1.04103339f, -4.47060108f, -2.63824081f, 3.06782055f, -2.07702565f, 3.38269401f, -1.59988797f, -3.80122590f, 2.35341501f, 2.69095278f, 3.87612104f, 1.89984226f, 0.95496917f, 3.14841127f, -5.84543085f, -7.24945450f, -2.65708590f, 2.87417006f, 0.97556210f, -3.75203967f, 1.55287778f, -7.43401051f, -1.29005826f, -3.40252638f, -4.01049423f, 2.82721639f, -1.21479535f, 8.54563904f, 7.39749908f, -0.61361837f, 7.60177565f, 1.65812778f, -0.83008504f, -3.60961151f, -7.69062138f, -1.26275063f, -4.17071676f, 5.28448200f, 4.04685593f, -1.18231702f, 1.15276611f, 1.58620787f, 6.75060844f, 3.29332161f, -0.67640316f, 5.78984785f, -3.14913464f, -6.41867924f, -2.58316016f, -2.04366302f, 2.01089478f, -3.81723452f, 3.63843751f, -5.13238430f, -3.79432917f, 4.86581373f, -1.06922054f, 3.95978498f, -0.78166616f, 8.35650539f, 5.35834265f, 0.35594034f, 9.41657066f, -0.84108615f, -6.54425859f, -3.44328952f, -6.55536795f, -0.08963367f, -1.53906262f, 0.17658240f, -0.13108420f, -0.44371247f, -0.78411150f, 2.64754868f, 9.66306782f, 1.70506203f, -0.31588936f, 4.31715870f, -6.16665173f, -10.43371868f, -3.72962189f, 4.35245228f, -1.75867891f, -4.20046234f, 8.62637043f, 1.45946813f, -3.30153608f, 0.85179043f, -2.66643381f, 3.01863337f, -2.52916121f, 8.35405540f, -0.37298933f, -0.89473486f, 6.88681793f, -4.46370125f, -7.50776386f, 3.80255938f, -3.55003357f, 1.43528831f, -2.20383263f, 2.34999895f, 2.03803205f, 1.94830751f, -1.85976326f, 0.97718471f, 5.53710842f, -0.80560827f, 0.23925614f, 5.98795223f, -2.03578377f, -7.77835321f, -2.79955530f, -1.88185954f, -2.49112058f, -0.76095992f, 2.71161270f, -0.55918610f, 0.83789903f, -1.42063200f, -0.61528748f, -4.18273115f, 1.76384258f, 4.21265936f, 5.50964785f, -0.93324339f, 3.83215356f, 1.52210593f, -0.91594946f, 1.31148386f, 3.20160103f, 1.24493563f, -0.72693497f, 1.84716725f, 3.09897518f, -1.34605026f, -1.17511916f, -1.05526352f, -1.08590937f, -1.41319299f, -3.75052118f, -2.67095542f, -0.76179552f, -3.32081509f, -1.04692316f, -1.30194843f, -1.98795474f, 5.01223469f, 0.21895903f, -1.85535169f, 3.12362719f, 0.16198632f, -3.86784005f, -2.03062248f, -0.15415624f, 8.22020721f, 4.83055592f, 4.50315666f, 4.19443417f, 0.42727345f, -4.67786789f, -5.18739986f, 2.53988838f, 3.19683266f, 1.80313504f, 1.94664574f, 0.59795094f, -4.21626759f, 0.50492239f, -0.41232634f, -0.99224532f, -3.94929314f, 1.74060190f, -0.92474866f, -1.00664830f, -6.17397356f, -1.33146775f, -3.78111315f, -4.91876888f, 2.50303864f, -0.34890354f, -1.25013232f, 0.38168997f, -1.84135628f, -4.46107960f, -4.05920792f, -2.61709857f, 0.71046209f, 9.80566883f, 6.34086990f, 2.73394704f, -2.03342366f, -2.21424174f, -5.56514263f, -4.74755144f, -2.20672894f, 0.09010231f, 1.70423889f, 3.19200158f, -6.99027634f, 1.14216340f, 0.05824995f, -0.76996505f, -6.51575899f, -0.41109252f, 0.78229940f, 1.36170781f, -5.65170193f, 1.12221193f, -4.60430050f, -4.40174437f, 4.01805925f, 0.10774946f, -2.77991009f, -0.18023163f, 0.02151692f, -1.77023101f, -1.86639869f, -0.69443607f, 4.92290831f, 6.83520412f, 4.27372265f, 6.54272366f, -7.59249687f, -1.40776849f, -3.52368808f, 1.01398587f, -3.58802676f, -0.35658866f, 1.14716864f, 3.75847244f, -2.30159235f, -0.72130895f, -0.24564353f, -1.77531350f, -3.08677864f, -0.73486501f, -1.20357263f, 0.60789430f, -3.46990204f, -0.20668676f, -5.46096087f, -5.22016764f, 0.98259866f, 1.81012678f, 3.92534304f, -2.94997001f, 1.65154219f, 2.27040243f, 0.99095678f, 0.09144652f, -0.99103236f, -1.11210847f, 0.78181303f, 2.38706732f, 2.96695375f, -0.17279971f, 0.31143007f, 1.35465562f, 2.03586054f, 6.19515753f, -3.14652419f, -2.89027119f, -3.26665854f, -1.93043876f, -0.46601450f, 1.07655203f, 1.74946189f, 4.02148342f, 0.69275337f, 0.50094581f, -4.07613230f, 2.98369169f, 4.24537849f, 0.49480581f, -2.02408123f, -2.02068973f, 6.54505825f, -5.19377470f, -0.12596917f, -0.70204186f, -0.98308045f, -3.19708824f, 1.63609934f, 1.35475993f, 0.16313422f, 4.13918924f, 7.69187021f, 3.72601676f, -1.97790039f, -1.16739464f, -3.31835508f, 8.14553452f, -1.78718984f, 1.21505618f, -3.84255409f, -3.21992350f, 0.07376552f, -0.81223297f, 3.57002878f, 1.48521733f, -0.45995998f, 0.30551746f, -3.33944130f, 1.39538884f, 1.84758544f, -0.21494150f, -2.27316713f, -4.37771225f, 6.48841667f, -5.00251961f, -0.45162797f, -5.01056004f, 0.70199943f, -4.60057783f, -2.22394514f, 0.07777429f, -1.49820781f, 3.47308421f, 6.13231564f, 1.18605387f, -4.78924608f, -3.49548388f, -2.73382568f, 6.24617863f, -2.74291611f, -1.03833354f, -2.20752788f, -2.33219409f, 1.48633552f, 1.65796840f, 4.95045471f, 2.58479190f, -0.90922785f, 0.71312457f, -4.44465590f, 1.37020862f, 2.37683725f, 0.18805164f, -3.28422308f, -1.64939332f, 3.64181972f, -3.75277281f, 3.67203593f, -0.11204052f, 2.24140930f, -3.90657187f, 2.56883717f, -1.44016707f, -2.83842611f, -0.29104578f, 2.17757058f, -0.71431804f, 1.36911654f, 0.85083604f, -1.60110259f, -1.97247636f, -1.61163378f, -0.81236130f, -0.38993555f, -3.03631902f, -0.38213277f, 0.06394482f, 3.19348621f, 0.36771113f, 1.36763072f, 2.49159527f, -0.39599860f, -2.69996762f, -0.97561121f, -2.97563028f, -0.49662948f, -0.17564940f, -2.79042959f, 0.72395414f, 2.07260203f, -0.99439794f, -2.20248008f, -0.07389921f, 0.65536159f, 4.73054695f, -0.63917702f, 0.58788192f, -3.60156059f, 6.59609890f, 3.88419437f, -3.38469863f, -3.56237841f, -2.03295064f, 0.07279694f, 3.71804547f, 0.79928309f, -2.13411403f, -1.13909864f, -0.34193408f, -1.00338125f, -1.44231665f, -5.39835978f, -0.45086145f, 1.16064668f, 2.58335257f, 2.10072684f, 4.64244223f, 7.10090065f, 1.01974952f, -4.44687223f, 2.99792576f, 1.10303724f, -1.22736573f, -3.91514421f, 3.07458854f, 2.18765211f, 3.34481716f, 2.46166849f, 2.99648619f, -0.94046807f, 5.55028200f, 0.92199719f, -0.83934361f, -0.72042274f, 0.84869325f, 1.46914721f, 0.85937387f, 4.77306223f, -4.06436539f, -2.59847593f, 2.44828081f, 0.50484699f, -2.71092367f, -6.39010477f, 0.91778028f, 3.25469685f, 1.30310678f, 1.35258150f, 3.56171441f, 7.82435083f, -2.51527429f, -4.24328852f, 2.36876059f, 1.94595242f, -2.59290171f, -6.62389565f, 3.32567835f, 2.13659120f, 4.09299326f, 3.48293996f, 2.64965177f, -3.19157362f, 13.37204266f, -0.50297594f, -4.57448196f, 3.95582604f, -0.69038916f, 0.10098404f, 1.18737555f, 3.65761185f, -5.69623756f, -2.03357077f, 1.02868807f, -1.38448596f, -0.05690211f, -8.48874187f, 0.56755424f, 1.45485961f, 0.66273880f, 0.06495565f, 1.79539490f, 8.46864319f, -1.22696662f, -1.87585378f, -0.99768794f, 2.72801924f, -0.66980243f, -2.31924677f, 0.33271110f, 0.11666083f, 1.86980045f, 5.95332909f, 7.38583708f, -2.80956483f, 6.79227638f, -6.78070831f, 1.21884382f, -1.40695429f, 0.90236962f, -1.13695288f, 0.50760663f, 1.00955284f, -5.39029121f, 0.24987072f, 2.24283314f, -4.02145576f, 2.18057394f, -3.35627747f, 1.26061773f, 1.30342579f, 0.11311233f, -1.11199212f, -4.06509686f, 5.82649660f, -1.24059582f, 5.51652861f, -1.90937877f, 1.10658336f, -0.47065550f, -2.39167786f, -1.95931304f, 4.12717247f, 1.15396059f, 1.26015663f, 7.97836876f, 7.33633423f, 2.27785325f, -2.83802366f, -2.74850106f, 0.86126029f, 6.18781090f, -1.43707538f, -6.97134876f, -3.25486469f, -1.95214593f, 0.91066706f, 0.89637989f, 1.06481194f, 6.25791073f, 0.81779671f, -1.08384395f, -3.21191931f, 2.04216075f, 4.76030350f, -2.37217665f, -1.42571259f, -6.35876131f, 4.62536526f, -5.40060568f, -3.14868999f, -1.00587153f, 1.80662942f, -7.03201485f, 6.08373499f, 0.99862772f, 2.21717811f, 4.06814623f, 6.02428913f, 5.33422756f, -0.87013257f, -2.22477579f, -2.51505303f, 5.82925224f, -0.82854009f, -4.30698347f, -1.75007713f, 2.08352375f, -2.25235629f, 1.17517352f, 5.77717733f, 2.27472878f, 2.72778273f, -1.95411634f, -4.52602863f, 1.13983536f, 1.16340065f, -2.02740526f, -3.11290503f, -1.94906235f, 1.54855204f, -4.52984142f, 1.97465122f, -1.79415476f, 4.03510094f, -8.45349979f, 10.87430096f, 2.19863629f, -5.39083815f, 5.86213875f, 6.25744534f, 6.52600002f, -4.72149038f, -1.75254321f, -5.51459169f, 7.03155518f, -2.01889277f, -4.58441257f, -3.61226106f, 0.42395937f, -0.93263882f, 2.28703761f, 2.80611467f, 2.59498215f, 0.65989012f, -1.51268566f, -4.49465561f, -4.70453882f, 5.44696808f, -4.37603617f, 0.46670085f, 2.82488608f, 2.18854523f, -2.04817152f, 1.19557285f, 1.53618634f, 4.44758606f, -7.31593513f, 7.43966007f, -3.55480957f, -5.29834652f, 2.14622784f, 1.65194583f, 2.71262598f, -4.86145496f, 0.79726243f, -8.88541985f, 1.19627261f, 0.79660845f, -1.98016644f, 1.03741014f, -3.93128228f, 1.05535269f, 2.01378822f, -0.46086323f, -0.77754641f, -1.43942690f, 0.49809402f, -2.27861357f, -3.29815221f, 0.38201320f, -3.98481083f, 4.88261318f, -0.44555628f, -2.57224536f, 2.35001850f, -2.65835261f, -2.43422794f, -2.97889376f, 1.07349825f, 1.88157082f, 4.74075413f, 0.60376728f, -0.48894715f, -1.15800071f, 4.68110943f, -0.86976886f, 1.49192941f, 0.62665290f, 0.20652676f, 0.53916287f, -1.45706177f, 0.66133004f, 1.34405875f, -4.27689552f, -0.20838106f, -5.14266443f, -1.29718637f, -1.74506426f, -0.86022055f, -3.57553625f, 0.46880072f, -1.25287139f, 3.28596354f, 11.33191013f, 1.23942876f, -3.87616491f, 7.57880497f, -0.22940339f, -5.68512678f, -1.94969654f, 5.85449600f, 3.75705457f, 4.24395847f, 1.60086083f, 2.62553668f, -0.93964291f, 5.84753895f, -0.79931092f, 0.48274064f, 2.07170033f, 3.02243996f, 2.63509989f, -0.76043403f, -1.64048159f, -6.17683458f, -3.09974527f, -2.12773156f, -0.89379883f, 2.82242465f, -1.99981332f, -0.08763933f, 0.01921120f, -1.94142103f, 2.48067307f, 0.41083777f, 8.24922180f, -1.84516132f, -1.39224625f, 5.03956223f, 0.49562740f, -5.28296328f, -0.20005548f, 3.13672113f, 0.51187158f, 7.11563921f, 6.43059587f, 3.48430967f, -5.37095928f, 8.03863049f, -5.53923941f, -2.16421175f, -3.77641368f, 3.29633045f, 5.04030085f, 2.25945377f, -3.04169011f, -2.16198015f, -2.49559617f, -0.26252726f, -6.99201345f, 2.87374353f, -0.12568980f, 0.23314142f, -1.32087135f, 4.39030552f, -0.24638844f, -4.37242651f, 14.09276772f, 1.23987353f, -1.72249663f, 0.31124914f, -2.13725138f, -3.74915648f, -1.87147236f, 0.47318631f, 1.13337576f, 3.00416899f, 8.82548523f, 4.80538750f, -5.28486395f, 5.51870108f, -5.15801477f, 0.95712411f, -1.50416136f, 2.34657240f, 4.20726633f, 5.56757259f, -3.30645251f, -3.39945269f, -2.68488026f, -2.53525281f, -3.15145874f, 2.74529529f, -0.96283442f, 2.87778258f, 0.22186530f, 1.24905694f, -7.07941198f, -5.45916176f, 3.46988297f, 0.92430985f, -0.98330998f, -2.23672342f, -3.03262734f, 0.73941302f, 0.98004431f, 0.83219361f, 7.17411804f, 4.27849865f, 0.14765590f, 8.61269569f, 9.04497051f, 1.53991723f, -2.08305025f, -4.34939337f, 0.63786775f, 2.60098696f, 0.02432060f, -1.48516297f, -4.06825686f, 5.12420368f, -0.75312757f, 1.96927559f, 4.91575956f, 3.41533065f, 3.62557888f, -4.35002136f, -5.91343403f, 0.45026422f, 4.93286371f, 3.45830250f, -4.39032364f, -0.51697755f, -7.41543341f, -3.06703568f, 1.01196158f, 2.47106576f, 5.54014874f, -4.65312243f, 8.61000633f, 8.25905323f, -1.41497111f, 8.69221878f, 0.40090930f, 1.11325574f, -1.67089832f, -4.01080132f, 1.07925677f, 2.68086481f, -0.73093414f, -1.35081220f, -7.85765076f, -5.98989439f, -0.04651213f, 4.63693142f, 2.07757711f, -0.22652936f, 3.45525455f, -0.69198442f, -10.39761639f, -2.02106953f, 4.77755499f, -2.67665577f, -1.72481167f, 4.49634743f, -2.55717134f, -4.55044937f, 0.46377492f, -3.08933020f, 3.86891365f, -2.79104614f, 8.36974335f, 0.86471701f, -5.39342690f, 12.54906940f, -0.41536295f, -5.29502535f, -3.94430566f, -5.67391300f, -4.65079165f, 2.22505951f, -0.30000746f, 2.27855444f, -4.81604433f, -1.73440599f, 4.68784523f, 5.00208044f, 0.18863934f, -1.74989462f, 3.17923450f, -1.59773099f, -12.59962940f, -1.54495025f, -0.00576371f, 1.79913878f, -2.43449807f, 1.49516344f, -3.90507102f, 1.68647158f, 4.50177765f, -5.32286358f, 3.47539330f, -2.90529680f, 1.61576962f, 0.83679676f, -5.55615807f, 3.78939056f, -4.46644831f, -5.95550919f, 0.37808037f, 0.51334500f, 1.74658906f, -0.82085419f, -0.65387219f, 3.67790437f, 0.03758264f, -2.42622781f, 1.83335185f, 4.73835945f, -0.83536482f, -0.03993917f, 3.78230667f, -4.81265640f, -8.26869011f, -1.30363441f, -2.09106350f, -3.96769738f, -1.89037073f, 0.38682747f, 0.05434489f, 5.72213697f, 0.55685395f, -3.47729349f, -1.11535001f, 2.09416127f, 5.08877802f, 5.72183466f, 1.29632664f, 0.16822398f, -2.43180108f, 3.49967623f, 2.15753818f, -0.26548505f, 3.24446392f, -0.00599277f, 1.08215356f, -0.23225522f, -2.40723038f, 0.18496060f, -3.70608735f, -0.19918591f, -1.64028871f, 0.80792952f, -0.85334057f, -2.52314138f, -3.12099195f, 0.17949918f, -0.82650864f, 2.32224989f, 9.56476116f, -0.20134282f, -0.48428559f, 2.86784410f, 0.07289505f, -3.92880869f, -2.11887884f, 0.59164631f, 6.31267452f, 7.49149418f, 2.88749456f, 2.40504885f, -3.57608175f, -1.48019314f, -0.69410253f, 0.90275228f, -0.34111357f, 2.19190216f, 3.39090061f, 3.39631820f, -5.19105434f, 2.67546582f, -2.56549048f, -0.59797800f, -4.21802664f, 0.63918972f, -0.69969130f, 0.47496963f, -4.30976725f, 0.16531238f, -3.59595251f, -0.76877379f, 11.79971790f, -0.93276632f, -1.48630571f, 8.04754066f, 2.09168458f, -3.77018499f, -4.19337654f, 0.26171905f, 1.99359691f, 8.96759701f, 8.39609814f, 6.19231987f, -5.36037970f, 4.69818354f, -4.22453928f, -4.61665344f, -2.52073431f, 1.34026706f, 2.80182385f, 2.56681514f, -4.04676390f, -3.01466990f, -4.10480118f, 0.38737059f, -0.37146521f, -2.26529670f, -1.72867084f, 0.93472683f, -2.47562981f, 0.89871657f, -1.67618203f, -0.28950238f, 5.30124855f, -0.14731219f, -0.81319761f, -1.11265934f, 0.11356127f, -2.52802444f, -1.93826056f, 1.06187987f, 1.48062325f, 4.28070498f, 5.69893932f, 9.26904392f, -4.23773003f, 5.78582096f, -6.18445301f, -2.85200453f, -5.30461454f, -4.16009140f, -0.07239690f, 4.11531162f, -1.12266588f, -1.50265646f, 0.47661865f, -1.90043914f, -6.48978710f, 1.71005368f, 0.18256521f, -0.88272136f, -0.51324779f, -0.78045660f, -5.21036625f, -4.11805344f, 3.99454761f, -1.04999924f, -6.99629354f, -5.02737141f, 0.94748145f, -2.35882139f, 4.13982439f, -1.41835535f, 7.56763077f, 3.97024012f, -4.08156776f, 6.90305424f, 0.53571963f, -2.22625160f, -2.09144926f, -4.98530245f, -0.15102190f, 0.59995949f, 3.28562784f, 0.77991986f, -3.08389306f, 3.34046674f, 0.41394949f, 5.10031366f, 2.99692893f, 0.17706826f, 2.85998058f, -6.68330860f, -6.72653008f, -0.04071128f, 3.71085787f, 3.17834806f, -4.88019037f, 6.74075413f, -7.41782188f, -5.22026348f, -1.94595623f, -3.61318684f, 1.85610664f, 1.08613706f, 6.41580677f, 1.46376514f, -4.11524010f, 9.59146214f, -2.92772651f, -1.70753336f, -1.51594138f, -4.88185692f, 1.47331417f, -2.23893595f, 4.98459148f, 1.29359996f, -2.29221845f, -0.99594390f, 3.05759239f, 6.86030054f, 2.40487719f, 3.28339863f, 7.72739315f, -3.60563445f, -9.73502827f, -1.51672328f, -0.08473521f, -2.43673515f, -3.26616001f, 3.63767886f, -11.25394535f, -5.17597103f, -1.27523947f, -7.82669783f, 0.67929745f, -4.50530529f, 5.49323797f, 6.78993320f, -2.28033876f, 4.61412525f, 2.55109429f, -12.38607693f, -0.63024014f, -3.45992327f, -0.84092742f, -0.03252453f, 4.58635283f, 5.28213978f, -1.28417206f, -1.71185923f, -0.26850975f, 8.28257561f, 4.47432184f, 2.72818279f, 8.42217731f, -4.22216320f, -8.95128918f, -1.57179546f, 1.34253705f, -5.47035217f, -5.50866985f, 4.64156532f, -6.11207914f, -5.46734476f, 3.54298997f, -2.79237103f, -0.70766860f, -3.62739944f, 3.22660995f, -2.02262759f, 0.11224222f, 2.63832402f, -0.91955596f, -4.65958309f, -0.29729855f, -1.78957534f, -0.40749407f, 0.51688713f, 0.83725226f, 0.30945438f, 1.20769620f, -1.75219965f, 2.59689760f, 5.01501608f, -1.59034789f, 0.58155286f, 3.75831509f, -5.26110506f, -8.65382767f, -6.19066620f, -0.61932850f, -2.71863723f, -0.87443137f, 3.40582991f, -1.27868056f, 3.51236677f, -2.07806540f, -0.85076392f, -1.14599180f, 1.16361260f, 1.86411846f, 5.86179352f, 0.69029891f, -0.06060839f, 1.54649436f, -0.60351688f, 1.51970077f, 0.04187265f, 1.64540339f, 2.75502157f, 2.46308279f, 1.69071770f, -3.23827076f, 0.92096543f, -3.09458661f, -1.23823690f, 0.24035048f, -0.74456501f, -1.85476089f, -0.32914662f, -2.10325241f, 1.19795251f, -2.05372071f, 1.02114081f, 2.56286955f, 0.42165697f, -1.65826249f, 4.00724554f, -2.18727994f, -1.05848944f, -0.52338278f, -0.28714985f, 8.08780861f, 5.04444599f, 3.51866961f, 3.37445784f, -1.96067202f, -1.21509445f, -3.96595931f, -0.80801201f, 0.76944816f, 1.80147493f, 4.14419460f, -0.12201095f, -2.77788162f, 1.13284469f, -2.05441403f, -0.61129224f, -2.69690657f, 1.91634214f, -2.17146754f, -0.22308528f, -6.02561045f, 0.49161875f, -6.74280357f, -4.62689781f, 2.47910833f, 1.86534905f, -3.24152899f, -1.39898300f, 0.29427958f, -2.16338181f, 0.90073711f, 1.75551236f, 4.42651892f, 8.34437466f, 5.50070190f, 5.68162251f, 1.65345454f, -2.72315669f, -5.43411493f, -0.29380533f, 1.07508349f, -1.73533511f, 2.56912184f, 3.62010550f, -6.30422783f, 1.74158525f, -1.22070909f, -0.80982518f, -4.14757967f, 4.29217434f, 0.70600843f, -2.09282112f, -5.09018898f, -0.11623126f, -5.99775553f, -4.66743088f, 1.61512172f, -1.30276895f, -3.17103505f, -0.26310229f, -1.00843918f, -0.77664804f, -2.05240250f, 0.04728425f, 1.15720487f, 4.01001406f, 7.24615860f, 2.55452180f, -5.76347876f, 0.34683830f, -6.05540276f, -4.70677900f, -0.93182588f, -4.37759733f, 2.93209839f, 1.63947964f, -2.43563962f, 1.35213876f, 0.00670356f, -0.02742785f, -2.16460943f, 1.39449501f, 0.23929763f, 2.37476778f, -4.17733765f, -0.81475425f, -6.15027046f, -5.74441719f, 3.53978682f, 0.66798484f});
|
|
|
|
sd::ops::deconv2d_tf op;
|
|
auto result = op.evaluate({&input0, &input1, &input2}, {}, {7,7, 2,2, 0,0, 1,1, 1,1});
|
|
ASSERT_EQ(Status::OK(), result.status());
|
|
|
|
auto z = result.at(0);
|
|
|
|
ASSERT_TRUE(exp.isSameShape(z));
|
|
ASSERT_TRUE(exp.equalsTo(z));
|
|
|
|
}
|
|
|
|
//////////////////////////////////////////////////////////////////////
|
|
TEST_F(ConvolutionTests2, Test_Dilation2D_Again_1) {
|
|
auto x = NDArrayFactory::create<double>('c', {4, 128, 128, 4});
|
|
auto w = NDArrayFactory::create<double>('c', {4, 5, 4});
|
|
auto exp = NDArrayFactory::create<double>('c', {4, 64, 43, 4});
|
|
|
|
|
|
sd::ops::dilation2d op;
|
|
auto result = op.evaluate({&x, &w}, {}, {1, 1,5,7,1, 1,2,3,1});
|
|
ASSERT_EQ(Status::OK(), result.status());
|
|
|
|
auto z = result.at(0);
|
|
|
|
ASSERT_TRUE(exp.isSameShape(z));
|
|
}
|
|
|
|
//////////////////////////////////////////////////////////////////////
|
|
TEST_F(ConvolutionTests2, Test_Dilation2D_Again_2) {
|
|
auto x = NDArrayFactory::create<double>('c', {4, 26, 19, 4});
|
|
auto w = NDArrayFactory::create<double>('c', {11, 7, 4});
|
|
|
|
sd::ops::dilation2d op;
|
|
auto result = op.evaluate({&x, &w}, {}, {0, 1,2,3,1, 1,3,2,1});
|
|
ASSERT_EQ(Status::OK(), result.status());
|
|
|
|
}
|
|
|
|
TYPED_TEST(TypedConvolutionTests2, sconv2d_bp_1) {
|
|
TypeParam _expGradWpB[] = {1603.7102981f, 10645.6278024f, 5975.4227995f, 17697.0903052f, 12133.6353024f, 26535.0528052f, 1779.221097f, 11795.5686029f, 6721.9835994f, 19904.0811062f, 13775.2461029f, 30123.0936062f, 1954.7318976f, 12945.5094033f, 7468.5443993f, 22111.071907f, 15416.8569033f, 33711.134407f, 2130.2426974f, 14095.4502038f, 8215.1051992f, 24318.0627081f, 17058.4677038f, 37299.1752081f, 2305.7534972f, 15245.3910042f, 8961.6659991f, 26525.0535091f, 18700.0785042f, 40887.2160091f, 2481.2642970f, 16395.3318047f, 9708.2267991f, 28732.0443100f, 20341.6893047f, 44475.2568100f, 2656.7750968f, 17545.2726051f, 10454.7875990f, 30939.0351110f, 21983.3001051f, 48063.2976110f, 2832.2858966f, 18695.2134056f, 11201.3483989f, 33146.0259119f, 23624.9109056f, 51651.3384119f, 3007.7966964f, 19845.1542060f, 11947.9091988f, 35353.0167129f, 25266.5217060f, 55239.3792129f, 3183.3074962f, 20995.095006f, 12694.4699987f, 37560.007513f, 26908.132506f, 58827.4200139f};
|
|
Nd4jLong _expGradWpS[] {4, 10, 6, 1, 1, 6, 1, 1, 1, typeid(TypeParam) == typeid(float) ? 8192 : 16384, 1, 99};
|
|
NDArray expGWP(_expGradWpB, _expGradWpS);
|
|
expGWP.permutei({2,3,1,0});
|
|
|
|
TypeParam _expGradWdB[] = {2074.21032f, 2082.76104f, 2091.31176f, 2099.86248f, 2108.4132f, 2159.71752f, 2168.26824f, 2176.81896f, 2185.36968f, 2193.9204f, 2245.22472f, 2253.77544f, 2262.32616f, 2270.87688f, 2279.4276f, 2330.73192f, 2339.28264f, 2347.83336f, 2356.38408f, 2364.9348f, 2416.23912f, 2424.78984f, 2433.34056f, 2441.89128f, 2450.442f, 3112.99344f, 3122.06328f, 3131.13312f, 3140.20296f, 3149.2728f, 3203.69184f, 3212.76168f, 3221.83152f, 3230.90136f, 3239.9712f, 3294.39024f, 3303.46008f, 3312.52992f, 3321.59976f, 3330.6696f, 3385.08864f, 3394.15848f, 3403.22832f, 3412.29816f, 3421.368f, 3475.78704f, 3484.85688f, 3493.92672f, 3502.99656f, 3512.0664f, 4255.60056f, 4265.18952f, 4274.77848f, 4284.36744f, 4293.9564f, 4351.49016f, 4361.07912f, 4370.66808f, 4380.25704f, 4389.846f, 4447.37976f, 4456.96872f, 4466.55768f, 4476.14664f, 4485.7356f, 4543.26936f, 4552.85832f, 4562.44728f, 4572.03624f, 4581.6252f, 4639.15896f, 4648.74792f, 4658.33688f, 4667.92584f, 4677.5148f, 2140.10988f, 2148.92016f, 2157.73044f, 2166.54072f, 2175.351f, 2228.21268f, 2237.02296f, 2245.83324f, 2254.64352f, 2263.4538f, 2316.31548f, 2325.12576f, 2333.93604f, 2342.74632f, 2351.5566f, 2404.41828f, 2413.22856f, 2422.03884f, 2430.84912f, 2439.6594f, 2492.52108f, 2501.33136f, 2510.14164f, 2518.95192f, 2527.7622f, 3204.849f, 3214.1784f, 3223.5078f, 3232.8372f, 3242.1666f, 3298.143f, 3307.4724f, 3316.8018f, 3326.1312f, 3335.4606f, 3391.437f, 3400.7664f, 3410.0958f, 3419.4252f, 3428.7546f, 3484.731f, 3494.0604f, 3503.3898f, 3512.7192f, 3522.0486f, 3578.025f, 3587.3544f, 3596.6838f, 3606.0132f, 3615.3426f, 4373.41212f, 4383.26064f, 4393.10916f, 4402.95768f, 4412.8062f, 4471.89732f, 4481.74584f, 4491.59436f, 4501.44288f, 4511.2914f, 4570.38252f, 4580.23104f, 4590.07956f, 4599.92808f, 4609.7766f, 4668.86772f, 4678.71624f, 4688.56476f, 4698.41328f, 4708.2618f, 4767.35292f, 4777.20144f, 4787.04996f, 4796.89848f, 4806.747f};
|
|
Nd4jLong _expGradWdS[] = {4, 2, 3, 5, 5, 75, 25, 5, 1, typeid(TypeParam) == typeid(float) ? 8192 : 16384, 1, 99};
|
|
NDArray expGWD(_expGradWdB, _expGradWdS);
|
|
expGWD.permutei({2,3,1,0});
|
|
|
|
TypeParam _expEB[] = {5.0103f, 10.17147f, 15.48408f, 20.9487f, 26.5659f, 26.6832f, 21.65628f, 16.47507f, 11.139f, 5.6475f, 10.79727f, 21.90255f, 33.31698f, 45.0417f, 57.07785f, 57.3267f, 46.49334f, 35.34513f, 23.88093f, 12.0996f, 17.37801f, 35.22744f, 53.55f, 72.3474f, 91.62135f, 92.016f, 74.57958f, 56.66148f, 38.25999f, 19.3734f, 24.76962f, 50.18034f, 76.23444f, 102.9342f, 130.2819f, 130.8366f, 105.9834f, 80.47542f, 54.31038f, 27.486f, 32.9892f, 66.79545f, 101.4216f, 136.8705f, 173.145f, 173.874f, 140.7732f, 106.83825f, 72.0663f, 36.4545f, 33.8298f, 68.49375f, 103.9947f, 140.3355f, 177.519f, 178.248f, 144.3066f, 109.51395f, 73.8672f, 37.3635f, 28.85658f, 58.39302f, 88.6116f, 119.5146f, 151.1043f, 151.716f, 122.76444f, 93.11934f, 62.77842f, 31.7394f, 23.00409f, 46.52748f, 70.57188f, 95.139f, 120.23055f, 120.7107f, 97.6311f, 74.02194f, 49.88151f, 25.2081f, 16.25523f, 32.86293f, 49.82424f, 67.1403f, 84.81225f, 85.1466f, 68.83818f, 52.17045f, 35.14227f, 17.7525f, 8.5929f, 17.36517f, 26.31738f, 35.4501f, 44.7639f, 44.9382f, 36.31728f, 27.51357f, 18.5265f, 9.3555f, 8.63807f, 17.45032f, 26.43736f, 35.5998f, 44.93825f, 45.1399f, 36.46882f, 27.6199f, 18.59253f, 9.3861f, 18.18615f, 36.72737f, 55.62488f, 74.8799f, 94.49365f, 94.9122f, 76.65698f, 58.03937f, 39.05815f, 19.7121f, 28.66254f, 57.86775f, 87.61746f, 117.9135f, 148.7577f, 149.4084f, 120.63768f, 91.31331f, 61.43346f, 30.9963f, 40.08554f, 80.90806f, 122.47f, 164.7738f, 207.8219f, 208.72f, 168.48412f, 127.49662f, 85.75506f, 43.257f, 52.47345f, 105.8849f, 160.2374f, 215.534f, 271.77775f, 272.9385f, 220.2695f, 166.6442f, 112.05955f, 56.5125f, 53.82975f, 108.6158f, 164.3612f, 221.069f, 278.74225f, 279.903f, 225.8777f, 170.8778f, 114.90025f, 57.942f, 45.14002f, 91.0585f, 137.75788f, 185.2406f, 233.5091f, 234.4682f, 189.16564f, 143.06998f, 96.17878f, 48.4896f, 35.43048f, 71.45487f, 108.075f, 145.2927f, 183.1098f, 183.852f, 148.29504f, 112.13319f, 75.36462f, 37.9875f, 24.68283f, 49.76831f, 75.25766f, 101.1521f, 127.45285f, 127.9629f, 103.1927f, 78.01253f, 52.42117f, 26.4174f, 12.87877f, 25.96222f, 39.25096f, 52.7456f, 66.44675f, 66.7094f, 53.78542f, 40.6531f, 27.31183f, 13.761f, 12.59184f, 25.38317f, 38.37464f, 51.5669f, 64.9606f, 65.2566f, 52.61336f, 39.76673f, 26.71606f, 13.4607f, 26.23903f, 52.88419f, 79.93678f, 107.3981f, 135.26945f, 135.8777f, 109.53262f, 82.77361f, 55.59937f, 28.0086f, 40.96107f, 82.54206f, 124.74492f, 167.5716f, 211.02405f, 211.9608f, 170.83578f, 129.07914f, 86.68893f, 43.6632f, 56.77746f, 114.39578f, 172.85756f, 232.1654f, 292.3219f, 293.6034f, 236.60084f, 178.74182f, 120.02374f, 60.444f, 73.7077f, 148.48435f, 224.3332f, 301.2575f, 379.2605f, 380.903f, 306.9058f, 231.82015f, 155.6428f, 78.3705f, 75.6397f, 152.36785f, 230.1877f, 309.1025f, 389.1155f, 390.758f, 314.8288f, 237.79165f, 159.6433f, 80.3805f, 62.89546f, 126.67598f, 191.34416f, 256.9026f, 323.3539f, 324.7004f, 261.56684f, 197.53262f, 132.59514f, 66.7518f, 48.97887f, 98.63226f, 148.96212f, 199.9704f, 251.65905f, 252.6933f, 203.53098f, 153.68244f, 103.14573f, 51.9189f, 33.87043f, 68.19769f, 102.98308f, 138.2279f, 173.93345f, 174.6392f, 140.64322f, 106.18261f, 71.25607f, 35.8623f, 17.55064f, 35.33327f, 53.34854f, 71.5971f, 90.0796f, 90.4406f, 72.82556f, 54.97463f, 36.88716f, 18.5625f, 13.0455f, 26.44707f, 40.20528f, 54.3207f, 68.7939f, 68.9112f, 55.84908f, 42.42747f, 28.6458f, 14.5035f, 27.89367f, 56.50575f, 85.83738f, 115.8897f, 146.66385f, 146.9127f, 118.98294f, 90.32793f, 60.94653f, 30.8376f, 44.56161f, 90.21024f, 136.9476f, 184.7754f, 233.69535f, 234.09f, 189.46998f, 143.75268f, 96.93639f, 49.0194f, 63.06642f, 127.59474f, 193.58724f, 261.0462f, 329.9739f, 330.5286f, 267.3786f, 202.75302f, 136.64958f, 69.066f, 83.4252f, 168.69345f, 255.8076f, 344.7705f, 435.585f, 436.314f, 352.7772f, 267.38025f, 180.1203f, 90.9945f, 84.2658f, 170.39175f, 258.3807f, 348.2355f, 439.959f, 440.688f, 356.3106f, 270.05595f, 181.9212f, 91.9035f, 71.25738f, 144.01542f, 218.2764f, 294.0426f, 371.3163f, 371.928f, 300.57564f, 227.70894f, 153.32562f, 77.4234f, 56.34369f, 113.82228f, 172.43748f, 232.191f, 293.08455f, 293.5647f, 237.1455f, 179.58114f, 120.86991f, 61.0101f, 39.50763f, 79.77813f, 120.81264f, 162.6123f, 205.17825f, 205.5126f, 165.95178f, 125.62125f, 84.51987f, 42.6465f, 20.7321f, 41.84877f, 63.35058f, 85.2381f, 107.5119f, 107.6862f, 86.92608f, 65.77797f, 44.2413f, 22.3155f, 22.71767f, 45.82912f, 69.33496f, 93.2358f, 117.53225f, 117.7339f, 94.98322f, 71.8351f, 48.28893f, 24.3441f, 47.44335f, 95.68097f, 144.71408f, 194.5439f, 245.17165f, 245.5902f, 198.07778f, 149.76377f, 100.64695f, 50.7261f, 74.19534f, 149.59215f, 226.19226f, 303.9975f, 383.0097f, 383.6604f, 309.35688f, 233.84091f, 157.11066f, 79.1643f, 102.99194f, 207.59926f, 313.8244f, 421.6698f, 531.1379f, 532.036f, 428.89372f, 324.12142f, 217.71666f, 109.677f, 133.85145f, 269.7389f, 407.6654f, 547.634f, 689.64775f, 690.8085f, 556.7615f, 420.6602f, 282.50155f, 142.2825f, 135.20775f, 272.4698f, 411.7892f, 553.169f, 696.61225f, 697.773f, 562.3697f, 424.8938f, 285.34225f, 143.712f, 112.43842f, 226.5337f, 342.28828f, 459.7046f, 578.7851f, 579.7442f, 467.14324f, 352.87078f, 236.92438f, 119.3016f, 87.55128f, 176.35527f, 266.4138f, 357.7287f, 450.3018f, 451.044f, 363.36624f, 274.42479f, 184.21782f, 92.7435f, 60.52803f, 121.89791f, 184.11086f, 247.1681f, 311.07085f, 311.5809f, 250.9655f, 189.50093f, 127.18597f, 64.0194f, 31.35037f, 63.12502f, 95.32456f, 127.9496f, 161.00075f, 161.2634f, 129.86782f, 98.0443f, 65.79223f, 33.111f, 33.43584f, 67.30517f, 101.60864f, 136.3469f, 171.5206f, 171.8166f, 138.32936f, 104.40473f, 70.04206f, 35.2407f, 69.09703f, 139.06819f, 209.91478f, 281.6381f, 354.23945f, 354.8477f, 285.64462f, 215.55961f, 144.59137f, 72.7386f, 107.00307f, 215.32806f, 324.97692f, 435.9516f, 548.25405f, 549.1908f, 442.02378f, 333.52314f, 223.68693f, 112.5132f, 147.17346f, 296.12378f, 446.85356f, 599.3654f, 753.6619f, 754.9434f, 607.54484f, 458.35382f, 307.36774f, 154.584f, 189.6277f, 381.49435f, 575.6032f, 771.9575f, 970.5605f, 972.203f, 782.2858f, 590.11015f, 395.6728f, 198.9705f, 191.5597f, 385.37785f, 581.4577f, 779.8025f, 980.4155f, 982.058f, 790.2088f, 596.08165f, 399.6733f, 200.9805f, 157.97146f, 317.76398f, 479.38016f, 642.8226f, 808.0939f, 809.4404f, 651.23084f, 491.18462f, 329.29914f, 165.5718f, 122.04087f, 245.45826f, 370.25412f, 496.4304f, 623.98905f, 625.0233f, 502.79898f, 379.18644f, 254.18373f, 127.7889f, 83.74843f, 168.42169f, 254.02108f, 340.5479f, 428.00345f, 428.7092f, 344.83522f, 260.02861f, 174.28807f, 87.6123f, 43.07464f, 86.61527f, 130.62254f, 175.0971f, 220.0396f, 220.4006f, 177.26156f, 133.65263f, 89.57316f, 45.0225f };
|
|
Nd4jLong _expES[] = {4, 2, 3, 10, 10, 300, 100, 10, 1, typeid(TypeParam) == typeid(float) ? 8192 : 16384, 1, 99};
|
|
NDArray expE(_expEB, _expES);
|
|
|
|
auto input = NDArrayFactory::create<TypeParam>('c', {2, 3, 10, 10});
|
|
auto weightsD = NDArrayFactory::create<TypeParam>('c', {2, 3, 5, 5});
|
|
auto weightsP = NDArrayFactory::create<TypeParam>('c', {10, 6, 1, 1});
|
|
|
|
auto epsilon = NDArrayFactory::create<TypeParam>('c', {2, 3, 10, 10});
|
|
auto epsilonNext = NDArrayFactory::create<TypeParam>('c', {2, 10, 6, 6});
|
|
|
|
input.linspace(1);
|
|
weightsD.linspace(1);
|
|
weightsP.linspace(1);
|
|
epsilonNext.linspace(1);
|
|
weightsD.permutei({2,3,1,0});
|
|
weightsP.permutei({2,3,1,0});
|
|
|
|
input.applyScalar(scalar::Divide, 100.0, input);
|
|
weightsD.applyScalar(scalar::Divide, 100.0, weightsD);
|
|
weightsP.applyScalar(scalar::Divide, 100.0, weightsP);
|
|
epsilonNext.applyScalar(scalar::Divide, 100.0, epsilonNext);
|
|
|
|
sd::ops::sconv2d_bp op;
|
|
auto resultBP = op.evaluate({&input, &epsilonNext, &weightsD, &weightsP },{}, {5, 5, 1, 1, 0, 0, 1, 1, 0}, {});
|
|
|
|
ASSERT_EQ(3, resultBP.size());
|
|
|
|
auto _epsilon = resultBP.at(0);
|
|
auto _gradWD = resultBP.at(1);
|
|
auto _gradWP = resultBP.at(2);
|
|
|
|
//_gradWP->printBuffer("gradWP");
|
|
|
|
ASSERT_TRUE(_gradWP->isSameShape(&expGWP));
|
|
ASSERT_TRUE(_gradWP->isSameShape(&weightsP));
|
|
|
|
ASSERT_TRUE(_gradWP->equalsTo(&expGWP));
|
|
|
|
//_gradWD->printShapeInfo("gradWD shape");
|
|
|
|
ASSERT_TRUE(_gradWD->isSameShape(&expGWD));
|
|
ASSERT_TRUE(_gradWD->isSameShape(&weightsD));
|
|
// _gradWD->printIndexedBuffer();
|
|
ASSERT_TRUE(_gradWD->equalsTo(&expGWD));
|
|
|
|
ASSERT_TRUE(_epsilon->isSameShape(&input));
|
|
ASSERT_TRUE(_epsilon->isSameShape(&expE));
|
|
|
|
ASSERT_TRUE(_epsilon->equalsTo(&expE));
|
|
|
|
}
|
|
|
|
//////////////////////////////////////////////////////////////////////
|
|
TYPED_TEST(TypedConvolutionTests2, sconv2d_bp_2) {
|
|
|
|
int bS=3, iH=16,iW=16, iC=3,mC=3, kH=1,kW=1, sH=1,sW=1, pH=0,pW=0, dH=2,dW=2;
|
|
int oH=16,oW=16;
|
|
int oC=2;
|
|
int paddingMode = 0; // 1-SAME, 0-VALID;
|
|
int dataFormat = 0; // 1-NHWC, 0-NCHW
|
|
|
|
NDArray input('c', {bS, iC, iH, iW}, typeid(TypeParam) == typeid(float) ? sd::DataType::FLOAT32 : sd::DataType::DOUBLE);
|
|
NDArray gradO('c', {bS, oC, oH, oW}, typeid(TypeParam) == typeid(float) ? sd::DataType::FLOAT32 : sd::DataType::DOUBLE);
|
|
NDArray weightsDepth('c', {kH, kW, iC, mC}, typeid(TypeParam) == typeid(float) ? sd::DataType::FLOAT32 : sd::DataType::DOUBLE);
|
|
NDArray weightsPoint('f', {1, 1, iC*mC, oC}, typeid(TypeParam) == typeid(float) ? sd::DataType::FLOAT32 : sd::DataType::DOUBLE);
|
|
NDArray bias('c', {1,oC}, {0.5, 0.5}, typeid(TypeParam) == typeid(float) ? sd::DataType::FLOAT32 : sd::DataType::DOUBLE);
|
|
|
|
NDArray gradI(&input);
|
|
NDArray gradWD(&weightsDepth);
|
|
NDArray gradWP(&weightsPoint);
|
|
NDArray gradB(&bias);
|
|
|
|
input = 2.;
|
|
weightsDepth.linspace(0.1, 0.1);
|
|
weightsPoint.linspace(0.15, 0.1);
|
|
gradO.linspace(0.01, 0.01);
|
|
|
|
sd::ops::sconv2d_bp op;
|
|
Nd4jStatus status = op.execute({&input, &gradO, &weightsDepth, & weightsPoint, &bias},
|
|
{&gradI, &gradWD, &gradWP, &gradB},
|
|
{}, {kH,kW, sH,sW, pH,pW, dH,dW, paddingMode, dataFormat}, {});
|
|
|
|
ASSERT_EQ(Status::OK(), status);
|
|
|
|
NDArray expGradI = gradI;
|
|
NDArray expGradWD = gradWD;
|
|
NDArray expGradWP = gradWP;
|
|
NDArray expGradB = gradB;
|
|
|
|
for( int i=0; i<10; i++ ) {
|
|
Nd4jStatus status = op.execute({&input, &gradO, &weightsDepth, & weightsPoint, &bias},
|
|
{&gradI, &gradWD, &gradWP, &gradB},
|
|
{}, {kH,kW, sH,sW, pH,pW, dH,dW, paddingMode, dataFormat}, {});
|
|
ASSERT_EQ(Status::OK(), status);
|
|
|
|
ASSERT_TRUE(expGradI.equalsTo(gradI));
|
|
ASSERT_TRUE(expGradWD.equalsTo(gradWD));
|
|
ASSERT_TRUE(expGradWP.equalsTo(gradWP));
|
|
ASSERT_TRUE(expGradB.equalsTo(expGradB));
|
|
}
|
|
}
|
|
|
|
//////////////////////////////////////////////////////////////////////
|
|
TYPED_TEST(TypedConvolutionTests2, sconv2d_bp_3) {
|
|
|
|
auto input = NDArrayFactory::create<TypeParam>('c', {3, 3, 16, 16});
|
|
auto weightsD = NDArrayFactory::create<TypeParam>('c', {1, 3, 2, 2});
|
|
auto weightsP = NDArrayFactory::create<TypeParam>('c', {2, 3, 1, 1});
|
|
auto bias = NDArrayFactory::create<TypeParam>('c', {1, 2});
|
|
|
|
weightsD.permutei({2,3,1,0});
|
|
weightsP.permutei({2,3,1,0});
|
|
|
|
auto epsilonNext = NDArrayFactory::create<TypeParam>('c', {3, 2, 14, 14});
|
|
|
|
auto epsilon = NDArrayFactory::create<TypeParam>('c', {3, 3, 16, 16});
|
|
|
|
sd::ops::sconv2d_bp op;
|
|
auto result = op.evaluate({&input, &epsilonNext, &weightsD, &weightsP}, {}, {2, 2, 1, 1, 0, 0, 2, 2, 0});
|
|
|
|
auto eps = result.at(0);
|
|
auto gWD = result.at(1);
|
|
auto gWP = result.at(2);
|
|
|
|
|
|
ASSERT_TRUE(epsilon.isSameShape(eps));
|
|
}
|
|
|
|
//////////////////////////////////////////////////////////////////////
|
|
TYPED_TEST(TypedConvolutionTests2, sconv2d_bp_4) {
|
|
|
|
int bS=2, iH=4,iW=3, iC=2,mC=2, kH=3,kW=2, sH=1,sW=1, pH=0,pW=0, dH=1,dW=1;
|
|
int oH=4,oW=3;
|
|
int oC=iC*mC;
|
|
int paddingMode = 1; // 1-SAME, 0-VALID;
|
|
int dataFormat = 1; // 1-NHWC, 0-NCHW
|
|
|
|
auto input = NDArrayFactory::create<TypeParam>('c', {bS, iH, iW, iC});
|
|
auto weightsDepth = NDArrayFactory::create<TypeParam>('c', {kH, kW, iC, mC});
|
|
auto bias = NDArrayFactory::create<TypeParam>('c', {oC}, {1,2,3,4});
|
|
auto gradO = NDArrayFactory::create<TypeParam>('c', {bS, oH, oW, oC});
|
|
|
|
auto expGradI = NDArrayFactory::create<TypeParam>('c', {bS, iH, iW, iC},{0.07f, 0.19f, 0.348f, 0.652f, 0.588f, 0.956f, 0.387f, 0.687f, 1.326f, 2.022f, 1.878f, 2.67f, 1.071f, 1.515f, 2.982f, 3.966f, 3.534f, 4.614f, 1.606f, 1.982f, 3.932f, 4.748f, 4.428f, 5.308f,
|
|
1.126f, 1.63f, 3.228f, 4.3f, 3.468f, 4.604f, 3.123f, 3.999f, 7.95f, 9.798f, 8.502f, 10.446f, 3.807f, 4.827f, 9.606f, 11.742f,10.158f, 12.39f, 4.198f, 4.958f, 9.884f, 11.468f,10.38f, 12.028f});
|
|
|
|
auto expGradW = NDArrayFactory::create<TypeParam>('c', {kH, kW, iC, mC},{19.08f, 19.44f, 19.8f, 20.16f, 12.24f, 12.48f, 12.72f, 12.96f, 22.56f, 23.04f, 23.52f, 24.f, 14.4f, 14.72f, 15.04f, 15.36f, 14.76f, 15.12f, 15.48f, 15.84f, 9.36f, 9.6f, 9.84f, 10.08f});
|
|
|
|
input = 2.;
|
|
weightsDepth.linspace(0.1, 0.1);
|
|
gradO.linspace(0.01, 0.01);
|
|
|
|
sd::ops::sconv2d_bp op;
|
|
auto results = op.evaluate({&input, &gradO, &weightsDepth, &bias}, {}, {kH,kW, sH,sW, pH,pW, dH,dW, paddingMode, dataFormat});
|
|
auto* gradI = results.at(0);
|
|
auto* gradWD = results.at(1);
|
|
|
|
ASSERT_EQ(Status::OK(), results.status());
|
|
|
|
ASSERT_TRUE(expGradI.isSameShape(gradI));
|
|
ASSERT_TRUE(expGradI.equalsTo(gradI));
|
|
|
|
ASSERT_TRUE(expGradW.isSameShape(gradWD));
|
|
ASSERT_TRUE(expGradW.equalsTo(gradWD));
|
|
|
|
}
|
|
|
|
//////////////////////////////////////////////////////////////////////
|
|
TEST_F(ConvolutionTests2, sconv2d_bp_5) {
|
|
|
|
int bS=1, iH=8,iW=8, iC=3,mC=3, kH=1,kW=1, sH=1,sW=1, pH=0,pW=0, dH=1,dW=1;
|
|
int oH=8,oW=8;
|
|
int oC=2; // iC*mC if weightsPoint = nullptr
|
|
int paddingMode = 0; // 1-SAME, 0-VALID;
|
|
int dataFormat = 0; // 1-NHWC, 0-NCHW
|
|
|
|
auto input = NDArrayFactory::create<double>('c', {bS, iC, iH, iW});
|
|
auto gradO = NDArrayFactory::create<double>('c', {bS, oC, oH, oW});
|
|
auto weightsDepth = NDArrayFactory::create<double>('c', {kH, kW, iC, mC});
|
|
auto weightsPoint = NDArrayFactory::create<double>('c', {1, 1, iC*mC, oC});
|
|
auto bias = NDArrayFactory::create<double>('c', {1,oC}, {1,2});
|
|
|
|
auto gradI = NDArrayFactory::create<double>('c', {bS, iC, iH, iW});
|
|
auto gradWD = NDArrayFactory::create<double>('f', {kH, kW, iC, mC});
|
|
auto gradWP = NDArrayFactory::create<double>('c', {1, 1, iC*mC, oC});
|
|
auto gradB = NDArrayFactory::create<double>('c', {1,oC}, {1,2});
|
|
|
|
input = 2.;
|
|
weightsDepth.linspace(0.1, 0.1);
|
|
weightsDepth.linspace(-0.5, 0.1);
|
|
gradO.linspace(0.01, 0.01);
|
|
|
|
sd::ops::sconv2d_bp op;
|
|
auto status = op.execute({&input, &gradO, &weightsDepth, &weightsPoint, &bias}, {&gradI, &gradWD, &gradWP, &gradB}, {}, {kH,kW, sH,sW, pH,pW, dH,dW, paddingMode, dataFormat}, {});
|
|
ASSERT_EQ(Status::OK(), status);
|
|
}
|
|
|
|
//////////////////////////////////////////////////////////////////////
|
|
TEST_F(ConvolutionTests2, im2col_bp_1) {
|
|
|
|
int bS=3, iH=12,iW=12, iC=6,oC=3, kH=2,kW=2, sH=1,sW=1, pH=0,pW=0, dH=1,dW=1;
|
|
int oH=12,oW=12;
|
|
|
|
// [bS, iC, kH, kW, oH, oW] is de-convoluted to [bS, iC, iH, iW]
|
|
NDArray input('c', {bS, iC, iH, iW}, sd::DataType::DOUBLE);
|
|
NDArray gradO('c', {bS, iC, kH, kW, oH, oW}, sd::DataType::DOUBLE);
|
|
NDArray gradI('c', {bS, iC, iH, iW}, sd::DataType::DOUBLE); // output
|
|
|
|
sd::ops::im2col_bp op;
|
|
Nd4jStatus status = op.execute({&input, &gradO}, {&gradI}, {}, {kH, kW, sH, sW, pH, pW, dH, dW, 1}, {});
|
|
|
|
ASSERT_EQ(ND4J_STATUS_OK, status);
|
|
|
|
}
|
|
|
|
//////////////////////////////////////////////////////////////////////
|
|
TEST_F(ConvolutionTests2, deconv3d_test1) {
|
|
|
|
int bS=2, iD=4,iH=4,iW=4, iC=2,oC=3, kD=2,kH=2,kW=2, sD=1,sH=1,sW=1, pD=0,pH=0,pW=0, dD=1,dH=1,dW=1;
|
|
int oD=3,oH=3,oW=3;
|
|
int paddingMode = 0; // 1-SAME, 0-VALID;
|
|
int dataFormat = 1; // 1-NDHWC, 0-NCDHW
|
|
|
|
auto input = NDArrayFactory::create<double>('c', {bS, oD, oH, oW, oC});
|
|
auto weights = NDArrayFactory::create<double>('c', {kD, kH, kW, iC, oC});
|
|
auto exp = NDArrayFactory::create<double>('c', {bS, iD, iH, iW, iC}, {0.3 , 0.75, 1.5 , 2.4 , 1.5 , 2.4 , 1.2 , 1.65, 2.4 , 3.3 , 6.6 , 8.4 , 6.6 , 8.4 , 4.2 , 5.1 , 2.4 , 3.3 , 6.6 , 8.4 , 6.6 , 8.4 , 4.2 , 5.1 , 2.1 , 2.55, 5.1 , 6. , 5.1 , 6. , 3. , 3.45,
|
|
4.2 , 5.1 ,10.2 ,12. ,10.2 ,12. , 6. , 6.9 ,12. ,13.8 ,27.6 ,31.2 ,27.6 ,31.2 ,15.6 ,17.4 ,12. ,13.8 ,27.6 ,31.2 ,27.6 ,31.2 ,15.6 ,17.4 , 7.8 , 8.7 ,17.4 ,19.2 ,17.4 ,19.2 , 9.6 ,10.5 ,
|
|
4.2 , 5.1 ,10.2 ,12. ,10.2 ,12. , 6. , 6.9 ,12. ,13.8 ,27.6 ,31.2 ,27.6 ,31.2 ,15.6 ,17.4 ,12. ,13.8 ,27.6 ,31.2 ,27.6 ,31.2 ,15.6 ,17.4 , 7.8 , 8.7 ,17.4 ,19.2 ,17.4 ,19.2 , 9.6 ,10.5 ,
|
|
3.9 , 4.35, 8.7 , 9.6 , 8.7 , 9.6 , 4.8 , 5.25, 9.6 ,10.5 ,21. ,22.8 ,21. ,22.8 ,11.4 ,12.3 , 9.6 ,10.5 ,21. ,22.8 ,21. ,22.8 ,11.4 ,12.3 , 5.7 , 6.15,12.3 ,13.2 ,12.3 ,13.2 , 6.6 , 7.05,
|
|
0.3 , 0.75, 1.5 , 2.4 , 1.5 , 2.4 , 1.2 , 1.65, 2.4 , 3.3 , 6.6 , 8.4 , 6.6 , 8.4 , 4.2 , 5.1 , 2.4 , 3.3 , 6.6 , 8.4 , 6.6 , 8.4 , 4.2 , 5.1 , 2.1 , 2.55, 5.1 , 6. , 5.1 , 6. , 3. , 3.45,
|
|
4.2 , 5.1 ,10.2 ,12. ,10.2 ,12. , 6. , 6.9 ,12. ,13.8 ,27.6 ,31.2 ,27.6 ,31.2 ,15.6 ,17.4 ,12. ,13.8 ,27.6 ,31.2 ,27.6 ,31.2 ,15.6 ,17.4 , 7.8 , 8.7 ,17.4 ,19.2 ,17.4 ,19.2 , 9.6 ,10.5 ,
|
|
4.2 , 5.1 ,10.2 ,12. ,10.2 ,12. , 6. , 6.9 ,12. ,13.8 ,27.6 ,31.2 ,27.6 ,31.2 ,15.6 ,17.4 ,12. ,13.8 ,27.6 ,31.2 ,27.6 ,31.2 ,15.6 ,17.4 , 7.8 , 8.7 ,17.4 ,19.2 ,17.4 ,19.2 , 9.6 ,10.5 ,
|
|
3.9 , 4.35, 8.7 , 9.6 , 8.7 , 9.6 , 4.8 , 5.25, 9.6 ,10.5 ,21. ,22.8 ,21. ,22.8 ,11.4 ,12.3 , 9.6 ,10.5 ,21. ,22.8 ,21. ,22.8 ,11.4 ,12.3 , 5.7 , 6.15,12.3 ,13.2 ,12.3 ,13.2 , 6.6 , 7.05});
|
|
input = 0.5;
|
|
weights.linspace(0.1, 0.1);
|
|
|
|
sd::ops::deconv3d op;
|
|
auto results = op.evaluate({&input, &weights}, {}, {kD,kH,kW, sD,sH,sW, pD,pH,pW, dD,dH,dW, paddingMode, dataFormat}, {});
|
|
auto output = results.at(0);
|
|
|
|
// output->printBuffer();
|
|
|
|
ASSERT_EQ(Status::OK(), results.status());
|
|
ASSERT_TRUE(exp.isSameShape(output));
|
|
ASSERT_TRUE(exp.equalsTo(output));
|
|
|
|
}
|
|
|
|
//////////////////////////////////////////////////////////////////////
|
|
TEST_F(ConvolutionTests2, deconv3d_test2) {
|
|
|
|
int bS=2, iD=4,iH=4,iW=4, iC=2,oC=3, kD=2,kH=2,kW=2, sD=1,sH=1,sW=1, pD=0,pH=0,pW=0, dD=1,dH=1,dW=1;
|
|
int oD=4,oH=4,oW=4;
|
|
int paddingMode = 1; // 1-SAME, 0-VALID;
|
|
int dataFormat = 1; // 1-NDHWC, 0-NCDHW
|
|
|
|
auto input = NDArrayFactory::create<double>('c', {bS, oD, oH, oW, oC});
|
|
auto weights = NDArrayFactory::create<double>('c', {kD, kH, kW, iC, oC});
|
|
auto exp = NDArrayFactory::create<double>('c', {bS, iD, iH, iW, iC}, {0.3 , 0.75, 1.5 , 2.4 , 1.5 , 2.4 , 1.5 , 2.4 , 2.4 , 3.3 , 6.6 , 8.4 , 6.6 , 8.4 , 6.6 , 8.4 , 2.4 , 3.3 , 6.6 , 8.4 , 6.6 , 8.4 , 6.6 , 8.4 , 2.4 , 3.3 , 6.6 , 8.4 , 6.6 , 8.4 , 6.6 , 8.4 ,
|
|
4.2 , 5.1 ,10.2 , 12. ,10.2 , 12. ,10.2 , 12. ,12. , 13.8 ,27.6 , 31.2 ,27.6 , 31.2 ,27.6 , 31.2 ,12. , 13.8 ,27.6 , 31.2 ,27.6 , 31.2 ,27.6 , 31.2 ,12. , 13.8 ,27.6 , 31.2 ,27.6 , 31.2 ,27.6 , 31.2 ,
|
|
4.2 , 5.1 ,10.2 , 12. ,10.2 , 12. ,10.2 , 12. ,12. , 13.8 ,27.6 , 31.2 ,27.6 , 31.2 ,27.6 , 31.2 ,12. , 13.8 ,27.6 , 31.2 ,27.6 , 31.2 ,27.6 , 31.2 ,12. , 13.8 ,27.6 , 31.2 ,27.6 , 31.2 ,27.6 , 31.2 ,
|
|
4.2 , 5.1 ,10.2 , 12. ,10.2 , 12. ,10.2 , 12. ,12. , 13.8 ,27.6 , 31.2 ,27.6 , 31.2 ,27.6 , 31.2 ,12. , 13.8 ,27.6 , 31.2 ,27.6 , 31.2 ,27.6 , 31.2 ,12. , 13.8 ,27.6 , 31.2 ,27.6 , 31.2 ,27.6 , 31.2 ,
|
|
0.3 , 0.75, 1.5 , 2.4 , 1.5 , 2.4 , 1.5 , 2.4 , 2.4 , 3.3 , 6.6 , 8.4 , 6.6 , 8.4 , 6.6 , 8.4 , 2.4 , 3.3 , 6.6 , 8.4 , 6.6 , 8.4 , 6.6 , 8.4 , 2.4 , 3.3 , 6.6 , 8.4 , 6.6 , 8.4 , 6.6 , 8.4 ,
|
|
4.2 , 5.1 ,10.2 , 12. ,10.2 , 12. ,10.2 , 12. ,12. , 13.8 ,27.6 , 31.2 ,27.6 , 31.2 ,27.6 , 31.2 ,12. , 13.8 ,27.6 , 31.2 ,27.6 , 31.2 ,27.6 , 31.2 ,12. , 13.8 ,27.6 , 31.2 ,27.6 , 31.2 ,27.6 , 31.2 ,
|
|
4.2 , 5.1 ,10.2 , 12. ,10.2 , 12. ,10.2 , 12. ,12. , 13.8 ,27.6 , 31.2 ,27.6 , 31.2 ,27.6 , 31.2 ,12. , 13.8 ,27.6 , 31.2 ,27.6 , 31.2 ,27.6 , 31.2 ,12. , 13.8 ,27.6 , 31.2 ,27.6 , 31.2 ,27.6 , 31.2 ,
|
|
4.2 , 5.1 ,10.2 , 12. ,10.2 , 12. ,10.2 , 12. ,12. , 13.8 ,27.6 , 31.2 ,27.6 , 31.2 ,27.6 , 31.2 ,12. , 13.8 ,27.6 , 31.2 ,27.6 , 31.2 ,27.6 , 31.2 ,12. , 13.8 ,27.6 , 31.2 ,27.6 , 31.2 ,27.6 , 31.2 });
|
|
input = 0.5;
|
|
weights.linspace(0.1, 0.1);
|
|
|
|
sd::ops::deconv3d op;
|
|
auto results = op.evaluate({&input, &weights}, {}, {kD,kH,kW, sD,sH,sW, pD,pH,pW, dD,dH,dW, paddingMode, dataFormat}, {});
|
|
auto output = results.at(0);
|
|
|
|
ASSERT_EQ(Status::OK(), results.status());
|
|
ASSERT_TRUE(exp.isSameShape(output));
|
|
ASSERT_TRUE(exp.equalsTo(output));
|
|
|
|
}
|
|
|
|
//////////////////////////////////////////////////////////////////////
|
|
TEST_F(ConvolutionTests2, deconv3d_test3) {
|
|
|
|
int bS=2, iD=4,iH=4,iW=4, iC=2,oC=3, kD=2,kH=2,kW=2, sD=1,sH=1,sW=1, pD=0,pH=0,pW=0, dD=1,dH=1,dW=1;
|
|
int oD=3,oH=3,oW=3;
|
|
int paddingMode = 0; // 1-SAME, 0-VALID;
|
|
int dataFormat = 0; // 1-NDHWC, 0-NCDHW
|
|
|
|
auto input = NDArrayFactory::create<double>('c', {bS, oC, oD, oH, oW});
|
|
auto weights = NDArrayFactory::create<double>('c', {oC, iC, kD, kH, kW});
|
|
auto exp = NDArrayFactory::create<double>('c', {bS, iC, iD, iH, iW}, {2.55, 5.25, 5.25, 2.7, 5.4 , 11.1 , 11.1 , 5.7, 5.4 , 11.1 , 11.1 , 5.7, 2.85, 5.85, 5.85, 3. , 5.7 , 11.7 , 11.7 , 6. ,12. , 24.6 , 24.6 , 12.6,12. , 24.6 , 24.6 , 12.6, 6.3 , 12.9 , 12.9 , 6.6,
|
|
5.7 , 11.7 , 11.7 , 6. ,12. , 24.6 , 24.6 , 12.6,12. , 24.6 , 24.6 , 12.6, 6.3 , 12.9 , 12.9 , 6.6, 3.15, 6.45, 6.45, 3.3, 6.6 , 13.5 , 13.5 , 6.9, 6.6 , 13.5 , 13.5 , 6.9, 3.45, 7.05, 7.05, 3.6,
|
|
3.75, 7.65, 7.65, 3.9, 7.8 , 15.9 , 15.9 , 8.1, 7.8 , 15.9 , 15.9 , 8.1, 4.05, 8.25, 8.25, 4.2, 8.1 , 16.5 , 16.5 , 8.4,16.8 , 34.2 , 34.2 , 17.4,16.8 , 34.2 , 34.2 , 17.4, 8.7 , 17.7 , 17.7 , 9. ,
|
|
8.1 , 16.5 , 16.5 , 8.4,16.8 , 34.2 , 34.2 , 17.4,16.8 , 34.2 , 34.2 , 17.4, 8.7 , 17.7 , 17.7 , 9. , 4.35, 8.85, 8.85, 4.5, 9. , 18.3 , 18.3 , 9.3, 9. , 18.3 , 18.3 , 9.3, 4.65, 9.45, 9.45, 4.8,
|
|
2.55, 5.25, 5.25, 2.7, 5.4 , 11.1 , 11.1 , 5.7, 5.4 , 11.1 , 11.1 , 5.7, 2.85, 5.85, 5.85, 3. , 5.7 , 11.7 , 11.7 , 6. ,12. , 24.6 , 24.6 , 12.6,12. , 24.6 , 24.6 , 12.6, 6.3 , 12.9 , 12.9 , 6.6,
|
|
5.7 , 11.7 , 11.7 , 6. ,12. , 24.6 , 24.6 , 12.6,12. , 24.6 , 24.6 , 12.6, 6.3 , 12.9 , 12.9 , 6.6, 3.15, 6.45, 6.45, 3.3, 6.6 , 13.5 , 13.5 , 6.9, 6.6 , 13.5 , 13.5 , 6.9, 3.45, 7.05, 7.05, 3.6,
|
|
3.75, 7.65, 7.65, 3.9, 7.8 , 15.9 , 15.9 , 8.1, 7.8 , 15.9 , 15.9 , 8.1, 4.05, 8.25, 8.25, 4.2, 8.1 , 16.5 , 16.5 , 8.4,16.8 , 34.2 , 34.2 , 17.4,16.8 , 34.2 , 34.2 , 17.4, 8.7 , 17.7 , 17.7 , 9. ,
|
|
8.1 , 16.5 , 16.5 , 8.4,16.8 , 34.2 , 34.2 , 17.4,16.8 , 34.2 , 34.2 , 17.4, 8.7 , 17.7 , 17.7 , 9. , 4.35, 8.85, 8.85, 4.5, 9. , 18.3 , 18.3 , 9.3, 9. , 18.3 , 18.3 , 9.3, 4.65, 9.45, 9.45, 4.8});
|
|
input = 0.5;
|
|
weights.linspace(0.1, 0.1);
|
|
weights.permutei({2, 3, 4, 1, 0});
|
|
|
|
sd::ops::deconv3d op;
|
|
auto results = op.evaluate({&input, &weights}, {}, {kD,kH,kW, sD,sH,sW, pD,pH,pW, dD,dH,dW, paddingMode, dataFormat}, {});
|
|
auto output = results.at(0);
|
|
|
|
ASSERT_EQ(Status::OK(), results.status());
|
|
ASSERT_TRUE(exp.isSameShape(output));
|
|
ASSERT_TRUE(exp.equalsTo(output));
|
|
|
|
}
|
|
|
|
//////////////////////////////////////////////////////////////////////
|
|
TEST_F(ConvolutionTests2, deconv3d_test4) {
|
|
|
|
int bS=2, iD=2,iH=2,iW=2, iC=2,oC=3, kD=2,kH=2,kW=2, sD=1,sH=1,sW=1, pD=1,pH=1,pW=1, dD=1,dH=1,dW=1;
|
|
int oD=3,oH=3,oW=3;
|
|
int paddingMode = 0; // 1-SAME, 0-VALID;
|
|
int dataFormat = 0; // 1-NDHWC, 0-NCDHW
|
|
|
|
auto input = NDArrayFactory::create<double>('c', {bS, oC, oD, oH, oW});
|
|
auto weights = NDArrayFactory::create<double>('c', {oC, iC, kD, kH, kW});
|
|
auto exp = NDArrayFactory::create<double>('c', {bS, iC, iD, iH, iW}, {24.6, 24.6,24.6, 24.6,24.6, 24.6,24.6, 24.6,34.2, 34.2,34.2, 34.2,34.2, 34.2,34.2, 34.2,24.6, 24.6,24.6, 24.6,
|
|
24.6, 24.6,24.6, 24.6,34.2, 34.2,34.2, 34.2,34.2, 34.2,34.2, 34.2});
|
|
input = 0.5;
|
|
weights.linspace(0.1, 0.1);
|
|
weights.permutei({2, 3, 4, 1, 0});
|
|
|
|
sd::ops::deconv3d op;
|
|
auto results = op.evaluate({&input, &weights}, {}, {kD,kH,kW, sD,sH,sW, pD,pH,pW, dD,dH,dW, paddingMode, dataFormat}, {});
|
|
auto output = results.at(0);
|
|
|
|
ASSERT_EQ(Status::OK(), results.status());
|
|
ASSERT_TRUE(exp.isSameShape(output));
|
|
ASSERT_TRUE(exp.equalsTo(output));
|
|
|
|
}
|
|
|
|
//////////////////////////////////////////////////////////////////////
|
|
TEST_F(ConvolutionTests2, deconv3d_test5) {
|
|
int bS=1, oD=5,oH=5,oW=5, oC=3,iC=2, kD=2,kH=2,kW=2, sD=1,sH=1,sW=1, pD=0,pH=0,pW=0, dD=2,dH=2,dW=2;
|
|
int iD=3,iH=3,iW=3;
|
|
int paddingMode = 0; // 1-SAME, 0-VALID;
|
|
int dataFormat = 1; // 1-NHWC, 0-NCHW
|
|
|
|
auto input = NDArrayFactory::create<float>('c', {bS, iD, iH, iW, iC});
|
|
auto weights = NDArrayFactory::create<float>('c', {kD, kH, kW, oC, iC});
|
|
auto bias = NDArrayFactory::create<float>('c', {oC});
|
|
|
|
auto exp = NDArrayFactory::create<float>('c', {bS, oD, oH, oW, oC}, {-2.9f, -6.8f, -10.7f, -2.6f, -6.1f, -9.6f, -16.9f, -23.9f, -30.9f, -13.1f, -16.6f, -20.1f, -11.6f, -14.7f, -17.8f, -2.0f, -4.7f, -7.4f, -1.7f, -4.0f, -6.3f, -11.5f,
|
|
-16.1f, -20.7f, -8.6f, -10.9f, -13.2f, -7.1f, -9.0f, -10.9f, -27.4f, -32.8f, -38.2f, -24.4f, -29.0f, -33.6f, -65.0f, -74.2f, -83.4f, -38.2f, -42.8f, -47.4f, -32.8f,
|
|
-36.6f, -40.4f, -18.2f, -20.9f, -23.6f, -15.5f, -17.8f, -20.1f, -39.1f, -43.7f, -48.3f, -22.4f, -24.7f, -27.0f, -18.5f, -20.4f, -22.3f, -10.1f, -11.6f, -13.1f, -7.4f,
|
|
-8.5f, -9.6f, -19.3f, -21.5f, -23.7f, -10.7f, -11.8f, -12.9f, -6.8f, -7.5f, -8.2f, -0.2f, -0.5f, -0.8f, 0.1f, 0.2f, 0.3f, -0.7f, -0.5f, -0.3f, 0.4f, 0.5f, 0.6f, 1.9f, 2.4f,
|
|
2.9f, 0.7f, 1.6f, 2.5f, 1.0f, 2.3f, 3.6f, 4.7f, 7.3f, 9.9f, 4.9f, 6.2f, 7.5f, 6.4f, 8.1f, 9.8f, -0.4f, 1.4f, 3.2f, 2.6f, 5.2f, 7.8f, 10.6f, 15.8f, 21.0f, 10.4f, 13.0f, 15.6f,
|
|
15.8f, 19.2f, 22.6f, 6.1f, 7.0f, 7.9f, 8.8f, 10.1f, 11.4f, 20.3f, 22.9f, 25.5f, 12.7f, 14.0f, 15.3f, 16.6f, 18.3f, 20.0f, 14.2f, 16.3f, 18.4f, 16.9f, 19.4f, 21.9f, 40.1f,
|
|
45.1f, 50.1f, 24.4f, 26.9f, 29.4f, 28.3f, 31.2f, 34.1f, -47.2f, -47.8f, -48.4f, -41.8f, -41.6f, -41.4f, -85.4f, -85.f, -84.6f, -41.2f, -41.0f, -40.8f, -33.4f, -32.4f, -31.4f,
|
|
-31.f, -29.2f, -27.4f, -25.6f, -23.0f, -20.4f, -45.8f, -40.6f, -35.4f, -17.8f, -15.2f, -12.6f, -10.0f, -6.6f, -3.2f, -65.6f, -62.0f, -58.4f, -50.0f, -44.8f, -39.6f, -89.2f,
|
|
-78.8f, -68.4f, -34.4f, -29.2f, -24.f, -14.0f, -7.2f, -0.4f, -20.2f, -18.4f, -16.6f, -10.f, -7.4f, -4.8f, -14.6f, -9.4f, -4.2f, -2.2f, 0.4f, 3.0f, 10.4f, 13.8f, 17.2f, 10.4f,
|
|
14.6f, 18.8f, 20.6f, 25.6f, 30.6f, 53.8f, 63.8f, 73.8f, 35.6f, 40.6f, 45.6f, 48.2f, 54.0f, 59.8f, -3.8f, -4.1f, -4.4f, 1.3f, 1.4f, 1.5f, 1.7f, 1.9f, 2.1f, 1.6f, 1.7f, 1.8f, 7.9f,
|
|
8.4f, 8.9f, 11.5f, 12.4f, 13.3f, 16.6f, 17.9f, 19.2f, 35.9f, 38.5f, 41.1f, 20.5f, 21.8f, 23.1f, 26.8f, 28.5f, 30.2f, 21.2f, 23.0f, 24.8f, 33.8f, 36.4f, 39.0f, 73.0f, 78.2f,
|
|
83.4f, 41.6f, 44.2f, 46.8f, 56.6f, 60.0f, 63.4f, 16.9f, 17.8f, 18.7f, 24.4f, 25.7f, 27.f, 51.5f, 54.1f, 56.7f, 28.3f, 29.6f, 30.9f, 37.0f, 38.7f, 40.4f, 39.4f, 41.5f,
|
|
43.6f, 46.9f, 49.4f, 51.9f, 100.1f, 105.1f, 110.1f, 54.4f, 56.9f, 59.4f, 63.1f, 66.0f, 68.9f, 42.1f, 45.4f, 48.7f, 47.2f, 50.9f, 54.6f, 104.3f, 111.7f,
|
|
119.1f, 58.3f, 62.0f, 65.7f, 64.6f, 68.7f, 72.8f, 57.4f, 61.9f, 66.4f, 62.5f, 67.4f, 72.3f, 138.5f, 148.3f, 158.1f, 77.2f, 82.1f, 87.0f, 83.5f, 88.8f, 94.1f,
|
|
134.6f, 143.6f, 152.6f, 147.2f, 157.0f, 166.8f, 321.4f, 341.0f, 360.6f, 176.6f, 186.4f, 196.2f, 191.6f, 202.2f, 212.8f, 84.4f, 88.9f,
|
|
93.4f, 91.9f, 96.8f, 101.7f, 197.3f, 207.1f, 216.9f, 106.6f, 111.5f, 116.4f, 115.3f, 120.6f, 125.9f, 106.9f, 112.6f, 118.3f, 114.4f, 120.5f, 126.6f, 245.9f, 258.1f, 270.3f, 132.7f, 138.8f, 144.9f, 141.4f, 147.9f, 154.4f});
|
|
|
|
input.linspace(-10, 0.5);
|
|
weights.linspace(0.1, 0.1);
|
|
bias = 0.2;
|
|
|
|
sd::ops::deconv3d op;
|
|
auto results = op.evaluate({&input, &weights}, {}, {kD,kH,kW, sD,sH,sW, pD,pH,pW, dD,dH,dW, paddingMode, dataFormat});
|
|
ASSERT_EQ(Status::OK(), results.status());
|
|
|
|
auto output = results.at(0);
|
|
|
|
ASSERT_TRUE(exp.isSameShape(output));
|
|
ASSERT_TRUE(exp.equalsTo(output));
|
|
|
|
}
|
|
|
|
//////////////////////////////////////////////////////////////////////
|
|
TEST_F(ConvolutionTests2, deconv3d_test6) {
|
|
|
|
int bS=2, oD=4,oH=4,oW=4, oC=5,iC=10, kD=2,kH=2,kW=2, sD=1,sH=1,sW=1, pD=0,pH=0,pW=0, dD=1,dH=1,dW=1;
|
|
int iD=3,iH=3,iW=3;
|
|
int paddingMode = 0; // 1-SAME, 0-VALID;
|
|
int dataFormat = 1; // 1-NHWC, 0-NCHW
|
|
int wFormat = 1; // 0 - [kD, kH, kW, oC, iC], 1 - [iC, oC, kD, kH, kW], 2 - [iC, kD, kH, kW, oC]
|
|
|
|
NDArray input('c', {bS, iD, iH, iW, iC}, sd::DataType::FLOAT32);
|
|
NDArray weights('c', {iC, oC, kD, kH, kW}, {20., 15., 10., 5., 0., -5., -10., -15., 19., 14., 9., 4., -1., -6., -11., -16., 18., 13., 8., 3., -2., -7., -12., -17.,
|
|
17., 12., 7., 2., -3., -8., -13., -18., 16., 11., 6., 1., -4., -9., -14., -19., 19.9, 14.9, 9.9, 4.9, -0.1, -5.1, -10.1, -15.1, 18.9, 13.9, 8.9, 3.9, -1.1, -6.1,
|
|
-11.1, -16.1, 17.9, 12.9, 7.9, 2.9, -2.1, -7.1, -12.1, -17.1, 16.9, 11.9, 6.9, 1.9, -3.1, -8.1, -13.1, -18.1, 15.9, 10.9, 5.9, 0.9, -4.1, -9.1, -14.1, -19.1,
|
|
19.799999, 14.8, 9.8, 4.8, -0.2, -5.2, -10.2, -15.2, 18.799999, 13.8, 8.8, 3.8, -1.2, -6.2, -11.2, -16.200001, 17.799999, 12.8, 7.8, 2.8, -2.2, -7.2, -12.2,
|
|
-17.200001, 16.799999, 11.8, 6.8, 1.8, -3.2, -8.2, -13.2, -18.200001, 15.8, 10.8, 5.8, 0.8, -4.2, -9.2, -14.2, -19.200001, 19.700001, 14.7, 9.7, 4.7, -0.3, -5.3, -10.3, -15.3, 18.700001, 13.7, 8.7, 3.7, -1.3, -6.3, -11.3, -16.299999, 17.700001, 12.7, 7.7, 2.7, -2.3, -7.3, -12.3, -17.299999, 16.700001, 11.7, 6.7, 1.7, -3.3, -8.3, -13.3, -18.299999, 15.7, 10.7, 5.7, 0.7, -4.3, -9.3, -14.3, -19.299999, 19.6, 14.6, 9.6, 4.6, -0.4, -5.4, -10.4, -15.4, 18.6, 13.6, 8.6, 3.6, -1.4, -6.4, -11.4, -16.4, 17.6, 12.6, 7.6, 2.6, -2.4, -7.4, -12.4, -17.4, 16.6, 11.6, 6.6, 1.6, -3.4, -8.4, -13.4, -18.4, 15.6, 10.6, 5.6, 0.6, -4.4, -9.4, -14.4, -19.4, 19.5, 14.5, 9.5, 4.5, -0.5, -5.5, -10.5, -15.5, 18.5, 13.5, 8.5, 3.5, -1.5, -6.5, -11.5, -16.5, 17.5, 12.5, 7.5, 2.5, -2.5, -7.5, -12.5, -17.5, 16.5, 11.5, 6.5, 1.5, -3.5, -8.5, -13.5, -18.5, 15.5, 10.5, 5.5, 0.5, -4.5, -9.5, -14.5, -19.5, 19.4, 14.4, 9.4, 4.4, -0.6, -5.6, -10.6, -15.6, 18.4, 13.4, 8.4, 3.4, -1.6, -6.6, -11.6, -16.6, 17.4, 12.4, 7.4, 2.4, -2.6, -7.6, -12.6, -17.6, 16.4, 11.4, 6.4, 1.4, -3.6, -8.6, -13.6, -18.6, 15.4, 10.4, 5.4, 0.4, -4.6, -9.6, -14.6, -19.6, 19.299999, 14.3, 9.3, 4.3, -0.7, -5.7, -10.7, -15.7, 18.299999, 13.3, 8.3, 3.3, -1.7, -6.7, -11.7, -16.700001, 17.299999, 12.3, 7.3, 2.3, -2.7, -7.7, -12.7, -17.700001, 16.299999, 11.3, 6.3, 1.3, -3.7, -8.7, -13.7, -18.700001, 15.3, 10.3, 5.3, 0.3, -4.7, -9.7, -14.7, -19.700001, 19.200001, 14.2, 9.2, 4.2, -0.8, -5.8, -10.8, -15.8, 18.200001, 13.2, 8.2, 3.2, -1.8, -6.8, -11.8, -16.799999, 17.200001, 12.2, 7.2, 2.2, -2.8, -7.8, -12.8, -17.799999, 16.200001, 11.2, 6.2, 1.2, -3.8, -8.8, -13.8, -18.799999, 15.2, 10.2, 5.2, 0.2, -4.8, -9.8, -14.8, -19.799999, 19.1, 14.1, 9.1, 4.1, -0.9, -5.9, -10.9, -15.9, 18.1, 13.1, 8.1, 3.1, -1.9, -6.9, -11.9, -16.9, 17.1, 12.1, 7.1, 2.1, -2.9, -7.9, -12.9, -17.9, 16.1, 11.1, 6.1, 1.1, -3.9, -8.9, -13.9, -18.9, 15.1, 10.1, 5.1, 0.1, -4.9, -9.9, -14.9, -19.9}, sd::DataType::FLOAT32);
|
|
NDArray expOutput('c', {bS, oD, oH, oW, oC}, {-5191.349609, -4925.850098, -4660.350098, -4394.850098, -4129.349609, -8859.700195, -8338.700195, -7817.700195,
|
|
-7296.700195, -6775.700195, -8518.700195, -8017.700195, -7516.700195, -7015.700195, -6514.700195, -3572.850098, -3327.349854, -3081.850098, -2836.350098,
|
|
-2590.850098, -7141.200195, -6640.200195, -6139.199707, -5638.200195, -5137.200195, -11486.400391, -10504.400391, -9522.400391, -8540.400391, -7558.399902,
|
|
-11004.400391, -10062.400391, -9120.400391, -8178.399414, -7236.399414, -4254.200195, -3793.200195, -3332.200195, -2871.199951, -2410.200195, -6268.200195,
|
|
-5827.200195, -5386.200195, -4945.200195, -4504.200195, -10040.400391, -9178.400391, -8316.400391, -7454.400391, -6592.399902, -9558.400391, -8736.400391,
|
|
-7914.400391, -7092.399902, -6270.400391, -3681.199707, -3280.200195, -2879.200195, -2478.200195, -2077.200195, -1963.350098, -1757.850098, -1552.349854, -1346.849976, -1141.349976, -2803.700195, -2402.699951, -2001.699951, -1600.699951, -1199.699951, -2662.699951, -2281.699951, -1900.699951, -1519.699951, -1138.700073, -844.850037, -659.349976, -473.850006, -288.350006, -102.849998, -3313.200195, -2872.199951, -2431.200195, -1990.200195, -1549.199829, -4230.399902, -3368.400391, -2506.400391, -1644.400146, -782.400146, -3948.400146, -3126.400391, -2304.399902, -1482.400146, -660.400269, -926.200195, -525.199951, -124.199951, 276.799927, 677.799805, -1643.400269, -821.400146, 0.599609, 822.600098, 1644.599609, 1005.199951, 2609.199707, 4213.200195, 5817.200195, 7421.200684, 1169.199463, 2693.200195, 4217.199707, 5741.201172, 7265.203125, 2430.599609, 3172.600098, 3914.600098, 4656.599609, 5398.599609, -1097.400391, -395.400269, 306.599609, 1008.599854, 1710.599731, 1497.199219, 2861.199219, 4225.201172, 5589.200684, 6953.200684, 1661.199219, 2945.199463, 4229.199707, 5513.201172, 6797.200684, 2376.599609, 2998.599854, 3620.599609, 4242.600098, 4864.600098, 1042.799927, 1363.799927, 1684.800171, 2005.799805, 2326.799805, 3681.599609, 4303.599609, 4925.599609, 5547.600098, 6169.599609, 3563.599609, 4145.599609, 4727.600098, 5309.600098, 5891.599609, 2429.800293, 2710.800293, 2991.799805, 3272.799805, 3553.799805, -1594.199829, -1333.199951, -1072.200073, -811.200012, -550.200134, -1692.400024, -1190.399902, -688.400024, -186.400269, 315.600098, -1410.399902, -948.399902, -486.399902, -24.399780, 437.599731, -107.199890, 113.799988, 334.799988, 555.799988, 776.800049, -5.400024, 456.599731, 918.600281, 1380.599731, 1842.599976, 2481.199219, 3365.199219, 4249.199219, 5133.199219, 6017.199219, 2645.199219, 3449.199219, 4253.199707, 5057.199219, 5861.199707, 2268.600098, 2650.599609, 3032.600098, 3414.600098, 3796.599609, 540.599976, 882.600220, 1224.599854, 1566.599854, 1908.600220, 2973.200195, 3617.199707, 4261.199219, 4905.199219, 5549.199219, 3137.199707, 3701.199219, 4265.199707, 4829.199219, 5393.199219, 2214.599609, 2476.600098, 2738.599609, 3000.599854, 3262.599854, 961.800049, 1102.800049, 1243.799927, 1384.800171, 1525.799927, 2619.599609, 2881.599854, 3143.599854, 3405.599609, 3667.599609, 2501.599854, 2723.599609, 2945.599854, 3167.599609, 3389.600098, 1448.799927, 1549.800049, 1650.799927, 1751.800049, 1852.799927, 37.650002, 123.150009, 208.650009, 294.149994, 379.650024, 498.300018, 659.300049, 820.300049, 981.299927, 1142.299927, 439.300018, 580.299988, 721.299927, 862.300049, 1003.300049, 356.149963, 421.649994, 487.150024, 552.649963, 618.150024, 916.799988, 1057.800049, 1198.800171, 1339.800049, 1480.800171, 2429.600098, 2691.600098, 2953.599609, 3215.599609, 3477.599609, 2111.599854, 2333.599854, 2555.600098, 2777.599609, 2999.600098, 1203.800049, 1304.800049, 1405.799927, 1506.800049, 1607.800049, 589.799927, 670.800049, 751.800049, 832.800049, 913.800049, 1475.599976, 1617.600098, 1759.600098, 1901.600098, 2043.600098, 1157.600098, 1259.600098, 1361.600098, 1463.600098, 1565.599976, 576.799988, 617.800049, 658.799988, 699.799927, 740.800049, 265.649994, 291.149994, 316.650024, 342.150024, 367.649994, 554.300049, 595.299988, 636.299927, 677.299988, 718.299988, 295.300018, 316.300018, 337.299988, 358.299988, 379.300018, 84.149994, 89.650002, 95.150002, 100.650009, 106.150009, 87.150002, 82.650002, 78.150002, 73.650002, 69.150002, 347.299988, 328.300018, 309.300018, 290.299988, 271.299988, 688.300049, 649.299927, 610.299988, 571.300049, 532.300049, 355.650024, 331.149963, 306.649994, 282.149994, 257.649994, 715.800049, 676.800049, 637.799988, 598.800049, 559.800049, 1527.600098, 1429.599976, 1331.599976, 1233.600098, 1135.600098, 2009.600098, 1871.600098, 1733.599976, 1595.600098, 1457.600098, 902.799988, 823.799927, 744.800049, 665.800049, 586.800049, 1588.800049, 1489.800049, 1390.800049, 1291.800049, 1192.799927, 2973.600098, 2755.600098, 2537.600098, 2319.600098, 2101.600098, 3455.600098, 3197.600098, 2939.600098, 2681.600098, 2423.600098, 1475.800049, 1336.800049, 1197.800049, 1058.799927, 919.800049, 615.150024, 550.650024, 486.149994, 421.649994, 357.150024, 1003.300049, 864.300049, 725.299988, 586.300049, 447.300018, 1144.300049, 985.299988, 826.300049, 667.299988, 508.299988, 383.649994, 299.149994, 214.649994, 130.149994, 45.649998, 1843.799927, 1744.799927, 1645.800049, 1546.799927, 1447.800049, 3383.600098, 3165.600098, 2947.600098, 2729.599854, 2511.600098, 3665.599854, 3407.600098, 3149.599854, 2891.599854, 2633.599854, 1530.800171, 1391.800049, 1252.800049, 1113.800049, 974.800171, 3270.599609, 3012.599854, 2754.600098, 2496.599854, 2238.600098, 5433.199707, 4877.200195, 4321.200195, 3765.199707, 3209.199951, 5597.200195, 4961.199707, 4325.200195, 3689.199707, 3053.199951, 1944.600098, 1606.599854, 1268.600098, 930.599976, 592.600098, 3816.599854, 3438.600342, 3060.599854, 2682.600098, 2304.600098, 5925.200195, 5129.200684, 4333.200195, 3537.199951, 2741.199707, 6089.200684, 5213.200195, 4337.200195, 3461.199707, 2585.200195, 1890.599609, 1432.600220, 974.599976, 516.599976, 58.599976, 799.799927, 580.800171, 361.800110, 142.800110, -76.200073, 495.599976, 37.599976, -420.399902, -878.399902, -1336.400024, 377.599854, -120.399902, -618.399902, -1116.400391, -1614.399902, -513.199951, -772.200012, -1031.199951, -1290.199829, -1549.200073, 3562.800049, 3283.799805, 3004.799805, 2725.800293, 2446.800293, 5921.599609, 5343.599609, 4765.600098, 4187.599609, 3609.599854, 6203.599609, 5585.600098, 4967.600098, 4349.599609, 3731.600098, 2349.799805, 2030.800171, 1711.800293, 1392.800171, 1073.799927, 4908.600098, 4290.599609, 3672.600098, 3054.600098, 2436.600098, 6909.199219, 5633.200684, 4357.200195, 3081.199219, 1805.199463, 7073.200684, 5717.199707, 4361.199219, 3005.199463, 1649.199951, 1782.600464, 1084.599609, 386.599609, -311.400146, -1009.400635, 5454.600098, 4716.599609, 3978.599854, 3240.600098, 2502.600098, 7401.199219, 5885.199219, 4369.200195, 2853.200195, 1337.199219, 7565.199219, 5969.200195, 4373.200195, 2777.199219, 1181.199219, 1728.599854, 910.600098, 92.600098, -725.400391, -1543.400391, 718.799927, 319.800049, -79.200073, -478.200073, -877.200073, -566.400391, -1384.400391, -2202.400391, -3020.400391, -3838.400391, -684.400146, -1542.400391, -2400.400391, -3258.400391, -4116.400391, -1494.200073, -1933.200073, -2372.199707, -2811.200195, -3250.199951, -83.850006, -268.350006, -452.849945, -637.350037, -821.849976, -1094.699951, -1473.699951, -1852.700073, -2231.699707, -2610.699951, -1153.700073, -1552.699829, -1951.699829, -2350.700195, -2749.700195, -1115.350098, -1319.849854, -1524.350098, -1728.849976, -1933.350098, -2026.200073, -2425.200195, -2824.200195, -3223.199707, -3622.200195, -6156.400391, -6974.400391, -7792.400391, -8610.400391, -9428.399414, -6474.400391, -7332.400391, -8190.400391, -9048.399414, -9906.399414, -4439.200195, -4878.199707, -5317.200195, -5756.200195, -6195.200195, -2353.199951, -2812.200195, -3271.200195, -3730.200195, -4189.200195, -7110.400391, -8048.400391, -8986.399414, -9924.400391, -10862.400391, -7428.400391, -8406.399414, -9384.399414, -10362.400391, -11340.400391, -5066.200195, -5565.200195, -6064.200195, -6563.200195, -7062.200195, -2555.849854, -2800.349854, -3044.849854, -3289.350098, -3533.850098, -6438.700195, -6937.700195, -7436.700195, -7935.700195, -8434.699219, -6697.700195, -7216.700195, -7735.700195, -8254.699219, -8773.700195, -4087.349854, -4351.850098, -4616.349609, -4880.850098, -5145.350098}, sd::DataType::FLOAT32);
|
|
|
|
input.linspace(-27, 0.1);
|
|
|
|
sd::ops::deconv3d op;
|
|
auto results = op.evaluate({&input, &weights}, {kD,kH,kW, sD,sH,sW, pD,pH,pW, dD,dH,dW, paddingMode, dataFormat, wFormat});
|
|
ASSERT_EQ(Status::OK(), results.status());
|
|
|
|
auto output = results.at(0);
|
|
|
|
ASSERT_TRUE(expOutput.isSameShape(output));
|
|
ASSERT_TRUE(expOutput.equalsTo(output, 1e-3));
|
|
}
|
|
|
|
//////////////////////////////////////////////////////////////////////
|
|
TEST_F(ConvolutionTests2, deconv3d_test7) {
|
|
|
|
int bS=2, oD=4,oH=4,oW=4, iC=5,oC=10, kD=2,kH=2,kW=2, sD=1,sH=1,sW=1, pD=1,pH=0,pW=0, dD=1,dH=1,dW=1;
|
|
int iD=4,iH=4,iW=4;
|
|
int paddingMode = 1; // 1-SAME, 0-VALID;
|
|
int dataFormat = 0; // 1-NHWC, 0-NCHW
|
|
int wFormat = 2; // 0 - [kD, kH, kW, oC, iC], 1 - [iC, oC, kD, kH, kW], 2 - [iC, kD, kH, kW, oC]
|
|
|
|
NDArray input('c', {bS, iC, iD, iH, iW}, sd::DataType::FLOAT32);
|
|
NDArray weights('c', {iC, kD, kH, kW, oC}, {20., 19.5, 19., 18.5, 18., 17.5, 17., 16.5, 16., 15.5, 15., 14.5, 14., 13.5, 13., 12.5, 12., 11.5, 11., 10.5, 10.,
|
|
9.5, 9., 8.5, 8., 7.5, 7., 6.5, 6., 5.5, 5., 4.5, 4., 3.5, 3., 2.5, 2., 1.5, 1., 0.5, 0., -0.5, -1., -1.5, -2., -2.5, -3., -3.5, -4., -4.5, -5., -5.5, -6.,
|
|
-6.5, -7., -7.5, -8., -8.5, -9., -9.5, -10., -10.5, -11., -11.5, -12., -12.5, -13., -13.5, -14., -14.5, -15., -15.5, -16., -16.5, -17., -17.5, -18., -18.5,
|
|
-19., -19.5, 19.9, 19.4, 18.9, 18.4, 17.9, 17.4, 16.9, 16.4, 15.9, 15.4, 14.9, 14.4, 13.9, 13.4, 12.9, 12.4, 11.9, 11.4, 10.9, 10.4, 9.9, 9.4, 8.9, 8.4, 7.9,
|
|
7.4, 6.9, 6.4, 5.9, 5.4, 4.9, 4.4, 3.9, 3.4, 2.9, 2.4, 1.9, 1.4, 0.9, 0.4, -0.1, -0.6, -1.1, -1.6, -2.1, -2.6, -3.1, -3.6, -4.1, -4.6, -5.1, -5.6, -6.1, -6.6, -7.1, -7.6, -8.1, -8.6, -9.1, -9.6, -10.1, -10.6, -11.1, -11.6, -12.1, -12.6, -13.1, -13.6, -14.1, -14.6, -15.1, -15.6, -16.1, -16.6, -17.1, -17.6, -18.1, -18.6, -19.1, -19.6, 19.799999, 19.299999, 18.799999, 18.299999, 17.799999, 17.299999, 16.799999, 16.299999, 15.8, 15.3, 14.8, 14.3, 13.8, 13.3, 12.8, 12.3, 11.8, 11.3, 10.8, 10.3, 9.8, 9.3, 8.8, 8.3, 7.8, 7.3, 6.8, 6.3, 5.8, 5.3, 4.8, 4.3, 3.8, 3.3, 2.8, 2.3, 1.8, 1.3, 0.8, 0.3, -0.2, -0.7, -1.2, -1.7, -2.2, -2.7, -3.2, -3.7, -4.2, -4.7, -5.2, -5.7, -6.2, -6.7, -7.2, -7.7, -8.2, -8.7, -9.2, -9.7, -10.2, -10.7, -11.2, -11.7, -12.2, -12.7, -13.2, -13.7, -14.2, -14.7, -15.2, -15.7, -16.200001, -16.700001, -17.200001, -17.700001, -18.200001, -18.700001, -19.200001, -19.700001, 19.700001, 19.200001, 18.700001, 18.200001, 17.700001, 17.200001, 16.700001, 16.200001, 15.7, 15.2, 14.7, 14.2, 13.7, 13.2, 12.7, 12.2, 11.7, 11.2, 10.7, 10.2, 9.7, 9.2, 8.7, 8.2, 7.7, 7.2, 6.7, 6.2, 5.7, 5.2, 4.7, 4.2, 3.7, 3.2, 2.7, 2.2, 1.7, 1.2, 0.7, 0.2, -0.3, -0.8, -1.3, -1.8, -2.3, -2.8, -3.3, -3.8, -4.3, -4.8, -5.3, -5.8, -6.3, -6.8, -7.3, -7.8, -8.3, -8.8, -9.3, -9.8, -10.3, -10.8, -11.3, -11.8, -12.3, -12.8, -13.3, -13.8, -14.3, -14.8, -15.3, -15.8, -16.299999, -16.799999, -17.299999, -17.799999, -18.299999, -18.799999, -19.299999, -19.799999, 19.6, 19.1, 18.6, 18.1, 17.6, 17.1, 16.6, 16.1, 15.6, 15.1, 14.6, 14.1, 13.6, 13.1, 12.6, 12.1, 11.6, 11.1, 10.6, 10.1, 9.6, 9.1, 8.6, 8.1, 7.6, 7.1, 6.6, 6.1, 5.6, 5.1, 4.6, 4.1, 3.6, 3.1, 2.6, 2.1, 1.6, 1.1, 0.6, 0.1, -0.4, -0.9, -1.4, -1.9, -2.4, -2.9, -3.4, -3.9, -4.4, -4.9, -5.4, -5.9, -6.4, -6.9, -7.4, -7.9, -8.4, -8.9, -9.4, -9.9, -10.4, -10.9, -11.4, -11.9, -12.4, -12.9, -13.4, -13.9, -14.4, -14.9, -15.4, -15.9, -16.4, -16.9, -17.4, -17.9, -18.4, -18.9, -19.4, -19.9}, sd::DataType::FLOAT32);
|
|
NDArray expOutput('c', {bS, oC, oD, oH, oW}, {-1907.199951, -3324.499756, -3307.199707, -3289.899902, -2814.799805, -4664.800293, -4640.199707, -4615.600098,
|
|
-2755.599854, -4566.400391, -4541.800293, -4517.199707, -2696.400146, -4468., -4443.400391, -4418.799805, -1735.999878, -2542.199951, -2527.600098, -2513.,
|
|
-1592.800049, -1355.999756, -1346.799805, -1337.599854, -1554.400024, -1319.199829, -1310.000122, -1300.800049, -1516., -1282.400024, -1273.200195, -1263.999878,
|
|
-1579.200073, -2308.599854, -2294., -2279.400146, -1439.199951, -1208.799683, -1199.599976, -1190.399902, -1400.800049, -1172., -1162.800049, -1153.600098,
|
|
-1362.399902, -1135.199951, -1126., -1116.799805, -1422.400024, -2075., -2060.399902, -2045.799683, -1285.599976, -1061.599854, -1052.399902, -1043.200195,
|
|
-1247.199951, -1024.800049, -1015.599976, -1006.400146, -1208.799927, -988.000122, -978.799683, -969.599976, -1859.199951, -3228.75, -3211.949951, -3195.150146, -2719.800049, -4475.299805, -4451.699707, -4428.100098, -2662.600098, -4380.899902, -4357.300293, -4333.699707, -2605.399902, -4286.5, -4262.899902, -4239.300293, -1643.999878, -2358.700195, -2345.099854, -2331.5, -1410.800049, -992.999756, -985.799438, -978.600098, -1376.400024, -964.199707, -957., -949.800049, -1342., -935.399902, -928.199951, -921.000122, -1495.200073, -2141.099854, -2127.5, -2113.900391, -1273.199951, -877.799683, -870.599976, -863.39978, -1238.800049, -849., -841.800171, -834.599976, -1204.400024, -820.199707, -813., -805.799438, -1346.400146, -1923.500122, -1909.899902, -1896.299927, -1135.599976, -762.599976, -755.399658, -748.200195, -1101.199951, -733.800049, -726.599854, -719.400024, -1066.800049, -705., -697.800171, -690.599976, -1811.199951, -3133., -3116.699951, -3100.399902, -2624.799805, -4285.799805, -4263.199707, -4240.600098, -2569.600098, -4195.399902, -4172.800293, -4150.199707, -2514.399902, -4105., -4082.400146, -4059.800293, -1552., -2175.200195, -2162.599854, -2150., -1228.800049, -630., -624.799561, -619.599854, -1198.400024, -609.199463, -603.999756, -598.800049, -1167.999878, -588.400391, -583.199951, -578., -1411.200073, -1973.599854, -1961.000122, -1948.400146, -1107.199829, -546.800171, -541.599976, -536.400269, -1076.800049, -525.999756, -520.800049, -515.599976, -1046.400146, -505.199829, -500., -494.799683, -1270.399902, -1772., -1759.400146, -1746.799927, -985.599976, -463.600098, -458.399902, -453.199951, -955.199951, -442.799927, -437.599976, -432.400269, -924.799988, -422.000122, -416.800171, -411.599976, -1763.199951, -3037.25, -3021.449951, -3005.649902, -2529.800293, -4096.299805, -4074.699951, -4053.100098, -2476.600098, -4009.900146, -3988.300049, -3966.699951, -2423.399902, -3923.5, -3901.899902, -3880.299805, -1459.999878, -1991.699951, -1980.099854, -1968.500122, -1046.800049, -266.999878, -263.799805, -260.599854, -1020.400146, -254.199829, -251., -247.799927, -994., -241.400269, -238.200073, -234.999878, -1327.200073, -1806.099854, -1794.500122, -1782.900146, -941.199951, -215.799927, -212.600098, -209.399902, -914.799988, -203.000122, -199.799683, -196.599976, -888.400024, -190.200317, -186.999878, -183.799805, -1194.399902, -1620.500122, -1608.899902, -1597.299927, -835.599915, -164.599976, -161.400269, -158.200195, -809.200073, -151.799927, -148.599976, -145.400024, -782.799927, -139., -135.799805, -132.599976, -1715.200073, -2941.5, -2926.199951, -2910.899902, -2434.800049, -3906.799805, -3886.199951, -3865.599609, -2383.600098, -3824.400391, -3803.800049, -3783.199951, -2332.400146, -3742., -3721.400146, -3700.799805, -1367.999878, -1808.199707, -1797.599854, -1786.999878, -864.800049, 95.999878, 97.200073, 98.400024, -842.39978, 100.799927, 102.000244, 103.200439, -820., 105.599609, 106.800171, 108., -1243.199951, -1638.599854, -1628.000122, -1617.400146, -775.199829, 115.200195, 116.400146, 117.60022, -752.799805, 120., 121.200073, 122.400024, -730.399841, 124.799927, 125.999878, 127.199951, -1118.400024, -1468.999878, -1458.400146, -1447.799927, -685.599915, 134.400146, 135.60022, 136.800171, -663.199951, 139.200073, 140.399902, 141.599731, -640.799988, 144., 145.200195, 146.400146, -1667.199951, -2845.749756, -2830.949707, -2816.149902, -2339.799805, -3717.300049, -3697.699951, -3678.100098, -2290.600098, -3638.900146, -3619.300049, -3599.699951, -2241.399902, -3560.5, -3540.899902, -3521.299805, -1276., -1624.699951, -1615.100098, -1605.499878, -682.799927, 459.000122, 458.199951, 457.400146, -664.400024, 455.800049, 454.999878, 454.200439, -646.000122, 452.599976, 451.799805, 451.000122, -1159.200073, -1471.099854, -1461.5, -1451.900146, -609.199829, 446.200195, 445.400024, 444.600098, -590.799927, 443., 442.200073, 441.399658, -572.39978, 439.799927, 439.000122, 438.200073, -1042.399902, -1317.499756, -1307.900146, -1298.299683, -535.599976, 433.399963, 432.600098, 431.799744, -517.200012, 430.200195, 429.400024, 428.599976, -498.799927, 427.000061, 426.200256, 425.400024, -1619.199951, -2750., -2735.699951, -2721.399902, -2244.799805, -3527.799805, -3509.199951, -3490.600098, -2197.600098, -3453.400146, -3434.800049, -3416.199951, -2150.399902, -3379., -3360.400146, -3341.800049, -1184., -1441.199951, -1432.599854, -1424., -500.799927, 822.000122, 819.200195, 816.400146, -486.400024, 810.799927, 808.000244, 805.200073, -472., 799.60022, 796.799683, 794.000122, -1075.199951, -1303.599854, -1295.000122, -1286.400024, -443.199951, 777.200073, 774.400024, 771.599854, -428.799927, 766., 763.200317, 760.400024, -414.400146, 754.800049, 752.000244, 749.200195, -966.400146, -1166.000122, -1157.400146, -1148.799927, -385.600098, 732.400024, 729.599976, 726.799927, -371.200134, 721.200012, 718.400146, 715.599792, -356.799988, 710.000183, 707.199951, 704.400024, -1571.199951, -2654.25, -2640.449951, -2626.649902, -2149.800049, -3338.299805, -3320.699951, -3303.100098, -2104.600098, -3267.900146, -3250.299805, -3232.699951, -2059.399902, -3197.5, -3179.900146, -3162.300049, -1092., -1257.699951, -1250.099854, -1242.499878, -318.799927, 1185.000122, 1180.200439, 1175.400146, -308.399902, 1165.800293, 1161.000122, 1156.200073, -298., 1146.599731, 1141.800049, 1137.000122, -991.199951, -1136.099976, -1128.500122, -1120.899902, -277.199951, 1108.199829, 1103.400146, 1098.599976, -266.799927, 1089.000366, 1084.199951, 1079.400024, -256.399902, 1069.799927, 1065.000122, 1060.200317, -890.400024, -1014.5, -1006.900024, -999.299988, -235.599976, 1031.399902, 1026.599854, 1021.800049, -225.199951, 1012.200195, 1007.400024, 1002.599854, -214.799805, 992.999878, 988.199707, 983.400146, -1523.199951, -2558.5, -2545.199951, -2531.899902, -2054.800049, -3148.800049, -3132.199951, -3115.599854, -2011.599976, -3082.400146, -3065.800049, -3049.199951, -1968.400024, -3016., -2999.400146, -2982.799805, -1000.000061, -1074.199951, -1067.599976, -1061.000244, -136.799805, 1548.000244, 1541.200195, 1534.400269, -130.400146, 1520.800171, 1514.000122, 1507.200073, -124., 1493.600098, 1486.799805, 1480.000244, -907.200073, -968.599976, -962.000122, -955.400085, -111.199951, 1439.200073, 1432.399902, 1425.599854, -104.800049, 1412.000122, 1405.200195, 1398.400024, -98.400024, 1384.799927, 1378.000366, 1371.200195, -814.400024, -862.999939, -856.399902, -849.799927, -85.599976, 1330.400024, 1323.599854, 1316.799927, -79.200073, 1303.200073, 1296.399902, 1289.599731, -72.799927, 1276., 1269.200195, 1262.400024, -1475.200073, -2462.75, -2449.949951, -2437.149902, -1959.800049, -2959.299805, -2943.699951, -2928.099854, -1918.599976, -2896.900146, -2881.300049, -2865.699951, -1877.399902, -2834.5, -2818.900146, -2803.300049, -907.999939, -890.700012, -885.099915, -879.499878, 45.199829, 1911., 1902.200073, 1893.400024, 47.599976, 1875.800293, 1867.000244, 1858.200073, 49.999878, 1840.599976, 1831.800171, 1823.000244, -823.200073, -801.100098, -795.500061, -789.900024, 54.799927, 1770.199951, 1761.400269, 1752.599976, 57.200073, 1735., 1726.200073, 1717.400269, 59.599976, 1699.799805, 1691., 1682.200073, -738.400024, -711.499817, -705.900085, -700.299927, 64.400146, 1629.399902, 1620.599976, 1611.800171, 66.800049, 1594.200195, 1585.39978, 1576.599976, 69.200073, 1559.000122, 1550.199829, 1541.400146, 1260.800049, 2211.5, 2228.800049, 2246.100098, 1921.200073, 3207.200195, 3231.800049, 3256.399902, 1980.400024, 3305.599854, 3330.200195, 3354.800049, 2039.599854, 3404., 3428.599854, 3453.200195, 1400., 2129.800049, 2144.400146, 2159., 1479.199951, 1588.000244, 1597.200073, 1606.400024, 1517.599976, 1624.800171, 1634., 1643.199951, 1556., 1661.600098, 1670.800171, 1679.999878, 1556.799927, 2363.400146, 2378., 2392.600098, 1632.799805, 1735.199951, 1744.400146, 1753.600098, 1671.199829, 1771.999878, 1781.200073, 1790.400024, 1709.60022, 1808.800171, 1818.000244, 1827.200073, 1713.599976, 2597., 2611.599854, 2626.199951, 1786.400024, 1882.400024, 1891.600098, 1900.800171, 1824.799805, 1919.200195, 1928.400146, 1937.600098, 1863.199951, 1956., 1965.199951, 1974.400391, 1228.800049, 2147.25, 2164.049805, 2180.850098, 1856.199951, 3076.700195, 3100.300049, 3123.899902, 1913.400024, 3171.099854, 3194.700195, 3218.300049, 1970.599976, 3265.5, 3289.099854, 3312.699951, 1332., 1993.300049, 2006.900146, 2020.499878, 1341.199951, 1310.999878, 1318.199951, 1325.400146, 1375.60022, 1339.800171, 1347., 1354.199951, 1410., 1368.600098, 1375.800171, 1383., 1480.800049, 2210.900146, 2224.5, 2238.100098, 1478.799805, 1426.200073, 1433.400146, 1440.599609, 1513.199951, 1455., 1462.199951, 1469.400024, 1547.60022, 1483.799927, 1490.999878, 1498.199951, 1629.599976, 2428.500244, 2442.100098, 2455.699951, 1616.399902, 1541.400146, 1548.600098, 1555.799683, 1650.800049, 1570.200073, 1577.400024, 1584.600098, 1685.199951, 1598.99939, 1606.200317, 1613.400024, 1196.800049, 2083., 2099.300049, 2115.600098, 1791.200073, 2946.200195, 2968.800049, 2991.400146, 1846.400024, 3036.599854, 3059.200195, 3081.800049, 1901.599976, 3127., 3149.599854, 3172.200195, 1264., 1856.800049, 1869.400146, 1881.999878, 1203.200073, 1034., 1039.200073, 1044.400146, 1233.599976, 1054.799927, 1059.999878, 1065.199951, 1263.999878, 1075.599609, 1080.800171, 1086., 1404.799927, 2058.400146, 2071., 2083.599854, 1324.799927, 1117.199951, 1122.400146, 1127.599609, 1355.199951, 1138., 1143.200439, 1148.400146, 1385.599976, 1158.800171, 1164.000244, 1169.200073, 1545.599976, 2260., 2272.600098, 2285.199951, 1446.400024, 1200.400146, 1205.600098, 1210.800171, 1476.799805, 1221.199951, 1226.400024, 1231.600098, 1507.199951, 1242.000244, 1247.200073, 1252.400146, 1164.800049, 2018.75, 2034.549927, 2050.350098, 1726.200073, 2815.700195, 2837.300049, 2858.900146, 1779.400024, 2902.099854, 2923.700195, 2945.300049, 1832.599976, 2988.5, 3010.099854, 3031.700195, 1196.000122, 1720.300049, 1731.900146, 1743.499878, 1065.200073, 757.000122, 760.200073, 763.400024, 1091.599976, 769.800171, 773., 776.199951, 1118., 782.599976, 785.800049, 789., 1328.800049, 1905.900146, 1917.499878, 1929.100098, 1170.799805, 808.200073, 811.400024, 814.60022, 1197.199951, 821., 824.199951, 827.400024, 1223.599976, 833.799927, 837.000244, 840.199951, 1461.599976, 2091.5, 2103.100098, 2114.700195, 1276.400146, 859.400024, 862.600098, 865.800293, 1302.799927, 872.200073, 875.400146, 878.599854, 1329.199951, 885., 888.199951, 891.400024, 1132.800049, 1954.500122, 1969.799927, 1985.099976, 1661.199951, 2685.200195, 2705.800049, 2726.399902, 1712.399902, 2767.599854, 2788.200195, 2808.800049, 1763.599976, 2850., 2870.599854, 2891.199951, 1128., 1583.800049, 1594.400146, 1605., 927.200012, 480., 481.199951, 482.400146, 949.599976, 484.800171, 486., 487.200073, 971.999878, 489.599731, 490.800171, 492.000122, 1252.799927, 1753.400146, 1763.999878, 1774.600098, 1016.799805, 499.200195, 500.400024, 501.60022, 1039.199951, 504., 505.199951, 506.400146, 1061.599976, 508.799927, 510., 511.200195, 1377.599976, 1923.000122, 1933.600098, 1944.200073, 1106.400024, 518.400024, 519.60022, 520.800171, 1128.799927, 523.199829, 524.400024, 525.600098, 1151.199829, 528., 529.199829, 530.400146, 1100.800049, 1890.25, 1905.050049, 1919.849976, 1596.199951, 2554.700195, 2574.300049, 2593.900146, 1645.399902, 2633.099854, 2652.700195, 2672.300049, 1694.599976, 2711.5, 2731.099854, 2750.700195, 1060., 1447.299805, 1456.900146, 1466.499878, 789.200012, 203.000122, 202.200195, 201.400146, 807.600098, 199.800171, 199., 198.200195, 826., 196.599731, 195.800049, 195., 1176.799927, 1600.900146, 1610.500244, 1620.099854, 862.80011, 190.200317, 189.400146, 188.60022, 881.199951, 187., 186.199829, 185.400024, 899.60022, 183.800171, 183., 182.200073, 1293.599976, 1754.499878, 1764.099854, 1773.700073, 936.400024, 177.400146, 176.60022, 175.800049, 954.799805, 174.199951, 173.400024, 172.599854, 973.200073, 171., 170.200073, 169.400146, 1068.800049, 1826., 1840.299927, 1854.599976, 1531.199951, 2424.200195, 2442.800049, 2461.399902, 1578.399902, 2498.599854, 2517.199951, 2535.800049, 1625.599976, 2573., 2591.599854, 2610.200195, 991.999939, 1310.800049, 1319.400146, 1328., 651.199951, -74., -76.799805, -79.599854, 665.600098, -85.199829, -87.999756, -90.799805, 680., -96.400024, -99.199829, -102., 1100.800049, 1448.400146, 1456.999878, 1465.600098, 708.800049, -118.799805, -121.599976, -124.400269, 723.199829, -130., -132.800171, -135.599976, 737.599976, -141.200073, -144., -146.799805, 1209.599976, 1586., 1594.600098, 1603.200073, 766.400146, -163.599976, -166.39978, -169.200073, 780.800049, -174.799927, -177.599976, -180.400146, 795.199951, -185.999878, -188.800171, -191.599854, 1036.800049, 1761.75, 1775.550049, 1789.349976, 1466.200073, 2293.700195, 2311.300049, 2328.900146, 1511.399902, 2364.099854, 2381.700195, 2399.300049, 1556.599976, 2434.5, 2452.099854, 2469.700195, 923.999939, 1174.300049, 1181.899902, 1189.5, 513.200073, -350.999756, -355.799805, -360.599854, 523.599976, -370.199951, -374.999939, -379.799805, 534., -389.400146, -394.19989, -398.999817, 1024.800049, 1295.900146, 1303.5, 1311.10022, 554.799927, -427.800171, -432.599854, -437.400146, 565.199951, -446.999878, -451.799805, -456.599854, 575.599976, -466.200317, -470.999756, -475.799805, 1125.599976, 1417.499878, 1425.100098, 1432.700073, 596.400024, -504.599854, -509.400269, -514.199951, 606.800049, -523.800171, -528.599609, -533.400146, 617.200073, -542.999878, -547.800171, -552.599854, 1004.800049, 1697.5, 1710.799927, 1724.099976, 1401.199951, 2163.200195, 2179.800049, 2196.400146, 1444.400024, 2229.599854, 2246.200195, 2262.800049, 1487.599976, 2296., 2312.599854, 2329.200195, 855.999939, 1037.800049, 1044.400146, 1051., 375.199951, -627.999756, -634.800171, -641.599976, 381.599976, -655.199829, -661.999878, -668.80011, 388.000061, -682.400146, -689.199951, -695.999756, 948.799988, 1143.400146, 1149.999878, 1156.60022, 400.799805, -736.799927, -743.599976, -750.399902, 407.200073, -763.999878, -770.799805, -777.599731, 413.599976, -791.200073, -797.999756, -804.800171, 1041.599976, 1248.999878, 1255.60022, 1262.200073, 426.399902, -845.599854, -852.400146, -859.200073, 432.799927, -872.799805, -879.599854, -886.400024, 439.200073, -899.999878, -906.799927, -913.599976, 972.800049, 1633.25, 1646.049927, 1658.850098, 1336.200073, 2032.700195, 2048.300049, 2063.900146, 1377.400024, 2095.099854, 2110.700195, 2126.300049, 1418.599976, 2157.5, 2173.099854, 2188.700195, 787.999939, 901.299988, 906.899963, 912.500061, 237.200012, -904.999817, -913.799866, -922.599792, 239.599976, -940.199707, -948.999817, -957.800171, 242., -975.400146, -984.199829, -992.999756, 872.799988, 990.899963, 996.499878, 1002.10022, 246.800049, -1045.799927, -1054.599854, -1063.400024, 249.200073, -1080.999878, -1089.799805, -1098.599854, 251.600098, -1116.199951, -1124.999878, -1133.799683, 957.599976, 1080.499878, 1086.10022, 1091.700073, 256.400024, -1186.599854, -1195.400146, -1204.199829, 258.799927, -1221.800171, -1230.599976, -1239.400269, 261.199951, -1257., -1265.799927, -1274.600098}, sd::DataType::FLOAT32);
|
|
|
|
input.linspace(-32, 0.1);
|
|
|
|
sd::ops::deconv3d op;
|
|
auto results = op.evaluate({&input, &weights}, {kD,kH,kW, sD,sH,sW, pD,pH,pW, dD,dH,dW, paddingMode, dataFormat, wFormat});
|
|
ASSERT_EQ(Status::OK(), results.status());
|
|
|
|
auto output = results.at(0);
|
|
|
|
ASSERT_TRUE(expOutput.isSameShape(output));
|
|
ASSERT_TRUE(expOutput.equalsTo(output));
|
|
}
|
|
|
|
//////////////////////////////////////////////////////////////////////
|
|
TEST_F(ConvolutionTests2, deconv3d_bp_test1) {
|
|
|
|
int bS=1, iD=3,iH=3,iW=3, iC=1,oC=2, kD=2,kH=2,kW=2, sD=1,sH=1,sW=1, pD=0,pH=0,pW=0, dD=1,dH=1,dW=1;
|
|
int oD=2,oH=2,oW=2;
|
|
int paddingMode = 0; // 1-SAME, 0-VALID;
|
|
int dataFormat = 1; // 1-NDHWC, 0-NCDHW
|
|
|
|
auto input = NDArrayFactory::create<float>('c', {bS, oD, oH, oW, oC});
|
|
auto weights = NDArrayFactory::create<float>('c', {kD, kH, kW, iC, oC});
|
|
auto bias = NDArrayFactory::create<float>('c', {iC});
|
|
auto gradO = NDArrayFactory::create<float>('c', {bS, iD, iH, iW, iC});
|
|
|
|
NDArray expGradI('c', {bS, oD, oH, oW, oC}, {62., 67.6, 68.4, 74.8, 81.2, 89.2, 87.6, 96.4, 119.6, 132.4, 126., 139.6, 138.8, 154., 145.2, 161.2}, sd::DataType::FLOAT32);
|
|
NDArray expGradW('c', {kD, kH, kW, iC, oC}, {28., 28., 32., 32., 40., 40., 44., 44., 64, 64., 68., 68., 76., 76., 80., 80.}, sd::DataType::FLOAT32);
|
|
NDArray expGradB('c', {iC}, std::vector<double>{364.5}, sd::DataType::FLOAT32);
|
|
|
|
input = 0.5;
|
|
weights.linspace(0.1, 0.1);
|
|
gradO.linspace(0.5);
|
|
|
|
sd::ops::deconv3d_bp op;
|
|
auto results = op.evaluate({&input, &weights, &bias, &gradO}, {}, {kD,kH,kW, sD,sH,sW, pD,pH,pW, dD,dH,dW, paddingMode, dataFormat}, {});
|
|
|
|
auto gradI = results.at(0);
|
|
auto gradW = results.at(1);
|
|
auto gradB = results.at(2);
|
|
|
|
ASSERT_EQ(Status::OK(), results.status());
|
|
|
|
ASSERT_TRUE(expGradI.isSameShape(gradI));
|
|
ASSERT_TRUE(expGradI.equalsTo(gradI));
|
|
|
|
ASSERT_TRUE(expGradW.isSameShape(gradW));
|
|
ASSERT_TRUE(expGradW.equalsTo(gradW));
|
|
|
|
ASSERT_TRUE(expGradB.isSameShape(gradB));
|
|
ASSERT_TRUE(expGradB.equalsTo(gradB));
|
|
|
|
}
|
|
|
|
//////////////////////////////////////////////////////////////////////
|
|
TEST_F(ConvolutionTests2, deconv3d_bp_test2) {
|
|
|
|
int bS=1, iD=2,iH=2,iW=2, iC=1,oC=2, kD=2,kH=2,kW=2, sD=1,sH=1,sW=1, pD=0,pH=0,pW=0, dD=1,dH=1,dW=1;
|
|
int oD=2,oH=2,oW=2;
|
|
int paddingMode = 1; // 1-SAME, 0-VALID;
|
|
int dataFormat = 1; // 1-NDHWC, 0-NCDHW
|
|
|
|
auto input = NDArrayFactory::create<float>('c', {bS, oD, oH, oW, oC});
|
|
auto weights = NDArrayFactory::create<float>('c', {kD, kH, kW, iC, oC});
|
|
auto gradO = NDArrayFactory::create<float>('c', {bS, iD, iH, iW, iC});
|
|
|
|
NDArray expGradI('c', {bS, oD, oH, oW, oC}, {34, 37.2, 16.6, 18.4, 15.4, 17.4, 7.1, 8.2, 10.6, 13., 4.3, 5.6, 2.9, 4.3, 0.75, 1.5}, sd::DataType::FLOAT32);
|
|
NDArray expGradW('c', {kD, kH, kW, iC, oC}, {16, 16, 9, 9, 10, 10, 5.5, 5.5, 12, 12, 6.5, 6.5, 7, 7, 3.75, 3.75}, sd::DataType::FLOAT32);
|
|
|
|
input = 0.5;
|
|
weights.linspace(0.1, 0.1);
|
|
gradO.linspace(0.5);
|
|
|
|
sd::ops::deconv3d_bp op;
|
|
auto results = op.evaluate({&input, &weights, &gradO}, {}, {kD,kH,kW, sD,sH,sW, pD,pH,pW, dD,dH,dW, paddingMode, dataFormat}, {});
|
|
|
|
auto gradI = results.at(0);
|
|
auto gradW = results.at(1);
|
|
|
|
ASSERT_EQ(Status::OK(), results.status());
|
|
|
|
ASSERT_TRUE(expGradI.isSameShape(gradI));
|
|
ASSERT_TRUE(expGradI.equalsTo(gradI));
|
|
|
|
ASSERT_TRUE(expGradW.isSameShape(gradW));
|
|
ASSERT_TRUE(expGradW.equalsTo(gradW));
|
|
|
|
}
|
|
|
|
//////////////////////////////////////////////////////////////////////
|
|
TEST_F(ConvolutionTests2, deconv3d_bp_test3) {
|
|
|
|
int bS=1, iD=3,iH=3,iW=3, iC=1,oC=2, kD=2,kH=2,kW=2, sD=1,sH=1,sW=1, pD=0,pH=0,pW=0, dD=1,dH=1,dW=1;
|
|
int oD=2,oH=2,oW=2;
|
|
int paddingMode = 0; // 1-SAME, 0-VALID;
|
|
int dataFormat = 0; // 1-NDHWC, 0-NCDHW
|
|
|
|
auto input = NDArrayFactory::create<float>('c', {bS, oC, oD, oH, oW});
|
|
auto weights = NDArrayFactory::create<float>('c', {kD, kH, kW, iC, oC}, {0.1f, 0.9f, 0.2f, 0.1f, 0.3f, 1.1f, 0.4f, 1.2f, 0.5f, 1.3f, 0.6f, 1.4f, 0.7f, 1.5f, 0.8f, 1.6f});
|
|
auto gradO = NDArrayFactory::create<float>('c', {bS, iC, iD, iH, iW});
|
|
|
|
NDArray expGradI('c', {bS, oD, oH, oW, oC}, {33.8, 37.4, 44.6, 48.2, 66.2, 69.8, 77., 80.6, 77.25, 86.35, 104.55, 113.65, 159.15, 168.25, 186.45, 195.55}, sd::DataType::FLOAT32);
|
|
NDArray expGradW('c', {kD, kH, kW, iC, oC}, {28., 28, 32, 32, 40, 40, 44, 44, 64, 64, 68, 68, 76, 76, 80, 80.}, sd::DataType::FLOAT32);
|
|
|
|
input = 0.5;
|
|
gradO.linspace(0.5);
|
|
|
|
sd::ops::deconv3d_bp op;
|
|
auto results = op.evaluate({&input, &weights, &gradO}, {}, {kD,kH,kW, sD,sH,sW, pD,pH,pW, dD,dH,dW, paddingMode, dataFormat}, {});
|
|
|
|
auto gradI = results.at(0);
|
|
auto gradW = results.at(1);
|
|
|
|
ASSERT_EQ(Status::OK(), results.status());
|
|
|
|
ASSERT_TRUE(expGradI.isSameShape(gradI));
|
|
ASSERT_TRUE(expGradI.equalsTo(gradI));
|
|
|
|
ASSERT_TRUE(expGradW.isSameShape(gradW));
|
|
ASSERT_TRUE(expGradW.equalsTo(gradW));
|
|
|
|
}
|
|
|
|
//////////////////////////////////////////////////////////////////////
|
|
TEST_F(ConvolutionTests2, deconv3d_bp_test4) {
|
|
|
|
int bS=1, iD=2,iH=2,iW=2, iC=1,oC=2, kD=2,kH=2,kW=2, sD=1,sH=1,sW=1, pD=1,pH=1,pW=1, dD=1,dH=1,dW=1;
|
|
int oD=3,oH=3,oW=3;
|
|
int paddingMode = 0; // 1-SAME, 0-VALID;
|
|
int dataFormat = 0; // 1-NDHWC, 0-NCDHW
|
|
|
|
auto input = NDArrayFactory::create<float>('c', {bS, oC, oD, oH, oW});
|
|
auto weights = NDArrayFactory::create<float>('c', {kD, kH, kW, iC, oC}, {0.1f, 0.9f, 0.2f, 0.1f, 0.3f, 1.1f, 0.4f, 1.2f, 0.5f, 1.3f, 0.6f, 1.4f, 0.7f, 1.5f, 0.8f, 1.6f});
|
|
auto gradO = NDArrayFactory::create<float>('c', {bS, iC, iD, iH, iW});
|
|
|
|
NDArray expGradI('c', {bS, oC, oD, oH, oW}, {0.4, 1.55, 1.05, 2.3, 5.7, 3.2, 1.5, 3.35, 1.75, 3.8, 8.3, 4.3, 9.0, 18.6, 9.2, 4.4, 8.7, 4.1, 1.8, 3.55, 1.65, 3.5, 6.5, 2.8, 1.3, 2.15, 0.75, 0.8, 3.15, 2.25, 4.7, 12.1, 7.2, 3.5, 8.15, 4.55, 7.8, 17.9, 9.9, 19.75, 42.85, 23.6, 9.35, 21.55, 12.9, 5.4, 11.55, 6.05, 8.25, 20.75, 13.2, 0.65, 6.6, 6.75}, sd::DataType::FLOAT32);
|
|
NDArray expGradW('c', {kD, kH, kW, iC, oC}, {16.0, 16.0, 16.0, 16.0, 16.0, 16.0, 16.0, 16.0, 16.0, 16.0, 16.0, 16.0, 16.0, 16.0, 16.0, 16.}, sd::DataType::FLOAT32);
|
|
|
|
input = 0.5;
|
|
gradO.linspace(0.5);
|
|
|
|
sd::ops::deconv3d_bp op;
|
|
auto results = op.evaluate({&input, &weights, &gradO}, {}, {kD,kH,kW, sD,sH,sW, pD,pH,pW, dD,dH,dW, paddingMode, dataFormat}, {});
|
|
|
|
auto gradI = results.at(0);
|
|
auto gradW = results.at(1);
|
|
|
|
ASSERT_EQ(Status::OK(), results.status());
|
|
|
|
ASSERT_TRUE(expGradI.isSameShape(gradI));
|
|
ASSERT_TRUE(expGradI.equalsTo(gradI));
|
|
|
|
ASSERT_TRUE(expGradW.isSameShape(gradW));
|
|
ASSERT_TRUE(expGradW.equalsTo(gradW));
|
|
}
|
|
|
|
//////////////////////////////////////////////////////////////////////
|
|
TEST_F(ConvolutionTests2, deconv3d_bp_test5) {
|
|
|
|
int bS=2, iD=4,iH=4,iW=4, iC=3,oC=2, kD=2,kH=1,kW=1, sD=1,sH=1,sW=1, pD=0,pH=0,pW=0, dD=1,dH=1,dW=1;
|
|
int oD=4,oH=4,oW=4;
|
|
int paddingMode = 1; // 1-SAME, 0-VALID;
|
|
int dataFormat = 0; // 1-NHWC, 0-NCHW
|
|
int wFormat = 1; // 0 - [kD, kH, kW, oC, iC], 1 - [iC, oC, kD, kH, kW], 2 - [iC, kD, kH, kW, oC]
|
|
|
|
NDArray input('c', {bS, iC, iD, iH, iW}, sd::DataType::FLOAT32);
|
|
NDArray bias('c', {oC}, {-0.1, 0.2}, sd::DataType::FLOAT32);
|
|
NDArray weights('c',{iC, oC, kD, kH, kW}, {-0.6, 0., -0.3, 0.3, -0.5, 0.1, -0.2, 0.4, -0.4, 0.2, -0.1, 0.5}, sd::DataType::FLOAT32);
|
|
NDArray gradO('c', {bS, oC, oD, oH, oW},sd::DataType::FLOAT32);
|
|
|
|
NDArray expGradI('c', {bS, iC, iD, iH, iW}, {9.696001, 9.684001, 9.672001, 9.66, 9.648001, 9.636, 9.624001, 9.612, 9.600001, 9.587999, 9.576, 9.564001, 9.552,
|
|
9.540001, 9.528, 9.516, 9.504001, 9.492, 9.480001, 9.468, 9.455999, 9.444, 9.432001, 9.420001, 9.408001, 9.396, 9.384001, 9.372001, 9.36, 9.348001, 9.335999,
|
|
9.324001, 9.312, 9.300001, 9.288001, 9.276001, 9.264, 9.252001, 9.24, 9.228001, 9.216, 9.204, 9.191999, 9.18, 9.168001, 9.156, 9.144001, 9.132, 13.152, 13.134001,
|
|
13.116, 13.098, 13.080001, 13.062, 13.044001, 13.026001, 13.008001, 12.990001, 12.972, 12.954, 12.936001, 12.918, 12.900002, 12.882, 3.616001, 3.612, 3.608, 3.604,
|
|
3.6, 3.596, 3.592, 3.588, 3.584001, 3.579999, 3.576001, 3.571999, 3.568, 3.564, 3.56, 3.556, 3.552, 3.548, 3.544, 3.539999, 3.536001, 3.532001, 3.527999, 3.524001, 3.52, 3.516, 3.512, 3.508, 3.504, 3.5, 3.496, 3.492, 3.487999, 3.484001, 3.48, 3.476, 3.472, 3.468, 3.464, 3.46, 3.456, 3.452, 3.447999, 3.444001, 3.439999, 3.436, 3.432001, 3.428, 10.272, 10.258, 10.244, 10.23, 10.216, 10.202, 10.188, 10.174, 10.16, 10.146, 10.132, 10.118, 10.104, 10.09, 10.076, 10.062, -2.464, -2.460001, -2.455999, -2.452, -2.448, -2.444, -2.44, -2.436, -2.432, -2.428, -2.424, -2.42, -2.415999, -2.412, -2.408, -2.404, -2.4, -2.396, -2.392, -2.388, -2.384, -2.38, -2.376, -2.372, -2.368, -2.363999, -2.36, -2.356, -2.352, -2.348, -2.344, -2.34, -2.336, -2.332, -2.328001, -2.323999, -2.32, -2.316, -2.312, -2.308, -2.304, -2.3, -2.296, -2.292, -2.288, -2.283999, -2.28, -2.276, 7.392, 7.382, 7.372, 7.362, 7.352, 7.342, 7.332, 7.322, 7.312, 7.302, 7.292, 7.282, 7.272, 7.262, 7.252, 7.242, 8.16, 8.148001, 8.136001, 8.124001, 8.112, 8.1, 8.087999, 8.076, 8.063999, 8.052, 8.04, 8.028001, 8.016, 8.004001, 7.992001, 7.98, 7.968, 7.956, 7.944, 7.932001, 7.92, 7.908, 7.896, 7.884, 7.872001, 7.86, 7.848001, 7.835999, 7.824, 7.812, 7.800001, 7.788, 7.776, 7.764, 7.752, 7.740001, 7.728, 7.716001, 7.704, 7.692, 7.68, 7.668, 7.656, 7.644001, 7.632001, 7.62, 7.608001, 7.596001, 10.848, 10.830001, 10.812, 10.794001, 10.776, 10.758, 10.74, 10.722, 10.704, 10.686001, 10.668, 10.650001, 10.632, 10.614, 10.596001, 10.578001, 3.104, 3.1, 3.096, 3.092, 3.088, 3.084, 3.079999, 3.076001, 3.072, 3.068, 3.064, 3.06, 3.056, 3.052, 3.048, 3.044, 3.039999, 3.036001, 3.032, 3.028, 3.024001, 3.02, 3.016, 3.012, 3.008, 3.004, 3., 2.996, 2.992, 2.987999, 2.984001, 2.98, 2.976, 2.972, 2.968, 2.964, 2.96, 2.956, 2.952, 2.947999, 2.944001, 2.94, 2.936, 2.932001, 2.928, 2.924, 2.92, 2.916, 8.48, 8.466, 8.452, 8.438, 8.424, 8.41, 8.396, 8.382, 8.368, 8.354, 8.34, 8.326, 8.312, 8.298, 8.284, 8.27, -1.952, -1.948, -1.944, -1.94, -1.936, -1.932, -1.928, -1.924, -1.92, -1.916, -1.912, -1.908, -1.904, -1.9, -1.896, -1.892, -1.888, -1.884, -1.88, -1.876, -1.872, -1.868, -1.863999, -1.86, -1.856, -1.852, -1.848, -1.844, -1.84, -1.836, -1.832, -1.828, -1.823999, -1.82, -1.816, -1.812, -1.808, -1.804, -1.8, -1.796, -1.792, -1.788, -1.784, -1.78, -1.776, -1.771999, -1.768, -1.764, 6.112, 6.102, 6.092, 6.082, 6.072, 6.062, 6.052, 6.042, 6.032, 6.022, 6.012, 6.002, 5.992, 5.982, 5.972, 5.962}, sd::DataType::FLOAT32);
|
|
|
|
NDArray expGradW('c', {iC, oC, kD, kH, kW}, {-73678.695312, -59907.972656, -67739.515625, -54962.082031, -15966.075195, -17115.042969, -15269.777344, -16101.275391, 41746.566406, 25677.917969, 37200.003906, 22759.517578}, sd::DataType::FLOAT32);
|
|
NDArray expGradB('c', {oC}, {-1803.520020, -1639.679932}, sd::DataType::FLOAT32);
|
|
|
|
input.linspace(100., -0.5);
|
|
gradO.linspace(-16, 0.02);
|
|
|
|
sd::ops::deconv3d_bp op;
|
|
auto results = op.evaluate({&input, &weights, &bias, &gradO}, {kD,kH,kW, sD,sH,sW, pD,pH,pW, dD,dH,dW, paddingMode, dataFormat, wFormat});
|
|
|
|
ASSERT_EQ(ND4J_STATUS_OK, results.status());
|
|
|
|
auto gradI = results.at(0);
|
|
auto gradW = results.at(1);
|
|
auto gradB = results.at(2);
|
|
|
|
ASSERT_TRUE(expGradI.isSameShape(gradI));
|
|
ASSERT_TRUE(expGradI.equalsTo(gradI));
|
|
|
|
ASSERT_TRUE(expGradW.isSameShape(gradW));
|
|
ASSERT_TRUE(expGradW.equalsTo(gradW));
|
|
|
|
ASSERT_TRUE(expGradB.isSameShape(gradB));
|
|
ASSERT_TRUE(expGradB.equalsTo(gradB));
|
|
}
|
|
|
|
//////////////////////////////////////////////////////////////////////
|
|
TEST_F(ConvolutionTests2, deconv3d_bp_test6) {
|
|
|
|
int bS=2, iD=4,iH=4,iW=4, iC=3,oC=2, kD=2,kH=1,kW=1, sD=1,sH=1,sW=1, pD=0,pH=0,pW=0, dD=1,dH=1,dW=1;
|
|
int oD=5,oH=4,oW=4;
|
|
int paddingMode = 0; // 1-SAME, 0-VALID;
|
|
int dataFormat = 1; // 1-NHWC, 0-NCHW
|
|
int wFormat = 2; // 0 - [kD, kH, kW, oC, iC], 1 - [iC, oC, kD, kH, kW], 2 - [iC, kD, kH, kW, oC]
|
|
|
|
NDArray input('c', {bS, iD, iH, iW, iC}, sd::DataType::FLOAT32);
|
|
NDArray bias('c', {oC}, {-0.1, 0.2}, sd::DataType::FLOAT32);
|
|
NDArray weights('c',{iC, kD, kH, kW, oC}, {-0.6, -0.3, 0., 0.3, -0.5, -0.2, 0.1, 0.4, -0.4, -0.1, 0.2, 0.5}, sd::DataType::FLOAT32);
|
|
NDArray gradO('c', {bS, oD, oH, oW, oC}, sd::DataType::FLOAT32);
|
|
|
|
NDArray expGradI('c', {bS, iD, iH, iW, iC}, {1.056, 0.482, -0.092, 1.044, 0.478, -0.088, 1.032, 0.474, -0.084, 1.02, 0.47, -0.08, 1.008, 0.466, -0.076, 0.996,
|
|
0.462, -0.072, 0.984, 0.458, -0.068, 0.972, 0.454, -0.064, 0.96, 0.45, -0.06, 0.948, 0.446, -0.056, 0.936, 0.442, -0.052, 0.924, 0.438, -0.048, 0.912, 0.434,
|
|
-0.044, 0.9, 0.43, -0.04, 0.888, 0.426, -0.036, 0.876, 0.422, -0.032, 0.864, 0.418, -0.028, 0.852, 0.414, -0.024, 0.84, 0.41, -0.02, 0.828, 0.406, -0.016,
|
|
0.816, 0.402, -0.012, 0.804, 0.398, -0.008, 0.792, 0.394, -0.004, 0.78, 0.39, 0., 0.768, 0.386, 0.004, 0.756, 0.382, 0.008, 0.744, 0.378, 0.012, 0.732, 0.374,
|
|
0.016, 0.72, 0.37, 0.02, 0.708, 0.366, 0.024, 0.696, 0.362, 0.028, 0.684, 0.358, 0.032, 0.672, 0.354, 0.036, 0.66, 0.35, 0.04, 0.648, 0.346, 0.044, 0.636, 0.342, 0.048, 0.624, 0.338, 0.052, 0.612, 0.334, 0.056, 0.6, 0.33, 0.06, 0.588, 0.326, 0.064, 0.576, 0.322, 0.068, 0.564, 0.318, 0.072, 0.552, 0.314, 0.076, 0.54, 0.31, 0.08, 0.528, 0.306, 0.084, 0.516, 0.302, 0.088, 0.504, 0.298, 0.092, 0.492, 0.294, 0.096, 0.48, 0.29, 0.1, 0.468, 0.286, 0.104, 0.456, 0.282, 0.108, 0.444, 0.278, 0.112, 0.432, 0.274, 0.116, 0.42, 0.27, 0.12, 0.408, 0.266, 0.124, 0.396, 0.262, 0.128, 0.384, 0.258, 0.132, 0.372, 0.254, 0.136, 0.36, 0.25, 0.14, 0.348, 0.246, 0.144, 0.336, 0.242, 0.148, 0.324, 0.238, 0.152, 0.312, 0.234, 0.156, 0.3, 0.23, 0.16, 0.096, 0.162, 0.228, 0.084, 0.158, 0.232, 0.072, 0.154, 0.236, 0.06, 0.15, 0.24, 0.048, 0.146, 0.244, 0.036, 0.142, 0.248, 0.024, 0.138, 0.252, 0.012, 0.134, 0.256, 0., 0.13, 0.26, -0.012, 0.126, 0.264, -0.024, 0.122, 0.268, -0.036, 0.118, 0.272, -0.048, 0.114, 0.276, -0.06, 0.11, 0.28, -0.072, 0.106, 0.284, -0.084, 0.102, 0.288, -0.096, 0.098, 0.292, -0.108, 0.094, 0.296, -0.12, 0.09, 0.3, -0.132, 0.086, 0.304, -0.144, 0.082, 0.308, -0.156, 0.078, 0.312, -0.168, 0.074, 0.316, -0.18, 0.07, 0.32, -0.192, 0.066, 0.324, -0.204, 0.062, 0.328, -0.216, 0.058, 0.332, -0.228, 0.054, 0.336, -0.24, 0.05, 0.34, -0.252, 0.046, 0.344, -0.264, 0.042, 0.348, -0.276, 0.038, 0.352, -0.288, 0.034, 0.356, -0.3, 0.03, 0.36, -0.312, 0.026, 0.364, -0.324, 0.022, 0.368, -0.336, 0.018, 0.372, -0.348, 0.014, 0.376, -0.36, 0.01, 0.38, -0.372, 0.006, 0.384, -0.384, 0.002, 0.388, -0.396, -0.002, 0.392, -0.408, -0.006, 0.396, -0.42, -0.01, 0.4, -0.432, -0.014, 0.404, -0.444, -0.018, 0.408, -0.456, -0.022, 0.412, -0.468, -0.026, 0.416, -0.48, -0.03, 0.42, -0.492, -0.034, 0.424, -0.504, -0.038, 0.428, -0.516, -0.042, 0.432, -0.528, -0.046, 0.436, -0.54, -0.05, 0.44, -0.552, -0.054, 0.444, -0.564, -0.058, 0.448, -0.576, -0.062, 0.452, -0.588, -0.066, 0.456, -0.6, -0.07, 0.46, -0.612, -0.074, 0.464, -0.624, -0.078, 0.468, -0.636, -0.082, 0.472, -0.648, -0.086, 0.476, -0.66, -0.09, 0.48}, sd::DataType::FLOAT32);
|
|
|
|
NDArray expGradW('c', {iC, kD, kH, kW, oC}, {-6328.958984, -6322.880371, -6134.400879, -6128.319824, -6318.079590, -6312.640137, -6144.000000, -6138.560547, -6307.202637, -6302.399414, -6153.599609, -6148.799316}, sd::DataType::FLOAT32);
|
|
NDArray expGradB('c', {oC}, {-1.599994, 0.000001}, sd::DataType::FLOAT32);
|
|
|
|
input.linspace(100., -0.5);
|
|
gradO.linspace(-1.6, 0.01);
|
|
|
|
sd::ops::deconv3d_bp op;
|
|
auto results = op.evaluate({&input, &weights, &bias, &gradO}, {kD,kH,kW, sD,sH,sW, pD,pH,pW, dD,dH,dW, paddingMode, dataFormat, wFormat});
|
|
|
|
ASSERT_EQ(ND4J_STATUS_OK, results.status());
|
|
|
|
auto gradI = results.at(0);
|
|
auto gradW = results.at(1);
|
|
auto gradB = results.at(2);
|
|
|
|
ASSERT_TRUE(expGradI.isSameShape(gradI));
|
|
ASSERT_TRUE(expGradI.equalsTo(gradI));
|
|
|
|
ASSERT_TRUE(expGradW.isSameShape(gradW));
|
|
ASSERT_TRUE(expGradW.equalsTo(gradW));
|
|
|
|
ASSERT_TRUE(expGradB.isSameShape(gradB));
|
|
ASSERT_TRUE(expGradB.equalsTo(gradB));
|
|
}
|
|
|
|
//////////////////////////////////////////////////////////////////////
|
|
TEST_F(ConvolutionTests2, maxpool2d_1) {
|
|
|
|
auto x = NDArrayFactory::create_<float>('c', {bS,iD,iH,iW});
|
|
auto exp = NDArrayFactory::create<float>('c',{bS,iD,oH,oW});
|
|
// auto z('c',{bS,iD,oH,oW});
|
|
|
|
auto variableSpace = new VariableSpace();
|
|
variableSpace->putVariable(-1, x);
|
|
// variableSpace->putVariable(1, &z);
|
|
|
|
auto block = new Context(1, variableSpace, false);
|
|
block->fillInputs({-1});
|
|
std::vector<int>* argI = block->getIArguments();
|
|
*argI = {kH,kW, sH,sW, pH,pW, dH,dW, 0}; // 0,1 - kernel Height/Width; 2,3 - stride Height/Width; 4,5 - pad Height/Width; 6,7 - dilation Height/Width; 8 - same mode;
|
|
|
|
sd::ops::maxpool2d pooling;
|
|
Nd4jStatus status = pooling.execute(block);
|
|
ASSERT_EQ(ND4J_STATUS_OK, status);
|
|
|
|
auto result = variableSpace->getVariable(block->getNodeId())->getNDArray();
|
|
// result.printShapeInfo();
|
|
ASSERT_TRUE(exp.isSameShape(result));
|
|
|
|
delete variableSpace;
|
|
delete block;
|
|
}
|
|
|
|
//////////////////////////////////////////////////////////////////////
|
|
TEST_F(ConvolutionTests2, maxpool2d_2) {
|
|
|
|
const int bS = 2;
|
|
const int iD = 1;
|
|
const int iH = 28;
|
|
const int iW = 28;
|
|
const int kH = 5;
|
|
const int kW = 5;
|
|
const int sH = 1;
|
|
const int sW = 1;
|
|
const int pH = 0;
|
|
const int pW = 0;
|
|
const int dH = 1;
|
|
const int dW = 1;
|
|
const int oH = (iH - kH - (kH-1)*(dH-1) + 2*pH)/sH + 1; // output height
|
|
const int oW = (iW - kW - (kW-1)*(dW-1) + 2*pW)/sW + 1; // output width
|
|
|
|
|
|
auto x = NDArrayFactory::create_<float>('c', {bS,iD,iH,iW});
|
|
auto exp = NDArrayFactory::create<float>('c',{bS,iD,oH,oW});
|
|
// auto z('c',{bS,iD,oH,oW});
|
|
|
|
auto variableSpace = new VariableSpace();
|
|
variableSpace->putVariable(-1, x);
|
|
// variableSpace->putVariable(1, &z);
|
|
|
|
auto block = new Context(1, variableSpace, false);
|
|
block->fillInputs({-1});
|
|
std::vector<int>* argI = block->getIArguments();
|
|
*argI = {kH,kW, sH,sW, pH,pW, dH,dW, 0}; // 0,1 - kernel Height/Width; 2,3 - stride Height/Width; 4,5 - pad Height/Width; 6,7 - dilation Height/Width; 8 - same mode;
|
|
|
|
sd::ops::maxpool2d pooling;
|
|
Nd4jStatus status = pooling.execute(block);
|
|
ASSERT_EQ(ND4J_STATUS_OK, status);
|
|
|
|
auto result = variableSpace->getVariable(block->getNodeId())->getNDArray();
|
|
// result.printShapeInfo();
|
|
ASSERT_TRUE(exp.isSameShape(result));
|
|
|
|
delete variableSpace;
|
|
delete block;
|
|
}
|
|
|
|
//////////////////////////////////////////////////////////////////////
|
|
TEST_F(ConvolutionTests2, maxpool2d_3) {
|
|
|
|
const int bS = 2;
|
|
const int iD = 1;
|
|
const int iH = 28;
|
|
const int iW = 28;
|
|
const int kH = 5;
|
|
const int kW = 5;
|
|
const int sH = 1;
|
|
const int sW = 1;
|
|
const int pH = 0;
|
|
const int pW = 0;
|
|
const int dH = 1;
|
|
const int dW = 1;
|
|
const int oH = (int) sd::math::nd4j_ceil<float, int>(iH * 1.f / sH);
|
|
const int oW = (int) sd::math::nd4j_ceil<float, int>(iW * 1.f / sW);
|
|
|
|
|
|
auto x = NDArrayFactory::create_<float>('c', {bS,iD,iH,iW});
|
|
auto exp = NDArrayFactory::create<float>('c',{bS,iD,oH,oW});
|
|
// auto z('c',{bS,iD,oH,oW});
|
|
|
|
auto variableSpace = new VariableSpace();
|
|
variableSpace->putVariable(-1, x);
|
|
// variableSpace->putVariable(1, &z);
|
|
|
|
auto block = new Context(1, variableSpace, false);
|
|
block->fillInputs({-1});
|
|
std::vector<int>* argI = block->getIArguments();
|
|
*argI = {kH,kW, sH,sW, pH,pW, dH,dW, 1}; // 0,1 - kernel Height/Width; 2,3 - stride Height/Width; 4,5 - pad Height/Width; 6,7 - dilation Height/Width; 8 - same mode;
|
|
|
|
sd::ops::maxpool2d pooling;
|
|
Nd4jStatus status = pooling.execute(block);
|
|
ASSERT_EQ(ND4J_STATUS_OK, status);
|
|
|
|
auto result = variableSpace->getVariable(block->getNodeId())->getNDArray();
|
|
// result.printShapeInfo();
|
|
ASSERT_TRUE(exp.isSameShape(result));
|
|
|
|
delete variableSpace;
|
|
delete block;
|
|
}
|
|
|
|
//////////////////////////////////////////////////////////////////////
|
|
TEST_F(ConvolutionTests2, maxpool2d_4) {
|
|
|
|
const int bS = 2;
|
|
const int iD = 1;
|
|
const int iH = 24;
|
|
const int iW = 24;
|
|
const int kH = 3;
|
|
const int kW = 3;
|
|
const int sH = 1;
|
|
const int sW = 1;
|
|
const int pH = 0;
|
|
const int pW = 0;
|
|
const int dH = 1;
|
|
const int dW = 1;
|
|
const int oH = (iH - kH - (kH-1)*(dH-1) + 2*pH)/sH + 1; // output height
|
|
const int oW = (iW - kW - (kW-1)*(dW-1) + 2*pW)/sW + 1; // output width
|
|
|
|
|
|
auto x = NDArrayFactory::create_<float>('c', {bS,iD,iH,iW});
|
|
auto exp = NDArrayFactory::create<float>('c',{bS,iD,oH,oW});
|
|
// auto z('c',{bS,iD,oH,oW});
|
|
|
|
auto variableSpace = new VariableSpace();
|
|
variableSpace->putVariable(-1, x);
|
|
// variableSpace->putVariable(1, &z);
|
|
|
|
auto block = new Context(1, variableSpace, false);
|
|
block->fillInputs({-1});
|
|
std::vector<int>* argI = block->getIArguments();
|
|
*argI = {kH,kW, sH,sW, pH,pW, dH,dW, 0}; // 0,1 - kernel Height/Width; 2,3 - stride Height/Width; 4,5 - pad Height/Width; 6,7 - dilation Height/Width; 8 - same mode;
|
|
|
|
sd::ops::maxpool2d pooling;
|
|
Nd4jStatus status = pooling.execute(block);
|
|
ASSERT_EQ(ND4J_STATUS_OK, status);
|
|
|
|
auto result = variableSpace->getVariable(block->getNodeId())->getNDArray();
|
|
// result.printShapeInfo();
|
|
ASSERT_TRUE(exp.isSameShape(result));
|
|
|
|
delete variableSpace;
|
|
delete block;
|
|
}
|
|
|
|
//////////////////////////////////////////////////////////////////////
|
|
TEST_F(ConvolutionTests2, maxpool2d_5) {
|
|
|
|
const int bS = 2;
|
|
const int iD = 1;
|
|
const int iH = 24;
|
|
const int iW = 24;
|
|
const int kH = 3;
|
|
const int kW = 3;
|
|
const int sH = 1;
|
|
const int sW = 1;
|
|
const int pH = 0;
|
|
const int pW = 0;
|
|
const int dH = 1;
|
|
const int dW = 1;
|
|
const int oH = (int) sd::math::nd4j_ceil<float, int>(iH * 1.f / sH);
|
|
const int oW = (int) sd::math::nd4j_ceil<float, int>(iW * 1.f / sW);
|
|
|
|
|
|
auto x = NDArrayFactory::create_<float>('c', {bS,iD,iH,iW});
|
|
auto exp = NDArrayFactory::create<float>('c',{bS,iD,oH,oW});
|
|
// auto z('c',{bS,iD,oH,oW});
|
|
|
|
auto variableSpace = new VariableSpace();
|
|
variableSpace->putVariable(-1, x);
|
|
// variableSpace->putVariable(1, &z);
|
|
|
|
auto block = new Context(1, variableSpace, false);
|
|
block->fillInputs({-1});
|
|
std::vector<int>* argI = block->getIArguments();
|
|
*argI = {kH,kW, sH,sW, pH,pW, dH,dW, 1}; // 0,1 - kernel Height/Width; 2,3 - stride Height/Width; 4,5 - pad Height/Width; 6,7 - dilation Height/Width; 8 - same mode;
|
|
|
|
sd::ops::maxpool2d pooling;
|
|
Nd4jStatus status = pooling.execute(block);
|
|
ASSERT_EQ(ND4J_STATUS_OK, status);
|
|
|
|
auto result = variableSpace->getVariable(block->getNodeId())->getNDArray();
|
|
// result.printShapeInfo();
|
|
ASSERT_TRUE(exp.isSameShape(result));
|
|
|
|
delete variableSpace;
|
|
delete block;
|
|
}
|
|
|
|
//////////////////////////////////////////////////////////////////////
|
|
TYPED_TEST(TypedConvolutionTests2, maxpool2d_6) {
|
|
auto x = NDArrayFactory::create<TypeParam>('c', {2, 4, 4, 2});
|
|
auto exp = NDArrayFactory::create<TypeParam>('c', {2, 2, 2, 2}, {11.f, 12.f, 15.f, 16.f, 27.f, 28.f, 31.f, 32.f, 43.f, 44.f, 47.f, 48.f, 59.f, 60.f, 63.f, 64.f});
|
|
|
|
x.linspace(1);
|
|
|
|
sd::ops::maxpool2d op;
|
|
auto result = op.evaluate({&x}, {}, {2, 2, 2, 2, 0, 0, 1, 1, 1, 1, 1});
|
|
|
|
ASSERT_EQ(ND4J_STATUS_OK, result.status());
|
|
|
|
auto z = result.at(0);
|
|
#if 0
|
|
exp.printIndexedBuffer("Expected");
|
|
z->printIndexedBuffer("Z");
|
|
#endif
|
|
ASSERT_TRUE(exp.isSameShape(z));
|
|
ASSERT_TRUE(exp.equalsTo(z));
|
|
|
|
}
|
|
|
|
//////////////////////////////////////////////////////////////////////
|
|
TYPED_TEST(TypedConvolutionTests2, maxpool2d_7) {
|
|
auto x = NDArrayFactory::create<TypeParam>('c', {2, 4, 4, 2});
|
|
auto exp = NDArrayFactory::create<TypeParam>('c', {2, 2, 2, 2}, {11.f, 12.f, 15.f, 16.f, 27.f, 28.f, 31.f, 32.f, 43.f, 44.f, 47.f, 48.f, 59.f, 60.f, 63.f, 64.f});
|
|
|
|
x.linspace(1);
|
|
|
|
sd::ops::maxpool2d op;
|
|
auto result = op.evaluate({&x}, {}, {2, 2, 2, 2, 0, 0, 1, 1, 0, 1, 1});
|
|
|
|
ASSERT_EQ(ND4J_STATUS_OK, result.status());
|
|
|
|
auto z = result.at(0);
|
|
#if 0
|
|
exp.printIndexedBuffer("Expected");
|
|
z->printIndexedBuffer("Z");
|
|
#endif
|
|
ASSERT_TRUE(exp.isSameShape(z));
|
|
ASSERT_TRUE(exp.equalsTo(z));
|
|
|
|
}
|
|
|
|
//////////////////////////////////////////////////////////////////////
|
|
TYPED_TEST(TypedConvolutionTests2, maxpool2d_8) {
|
|
auto x = NDArrayFactory::create<TypeParam>('c', {2, 2, 5, 5});
|
|
auto exp = NDArrayFactory::create<TypeParam>('c', {2, 2, 2, 2}, {7.f, 9.f, 17.f, 19.f, 32.f, 34.f, 42.f, 44.f, 57.f, 59.f, 67.f, 69.f, 82.f, 84.f, 92.f, 94.f});
|
|
|
|
x.linspace(1);
|
|
|
|
sd::ops::maxpool2d op;
|
|
auto result = op.evaluate({&x}, {}, {2, 2, 2, 2, 0, 0, 1, 1, 0, 1, 0});
|
|
|
|
ASSERT_EQ(ND4J_STATUS_OK, result.status());
|
|
|
|
auto z = result.at(0);
|
|
#if 0
|
|
exp.printIndexedBuffer("Expected");
|
|
z->printIndexedBuffer("Z");
|
|
#endif
|
|
ASSERT_TRUE(exp.isSameShape(z));
|
|
ASSERT_TRUE(exp.equalsTo(z));
|
|
}
|
|
|
|
//////////////////////////////////////////////////////////////////////
|
|
TYPED_TEST(TypedConvolutionTests2, maxpool2d_9) {
|
|
|
|
int bS = 3; // batch size (number of samples)
|
|
int iC = 3; // input channels
|
|
int iH = 28, iW = 28; // input height/width
|
|
int kH = 2, kW = 2; // kernel (filter) height/width
|
|
int sH = 1, sW = 1; // stride height/width
|
|
int pH = 0, pW = 0; // padding height/width
|
|
int dH = 1, dW = 1; // dilation height/width
|
|
|
|
int oH = 27, oW = 27; // output height/width
|
|
|
|
int isSameMode = 0; // 1-SAME, 0-VALID
|
|
|
|
auto input = NDArrayFactory::create<TypeParam>('c', {bS, iC, iH, iW});
|
|
|
|
sd::ops::maxpool2d op;
|
|
auto results = op.evaluate({&input}, {}, {kH,kW, sH,sW, pH,pW, dH,dW, isSameMode, 1, 0});
|
|
auto output = results.at(0);
|
|
|
|
ASSERT_EQ(Status::OK(), results.status());
|
|
ASSERT_TRUE(output->isSameShape({bS, iC, oH, oW}));
|
|
|
|
}
|
|
|
|
//////////////////////////////////////////////////////////////////////
|
|
TYPED_TEST(TypedConvolutionTests2, maxpool2d_10) {
|
|
|
|
int bS=1, iH=4,iW=4, iC=3, kH=2,kW=2, sH=1,sW=1, pH=0,pW=0, dH=1,dW=1;
|
|
int oH=3,oW=3;
|
|
int paddingMode = 0; // 1-SAME, 0-VALID;
|
|
int dataFormat = 0; // 1-NHWC, 0-NCHW
|
|
|
|
auto input = NDArrayFactory::create<TypeParam>('c', {bS, iC, iH, iW}, {0.27620894f, 0.21801452f, 0.062078513f, 7.348895E-4f, 0.24149609f, 0.4948205f, 0.93483436f, 0.52035654f, 0.30292067f, 0.3289706f, 0.7977864f,
|
|
0.03180518f, 0.1455722f, 0.90352905f, 0.9405744f, 0.0048329555f, 0.44062102f, 0.111197524f, 0.31742015f, 0.1933705f, 0.23825112f, 0.35076278f, 0.7135856f, 0.28229436f, 0.18310733f,
|
|
0.9613717f, 0.56823575f, 0.78289545f, 0.62195826f, 0.5244586f, 0.5040889f, 0.025349546f, 0.41400263f, 0.28420195f, 0.8536445f, 0.3044107f, 0.7997134f, 0.45762005f, 0.7653578f,
|
|
0.07198584f, 0.5304998f, 0.7334402f, 0.85019743f, 0.031957153f, 0.37088063f, 0.85722464f, 0.06376881f, 0.39791203f});
|
|
|
|
auto expOutput = NDArrayFactory::create<TypeParam>('c', {bS, iC, oH, oW}, {0.4948205f, 0.93483436f, 0.93483436f, 0.4948205f, 0.93483436f, 0.93483436f, 0.90352905f, 0.9405744f, 0.9405744f, 0.44062102f, 0.7135856f,
|
|
0.7135856f, 0.9613717f, 0.9613717f, 0.78289545f, 0.9613717f, 0.9613717f, 0.78289545f, 0.7997134f, 0.8536445f, 0.8536445f, 0.7997134f, 0.85019743f, 0.85019743f,
|
|
0.85722464f, 0.85722464f, 0.85019743f});
|
|
|
|
sd::ops::maxpool2d op;
|
|
auto results = op.evaluate({&input}, {}, {kH,kW, sH,sW, pH,pW, dH,dW, paddingMode});
|
|
auto* output = results.at(0);
|
|
|
|
ASSERT_EQ(Status::OK(), results.status());
|
|
#if 0
|
|
expOutput.printIndexedBuffer("expOutput");
|
|
output->printIndexedBuffer("output");
|
|
#endif
|
|
ASSERT_TRUE(expOutput.isSameShape(output));
|
|
ASSERT_TRUE(expOutput.equalsTo(output));
|
|
}
|
|
|
|
//////////////////////////////////////////////////////////////////////
|
|
TEST_F(ConvolutionTests2, maxpool2d_11) {
|
|
|
|
NDArray input('c', {1,1,4,5}, sd::DataType::FLOAT32);
|
|
NDArray z('c', {1,1,4,5}, sd::DataType::FLOAT32);
|
|
|
|
input.linspace(1.);
|
|
|
|
sd::ops::maxpool2d op;
|
|
auto results = op.evaluate({&input}, {}, {2,2, 1,1, 1,1, 2,2, 1,0,0});
|
|
|
|
ASSERT_EQ(Status::OK(), results.status());
|
|
|
|
}
|
|
|
|
//////////////////////////////////////////////////////////////////////
|
|
TYPED_TEST(TypedConvolutionTests2, avgpool3d_test1) {
|
|
|
|
int bS=2, iD=3,iH=4,iW=3, iC=3, kD=2,kH=3,kW=2, sD=1,sH=1,sW=1, pD=0,pH=0,pW=0, dD=1,dH=1,dW=1;
|
|
int oD=2,oH=2,oW=2;
|
|
int paddingMode = 0; // 1-SAME, 0-VALID
|
|
int dataFormat = 0; // 1-NDHWC, 0-NCDHW
|
|
|
|
auto input = NDArrayFactory::create<TypeParam>('c', {bS, iC, iD, iH, iW});
|
|
auto expected = NDArrayFactory::create<TypeParam>('c', {bS, iC, oD, oH, oW}, {10.5f, 11.5f, 13.5f, 14.5f, 22.5f, 23.5f, 25.5f, 26.5f, 46.5f, 47.5f, 49.5f, 50.5f, 58.5f, 59.5f, 61.5f, 62.5f,
|
|
82.5f, 83.5f, 85.5f, 86.5f, 94.5f, 95.5f, 97.5f, 98.5f,118.5f,119.5f,121.5f,122.5f,130.5f,131.5f,133.5f,134.5f,
|
|
154.5f,155.5f,157.5f,158.5f,166.5f,167.5f,169.5f,170.5f,190.5f,191.5f,193.5f,194.5f,202.5f,203.5f,205.5f,206.5f});
|
|
input.linspace(1.);
|
|
|
|
sd::ops::avgpool3dnew op;
|
|
auto results = op.evaluate({&input}, {}, {kD,kH,kW, sD,sH,sW, pD,pH,pW, dD,dH,dW, paddingMode, 1, dataFormat});
|
|
auto output = results.at(0);
|
|
|
|
ASSERT_EQ(Status::OK(), results.status());
|
|
ASSERT_TRUE(expected.isSameShape(output));
|
|
ASSERT_TRUE(expected.equalsTo(output));
|
|
|
|
}
|
|
|
|
//////////////////////////////////////////////////////////////////////
|
|
TYPED_TEST(TypedConvolutionTests2, avgpool3d_test2) {
|
|
|
|
int bS=2, iD=3,iH=4,iW=3, iC=3, kD=2,kH=3,kW=2, sD=1,sH=1,sW=1, pD=0,pH=0,pW=0, dD=1,dH=1,dW=1;
|
|
int oD=3,oH=4,oW=3;
|
|
int paddingMode = 1; // 1-SAME, 0-VALID
|
|
int dataFormat = 1; // 1-NDHWC, 0-NCDHW
|
|
|
|
auto input = NDArrayFactory::create<TypeParam>('c', {bS, iD, iH, iW, iC});
|
|
auto expected = NDArrayFactory::create<TypeParam>('c', {bS, oD, oH, oW, iC}, { 25.f, 26.f, 27.f, 28.f, 29.f, 30.f, 29.5f, 30.5f, 31.5f, 29.5f, 30.5f, 31.5f, 32.5f, 33.5f, 34.5f, 34.f, 35.f, 36.f, 38.5f, 39.5f, 40.5f, 41.5f, 42.5f, 43.5f, 43.f, 44.f, 45.f, 43.f, 44.f, 45.f, 46.f, 47.f, 48.f, 47.5f, 48.5f, 49.5f,
|
|
61.f, 62.f, 63.f, 64.f, 65.f, 66.f, 65.5f, 66.5f, 67.5f, 65.5f, 66.5f, 67.5f, 68.5f, 69.5f, 70.5f, 70.f, 71.f, 72.f, 74.5f, 75.5f, 76.5f, 77.5f, 78.5f, 79.5f, 79.f, 80.f, 81.f, 79.f, 80.f, 81.f, 82.f, 83.f, 84.f, 83.5f, 84.5f, 85.5f,
|
|
79.f, 80.f, 81.f, 82.f, 83.f, 84.f, 83.5f, 84.5f, 85.5f, 83.5f, 84.5f, 85.5f, 86.5f, 87.5f, 88.5f, 88.f, 89.f, 90.f, 92.5f, 93.5f, 94.5f, 95.5f, 96.5f, 97.5f, 97.f, 98.f, 99.f, 97.f, 98.f, 99.f, 100.f, 101.f, 102.f, 101.5f, 102.5f, 103.5f,
|
|
133.f, 134.f, 135.f, 136.f, 137.f, 138.f, 137.5f, 138.5f, 139.5f, 137.5f, 138.5f, 139.5f, 140.5f, 141.5f, 142.5f, 142.f, 143.f, 144.f, 146.5f, 147.5f, 148.5f, 149.5f, 150.5f, 151.5f, 151.f, 152.f, 153.f, 151.f, 152.f, 153.f, 154.f, 155.f, 156.f, 155.5f, 156.5f, 157.5f,
|
|
169.f, 170.f, 171.f, 172.f, 173.f, 174.f, 173.5f, 174.5f, 175.5f, 173.5f, 174.5f, 175.5f, 176.5f, 177.5f, 178.5f, 178.f, 179.f, 180.f, 182.5f, 183.5f, 184.5f, 185.5f, 186.5f, 187.5f, 187.f, 188.f, 189.f, 187.f, 188.f, 189.f, 190.f, 191.f, 192.f, 191.5f, 192.5f, 193.5f,
|
|
187.f, 188.f, 189.f, 190.f, 191.f, 192.f, 191.5f, 192.5f, 193.5f, 191.5f, 192.5f, 193.5f, 194.5f, 195.5f, 196.5f, 196.f, 197.f, 198.f, 200.5f, 201.5f, 202.5f, 203.5f, 204.5f, 205.5f, 205.f, 206.f, 207.f, 205.f, 206.f, 207.f, 208.f, 209.f, 210.f, 209.5f, 210.5f, 211.5f});
|
|
input.linspace(1.);
|
|
|
|
sd::ops::avgpool3dnew op;
|
|
auto results = op.evaluate({&input}, {}, {kD,kH,kW, sD,sH,sW, pD,pH,pW, dD,dH,dW, paddingMode, 0, dataFormat});
|
|
auto output = results.at(0);
|
|
|
|
ASSERT_EQ(Status::OK(), results.status());
|
|
ASSERT_TRUE(expected.isSameShape(output));
|
|
ASSERT_TRUE(expected.equalsTo(output));
|
|
|
|
}
|
|
|
|
//////////////////////////////////////////////////////////////////////
|
|
TYPED_TEST(TypedConvolutionTests2, avgpool3d_test3) {
|
|
|
|
int bS=2, iD=3,iH=4,iW=3, iC=3, kD=2,kH=3,kW=2, sD=1,sH=1,sW=1, pD=0,pH=0,pW=0, dD=1,dH=1,dW=1;
|
|
int oD=2,oH=2,oW=2;
|
|
int paddingMode = 0; // 1-SAME, 0-VALID
|
|
int dataFormat = 1; // 1-NDHWC, 0-NCDHW
|
|
|
|
auto input = NDArrayFactory::create<TypeParam>('c', {bS, iD, iH, iW, iC});
|
|
auto expected = NDArrayFactory::create<TypeParam>('c', {bS, oD, oH, oW, iC}, { 29.5f, 30.5f, 31.5f, 32.5f, 33.5f, 34.5f, 38.5f, 39.5f, 40.5f, 41.5f, 42.5f, 43.5f, 65.5f, 66.5f, 67.5f, 68.5f, 69.5f, 70.5f,
|
|
74.5f, 75.5f, 76.5f, 77.5f, 78.5f, 79.5f, 137.5f, 138.5f, 139.5f, 140.5f, 141.5f, 142.5f, 146.5f, 147.5f, 148.5f, 149.5f, 150.5f, 151.5f,
|
|
173.5f, 174.5f, 175.5f, 176.5f, 177.5f, 178.5f, 182.5f, 183.5f, 184.5f, 185.5f, 186.5f, 187.5f});
|
|
input.linspace(1.);
|
|
|
|
sd::ops::avgpool3dnew op;
|
|
auto results = op.evaluate({&input}, {}, {kD,kH,kW, sD,sH,sW, pD,pH,pW, dD,dH,dW, paddingMode, 1, dataFormat});
|
|
auto output = results.at(0);
|
|
|
|
ASSERT_EQ(Status::OK(), results.status());
|
|
ASSERT_TRUE(expected.isSameShape(output));
|
|
ASSERT_TRUE(expected.equalsTo(output));
|
|
|
|
}
|
|
|
|
//////////////////////////////////////////////////////////////////////
|
|
TYPED_TEST(TypedConvolutionTests2, avgpool3d_test4) {
|
|
|
|
int bS=2, iD=3,iH=4,iW=3, iC=3, kD=2,kH=3,kW=2, sD=1,sH=1,sW=1, pD=1,pH=1,pW=1, dD=1,dH=1,dW=1;
|
|
int oD=4,oH=4,oW=4;
|
|
int paddingMode = 0; // 1-SAME, 0-VALID
|
|
int dataFormat = 0; // 1-NDHWC, 0-NCDHW
|
|
|
|
auto input = NDArrayFactory::create<TypeParam>('c', {bS, iC, iD, iH, iW});
|
|
auto expected = NDArrayFactory::create<TypeParam>('c', {bS, iC, oD, oH, oW},{0.416667f, 1.00f, 1.333333f, 0.75f, 1.00f, 2.25f, 2.75f, 1.50f, 1.75f, 3.75f, 4.25f, 2.25f, 1.416667f, 3.00f, 3.333333f, 1.75f, 2.833333f, 6.00f, 6.666667f, 3.50f, 5.00f, 10.50f, 11.50f, 6.00f, 6.50f,
|
|
13.50f, 14.50f, 7.50f, 4.833333f, 10.00f, 10.666667f, 5.50f, 6.833333f, 14.00f, 14.666667f, 7.50f, 11.00f, 22.50f, 23.50f, 12.00f, 12.50f, 25.50f, 26.50f, 13.50f, 8.833333f, 18.00f, 18.666666f, 9.50f,
|
|
4.416667f, 9.00f, 9.333333f, 4.75f, 7.00f, 14.25f, 14.75f, 7.50f, 7.75f, 15.75f, 16.25f, 8.25f, 5.416667f, 11.00f, 11.333333f, 5.75f, 6.416667f, 13.00f, 13.333333f, 6.75f, 10.00f, 20.25f, 20.75f,
|
|
10.50f, 10.75f, 21.75f, 22.25f, 11.25f, 7.416667f, 15.00f, 15.333333f, 7.75f, 14.833333f, 30.00f, 30.666666f, 15.50f, 23.00f, 46.50f, 47.50f, 24.00f, 24.50f, 49.50f, 50.50f, 25.50f, 16.833334f,
|
|
34.00f, 34.666668f, 17.50f, 18.833334f, 38.00f, 38.666668f, 19.50f, 29.00f, 58.50f, 59.50f, 30.00f, 30.50f, 61.50f, 62.50f, 31.50f, 20.833334f, 42.00f, 42.666668f, 21.50f, 10.416667f, 21.00f,
|
|
21.333334f, 10.75f, 16.00f, 32.25f, 32.75f, 16.50f, 16.75f, 33.75f, 34.25f, 17.25f, 11.416667f, 23.00f, 23.333334f, 11.75f, 12.416667f, 25.00f, 25.333334f, 12.75f, 19.00f, 38.25f, 38.75f, 19.50f,
|
|
19.75f, 39.75f, 40.25f, 20.25f, 13.416667f, 27.00f, 27.333334f, 13.75f, 26.833334f, 54.00f, 54.666668f, 27.50f, 41.00f, 82.50f, 83.50f, 42.00f, 42.50f, 85.50f, 86.50f, 43.50f, 28.833334f, 58.00f,
|
|
58.666668f, 29.50f, 30.833334f, 62.00f, 62.666668f, 31.50f, 47.00f, 94.50f, 95.50f, 48.00f, 48.50f, 97.50f, 98.50f, 49.50f, 32.833332f, 66.00f, 66.666664f, 33.50f, 16.416666f, 33.00f, 33.333332f,
|
|
16.75f, 25.00f, 50.25f, 50.75f, 25.50f, 25.75f, 51.75f, 52.25f, 26.25f, 17.416666f, 35.00f, 35.333332f, 17.75f, 18.416666f, 37.00f, 37.333332f, 18.75f, 28.00f, 56.25f, 56.75f, 28.50f, 28.75f,
|
|
57.75f, 58.25f, 29.25f, 19.416666f, 39.00f, 39.333332f, 19.75f, 38.833332f, 78.00f, 78.666664f, 39.50f, 59.00f, 118.50f, 119.50f, 60.00f, 60.50f, 121.50f, 122.50f, 61.50f, 40.833332f, 82.00f,
|
|
82.666664f, 41.50f, 42.833332f, 86.00f, 86.666664f, 43.50f, 65.00f, 130.50f, 131.50f, 66.00f, 66.50f, 133.50f, 134.50f, 67.50f, 44.833332f, 90.00f, 90.666664f, 45.50f, 22.416666f, 45.00f,
|
|
45.333332f, 22.75f, 34.00f, 68.25f, 68.75f, 34.50f, 34.75f, 69.75f, 70.25f, 35.25f, 23.416666f, 47.00f, 47.333332f, 23.75f, 24.416666f, 49.00f, 49.333332f, 24.75f, 37.00f, 74.25f, 74.75f,
|
|
37.50f, 37.75f, 75.75f, 76.25f, 38.25f, 25.416666f, 51.00f, 51.333332f, 25.75f, 50.833332f, 102.00f, 102.666664f, 51.50f, 77.00f, 154.50f, 155.50f, 78.00f, 78.50f, 157.50f, 158.50f, 79.50f,
|
|
52.833332f, 106.00f, 106.666664f, 53.50f, 54.833332f, 110.00f, 110.666664f, 55.50f, 83.00f, 166.50f, 167.50f, 84.00f, 84.50f, 169.50f, 170.50f, 85.50f, 56.833332f, 114.00f, 114.666664f,
|
|
57.50f, 28.416666f, 57.00f, 57.333332f, 28.75f, 43.00f, 86.25f, 86.75f, 43.50f, 43.75f, 87.75f, 88.25f, 44.25f, 29.416666f, 59.00f, 59.333332f, 29.75f, 30.416666f, 61.00f, 61.333332f, 30.75f,
|
|
46.00f, 92.25f, 92.75f, 46.50f, 46.75f, 93.75f, 94.25f, 47.25f, 31.416666f, 63.00f, 63.333332f, 31.75f, 62.833332f, 126.00f, 126.666664f, 63.50f, 95.00f, 190.50f, 191.50f, 96.00f, 96.50f,
|
|
193.50f, 194.50f, 97.50f, 64.833336f, 130.00f, 130.666672f, 65.50f, 66.833336f, 134.00f, 134.666672f, 67.50f, 101.00f, 202.50f, 203.50f, 102.00f, 102.50f, 205.50f, 206.50f, 103.50f,
|
|
68.833336f, 138.00f, 138.666672f, 69.50f, 34.416668f, 69.00f, 69.333336f, 34.75f, 52.00f, 104.25f, 104.75f, 52.50f, 52.75f, 105.75f, 106.25f, 53.25f, 35.416668f, 71.00f, 71.333336f, 35.75f});
|
|
input.linspace(1.);
|
|
|
|
sd::ops::avgpool3dnew op;
|
|
auto results = op.evaluate({&input}, {}, {kD,kH,kW, sD,sH,sW, pD,pH,pW, dD,dH,dW, paddingMode, 1, dataFormat});
|
|
auto output = results.at(0);
|
|
|
|
ASSERT_EQ(Status::OK(), results.status());
|
|
ASSERT_TRUE(expected.isSameShape(output));
|
|
ASSERT_TRUE(expected.equalsTo(output));
|
|
|
|
}
|
|
|
|
//////////////////////////////////////////////////////////////////////
|
|
TYPED_TEST(TypedConvolutionTests2, maxpool3d_test1) {
|
|
|
|
int bS=2, iD=3,iH=4,iW=3, iC=3, kD=2,kH=3,kW=2, sD=1,sH=1,sW=1, pD=0,pH=0,pW=0, dD=1,dH=1,dW=1;
|
|
int oD=2,oH=2,oW=2;
|
|
int paddingMode = 0; // 1-SAME, 0-VALID
|
|
int dataFormat = 0; // 1-NDHWC, 0-NCDHW
|
|
|
|
auto input = NDArrayFactory::create<TypeParam>('c', {bS, iC, iD, iH, iW});
|
|
auto expected = NDArrayFactory::create<TypeParam>('c', {bS, iC, oD, oH, oW}, {20.f, 21.f, 23.f, 24.f, 32.f, 33.f, 35.f, 36.f, 56.f, 57.f, 59.f, 60.f, 68.f, 69.f, 71.f, 72.f, 92.f, 93.f, 95.f, 96.f, 104.f, 105.f, 107.f, 108.f,
|
|
128.f, 129.f, 131.f, 132.f, 140.f, 141.f, 143.f, 144.f, 164.f, 165.f, 167.f, 168.f, 176.f, 177.f, 179.f, 180.f, 200.f, 201.f, 203.f, 204.f, 212.f, 213.f, 215.f, 216.f});
|
|
input.linspace(1.);
|
|
|
|
sd::ops::maxpool3dnew op;
|
|
auto results = op.evaluate({&input}, {}, {kD,kH,kW, sD,sH,sW, pD,pH,pW, dD,dH,dW, paddingMode, 1, dataFormat});
|
|
auto output = results.at(0);
|
|
|
|
ASSERT_EQ(Status::OK(), results.status());
|
|
ASSERT_TRUE(expected.isSameShape(output));
|
|
ASSERT_TRUE(expected.equalsTo(output));
|
|
}
|
|
|
|
//////////////////////////////////////////////////////////////////////
|
|
TYPED_TEST(TypedConvolutionTests2, maxpool3d_test2) {
|
|
|
|
int bS=2, iD=3,iH=4,iW=3, iC=3, kD=2,kH=3,kW=2, sD=1,sH=1,sW=1, pD=0,pH=0,pW=0, dD=1,dH=1,dW=1;
|
|
int oD=3,oH=4,oW=3;
|
|
int paddingMode = 1; // 1-SAME, 0-VALID
|
|
int dataFormat = 1; // 1-NDHWC, 0-NCDHW
|
|
|
|
auto input = NDArrayFactory::create<TypeParam>('c', {bS, iD, iH, iW, iC});
|
|
auto expected = NDArrayFactory::create<TypeParam>('c', {bS, oD, oH, oW, iC}, { 49.f, 50.f, 51.f, 52.f, 53.f, 54.f, 52.f, 53.f, 54.f, 58.f, 59.f, 60.f, 61.f, 62.f, 63.f, 61.f, 62.f, 63.f, 67.f, 68.f, 69.f, 70.f, 71.f, 72.f, 70.f, 71.f, 72.f, 67.f, 68.f, 69.f, 70.f, 71.f, 72.f, 70.f, 71.f, 72.f,
|
|
85.f, 86.f, 87.f, 88.f, 89.f, 90.f, 88.f, 89.f, 90.f, 94.f, 95.f, 96.f, 97.f, 98.f, 99.f, 97.f, 98.f, 99.f, 103.f, 104.f, 105.f, 106.f, 107.f, 108.f, 106.f, 107.f, 108.f, 103.f, 104.f, 105.f, 106.f, 107.f, 108.f, 106.f, 107.f, 108.f,
|
|
85.f, 86.f, 87.f, 88.f, 89.f, 90.f, 88.f, 89.f, 90.f, 94.f, 95.f, 96.f, 97.f, 98.f, 99.f, 97.f, 98.f, 99.f, 103.f, 104.f, 105.f, 106.f, 107.f, 108.f, 106.f, 107.f, 108.f, 103.f, 104.f, 105.f, 106.f, 107.f, 108.f, 106.f, 107.f, 108.f,
|
|
157.f, 158.f, 159.f, 160.f, 161.f, 162.f, 160.f, 161.f, 162.f, 166.f, 167.f, 168.f, 169.f, 170.f, 171.f, 169.f, 170.f, 171.f, 175.f, 176.f, 177.f, 178.f, 179.f, 180.f, 178.f, 179.f, 180.f, 175.f, 176.f, 177.f, 178.f, 179.f, 180.f, 178.f, 179.f, 180.f,
|
|
193.f, 194.f, 195.f, 196.f, 197.f, 198.f, 196.f, 197.f, 198.f, 202.f, 203.f, 204.f, 205.f, 206.f, 207.f, 205.f, 206.f, 207.f, 211.f, 212.f, 213.f, 214.f, 215.f, 216.f, 214.f, 215.f, 216.f, 211.f, 212.f, 213.f, 214.f, 215.f, 216.f, 214.f, 215.f, 216.f,
|
|
193.f, 194.f, 195.f, 196.f, 197.f, 198.f, 196.f, 197.f, 198.f, 202.f, 203.f, 204.f, 205.f, 206.f, 207.f, 205.f, 206.f, 207.f, 211.f, 212.f, 213.f, 214.f, 215.f, 216.f, 214.f, 215.f, 216.f, 211.f, 212.f, 213.f, 214.f, 215.f, 216.f, 214.f, 215.f, 216.f});
|
|
input.linspace(1.);
|
|
|
|
sd::ops::maxpool3dnew op;
|
|
auto results = op.evaluate({&input}, {}, {kD,kH,kW, sD,sH,sW, pD,pH,pW, dD,dH,dW, paddingMode, 1, dataFormat});
|
|
auto output = results.at(0);
|
|
|
|
ASSERT_EQ(Status::OK(), results.status());
|
|
ASSERT_TRUE(expected.isSameShape(output));
|
|
ASSERT_TRUE(expected.equalsTo(output));
|
|
}
|
|
|
|
//////////////////////////////////////////////////////////////////////
|
|
TYPED_TEST(TypedConvolutionTests2, maxpool3d_test3) {
|
|
|
|
int bS=2, iD=3,iH=4,iW=3, iC=3, kD=2,kH=3,kW=2, sD=1,sH=1,sW=1, pD=0,pH=0,pW=0, dD=1,dH=1,dW=1;
|
|
int oD=2,oH=2,oW=2;
|
|
int paddingMode = 0; // 1-SAME, 0-VALID
|
|
int dataFormat = 1; // 1-NDHWC, 0-NCDHW
|
|
|
|
auto input = NDArrayFactory::create<TypeParam>('c', {bS, iD, iH, iW, iC});
|
|
auto expected = NDArrayFactory::create<TypeParam>('c', {bS, oD, oH, oW, iC}, {58.f, 59.f, 60.f, 61.f, 62.f, 63.f, 67.f, 68.f, 69.f, 70.f, 71.f, 72.f, 94.f, 95.f, 96.f, 97.f, 98.f, 99.f, 103.f, 104.f, 105.f, 106.f, 107.f, 108.f,
|
|
166.f, 167.f, 168.f, 169.f, 170.f, 171.f, 175.f, 176.f, 177.f, 178.f, 179.f, 180.f, 202.f, 203.f, 204.f, 205.f, 206.f, 207.f, 211.f, 212.f, 213.f, 214.f, 215.f, 216.f});
|
|
input.linspace(1.);
|
|
|
|
sd::ops::maxpool3dnew op;
|
|
auto results = op.evaluate({&input}, {}, {kD,kH,kW, sD,sH,sW, pD,pH,pW, dD,dH,dW, paddingMode, 1, dataFormat});
|
|
auto output = results.at(0);
|
|
|
|
ASSERT_EQ(Status::OK(), results.status());
|
|
ASSERT_TRUE(expected.isSameShape(output));
|
|
ASSERT_TRUE(expected.equalsTo(output));
|
|
|
|
}
|
|
|
|
//////////////////////////////////////////////////////////////////////
|
|
TYPED_TEST(TypedConvolutionTests2, maxpool3d_test4) {
|
|
|
|
int bS=2, iD=3,iH=4,iW=3, iC=3, kD=2,kH=3,kW=2, sD=1,sH=1,sW=1, pD=1,pH=1,pW=1, dD=1,dH=1,dW=1;
|
|
int oD=4,oH=4,oW=4;
|
|
int paddingMode = 0; // -SAME, 0-VALID
|
|
int dataFormat = 0; // -NDHWC, 0-NCDHW
|
|
|
|
auto input = NDArrayFactory::create<TypeParam>('c', {bS, iC, iD, iH, iW});
|
|
auto expected = NDArrayFactory::create<TypeParam>('c', {bS, iC, oD, oH, oW},{ 4.f, 5.f, 6.f, 6.f, 7.f, 8.f, 9.f, 9.f, 10.f, 11.f, 12.f, 12.f, 10.f, 11.f, 12.f, 12.f, 16.f, 17.f, 18.f, 18.f, 19.f, 20.f, 21.f, 21.f, 22.f, 23.f, 24.f, 24.f, 22.f, 23.f, 24.f, 24.f, 28.f, 29.f, 30.f, 30.f, 31.f, 32.f, 33.f, 33.f, 34.f, 35.f, 36.f, 36.f, 34.f, 35.f, 36.f, 36.f,
|
|
28.f, 29.f, 30.f, 30.f, 31.f, 32.f, 33.f, 33.f, 34.f, 35.f, 36.f, 36.f, 34.f, 35.f, 36.f, 36.f, 40.f, 41.f, 42.f, 42.f, 43.f, 44.f, 45.f, 45.f, 46.f, 47.f, 48.f, 48.f, 46.f, 47.f, 48.f, 48.f, 52.f, 53.f, 54.f, 54.f, 55.f, 56.f, 57.f, 57.f, 58.f, 59.f, 60.f, 60.f, 58.f, 59.f, 60.f, 60.f,
|
|
64.f, 65.f, 66.f, 66.f, 67.f, 68.f, 69.f, 69.f, 70.f, 71.f, 72.f, 72.f, 70.f, 71.f, 72.f, 72.f, 64.f, 65.f, 66.f, 66.f, 67.f, 68.f, 69.f, 69.f, 70.f, 71.f, 72.f, 72.f, 70.f, 71.f, 72.f, 72.f, 76.f, 77.f, 78.f, 78.f, 79.f, 80.f, 81.f, 81.f, 82.f, 83.f, 84.f, 84.f, 82.f, 83.f, 84.f, 84.f,
|
|
88.f, 89.f, 90.f, 90.f, 91.f, 92.f, 93.f, 93.f, 94.f, 95.f, 96.f, 96.f, 94.f, 95.f, 96.f, 96.f, 100.f, 101.f, 102.f, 102.f, 103.f, 104.f, 105.f, 105.f, 106.f, 107.f, 108.f, 108.f, 106.f, 107.f, 108.f, 108.f, 100.f, 101.f, 102.f, 102.f, 103.f, 104.f, 105.f, 105.f, 106.f, 107.f, 108.f, 108.f, 106.f, 107.f, 108.f, 108.f,
|
|
112.f, 113.f, 114.f, 114.f, 115.f, 116.f, 117.f, 117.f, 118.f, 119.f, 120.f, 120.f, 118.f, 119.f, 120.f, 120.f, 124.f, 125.f, 126.f, 126.f, 127.f, 128.f, 129.f, 129.f, 130.f, 131.f, 132.f, 132.f, 130.f, 131.f, 132.f, 132.f, 136.f, 137.f, 138.f, 138.f, 139.f, 140.f, 141.f, 141.f, 142.f, 143.f, 144.f, 144.f, 142.f, 143.f, 144.f, 144.f,
|
|
136.f, 137.f, 138.f, 138.f, 139.f, 140.f, 141.f, 141.f, 142.f, 143.f, 144.f, 144.f, 142.f, 143.f, 144.f, 144.f, 148.f, 149.f, 150.f, 150.f, 151.f, 152.f, 153.f, 153.f, 154.f, 155.f, 156.f, 156.f, 154.f, 155.f, 156.f, 156.f, 160.f, 161.f, 162.f, 162.f, 163.f, 164.f, 165.f, 165.f, 166.f, 167.f, 168.f, 168.f, 166.f, 167.f, 168.f, 168.f,
|
|
172.f, 173.f, 174.f, 174.f, 175.f, 176.f, 177.f, 177.f, 178.f, 179.f, 180.f, 180.f, 178.f, 179.f, 180.f, 180.f, 172.f, 173.f, 174.f, 174.f, 175.f, 176.f, 177.f, 177.f, 178.f, 179.f, 180.f, 180.f, 178.f, 179.f, 180.f, 180.f, 184.f, 185.f, 186.f, 186.f, 187.f, 188.f, 189.f, 189.f, 190.f, 191.f, 192.f, 192.f, 190.f, 191.f, 192.f, 192.f,
|
|
196.f, 197.f, 198.f, 198.f, 199.f, 200.f, 201.f, 201.f, 202.f, 203.f, 204.f, 204.f, 202.f, 203.f, 204.f, 204.f, 208.f, 209.f, 210.f, 210.f, 211.f, 212.f, 213.f, 213.f, 214.f, 215.f, 216.f, 216.f, 214.f, 215.f, 216.f, 216.f, 208.f, 209.f, 210.f, 210.f, 211.f, 212.f, 213.f, 213.f, 214.f, 215.f, 216.f, 216.f, 214.f, 215.f, 216.f, 216.f});
|
|
input.linspace(1.);
|
|
|
|
sd::ops::maxpool3dnew op;
|
|
auto results = op.evaluate({&input}, {}, {kD,kH,kW, sD,sH,sW, pD,pH,pW, dD,dH,dW, paddingMode, 1, dataFormat});
|
|
auto output = results.at(0);
|
|
|
|
ASSERT_EQ(Status::OK(), results.status());
|
|
ASSERT_TRUE(expected.isSameShape(output));
|
|
ASSERT_TRUE(expected.equalsTo(output));
|
|
}
|
|
|
|
//////////////////////////////////////////////////////////////////////
|
|
TYPED_TEST(TypedConvolutionTests2, avgpool3d_bp_test1) {
|
|
|
|
int bS=2, iD=3,iH=4,iW=3, iC=3, kD=2,kH=3,kW=2, sD=1,sH=1,sW=1, pD=0,pH=0,pW=0, dD=1,dH=1,dW=1;
|
|
int oD=2,oH=2,oW=2;
|
|
int paddingMode = 0; // 1-SAME, 0-VALID
|
|
int dataFormat = 0; // 1-NDHWC, 0-NCDHW
|
|
|
|
auto input = NDArrayFactory::create<TypeParam>('c', {bS, iC, iD, iH, iW});
|
|
auto gradO = NDArrayFactory::create<TypeParam>('c', {bS, iC, oD, oH, oW});
|
|
auto expected = NDArrayFactory::create<TypeParam>('c', {bS, iC, iD, iH, iW}, {0.166667f, 0.333333f, 0.166667f, 0.333333f, 0.666667f, 0.333333f, 0.333333f, 0.666667f, 0.333333f, 0.166667f, 0.333333f, 0.166667f, 0.333333f, 0.666667f, 0.333333f, 0.666667f, 1.333333f, 0.666667f, 0.666667f, 1.333333f, 0.666667f, 0.333333f, 0.666667f, 0.333333f,
|
|
0.166667f, 0.333333f, 0.166667f, 0.333333f, 0.666667f, 0.333333f, 0.333333f, 0.666667f, 0.333333f, 0.166667f, 0.333333f, 0.166667f, 0.166667f, 0.333333f, 0.166667f, 0.333333f, 0.666667f, 0.333333f, 0.333333f, 0.666667f, 0.333333f, 0.166667f, 0.333333f, 0.166667f,
|
|
0.333333f, 0.666667f, 0.333333f, 0.666667f, 1.333333f, 0.666667f, 0.666667f, 1.333333f, 0.666667f, 0.333333f, 0.666667f, 0.333333f, 0.166667f, 0.333333f, 0.166667f, 0.333333f, 0.666667f, 0.333333f, 0.333333f, 0.666667f, 0.333333f, 0.166667f, 0.333333f, 0.166667f,
|
|
0.166667f, 0.333333f, 0.166667f, 0.333333f, 0.666667f, 0.333333f, 0.333333f, 0.666667f, 0.333333f, 0.166667f, 0.333333f, 0.166667f, 0.333333f, 0.666667f, 0.333333f, 0.666667f, 1.333333f, 0.666667f, 0.666667f, 1.333333f, 0.666667f, 0.333333f, 0.666667f, 0.333333f,
|
|
0.166667f, 0.333333f, 0.166667f, 0.333333f, 0.666667f, 0.333333f, 0.333333f, 0.666667f, 0.333333f, 0.166667f, 0.333333f, 0.166667f, 0.166667f, 0.333333f, 0.166667f, 0.333333f, 0.666667f, 0.333333f, 0.333333f, 0.666667f, 0.333333f, 0.166667f, 0.333333f, 0.166667f,
|
|
0.333333f, 0.666667f, 0.333333f, 0.666667f, 1.333333f, 0.666667f, 0.666667f, 1.333333f, 0.666667f, 0.333333f, 0.666667f, 0.333333f, 0.166667f, 0.333333f, 0.166667f, 0.333333f, 0.666667f, 0.333333f, 0.333333f, 0.666667f, 0.333333f, 0.166667f, 0.333333f, 0.166667f,
|
|
0.166667f, 0.333333f, 0.166667f, 0.333333f, 0.666667f, 0.333333f, 0.333333f, 0.666667f, 0.333333f, 0.166667f, 0.333333f, 0.166667f, 0.333333f, 0.666667f, 0.333333f, 0.666667f, 1.333333f, 0.666667f, 0.666667f, 1.333333f, 0.666667f, 0.333333f, 0.666667f, 0.333333f,
|
|
0.166667f, 0.333333f, 0.166667f, 0.333333f, 0.666667f, 0.333333f, 0.333333f, 0.666667f, 0.333333f, 0.166667f, 0.333333f, 0.166667f, 0.166667f, 0.333333f, 0.166667f, 0.333333f, 0.666667f, 0.333333f, 0.333333f, 0.666667f, 0.333333f, 0.166667f, 0.333333f, 0.166667f,
|
|
0.333333f, 0.666667f, 0.333333f, 0.666667f, 1.333333f, 0.666667f, 0.666667f, 1.333333f, 0.666667f, 0.333333f, 0.666667f, 0.333333f, 0.166667f, 0.333333f, 0.166667f, 0.333333f, 0.666667f, 0.333333f, 0.333333f, 0.666667f, 0.333333f, 0.166667f, 0.333333f, 0.166667f});
|
|
input.linspace(1.);
|
|
gradO = 2.;
|
|
|
|
sd::ops::avgpool3dnew_bp op;
|
|
auto results = op.evaluate({&input, &gradO}, {}, {kD,kH,kW, sD,sH,sW, pD,pH,pW, dD,dH,dW, paddingMode, 1, dataFormat});
|
|
auto output = results.at(0);
|
|
|
|
ASSERT_EQ(Status::OK(), results.status());
|
|
ASSERT_TRUE(expected.isSameShape(output));
|
|
ASSERT_TRUE(expected.equalsTo(output));
|
|
|
|
}
|
|
|
|
|
|
//////////////////////////////////////////////////////////////////////
|
|
TYPED_TEST(TypedConvolutionTests2, avgpool3d_bp_test2) {
|
|
|
|
int bS=2, iD=3,iH=4,iW=3, iC=3, kD=2,kH=3,kW=2, sD=1,sH=1,sW=1, pD=1,pH=1,pW=1, dD=1,dH=1,dW=1;
|
|
int oD=4,oH=4,oW=4;
|
|
int paddingMode = 0; // 1-SAME, 0-VALID
|
|
int dataFormat = 0; // 1-NDHWC, 0-NCDHW
|
|
|
|
auto input = NDArrayFactory::create<TypeParam>('c', {bS, iC, iD, iH, iW});
|
|
auto gradO = NDArrayFactory::create<TypeParam>('c', {bS, iC, oD, oH, oW});
|
|
auto expected = NDArrayFactory::create<TypeParam>('c', {bS, iC, iD, iH, iW}, {1.333333f, 1.333333f, 1.333333f, 2.f, 2.f, 2.f, 2.f, 2.f, 2.f, 1.333333f, 1.333333f, 1.333333f, 1.333333f, 1.333333f, 1.333333f, 2.f, 2.f, 2.f, 2.f, 2.f, 2.f, 1.333333f, 1.333333f, 1.333333f,
|
|
1.333333f, 1.333333f, 1.333333f, 2.f, 2.f, 2.f, 2.f, 2.f, 2.f, 1.333333f, 1.333333f, 1.333333f, 1.333333f, 1.333333f, 1.333333f, 2.f, 2.f, 2.f, 2.f, 2.f, 2.f, 1.333333f, 1.333333f, 1.333333f,
|
|
1.333333f, 1.333333f, 1.333333f, 2.f, 2.f, 2.f, 2.f, 2.f, 2.f, 1.333333f, 1.333333f, 1.333333f, 1.333333f, 1.333333f, 1.333333f, 2.f, 2.f, 2.f, 2.f, 2.f, 2.f, 1.333333f, 1.333333f, 1.333333f,
|
|
1.333333f, 1.333333f, 1.333333f, 2.f, 2.f, 2.f, 2.f, 2.f, 2.f, 1.333333f, 1.333333f, 1.333333f, 1.333333f, 1.333333f, 1.333333f, 2.f, 2.f, 2.f, 2.f, 2.f, 2.f, 1.333333f, 1.333333f, 1.333333f,
|
|
1.333333f, 1.333333f, 1.333333f, 2.f, 2.f, 2.f, 2.f, 2.f, 2.f, 1.333333f, 1.333333f, 1.333333f, 1.333333f, 1.333333f, 1.333333f, 2.f, 2.f, 2.f, 2.f, 2.f, 2.f, 1.333333f, 1.333333f, 1.333333f,
|
|
1.333333f, 1.333333f, 1.333333f, 2.f, 2.f, 2.f, 2.f, 2.f, 2.f, 1.333333f, 1.333333f, 1.333333f, 1.333333f, 1.333333f, 1.333333f, 2.f, 2.f, 2.f, 2.f, 2.f, 2.f, 1.333333f, 1.333333f, 1.333333f,
|
|
1.333333f, 1.333333f, 1.333333f, 2.f, 2.f, 2.f, 2.f, 2.f, 2.f, 1.333333f, 1.333333f, 1.333333f, 1.333333f, 1.333333f, 1.333333f, 2.f, 2.f, 2.f, 2.f, 2.f, 2.f, 1.333333f, 1.333333f, 1.333333f,
|
|
1.333333f, 1.333333f, 1.333333f, 2.f, 2.f, 2.f, 2.f, 2.f, 2.f, 1.333333f, 1.333333f, 1.333333f, 1.333333f, 1.333333f, 1.333333f, 2.f, 2.f, 2.f, 2.f, 2.f, 2.f, 1.333333f, 1.333333f, 1.333333f,
|
|
1.333333f, 1.333333f, 1.333333f, 2.f, 2.f, 2.f, 2.f, 2.f, 2.f, 1.333333f, 1.333333f, 1.333333f, 1.333333f, 1.333333f, 1.333333f, 2.f, 2.f, 2.f, 2.f, 2.f, 2.f, 1.333333f, 1.333333f, 1.333333f});
|
|
input.linspace(1.);
|
|
gradO = 2.;
|
|
|
|
sd::ops::avgpool3dnew_bp op;
|
|
auto results = op.evaluate({&input, &gradO}, {}, {kD,kH,kW, sD,sH,sW, pD,pH,pW, dD,dH,dW, paddingMode, 1, dataFormat});
|
|
auto output = results.at(0);
|
|
|
|
// output->printBuffer();
|
|
|
|
ASSERT_EQ(Status::OK(), results.status());
|
|
ASSERT_TRUE(expected.isSameShape(output));
|
|
ASSERT_TRUE(expected.equalsTo(output));
|
|
|
|
}
|
|
|
|
|
|
//////////////////////////////////////////////////////////////////////
|
|
TYPED_TEST(TypedConvolutionTests2, avgpool3d_bp_test3) {
|
|
|
|
int bS=2, iD=3,iH=4,iW=3, iC=3, kD=2,kH=3,kW=2, sD=1,sH=1,sW=1, pD=0,pH=0,pW=0, dD=1,dH=1,dW=1;
|
|
int oD=3,oH=4,oW=3;
|
|
int paddingMode = 1; // 1-SAME, 0-VALID
|
|
int dataFormat = 1; // 1-NDHWC, 0-NCDHW
|
|
|
|
auto input = NDArrayFactory::create<TypeParam>('c', {bS, iD, iH, iW, iC});
|
|
auto gradO = NDArrayFactory::create<TypeParam>('c', {bS, oD, oH, oW, iC});
|
|
auto expected = NDArrayFactory::create<TypeParam>('c', {bS, iD, iH, iW, iC}, {0.41667f, 0.41667f, 0.41667f, 0.83333f, 0.83333f, 0.83333f, 1.25f, 1.25f, 1.25f, 0.58333f, 0.58333f, 0.58333f, 1.16667f, 1.16667f, 1.16667f, 1.75f, 1.75f, 1.75f, 0.58333f, 0.58333f, 0.58333f, 1.16667f, 1.16667f, 1.16667f, 1.75f, 1.75f, 1.75f,
|
|
0.41667f, 0.41667f, 0.41667f, 0.83333f, 0.83333f, 0.83333f, 1.25f, 1.25f, 1.25f, 0.83333f, 0.83333f, 0.83333f, 1.66667f, 1.66667f, 1.66667f, 2.5f, 2.5f, 2.5f, 1.16667f, 1.16667f, 1.16667f, 2.33333f, 2.33333f, 2.33333f, 3.5f, 3.5f, 3.5f,
|
|
1.16667f, 1.16667f, 1.16667f, 2.33333f, 2.33333f, 2.33333f, 3.5f, 3.5f, 3.5f, 0.83333f, 0.83333f, 0.83333f, 1.66667f, 1.66667f, 1.66667f, 2.5f, 2.5f, 2.5f, 1.25f, 1.25f, 1.25f, 2.5f, 2.5f, 2.5f, 3.75f, 3.75f, 3.75f,
|
|
1.75f, 1.75f, 1.75f, 3.5f, 3.5f, 3.5f, 5.25f, 5.25f, 5.25f, 1.75f, 1.75f, 1.75f, 3.5f, 3.5f, 3.5f, 5.25f, 5.25f, 5.25f, 1.25f, 1.25f, 1.25f, 2.5f, 2.5f, 2.5f, 3.75f, 3.75f, 3.75f,
|
|
0.41667f, 0.41667f, 0.41667f, 0.83333f, 0.83333f, 0.83333f, 1.25f, 1.25f, 1.25f, 0.58333f, 0.58333f, 0.58333f, 1.16667f, 1.16667f, 1.16667f, 1.75f, 1.75f, 1.75f, 0.58333f, 0.58333f, 0.58333f, 1.16667f, 1.16667f, 1.16667f, 1.75f, 1.75f, 1.75f,
|
|
0.41667f, 0.41667f, 0.41667f, 0.83333f, 0.83333f, 0.83333f, 1.25f, 1.25f, 1.25f, 0.83333f, 0.83333f, 0.83333f, 1.66667f, 1.66667f, 1.66667f, 2.5f, 2.5f, 2.5f, 1.16667f, 1.16667f, 1.16667f, 2.33333f, 2.33333f, 2.33333f, 3.5f, 3.5f, 3.5f,
|
|
1.16667f, 1.16667f, 1.16667f, 2.33333f, 2.33333f, 2.33333f, 3.5f, 3.5f, 3.5f, 0.83333f, 0.83333f, 0.83333f, 1.66667f, 1.66667f, 1.66667f, 2.5f, 2.5f, 2.5f, 1.25f, 1.25f, 1.25f, 2.5f, 2.5f, 2.5f, 3.75f, 3.75f, 3.75f,
|
|
1.75f, 1.75f, 1.75f, 3.5f, 3.5f, 3.5f, 5.25f, 5.25f, 5.25f, 1.75f, 1.75f, 1.75f, 3.5f, 3.5f, 3.5f, 5.25f, 5.25f, 5.25f, 1.25f, 1.25f, 1.25f, 2.5f, 2.5f, 2.5f, 3.75f, 3.75f, 3.75f});
|
|
input.linspace(1.);
|
|
gradO = 2.;
|
|
|
|
sd::ops::avgpool3dnew_bp op;
|
|
auto results = op.evaluate({&input, &gradO}, {}, {kD,kH,kW, sD,sH,sW, pD,pH,pW, dD,dH,dW, paddingMode, 0, dataFormat});
|
|
auto output = results.at(0);
|
|
|
|
ASSERT_EQ(Status::OK(), results.status());
|
|
ASSERT_TRUE(expected.isSameShape(output));
|
|
ASSERT_TRUE(expected.equalsTo(output));
|
|
}
|
|
|
|
//////////////////////////////////////////////////////////////////////
|
|
TYPED_TEST(TypedConvolutionTests2, avgpool3d_bp_test4) {
|
|
|
|
int bS=2, iD=3,iH=4,iW=3, iC=3, kD=2,kH=3,kW=2, sD=1,sH=1,sW=1, pD=0,pH=0,pW=0, dD=1,dH=1,dW=1;
|
|
int oD=3,oH=4,oW=3;
|
|
int paddingMode = 0; // 1-SAME, 0-VALID
|
|
int dataFormat = 1; // 1-NDHWC, 0-NCDHW
|
|
|
|
auto input = NDArrayFactory::create<TypeParam>('c', {bS, iD, iH, iW, iC});
|
|
auto gradO = NDArrayFactory::create<TypeParam>('c', {bS, oD, oH, oW, iC});
|
|
auto expected = NDArrayFactory::create<TypeParam>('c', {bS, iD, iH, iW, iC}, {0.16667f, 0.16667f, 0.16667f, 0.33333f, 0.33333f, 0.33333f, 0.5f, 0.5f, 0.5f, 0.33333f, 0.33333f, 0.33333f, 0.66667f, 0.66667f, 0.66667f, 1.f, 1.f, 1.f, 0.58333f, 0.58333f, 0.58333f, 1.16667f, 1.16667f, 1.16667f, 1.75f, 1.75f, 1.75f,
|
|
0.91667f, 0.91667f, 0.91667f, 1.83333f, 1.83333f, 1.83333f, 2.75f, 2.75f, 2.75f, 0.33333f, 0.33333f, 0.33333f, 0.66667f, 0.66667f, 0.66667f, 1.f, 1.f, 1.f, 0.66667f, 0.66667f, 0.66667f, 1.33333f, 1.33333f, 1.33333f, 2.f, 2.f, 2.f,
|
|
1.16667f, 1.16667f, 1.16667f, 2.33333f, 2.33333f, 2.33333f, 3.5f, 3.5f, 3.5f, 1.83333f, 1.83333f, 1.83333f, 3.66667f, 3.66667f, 3.66667f, 5.5f, 5.5f, 5.5f, 0.5f, 0.5f, 0.5f, 1.f, 1.f, 1.f, 1.5f, 1.5f, 1.5f,
|
|
1.f, 1.f, 1.f, 2.f, 2.f, 2.f, 3.f, 3.f, 3.f, 1.75f, 1.75f, 1.75f, 3.5f, 3.5f, 3.5f, 5.25f, 5.25f, 5.25f, 2.75f, 2.75f, 2.75f, 5.5f, 5.5f, 5.5f, 8.25f, 8.25f, 8.25f,
|
|
0.16667f, 0.16667f, 0.16667f, 0.33333f, 0.33333f, 0.33333f, 0.5f, 0.5f, 0.5f, 0.33333f, 0.33333f, 0.33333f, 0.66667f, 0.66667f, 0.66667f, 1.f, 1.f, 1.f, 0.58333f, 0.58333f, 0.58333f, 1.16667f, 1.16667f, 1.16667f, 1.75f, 1.75f, 1.75f,
|
|
0.91667f, 0.91667f, 0.91667f, 1.83333f, 1.83333f, 1.83333f, 2.75f, 2.75f, 2.75f, 0.33333f, 0.33333f, 0.33333f, 0.66667f, 0.66667f, 0.66667f, 1.f, 1.f, 1.f, 0.66667f, 0.66667f, 0.66667f, 1.33333f, 1.33333f, 1.33333f, 2.f, 2.f, 2.f,
|
|
1.16667f, 1.16667f, 1.16667f, 2.33333f, 2.33333f, 2.33333f, 3.5f, 3.5f, 3.5f, 1.83333f, 1.83333f, 1.83333f, 3.66667f, 3.66667f, 3.66667f, 5.5f, 5.5f, 5.5f, 0.5f, 0.5f, 0.5f, 1.f, 1.f, 1.f, 1.5f, 1.5f, 1.5f,
|
|
1.f, 1.f, 1.f, 2.f, 2.f, 2.f, 3.f, 3.f, 3.f, 1.75f, 1.75f, 1.75f, 3.5f, 3.5f, 3.5f, 5.25f, 5.25f, 5.25f, 2.75f, 2.75f, 2.75f, 5.5f, 5.5f, 5.5f, 8.25f, 8.25f, 8.25f});
|
|
input.linspace(1.);
|
|
gradO = 2.;
|
|
|
|
sd::ops::avgpool3dnew_bp op;
|
|
auto results = op.evaluate({&input, &gradO}, {}, {kD,kH,kW, sD,sH,sW, pD,pH,pW, dD,dH,dW, paddingMode, 0, dataFormat});
|
|
auto output = results.at(0);
|
|
|
|
ASSERT_EQ(Status::OK(), results.status());
|
|
ASSERT_TRUE(expected.isSameShape(output));
|
|
ASSERT_TRUE(expected.equalsTo(output));
|
|
|
|
}
|
|
|
|
//////////////////////////////////////////////////////////////////////
|
|
TYPED_TEST(TypedConvolutionTests2, maxpool3d_bp_test1) {
|
|
|
|
int bS=2, iD=3,iH=4,iW=3, iC=3, kD=2,kH=3,kW=2, sD=1,sH=1,sW=1, pD=0,pH=0,pW=0, dD=1,dH=1,dW=1;
|
|
int oD=2,oH=2,oW=2;
|
|
int paddingMode = 0; // 1-SAME, 0-VALID
|
|
int dataFormat = 0; // 1-NDHWC, 0-NCDHW
|
|
|
|
auto input = NDArrayFactory::create<TypeParam>('c', {bS, iC, iD, iH, iW});
|
|
auto gradO = NDArrayFactory::create<TypeParam>('c', {bS, iC, oD, oH, oW});
|
|
auto expected = NDArrayFactory::create<TypeParam>('c', {bS, iC, iD, iH, iW}, {0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.1f, 0.2f, 0.f, 0.3f, 0.4f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.5f, 0.6f, 0.f, 0.7f, 0.8f,
|
|
0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.9f, 1.f, 0.f, 1.1f, 1.2f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 1.3f, 1.4f, 0.f, 1.5f, 1.6f,
|
|
0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 1.7f, 1.8f, 0.f, 1.9f, 2.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 2.1f, 2.2f, 0.f, 2.3f, 2.4f,
|
|
0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 2.5f, 2.6f, 0.f, 2.7f, 2.8f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 2.9f, 3.f, 0.f, 3.1f, 3.2f,
|
|
0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 3.3f, 3.4f, 0.f, 3.5f, 3.6f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 3.7f, 3.8f, 0.f, 3.9f, 4.f,
|
|
0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 4.1f, 4.2f, 0.f, 4.3f, 4.4f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 4.5f, 4.6f, 0.f, 4.7f, 4.8f});
|
|
|
|
input.linspace(1.);
|
|
gradO.linspace(0.1, 0.1);
|
|
|
|
sd::ops::maxpool3dnew_bp op;
|
|
auto results = op.evaluate({&input, &gradO}, {}, {kD,kH,kW, sD,sH,sW, pD,pH,pW, dD,dH,dW, paddingMode, 1, dataFormat});
|
|
auto output = results.at(0);
|
|
|
|
ASSERT_EQ(Status::OK(), results.status());
|
|
ASSERT_TRUE(expected.isSameShape(output));
|
|
ASSERT_TRUE(expected.equalsTo(output));
|
|
}
|
|
|
|
//////////////////////////////////////////////////////////////////////
|
|
TYPED_TEST(TypedConvolutionTests2, maxpool3d_bp_test2) {
|
|
|
|
int bS=2, iD=3,iH=4,iW=3, iC=3, kD=2,kH=3,kW=2, sD=1,sH=1,sW=1, pD=1,pH=1,pW=1, dD=1,dH=1,dW=1;
|
|
int oD=4,oH=4,oW=4;
|
|
int paddingMode = 0; // 1-SAME, 0-VALID
|
|
int dataFormat = 0; // 1-NDHWC, 0-NCDHW
|
|
|
|
auto input = NDArrayFactory::create<TypeParam>('c', {bS, iC, iD, iH, iW});
|
|
auto gradO = NDArrayFactory::create<TypeParam>('c', {bS, iC, oD, oH, oW});
|
|
auto expected = NDArrayFactory::create<TypeParam>('c', {bS, iC, iD, iH, iW}, {0.000e+00f, 0.000e+00f, 0.000e+00f, 1.000e-01f, 2.000e-01f, 7.000e-01f, 5.000e-01f, 6.000e-01f, 1.500e+00f, 2.200e+00f, 2.400e+00f, 5.400e+00f, 0.000e+00f, 0.000e+00f, 0.000e+00f, 1.700e+00f, 1.800e+00f, 3.900e+00f, 2.100e+00f, 2.200e+00f, 4.700e+00f, 5.400e+00f, 5.600e+00f, 1.180e+01f,
|
|
0.000e+00f, 0.000e+00f, 0.000e+00f, 8.200e+00f, 8.400e+00f, 1.740e+01f, 9.000e+00f, 9.200e+00f, 1.900e+01f, 2.040e+01f, 2.080e+01f, 4.280e+01f, 0.000e+00f, 0.000e+00f, 0.000e+00f, 6.500e+00f, 6.600e+00f, 1.350e+01f, 6.900e+00f, 7.000e+00f, 1.430e+01f, 1.500e+01f, 1.520e+01f, 3.100e+01f,
|
|
0.000e+00f, 0.000e+00f, 0.000e+00f, 8.100e+00f, 8.200e+00f, 1.670e+01f, 8.500e+00f, 8.600e+00f, 1.750e+01f, 1.820e+01f, 1.840e+01f, 3.740e+01f, 0.000e+00f, 0.000e+00f, 0.000e+00f, 2.100e+01f, 2.120e+01f, 4.300e+01f, 2.180e+01f, 2.200e+01f, 4.460e+01f, 4.600e+01f, 4.640e+01f, 9.400e+01f,
|
|
0.000e+00f, 0.000e+00f, 0.000e+00f, 1.290e+01f, 1.300e+01f, 2.630e+01f, 1.330e+01f, 1.340e+01f, 2.710e+01f, 2.780e+01f, 2.800e+01f, 5.660e+01f, 0.000e+00f, 0.000e+00f, 0.000e+00f, 1.450e+01f, 1.460e+01f, 2.950e+01f, 1.490e+01f, 1.500e+01f, 3.030e+01f, 3.100e+01f, 3.120e+01f, 6.300e+01f,
|
|
0.000e+00f, 0.000e+00f, 0.000e+00f, 3.380e+01f, 3.400e+01f, 6.860e+01f, 3.460e+01f, 3.480e+01f, 7.020e+01f, 7.160e+01f, 7.200e+01f, 1.452e+02f, 0.000e+00f, 0.000e+00f, 0.000e+00f, 1.930e+01f, 1.940e+01f, 3.910e+01f, 1.970e+01f, 1.980e+01f, 3.990e+01f, 4.060e+01f, 4.080e+01f, 8.220e+01f,
|
|
0.000e+00f, 0.000e+00f, 0.000e+00f, 2.090e+01f, 2.100e+01f, 4.230e+01f, 2.130e+01f, 2.140e+01f, 4.310e+01f, 4.380e+01f, 4.400e+01f, 8.860e+01f, 0.000e+00f, 0.000e+00f, 0.000e+00f, 4.660e+01f, 4.680e+01f, 9.420e+01f, 4.740e+01f, 4.760e+01f, 9.580e+01f, 9.720e+01f, 9.760e+01f, 1.964e+02f,
|
|
0.000e+00f, 0.000e+00f, 0.000e+00f, 2.570e+01f, 2.580e+01f, 5.190e+01f, 2.610e+01f, 2.620e+01f, 5.270e+01f, 5.340e+01f, 5.360e+01f, 1.078e+02f, 0.000e+00f, 0.000e+00f, 0.000e+00f, 2.730e+01f, 2.740e+01f, 5.510e+01f, 2.770e+01f, 2.780e+01f, 5.590e+01f, 5.660e+01f, 5.680e+01f, 1.142e+02f,
|
|
0.000e+00f, 0.000e+00f, 0.000e+00f, 5.940e+01f, 5.960e+01f, 1.198e+02f, 6.020e+01f, 6.040e+01f, 1.214e+02f, 1.228e+02f, 1.232e+02f, 2.476e+02f, 0.000e+00f, 0.000e+00f, 0.000e+00f, 3.210e+01f, 3.220e+01f, 6.470e+01f, 3.250e+01f, 3.260e+01f, 6.550e+01f, 6.620e+01f, 6.640e+01f, 1.334e+02f,
|
|
0.000e+00f, 0.000e+00f, 0.000e+00f, 3.370e+01f, 3.380e+01f, 6.790e+01f, 3.410e+01f, 3.420e+01f, 6.870e+01f, 6.940e+01f, 6.960e+01f, 1.398e+02f, 0.000e+00f, 0.000e+00f, 0.000e+00f, 7.220e+01f, 7.240e+01f, 1.454e+02f, 7.300e+01f, 7.320e+01f, 1.470e+02f, 1.484e+02f, 1.488e+02f, 2.988e+02f});
|
|
input.linspace(1.);
|
|
gradO.linspace(0.1, 0.1);
|
|
|
|
sd::ops::maxpool3dnew_bp op;
|
|
auto results = op.evaluate({&input, &gradO}, {}, {kD,kH,kW, sD,sH,sW, pD,pH,pW, dD,dH,dW, paddingMode, 1, dataFormat});
|
|
auto output = results.at(0);
|
|
|
|
ASSERT_EQ(Status::OK(), results.status());
|
|
ASSERT_TRUE(expected.isSameShape(output));
|
|
ASSERT_TRUE(expected.equalsTo(output));
|
|
}
|
|
|
|
//////////////////////////////////////////////////////////////////////
|
|
TYPED_TEST(TypedConvolutionTests2, maxpool3d_bp_test3) {
|
|
|
|
int bS=2, iD=3,iH=4,iW=3, iC=3, kD=2,kH=3,kW=2, sD=1,sH=1,sW=1, pD=0,pH=0,pW=0, dD=1,dH=1,dW=1;
|
|
int oD=3,oH=4,oW=3;
|
|
int paddingMode = 1; // 1-SAME, 0-VALID
|
|
int dataFormat = 1; // 1-NDHWC, 0-NCDHW
|
|
|
|
auto input = NDArrayFactory::create<TypeParam>('c', {bS, iD, iH, iW, iC});
|
|
auto gradO = NDArrayFactory::create<TypeParam>('c', {bS, oD, oH, oW, iC});
|
|
auto expected = NDArrayFactory::create<TypeParam>('c', {bS, iD, iH, iW, iC}, { 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f,
|
|
0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.1f, 0.2f, 0.3f, 1.1f, 1.3f, 1.5f,
|
|
0.f, 0.f, 0.f, 1.f, 1.1f, 1.2f, 2.9f, 3.1f, 3.3f, 0.f, 0.f, 0.f, 4.7f, 4.9f, 5.1f, 11.2f, 11.6f, 12.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f,
|
|
0.f, 0.f, 0.f, 11.f, 11.2f, 11.4f, 23.8f, 24.2f, 24.6f, 0.f, 0.f, 0.f, 12.8f, 13.f, 13.2f, 27.4f, 27.8f, 28.2f, 0.f, 0.f, 0.f, 31.f, 31.4f, 31.8f, 65.6f, 66.39999f, 67.2f,
|
|
0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f,
|
|
0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 10.9f, 11.f, 11.1f, 22.7f, 22.9f, 23.1f,
|
|
0.f, 0.f, 0.f, 11.8f, 11.9f, 12.f, 24.5f, 24.7f, 24.9f, 0.f, 0.f, 0.f, 26.3f, 26.5f, 26.7f, 54.4f, 54.8f, 55.2f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f,
|
|
0.f, 0.f, 0.f, 32.6f, 32.8f, 33.f, 67.f, 67.4f, 67.8f, 0.f, 0.f, 0.f, 34.4f, 34.6f, 34.8f, 70.6f, 71.f, 71.4f, 0.f, 0.f, 0.f, 74.2f, 74.6f, 75.f, 152.f, 152.8f, 153.6f});
|
|
input.linspace(1.);
|
|
gradO.linspace(0.1, 0.1);
|
|
|
|
sd::ops::maxpool3dnew_bp op;
|
|
auto results = op.evaluate({&input, &gradO}, {}, {kD,kH,kW, sD,sH,sW, pD,pH,pW, dD,dH,dW, paddingMode, 1, dataFormat});
|
|
auto output = results.at(0);
|
|
|
|
ASSERT_EQ(Status::OK(), results.status());
|
|
ASSERT_TRUE(expected.isSameShape(output));
|
|
ASSERT_TRUE(expected.equalsTo(output));
|
|
|
|
}
|
|
|
|
|
|
//////////////////////////////////////////////////////////////////////
|
|
TYPED_TEST(TypedConvolutionTests2, maxpool3d_bp_test4) {
|
|
|
|
int bS=2, iD=3,iH=4,iW=3, iC=3, kD=2,kH=3,kW=2, sD=1,sH=1,sW=1, pD=0,pH=0,pW=0, dD=1,dH=1,dW=1;
|
|
int oD=3,oH=4,oW=3;
|
|
int paddingMode = 0; // 1-SAME, 0-VALID
|
|
int dataFormat = 1; // 1-NDHWC, 0-NCDHW
|
|
|
|
auto input = NDArrayFactory::create<TypeParam>('c', {bS, iD, iH, iW, iC});
|
|
auto gradO = NDArrayFactory::create<TypeParam>('c', {bS, oD, oH, oW, iC});
|
|
auto expected = NDArrayFactory::create<TypeParam>('c', {bS, iD, iH, iW, iC}, {0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f,
|
|
0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.1f, 0.2f, 0.3f, 1.1f, 1.3f, 1.5f, 0.f, 0.f, 0.f, 5.7f, 6.f, 6.3f,
|
|
14.1f, 14.7f, 15.3f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 11.f, 11.2f, 11.4f, 23.8f, 24.2f,
|
|
24.6f, 0.f, 0.f, 0.f, 43.8f, 44.4f, 45.f, 93.f, 94.2f, 95.4f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f,
|
|
0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f,
|
|
10.9f, 11.f, 11.1f, 22.7f, 22.9f, 23.1f, 0.f, 0.f, 0.f, 38.1f, 38.4f, 38.7f, 78.9f, 79.5f, 80.1f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f,
|
|
0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 32.6f, 32.8f, 33.f, 67.f, 67.4f, 67.8f, 0.f, 0.f, 0.f, 108.6f, 109.2f, 109.8f, 222.6f, 223.8f, 225.f});
|
|
input.linspace(1.);
|
|
gradO.linspace(0.1, 0.1);
|
|
|
|
sd::ops::maxpool3dnew_bp op;
|
|
auto results = op.evaluate({&input, &gradO}, {}, {kD,kH,kW, sD,sH,sW, pD,pH,pW, dD,dH,dW, paddingMode, 1, dataFormat});
|
|
auto output = results.at(0);
|
|
|
|
ASSERT_EQ(Status::OK(), results.status());
|
|
ASSERT_TRUE(expected.isSameShape(output));
|
|
ASSERT_TRUE(expected.equalsTo(output));
|
|
|
|
}
|
|
|
|
//////////////////////////////////////////////////////////////////////
|
|
TEST_F(ConvolutionTests2, maxpool2d_bp_1) {
|
|
|
|
auto input = NDArrayFactory::create_<float>('c', {bS,iD,iH,iW});
|
|
auto epsilon = NDArrayFactory::create_<float>('c', {bS,iD,oH,oW});
|
|
auto exp = NDArrayFactory::create<float>('c', {bS,iD,iH,iW});
|
|
|
|
auto variableSpace = new VariableSpace();
|
|
variableSpace->putVariable(-1, input);
|
|
variableSpace->putVariable(-2, epsilon);
|
|
// variableSpace->putVariable(1, &z);
|
|
|
|
auto block = new Context(1, variableSpace, false);
|
|
block->fillInputs({-1});
|
|
block->fillInputs({-2});
|
|
std::vector<int>* argI = block->getIArguments();
|
|
*argI = {kH,kW, sH,sW, pH,pW, dW,dH, 0, 0, 0}; // 0,1 - kernel Height/Width; 2,3 - stride Height/Width; 4,5 - pad Height/Width; 6,7 - dilation Height/Width; 8 - same mode;
|
|
|
|
sd::ops::maxpool2d_bp bp;
|
|
Nd4jStatus status = bp.execute(block);
|
|
ASSERT_EQ(ND4J_STATUS_OK, status);
|
|
|
|
auto result = variableSpace->getVariable(block->getNodeId())->getNDArray();
|
|
ASSERT_TRUE(exp.isSameShape(result));
|
|
|
|
delete variableSpace;
|
|
delete block;
|
|
}
|
|
|
|
//////////////////////////////////////////////////////////////////////
|
|
TEST_F(ConvolutionTests2, maxpool2d_bp_2) {
|
|
|
|
int bS=2, iD=1, iH=4,iW=4, oD=3, kH=2,kW=2, sH=1,sW=1, pH=0,pW=0, dH=1,dW=1;
|
|
int oH = (iH - kH - (kH-1)*(dH-1) + 2*pH)/sH + 1;
|
|
int oW = (iW - kW - (kW-1)*(dW-1) + 2*pW)/sW + 1;
|
|
|
|
// TypeParam epsilonBuff[] = {6., 7., 8., 10., 11., 12., 14., 15., 16., 22., 23., 24., 26., 27., 28., 30., 31., 32.};
|
|
// TypeParam expectedBuff[] = {0., 0., 0., 0.,0., 6., 7., 8.,0.,10.,11.,12.,0.,14.,15.,16.,0., 0., 0., 0.,0.,22.,23.,24.,0.,26.,27.,28.,0.,30.,31.,32.};
|
|
|
|
NDArray input('c', {bS,iD,iH,iW});
|
|
NDArray epsilon('c', {bS,iD,oH,oW}, {6., 7., 8., 10., 11., 12., 14., 15., 16., 22., 23., 24., 26., 27., 28., 30., 31., 32.});
|
|
NDArray expected('c', {bS,iD,iH,iW}, {0., 0., 0., 0.,0., 6., 7., 8.,0.,10.,11.,12.,0.,14.,15.,16.,0., 0., 0., 0.,0.,22.,23.,24.,0.,26.,27.,28.,0.,30.,31.,32.});
|
|
|
|
|
|
input.linspace(1.);
|
|
|
|
std::initializer_list<Nd4jLong> argI = {kH,kW, sH,sW, pH,pW, dW,dH, 0, 0, 0}; // 0,1 - kernel Height/Width; 2,3 - stride Height/Width; 4,5 - pad Height/Width; 6,7 - dilation Height/Width; 8 - same mode;
|
|
|
|
sd::ops::maxpool2d_bp op;
|
|
auto results = op.evaluate({&input, &epsilon}, {}, argI);
|
|
auto output = results.at(0);
|
|
|
|
ASSERT_TRUE(expected.isSameShape(output));
|
|
ASSERT_TRUE(expected.equalsTo(output));
|
|
|
|
}
|
|
|
|
//////////////////////////////////////////////////////////////////////
|
|
TYPED_TEST(TypedConvolutionTests2, maxpool2d_bp_3) {
|
|
|
|
int bS=2, iH=4,iW=3, iC=3, kH=3,kW=2, sH=1,sW=1, pH=0,pW=0, dH=1,dW=1;
|
|
int oH=2,oW=2;
|
|
int paddingMode = 0; // 1-SAME, 0-VALID
|
|
int dataFormat = 0; // 1-NDHWC, 0-NCDHW
|
|
|
|
auto input = NDArrayFactory::create<TypeParam>('c', {bS, iC, iH, iW});
|
|
auto gradO = NDArrayFactory::create<TypeParam>('c', {bS, iC, oH, oW});
|
|
auto expected = NDArrayFactory::create<TypeParam>('c', {bS, iC, iH, iW}, {0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.1f, 0.2f, 0.f, 0.3f, 0.4f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.5f, 0.6f, 0.f, 0.7f, 0.8f,
|
|
0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.9f, 1.f, 0.f, 1.1f, 1.2f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 1.3f, 1.4f, 0.f, 1.5f, 1.6f,
|
|
0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 1.7f, 1.8f, 0.f, 1.9f, 2.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 2.1f, 2.2f, 0.f, 2.3f, 2.4f});
|
|
input.linspace(1.);
|
|
gradO.linspace(0.1, 0.1);
|
|
|
|
sd::ops::maxpool2d_bp op;
|
|
auto results = op.evaluate({&input, &gradO}, {}, {kH,kW, sH,sW, pH,pW, dH,dW, paddingMode, 1, dataFormat});
|
|
auto output = results.at(0);
|
|
|
|
ASSERT_EQ(Status::OK(), results.status());
|
|
ASSERT_TRUE(expected.isSameShape(output));
|
|
ASSERT_TRUE(expected.equalsTo(output));
|
|
}
|
|
|
|
//////////////////////////////////////////////////////////////////////
|
|
TYPED_TEST(TypedConvolutionTests2, maxpool2d_bp_4) {
|
|
|
|
int bS=2, iH=4,iW=3, iC=3, kH=3,kW=2, sH=1,sW=1, pH=1,pW=1, dH=1,dW=1;
|
|
int oH=4,oW=4;
|
|
int paddingMode = 0; // 1-SAME, 0-VALID
|
|
int dataFormat = 0; // 1-NDHWC, 0-NCDHW
|
|
|
|
auto input = NDArrayFactory::create<TypeParam>('c', {bS, iC, iH, iW});
|
|
auto gradO = NDArrayFactory::create<TypeParam>('c', {bS, iC, oH, oW});
|
|
auto expected = NDArrayFactory::create<TypeParam>('c', {bS, iC, iH, iW}, {0.f, 0.f, 0.f, 0.1f, 0.2f, 0.7f, 0.5f, 0.6f, 1.5f, 2.2f, 2.4f, 5.4f, 0.f, 0.f, 0.f, 1.7f, 1.8f, 3.9f, 2.1f, 2.2f, 4.7f, 5.4f, 5.6f, 11.8f,
|
|
0.f, 0.f, 0.f, 3.3f, 3.4f, 7.1f, 3.7f, 3.8f, 7.9f, 8.6f, 8.8f, 18.2f, 0.f, 0.f, 0.f, 4.9f, 5.f, 10.3f, 5.3f, 5.4f, 11.1f, 11.8f, 12.f, 24.6f,
|
|
0.f, 0.f, 0.f, 6.5f, 6.6f, 13.5f, 6.9f, 7.f, 14.3f, 15.f, 15.2f, 31.f, 0.f, 0.f, 0.f, 8.1f, 8.2f, 16.7f, 8.5f, 8.6f, 17.5f, 18.2f, 18.4f, 37.4f});
|
|
input.linspace(1.);
|
|
gradO.linspace(0.1, 0.1);
|
|
|
|
sd::ops::maxpool2d_bp op;
|
|
auto results = op.evaluate({&input, &gradO}, {}, {kH,kW, sH,sW, pH,pW, dH,dW, paddingMode, 1, dataFormat});
|
|
auto output = results.at(0);
|
|
|
|
ASSERT_EQ(Status::OK(), results.status());
|
|
ASSERT_TRUE(expected.isSameShape(output));
|
|
ASSERT_TRUE(expected.equalsTo(output));
|
|
|
|
}
|
|
|
|
//////////////////////////////////////////////////////////////////////
|
|
TYPED_TEST(TypedConvolutionTests2, maxpool2d_bp_5) {
|
|
|
|
int bS=2, iH=4,iW=3, iC=3, kH=3,kW=2, sH=1,sW=1, pH=0,pW=0, dH=1,dW=1;
|
|
int oH=4,oW=3;
|
|
int paddingMode = 1; // 1-SAME, 0-VALID
|
|
int dataFormat = 1; // 1-NDHWC, 0-NCDHW
|
|
|
|
auto input = NDArrayFactory::create<TypeParam>('c', {bS, iH, iW, iC});
|
|
auto gradO = NDArrayFactory::create<TypeParam>('c', {bS, oH, oW, iC});
|
|
auto expected = NDArrayFactory::create<TypeParam>('c', {bS, iH, iW, iC}, {0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.1f, 0.2f, 0.3f, 1.1f, 1.3f, 1.5f, 0.f, 0.f, 0.f, 1.f, 1.1f, 1.2f, 2.9f, 3.1f, 3.3f,
|
|
0.f, 0.f, 0.f, 4.7f, 4.9f, 5.1f, 11.2f, 11.6f, 12.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 3.7f, 3.8f, 3.9f, 8.3f, 8.5f, 8.7f,
|
|
0.f, 0.f, 0.f, 4.6f, 4.7f, 4.8f, 10.1f, 10.3f, 10.5f, 0.f, 0.f, 0.f, 11.9f, 12.1f, 12.3f, 25.6f, 26.f, 26.4f});
|
|
input.linspace(1.);
|
|
gradO.linspace(0.1, 0.1);
|
|
|
|
sd::ops::maxpool2d_bp op;
|
|
auto results = op.evaluate({&input, &gradO}, {}, {kH,kW, sH,sW, pH,pW, dH,dW, paddingMode, 1, dataFormat});
|
|
auto output = results.at(0);
|
|
|
|
ASSERT_EQ(Status::OK(), results.status());
|
|
ASSERT_TRUE(expected.isSameShape(output));
|
|
ASSERT_TRUE(expected.equalsTo(output));
|
|
|
|
}
|
|
|
|
//////////////////////////////////////////////////////////////////////
|
|
TYPED_TEST(TypedConvolutionTests2, maxpool2d_bp_6) {
|
|
|
|
int bS=2, iH=4,iW=3, iC=3, kH=3,kW=2, sH=1,sW=1, pH=0,pW=0, dH=1,dW=1;
|
|
int oH=2,oW=2;
|
|
int paddingMode = 0; // 1-SAME, 0-VALID
|
|
int dataFormat = 1; // 1-NDHWC, 0-NCDHW
|
|
|
|
auto input = NDArrayFactory::create<TypeParam>('c', {bS, iH, iW, iC});
|
|
auto gradO = NDArrayFactory::create<TypeParam>('c', {bS, oH, oW, iC});
|
|
auto expected = NDArrayFactory::create<TypeParam>('c', {bS, iH, iW, iC}, {0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.1f, 0.2f, 0.3f, 0.4f, 0.5f, 0.6f,
|
|
0.f, 0.f, 0.f, 0.7f, 0.8f, 0.9f, 1.f, 1.1f, 1.2f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f,
|
|
0.f, 0.f, 0.f, 1.3f, 1.4f, 1.5f, 1.6f, 1.7f, 1.8f, 0.f, 0.f, 0.f, 1.9f, 2.f, 2.1f, 2.2f, 2.3f, 2.4f});
|
|
input.linspace(1.);
|
|
gradO.linspace(0.1, 0.1);
|
|
|
|
sd::ops::maxpool2d_bp op;
|
|
auto results = op.evaluate({&input, &gradO}, {}, {kH,kW, sH,sW, pH,pW, dH,dW, paddingMode, 1, dataFormat});
|
|
auto output = results.at(0);
|
|
|
|
ASSERT_EQ(Status::OK(), results.status());
|
|
ASSERT_TRUE(expected.isSameShape(output));
|
|
ASSERT_TRUE(expected.equalsTo(output));
|
|
|
|
}
|
|
|
|
//////////////////////////////////////////////////////////////////////
|
|
TEST_F(ConvolutionTests2, maxpool2d_bp_7) {
|
|
|
|
int bS=2, iH=56,iW=56, iC=3, kH=2,kW=2, sH=2,sW=2, pH=0,pW=0, dH=1,dW=1;
|
|
int oH=28,oW=28;
|
|
int paddingMode = 1; // 1-SAME, 0-VALID
|
|
int dataFormat = 0; // 1-NDHWC, 0-NCDHW
|
|
|
|
auto input = NDArrayFactory::create<float16>('c', {bS, iC, iH, iW});
|
|
auto gradO = NDArrayFactory::create<float16>('c', {bS, iC, oH, oW});
|
|
|
|
input.linspace(1.);
|
|
gradO.linspace(0.1, 0.1);
|
|
|
|
sd::ops::maxpool2d_bp op;
|
|
auto results = op.evaluate({&input, &gradO}, {}, {kH,kW, sH,sW, pH,pW, dH,dW, paddingMode, 1, dataFormat});
|
|
// auto output = results.at(0);
|
|
|
|
ASSERT_EQ(Status::OK(), results.status());
|
|
// ASSERT_TRUE(expected.isSameShape(output));
|
|
// ASSERT_TRUE(expected.equalsTo(output));
|
|
|
|
}
|
|
|
|
//////////////////////////////////////////////////////////////////////
|
|
TEST_F(ConvolutionTests2, avgpool2d_bp_1) {
|
|
|
|
auto input = NDArrayFactory::create_<float>('c', {bS,iD,iH,iW});
|
|
auto epsilon = NDArrayFactory::create_<float>('c', {bS,iD,oH,oW});
|
|
auto exp = NDArrayFactory::create<float>('c', {bS,iD,iH,iW});
|
|
|
|
auto variableSpace = new VariableSpace();
|
|
variableSpace->putVariable(-1, input);
|
|
variableSpace->putVariable(-2, epsilon);
|
|
// variableSpace->putVariable(1, &z);
|
|
|
|
auto block = new Context(1, variableSpace, false);
|
|
block->fillInputs({-1});
|
|
block->fillInputs({-2});
|
|
std::vector<int>* argI = block->getIArguments();
|
|
*argI = {kH,kW, sH,sW, pH,pW, dW,dH, 0, 1, 0}; // 0,1 - kernel Height/Width; 2,3 - stride Height/Width; 4,5 - pad Height/Width; 6,7 - dilation Height/Width; 8 - same mode, 9 - extraParam0 (unnecessary for avg mode), 10 - data format
|
|
|
|
sd::ops::avgpool2d_bp bp;
|
|
Nd4jStatus status = bp.execute(block);
|
|
ASSERT_EQ(ND4J_STATUS_OK, status);
|
|
|
|
auto result = variableSpace->getVariable(block->getNodeId())->getNDArray();
|
|
ASSERT_TRUE(exp.isSameShape(result));
|
|
|
|
delete variableSpace;
|
|
delete block;
|
|
}
|
|
|
|
//////////////////////////////////////////////////////////////////////
|
|
TYPED_TEST(TypedConvolutionTests2, avgpool2d_bp_2) {
|
|
|
|
int bS=2, iD=1, iH=4,iW=4, oD=3, kH=2,kW=2, sH=1,sW=1, pH=0,pW=0, dH=1,dW=1;
|
|
int oH = (iH - kH - (kH-1)*(dH-1) + 2*pH)/sH + 1;
|
|
int oW = (iW - kW - (kW-1)*(dW-1) + 2*pW)/sW + 1;
|
|
|
|
// TypeParam epsilonBuff[] = {3.5 , 4.5 , 5.5, 7.5 , 8.5 , 9.5, 11.5, 12.5, 13.5, 19.5, 20.5, 21.5, 23.5, 24.5, 25.5, 27.5, 28.5, 29.5};
|
|
// TypeParam expectedBuff[] = {0.875, 2., 2.5,1.375, 2.75 , 6., 7., 3.75, 4.75 ,10., 11., 5.75, 2.875, 6., 6.5, 3.375, 4.875, 10.,10.5, 5.375, 10.75, 22.,23., 11.75, 12.75, 26.,27., 13.75, 6.875, 14.,14.5, 7.375};
|
|
|
|
auto input = NDArrayFactory::create<TypeParam>('c', {bS,iD,iH,iW});
|
|
auto epsilon = NDArrayFactory::create<TypeParam>('c', {bS,iD,oH,oW}, {3.5f, 4.5f, 5.5f, 7.5f, 8.5f, 9.5f, 11.5f, 12.5f, 13.5f, 19.5f, 20.5f, 21.5f, 23.5f, 24.5f, 25.5f, 27.5f, 28.5f, 29.5f});
|
|
auto expected = NDArrayFactory::create<TypeParam>('c', {bS,iD,iH,iW}, {0.875f, 2.f, 2.5f, 1.375f, 2.75f, 6.f, 7.f, 3.75f, 4.75f, 10.f, 11.f, 5.75f, 2.875f, 6.f, 6.5f, 3.375f, 4.875f, 10.f, 10.5f, 5.375f, 10.75f, 22.f, 23.f, 11.75f, 12.75f, 26.f, 27.f, 13.75f, 6.875f, 14.f, 14.5f, 7.375f});
|
|
|
|
input.linspace(1.);
|
|
|
|
std::initializer_list<Nd4jLong> argI = {kH,kW, sH,sW, pH,pW, dW,dH, 1, 1, 0};
|
|
|
|
sd::ops::avgpool2d_bp op;
|
|
auto results = op.evaluate({&input, &epsilon}, {}, argI);
|
|
auto output = results.at(0);
|
|
|
|
ASSERT_TRUE(expected.isSameShape(output));
|
|
ASSERT_TRUE(expected.equalsTo(output));
|
|
|
|
}
|
|
|
|
//////////////////////////////////////////////////////////////////////
|
|
TYPED_TEST(TypedConvolutionTests2, avgpool2d_bp_3) {
|
|
|
|
int bS=2, iH=4,iW=3, iC=3, kH=3,kW=2, sH=1,sW=1, pH=0,pW=0, dH=1,dW=1;
|
|
int oH=2,oW=2;
|
|
int paddingMode = 0; // 1-SAME, 0-VALID
|
|
int dataFormat = 0; // 1-NDHWC, 0-NCDHW
|
|
|
|
auto input = NDArrayFactory::create<TypeParam>('c', {bS, iC, iH, iW});
|
|
auto gradO = NDArrayFactory::create<TypeParam>('c', {bS, iC, oH, oW});
|
|
auto expected = NDArrayFactory::create<TypeParam>('c', {bS, iC, iH, iW}, {0.016667f, 0.05f, 0.033333f, 0.066667f, 0.166667f, 0.1f, 0.066667f, 0.166667f, 0.1f, 0.05f, 0.116667f, 0.066667f,
|
|
0.083333f, 0.183333f, 0.1f, 0.2f, 0.433333f, 0.233333f, 0.2f, 0.433333f, 0.233333f, 0.116667f, 0.25f, 0.133333f,
|
|
0.15f, 0.316667f, 0.166667f, 0.333333f, 0.7f, 0.366667f, 0.333333f, 0.7f, 0.366667f, 0.183333f, 0.383333f, 0.2f,
|
|
0.216667f, 0.45f, 0.233333f, 0.466667f, 0.966667f, 0.5f, 0.466667f, 0.966667f, 0.5f, 0.25f, 0.516667f, 0.266667f,
|
|
0.283333f, 0.583333f, 0.3f, 0.6f, 1.233333f, 0.633333f, 0.6f, 1.233333f, 0.633333f, 0.316667f, 0.65f, 0.333333f,
|
|
0.35f, 0.716667f, 0.366667f, 0.733333f, 1.5f, 0.766667f, 0.733333f, 1.5f, 0.766667f, 0.383333f, 0.783333f, 0.4f });
|
|
input.linspace(1.);
|
|
gradO.linspace(0.1, 0.1);
|
|
|
|
sd::ops::avgpool2d_bp op;
|
|
auto results = op.evaluate({&input, &gradO}, {}, {kH,kW, sH,sW, pH,pW, dH,dW, paddingMode, 1, dataFormat});
|
|
auto output = results.at(0);
|
|
|
|
ASSERT_EQ(Status::OK(), results.status());
|
|
ASSERT_TRUE(expected.isSameShape(output));
|
|
ASSERT_TRUE(expected.equalsTo(output));
|
|
|
|
}
|
|
|
|
|
|
//////////////////////////////////////////////////////////////////////
|
|
TYPED_TEST(TypedConvolutionTests2, avgpool2d_bp_4) {
|
|
|
|
int bS=2, iH=4,iW=3, iC=3, kH=3,kW=2, sH=1,sW=1, pH=1,pW=1, dH=1,dW=1;
|
|
int oH=4,oW=4;
|
|
int paddingMode = 0; // 1-SAME, 0-VALID
|
|
int dataFormat = 0; // 1-NDHWC, 0-NCDHW
|
|
|
|
auto input = NDArrayFactory::create<TypeParam>('c', {bS, iC, iH, iW});
|
|
auto gradO = NDArrayFactory::create<TypeParam>('c', {bS, iC, oH, oW});
|
|
auto expected = NDArrayFactory::create<TypeParam>('c', {bS, iC, iH, iW}, {0.233333f, 0.3f, 0.366667f, 0.55f, 0.65f, 0.75f, 0.95f, 1.05f, 1.15f, 0.766667f, 0.833333f, 0.9f,
|
|
1.3f, 1.366667f, 1.433333f, 2.15f, 2.25f, 2.35f, 2.55f, 2.65f, 2.75f, 1.833333f, 1.9f, 1.966667f,
|
|
2.366667f, 2.433333f, 2.5f, 3.75f, 3.85f, 3.95f, 4.15f, 4.25f, 4.35f, 2.9f, 2.966667f, 3.033333f,
|
|
3.433333f, 3.5f, 3.566667f, 5.35f, 5.45f, 5.55f, 5.75f, 5.85f, 5.95f, 3.966667f, 4.033333f, 4.1f,
|
|
4.5f, 4.566667f, 4.633333f, 6.95f, 7.05f, 7.15f, 7.35f, 7.45f, 7.55f, 5.033333f, 5.1f, 5.166667f,
|
|
5.566667f, 5.633333f, 5.7f, 8.549999f, 8.65f, 8.75f, 8.95f, 9.05f, 9.150001f, 6.1f, 6.166667f, 6.233334f});
|
|
input.linspace(1.);
|
|
gradO.linspace(0.1, 0.1);
|
|
|
|
sd::ops::avgpool2d_bp op;
|
|
auto results = op.evaluate({&input, &gradO}, {}, {kH,kW, sH,sW, pH,pW, dH,dW, paddingMode, 1, dataFormat});
|
|
auto output = results.at(0);
|
|
|
|
ASSERT_EQ(Status::OK(), results.status());
|
|
ASSERT_TRUE(expected.isSameShape(output));
|
|
ASSERT_TRUE(expected.equalsTo(output));
|
|
|
|
}
|
|
|
|
|
|
////////////////////////////////////////////////////////////////////////
|
|
TYPED_TEST(TypedConvolutionTests2, avgpool2d_bp_5) {
|
|
|
|
int bS=2, iH=4,iW=3, iC=3, kH=3,kW=2, sH=1,sW=1, pH=0,pW=0, dH=1,dW=1;
|
|
int oH=4,oW=3;
|
|
int paddingMode = 1; // 1-SAME, 0-VALID
|
|
int dataFormat = 1; // 1-NDHWC, 0-NCDHW
|
|
|
|
auto input = NDArrayFactory::create<TypeParam>('c', {bS, iH, iW, iC});
|
|
auto gradO = NDArrayFactory::create<TypeParam>('c', {bS, oH, oW, iC});
|
|
auto expected = NDArrayFactory::create<TypeParam>('c', {bS, iH, iW, iC}, {0.19167f, 0.23333f, 0.275f, 0.50833f, 0.59167f, 0.675f, 1.2f, 1.325f, 1.45f, 0.50833f, 0.56667f, 0.625f, 1.19167f, 1.30833f, 1.425f, 2.4f, 2.575f, 2.75f,
|
|
1.18333f, 1.24167f, 1.3f, 2.54167f, 2.65833f, 2.775f, 4.425f, 4.6f, 4.775f, 1.01667f, 1.05833f, 1.1f, 2.15833f, 2.24167f, 2.325f, 3.675f, 3.8f, 3.925f,
|
|
1.69167f, 1.73333f, 1.775f, 3.50833f, 3.59167f, 3.675f, 5.7f, 5.825f, 5.95f, 2.60833f, 2.66667f, 2.725f, 5.39167f, 5.50833f, 5.625f, 8.7f, 8.875f, 9.05f,
|
|
3.28333f, 3.34167f, 3.4f, 6.74167f, 6.85833f, 6.975f, 10.725f, 10.9f, 11.075f, 2.51667f, 2.55833f, 2.6f, 5.15833f, 5.24167f, 5.325f, 8.175f, 8.3f, 8.425f});
|
|
input.linspace(1.);
|
|
gradO.linspace(0.1, 0.1);
|
|
|
|
sd::ops::avgpool2d_bp op;
|
|
auto results = op.evaluate({&input, &gradO}, {}, {kH,kW, sH,sW, pH,pW, dH,dW, paddingMode, 0, dataFormat});
|
|
auto output = results.at(0);
|
|
|
|
ASSERT_EQ(Status::OK(), results.status());
|
|
ASSERT_TRUE(expected.isSameShape(output));
|
|
ASSERT_TRUE(expected.equalsTo(output));
|
|
|
|
}
|
|
|
|
|
|
//////////////////////////////////////////////////////////////////////
|
|
TYPED_TEST(TypedConvolutionTests2, avgpool2d_bp_6) {
|
|
|
|
int bS=2, iH=4,iW=3, iC=3, kH=3,kW=2, sH=1,sW=1, pH=0,pW=0, dH=1,dW=1;
|
|
int oH=2,oW=2;
|
|
int paddingMode = 0; // 1-SAME, 0-VALID
|
|
int dataFormat = 1; // 1-NDHWC, 0-NCDHW
|
|
|
|
auto input = NDArrayFactory::create<TypeParam>('c', {bS, iH, iW, iC});
|
|
auto gradO = NDArrayFactory::create<TypeParam>('c', {bS, oH, oW, iC});
|
|
auto expected = NDArrayFactory::create<TypeParam>('c', {bS, iH, iW, iC}, {0.01667f, 0.03333f, 0.05f, 0.08333f, 0.11667f, 0.15f, 0.06667f, 0.08333f, 0.1f, 0.13333f, 0.16667f, 0.2f, 0.36667f, 0.43333f, 0.5f, 0.23333f, 0.26667f, 0.3f,
|
|
0.13333f, 0.16667f, 0.2f, 0.36667f, 0.43333f, 0.5f, 0.23333f, 0.26667f, 0.3f, 0.11667f, 0.13333f, 0.15f, 0.28333f, 0.31667f, 0.35f, 0.16667f, 0.18333f, 0.2f,
|
|
0.21667f, 0.23333f, 0.25f, 0.48333f, 0.51667f, 0.55f, 0.26667f, 0.28333f, 0.3f, 0.53333f, 0.56667f, 0.6f, 1.16667f, 1.23333f, 1.3f, 0.63333f, 0.66667f, 0.7f,
|
|
0.53333f, 0.56667f, 0.6f, 1.16667f, 1.23333f, 1.3f, 0.63333f, 0.66667f, 0.7f, 0.31667f, 0.33333f, 0.35f, 0.68333f, 0.71667f, 0.75f, 0.36667f, 0.38333f, 0.4f});
|
|
input.linspace(1.);
|
|
gradO.linspace(0.1, 0.1);
|
|
|
|
sd::ops::avgpool2d_bp op;
|
|
auto results = op.evaluate({&input, &gradO}, {}, {kH,kW, sH,sW, pH,pW, dH,dW, paddingMode, 1, dataFormat});
|
|
auto output = results.at(0);
|
|
|
|
ASSERT_EQ(Status::OK(), results.status());
|
|
ASSERT_TRUE(expected.isSameShape(output));
|
|
ASSERT_TRUE(expected.equalsTo(output));
|
|
|
|
}
|
|
|
|
//////////////////////////////////////////////////////////////////////
|
|
TEST_F(ConvolutionTests2, pnormpool2d_bp_1) {
|
|
|
|
auto input = NDArrayFactory::create_<float>('c', {bS,iD,iH,iW});
|
|
auto epsilon = NDArrayFactory::create_<float>('c', {bS,iD,oH,oW});
|
|
auto exp = NDArrayFactory::create<float>('c', {bS,iD,iH,iW});
|
|
|
|
auto variableSpace = new VariableSpace();
|
|
variableSpace->putVariable(-1, input);
|
|
variableSpace->putVariable(-2, epsilon);
|
|
// variableSpace->putVariable(1, &z);
|
|
|
|
auto block = new Context(1, variableSpace, false);
|
|
block->fillInputs({-1});
|
|
block->fillInputs({-2});
|
|
auto argI = block->getIArguments();
|
|
*argI = {kH,kW, sH,sW, pH,pW, dW,dH, 0, 3}; // 0,1 - kernel Height/Width; 2,3 - stride Height/Width; 4,5 - pad Height/Width; 6,7 - dilation Height/Width; 8 - same mode; 9 - divisor
|
|
std::vector<double>* argT = block->getTArguments();
|
|
*argT = {0.000001};
|
|
|
|
sd::ops::pnormpool2d_bp bp;
|
|
Nd4jStatus status = bp.execute(block);
|
|
ASSERT_EQ(ND4J_STATUS_OK, status);
|
|
|
|
auto result = variableSpace->getVariable(block->getNodeId())->getNDArray();
|
|
ASSERT_TRUE(exp.isSameShape(result));
|
|
|
|
delete variableSpace;
|
|
delete block;
|
|
}
|
|
|
|
//////////////////////////////////////////////////////////////////////
|
|
TYPED_TEST(TypedConvolutionTests2, pnormpool2d_bp_2) {
|
|
|
|
int bS=2, iH=4,iW=3, iC=3, kH=3,kW=2, sH=1,sW=1, pH=0,pW=0, dH=1,dW=1;
|
|
int oH=2,oW=2;
|
|
int pnorm = 3;
|
|
double eps = 0.;
|
|
|
|
int paddingMode = 0; // 1-SAME, 0-VALID
|
|
int dataFormat = 0; // 1-NDHWC, 0-NCDHW
|
|
|
|
auto input = NDArrayFactory::create<TypeParam>('c', {bS, iC, iH, iW});
|
|
auto gradO = NDArrayFactory::create<TypeParam>('c', {bS, iC, oH, oW});
|
|
auto expected = NDArrayFactory::create<TypeParam>('c', {bS, iC, iH, iW}, {9.661570e-04f, 9.671602e-03f, 1.306569e-02f, 3.679184e-02f, 1.297220e-01f, 1.040181e-01f, 1.126750e-01f, 3.320884e-01f, 2.340406e-01f, 1.333333e-01f, 3.352886e-01f, 2.070211e-01f,
|
|
8.991618e-02f, 2.160601e-01f, 1.283173e-01f, 2.744226e-01f, 6.364498e-01f, 3.662123e-01f, 3.869788e-01f, 8.808994e-01f, 4.984556e-01f, 2.613189e-01f, 5.818475e-01f, 3.225517e-01f,
|
|
2.065654e-01f, 4.553546e-01f, 2.501175e-01f, 5.190718e-01f, 1.131343e+00f, 6.148388e-01f, 6.362602e-01f, 1.377521e+00f, 7.439550e-01f, 3.833026e-01f, 8.227519e-01f, 4.407146e-01f,
|
|
3.261206e-01f, 6.969233e-01f, 3.717564e-01f, 7.627507e-01f, 1.620991e+00f, 8.600952e-01f, 8.814538e-01f, 1.866888e+00f, 9.873542e-01f, 5.046682e-01f, 1.064004e+00f, 5.602558e-01f,
|
|
4.464697e-01f, 9.389536e-01f, 4.932274e-01f, 1.005908e+00f, 2.108550e+00f, 1.104095e+00f, 1.125322e+00f, 2.354009e+00f, 1.230180e+00f, 6.258913e-01f, 1.305581e+00f, 6.804127e-01f,
|
|
5.671396e-01f, 1.181128e+00f, 6.145977e-01f, 1.248783e+00f, 2.595083e+00f, 1.347494e+00f, 1.368600e+00f, 2.840157e+00f, 1.472778e+00f, 7.470673e-01f, 1.547362e+00f, 8.008900e-01f});
|
|
input.linspace(1.);
|
|
gradO.linspace(0.1, 0.1);
|
|
|
|
sd::ops::pnormpool2d_bp op;
|
|
auto results = op.evaluate({&input, &gradO}, {eps}, {kH,kW, sH,sW, pH,pW, dH,dW, paddingMode, pnorm, dataFormat});
|
|
auto output = results.at(0);
|
|
|
|
ASSERT_EQ(Status::OK(), results.status());
|
|
ASSERT_TRUE(expected.isSameShape(output));
|
|
ASSERT_TRUE(expected.equalsTo(output));
|
|
|
|
}
|
|
|
|
//////////////////////////////////////////////////////////////////////
|
|
TYPED_TEST(TypedConvolutionTests2, pnormpool2d_bp_3) {
|
|
|
|
int bS=2, iH=4,iW=3, iC=3, kH=3,kW=2, sH=1,sW=1, pH=0,pW=0, dH=1,dW=1;
|
|
int oH=2,oW=2;
|
|
int pnorm = 2;
|
|
double eps = 0.;
|
|
|
|
int paddingMode = 0; // 1-SAME, 0-VALID
|
|
int dataFormat = 0; // 1-NDHWC, 0-NCDHW
|
|
|
|
auto input = NDArrayFactory::create<TypeParam>('c', {bS, iC, iH, iW});
|
|
auto gradO = NDArrayFactory::create<TypeParam>('c', {bS, iC, oH, oW});
|
|
auto expected = NDArrayFactory::create<TypeParam>('c', {bS, iC, iH, iW}, {0.007931f, 0.042891f, 0.040544f, 0.09369f, 0.276841f, 0.191675f, 0.163957f, 0.442946f, 0.287512f, 0.154919f, 0.373153f, 0.221172f,
|
|
0.15901f, 0.365232f, 0.207846f, 0.428282f, 0.959455f, 0.534076f, 0.508585f, 1.128771f, 0.623089f, 0.319794f, 0.698063f, 0.379547f,
|
|
0.321068f, 0.692438f, 0.372316f, 0.757521f, 1.620323f, 0.864566f, 0.838684f, 1.787943f, 0.951023f, 0.483194f, 1.023434f, 0.541058f,
|
|
0.483937f, 1.019414f, 0.536145f, 1.085348f, 2.276996f, 1.192917f, 1.166749f, 2.443606f, 1.278126f, 0.646499f, 1.349361f, 0.703463f,
|
|
0.647021f, 1.346249f, 0.699745f, 1.412654f, 2.932174f, 1.520512f, 1.494153f, 3.098146f, 1.604985f, 0.809791f, 1.675544f, 0.866229f,
|
|
0.810192f, 1.673009f, 0.863237f, 1.739711f, 3.58665f, 1.847753f, 1.82126f, 3.752188f, 1.931741f, 0.973081f, 2.001861f, 1.029173f});
|
|
input.linspace(1.);
|
|
gradO.linspace(0.1, 0.1);
|
|
|
|
sd::ops::pnormpool2d_bp op;
|
|
auto results = op.evaluate({&input, &gradO}, {eps}, {kH,kW, sH,sW, pH,pW, dH,dW, paddingMode, pnorm, dataFormat});
|
|
auto output = results.at(0);
|
|
|
|
ASSERT_EQ(Status::OK(), results.status());
|
|
ASSERT_TRUE(expected.isSameShape(output));
|
|
ASSERT_TRUE(expected.equalsTo(output));
|
|
|
|
}
|
|
|
|
|
|
//////////////////////////////////////////////////////////////////////
|
|
TEST_F(ConvolutionTests2, upsampling2d_bp_1) {
|
|
|
|
const int bS=1, iH=2,iW=2, iC=1;
|
|
const int factorH=2, factorW=2;
|
|
const int isNCHW = 1; // data format, default is NCHW
|
|
|
|
auto input = NDArrayFactory::create<float>('c', {bS, iC, iH, iW});
|
|
auto gradO = NDArrayFactory::create<float>('c', {bS, iC, iH*factorH, iW*factorW});
|
|
gradO = 1.;
|
|
|
|
auto expGradI = NDArrayFactory::create<float>('c', {bS, iC, iH, iW});
|
|
expGradI = 4.;
|
|
|
|
sd::ops::upsampling2d_bp op;
|
|
auto results = op.evaluate({&input, &gradO}, {}, {isNCHW});
|
|
auto* gradI = results.at(0);
|
|
|
|
ASSERT_EQ(Status::OK(), results.status());
|
|
ASSERT_TRUE(expGradI.isSameShape(gradI));
|
|
ASSERT_TRUE(expGradI.equalsTo(gradI));
|
|
|
|
}
|
|
|
|
//////////////////////////////////////////////////////////////////////
|
|
TEST_F(ConvolutionTests2, upsampling2d_bp_2) {
|
|
|
|
const int bS=1, iH=2,iW=2, iC=1;
|
|
const int factorH=2, factorW=2;
|
|
const int isNCHW = 0; // data format, default is NCHW
|
|
|
|
auto input = NDArrayFactory::create<float>('c', {bS, iH, iW, iC});
|
|
auto gradO = NDArrayFactory::create<float>('c', {bS, iH*factorH, iW*factorW, iC});
|
|
gradO = 1.;
|
|
|
|
auto expGradI = NDArrayFactory::create<float>('c', {bS, iH, iW, iC});
|
|
expGradI = 4.;
|
|
|
|
sd::ops::upsampling2d_bp op;
|
|
auto results = op.evaluate({&input, &gradO}, {}, {isNCHW});
|
|
auto* gradI = results.at(0);
|
|
|
|
ASSERT_EQ(Status::OK(), results.status());
|
|
ASSERT_TRUE(expGradI.isSameShape(gradI));
|
|
ASSERT_TRUE(expGradI.equalsTo(gradI));
|
|
|
|
}
|
|
|
|
//////////////////////////////////////////////////////////////////////
|
|
TEST_F(ConvolutionTests2, upsampling2d_bp_3) {
|
|
|
|
const int bS=1, iH=3,iW=3, iC=2;
|
|
const int factorH=2, factorW=2;
|
|
const int isNCHW = 1; // data format, default is NCHW
|
|
|
|
NDArray input('c', {bS, iC, iH, iW}, sd::DataType::FLOAT32);
|
|
|
|
NDArray gradO('c', {bS, iC, iH*factorH, iW*factorW}, {0.6793504, 0.35508695, 0.84278935, 0.20031333, 0.7014987, 0.31069338, 0.44793984,
|
|
0.93800974, 0.32667395, 0.15187258, 0.38331753, 0.78212297, 0.1988072, 0.7985636, 0.1632634, 0.14696825, 0.26089668, 0.13505761,
|
|
0.7562093, 0.27545404, 0.36908787, 0.09282647, 0.83649176, 0.26841334, 0.09506222, 0.31279507, 0.13591796, 0.5175439, 0.32870287,
|
|
0.061735712, 0.39643127, 0.248016, 0.5489592, 0.115046196, 0.8143622, 0.7215636, 0.40449402, 0.29908907, 0.4038839, 0.9883108,
|
|
0.022296403, 0.927782, 0.3184157, 0.0685462, 0.28453344, 0.23272, 0.35214192, 0.058909304, 0.7112212, 0.6744568, 0.19694561, 0.6994972,
|
|
0.0743224, 0.42042503, 0.5842631, 0.14957358, 0.44640633, 0.72307247, 0.06448108, 0.48307765, 0.8759956, 0.5698191, 0.4458631, 0.5277549,
|
|
0.016646361, 0.753678, 0.14063567, 0.7541292, 0.16193217, 0.7750374, 0.3326449, 0.11739397}, sd::DataType::FLOAT32);
|
|
|
|
NDArray expGradI('c', {bS, iC, iH, iW}, {2.4203868, 1.5216494, 2.1776323, 2.0290341, 0.772146, 1.5008594, 1.0523045, 1.3174672, 1.9263644,
|
|
1.090545, 1.9094483, 1.3611296, 2.1195147, 2.0659215, 1.0423062, 2.3405795, 1.9105877, 1.2203633}, sd::DataType::FLOAT32);
|
|
|
|
sd::ops::upsampling2d_bp op;
|
|
auto results = op.evaluate({&input, &gradO}, {}, {isNCHW});
|
|
auto* gradI = results.at(0);
|
|
|
|
ASSERT_EQ(Status::OK(), results.status());
|
|
ASSERT_TRUE(expGradI.isSameShape(gradI));
|
|
ASSERT_TRUE(expGradI.equalsTo(gradI));
|
|
|
|
}
|
|
|
|
//////////////////////////////////////////////////////////////////////
|
|
TYPED_TEST(TypedConvolutionTests2, depthwise_conv2d_1) {
|
|
|
|
int bS=2, iH=4,iW=3, iC=2,mC=2, kH=3,kW=2, sH=1,sW=1, pH=0,pW=0, dH=1,dW=1;
|
|
int oC=iC*mC;
|
|
int oH=4,oW=3;
|
|
int paddingMode = 1; // 1-SAME, 0-VALID;
|
|
int dataFormat = 1; // 1-NHWC, 0-NCHW
|
|
|
|
auto input = NDArrayFactory::create<TypeParam>('c', {bS, iH, iW, iC});
|
|
auto weights = NDArrayFactory::create<TypeParam>('c', {kH, kW, iC, mC});
|
|
|
|
|
|
auto expOutput = NDArrayFactory::create<TypeParam>('c', {bS, oH, oW, oC},{12.f, 12.8f, 13.6f, 14.4f, 12.f, 12.8f, 13.6f, 14.4f, 5.2f, 5.6f, 6.f, 6.4f, 13.2f, 14.4f, 15.6f, 16.8f, 13.2f, 14.4f, 15.6f, 16.8f, 5.4f, 6.f, 6.6f, 7.2f,
|
|
13.2f, 14.4f, 15.6f, 16.8f, 13.2f, 14.4f, 15.6f, 16.8f, 5.4f, 6.f, 6.6f, 7.2f, 5.6f, 6.4f, 7.2f, 8.f, 5.6f, 6.4f, 7.2f, 8.f, 2.f, 2.4f, 2.8f, 3.2f,
|
|
12.f, 12.8f, 13.6f, 14.4f, 12.f, 12.8f, 13.6f, 14.4f, 5.2f, 5.6f, 6.f, 6.4f, 13.2f, 14.4f, 15.6f, 16.8f, 13.2f, 14.4f, 15.6f, 16.8f, 5.4f, 6.f, 6.6f, 7.2f,
|
|
13.2f, 14.4f, 15.6f, 16.8f, 13.2f, 14.4f, 15.6f, 16.8f, 5.4f, 6.f, 6.6f, 7.2f, 5.6f, 6.4f, 7.2f, 8.f, 5.6f, 6.4f, 7.2f, 8.f, 2.f, 2.4f, 2.8f, 3.2f});
|
|
input = 2.;
|
|
weights.linspace(0.1, 0.1);
|
|
|
|
sd::ops::depthwise_conv2d op;
|
|
auto results = op.evaluate({&input, &weights}, {}, {kH,kW, sH,sW, pH,pW, dH,dW, paddingMode, dataFormat});
|
|
auto* output = results.at(0);
|
|
|
|
ASSERT_EQ(Status::OK(), results.status());
|
|
|
|
ASSERT_TRUE(expOutput.isSameShape(output));
|
|
ASSERT_TRUE(expOutput.equalsTo(output));
|
|
}
|
|
|
|
//////////////////////////////////////////////////////////////////////
|
|
TEST_F(ConvolutionTests2, depthwise_conv2d_2) {
|
|
|
|
int bS=2, iH=4,iW=3, iC=2,mC=2, kH=3,kW=2, sH=1,sW=1, pH=0,pW=0, dH=1,dW=1;
|
|
int oC=iC*mC;
|
|
int oH=2,oW=2;
|
|
int paddingMode = 0; // 1-SAME, 0-VALID;
|
|
int dataFormat = 1; // 1-NHWC, 0-NCHW
|
|
|
|
auto input = NDArrayFactory::create<float>('c', {bS, iH, iW, iC});
|
|
auto weights = NDArrayFactory::create<float>('c', {kH, kW, iC, mC});
|
|
|
|
|
|
auto expOutput = NDArrayFactory::create<float>('c', {bS, oH, oW, oC},{13.2f, 14.4f, 15.6f, 16.8f, 13.2f, 14.4f, 15.6f, 16.8f, 13.2f, 14.4f, 15.6f, 16.8f, 13.2f, 14.4f, 15.6f, 16.8f,
|
|
13.2f, 14.4f, 15.6f, 16.8f, 13.2f, 14.4f, 15.6f, 16.8f, 13.2f, 14.4f, 15.6f, 16.8f, 13.2f, 14.4f, 15.6f, 16.8f});
|
|
input = 2.;
|
|
weights.linspace(0.1, 0.1);
|
|
|
|
sd::ops::depthwise_conv2d op;
|
|
auto results = op.evaluate({&input, &weights}, {}, {kH,kW, sH,sW, pH,pW, dH,dW, paddingMode, dataFormat});
|
|
auto* output = results.at(0);
|
|
|
|
ASSERT_EQ(Status::OK(), results.status());
|
|
|
|
ASSERT_TRUE(expOutput.isSameShape(output));
|
|
ASSERT_TRUE(expOutput.equalsTo(output));
|
|
|
|
}
|
|
|
|
|
|
//////////////////////////////////////////////////////////////////////
|
|
TEST_F(ConvolutionTests2, depthwise_conv2d_3) {
|
|
|
|
int bS=2, iH=4,iW=3, iC=2,mC=2, kH=3,kW=2, sH=1,sW=1, pH=0,pW=0, dH=1,dW=1;
|
|
int oC=iC*mC;
|
|
int oH=2,oW=2;
|
|
int paddingMode = 0; // 1-SAME, 0-VALID;
|
|
int dataFormat = 0; // 1-NHWC, 0-NCHW
|
|
|
|
auto input = NDArrayFactory::create<float>('c', {bS, iC, iH, iW});
|
|
auto weights = NDArrayFactory::create<float>('c', {mC, iC, kH, kW});
|
|
auto biases = NDArrayFactory::create<float>('c', {iC*mC}, {1.f,2.f,3.f,4.f});
|
|
|
|
NDArray expOutput('c', {bS, oC, oH, oW},{5.2, 5.2, 5.2, 5.2,20.6,20.6,20.6,20.6,14.4,14.4,14.4,14.4,29.8,29.8,29.8,29.8, 5.2, 5.2, 5.2, 5.2,20.6,20.6,20.6,20.6,14.4,14.4,14.4,14.4,29.8,29.8,29.8,29.8}, sd::DataType::FLOAT32);
|
|
|
|
input = 2.;
|
|
weights.linspace(0.1, 0.1);
|
|
weights.permutei({2,3,1,0});
|
|
|
|
sd::ops::depthwise_conv2d op;
|
|
auto results = op.evaluate({&input, &weights, &biases}, {}, {kH,kW, sH,sW, pH,pW, dH,dW, paddingMode, dataFormat});
|
|
auto* output = results.at(0);
|
|
|
|
ASSERT_EQ(Status::OK(), results.status());
|
|
|
|
ASSERT_TRUE(expOutput.isSameShape(output));
|
|
ASSERT_TRUE(expOutput.equalsTo(output));
|
|
|
|
}
|
|
|
|
//////////////////////////////////////////////////////////////////////
|
|
TEST_F(ConvolutionTests2, depthwise_conv2d_4) {
|
|
|
|
int bS=1, iH=111,iW=111, iC=32,mC=1, kH=7,kW=7, sH=2,sW=2, pH=0,pW=0, dH=1,dW=1;
|
|
int oC=iC*mC;
|
|
int oH=56,oW=56;
|
|
|
|
int paddingMode = 1; // 1-SAME, 0-VALID;
|
|
int dataFormat = 1; // 1-NHWC, 0-NCHW
|
|
|
|
const float unique = -1000000;
|
|
|
|
NDArray input('c', {bS, iH, iW, iC}, sd::DataType::FLOAT32);
|
|
NDArray weights('c', {kH, kW, iC, mC}, sd::DataType::FLOAT32);
|
|
NDArray output('c', {bS, oH, oW, oC}, sd::DataType::FLOAT32);
|
|
input.linspace(0.1, 0.0001);
|
|
weights = 0.5;
|
|
output = unique;
|
|
|
|
sd::ops::depthwise_conv2d op;
|
|
Nd4jStatus status = op.execute({&input, &weights}, {&output} , {}, {kH,kW, sH,sW, pH,pW, dH,dW, paddingMode, dataFormat}, {});
|
|
|
|
ASSERT_EQ(Status::OK(), status);
|
|
|
|
for(Nd4jLong i=output.lengthOf()/1.5; i < output.lengthOf(); ++i)
|
|
ASSERT_EQ(output.e<float>(i) != unique, true);
|
|
}
|
|
|
|
//////////////////////////////////////////////////////////////////////
|
|
TEST_F(ConvolutionTests2, depthwise_conv2d_5) {
|
|
|
|
int bS=1, iH=3,iW=3, iC=2,mC=1, kH=2,kW=2, sH=1,sW=1, pH=0,pW=0, dH=1,dW=1;
|
|
int oC=iC*mC;
|
|
int oH=3,oW=3;
|
|
int paddingMode = 1; // 1-SAME, 0-VALID;
|
|
int dataFormat = 1; // 1-NHWC, 0-NCHW
|
|
|
|
auto input = NDArrayFactory::create<float>('c', {bS, iH, iW, iC});
|
|
auto weights = NDArrayFactory::create<float>('c', {kH, kW, iC, mC});
|
|
|
|
NDArray expOutput('c', {bS, oH, oW, oC}, {10., 12., 14., 16., 8., 9., 22., 24., 26., 28., 14., 15., 14., 15., 16., 17., 8.5, 9.}, sd::DataType::FLOAT32);
|
|
|
|
input.linspace(1.);
|
|
weights = 0.5;
|
|
|
|
sd::ops::depthwise_conv2d op;
|
|
auto results = op.evaluate({&input, &weights}, {}, {kH,kW, sH,sW, pH,pW, dH,dW, paddingMode, dataFormat});
|
|
auto output = results.at(0);
|
|
|
|
ASSERT_EQ(Status::OK(), results.status());
|
|
|
|
ASSERT_TRUE(expOutput.isSameShape(output));
|
|
ASSERT_TRUE(expOutput.equalsTo(output));
|
|
|
|
}
|
|
|
|
//////////////////////////////////////////////////////////////////////
|
|
TEST_F(ConvolutionTests2, depthwise_conv2d_6) {
|
|
|
|
int bS=1, iH=3,iW=3, iC=2,mC=1, kH=2,kW=2, sH=1,sW=1, pH=0,pW=0, dH=1,dW=1;
|
|
int oC=iC*mC;
|
|
int oH=3,oW=3;
|
|
int paddingMode = 1; // 1-SAME, 0-VALID;
|
|
int dataFormat = 1; // 1-NHWC, 0-NCHW
|
|
|
|
NDArray input('c', {bS, iH, iW, iC}, sd::DataType::FLOAT32);
|
|
NDArray weights('c', {kH, kW, iC, mC}, sd::DataType::FLOAT32);
|
|
|
|
NDArray expOutput('c', {bS, oH, oW, oC}, {20., 24.,28., 32.,16., 18.,44., 48.,52., 56.,28., 30.,28., 30.,32., 34.,17., 18.}, sd::DataType::FLOAT32);
|
|
input.linspace(1.);
|
|
weights = 1.;
|
|
|
|
sd::ops::depthwise_conv2d op;
|
|
auto results = op.evaluate({&input, &weights}, {}, {kH,kW, sH,sW, pH,pW, dH,dW, paddingMode, dataFormat});
|
|
NDArray* output = results.at(0);
|
|
// output.printIndexedBuffer();
|
|
|
|
ASSERT_EQ(Status::OK(), results.status());
|
|
|
|
ASSERT_TRUE(expOutput.isSameShape(output));
|
|
ASSERT_TRUE(expOutput.equalsTo(output));
|
|
}
|
|
|
|
//////////////////////////////////////////////////////////////////////
|
|
TEST_F(ConvolutionTests2, depthwise_conv2d_7) {
|
|
|
|
int bS=1, iH=3,iW=3, iC=2,mC=2, kH=1,kW=1, sH=1,sW=1, pH=0,pW=0, dH=1,dW=1;
|
|
int oC=iC*mC;
|
|
int oH=3,oW=3;
|
|
int paddingMode = 0; // 1-SAME, 0-VALID;
|
|
int dataFormat = 0; // 1-NHWC, 0-NCHW
|
|
|
|
NDArray input('c', {bS, iC, iH, iW}, {0.6793503761291504, 0.35508695244789124, 0.842789351940155, 0.20031332969665527, 0.7014986872673035, 0.3106933832168579,
|
|
0.44793984293937683, 0.9380097389221191, 0.3266739547252655, 0.15187257528305054, 0.3833175301551819, 0.7821229696273804,
|
|
0.19880719482898712, 0.7985635995864868, 0.16326339542865753, 0.14696824550628662, 0.2608966827392578, 0.13505761325359344}, sd::DataType::FLOAT32);
|
|
NDArray weights('c', {kH, kW, iC, mC}, {0.1308445781469345, 0.6442840099334717, 0.5698848366737366, 0.19896849989891052}, sd::DataType::FLOAT32);
|
|
NDArray biases('c', {1,iC*mC}, {0.6123566627502441, 0.37637925148010254, 0.17464971542358398, 0.4270855486392975}, sd::DataType::FLOAT32);
|
|
|
|
NDArray expOutput('c', {bS, oC, oH, oW}, {0.7012459761288241, 0.6588178652487691, 0.722631079971582, 0.6385665758716108, 0.7041439625563628, 0.6530092074102978,
|
|
0.670967162534851, 0.735090151337225, 0.6551001785478623, 0.8140738359624038, 0.6051560970782859, 0.9193749546773375, 0.5054379267801892, 0.8283436386757472,
|
|
0.5765540302788565, 0.6649797296980537, 0.9807239274294943, 0.586850056971322, 0.261199593183985, 0.3930965634902499, 0.6203697362284615, 0.28794692117826504,
|
|
0.6297390019475202, 0.26769104886224415, 0.25840469001015975, 0.3233307788551656, 0.25161700129415276, 0.4573034071191504, 0.5033536625992294, 0.5827033826425385,
|
|
0.4666419179635315, 0.585974550122895, 0.4595698215161401, 0.45632759998045813, 0.4789957702325296, 0.4539577593482922}, sd::DataType::FLOAT32);
|
|
|
|
|
|
sd::ops::depthwise_conv2d op;
|
|
auto results = op.evaluate({&input, &weights, &biases}, {}, {kH,kW, sH,sW, pH,pW, dH,dW, paddingMode, dataFormat});
|
|
auto* output = results.at(0);
|
|
|
|
ASSERT_EQ(Status::OK(), results.status());
|
|
|
|
ASSERT_TRUE(expOutput.isSameShape(output));
|
|
ASSERT_TRUE(expOutput.equalsTo(output));
|
|
}
|
|
|
|
//////////////////////////////////////////////////////////////////////
|
|
TEST_F(ConvolutionTests2, depthwise_conv2d_8) {
|
|
|
|
int bS=1, iH=10,iW=10, iC=8,mC=1, kH=3,kW=3, sH=1,sW=1, pH=0,pW=0, dH=1,dW=1;
|
|
int oC=iC*mC;
|
|
int oH=10,oW=10;
|
|
int paddingMode = 1; // 1-SAME, 0-VALID;
|
|
int dataFormat = 1; // 1-NHWC, 0-NCHW
|
|
|
|
NDArray input('c', {bS, iH, iW, iC}, sd::DataType::FLOAT32);
|
|
NDArray weights('c', {kH, kW, iC, mC}, sd::DataType::FLOAT32);
|
|
|
|
NDArray expOutput('c', {bS, oH, oW, oC}, {-42.879997, -43.959999, -44.959999, -45.879997, -46.720005, -47.480003, -48.160000, -48.760002, -43.519997, -45.139999, -46.639996, -48.020000, -49.280003, -50.419998, -51.440006, -52.340000, -31.999998, -33.139999, -34.160000, -35.060001, -35.840004, -36.500004, -37.039997, -37.459999, -20.480000,
|
|
-21.139997, -21.680000, -22.100000, -22.399998, -22.579998, -22.639996, -22.580002, -8.960000, -9.139998, -9.200002, -9.140001, -8.960001, -8.660000, -8.240002, -7.700001, 2.560000, 2.860002, 3.279998, 3.820000, 4.480001, 5.260000, 6.160001, 7.180000, 14.080000, 14.860000, 15.759998, 16.779999, 17.920002, 19.180000, 20.560001, 22.059998,
|
|
25.600000, 26.860001, 28.239998, 29.739998, 31.360001, 33.099998, 34.959999, 36.939999, 37.119999, 38.860001, 40.720001, 42.699997, 44.800003, 47.020000, 49.360001, 51.820000, 26.239998, 27.400002, 28.639999, 29.959999, 31.360001, 32.840000, 34.400002, 36.040001, 62.400002, 62.459999, 62.639999, 62.940002, 63.360001, 63.900002, 64.559998,
|
|
65.340004, 106.080002, 106.169998, 106.440002, 106.889999, 107.519997, 108.330002, 109.320000, 110.490005, 114.720001, 115.529999, 116.520004, 117.690002, 119.040009, 120.570000, 122.279999, 124.169998, 123.359985, 124.889999, 126.599998, 128.490005, 130.559998, 132.809998, 135.240005, 137.850006, 132.000000, 134.250000, 136.679993,
|
|
139.290009, 142.080002, 145.049988, 148.199997, 151.529999, 140.639999, 143.610001, 146.760010, 150.089996, 153.600006, 157.290009, 161.160004, 165.209991, 149.279999, 152.970001, 156.839996, 160.889999, 165.120010, 169.529999, 174.119995, 178.889999, 157.919998, 162.330002, 166.919983, 171.690002, 176.639999, 181.769989, 187.079987,
|
|
192.570007, 166.559998, 171.690002, 177.000000, 182.489990, 188.160004, 194.010010, 200.040009, 206.250000, 100.799995, 104.220001, 107.760002, 111.419998, 115.200005, 119.099998, 123.120003, 127.260010, 139.200012, 144.059998, 149.040009, 154.139999, 159.360001, 164.699997, 170.160004, 175.739990, 192.479996, 199.770020, 207.239990,
|
|
214.889999, 222.720001, 230.730011, 238.919998, 247.290009, 201.119995, 209.129990, 217.319992, 225.690002, 234.240005, 242.970001, 251.880005, 260.970001, 209.760010, 218.489990, 227.399994, 236.490005, 245.760010, 255.209991, 264.839996, 274.649994, 218.399994, 227.850006, 237.479996, 247.289993, 257.279999, 267.449982, 277.799988,
|
|
288.330017, 227.040009, 237.209991, 247.559998, 258.089996, 268.800018, 279.690002, 290.760010, 302.010010, 235.679993, 246.570007, 257.639984, 268.889984, 280.320007, 291.929993, 303.720001, 315.690002, 244.320007, 255.929993, 267.720001, 279.690002, 291.839996, 304.169983, 316.679993, 329.369995, 252.959991, 265.290009, 277.799988, 290.489990, 303.359985, 316.410004, 329.640015, 343.050018, 139.199997, 147.419998, 155.760010, 164.220001, 172.799988, 181.500000, 190.319992, 199.260010, 216.000000, 225.660004, 235.440002, 245.339996, 255.360016, 265.500000, 275.760010, 286.140015, 278.880005, 293.369995, 308.040009, 322.889984, 337.920013, 353.129974, 368.519989,
|
|
384.090027, 287.520020, 302.730011, 318.119995, 333.690002, 349.440002, 365.369995, 381.479980, 397.770020, 296.160004, 312.089996, 328.199982, 344.489990, 360.960022, 377.609985, 394.440002, 411.449982, 304.799988, 321.450012, 338.280029, 355.289978, 372.480011, 389.850006, 407.399994, 425.130005, 313.440002, 330.809998, 348.359985, 366.089996, 384.000000, 402.090027, 420.359985, 438.809998, 322.079987, 340.169983, 358.440002, 376.889984, 395.520020, 414.329987, 433.320007, 452.489990, 330.720001, 349.530029, 368.520020, 387.690002, 407.039978, 426.570007, 446.279999, 466.170013, 339.360016, 358.890015, 378.599976, 398.490021, 418.559998, 438.809998, 459.239990, 479.849976, 177.600006, 190.619995, 203.759995, 217.020004, 230.399994, 243.899994, 257.519989, 271.260010, 292.799988, 307.260010, 321.839996, 336.539978, 351.360016, 366.299988, 381.359985, 396.540009, 365.279999, 386.970001, 408.839996, 430.889984, 453.120026, 475.529968, 498.119995, 520.890015, 373.920013, 396.329987, 418.919983, 441.690002, 464.640015, 487.769958, 511.079987, 534.570007, 382.559998, 405.690002, 429.000000, 452.489990, 476.160004, 500.010010, 524.039978, 548.250000, 391.200012, 415.049988, 439.080017, 463.290009, 487.679993, 512.250000, 537.000000, 561.930054, 399.839996, 424.409973, 449.160034, 474.089966, 499.200012, 524.489990, 549.959961, 575.609985, 408.479980, 433.770020, 459.239990, 484.889954, 510.720032, 536.729980, 562.919983, 589.290039, 417.119995, 443.130005, 469.319977, 495.690002, 522.239990, 548.969971, 575.880005, 602.969971, 425.760010, 452.489990, 479.399994, 506.489990, 533.760010, 561.209961, 588.839966, 616.650024, 216.000000, 233.819992, 251.760010, 269.820007, 288.000000, 306.299988, 324.719971, 343.260010, 369.600006, 388.859985, 408.239990, 427.739990, 447.360016, 467.100006, 486.959961, 506.940002, 451.679993, 480.570007, 509.639984, 538.890015, 568.320007, 597.929993, 627.719971, 657.690002, 460.320007, 489.929993, 519.719971, 549.690002, 579.840027, 610.170044, 640.680054, 671.369995, 468.960022, 499.289978, 529.799988, 560.489990, 591.359985, 622.409973, 653.640015, 685.049988, 477.599976, 508.650024, 539.880005, 571.289978, 602.880005, 634.650024, 666.599976, 698.729980, 486.239990, 518.010010, 549.960022, 582.089966, 614.400024, 646.890015, 679.559937, 712.410034, 494.879974, 527.369995, 560.039978, 592.890015, 625.920044, 659.130005, 692.520020, 726.089966, 503.519989, 536.729980, 570.119995, 603.689941, 637.440063, 671.369995, 705.480042, 739.770020, 512.160034, 546.089966, 580.199951, 614.489990, 648.960022, 683.609985, 718.440002, 753.449951, 254.400009, 277.020020, 299.760010, 322.619995, 345.600006, 368.700012, 391.919983, 415.260010, 446.399994, 470.459961, 494.640015, 518.940002, 543.360046, 567.900024, 592.559998, 617.340027, 538.080017, 574.170044, 610.440002, 646.890015, 683.520020, 720.329956, 757.320007, 794.489990, 546.719971, 583.530029, 620.520020, 657.690002, 695.040039, 732.570007, 770.279968, 808.169983, 555.359985, 592.889954, 630.599976, 668.489990, 706.559998, 744.809998, 783.239990, 821.849976, 564.000000, 602.250000, 640.679993, 679.289978, 718.080017, 757.050049, 796.199951, 835.530029, 572.640015, 611.609985, 650.760010, 690.089966, 729.600037, 769.289978, 809.160034, 849.210083, 581.279968, 620.970032, 660.839966, 700.889954, 741.119995, 781.529968, 822.119995, 862.890015, 589.919983, 630.330017, 670.919983, 711.690002, 752.640015, 793.770020, 835.079956, 876.570007, 598.559998, 639.690002, 681.000000, 722.490051, 764.160034, 806.010010, 848.039978, 890.250061, 292.799988, 320.220001, 347.760010, 375.419983, 403.200012, 431.100006, 459.119995, 487.260010, 523.199951, 552.059998, 581.040039, 610.139954, 639.360046, 668.699951, 698.159973, 727.739990, 624.479980, 667.770020, 711.239990, 754.890015, 798.719971, 842.729980, 886.919983, 931.290039, 633.119995, 677.130005, 721.319946, 765.690002, 810.239990, 854.969971, 899.880005, 944.969971, 641.760010, 686.489990, 731.400024, 776.489990, 821.760010, 867.209961, 912.839966, 958.650024, 650.400024, 695.849976, 741.479980, 787.290039, 833.279968, 879.449951, 925.799927, 972.330017, 659.040039, 705.210022, 751.559998, 798.089966, 844.800049, 891.690002, 938.760010, 986.010010, 667.679993, 714.569946, 761.640015, 808.890015, 856.320007, 903.929993, 951.719971, 999.690063, 676.320007, 723.929993, 771.719971, 819.690002, 867.839966, 916.169922, 964.679932, 1013.369995, 684.959961, 733.290039, 781.800049, 830.489990, 879.359985, 928.410034, 977.640015, 1027.050049, 331.199982, 363.419983, 395.760010, 428.220001, 460.799988, 493.500000, 526.320007, 559.260010, 600.000000, 633.660034, 667.440002, 701.339966, 735.359985, 769.500000, 803.759949, 838.140015, 710.880005, 761.369995, 812.039978, 862.889893, 913.919983, 965.130005, 1016.520020, 1068.090088, 719.520020, 770.729980, 822.119934, 873.689941, 925.440063, 977.369995, 1029.479980, 1081.770020, 728.160034, 780.090088, 832.199951, 884.489990, 936.960022, 989.610046, 1042.439941, 1095.449951, 736.799927, 789.449951, 842.280029, 895.290039, 948.480042, 1001.849976, 1055.399902, 1109.129883, 745.439941, 798.810059, 852.359985, 906.089966, 960.000000, 1014.089966, 1068.359985, 1122.810059, 754.080017, 808.170044, 862.440002, 916.890015, 971.520020, 1026.330078, 1081.319946, 1136.489990, 762.720032, 817.530029, 872.520020, 927.689941, 983.040039, 1038.569946, 1094.280029, 1150.169922, 771.359985, 826.890015, 882.599976, 938.489990, 994.559998, 1050.810059, 1107.239990, 1163.849976, 369.599976, 406.619995, 443.760010, 481.020020, 518.400024, 555.900024, 593.520020, 631.260010, 113.279999, 136.839996, 160.480011, 184.199982, 208.000015, 231.880005, 255.839996, 279.880005, 31.359985, 66.699989, 102.160004, 137.740005, 173.440002, 209.260010, 245.199982, 281.260010, 31.359993, 67.179993, 103.120003, 139.179993, 175.360016, 211.660004, 248.079987, 284.619995, 31.359993, 67.659996, 104.080009, 140.619995, 177.280014, 214.060013, 250.959991, 287.980011, 31.359993, 68.139999, 105.039993, 142.059982, 179.200027, 216.459991, 253.839996, 291.339996, 31.360008, 68.619995, 106.000000, 143.499985, 181.119995, 218.860001, 256.719971, 294.700012, 31.360001, 69.099991, 106.959984, 144.939987, 183.040009, 221.260010, 259.600006, 298.059998, 31.360008, 69.579971, 107.920006, 146.379990, 184.960007, 223.660004, 262.479980, 301.419983, 31.360001, 70.059975, 108.880020, 147.819977, 186.880020, 226.059998, 265.359985, 304.779999, -83.840004, -58.040001, -32.159988, -6.200012, 19.840012, 45.959984, 72.159996, 98.440010}, sd::DataType::FLOAT32);
|
|
|
|
input.linspace(-10, 0.1);
|
|
weights.linspace(-2, 0.1);
|
|
|
|
sd::ops::depthwise_conv2d op;
|
|
auto results = op.evaluate({&input, &weights}, {}, {kH,kW, sH,sW, pH,pW, dH,dW, paddingMode, dataFormat});
|
|
auto output = results.at(0);
|
|
// output->printBuffer();
|
|
|
|
ASSERT_EQ(Status::OK(), results.status());
|
|
|
|
ASSERT_TRUE(expOutput.isSameShape(output));
|
|
ASSERT_TRUE(expOutput.equalsTo(output));
|
|
}
|
|
|
|
//////////////////////////////////////////////////////////////////////
|
|
TEST_F(ConvolutionTests2, depthwise_conv2d_9) {
|
|
|
|
int bS=1, iH=10,iW=10, iC=8,mC=1, kH=3,kW=3, sH=1,sW=1, pH=0,pW=0, dH=1,dW=1;
|
|
int oC=iC*mC;
|
|
int oH=10,oW=10;
|
|
int paddingMode = 1; // 1-SAME, 0-VALID;
|
|
int dataFormat = 0; // 1-NHWC, 0-NCHW
|
|
|
|
NDArray input('c', {bS, iC, iH, iW}, sd::DataType::FLOAT32);
|
|
NDArray weights('c', {kH, kW, iC, mC}, sd::DataType::FLOAT32);
|
|
|
|
NDArray expOutput('c', {bS, oC, oH, oW}, {-103.360001, -131.440002, -130.000000, -128.559998, -127.120003, -125.680000, -124.240005, -122.799995, -121.360001, -66.720001,-76.199997, -81.239998, -80.160004, -79.080002, -78.000000, -76.919998, -75.840004, -74.760002, -73.680000, -29.400002, -66.599998, -70.440002, -69.360001, -68.279999,
|
|
-67.199997, -66.120003, -65.040001, -63.959999, -62.879997, -24.599997, -57.000000, -59.639999, -58.560005, -57.479996, -56.399998, -55.320000, -54.240002, -53.159996, -52.080002, -19.799997, -47.400002, -48.840000, -47.760002, -46.680000, -45.599998, -44.520000, -43.440002, -42.360001, -41.279999, -15.000000, -37.799999, -38.040001,
|
|
-36.959999, -35.879997, -34.799999, -33.720001, -32.639999, -31.560001, -30.479996, -10.199999, -28.200001, -27.240002, -26.160000, -25.080002, -24.000000, -22.919998,-21.840000, -20.759998, -19.679998, -5.400000, -18.599998, -16.439999, -15.360001, -14.280001, -13.200001, -12.120001, -11.040000, -9.960001, -8.880000, -0.600000,
|
|
-9.000000, -5.639999, -4.560000, -3.480000, -2.400000, -1.320001, -0.240000, 0.840001, 1.920000, 4.200000, 0.160000, 3.920000, 3.920000, 3.920000, 3.920000, 3.920000,3.920001, 3.920000, 3.920000, 3.520000, 8.860001, 12.920000, 14.420000, 15.920000, 17.420000, 18.920000, 20.420000, 21.920000, 23.420000, 13.820000, 20.430000, 27.750000,
|
|
28.919998, 30.090000, 31.260000, 32.430000, 33.600002, 34.770000, 35.939999, 19.709999, 30.630001, 39.450001, 40.619999, 41.790001, 42.960003, 44.129997, 45.299999, 46.470001, 47.639999, 25.110001, 40.829998, 51.150002, 52.320000, 53.489998, 54.660004, 55.829994, 57.000000, 58.169998, 59.340004, 30.510002, 51.029999, 62.849998,
|
|
64.019997, 65.190002, 66.360001, 67.529999, 68.699997, 69.870003, 71.040001, 35.910000, 61.229996, 74.550003, 75.720001, 76.889999, 78.059998, 79.229996, 80.400002, 81.570000, 82.740005, 41.310001, 71.430000, 86.250000, 87.419998, 88.589996, 89.760002, 90.929993, 92.099991, 93.270004, 94.440002, 46.709999, 81.630005, 97.949997,
|
|
99.120003, 100.290009, 101.459999, 102.630005, 103.800003, 104.970001, 106.139999, 52.110001, 91.830002, 109.649994, 110.820007, 111.990005, 113.159996, 114.330002, 115.500000, 116.669998, 117.839996, 57.509995, 19.580000, 9.079998, 9.139999, 9.199999, 9.259996, 9.320001, 9.379998, 9.440000, 9.500000, -8.740000, 129.080002, 169.279999,
|
|
170.839996, 172.399994, 173.960007, 175.520004, 177.080002, 178.639999, 180.199982, 102.360001, 129.059998, 154.739990, 156.000000, 157.259995, 158.520004, 159.779999, 161.039993, 162.300003, 163.559998, 80.820000, 139.860001, 167.340012, 168.600006, 169.860001, 171.119995, 172.380005, 173.639999, 174.899994, 176.160004, 86.820000,
|
|
150.660004, 179.940002, 181.200012, 182.459991, 183.720001, 184.980011, 186.239990, 187.500000, 188.759995, 92.820007, 161.459991, 192.540009, 193.799988, 195.059998, 196.319992, 197.579987, 198.839996, 200.100006, 201.360001, 98.820000, 172.259995, 205.139999, 206.399994, 207.660004, 208.919983, 210.179993, 211.440002, 212.700012,
|
|
213.959991, 104.819992, 183.059998, 217.739990, 219.000000, 220.259995, 221.519989, 222.779999, 224.039993, 225.300018, 226.559998, 110.819992, 193.860016, 230.339996, 231.600006, 232.860001, 234.119995, 235.380005, 236.639999, 237.900009, 239.160004, 116.820000, 204.660004, 242.940002, 244.199982, 245.459991, 246.720001, 247.980011,
|
|
249.239990, 250.500000, 251.759995, 122.819992, 47.000000, 26.240004, 26.360004, 26.479998, 26.600002, 26.720001, 26.840002, 26.959997, 27.080000, -12.999998, 257.299988, 337.640015, 339.260010, 340.879974, 342.499969, 344.119995, 345.740021, 347.359985, 348.979980, 198.899994, 249.690002, 299.729980, 301.079987, 302.429993, 303.779999, 305.130005, 306.480011, 307.829987, 309.179993, 153.929993, 261.089996, 313.230011, 314.580017, 315.929993, 317.279968, 318.630005, 319.979980, 321.329987, 322.679993, 160.529999, 272.489990, 326.729980, 328.079987, 329.429993, 330.779968, 332.130005, 333.479980, 334.829987, 336.179993, 167.130005, 283.889984, 340.230011, 341.580017, 342.929993, 344.279999, 345.630005, 346.980011, 348.330017, 349.679993, 173.729996, 295.289978, 353.729980, 355.079987, 356.429993, 357.779968, 359.130005, 360.480011, 361.829987, 363.179993, 180.329987, 306.690002, 367.230011, 368.580017, 369.929993, 371.279999, 372.630005, 373.980011, 375.330017, 376.679993, 186.929993, 318.089996, 380.729980, 382.080017, 383.429993, 384.779968, 386.130005, 387.479980, 388.829987, 390.179993, 193.529984, 329.489990, 394.229980, 395.579987, 396.929993, 398.279999, 399.630005, 400.980011, 402.330017, 403.679993, 200.130005, 82.419998, 55.400005, 55.580002, 55.759995, 55.939999, 56.120003, 56.299995, 56.479996, 56.659996, -9.260002, 393.520020, 518.000000, 519.679993, 521.359985, 523.040039, 524.720032, 526.400024, 528.080017, 529.760010, 303.440002, 382.320007, 462.720032, 464.160004, 465.600037, 467.040009, 468.479980, 469.919983, 471.359985, 472.800018, 239.040009, 394.320007, 477.119995, 478.559998, 480.000000, 481.440002, 482.880005, 484.320007, 485.760010, 487.200012, 246.240005, 406.320007, 491.520020, 492.960022, 494.400024, 495.839996, 497.280029, 498.720032, 500.160004, 501.600037, 253.440002, 418.320007, 505.919983, 507.359985, 508.800018, 510.240051, 511.680023, 513.119995, 514.559998, 516.000000, 260.640015, 430.319977, 520.320007, 521.760010, 523.200012, 524.640015, 526.079956, 527.520020, 528.960022, 530.400024, 267.839996, 442.320007, 534.720032, 536.160034, 537.600037, 539.040039, 540.479980, 541.919983, 543.359985, 544.800049, 275.040009, 454.320007, 549.119995, 550.559998, 552.000000, 553.440002, 554.880005, 556.320007, 557.760010, 559.200012, 282.239990, 466.320007, 563.520020, 564.960022, 566.400024, 567.839966, 569.280029, 570.720032, 572.160034, 573.600037, 289.440002, 125.839996, 96.559998, 96.799995, 97.040009, 97.280014, 97.520004, 97.759995, 98.000000, 98.240013, 2.480007, 537.739990, 710.359985, 712.099976, 713.840027, 715.579956, 717.319946, 719.059998, 720.799988, 722.539978, 415.980011, 526.950012, 643.710022, 645.240051, 646.770020, 648.300049, 649.829956, 651.359985, 652.890015, 654.419983, 336.149994, 539.549988, 659.010010, 660.539978, 662.070007, 663.600037, 665.130005, 666.660034, 668.190002, 669.720032, 343.950012, 552.150024, 674.309998, 675.839966, 677.369995, 678.900024, 680.429993, 681.960022, 683.490051, 685.020020, 351.750000, 564.750000, 689.609985, 691.140015, 692.669983, 694.200012, 695.729980, 697.260010, 698.789978, 700.320007, 359.549988, 577.349976, 704.910034, 706.440002, 707.970032, 709.500000, 711.029968, 712.559998, 714.089966, 715.619995, 367.350037, 589.950012, 720.210022, 721.740051, 723.270020, 724.800049, 726.329956, 727.859985, 729.390015, 730.919983, 375.149994, 602.549988, 735.510010, 737.039978, 738.570007, 740.100037, 741.630005, 743.160034, 744.690002, 746.220032, 382.950012, 615.150024, 750.809998, 752.339966, 753.869995, 755.399963, 756.929993, 758.460022, 759.990051, 761.520020, 390.750000, 177.260010, 149.720001, 150.020004, 150.319992, 150.619995, 150.919998, 151.220001, 151.520004, 151.819992, 22.220009, 689.959961, 914.720032, 916.519958, 918.319946, 920.119995, 921.919983, 923.719971, 925.520020, 927.320007, 536.519958, 683.579956, 842.699951, 844.319946, 845.940002, 847.559998, 849.179993, 850.799988, 852.419983, 854.039978, 445.260010, 696.779968, 858.900024, 860.520020, 862.140015, 863.760010, 865.380005, 867.000000, 868.619995, 870.239990, 453.659973, 709.979980, 875.099976, 876.719971, 878.339966, 879.959961, 881.579956, 883.199951, 884.819946, 886.440002, 462.059998, 723.179993, 891.299988, 892.919983, 894.539978, 896.159973, 897.779968, 899.400024, 901.020020, 902.640015, 470.459991, 736.380005, 907.500000, 909.119995, 910.739990, 912.359985, 913.979980, 915.599976, 917.219971, 918.839966, 478.859985, 749.579956, 923.699951, 925.319946, 926.940002, 928.559998, 930.179993, 931.799988, 933.419983, 935.039978, 487.260010, 762.779968, 939.900024, 941.520020, 943.140015, 944.760010, 946.380005, 948.000000, 949.619995, 951.239990, 495.659973, 775.979980, 956.099976, 957.719971, 959.339966, 960.959961, 962.579956, 964.199951, 965.819946, 967.440002, 504.059998, 236.679977, 214.880005, 215.239990, 215.599991, 215.959991, 216.319992, 216.679993, 217.040009, 217.399994, 49.959995, 850.180054, 1131.079956, 1132.939941, 1134.800049, 1136.660034, 1138.520020, 1140.380005, 1142.239990, 1144.100098, 665.060059, 852.209961, 1059.689941, 1061.399902, 1063.110107, 1064.820068, 1066.530029, 1068.239990, 1069.950073, 1071.660034, 566.370056, 866.010010, 1076.790039, 1078.500000, 1080.209961, 1081.920044, 1083.630005, 1085.339966, 1087.050049, 1088.760010, 575.369995, 879.809998, 1093.890015, 1095.599976, 1097.310059, 1099.020020, 1100.729980, 1102.439941, 1104.149902, 1105.859985, 584.369995, 893.609985, 1110.989990, 1112.699951, 1114.410034, 1116.120117, 1117.830078, 1119.540039, 1121.250000, 1122.959961, 593.370056, 907.410034, 1128.089966, 1129.800049, 1131.510010, 1133.220093, 1134.929932, 1136.639893, 1138.349976, 1140.060059, 602.369995, 921.209961, 1145.189941, 1146.900024, 1148.609985, 1150.320068, 1152.030029, 1153.739990, 1155.449951, 1157.160034, 611.370056, 935.010010, 1162.290039, 1164.000000, 1165.709961, 1167.420044, 1169.130005, 1170.839966, 1172.550049, 1174.260010, 620.369995, 948.809998, 1179.390015, 1181.099976, 1182.810059, 1184.520020, 1186.229980, 1187.939941, 1189.650024, 1191.359985, 629.370056, 304.099976, 292.039978, 292.460022, 292.880005, 293.300018, 293.720001, 294.140015, 294.559998, 294.980042, 85.700005}, sd::DataType::FLOAT32);
|
|
|
|
input.linspace(-10, 0.1);
|
|
weights.linspace(-2, 0.1);
|
|
|
|
sd::ops::depthwise_conv2d op;
|
|
auto results = op.evaluate({&input, &weights}, {}, {kH,kW, sH,sW, pH,pW, dH,dW, paddingMode, dataFormat});
|
|
auto output = results.at(0);
|
|
|
|
ASSERT_EQ(Status::OK(), results.status());
|
|
|
|
ASSERT_TRUE(expOutput.isSameShape(output));
|
|
ASSERT_TRUE(expOutput.equalsTo(output, 1e-4));
|
|
|
|
}
|
|
|
|
//////////////////////////////////////////////////////////////////////
|
|
TEST_F(ConvolutionTests2, depthwise_conv2d_10) {
|
|
|
|
int bS=1, iH=3,iW=3, iC=2,mC=2, kH=1,kW=1, sH=1,sW=1, pH=0,pW=0, dH=1,dW=1;
|
|
int oC=iC*mC;
|
|
int oH=3,oW=3;
|
|
int paddingMode = 0; // 1-SAME, 0-VALID;
|
|
int dataFormat = 0; // 1-NHWC, 0-NCHW
|
|
int wFormat = 1; // 0-[kH, kW, iC, mC], 1-[mC, iC, kH, kW], 2-[mC, kH, kW, iC]
|
|
|
|
NDArray input('c', {bS, iC, iH, iW}, {0.6793503761291504, 0.35508695244789124, 0.842789351940155, 0.20031332969665527, 0.7014986872673035, 0.3106933832168579,
|
|
0.44793984293937683, 0.9380097389221191, 0.3266739547252655, 0.15187257528305054, 0.3833175301551819, 0.7821229696273804,
|
|
0.19880719482898712, 0.7985635995864868, 0.16326339542865753, 0.14696824550628662, 0.2608966827392578, 0.13505761325359344}, sd::DataType::FLOAT32);
|
|
NDArray weights('c', {mC, iC, kH, kW}, {0.130845, 0.569885, 0.644284, 0.198968}, sd::DataType::FLOAT32);
|
|
NDArray biases('c', {iC*mC}, {0.6123566627502441, 0.37637925148010254, 0.17464971542358398, 0.4270855486392975}, sd::DataType::FLOAT32);
|
|
|
|
NDArray expOutput('c', {bS, oC, oH, oW}, {0.7012459761288241, 0.6588178652487691, 0.722631079971582, 0.6385665758716108, 0.7041439625563628, 0.6530092074102978,
|
|
0.670967162534851, 0.735090151337225, 0.6551001785478623, 0.8140738359624038, 0.6051560970782859, 0.9193749546773375, 0.5054379267801892, 0.8283436386757472,
|
|
0.5765540302788565, 0.6649797296980537, 0.9807239274294943, 0.586850056971322, 0.261199593183985, 0.3930965634902499, 0.6203697362284615, 0.28794692117826504,
|
|
0.6297390019475202, 0.26769104886224415, 0.25840469001015975, 0.3233307788551656, 0.25161700129415276, 0.4573034071191504, 0.5033536625992294, 0.5827033826425385,
|
|
0.4666419179635315, 0.585974550122895, 0.4595698215161401, 0.45632759998045813, 0.4789957702325296, 0.4539577593482922}, sd::DataType::FLOAT32);
|
|
|
|
sd::ops::depthwise_conv2d op;
|
|
auto results = op.evaluate({&input, &weights, &biases}, {}, {kH,kW, sH,sW, pH,pW, dH,dW, paddingMode, dataFormat, wFormat});
|
|
auto* output = results.at(0);
|
|
|
|
ASSERT_EQ(Status::OK(), results.status());
|
|
|
|
ASSERT_TRUE(expOutput.isSameShape(output));
|
|
ASSERT_TRUE(expOutput.equalsTo(output));
|
|
}
|
|
|
|
//////////////////////////////////////////////////////////////////////
|
|
TEST_F(ConvolutionTests2, depthwise_conv2d_11) {
|
|
|
|
int bS=1, iH=10,iW=10, iC=8,mC=1, kH=3,kW=3, sH=1,sW=1, pH=0,pW=0, dH=1,dW=1;
|
|
int oC=iC*mC;
|
|
int oH=10,oW=10;
|
|
int paddingMode = 1; // 1-SAME, 0-VALID;
|
|
int dataFormat = 1; // 1-NHWC, 0-NCHW
|
|
int wFormat = 2; // 0-[kH, kW, iC, mC], 1-[mC, iC, kH, kW], 2-[mC, kH, kW, iC]
|
|
|
|
NDArray input('c', {bS, iH, iW, iC}, sd::DataType::FLOAT32);
|
|
NDArray weights('c', {mC, kH, kW, iC}, {-2., -1.9, -1.8, -1.7, -1.6, -1.5, -1.4, -1.3, -1.2, -1.1, -1., -0.9, -0.8, -0.7, -0.6, -0.5, -0.4, -0.3, -0.2, -0.1,
|
|
0., 0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9, 1., 1.1, 1.2, 1.3, 1.4, 1.5, 1.6, 1.7, 1.8, 1.9, 2., 2.1, 2.2, 2.3, 2.4, 2.5, 2.6, 2.7, 2.8, 2.9, 3.,
|
|
3.1, 3.2, 3.3, 3.4, 3.5, 3.6, 3.7, 3.8, 3.9, 4., 4.1, 4.2, 4.3, 4.4, 4.5, 4.6, 4.7, 4.8, 4.9, 5., 5.1}, sd::DataType::FLOAT32);
|
|
|
|
NDArray expOutput('c', {bS, oH, oW, oC}, {-42.879997, -43.959999, -44.959999, -45.879997, -46.720005, -47.480003, -48.160000, -48.760002, -43.519997, -45.139999, -46.639996, -48.020000, -49.280003, -50.419998, -51.440006, -52.340000, -31.999998, -33.139999, -34.160000, -35.060001, -35.840004, -36.500004, -37.039997, -37.459999, -20.480000,
|
|
-21.139997, -21.680000, -22.100000, -22.399998, -22.579998, -22.639996, -22.580002, -8.960000, -9.139998, -9.200002, -9.140001, -8.960001, -8.660000, -8.240002, -7.700001, 2.560000, 2.860002, 3.279998, 3.820000, 4.480001, 5.260000, 6.160001, 7.180000, 14.080000, 14.860000, 15.759998, 16.779999, 17.920002, 19.180000, 20.560001, 22.059998,
|
|
25.600000, 26.860001, 28.239998, 29.739998, 31.360001, 33.099998, 34.959999, 36.939999, 37.119999, 38.860001, 40.720001, 42.699997, 44.800003, 47.020000, 49.360001, 51.820000, 26.239998, 27.400002, 28.639999, 29.959999, 31.360001, 32.840000, 34.400002, 36.040001, 62.400002, 62.459999, 62.639999, 62.940002, 63.360001, 63.900002, 64.559998,
|
|
65.340004, 106.080002, 106.169998, 106.440002, 106.889999, 107.519997, 108.330002, 109.320000, 110.490005, 114.720001, 115.529999, 116.520004, 117.690002, 119.040009, 120.570000, 122.279999, 124.169998, 123.359985, 124.889999, 126.599998, 128.490005, 130.559998, 132.809998, 135.240005, 137.850006, 132.000000, 134.250000, 136.679993,
|
|
139.290009, 142.080002, 145.049988, 148.199997, 151.529999, 140.639999, 143.610001, 146.760010, 150.089996, 153.600006, 157.290009, 161.160004, 165.209991, 149.279999, 152.970001, 156.839996, 160.889999, 165.120010, 169.529999, 174.119995, 178.889999, 157.919998, 162.330002, 166.919983, 171.690002, 176.639999, 181.769989, 187.079987,
|
|
192.570007, 166.559998, 171.690002, 177.000000, 182.489990, 188.160004, 194.010010, 200.040009, 206.250000, 100.799995, 104.220001, 107.760002, 111.419998, 115.200005, 119.099998, 123.120003, 127.260010, 139.200012, 144.059998, 149.040009, 154.139999, 159.360001, 164.699997, 170.160004, 175.739990, 192.479996, 199.770020, 207.239990,
|
|
214.889999, 222.720001, 230.730011, 238.919998, 247.290009, 201.119995, 209.129990, 217.319992, 225.690002, 234.240005, 242.970001, 251.880005, 260.970001, 209.760010, 218.489990, 227.399994, 236.490005, 245.760010, 255.209991, 264.839996, 274.649994, 218.399994, 227.850006, 237.479996, 247.289993, 257.279999, 267.449982, 277.799988,
|
|
288.330017, 227.040009, 237.209991, 247.559998, 258.089996, 268.800018, 279.690002, 290.760010, 302.010010, 235.679993, 246.570007, 257.639984, 268.889984, 280.320007, 291.929993, 303.720001, 315.690002, 244.320007, 255.929993, 267.720001, 279.690002, 291.839996, 304.169983, 316.679993, 329.369995, 252.959991, 265.290009, 277.799988,
|
|
290.489990, 303.359985, 316.410004, 329.640015, 343.050018, 139.199997, 147.419998, 155.760010, 164.220001, 172.799988, 181.500000, 190.319992, 199.260010, 216.000000, 225.660004, 235.440002, 245.339996, 255.360016, 265.500000, 275.760010, 286.140015, 278.880005, 293.369995, 308.040009, 322.889984, 337.920013, 353.129974, 368.519989,
|
|
384.090027, 287.520020, 302.730011, 318.119995, 333.690002, 349.440002, 365.369995, 381.479980, 397.770020, 296.160004, 312.089996, 328.199982, 344.489990, 360.960022, 377.609985, 394.440002, 411.449982, 304.799988, 321.450012, 338.280029, 355.289978, 372.480011, 389.850006, 407.399994, 425.130005, 313.440002, 330.809998, 348.359985, 366.089996, 384.000000, 402.090027, 420.359985, 438.809998, 322.079987, 340.169983, 358.440002, 376.889984, 395.520020, 414.329987, 433.320007, 452.489990, 330.720001, 349.530029, 368.520020, 387.690002, 407.039978, 426.570007, 446.279999, 466.170013, 339.360016, 358.890015, 378.599976, 398.490021, 418.559998, 438.809998, 459.239990, 479.849976, 177.600006, 190.619995, 203.759995, 217.020004, 230.399994, 243.899994, 257.519989, 271.260010, 292.799988, 307.260010, 321.839996, 336.539978, 351.360016, 366.299988, 381.359985, 396.540009, 365.279999, 386.970001, 408.839996, 430.889984, 453.120026, 475.529968, 498.119995, 520.890015, 373.920013, 396.329987, 418.919983, 441.690002, 464.640015, 487.769958, 511.079987, 534.570007, 382.559998, 405.690002, 429.000000, 452.489990, 476.160004, 500.010010, 524.039978, 548.250000, 391.200012, 415.049988, 439.080017, 463.290009, 487.679993, 512.250000, 537.000000, 561.930054, 399.839996, 424.409973, 449.160034, 474.089966, 499.200012, 524.489990, 549.959961, 575.609985, 408.479980, 433.770020, 459.239990, 484.889954, 510.720032, 536.729980, 562.919983, 589.290039, 417.119995, 443.130005, 469.319977, 495.690002, 522.239990, 548.969971, 575.880005, 602.969971, 425.760010, 452.489990, 479.399994, 506.489990, 533.760010, 561.209961, 588.839966, 616.650024, 216.000000, 233.819992, 251.760010, 269.820007, 288.000000, 306.299988, 324.719971, 343.260010, 369.600006, 388.859985, 408.239990, 427.739990, 447.360016, 467.100006, 486.959961, 506.940002, 451.679993, 480.570007, 509.639984, 538.890015, 568.320007, 597.929993, 627.719971, 657.690002, 460.320007, 489.929993, 519.719971, 549.690002, 579.840027, 610.170044, 640.680054, 671.369995, 468.960022, 499.289978, 529.799988, 560.489990, 591.359985, 622.409973, 653.640015, 685.049988, 477.599976, 508.650024, 539.880005, 571.289978, 602.880005, 634.650024, 666.599976, 698.729980, 486.239990, 518.010010, 549.960022, 582.089966, 614.400024, 646.890015, 679.559937, 712.410034, 494.879974, 527.369995, 560.039978, 592.890015, 625.920044, 659.130005, 692.520020, 726.089966, 503.519989, 536.729980, 570.119995, 603.689941, 637.440063, 671.369995, 705.480042, 739.770020, 512.160034, 546.089966, 580.199951, 614.489990, 648.960022, 683.609985, 718.440002, 753.449951, 254.400009, 277.020020, 299.760010, 322.619995, 345.600006, 368.700012, 391.919983, 415.260010, 446.399994, 470.459961, 494.640015, 518.940002, 543.360046, 567.900024, 592.559998, 617.340027, 538.080017, 574.170044, 610.440002, 646.890015, 683.520020, 720.329956, 757.320007, 794.489990, 546.719971, 583.530029, 620.520020, 657.690002, 695.040039, 732.570007, 770.279968, 808.169983, 555.359985, 592.889954, 630.599976, 668.489990, 706.559998, 744.809998, 783.239990, 821.849976, 564.000000, 602.250000, 640.679993, 679.289978, 718.080017, 757.050049, 796.199951, 835.530029, 572.640015, 611.609985, 650.760010, 690.089966, 729.600037, 769.289978, 809.160034, 849.210083, 581.279968, 620.970032, 660.839966, 700.889954, 741.119995, 781.529968, 822.119995, 862.890015, 589.919983, 630.330017, 670.919983, 711.690002, 752.640015, 793.770020, 835.079956, 876.570007, 598.559998, 639.690002, 681.000000, 722.490051, 764.160034, 806.010010, 848.039978, 890.250061, 292.799988, 320.220001, 347.760010, 375.419983, 403.200012, 431.100006, 459.119995, 487.260010, 523.199951, 552.059998, 581.040039, 610.139954, 639.360046, 668.699951, 698.159973, 727.739990, 624.479980, 667.770020, 711.239990, 754.890015, 798.719971, 842.729980, 886.919983, 931.290039, 633.119995, 677.130005, 721.319946, 765.690002, 810.239990, 854.969971, 899.880005, 944.969971, 641.760010, 686.489990, 731.400024, 776.489990, 821.760010, 867.209961, 912.839966, 958.650024, 650.400024, 695.849976, 741.479980, 787.290039, 833.279968, 879.449951, 925.799927, 972.330017, 659.040039, 705.210022, 751.559998, 798.089966, 844.800049, 891.690002, 938.760010, 986.010010, 667.679993, 714.569946, 761.640015, 808.890015, 856.320007, 903.929993, 951.719971, 999.690063, 676.320007, 723.929993, 771.719971, 819.690002, 867.839966, 916.169922, 964.679932, 1013.369995, 684.959961, 733.290039, 781.800049, 830.489990, 879.359985, 928.410034, 977.640015, 1027.050049, 331.199982, 363.419983, 395.760010, 428.220001, 460.799988, 493.500000, 526.320007, 559.260010, 600.000000, 633.660034, 667.440002, 701.339966, 735.359985, 769.500000, 803.759949, 838.140015, 710.880005, 761.369995, 812.039978, 862.889893, 913.919983, 965.130005, 1016.520020, 1068.090088, 719.520020, 770.729980, 822.119934, 873.689941, 925.440063, 977.369995, 1029.479980, 1081.770020, 728.160034, 780.090088, 832.199951, 884.489990, 936.960022, 989.610046, 1042.439941, 1095.449951, 736.799927, 789.449951, 842.280029, 895.290039, 948.480042, 1001.849976, 1055.399902, 1109.129883, 745.439941, 798.810059, 852.359985, 906.089966, 960.000000, 1014.089966, 1068.359985, 1122.810059, 754.080017, 808.170044, 862.440002, 916.890015, 971.520020, 1026.330078, 1081.319946, 1136.489990, 762.720032, 817.530029, 872.520020, 927.689941, 983.040039, 1038.569946, 1094.280029, 1150.169922, 771.359985, 826.890015, 882.599976, 938.489990, 994.559998, 1050.810059, 1107.239990, 1163.849976, 369.599976, 406.619995, 443.760010, 481.020020, 518.400024, 555.900024, 593.520020, 631.260010, 113.279999, 136.839996, 160.480011, 184.199982, 208.000015, 231.880005, 255.839996, 279.880005, 31.359985, 66.699989, 102.160004, 137.740005, 173.440002, 209.260010, 245.199982, 281.260010, 31.359993, 67.179993, 103.120003, 139.179993, 175.360016, 211.660004, 248.079987, 284.619995, 31.359993, 67.659996, 104.080009, 140.619995, 177.280014, 214.060013, 250.959991, 287.980011, 31.359993, 68.139999, 105.039993, 142.059982, 179.200027, 216.459991, 253.839996, 291.339996, 31.360008, 68.619995, 106.000000, 143.499985, 181.119995, 218.860001, 256.719971, 294.700012, 31.360001, 69.099991, 106.959984, 144.939987, 183.040009, 221.260010, 259.600006, 298.059998, 31.360008, 69.579971, 107.920006, 146.379990, 184.960007, 223.660004, 262.479980, 301.419983, 31.360001, 70.059975, 108.880020, 147.819977, 186.880020, 226.059998, 265.359985, 304.779999, -83.840004, -58.040001, -32.159988, -6.200012, 19.840012, 45.959984, 72.159996, 98.440010}, sd::DataType::FLOAT32);
|
|
|
|
input.linspace(-10, 0.1);
|
|
weights.linspace(-2, 0.1);
|
|
|
|
sd::ops::depthwise_conv2d op;
|
|
auto results = op.evaluate({&input, &weights}, {}, {kH,kW, sH,sW, pH,pW, dH,dW, paddingMode, dataFormat, wFormat});
|
|
auto output = results.at(0);
|
|
|
|
ASSERT_EQ(Status::OK(), results.status());
|
|
|
|
ASSERT_TRUE(expOutput.isSameShape(output));
|
|
ASSERT_TRUE(expOutput.equalsTo(output));
|
|
}
|
|
|
|
//////////////////////////////////////////////////////////////////////
|
|
TEST_F(ConvolutionTests2, depthwise_conv2d_bp_test1) {
|
|
|
|
int bS=2, iH=4,iW=3, iC=2,mC=2, kH=3,kW=2, sH=1,sW=1, pH=0,pW=0, dH=1,dW=1;
|
|
int oH=4,oW=3;
|
|
int oC=iC*mC;
|
|
int paddingMode = 1; // 1-SAME, 0-VALID;
|
|
int dataFormat = 1; // 1-NHWC, 0-NCHW
|
|
|
|
auto input = NDArrayFactory::create<float>('c', {bS, iH, iW, iC});
|
|
auto weights = NDArrayFactory::create<float>('c', {kH, kW, iC, mC});
|
|
auto bias = NDArrayFactory::create<float>('c', {oC}, {1,2,3,4});
|
|
auto gradO = NDArrayFactory::create<float>('c', {bS, oH, oW, oC});
|
|
|
|
NDArray expGradI('c', {bS, iH, iW, iC},{0.07 , 0.19 , 0.348, 0.652, 0.588, 0.956, 0.387, 0.687, 1.326, 2.022, 1.878, 2.67 , 1.071, 1.515, 2.982, 3.966, 3.534, 4.614, 1.606, 1.982, 3.932, 4.748, 4.428, 5.308,
|
|
1.126, 1.63 , 3.228, 4.3 , 3.468, 4.604, 3.123, 3.999, 7.95 , 9.798, 8.502, 10.446, 3.807, 4.827, 9.606, 11.742,10.158, 12.39 , 4.198, 4.958, 9.884, 11.468,10.38 , 12.028}, sd::DataType::FLOAT32);
|
|
|
|
NDArray expGradW('c', {kH, kW, iC, mC},{19.08, 19.44,19.8 , 20.16,12.24, 12.48,12.72, 12.96,22.56, 23.04,23.52, 24. ,14.4 , 14.72,15.04, 15.36,14.76, 15.12,15.48, 15.84, 9.36, 9.6 , 9.84, 10.08}, sd::DataType::FLOAT32);
|
|
|
|
input = 2.;
|
|
weights.linspace(0.1, 0.1);
|
|
gradO.linspace(0.01, 0.01);
|
|
|
|
sd::ops::depthwise_conv2d_bp op;
|
|
auto results = op.evaluate({&input, &weights, &bias, &gradO}, {kH,kW, sH,sW, pH,pW, dH,dW, paddingMode, dataFormat});
|
|
auto* gradI = results.at(0);
|
|
auto* gradW = results.at(1);
|
|
|
|
ASSERT_EQ(Status::OK(), results.status());
|
|
|
|
ASSERT_TRUE(expGradI.isSameShape(gradI));
|
|
ASSERT_TRUE(expGradI.equalsTo(gradI));
|
|
|
|
ASSERT_TRUE(expGradW.isSameShape(gradW));
|
|
ASSERT_TRUE(expGradW.equalsTo(gradW));
|
|
|
|
}
|
|
|
|
//////////////////////////////////////////////////////////////////////
|
|
TEST_F(ConvolutionTests2, depthwise_conv2d_bp_test2) {
|
|
|
|
int bS=2, iH=4,iW=3, iC=2,mC=2, kH=3,kW=2, sH=1,sW=1, pH=0,pW=0, dH=1,dW=1;
|
|
int oH=2,oW=2;
|
|
int oC=iC*mC;
|
|
int paddingMode = 0; // 1-SAME, 0-VALID;
|
|
int dataFormat = 1; // 1-NHWC, 0-NCHW
|
|
|
|
auto input = NDArrayFactory::create<float>('c', {bS, iH, iW, iC});
|
|
auto weights = NDArrayFactory::create<float>('c', {kH, kW, iC, mC});
|
|
auto bias = NDArrayFactory::create<float>('c', {oC}, {1,2,3,4});
|
|
auto gradO = NDArrayFactory::create<float>('c', {bS, oH, oW, oC});
|
|
|
|
NDArray expGradI('c', {bS, iH, iW, iC},{0.005, 0.025,0.034, 0.106,0.061, 0.113,0.058, 0.162,0.292, 0.564,0.298, 0.466,0.234, 0.402,0.772, 1.172,0.602, 0.834,0.333, 0.449,0.882, 1.146,0.581, 0.729,
|
|
0.053, 0.137,0.258, 0.458,0.237, 0.353,0.41 , 0.642,1.252, 1.78 ,0.906, 1.202,1.098, 1.394,2.756, 3.412,1.722, 2.082,0.893, 1.073,2.13 , 2.522,1.269, 1.481}, sd::DataType::FLOAT32);
|
|
NDArray expGradW('c', {kH, kW, iC, mC},{2.4 , 2.56,2.72, 2.88,2.4 , 2.56,2.72, 2.88,2.4 , 2.56,2.72, 2.88,2.4 , 2.56,2.72, 2.88,2.4 , 2.56,2.72, 2.88,2.4 , 2.56,2.72, 2.88}, sd::DataType::FLOAT32);
|
|
|
|
input = 2.;
|
|
weights.linspace(0.1, 0.1);
|
|
gradO.linspace(0.01, 0.01);
|
|
|
|
sd::ops::depthwise_conv2d_bp op;
|
|
auto results = op.evaluate({&input, &weights, &bias, &gradO}, {kH,kW, sH,sW, pH,pW, dH,dW, paddingMode, dataFormat});
|
|
auto* gradI = results.at(0);
|
|
auto* gradW = results.at(1);
|
|
|
|
ASSERT_EQ(Status::OK(), results.status());
|
|
|
|
ASSERT_TRUE(expGradI.isSameShape(gradI));
|
|
ASSERT_TRUE(expGradI.equalsTo(gradI));
|
|
|
|
ASSERT_TRUE(expGradW.isSameShape(gradW));
|
|
ASSERT_TRUE(expGradW.equalsTo(gradW));
|
|
|
|
}
|
|
|
|
//////////////////////////////////////////////////////////////////////
|
|
TEST_F(ConvolutionTests2, depthwise_conv2d_bp_test3) {
|
|
|
|
auto in = NDArrayFactory::create<float>('c', {4, 8, 64, 64});
|
|
auto w = NDArrayFactory::create<float>('c', {2, 2, 8, 2});
|
|
auto b = NDArrayFactory::create<float>('c', {1, 16});
|
|
auto grad = NDArrayFactory::create<float>('c', {4, 16, 64, 64});
|
|
|
|
auto gradI = in.like();
|
|
auto gradW = w.like();
|
|
auto gradB = b.like();
|
|
|
|
nd4j:ops::depthwise_conv2d_bp op;
|
|
auto status = op.execute({&in, &w, &b, &grad}, {&gradI, &gradW, &gradB}, {2, 2, 1, 1, 0, 0, 1, 1, 1, 0});
|
|
ASSERT_EQ(Status::OK(), status);
|
|
}
|
|
|
|
//////////////////////////////////////////////////////////////////////
|
|
TEST_F(ConvolutionTests2, depthwise_conv2d_bp_test4) {
|
|
|
|
int bS=1, iH=10,iW=10, iC=8,mC=1, kH=3,kW=3, sH=1,sW=1, pH=0,pW=0, dH=1,dW=1;
|
|
int oC=iC*mC;
|
|
int oH=10,oW=10;
|
|
int paddingMode = 1; // 1-SAME, 0-VALID;
|
|
int dataFormat = 1; // 1-NHWC, 0-NCHW
|
|
|
|
NDArray input('c', {bS, iH, iW, iC}, sd::DataType::FLOAT32);
|
|
NDArray weights('c', {kH, kW, iC, mC}, sd::DataType::FLOAT32);
|
|
NDArray gradO('c', {bS, oH, oW, oC}, sd::DataType::FLOAT32);
|
|
NDArray bias('c', {oC}, sd::DataType::FLOAT32);
|
|
|
|
input.linspace(-10, 0.1);
|
|
weights.linspace(-2, 0.1);
|
|
gradO.linspace(10, -0.1);
|
|
|
|
|
|
NDArray expGradI('c', {bS, iH, iW, iC},{10.880001, 13.239998, 15.520001, 17.719997, 19.840000, 21.880001, 23.839998, 25.720001, 31.360004, 34.420002, 37.360001, 40.180004, 42.880005, 45.460003, 47.919994, 50.260002, 31.360001, 33.939999, 36.400002, 38.739998, 40.959999, 43.059998, 45.040001, 46.900005, 31.359997, 33.459999, 35.439999, 37.300003, 39.040001, 40.660000, 42.160000, 43.539997, 31.360001, 32.980000, 34.480000, 35.860001, 37.119999, 38.259998, 39.279999, 40.180000, 31.360001, 32.499996, 33.520000, 34.419998, 35.200001, 35.860001, 36.400002, 36.820000, 31.360001, 32.019997, 32.560001, 32.979996, 33.280003, 33.459999, 33.520000, 33.459999, 31.360001, 31.540001, 31.599998, 31.539999, 31.360001, 31.059999, 30.639999, 30.100000, 31.360001, 31.060001, 30.639999, 30.099998, 29.440002, 28.660000, 27.759998, 26.740000, 18.559999, 18.040001, 17.440001, 16.760000, 16.000000, 15.160000, 14.240001, 13.240000, 85.439995, 85.860001, 86.159996, 86.339996, 86.400002, 86.340012, 86.159996, 85.860008, 132.000000, 131.910004, 131.639999, 131.190002, 130.559998, 129.750000, 128.760010, 127.589996, 123.360001, 122.550003, 121.559998, 120.389999, 119.040009, 117.510002, 115.799988, 113.910004, 114.720001, 113.189995, 111.480003, 109.590004, 107.520004, 105.270004, 102.839996, 100.230011, 106.079994, 103.830002, 101.400009, 98.790009, 96.000008,
|
|
93.030006, 89.879990, 86.549988, 97.439995, 94.469994, 91.319992, 87.990005, 84.479996, 80.789993, 76.919998, 72.870003, 88.800003, 85.110001, 81.239998, 77.190002, 72.960007, 68.550003, 63.959999, 59.190002, 80.160004, 75.750000, 71.160004, 66.389999, 61.440002, 56.309994, 51.000000, 45.510002, 71.519997, 66.389999, 61.079998, 55.590000, 49.919998, 44.070000, 38.040001, 31.830002, 31.680000, 27.780003, 23.760000, 19.619999, 15.360001, 10.980000, 6.480000, 1.859999, 47.040001, 42.660004, 38.160000, 33.540001, 28.799999, 23.939999, 18.960001, 13.860001, 45.599998, 38.310001, 30.840000, 23.190002, 15.360001, 7.349998, -0.840002, -9.210003, 36.959999, 28.950003, 20.759998, 12.390001, 3.839998, -4.889999, -13.799999, -22.890003, 28.320002, 19.589998, 10.680000, 1.590002, -7.680002, -17.129999, -26.759998, -36.570007, 19.680002, 10.230003, 0.599998, -9.210001, -19.199999, -29.370003, -39.720001, -50.250008, 11.039999, 0.869999, -9.480000, -20.010002, -30.719994, -41.610001, -52.679996, -63.930008, 2.400005, -8.489998, -19.560005, -30.809998, -42.239998, -53.849991, -65.639992, -77.610001, -6.239998, -17.849998, -29.639988, -41.609985, -53.760002, -66.090004, -78.599991, -91.290009, -14.879990, -27.209995, -39.720009, -52.410007, -65.279999, -78.330002, -91.559998, -104.969986, -45.119995, -53.820000, -62.639999, -71.580002, -80.640007, -89.819992, -99.119995, -108.540009, 8.639999, -0.540001, -9.839996, -19.259998, -28.799995, -38.459999, -48.240002, -58.140003, -40.799999, -55.289997, -69.960007, -84.810013, -99.840004, -115.050011, -130.440018, -146.010010, -49.439991, -64.650009, -80.040009, -95.610016, -111.360008, -127.290001, -143.399994, -159.690018, -58.080009, -74.009987, -90.119995, -106.409988, -122.880005, -139.530014, -156.360001, -173.369995, -66.720001, -83.369995, -100.199997,
|
|
-117.209999, -134.399994, -151.769989, -169.319992, -187.049988, -75.360008, -92.729996, -110.279991, -128.009979, -145.920013, -164.009995, -182.279984, -200.729996, -84.000000, -102.089996, -120.360016, -138.809967, -157.440002, -176.249969, -195.240005, -214.410019, -92.639999, -111.449997, -130.440018, -149.610016, -168.960007, -188.489990, -208.200012, -228.090012, -101.279976, -120.809982, -140.519989, -160.410004, -180.480011, -200.730011, -221.160034, -241.770020, -121.920006, -135.420013, -149.040009, -162.779999, -176.640015, -190.619995, -204.719986, -218.940002, -29.760002, -43.739998, -57.840000, -72.059998, -86.400009, -100.860001, -115.439995, -130.140015, -127.199997, -148.890015, -170.760010, -192.809998, -215.040024, -237.450012, -260.039978, -282.809998, -135.839996, -158.250000, -180.840012, -203.610046, -226.559982, -249.690002, -272.999969, -296.489990, -144.479980, -167.609985, -190.920013, -214.410019, -238.080032, -261.929993, -285.959991, -310.169983, -153.119995, -176.969986, -201.000031, -225.210022, -249.599976, -274.170013, -298.920013, -323.849976, -161.760040, -186.330017, -211.079987, -236.009995, -261.120026, -286.410034, -311.879974, -337.530029, -170.400009, -195.689987, -221.159973, -246.809998, -272.639954, -298.650024, -324.840057, -351.209991, -179.039963, -205.050018, -231.240021, -257.609985, -284.160004, -310.890015, -337.799988, -364.890015, -187.680023, -214.410004, -241.319977, -268.410004, -295.679993, -323.130005, -350.760010, -378.570038, -198.720016, -217.019989, -235.440002, -253.979980, -272.640045, -291.419983, -310.319977, -329.339996, -68.159981, -86.939987, -105.840012, -124.860001, -144.000000, -163.260010, -182.639984, -202.140015, -213.600021, -242.489990, -271.559937, -300.809998, -330.239990, -359.849976, -389.639984,
|
|
-419.610016, -222.240036, -251.849960, -281.640015, -311.609985, -341.760040, -372.089996, -402.600037, -433.290009, -230.880005, -261.210022, -291.719971, -322.410034, -353.280029, -384.329956, -415.559998, -446.970001, -239.519989, -270.570007, -301.800018, -333.209991, -364.800018, -396.570007, -428.520020, -460.650024, -248.160034, -279.929962, -311.880005, -344.010010, -376.320038, -408.809998, -441.479980, -474.330017, -256.799988, -289.289978, -321.960022, -354.809967, -387.839996, -421.050018, -454.440002, -488.009979, -265.440002, -298.650024, -332.040009, -365.609985, -399.360016, -433.290009, -467.399963, -501.689941, -274.080017, -308.009949, -342.119995, -376.409973, -410.880005, -445.530029, -480.359985, -515.369995, -275.520020, -298.619995, -321.839966, -345.179993, -368.640015, -392.220001, -415.919952, -439.740021, -106.560005, -130.140030, -153.840027, -177.659973, -201.599991, -225.660019, -249.840012, -274.140015, -300.000000, -336.090057, -372.360046, -408.809937, -445.440002, -482.250031, -519.240051, -556.410034, -308.640015, -345.450012, -382.440002, -419.609955, -456.959961, -494.489960, -532.200012, -570.089966, -317.280029, -354.809998, -392.520020, -430.410004, -468.480042, -506.729980, -545.159912, -583.770020, -325.920013, -364.169952, -402.600037, -441.210022, -480.000000, -518.970032, -558.119873, -597.449951, -334.559967, -373.529999, -412.679993, -452.009949, -491.519989, -531.209961, -571.080017, -611.129944, -343.200012, -382.889984, -422.760071, -462.809906, -503.039978, -543.449951, -584.039978, -624.809998, -351.839966, -392.250000, -432.839966, -473.609955, -514.560120, -555.689941, -596.999939, -638.489990, -360.480011, -401.610016, -442.920044, -484.409912, -526.080017, -567.929993, -609.959961, -652.169983, -352.320007, -380.220001,
|
|
-408.239990, -436.380005, -464.639984, -493.019989, -521.519958, -550.139954, -144.960022, -173.339996, -201.839996, -230.459976, -259.200043, -288.059998, -317.039978, -346.140015, -386.399963, -429.690002, -473.159912, -516.809937, -560.640076, -604.650024, -648.839966, -693.210022, -395.039978, -439.050018, -483.239929, -527.609985, -572.159973, -616.890015, -661.799988, -706.890015, -403.680023, -448.409973, -493.320007, -538.410034, -583.680054, -629.129944, -674.760010, -720.570068, -412.320007, -457.769897, -503.399963, -549.210083, -595.199951, -641.369995, -687.720093, -734.250000, -420.960052, -467.130035, -513.479980, -560.010010, -606.720093, -653.610046, -700.680054, -747.930115, -429.599976, -476.489990, -523.559998, -570.809937, -618.239990, -665.849976, -713.640015, -761.609985, -438.239990, -485.850037, -533.640015, -581.610046, -629.760010, -678.089966, -726.600037, -775.289917, -446.880035,-495.210052, -543.719971, -592.410034, -641.279968, -690.330017, -739.559937, -788.970093, -429.120026, -461.819946, -494.639984, -527.580017, -560.640015, -593.820007, -627.119995, -660.540039, -183.360016, -216.540009, -249.839996, -283.260040, -316.800018, -350.459961, -384.239990, -418.139984, -472.800049, -523.289917, -573.959961, -624.809998, -675.839966, -727.050049, -778.440063, -830.010010, -481.440002, -532.649963, -584.040100, -635.609985, -687.359924, -739.290039, -791.399963, -843.689941, -490.079987, -542.010010, -594.119995, -646.410034, -698.880005, -751.529968, -804.359985, -857.369995, -498.720032, -551.369995, -604.200012, -657.210022, -710.400024, -763.770081, -817.319946, -871.050049, -507.359955, -560.729919, -614.280029, -668.010010, -721.919983, -776.010010, -830.280029, -884.730042, -515.999939, -570.089966, -624.360046, -678.809937, -733.440002,
|
|
-788.250000, -843.239990, -898.410034, -524.639954, -579.449951, -634.440002, -689.609985, -744.960022, -800.489990, -856.200012, -912.090027, -533.280029, -588.810059, -644.520081, -700.409973, -756.480042, -812.730103, -869.159912, -925.769958, -505.920013, -543.420044, -581.040039, -618.780029, -656.640015, -694.620056, -732.719971, -770.940002, -447.359985, -471.559998, -495.840027, -520.200012, -544.640015, -569.159973, -593.760010, -618.440002, -815.359985, -852.140015, -889.040039, -926.059937, -963.200073, -1000.460022, -1037.839966, -1075.339966, -826.879944, -864.139954, -901.519958, -939.019958, -976.640076, -1014.379944, -1052.239990, -1090.219971, -838.400024, -876.140015, -913.999939, -951.979919, -990.080017, -1028.299927, -1066.640015, -1105.099976, -849.919983, -888.140015, -926.479980, -964.939941, -1003.520081, -1042.219971, -1081.040039, -1119.979980, -861.440063, -900.140015, -938.960022,-977.899963, -1016.960022, -1056.140015, -1095.440063, -1134.859985, -872.960022, -912.140015, -951.439941, -990.859985, -1030.400024, -1070.060059, -1109.839844, -1149.739990, -884.479980, -924.140015, -963.919922, -1003.819946, -1043.839966, -1083.979980, -1124.239990, -1164.619995, -896.000000, -936.140015, -976.399963, -1016.780029, -1057.280029, -1097.899902, -1138.640015, -1179.500122, -705.919983, -733.000000, -760.159912, -787.400024, -814.719971, -842.119995, -869.599976, -897.160034}, sd::DataType::FLOAT32);
|
|
|
|
NDArray expGradW('c', {kH, kW, iC, mC},{-104306.421875, -104786.734375, -105268.687500, -105752.250000, -106237.421875, -106724.242188, -107212.671875,
|
|
-107702.734375, -116289.593750, -116823.296875, -117358.781250, -117896.109375, -118435.210938, -118976.109375, -119518.796875, -120063.296875, -104824.789062,
|
|
-105305.117188, -105787.070312, -106270.640625, -106755.843750, -107242.640625, -107731.078125, -108221.117188, -126744.000000, -127277.710938, -127813.187500,
|
|
-128350.484375, -128889.601562, -129430.515625, -129973.210938, -130517.703125, -140944.000000, -141536.984375, -142131.984375, -142729.000000, -143328.000000,
|
|
-143929.015625, -144532.000000, -145137.000000, -126744.000000, -127277.710938, -127813.187500, -128350.484375, -128889.601562, -129430.515625, -129973.210938, -130517.703125, -104824.789062, -105305.117188, -105787.070312, -106270.640625, -106755.843750, -107242.640625, -107731.078125, -108221.117188, -116289.593750, -116823.296875, -117358.781250, -117896.109375, -118435.210938, -118976.109375, -119518.796875, -120063.296875, -104306.421875, -104786.734375, -105268.687500, -105752.250000, -106237.421875, -106724.242188, -107212.671875, -107702.734375}, sd::DataType::FLOAT32);
|
|
|
|
NDArray expGradB('c', {oC}, {-2960., -2970., -2980., -2990., -3000., -3010., -3020., -3030.}, sd::DataType::FLOAT32);
|
|
|
|
sd::ops::depthwise_conv2d_bp op;
|
|
auto results = op.evaluate({&input, &weights, &bias, &gradO}, {kH,kW, sH,sW, pH,pW, dH,dW, paddingMode, dataFormat});
|
|
NDArray* gradI = results.at(0);
|
|
NDArray* gradW = results.at(1);
|
|
NDArray* gradB = results.at(2);
|
|
|
|
ASSERT_EQ(Status::OK(), results.status());
|
|
|
|
ASSERT_TRUE(expGradI.isSameShape(gradI));
|
|
ASSERT_TRUE(expGradI.equalsTo(gradI));
|
|
|
|
ASSERT_TRUE(expGradW.isSameShape(gradW));
|
|
ASSERT_TRUE(expGradW.equalsTo(gradW));
|
|
|
|
ASSERT_TRUE(expGradB.isSameShape(gradB));
|
|
ASSERT_TRUE(expGradB.equalsTo(gradB));
|
|
|
|
}
|
|
|
|
//////////////////////////////////////////////////////////////////////
|
|
TEST_F(ConvolutionTests2, depthwise_conv2d_bp_test5) {
|
|
|
|
int bS=1, iH=10,iW=10, iC=8,mC=1, kH=3,kW=3, sH=1,sW=1, pH=0,pW=0, dH=1,dW=1;
|
|
int oC=iC*mC;
|
|
int oH=10,oW=10;
|
|
int paddingMode = 1; // 1-SAME, 0-VALID;
|
|
int dataFormat = 0; // 1-NHWC, 0-NCHW
|
|
|
|
NDArray input('c', {bS, iC, iH, iW}, sd::DataType::FLOAT32);
|
|
NDArray weights('c', {kH, kW, iC, mC}, sd::DataType::FLOAT32);
|
|
NDArray gradO('c', {bS, oC, oH, oW}, sd::DataType::FLOAT32);
|
|
NDArray bias('c', {oC}, sd::DataType::FLOAT32);
|
|
|
|
input.linspace(-10, 0.1);
|
|
weights.linspace(-2, 0.1);
|
|
gradO.linspace(10, -0.1);
|
|
|
|
|
|
NDArray expGradI('c', {bS, iC, iH, iW}, {-12.639999, 3.920004, 3.920000, 3.920000, 3.920002, 3.920000, 3.920000, 3.919998, 3.919998, 16.319998, 52.680004, 111.000015, 109.919991, 108.840004, 107.760002, 106.680008, 105.600006, 104.519997, 103.440018, 87.960007, 47.880001, 100.200005, 99.119995, 98.040001, 96.959999, 95.879990, 94.799995, 93.720001, 92.639999, 78.360001, 43.079998, 89.399994, 88.320007, 87.240005, 86.159996, 85.079994, 84.000000, 82.919998, 81.840004, 68.759995, 38.279999, 78.600006, 77.519997, 76.440010, 75.360001, 74.279999, 73.200005, 72.120003, 71.040001, 59.160004, 33.480000, 67.799995, 66.720009, 65.639999, 64.559998, 63.480000, 62.399994, 61.320007, 60.240002, 49.559998, 28.680004, 57.000004, 55.919998, 54.839993, 53.759998, 52.680000, 51.600002, 50.519997, 49.440002, 39.959999, 23.880001, 46.200001, 45.120003, 44.039997, 42.959999, 41.880001, 40.799999, 39.719994, 38.639999, 30.360001, 19.079998, 35.400002, 34.320000, 33.239998, 32.159996, 31.080000, 29.999998, 28.919998, 27.840000, 20.759998, 14.079999, 24.080000, 22.639997, 21.200001, 19.759998, 18.320002, 16.880001, 15.440001, 14.000000, 9.759999, 3.140000, 3.560000, 3.500000, 3.440000, 3.380000, 3.320000, 3.260000, 3.200000, 3.140000, -0.220000, 4.050000, 2.010000, 0.840000, -0.330000, -1.499999, -2.670000, -3.840000, -5.010000, -6.179998, -9.150000, -1.350000, -9.690001, -10.859999, -12.029998, -13.200001, -14.370001, -15.539999, -16.710001, -17.879999, -19.349998, -6.750000, -21.389997, -22.560003, -23.730003, -24.900002, -26.069998, -27.239998, -28.410007, -29.580002, -29.550003, -12.150001, -33.089996, -34.260002, -35.430000, -36.600002, -37.770000, -38.939995, -40.110001, -41.280003, -39.749996, -17.550003, -44.790005, -45.959991, -47.129993, -48.300003, -49.470001, -50.640003, -51.809990, -52.979996, -49.950001, -22.949999, -56.490005, -57.660000, -58.829998, -60.000000, -61.170002, -62.340004, -63.510002, -64.680000,
|
|
-60.149994, -28.349998, -68.189987, -69.360001, -70.529999, -71.700005, -72.870010, -74.039993, -75.209999, -76.379990, -70.349998, -33.749996, -79.889999, -81.059990, -82.229988, -83.399994, -84.570007, -85.740005, -86.910004, -88.079994, -80.549995, -69.340004, -125.080002, -126.580002, -128.080002, -129.580002, -131.080002, -132.580002, -134.080002, -135.580002, -105.979996, 10.919998, -8.799997, -8.919998, -9.040003, -9.160004, -9.279999, -9.400002, -9.520002, -9.640003, -24.760000, -56.580009, -124.980003, -126.240005, -127.499992, -128.759995, -130.020020, -131.279999, -132.540009, -133.800003, -118.260002, -62.580009, -137.580002, -138.840012, -140.099991, -141.360001, -142.620010, -143.879974, -145.139999, -146.399994, -129.060013, -68.580002, -150.179993, -151.439987, -152.699997, -153.959991, -155.219986, -156.480011, -157.740005, -159.000000, -139.860001, -74.579994, -162.779999, -164.040024, -165.300003, -166.560028, -167.819977, -169.080002, -170.339996, -171.599991, -150.660004, -80.580002, -175.379990, -176.639999, -177.899994, -179.160019, -180.419998, -181.679993, -182.940002, -184.199997, -161.459991, -86.580002, -187.979996, -189.240005, -190.499985, -191.759995, -193.020020, -194.279999, -195.540024, -196.800018, -172.260010, -92.580002, -200.579987, -201.839981, -203.100006, -204.359970, -205.620010, -206.880005, -208.139999, -209.399994, -183.060013, -98.580002, -213.180023, -214.440002, -215.700012, -216.959991, -218.220001, -219.480011, -220.739975, -222.000000, -193.860001, -160.760010, -286.239990, -287.799988, -289.360016, -290.920013, -292.480011, -294.040009, -295.599976, -297.160004, -229.719986, 10.700003, -33.160004, -33.339996, -33.519993, -33.700001,
|
|
-33.879997, -34.059994, -34.239994, -34.419994, -57.299995, -129.209991, -269.969971, -271.319977, -272.670044, -274.019989, -275.369995, -276.720001, -278.070007, -279.420013, -239.369980, -135.809998, -283.470001, -284.820007, -286.169983, -287.520020, -288.869995, -290.220001, -291.570038, -292.919983, -250.770004, -142.410004, -296.969971, -298.320007, -299.669983, -301.020020, -302.369995, -303.719971, -305.070007, -306.419983, -262.169983, -149.009995, -310.470001, -311.820007, -313.170013, -314.519989, -315.869995, -317.220001, -318.570007, -319.919983, -273.570007, -155.610016, -323.969971, -325.320038, -326.669983, -328.020020, -329.369965, -330.719971, -332.070007, -333.419983, -284.970001, -162.209991, -337.469971, -338.820007, -340.169983, -341.519958, -342.869995, -344.220001, -345.570007, -346.920013, -296.369995, -168.809998, -350.970001, -352.320007, -353.669983, -355.019989, -356.369995, -357.719971, -359.070038, -360.419983, -307.769989, -175.410004, -364.469971, -365.820007, -367.169983, -368.520020, -369.869995, -371.219971, -372.570007, -373.919983, -319.169983, -260.179993, -459.399994, -461.019958, -462.639984, -464.260010, -465.880005, -467.500000, -469.119995, -470.739990, -361.459991, 2.480003, -69.520004, -69.760025, -70.000000, -70.239990, -70.479996, -70.720001, -70.960007, -71.200005, -97.839996, -213.840012, -432.960022, -434.400055, -435.840027, -437.279999, -438.720001, -440.160065, -441.599976, -443.040039, -372.480011, -221.040009, -447.360016, -448.800018, -450.239990, -451.679993, -453.119995, -454.559967, -456.000061, -457.440033, -384.480011, -228.239990, -461.759979, -463.200012, -464.639984, -466.079956, -467.520081, -468.960052, -470.399963, -471.839996, -396.479980, -235.440002, -476.159912,
|
|
-477.600006, -479.040039, -480.479980, -481.919952, -483.360046, -484.800079, -486.239990, -408.480042, -242.639999, -490.559967, -491.999969, -493.440063, -494.880035, -496.319946, -497.759979, -499.200012, -500.639984, -420.480011, -249.840012, -504.960052, -506.399963, -507.839996, -509.280029, -510.720001, -512.159973, -513.599976, -515.040039, -432.480011, -257.040009, -519.360046, -520.800049, -522.239990, -523.680054, -525.120056, -526.559998, -527.999939, -529.440002, -444.480011, -264.239990, -533.760010, -535.200012, -536.640015, -538.079956, -539.520020, -540.960022, -542.399963, -543.839966, -456.479980, -367.599976, -644.559998, -646.239929, -647.920044, -649.599976, -651.280029, -652.960022, -654.640076, -656.320007, -501.200043, -13.740002, -117.880005, -118.179993, -118.479996, -118.780014, -119.080002, -119.379990, -119.680008, -119.979996, -146.379990, -310.470001, -613.950012, -615.479980, -617.010071, -618.539978, -620.069946, -621.599976, -623.130005, -624.660034, -517.589966, -318.269958, -629.250000, -630.779968, -632.309937, -633.840027, -635.369995, -636.899902, -638.429993, -639.959961, -530.190063, -326.070038, -644.550049, -646.079956, -647.609985, -649.140015, -650.669922, -652.200012, -653.729980, -655.260010, -542.789978, -333.870026, -659.849976, -661.380005, -662.910034, -664.439941, -665.970093, -667.500000, -669.029968, -670.559937, -555.390015, -341.669983, -675.149902, -676.679993, -678.209961, -679.740051, -681.270020, -682.800049, -684.329956, -685.859985, -567.989990, -349.470001, -690.450012, -691.979980, -693.510010, -695.039978, -696.569946, -698.099976, -699.630005, -701.160034, -580.589966, -357.269958, -705.750000, -707.279968, -708.809937, -710.340027, -711.869995, -713.399902, -714.929993, -716.459961, -593.190002, -365.070038, -721.050049, -722.579956, -724.109985, -725.640015, -727.169922, -728.700012,
|
|
-730.229980, -731.760010, -605.789978, -483.019958, -841.719971, -843.460022, -845.200073, -846.939941, -848.680054, -850.419983, -852.159973, -853.899963, -648.940002, -37.960014, -178.240021, -178.599976, -178.959991, -179.320007, -179.679993, -180.039978, -180.399994, -180.759964, -202.919983, -419.099915, -812.939941, -814.559937, -816.179993, -817.800049, -819.419922, -821.040039, -822.660034, -824.279968, -674.699951, -427.500031, -829.140015, -830.759949, -832.380005, -833.999939, -835.619995, -837.240051, -838.859924, -840.479980, -687.899963, -435.899994, -845.339966, -846.959961, -848.579956, -850.200012, -851.819885, -853.439941, -855.059937, -856.679993, -701.100037, -444.299927, -861.540039, -863.160034, -864.779968, -866.399963, -868.020020, -869.640015, -871.259949, -872.880005, -714.299988, -452.700012, -877.740051, -879.359924, -880.979980, -882.599915, -884.219971, -885.839966, -887.459961, -889.079956, -727.500000, -461.099915, -893.939941, -895.559937, -897.179993, -898.800049, -900.419922, -902.040039, -903.660034, -905.279968, -740.700012, -469.499969, -910.140015, -911.759949, -913.380005, -914.999939, -916.620056, -918.239990, -919.860046, -921.479919, -753.899963, -477.899902, -926.339905, -927.959961, -929.579956, -931.200012, -932.819946, -934.439880, -936.059937, -937.679932, -767.100037, -606.439941, -1050.880005, -1052.680054, -1054.479980, -1056.280029, -1058.079956, -1059.880005, -1061.679932, -1063.479980, -804.679993, -70.180008, -250.600006, -251.019958, -251.440033, -251.860001, -252.280029, -252.700043, -253.120026, -253.540039, -267.459991, -539.730042, -1029.929932, -1031.640137, -1033.350098, -1035.060059, -1036.770020, -1038.479980, -1040.190063, -1041.900024, -843.809998, -548.729980, -1047.030029, -1048.740112, -1050.449829, -1052.160034, -1053.870117, -1055.580078, -1057.289917, -1059.000122, -857.609985, -557.729980,
|
|
-1064.130005, -1065.840088, -1067.550049, -1069.260010, -1070.969849, -1072.679932, -1074.390137, -1076.100098, -871.410034, -566.729980, -1081.229980, -1082.940063, -1084.650024, -1086.359985, -1088.069946, -1089.780029, -1091.489990, -1093.199951, -885.210022, -575.729980, -1098.329956, -1100.040039, -1101.750122, -1103.460205, -1105.170166, -1106.879883, -1108.589966, -1110.300049, -899.010071, -584.730042, -1115.429932, -1117.140137, -1118.850098, -1120.560059, -1122.270020, -1123.979980, -1125.689941, -1127.400024, -912.810059, -593.730042, -1132.530029, -1134.240234, -1135.949951, -1137.659912, -1139.370117, -1141.079956, -1142.790039, -1144.500122, -926.610046, -602.730042, -1149.629883, -1151.339966, -1153.050049, -1154.760132, -1156.469971, -1158.179810, -1159.890137, -1161.600098, -940.410034, -737.859985, -1272.040039, -1273.899902, -1275.760010, -1277.619995, -1279.479980, -1281.340088, -1283.200195, -1285.060059, -968.420044}, sd::DataType::FLOAT32);
|
|
|
|
NDArray expGradW('c', {kH, kW, iC, mC}, {-2586.600586, -2505.600098, -18624.595703, -50943.605469, -99462.601562, -164181.609375, -245100.609375, -342219.625000,
|
|
-2880.149902, -2790.150146, -20700.152344, -56610.148438, -110520.156250, -182430.156250, -272340.156250, -380250.125000, -2594.701416, -2513.699951,
|
|
-18632.699219, -50951.695312, -99470.695312, -164189.703125, -245108.687500, -342227.750000, -3043.501465, -2953.500244, -20863.500000, -56773.492188,
|
|
-110683.515625, -182593.515625, -272503.531250, -380413.562500, -3383.499756, -3283.500000, -23183.501953, -63083.500000, -122983.500000, -202883.515625,
|
|
-302783.531250, -422683.468750, -3043.501465, -2953.500244, -20863.500000, -56773.492188, -110683.515625, -182593.515625, -272503.531250, -380413.562500,
|
|
-2594.701416, -2513.699951, -18632.699219, -50951.695312, -99470.695312, -164189.703125, -245108.687500, -342227.750000, -2880.149902, -2790.150146, -20700.152344, -56610.148438, -110520.156250, -182430.156250, -272340.156250, -380250.125000, -2586.600586, -2505.600098, -18624.595703, -50943.605469, -99462.601562, -164181.609375, -245100.609375, -342219.625000}, sd::DataType::FLOAT32);
|
|
|
|
NDArray expGradB('c', {oC}, {505., -495., -1495., -2495., -3495., -4494.999512, -5495., -6495.}, sd::DataType::FLOAT32);
|
|
|
|
sd::ops::depthwise_conv2d_bp op;
|
|
auto results = op.evaluate({&input, &weights, &bias, &gradO}, {kH,kW, sH,sW, pH,pW, dH,dW, paddingMode, dataFormat});
|
|
NDArray* gradI = results.at(0);
|
|
NDArray* gradW = results.at(1);
|
|
NDArray* gradB = results.at(2);
|
|
|
|
ASSERT_EQ(Status::OK(), results.status());
|
|
|
|
ASSERT_TRUE(expGradI.isSameShape(gradI));
|
|
ASSERT_TRUE(expGradI.equalsTo(gradI));
|
|
|
|
ASSERT_TRUE(expGradW.isSameShape(gradW));
|
|
ASSERT_TRUE(expGradW.equalsTo(gradW));
|
|
|
|
ASSERT_TRUE(expGradB.isSameShape(gradB));
|
|
ASSERT_TRUE(expGradB.equalsTo(gradB));
|
|
}
|
|
|
|
//////////////////////////////////////////////////////////////////////
|
|
TEST_F(ConvolutionTests2, depthwise_conv2d_bp_test6) {
|
|
|
|
int bS=2, iH=4,iW=3, iC=2,mC=1, kH=3,kW=2, sH=1,sW=1, pH=0,pW=0, dH=1,dW=1;
|
|
int oH=2,oW=2;
|
|
int oC=iC*mC;
|
|
int paddingMode = 0; // 1-SAME, 0-VALID;
|
|
int dataFormat = 0; // 1-NHWC, 0-NCHW
|
|
|
|
auto input = NDArrayFactory::create<double>('c', {bS, iC, iH, iW});
|
|
auto weights = NDArrayFactory::create<double>('c', {kH, kW, iC, mC});
|
|
auto bias = NDArrayFactory::create<double>('c', {oC}, {3,4});
|
|
auto gradO = NDArrayFactory::create<double>('c', {bS, oC, oH, oW});
|
|
|
|
auto expGradI = NDArrayFactory::create<double>('c', {bS, iC, iH, iW},{0.001, 0.005, 0.006, 0.008, 0.03, 0.026, 0.024, 0.07, 0.05, 0.027, 0.069, 0.044, 0.01,
|
|
0.032, 0.024, 0.044, 0.12, 0.08, 0.092, 0.224, 0.136, 0.07, 0.164, 0.096, 0.009, 0.037, 0.03, 0.056, 0.158, 0.106, 0.136,
|
|
0.326, 0.194, 0.099, 0.229, 0.132, 0.026, 0.08, 0.056, 0.108, 0.28, 0.176, 0.22, 0.512, 0.296, 0.15, 0.34, 0.192});
|
|
|
|
auto expGradW = NDArrayFactory::create<double>('c', {kH, kW, iC, mC}, {1.04, 1.68, 1.04, 1.68, 1.04, 1.68, 1.04, 1.68, 1.04, 1.68, 1.04, 1.68});
|
|
|
|
input = 2.;
|
|
weights.linspace(0.1, 0.1);
|
|
gradO.linspace(0.01, 0.01);
|
|
|
|
sd::ops::depthwise_conv2d_bp op;
|
|
auto results = op.evaluate({&input, &weights, &bias, &gradO}, {kH,kW, sH,sW, pH,pW, dH,dW, paddingMode, dataFormat});
|
|
auto* gradI = results.at(0);
|
|
auto* gradW = results.at(1);
|
|
|
|
ASSERT_EQ(Status::OK(), results.status());
|
|
|
|
ASSERT_TRUE(expGradI.isSameShape(gradI));
|
|
ASSERT_TRUE(expGradI.equalsTo(gradI));
|
|
|
|
ASSERT_TRUE(expGradW.isSameShape(gradW));
|
|
ASSERT_TRUE(expGradW.equalsTo(gradW));
|
|
}
|
|
|
|
//////////////////////////////////////////////////////////////////////
|
|
TEST_F(ConvolutionTests2, depthwise_conv2d_bp_test7) {
|
|
|
|
int bS=2, iH=4,iW=3, iC=2,mC=1, kH=3,kW=2, sH=1,sW=1, pH=0,pW=0, dH=1,dW=1;
|
|
int oH=2,oW=2;
|
|
int oC=iC*mC;
|
|
int paddingMode = 0; // 1-SAME, 0-VALID;
|
|
int dataFormat = 0; // 1-NHWC, 0-NCHW
|
|
int wFormat = 1; // 0-[kH, kW, iC, mC], 1-[mC, iC, kH, kW], 2-[mC, kH, kW, iC]
|
|
|
|
NDArray input('c', {bS, iC, iH, iW}, sd::DataType::FLOAT32);
|
|
NDArray weights('c', {mC, iC, kH, kW}, {0.10, 0.30, 0.50, 0.70, 0.90, 1.10, 0.20, 0.40, 0.60, 0.80, 1., 1.2}, sd::DataType::FLOAT32);
|
|
NDArray bias('c', {oC}, {3,4}, sd::DataType::FLOAT32);
|
|
NDArray gradO('c', {bS, oC, oH, oW}, sd::DataType::FLOAT32);
|
|
|
|
|
|
NDArray expGradI('c', {bS, iC, iH, iW},{0.001, 0.005, 0.006, 0.008, 0.03, 0.026, 0.024, 0.07, 0.05, 0.027, 0.069, 0.044, 0.01,
|
|
0.032, 0.024, 0.044, 0.12, 0.08, 0.092, 0.224, 0.136, 0.07, 0.164, 0.096, 0.009, 0.037, 0.03, 0.056, 0.158, 0.106, 0.136,
|
|
0.326, 0.194, 0.099, 0.229, 0.132, 0.026, 0.08, 0.056, 0.108, 0.28, 0.176, 0.22, 0.512, 0.296, 0.15, 0.34, 0.192}, sd::DataType::FLOAT32);
|
|
|
|
NDArray expGradW('c', {mC, iC, kH, kW}, {1.04, 1.04, 1.04, 1.04, 1.04, 1.04, 1.68, 1.68, 1.68, 1.68, 1.68, 1.68}, sd::DataType::FLOAT32);
|
|
|
|
input = 2.;
|
|
gradO.linspace(0.01, 0.01);
|
|
|
|
sd::ops::depthwise_conv2d_bp op;
|
|
auto results = op.evaluate({&input, &weights, &bias, &gradO}, {kH,kW, sH,sW, pH,pW, dH,dW, paddingMode, dataFormat, wFormat});
|
|
auto* gradI = results.at(0);
|
|
auto* gradW = results.at(1);
|
|
|
|
ASSERT_EQ(Status::OK(), results.status());
|
|
|
|
ASSERT_TRUE(expGradI.isSameShape(gradI));
|
|
ASSERT_TRUE(expGradI.equalsTo(gradI));
|
|
|
|
ASSERT_TRUE(expGradW.isSameShape(gradW));
|
|
ASSERT_TRUE(expGradW.equalsTo(gradW));
|
|
}
|
|
|
|
#endif //LIBND4J_CONVOLUTIONTESTS2_H
|