[WIP] MSVC-related tests fixes (#88)

* fix narrowing down cast

Signed-off-by: raver119 <raver119@gmail.com>

* trigger jenkins

Signed-off-by: raver119 <raver119@gmail.com>

* few more fixes for MSVC and Windows

Signed-off-by: raver119 <raver119@gmail.com>

* few more fixes for MSVC and Windows

Signed-off-by: raver119 <raver119@gmail.com>

* few more fixes for MSVC and Windows

Signed-off-by: raver119 <raver119@gmail.com>

* few more fixes for MSVC and Windows

Signed-off-by: raver119 <raver119@gmail.com>

* few more tweaks

Signed-off-by: raver119 <raver119@gmail.com>

* few more tweaks

Signed-off-by: raver119 <raver119@gmail.com>

* few more tweaks

Signed-off-by: raver119 <raver119@gmail.com>

* few more tweaks

Signed-off-by: raver119 <raver119@gmail.com>

* few more tweaks

Signed-off-by: raver119 <raver119@gmail.com>

* - few more tweaks
- tensormmul dtype validation

Signed-off-by: raver119 <raver119@gmail.com>

* - few more tweaks
- batched gemm dtype validation

Signed-off-by: raver119 <raver119@gmail.com>

* - few more tweaks

Signed-off-by: raver119 <raver119@gmail.com>

* - few more tweaks

Signed-off-by: raver119 <raver119@gmail.com>

* - few more tweaks

Signed-off-by: raver119 <raver119@gmail.com>

* - few more tweaks

Signed-off-by: raver119 <raver119@gmail.com>
master
raver119 2019-11-30 16:02:07 +03:00 committed by GitHub
parent 2be47082c9
commit 4ada65b384
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
23 changed files with 1209 additions and 1215 deletions

View File

@ -110,19 +110,7 @@ DECLARE_SHAPE_FN(batched_gemm) {
auto shapeList = SHAPELIST(); auto shapeList = SHAPELIST();
if (!(M > 0 && N > 0 && K > 0 && ldA > 0 && ldB > 0 && ldC > 0 && batchSize > 0)) { if (!(M > 0 && N > 0 && K > 0 && ldA > 0 && ldB > 0 && ldC > 0 && batchSize > 0)) {
Nd4jLong *newShape; shapeList->push_back(ConstantShapeHelper::getInstance()->createShapeInfo(ArrayOptions::dataType(inputShape->at(0)), 'c', {1, 1}));
ALLOCATE(newShape, block.getWorkspace(), shape::shapeInfoLength(2), Nd4jLong);
newShape[0] = 2;
newShape[1] = 1;
newShape[2] = 1;
newShape[3] = 1;
newShape[4] = 1;
newShape[5] = 0;
newShape[6] = 1;
newShape[7] = 99;
shapeList->push_back(newShape);
return shapeList; return shapeList;
} }
@ -130,7 +118,7 @@ DECLARE_SHAPE_FN(batched_gemm) {
std::vector<Nd4jLong> shape({M, N}); std::vector<Nd4jLong> shape({M, N});
for (int e = 0; e < batchSize; e++) { for (int e = 0; e < batchSize; e++) {
auto newShape = ConstantShapeHelper::getInstance()->createShapeInfo(block.dataType(), 'f', shape); auto newShape = ConstantShapeHelper::getInstance()->createShapeInfo(ArrayOptions::dataType(inputShape->at(0)), 'f', shape);
shapeList->push_back(newShape); shapeList->push_back(newShape);
} }

View File

@ -31,7 +31,9 @@ namespace nd4j {
auto a = INPUT_VARIABLE(0); auto a = INPUT_VARIABLE(0);
auto b = INPUT_VARIABLE(1); auto b = INPUT_VARIABLE(1);
auto c = OUTPUT_VARIABLE(0); // auto c = OUTPUT_VARIABLE(0); //
REQUIRE_TRUE(a->dataType() == b->dataType(), 0, "tensormmul: A, B and C data types must be the same");
// building axes // building axes
int axe0_size = INT_ARG(0); int axe0_size = INT_ARG(0);
@ -54,7 +56,10 @@ namespace nd4j {
DECLARE_SHAPE_FN(tensormmul) { DECLARE_SHAPE_FN(tensormmul) {
auto aShapeInfo = inputShape->at(0); auto aShapeInfo = inputShape->at(0);
auto bShapeInfo = inputShape->at(1); auto bShapeInfo = inputShape->at(1);
REQUIRE_TRUE(ArrayOptions::dataType(aShapeInfo) == ArrayOptions::dataType(bShapeInfo), 0, "tensormmul: A and B data types must be the same");
// building axes // building axes
int axe0_size = INT_ARG(0); int axe0_size = INT_ARG(0);
int axe1_size = INT_ARG(axe0_size+1); int axe1_size = INT_ARG(axe0_size+1);
@ -70,7 +75,7 @@ namespace nd4j {
std::vector<Nd4jLong> shapeAt, shapeBt; std::vector<Nd4jLong> shapeAt, shapeBt;
auto outShape = nd4j::ShapeUtils::evalShapeForTensorDot(aShapeInfo, bShapeInfo, axes_0, axes_1, permutAt, permutBt, shapeAt, shapeBt); auto outShape = nd4j::ShapeUtils::evalShapeForTensorDot(aShapeInfo, bShapeInfo, axes_0, axes_1, permutAt, permutBt, shapeAt, shapeBt);
return SHAPELIST(ConstantShapeHelper::getInstance()->createShapeInfo(ShapeDescriptor(block.dataType(), 'c', outShape))); return SHAPELIST(ConstantShapeHelper::getInstance()->createShapeInfo(ShapeDescriptor(ArrayOptions::dataType(aShapeInfo), 'c', outShape)));
} }
DECLARE_TYPES(tensormmul) { DECLARE_TYPES(tensormmul) {

View File

@ -45,7 +45,10 @@ endif()
if (APPLE) if (APPLE)
set(CMAKE_CXX_FLAGS " -fPIC -std=c++11 -fmax-errors=2 -D__APPLE_OS__=true") set(CMAKE_CXX_FLAGS " -fPIC -std=c++11 -fmax-errors=2 -D__APPLE_OS__=true")
elseif(WIN32) elseif(WIN32)
set(CMAKE_CXX_FLAGS_RELEASE "${CMAKE_CXX_FLAGS_RELEASE} -march=native -mtune=native -O3") if (CPU_BLAS)
set(CMAKE_CXX_FLAGS_RELEASE "${CMAKE_CXX_FLAGS_RELEASE} -fPIC -march=native -mtune=native -O3")
endif()
if (CPU_BLAS AND LINUX) if (CPU_BLAS AND LINUX)
set(CMAKE_CXX_FLAGS " -fPIC -std=c++11 -fmax-errors=2") set(CMAKE_CXX_FLAGS " -fPIC -std=c++11 -fmax-errors=2")
endif() endif()

View File

@ -134,7 +134,7 @@ TYPED_TEST(TypedConvolutionTests1, conv2d_1) {
TYPED_TEST(TypedConvolutionTests1, conv2d_2) { TYPED_TEST(TypedConvolutionTests1, conv2d_2) {
auto input = NDArrayFactory::create<TypeParam>('c', {1, 1, 1, 4}); auto input = NDArrayFactory::create<TypeParam>('c', {1, 1, 1, 4});
auto weights = NDArrayFactory::create<TypeParam>('c', {1, 1, 1, 4}); auto weights = NDArrayFactory::create<TypeParam>('c', {1, 1, 1, 4});
auto exp = NDArrayFactory::create<TypeParam>('c', {1, 4, 1, 4}, {2., 4., 6., 8., 2., 4., 6., 8., 2., 4., 6., 8., 2., 4., 6., 8.}); auto exp = NDArrayFactory::create<TypeParam>('c', {1, 4, 1, 4}, {2.f, 4.f, 6.f, 8.f, 2.f, 4.f, 6.f, 8.f, 2.f, 4.f, 6.f, 8.f, 2.f, 4.f, 6.f, 8.f});
weights.assign(2.0); weights.assign(2.0);
input.linspace(1); input.linspace(1);
@ -161,7 +161,7 @@ TYPED_TEST(TypedConvolutionTests1, conv2d_3) {
auto input = NDArrayFactory::create<TypeParam>('c', {bS, iH, iW, iC}); auto input = NDArrayFactory::create<TypeParam>('c', {bS, iH, iW, iC});
auto weights = NDArrayFactory::create<TypeParam>('c', {kH, kW, iC, oC}); auto weights = NDArrayFactory::create<TypeParam>('c', {kH, kW, iC, oC});
auto bias = NDArrayFactory::create<TypeParam>('c', {oC}, {1,2,3}); auto bias = NDArrayFactory::create<TypeParam>('c', {oC}, {1.f, 2.f, 3.f});
auto expOutput = NDArrayFactory::create<TypeParam>('c', {bS, oH, oW, oC},{ 152.f, 155.2f, 158.4f, 152.f, 155.2f, 158.4f, 66.4f, 68.f, 69.6f, 170.4f, 175.2f, 180.f, 170.4f, 175.2f, 180.f, 70.8f, 73.2f, 75.6f, auto expOutput = NDArrayFactory::create<TypeParam>('c', {bS, oH, oW, oC},{ 152.f, 155.2f, 158.4f, 152.f, 155.2f, 158.4f, 66.4f, 68.f, 69.6f, 170.4f, 175.2f, 180.f, 170.4f, 175.2f, 180.f, 70.8f, 73.2f, 75.6f,
@ -762,10 +762,10 @@ TYPED_TEST(TypedConvolutionTests1, Test_Conv1D_ff_1) {
auto input = NDArrayFactory::create<TypeParam>('c', {2, 2, 6}); auto input = NDArrayFactory::create<TypeParam>('c', {2, 2, 6});
auto weights = NDArrayFactory::create<TypeParam>('c', {2, 2, 3}, {1,5,9,3,7,11,2,6,10,4,8,12}); auto weights = NDArrayFactory::create<TypeParam>('c', {2, 2, 3}, {1,5,9,3,7,11,2,6,10,4,8,12});
auto bias = NDArrayFactory::create<TypeParam>('c', {3}); auto bias = NDArrayFactory::create<TypeParam>('c', {3});
auto expFF = NDArrayFactory::create<TypeParam>('c', {2, 3, 5}, {59.0, 69.0, 79.0, 89.0, 99.0, 132.0, 158.0, 184.0, 210.0, 236.0, 205.0, 247.0, 289.0, 331.0, 373.0, 179.0, 189.0, 199.0, 209.0, 219.0, 444.0, 470.0, 496.0, 522.0, 548.0, 709.0, 751.0, 793.0, 835.0, 877.0}); auto expFF = NDArrayFactory::create<TypeParam>('c', {2, 3, 5}, {59.0f, 69.0f, 79.0f, 89.0f, 99.0f, 132.0f, 158.0f, 184.0f, 210.0f, 236.0f, 205.0f, 247.0f, 289.0f, 331.0f, 373.0f, 179.0f, 189.0f, 199.0f, 209.0f, 219.0f, 444.0f, 470.0f, 496.0f, 522.0f, 548.0f, 709.0f, 751.0f, 793.0f, 835.0f, 877.0f});
auto expEps = NDArrayFactory::create<TypeParam>('c', {2, 2, 6}, {130.0, 293.0, 326.0, 359.0, 392.0, 220.0, 166.0, 371.0, 416.0, 461.0, 506.0, 280.0, 355.0, 788.0, 821.0, 854.0, 887.0, 490.0, 481.0, 1046.0, 1091.0, 1136.0, 1181.0, 640.0}); auto expEps = NDArrayFactory::create<TypeParam>('c', {2, 2, 6}, {130.0f, 293.0f, 326.0f, 359.0f, 392.0f, 220.0f, 166.0f, 371.0f, 416.0f, 461.0f, 506.0f, 280.0f, 355.0f, 788.0f, 821.0f, 854.0f, 887.0f, 490.0f, 481.0f, 1046.0f, 1091.0f, 1136.0f, 1181.0f, 640.0f});
auto expGW = NDArrayFactory::create<TypeParam>('c', {3, 2, 2}, {1415.0, 1520.0, 2045.0, 2150.0, 1865.0, 2020.0, 2795.0, 2950.0, 2315.0, 2520.0, 3545.0, 3750.0}); auto expGW = NDArrayFactory::create<TypeParam>('c', {3, 2, 2}, {1415.0f, 1520.0f, 2045.0f, 2150.0f, 1865.0f, 2020.0f, 2795.0f, 2950.0f, 2315.0f, 2520.0f, 3545.0f, 3750.0f});
auto expGB = NDArrayFactory::create<TypeParam>('c', {3}, {105.0, 155.0, 205.0}); auto expGB = NDArrayFactory::create<TypeParam>('c', {3}, {105.0f, 155.0f, 205.0f});
expGW.permutei({2,1,0}); expGW.permutei({2,1,0});
input.linspace(1); input.linspace(1);
@ -809,7 +809,7 @@ TYPED_TEST(TypedConvolutionTests1, Test_Conv1D_ff_1) {
TYPED_TEST(TypedConvolutionTests1, Test_Conv1D_ff_2) { TYPED_TEST(TypedConvolutionTests1, Test_Conv1D_ff_2) {
auto input = NDArrayFactory::create<TypeParam>('c', {2, 2, 6}); auto input = NDArrayFactory::create<TypeParam>('c', {2, 2, 6});
auto weights = NDArrayFactory::create<TypeParam>('c', {2, 2, 3}, {1,5,9,3,7,11,2,6,10,4,8,12}); auto weights = NDArrayFactory::create<TypeParam>('c', {2, 2, 3}, {1.f, 5.f, 9.f, 3.f, 7.f, 11.f, 2.f, 6.f, 10.f, 4.f, 8.f, 12.f});
input.linspace(1); input.linspace(1);
@ -1164,7 +1164,6 @@ TYPED_TEST(TypedConvolutionTests1, conv2d_bp_test2) {
////////////////////////////////////////////////////////////////////// //////////////////////////////////////////////////////////////////////
TYPED_TEST(TypedConvolutionTests1, conv2d_bp_test3) { TYPED_TEST(TypedConvolutionTests1, conv2d_bp_test3) {
int bS=2, iH=4,iW=3, iC=4,oC=3, kH=3,kW=2, sH=1,sW=1, pH=0,pW=0, dH=1,dW=1; int bS=2, iH=4,iW=3, iC=4,oC=3, kH=3,kW=2, sH=1,sW=1, pH=0,pW=0, dH=1,dW=1;
int oH=2,oW=2; int oH=2,oW=2;
int paddingMode = 0; // 1-SAME, 0-VALID; int paddingMode = 0; // 1-SAME, 0-VALID;
@ -1175,16 +1174,16 @@ TYPED_TEST(TypedConvolutionTests1, conv2d_bp_test3) {
auto bias = NDArrayFactory::create<TypeParam>('c', {oC}, {1,2,3}); auto bias = NDArrayFactory::create<TypeParam>('c', {oC}, {1,2,3});
auto gradO = NDArrayFactory::create<TypeParam>('c', {bS, oC, oH, oW}); auto gradO = NDArrayFactory::create<TypeParam>('c', {bS, oC, oH, oW});
auto expGradI = NDArrayFactory::create<TypeParam>('c', {bS, iC, iH, iW},{ 0.567, 1.224,0.66 ,1.314, 2.82 ,1.512,1.386, 2.976,1.596,0.801, 1.71 ,0.912,0.657, 1.422,0.768,1.53 , 3.288,1.764,1.602, 3.444,1.848,0.927, 1.98 ,1.056, auto expGradI = NDArrayFactory::create<TypeParam>('c', {bS, iC, iH, iW},{ 0.567f, 1.224f, 0.66f, 1.314f, 2.82f, 1.512f, 1.386f, 2.976f, 1.596f, 0.801f, 1.71f, 0.912f, 0.657f, 1.422f, 0.768f, 1.53f, 3.288f, 1.764f, 1.602f, 3.444f, 1.848f, 0.927f, 1.98f, 1.056f,
0.747, 1.62 ,0.876,1.746, 3.756,2.016,1.818, 3.912,2.1 ,1.053, 2.25 ,1.2 ,0.837, 1.818,0.984,1.962, 4.224,2.268,2.034, 4.38 ,2.352,1.179, 2.52 ,1.344, 0.747f, 1.62f, 0.876f, 1.746f, 3.756f, 2.016f, 1.818f, 3.912f, 2.1f, 1.053f, 2.25f, 1.2f, 0.837f, 1.818f, 0.984f, 1.962f, 4.224f, 2.268f, 2.034f, 4.38f, 2.352f, 1.179f, 2.52f, 1.344f,
1.467, 3.06 ,1.596,3.186, 6.636,3.456,3.402, 7.08 ,3.684,1.845, 3.834,1.992,1.773, 3.69 ,1.92 ,3.834, 7.968,4.14 ,4.05 , 8.412,4.368,2.187, 4.536,2.352, 1.467f, 3.06f, 1.596f, 3.186f, 6.636f, 3.456f, 3.402f, 7.08f, 3.684f, 1.845f, 3.834f, 1.992f, 1.773f, 3.69f, 1.92f, 3.834f, 7.968f, 4.14f, 4.05f, 8.412f, 4.368f, 2.187f, 4.536f, 2.352f,
2.079, 4.32 ,2.244,4.482, 9.3 ,4.824,4.698, 9.744,5.052,2.529, 5.238,2.712,2.385, 4.95 ,2.568,5.13 ,10.632,5.508,5.346,11.076,5.736,2.871, 5.94 ,3.072}); 2.079f, 4.32f, 2.244f, 4.482f, 9.3f, 4.824f, 4.698f, 9.744f, 5.052f, 2.529f, 5.238f, 2.712f, 2.385f, 4.95f, 2.568f, 5.13f, 10.632f, 5.508f, 5.346f, 11.076f, 5.736f, 2.871f, 5.94f, 3.072f});
auto expGradW = NDArrayFactory::create<TypeParam>('c', {oC, iC, kH, kW},{1.3600e+00,1.3600e+00,1.3600e+00,1.3600e+00,1.3600e+00,1.3600e+00,1.3600e+00,1.3600e+00,1.3600e+00,1.3600e+00,1.3600e+00,1.3600e+00,1.3600e+00,1.3600e+00,1.3600e+00,1.3600e+00,1.3600e+00,1.3600e+00, auto expGradW = NDArrayFactory::create<TypeParam>('c', {oC, iC, kH, kW},{1.3600e+00f, 1.3600e+00f, 1.3600e+00f, 1.3600e+00f, 1.3600e+00f, 1.3600e+00f, 1.3600e+00f, 1.3600e+00f, 1.3600e+00f, 1.3600e+00f, 1.3600e+00f, 1.3600e+00f, 1.3600e+00f, 1.3600e+00f, 1.3600e+00f, 1.3600e+00f, 1.3600e+00f, 1.3600e+00f,
1.3600e+00,1.3600e+00,1.3600e+00,1.3600e+00,1.3600e+00,1.3600e+00,2.0000e+00,2.0000e+00,2.0000e+00,2.0000e+00,2.0000e+00,2.0000e+00,2.0000e+00,2.0000e+00,2.0000e+00,2.0000e+00,2.0000e+00,2.0000e+00, 1.3600e+00f, 1.3600e+00f, 1.3600e+00f, 1.3600e+00f, 1.3600e+00f, 1.3600e+00f, 2.0000e+00f, 2.0000e+00f, 2.0000e+00f, 2.0000e+00f, 2.0000e+00f, 2.0000e+00f, 2.0000e+00f, 2.0000e+00f, 2.0000e+00f, 2.0000e+00f, 2.0000e+00f, 2.0000e+00f,
2.0000e+00,2.0000e+00,2.0000e+00,2.0000e+00,2.0000e+00,2.0000e+00,2.0000e+00,2.0000e+00,2.0000e+00,2.0000e+00,2.0000e+00,2.0000e+00,2.6400e+00,2.6400e+00,2.6400e+00,2.6400e+00,2.6400e+00,2.6400e+00, 2.0000e+00f, 2.0000e+00f, 2.0000e+00f, 2.0000e+00f, 2.0000e+00f, 2.0000e+00f, 2.0000e+00f, 2.0000e+00f, 2.0000e+00f, 2.0000e+00f, 2.0000e+00f, 2.0000e+00f, 2.6400e+00f, 2.6400e+00f, 2.6400e+00f, 2.6400e+00f, 2.6400e+00f, 2.6400e+00f,
2.6400e+00,2.6400e+00,2.6400e+00,2.6400e+00,2.6400e+00,2.6400e+00,2.6400e+00,2.6400e+00,2.6400e+00,2.6400e+00,2.6400e+00,2.6400e+00,2.6400e+00,2.6400e+00,2.6400e+00,2.6400e+00,2.6400e+00,2.6400e+00}); 2.6400e+00f, 2.6400e+00f, 2.6400e+00f, 2.6400e+00f, 2.6400e+00f, 2.6400e+00f, 2.6400e+00f, 2.6400e+00f, 2.6400e+00f, 2.6400e+00f, 2.6400e+00f, 2.6400e+00f, 2.6400e+00f, 2.6400e+00f, 2.6400e+00f, 2.6400e+00f, 2.6400e+00f, 2.6400e+00f});
auto expGradB = NDArrayFactory::create<TypeParam>('c', {oC},{0.68, 1., 1.32}); auto expGradB = NDArrayFactory::create<TypeParam>('c', {oC},{0.68f, 1.f, 1.32f});
input = 2.; input = 2.;
weights.linspace(0.1, 0.1); weights.linspace(0.1, 0.1);
@ -1253,21 +1252,21 @@ TYPED_TEST(TypedConvolutionTests1, conv3d_bp_test1) {
auto bias = NDArrayFactory::create<TypeParam>('c', {oC}, {1,2,3}); auto bias = NDArrayFactory::create<TypeParam>('c', {oC}, {1,2,3});
auto gradO = NDArrayFactory::create<TypeParam>('c', {bS, oD, oH, oW, oC}); auto gradO = NDArrayFactory::create<TypeParam>('c', {bS, oD, oH, oW, oC});
auto expGradI = NDArrayFactory::create<TypeParam>('c', {bS, iD, iH, iW, iC},{0.226, 0.343, 0.46 , 0.577, 1.172, 1.46 , 1.748, 2.036, 1.892, 2.288, 2.684, 3.08 , 1.284, 1.581, 1.878, 2.175, 4.458, 5.133, 5.808, 6.483, 6.186, 7.023, 7.86 , 8.697, 3.39 , 3.93 , 4.47 , 5.01 , 9.642, 10.803, 11.964, 13.125, 11.37 , 12.693, 14.016, 15.339, auto expGradI = NDArrayFactory::create<TypeParam>('c', {bS, iD, iH, iW, iC},{0.226f, 0.343f, 0.46f, 0.577f, 1.172f, 1.46f, 1.748f, 2.036f, 1.892f, 2.288f, 2.684f, 3.08f, 1.284f, 1.581f, 1.878f, 2.175f, 4.458f, 5.133f, 5.808f, 6.483f, 6.186f, 7.023f, 7.86f, 8.697f, 3.39f, 3.93f, 4.47f, 5.01f, 9.642f, 10.803f, 11.964f, 13.125f, 11.37f, 12.693f, 14.016f, 15.339f,
5.266, 5.707, 6.148, 6.589, 12.98 , 13.916, 14.852, 15.788, 14.564, 15.608, 16.652, 17.696, 6.284, 7.166, 8.048, 8.93 , 17.896, 19.768, 21.64 , 23.512, 21.928, 24.016, 26.104, 28.192, 18.12 , 19.686, 21.252, 22.818, 45.852, 49.146, 52.44 , 55.734, 53.196, 56.814, 60.432, 64.05 , 5.266f, 5.707f, 6.148f, 6.589f, 12.98f, 13.916f, 14.852f, 15.788f, 14.564f, 15.608f, 16.652f, 17.696f, 6.284f, 7.166f, 8.048f, 8.93f, 17.896f, 19.768f, 21.64f, 23.512f, 21.928f, 24.016f, 26.104f, 28.192f, 18.12f, 19.686f, 21.252f, 22.818f, 45.852f, 49.146f, 52.44f, 55.734f, 53.196f, 56.814f, 60.432f, 64.05f,
28.164, 30.216, 32.268, 34.32 , 67.884, 72.15 , 76.416, 80.682, 75.228, 79.818, 84.408, 88.998, 29.324, 30.854, 32.384, 33.914, 67.432, 70.6 , 73.768, 76.936, 73.192, 76.576, 79.96 , 83.344, 27.884, 30.062, 32.24 , 34.418, 66.28 , 70.744, 75.208, 79.672, 70.312, 74.992, 79.672, 84.352, 28.164f, 30.216f, 32.268f, 34.32f, 67.884f, 72.15f, 76.416f, 80.682f, 75.228f, 79.818f, 84.408f, 88.998f, 29.324f, 30.854f, 32.384f, 33.914f, 67.432f, 70.6f, 73.768f, 76.936f, 73.192f, 76.576f, 79.96f, 83.344f, 27.884f, 30.062f, 32.24f, 34.418f, 66.28f, 70.744f, 75.208f, 79.672f, 70.312f, 74.992f, 79.672f, 84.352f,
58.296, 61.806, 65.316, 68.826,133.98 , 141.162, 148.344, 155.526,141.324, 148.83 , 156.336, 163.842, 68.34 , 72.336, 76.332, 80.328,156.012, 164.166, 172.32 , 180.474,163.356, 171.834, 180.312, 188.79 , 61.292, 64.118, 66.944, 69.77 ,136.552, 142.312, 148.072, 153.832,142.312, 148.288, 154.264, 160.24 , 58.296f, 61.806f, 65.316f, 68.826f, 133.98f, 141.162f, 148.344f, 155.526f, 141.324f, 148.83f, 156.336f, 163.842f, 68.34f, 72.336f, 76.332f, 80.328f, 156.012f, 164.166f, 172.32f, 180.474f, 163.356f, 171.834f, 180.312f, 188.79f, 61.292f, 64.118f, 66.944f, 69.77f, 136.552f, 142.312f, 148.072f, 153.832f, 142.312f, 148.288f, 154.264f, 160.24f,
9.298, 11.359, 13.42 , 15.481, 27.092, 31.268, 35.444, 39.62 , 27.812, 32.096, 36.38 , 40.664, 26.556, 29.769, 32.982, 36.195, 66.666, 73.173, 79.68 , 86.187, 68.394, 75.063, 81.732, 88.401, 28.662, 32.118, 35.574, 39.03 , 71.85 , 78.843, 85.836, 92.829, 73.578, 80.733, 87.888, 95.043, 9.298f, 11.359f, 13.42f, 15.481f, 27.092f, 31.268f, 35.444f, 39.62f, 27.812f, 32.096f, 36.38f, 40.664f, 26.556f, 29.769f, 32.982f, 36.195f, 66.666f, 73.173f, 79.68f, 86.187f, 68.394f, 75.063f, 81.732f, 88.401f, 28.662f, 32.118f, 35.574f, 39.03f, 71.85f, 78.843f, 85.836f, 92.829f, 73.578f, 80.733f, 87.888f, 95.043f,
29.89 , 32.275, 34.66 , 37.045, 70.004, 74.828, 79.652, 84.476, 71.588, 76.52 , 81.452, 86.384, 71.084, 75.854, 80.624, 85.394,163.048, 172.696, 182.344, 191.992,167.08 , 176.944, 186.808, 196.672,138.648, 146.046, 153.444, 160.842,310.236, 325.194, 340.152, 355.11 ,317.58 , 332.862, 348.144, 363.426, 29.89f, 32.275f, 34.66f, 37.045f, 70.004f, 74.828f, 79.652f, 84.476f, 71.588f, 76.52f, 81.452f, 86.384f, 71.084f, 75.854f, 80.624f, 85.394f, 163.048f, 172.696f, 182.344f, 191.992f, 167.08f, 176.944f, 186.808f, 196.672f, 138.648f, 146.046f, 153.444f, 160.842f, 310.236f, 325.194f, 340.152f, 355.11f, 317.58f, 332.862f, 348.144f, 363.426f,
148.692, 156.576, 164.46 , 172.344,332.268, 348.198, 364.128, 380.058,339.612, 355.866, 372.12 , 388.374,125.228, 130.646, 136.064, 141.482,274.792, 285.736, 296.68 , 307.624,280.552, 291.712, 302.872, 314.032, 92.684, 98.75 , 104.816, 110.882,211.432, 223.672, 235.912, 248.152,215.464, 227.92 , 240.376, 252.832, 148.692f, 156.576f, 164.46f, 172.344f, 332.268f, 348.198f, 364.128f, 380.058f, 339.612f, 355.866f, 372.12f, 388.374f, 125.228f, 130.646f, 136.064f, 141.482f, 274.792f, 285.736f, 296.68f, 307.624f, 280.552f, 291.712f, 302.872f, 314.032f, 92.684f, 98.75f, 104.816f, 110.882f, 211.432f, 223.672f, 235.912f, 248.152f, 215.464f, 227.92f, 240.376f, 252.832f,
178.824, 188.166, 197.508, 206.85 ,398.364, 417.21 , 436.056, 454.902,405.708, 424.878, 444.048, 463.218,188.868, 198.696, 208.524, 218.352,420.396, 440.214, 460.032, 479.85 ,427.74 , 447.882, 468.024, 488.166,157.196, 163.91 , 170.624, 177.338,343.912, 357.448, 370.984, 384.52 ,349.672, 363.424, 377.176, 390.928}); 178.824f, 188.166f, 197.508f, 206.85f, 398.364f, 417.21f, 436.056f, 454.902f, 405.708f, 424.878f, 444.048f, 463.218f, 188.868f, 198.696f, 208.524f, 218.352f, 420.396f, 440.214f, 460.032f, 479.85f, 427.74f, 447.882f, 468.024f, 488.166f, 157.196f, 163.91f, 170.624f, 177.338f, 343.912f, 357.448f, 370.984f, 384.52f, 349.672f, 363.424f, 377.176f, 390.928f});
auto expGradW = NDArrayFactory::create<TypeParam>('c', {kD, kH, kW, iC, oC},{120.96, 122.04, 123.12,120.96, 122.04, 123.12,120.96, 122.04, 123.12,120.96, 122.04, 123.12, 79.56, 80.28, 81. , 79.56, 80.28, 81. , 79.56, 80.28, 81. , 79.56, 80.28, 81. , auto expGradW = NDArrayFactory::create<TypeParam>('c', {kD, kH, kW, iC, oC},{120.96f, 122.04f, 123.12f, 120.96f, 122.04f, 123.12f, 120.96f, 122.04f, 123.12f, 120.96f, 122.04f, 123.12f, 79.56f, 80.28f, 81.f, 79.56f, 80.28f, 81.f, 79.56f, 80.28f, 81.f, 79.56f, 80.28f, 81.f,
154.8 , 156.24, 157.68,154.8 , 156.24, 157.68,154.8 , 156.24, 157.68,154.8 , 156.24, 157.68,101.76, 102.72, 103.68,101.76, 102.72, 103.68,101.76, 102.72, 103.68,101.76, 102.72, 103.68, 154.8f, 156.24f, 157.68f, 154.8f, 156.24f, 157.68f, 154.8f, 156.24f, 157.68f, 154.8f, 156.24f, 157.68f, 101.76f, 102.72f, 103.68f, 101.76f, 102.72f, 103.68f, 101.76f, 102.72f, 103.68f, 101.76f, 102.72f, 103.68f,
111.24, 112.32, 113.4 ,111.24, 112.32, 113.4 ,111.24, 112.32, 113.4 ,111.24, 112.32, 113.4 , 73.08, 73.8 , 74.52, 73.08, 73.8 , 74.52, 73.08, 73.8 , 74.52, 73.08, 73.8 , 74.52, 111.24f, 112.32f, 113.4f, 111.24f, 112.32f, 113.4f, 111.24f, 112.32f, 113.4f, 111.24f, 112.32f, 113.4f, 73.08f, 73.8f, 74.52f, 73.08f, 73.8f, 74.52f, 73.08f, 73.8f, 74.52f, 73.08f, 73.8f, 74.52f,
67.68, 68.4 , 69.12, 67.68, 68.4 , 69.12, 67.68, 68.4 , 69.12, 67.68, 68.4 , 69.12, 44.4 , 44.88, 45.36, 44.4 , 44.88, 45.36, 44.4 , 44.88, 45.36, 44.4 , 44.88, 45.36, 67.68f, 68.4f, 69.12f, 67.68f, 68.4f, 69.12f, 67.68f, 68.4f, 69.12f, 67.68f, 68.4f, 69.12f, 44.4f, 44.88f, 45.36f, 44.4f, 44.88f, 45.36f, 44.4f, 44.88f, 45.36f, 44.4f, 44.88f, 45.36f,
85.92, 86.88, 87.84, 85.92, 86.88, 87.84, 85.92, 86.88, 87.84, 85.92, 86.88, 87.84, 56.32, 56.96, 57.6 , 56.32, 56.96, 57.6 , 56.32, 56.96, 57.6 , 56.32, 56.96, 57.6 , 85.92f, 86.88f, 87.84f, 85.92f, 86.88f, 87.84f, 85.92f, 86.88f, 87.84f, 85.92f, 86.88f, 87.84f, 56.32f, 56.96f, 57.6f, 56.32f, 56.96f, 57.6f, 56.32f, 56.96f, 57.6f, 56.32f, 56.96f, 57.6f,
61.2 , 61.92, 62.64, 61.2 , 61.92, 62.64, 61.2 , 61.92, 62.64, 61.2 , 61.92, 62.64, 40.08, 40.56, 41.04, 40.08, 40.56, 41.04, 40.08, 40.56, 41.04, 40.08, 40.56, 41.04}); 61.2f, 61.92f, 62.64f, 61.2f, 61.92f, 62.64f, 61.2f, 61.92f, 62.64f, 61.2f, 61.92f, 62.64f, 40.08f, 40.56f, 41.04f, 40.08f, 40.56f, 41.04f, 40.08f, 40.56f, 41.04f, 40.08f, 40.56f, 41.04f});
// auto expGradB('c', {oC},{}); // auto expGradB('c', {oC},{});
input = 2.; input = 2.;
@ -1303,19 +1302,19 @@ TYPED_TEST(TypedConvolutionTests1, conv3d_bp_test2) {
auto bias = NDArrayFactory::create<TypeParam>('c', {oC}, {1,2,3}); auto bias = NDArrayFactory::create<TypeParam>('c', {oC}, {1,2,3});
auto gradO = NDArrayFactory::create<TypeParam>('c', {bS, oD, oH, oW, oC}); auto gradO = NDArrayFactory::create<TypeParam>('c', {bS, oD, oH, oW, oC});
auto expGradI = NDArrayFactory::create<TypeParam>('c', {bS, iD, iH, iW, iC},{ 0.014, 0.032, 0.05 , 0.068, 0.118, 0.181, 0.244, 0.307, 0.212, 0.257, 0.302, 0.347, 0.208, 0.298, 0.388, 0.478, 1.028, 1.262, 1.496, 1.73 , 1.036, 1.18 , 1.324, 1.468, 0.928, 1.018, 1.108, 1.198, 2.9 , 3.134, 3.368, 3.602, 2.188, 2.332, 2.476, 2.62 , auto expGradI = NDArrayFactory::create<TypeParam>('c', {bS, iD, iH, iW, iC},{ 0.014f, 0.032f, 0.05f, 0.068f, 0.118f, 0.181f, 0.244f, 0.307f, 0.212f, 0.257f, 0.302f, 0.347f, 0.208f, 0.298f, 0.388f, 0.478f, 1.028f, 1.262f, 1.496f, 1.73f, 1.036f, 1.18f, 1.324f, 1.468f, 0.928f, 1.018f, 1.108f, 1.198f, 2.9f, 3.134f, 3.368f, 3.602f, 2.188f, 2.332f, 2.476f, 2.62f,
1.202, 1.274, 1.346, 1.418, 3.142, 3.313, 3.484, 3.655, 2.048, 2.147, 2.246, 2.345, 0.532, 0.676, 0.82 , 0.964, 2.324, 2.666, 3.008, 3.35 , 2.008, 2.206, 2.404, 2.602, 3.584, 3.98 , 4.376, 4.772,10.552,11.452,12.352,13.252, 7.4 , 7.904, 8.408, 8.912, 1.202f, 1.274f, 1.346f, 1.418f, 3.142f, 3.313f, 3.484f, 3.655f, 2.048f, 2.147f, 2.246f, 2.345f, 0.532f, 0.676f, 0.82f, 0.964f, 2.324f, 2.666f, 3.008f, 3.35f, 2.008f, 2.206f, 2.404f, 2.602f, 3.584f, 3.98f, 4.376f, 4.772f, 10.552f, 11.452f, 12.352f, 13.252f, 7.4f, 7.904f, 8.408f, 8.912f,
6.752, 7.148, 7.544, 7.94 ,17.752,18.652,19.552,20.452,11.432,11.936,12.44 ,12.944, 5.932, 6.184, 6.436, 6.688,14.42 ,14.978,15.536,16.094, 8.704, 9.01 , 9.316, 9.622, 3.11 , 3.236, 3.362, 3.488, 7.39 , 7.669, 7.948, 8.227, 4.388, 4.541, 4.694, 4.847, 6.752f, 7.148f, 7.544f, 7.94f, 17.752f, 18.652f, 19.552f, 20.452f, 11.432f, 11.936f, 12.44f, 12.944f, 5.932f, 6.184f, 6.436f, 6.688f, 14.42f, 14.978f, 15.536f, 16.094f, 8.704f, 9.01f, 9.316f, 9.622f, 3.11f, 3.236f, 3.362f, 3.488f, 7.39f, 7.669f, 7.948f, 8.227f, 4.388f, 4.541f, 4.694f, 4.847f,
8.56 , 8.866, 9.172, 9.478,19.892,20.558,21.224,21.89 ,11.548,11.908,12.268,12.628,11.008,11.314,11.62 ,11.926,25.22 ,25.886,26.552,27.218,14.428,14.788,15.148,15.508, 7.322, 7.502, 7.682, 7.862,16.462,16.849,17.236,17.623, 9.248, 9.455, 9.662, 9.869, 8.56f, 8.866f, 9.172f, 9.478f, 19.892f, 20.558f, 21.224f, 21.89f, 11.548f, 11.908f, 12.268f, 12.628f, 11.008f, 11.314f, 11.62f, 11.926f, 25.22f, 25.886f, 26.552f, 27.218f, 14.428f, 14.788f, 15.148f, 15.508f, 7.322f, 7.502f, 7.682f, 7.862f, 16.462f, 16.849f, 17.236f, 17.623f, 9.248f, 9.455f, 9.662f, 9.869f,
0.158, 0.392, 0.626, 0.86 , 1.27 , 1.765, 2.26 , 2.755, 1.22 , 1.481, 1.742, 2.003, 2.224, 2.746, 3.268, 3.79 , 6.788, 7.886, 8.984,10.082, 4.78 , 5.356, 5.932, 6.508, 6.4 , 6.922, 7.444, 7.966,15.572,16.67 ,17.768,18.866, 9.388, 9.964,10.54 ,11.116, 0.158f, 0.392f, 0.626f, 0.86f, 1.27f, 1.765f, 2.26f, 2.755f, 1.22f, 1.481f, 1.742f, 2.003f, 2.224f, 2.746f, 3.268f, 3.79f, 6.788f, 7.886f, 8.984f, 10.082f, 4.78f, 5.356f, 5.932f, 6.508f, 6.4f, 6.922f, 7.444f, 7.966f, 15.572f, 16.67f, 17.768f, 18.866f, 9.388f, 9.964f, 10.54f, 11.116f,
4.802, 5.09 , 5.378, 5.666,11.206,11.809,12.412,13.015, 6.512, 6.827, 7.142, 7.457, 6.004, 6.58 , 7.156, 7.732,14.996,16.202,17.408,18.614, 9.208, 9.838,10.468,11.098,17.984,19.244,20.504,21.764,42.808,45.436,48.064,50.692,25.256,26.624,27.992,29.36 , 4.802f, 5.09f, 5.378f, 5.666f, 11.206f, 11.809f, 12.412f, 13.015f, 6.512f, 6.827f, 7.142f, 7.457f, 6.004f, 6.58f, 7.156f, 7.732f, 14.996f, 16.202f, 17.408f, 18.614f, 9.208f, 9.838f, 10.468f, 11.098f, 17.984f, 19.244f, 20.504f, 21.764f, 42.808f, 45.436f, 48.064f, 50.692f, 25.256f, 26.624f, 27.992f, 29.36f,
28.064,29.324,30.584,31.844,63.832,66.46 ,69.088,71.716,36.2 ,37.568,38.936,40.304,18.316,19. ,19.684,20.368,40.916,42.338,43.76 ,45.182,22.816,23.554,24.292,25.03 , 8.438, 8.78 , 9.122, 9.464,18.91 ,19.621,20.332,21.043,10.58 ,10.949,11.318,11.687, 28.064f, 29.324f, 30.584f, 31.844f, 63.832f, 66.46f, 69.088f, 71.716f, 36.2f, 37.568f, 38.936f, 40.304f, 18.316f, 19.f, 19.684f, 20.368f, 40.916f, 42.338f, 43.76f, 45.182f, 22.816f, 23.554f, 24.292f, 25.03f, 8.438f, 8.78f, 9.122f, 9.464f, 18.91f, 19.621f, 20.332f, 21.043f, 10.58f, 10.949f, 11.318f, 11.687f,
20.944,21.682,22.42 ,23.158,46.388,47.918,49.448,50.978,25.66 ,26.452,27.244,28.036,26.848,27.586,28.324,29.062,58.628,60.158,61.688,63.218,31.996,32.788,33.58 ,34.372,16.106,16.502,16.898,17.294,34.894,35.713,36.532,37.351,18.896,19.319,19.742,20.165}); 20.944f, 21.682f, 22.42f, 23.158f, 46.388f, 47.918f, 49.448f, 50.978f, 25.66f, 26.452f, 27.244f, 28.036f, 26.848f, 27.586f, 28.324f, 29.062f, 58.628f, 60.158f, 61.688f, 63.218f, 31.996f, 32.788f, 33.58f, 34.372f, 16.106f, 16.502f, 16.898f, 17.294f, 34.894f, 35.713f, 36.532f, 37.351f, 18.896f, 19.319f, 19.742f, 20.165f});
auto expGradW = NDArrayFactory::create<TypeParam>('c', {kD, kH, kW, iC, oC},{7.52, 7.84, 8.16,7.52, 7.84, 8.16,7.52, 7.84, 8.16,7.52, 7.84, 8.16,7.52, 7.84, 8.16,7.52, 7.84, 8.16,7.52, 7.84, 8.16,7.52, 7.84, 8.16,7.52, 7.84, 8.16,7.52, 7.84, 8.16,7.52, 7.84, 8.16,7.52, 7.84, 8.16, auto expGradW = NDArrayFactory::create<TypeParam>('c', {kD, kH, kW, iC, oC},{7.52f, 7.84f, 8.16f, 7.52f, 7.84f, 8.16f, 7.52f, 7.84f, 8.16f, 7.52f, 7.84f, 8.16f, 7.52f, 7.84f, 8.16f, 7.52f, 7.84f, 8.16f, 7.52f, 7.84f, 8.16f, 7.52f, 7.84f, 8.16f, 7.52f, 7.84f, 8.16f, 7.52f, 7.84f, 8.16f, 7.52f, 7.84f, 8.16f, 7.52f, 7.84f, 8.16f,
7.52, 7.84, 8.16,7.52, 7.84, 8.16,7.52, 7.84, 8.16,7.52, 7.84, 8.16,7.52, 7.84, 8.16,7.52, 7.84, 8.16,7.52, 7.84, 8.16,7.52, 7.84, 8.16,7.52, 7.84, 8.16,7.52, 7.84, 8.16,7.52, 7.84, 8.16,7.52, 7.84, 8.16, 7.52f, 7.84f, 8.16f, 7.52f, 7.84f, 8.16f, 7.52f, 7.84f, 8.16f, 7.52f, 7.84f, 8.16f, 7.52f, 7.84f, 8.16f, 7.52f, 7.84f, 8.16f, 7.52f, 7.84f, 8.16f, 7.52f, 7.84f, 8.16f, 7.52f, 7.84f, 8.16f, 7.52f, 7.84f, 8.16f, 7.52f, 7.84f, 8.16f, 7.52f, 7.84f, 8.16f,
7.52, 7.84, 8.16,7.52, 7.84, 8.16,7.52, 7.84, 8.16,7.52, 7.84, 8.16,7.52, 7.84, 8.16,7.52, 7.84, 8.16,7.52, 7.84, 8.16,7.52, 7.84, 8.16,7.52, 7.84, 8.16,7.52, 7.84, 8.16,7.52, 7.84, 8.16,7.52, 7.84, 8.16, 7.52f, 7.84f, 8.16f, 7.52f, 7.84f, 8.16f, 7.52f, 7.84f, 8.16f, 7.52f, 7.84f, 8.16f, 7.52f, 7.84f, 8.16f, 7.52f, 7.84f, 8.16f, 7.52f, 7.84f, 8.16f, 7.52f, 7.84f, 8.16f, 7.52f, 7.84f, 8.16f, 7.52f, 7.84f, 8.16f, 7.52f, 7.84f, 8.16f, 7.52f, 7.84f, 8.16f,
7.52, 7.84, 8.16,7.52, 7.84, 8.16,7.52, 7.84, 8.16,7.52, 7.84, 8.16,7.52, 7.84, 8.16,7.52, 7.84, 8.16,7.52, 7.84, 8.16,7.52, 7.84, 8.16,7.52, 7.84, 8.16,7.52, 7.84, 8.16,7.52, 7.84, 8.16,7.52, 7.84, 8.16}); 7.52f, 7.84f, 8.16f, 7.52f, 7.84f, 8.16f, 7.52f, 7.84f, 8.16f, 7.52f, 7.84f, 8.16f, 7.52f, 7.84f, 8.16f, 7.52f, 7.84f, 8.16f, 7.52f, 7.84f, 8.16f, 7.52f, 7.84f, 8.16f, 7.52f, 7.84f, 8.16f, 7.52f, 7.84f, 8.16f, 7.52f, 7.84f, 8.16f, 7.52f, 7.84f, 8.16f});
// auto expGradB('c', {oC},{}); // auto expGradB('c', {oC},{});
input = 2.; input = 2.;
@ -1351,23 +1350,23 @@ TYPED_TEST(TypedConvolutionTests1, conv3d_bp_test3) {
auto bias = NDArrayFactory::create<TypeParam>('c', {oC}, {1,2,3}); auto bias = NDArrayFactory::create<TypeParam>('c', {oC}, {1,2,3});
auto gradO = NDArrayFactory::create<TypeParam>('c', {bS, oC, oD, oH, oW}); auto gradO = NDArrayFactory::create<TypeParam>('c', {bS, oC, oD, oH, oW});
auto expGradI = NDArrayFactory::create<TypeParam>('c', {bS, iC, iD, iH, iW},{2.091, 4.356, 2.268, 4.53 , 9.42 , 4.896, 4.65 , 9.672, 5.028, 2.517, 5.226, 2.712, 4.932,10.242, 5.316,10.62 ,22.02 ,11.412,10.908,22.62 ,11.724, 5.868,12.15 , 6.288, 2.913, 6.03 , 3.12 , 6.234,12.888, 6.66 , 6.402,13.236, 6.84 , 3.423, 7.068, 3.648, auto expGradI = NDArrayFactory::create<TypeParam>('c', {bS, iC, iD, iH, iW},{2.091f, 4.356f, 2.268f, 4.53f, 9.42f, 4.896f, 4.65f, 9.672f, 5.028f, 2.517f, 5.226f, 2.712f, 4.932f, 10.242f, 5.316f, 10.62f, 22.02f, 11.412f, 10.908f, 22.62f, 11.724f, 5.868f, 12.15f, 6.288f, 2.913f, 6.03f, 3.12f, 6.234f, 12.888f, 6.66f, 6.402f, 13.236f, 6.84f, 3.423f, 7.068f, 3.648f,
2.415, 5.04 , 2.628, 5.25 ,10.932, 5.688, 5.37 ,11.184, 5.82 , 2.913, 6.054, 3.144, 5.724,11.898, 6.18 ,12.348,25.62 ,13.284,12.636,26.22 ,13.596, 6.804,14.094, 7.296, 3.381, 7.002, 3.624, 7.242,14.976, 7.74 , 7.41 ,15.324, 7.92 , 3.963, 8.184, 4.224, 2.415f, 5.04f, 2.628f, 5.25f, 10.932f, 5.688f, 5.37f, 11.184f, 5.82f, 2.913f, 6.054f, 3.144f, 5.724f, 11.898f, 6.18f, 12.348f, 25.62f, 13.284f, 12.636f, 26.22f, 13.596f, 6.804f, 14.094f, 7.296f, 3.381f, 7.002f, 3.624f, 7.242f, 14.976f, 7.74f, 7.41f, 15.324f, 7.92f, 3.963f, 8.184f, 4.224f,
2.739, 5.724, 2.988, 5.97 ,12.444, 6.48 , 6.09 ,12.696, 6.612, 3.309, 6.882, 3.576, 6.516,13.554, 7.044,14.076,29.22 ,15.156,14.364,29.82 ,15.468, 7.74 ,16.038, 8.304, 3.849, 7.974, 4.128, 8.25 ,17.064, 8.82 , 8.418,17.412, 9. , 4.503, 9.3 , 4.8 , 2.739f, 5.724f, 2.988f, 5.97f, 12.444f, 6.48f, 6.09f, 12.696f, 6.612f, 3.309f, 6.882f, 3.576f, 6.516f, 13.554f, 7.044f, 14.076f, 29.22f, 15.156f, 14.364f, 29.82f, 15.468f, 7.74f, 16.038f, 8.304f, 3.849f, 7.974f, 4.128f, 8.25f, 17.064f, 8.82f, 8.418f, 17.412f, 9.f, 4.503f, 9.3f, 4.8f,
3.063, 6.408, 3.348, 6.69 ,13.956, 7.272, 6.81 ,14.208, 7.404, 3.705, 7.71 , 4.008, 7.308,15.21 , 7.908,15.804,32.82 ,17.028,16.092,33.42 ,17.34 , 8.676,17.982, 9.312, 4.317, 8.946, 4.632, 9.258,19.152, 9.9 , 9.426,19.5 ,10.08 , 5.043,10.416, 5.376, 3.063f, 6.408f, 3.348f, 6.69f, 13.956f, 7.272f, 6.81f, 14.208f, 7.404f, 3.705f, 7.71f, 4.008f, 7.308f, 15.21f, 7.908f, 15.804f, 32.82f, 17.028f, 16.092f, 33.42f, 17.34f, 8.676f, 17.982f, 9.312f, 4.317f, 8.946f, 4.632f, 9.258f, 19.152f, 9.9f, 9.426f, 19.5f, 10.08f, 5.043f, 10.416f, 5.376f,
5.619,11.484, 5.868,11.73 ,23.964,12.24 ,12.138,24.792,12.66 , 6.333,12.93 , 6.6 ,12.42 ,25.362,12.948,25.884,52.836,26.964,26.748,54.588,27.852,13.932,28.422,14.496, 6.873,14.022, 7.152,14.298,29.16 ,14.868,14.754,30.084,15.336, 7.671,15.636, 7.968, 5.619f, 11.484f, 5.868f, 11.73f, 23.964f, 12.24f, 12.138f, 24.792f, 12.66f, 6.333f, 12.93f, 6.6f, 12.42f, 25.362f, 12.948f, 25.884f, 52.836f, 26.964f, 26.748f, 54.588f, 27.852f, 13.932f, 28.422f, 14.496f, 6.873f, 14.022f, 7.152f, 14.298f, 29.16f, 14.868f, 14.754f, 30.084f, 15.336f, 7.671f, 15.636f, 7.968f,
6.807,13.896, 7.092,14.178,28.932,14.76 ,14.586,29.76 ,15.18 , 7.593,15.486, 7.896,14.94 ,30.474,15.54 ,31.068,63.348,32.292,31.932,65.1 ,33.18 ,16.596,33.822,17.232, 8.205,16.722, 8.52 ,17.034,34.704,17.676,17.49 ,35.628,18.144, 9.075,18.48 , 9.408, 6.807f, 13.896f, 7.092f, 14.178f, 28.932f, 14.76f, 14.586f, 29.76f, 15.18f, 7.593f, 15.486f, 7.896f, 14.94f, 30.474f, 15.54f, 31.068f, 63.348f, 32.292f, 31.932f, 65.1f, 33.18f, 16.596f, 33.822f, 17.232f, 8.205f, 16.722f, 8.52f, 17.034f, 34.704f, 17.676f, 17.49f, 35.628f, 18.144f, 9.075f, 18.48f, 9.408f,
7.995,16.308, 8.316,16.626,33.9 ,17.28 ,17.034,34.728,17.7 , 8.853,18.042, 9.192,17.46 ,35.586,18.132,36.252,73.86 ,37.62 ,37.116,75.612,38.508,19.26 ,39.222,19.968, 9.537,19.422, 9.888,19.77 ,40.248,20.484,20.226,41.172,20.952,10.479,21.324,10.848, 7.995f, 16.308f, 8.316f, 16.626f, 33.9f, 17.28f, 17.034f, 34.728f, 17.7f, 8.853f, 18.042f, 9.192f, 17.46f, 35.586f, 18.132f, 36.252f, 73.86f, 37.62f, 37.116f, 75.612f, 38.508f, 19.26f, 39.222f, 19.968f, 9.537f, 19.422f, 9.888f, 19.77f, 40.248f, 20.484f, 20.226f, 41.172f, 20.952f, 10.479f, 21.324f, 10.848f,
9.183,18.72 , 9.54 ,19.074,38.868,19.8 ,19.482,39.696,20.22 ,10.113,20.598,10.488,19.98 ,40.698,20.724,41.436,84.372,42.948,42.3 ,86.124,43.836,21.924,44.622,22.704,10.869,22.122,11.256,22.506,45.792,23.292,22.962,46.716,23.76 ,11.883,24.168,12.288}); 9.183f, 18.72f, 9.54f, 19.074f, 38.868f, 19.8f, 19.482f, 39.696f, 20.22f, 10.113f, 20.598f, 10.488f, 19.98f, 40.698f, 20.724f, 41.436f, 84.372f, 42.948f, 42.3f, 86.124f, 43.836f, 21.924f, 44.622f, 22.704f, 10.869f, 22.122f, 11.256f, 22.506f, 45.792f, 23.292f, 22.962f, 46.716f, 23.76f, 11.883f, 24.168f, 12.288f});
auto expGradW = NDArrayFactory::create<TypeParam>('c', {oC, iC, kD, kH, kW},{5.28, 5.28, 5.28, 5.28, 5.28, 5.28, 5.28, 5.28, 5.28, 5.28, 5.28, 5.28, 5.28, 5.28, 5.28, 5.28, 5.28, 5.28, 5.28, 5.28, 5.28, 5.28, 5.28, 5.28, auto expGradW = NDArrayFactory::create<TypeParam>('c', {oC, iC, kD, kH, kW},{5.28f, 5.28f, 5.28f, 5.28f, 5.28f, 5.28f, 5.28f, 5.28f, 5.28f, 5.28f, 5.28f, 5.28f, 5.28f, 5.28f, 5.28f, 5.28f, 5.28f, 5.28f, 5.28f, 5.28f, 5.28f, 5.28f, 5.28f, 5.28f,
5.28, 5.28, 5.28, 5.28, 5.28, 5.28, 5.28, 5.28, 5.28, 5.28, 5.28, 5.28, 5.28, 5.28, 5.28, 5.28, 5.28, 5.28, 5.28, 5.28, 5.28, 5.28, 5.28, 5.28, 5.28f, 5.28f, 5.28f, 5.28f, 5.28f, 5.28f, 5.28f, 5.28f, 5.28f, 5.28f, 5.28f, 5.28f, 5.28f, 5.28f, 5.28f, 5.28f, 5.28f, 5.28f, 5.28f, 5.28f, 5.28f, 5.28f, 5.28f, 5.28f,
7.84, 7.84, 7.84, 7.84, 7.84, 7.84, 7.84, 7.84, 7.84, 7.84, 7.84, 7.84, 7.84, 7.84, 7.84, 7.84, 7.84, 7.84, 7.84, 7.84, 7.84, 7.84, 7.84, 7.84, 7.84f, 7.84f, 7.84f, 7.84f, 7.84f, 7.84f, 7.84f, 7.84f, 7.84f, 7.84f, 7.84f, 7.84f, 7.84f, 7.84f, 7.84f, 7.84f, 7.84f, 7.84f, 7.84f, 7.84f, 7.84f, 7.84f, 7.84f, 7.84f,
7.84, 7.84, 7.84, 7.84, 7.84, 7.84, 7.84, 7.84, 7.84, 7.84, 7.84, 7.84, 7.84, 7.84, 7.84, 7.84, 7.84, 7.84, 7.84, 7.84, 7.84, 7.84, 7.84, 7.84, 7.84f, 7.84f, 7.84f, 7.84f, 7.84f, 7.84f, 7.84f, 7.84f, 7.84f, 7.84f, 7.84f, 7.84f, 7.84f, 7.84f, 7.84f, 7.84f, 7.84f, 7.84f, 7.84f, 7.84f, 7.84f, 7.84f, 7.84f, 7.84f,
10.4, 10.4, 10.4, 10.4, 10.4, 10.4, 10.4, 10.4, 10.4, 10.4, 10.4, 10.4, 10.4, 10.4, 10.4, 10.4, 10.4, 10.4, 10.4, 10.4, 10.4, 10.4, 10.4, 10.4, 10.4f, 10.4f, 10.4f, 10.4f, 10.4f, 10.4f, 10.4f, 10.4f, 10.4f, 10.4f, 10.4f, 10.4f, 10.4f, 10.4f, 10.4f, 10.4f, 10.4f, 10.4f, 10.4f, 10.4f, 10.4f, 10.4f, 10.4f, 10.4f,
10.4, 10.4, 10.4, 10.4, 10.4, 10.4, 10.4, 10.4, 10.4, 10.4, 10.4, 10.4, 10.4, 10.4, 10.4, 10.4, 10.4, 10.4, 10.4, 10.4, 10.4, 10.4, 10.4, 10.4}); 10.4f, 10.4f, 10.4f, 10.4f, 10.4f, 10.4f, 10.4f, 10.4f, 10.4f, 10.4f, 10.4f, 10.4f, 10.4f, 10.4f, 10.4f, 10.4f, 10.4f, 10.4f, 10.4f, 10.4f, 10.4f, 10.4f, 10.4f, 10.4f});
auto expGradB = NDArrayFactory::create<TypeParam>('c', {oC},{2.64, 3.92, 5.2 }); auto expGradB = NDArrayFactory::create<TypeParam>('c', {oC},{2.64f, 3.92f, 5.2f});
input = 2.; input = 2.;
weights.linspace(0.1, 0.1); weights.linspace(0.1, 0.1);
@ -1408,10 +1407,10 @@ TYPED_TEST(TypedConvolutionTests1, depthwise_conv2d_1) {
auto weights = NDArrayFactory::create<TypeParam>('c', {kH, kW, iC, mC}); auto weights = NDArrayFactory::create<TypeParam>('c', {kH, kW, iC, mC});
auto expOutput = NDArrayFactory::create<TypeParam>('c', {bS, oH, oW, oC},{12. , 12.8, 13.6, 14.4,12. , 12.8, 13.6, 14.4, 5.2, 5.6, 6. , 6.4,13.2, 14.4, 15.6, 16.8,13.2, 14.4, 15.6, 16.8, 5.4, 6. , 6.6, 7.2, auto expOutput = NDArrayFactory::create<TypeParam>('c', {bS, oH, oW, oC},{12.f, 12.8f, 13.6f, 14.4f, 12.f, 12.8f, 13.6f, 14.4f, 5.2f, 5.6f, 6.f, 6.4f, 13.2f, 14.4f, 15.6f, 16.8f, 13.2f, 14.4f, 15.6f, 16.8f, 5.4f, 6.f, 6.6f, 7.2f,
13.2, 14.4, 15.6, 16.8,13.2, 14.4, 15.6, 16.8, 5.4, 6. , 6.6, 7.2, 5.6, 6.4, 7.2, 8. , 5.6, 6.4, 7.2, 8. , 2. , 2.4, 2.8, 3.2, 13.2f, 14.4f, 15.6f, 16.8f, 13.2f, 14.4f, 15.6f, 16.8f, 5.4f, 6.f, 6.6f, 7.2f, 5.6f, 6.4f, 7.2f, 8.f, 5.6f, 6.4f, 7.2f, 8.f, 2.f, 2.4f, 2.8f, 3.2f,
12. , 12.8, 13.6, 14.4,12. , 12.8, 13.6, 14.4, 5.2, 5.6, 6. , 6.4,13.2, 14.4, 15.6, 16.8,13.2, 14.4, 15.6, 16.8, 5.4, 6. , 6.6, 7.2, 12.f, 12.8f, 13.6f, 14.4f, 12.f, 12.8f, 13.6f, 14.4f, 5.2f, 5.6f, 6.f, 6.4f, 13.2f, 14.4f, 15.6f, 16.8f, 13.2f, 14.4f, 15.6f, 16.8f, 5.4f, 6.f, 6.6f, 7.2f,
13.2, 14.4, 15.6, 16.8,13.2, 14.4, 15.6, 16.8, 5.4, 6. , 6.6, 7.2, 5.6, 6.4, 7.2, 8. , 5.6, 6.4, 7.2, 8. , 2. , 2.4, 2.8, 3.2}); 13.2f, 14.4f, 15.6f, 16.8f, 13.2f, 14.4f, 15.6f, 16.8f, 5.4f, 6.f, 6.6f, 7.2f, 5.6f, 6.4f, 7.2f, 8.f, 5.6f, 6.4f, 7.2f, 8.f, 2.f, 2.4f, 2.8f, 3.2f});
input = 2.; input = 2.;
weights.linspace(0.1, 0.1); weights.linspace(0.1, 0.1);
@ -1440,8 +1439,8 @@ TEST_F(ConvolutionTests1, depthwise_conv2d_2) {
auto weights = NDArrayFactory::create<double>('c', {kH, kW, iC, mC}); auto weights = NDArrayFactory::create<double>('c', {kH, kW, iC, mC});
auto expOutput = NDArrayFactory::create<double>('c', {bS, oH, oW, oC},{13.2, 14.4, 15.6, 16.8,13.2, 14.4, 15.6, 16.8,13.2, 14.4, 15.6, 16.8,13.2, 14.4, 15.6, 16.8, auto expOutput = NDArrayFactory::create<double>('c', {bS, oH, oW, oC},{13.2f, 14.4f, 15.6f, 16.8f, 13.2f, 14.4f, 15.6f, 16.8f, 13.2f, 14.4f, 15.6f, 16.8f, 13.2f, 14.4f, 15.6f, 16.8f,
13.2, 14.4, 15.6, 16.8,13.2, 14.4, 15.6, 16.8,13.2, 14.4, 15.6, 16.8,13.2, 14.4, 15.6, 16.8}); 13.2f, 14.4f, 15.6f, 16.8f, 13.2f, 14.4f, 15.6f, 16.8f, 13.2f, 14.4f, 15.6f, 16.8f, 13.2f, 14.4f, 15.6f, 16.8f});
input = 2.; input = 2.;
weights.linspace(0.1, 0.1); weights.linspace(0.1, 0.1);
@ -1698,14 +1697,14 @@ TYPED_TEST(TypedConvolutionTests1, conv3d_test1) {
auto input = NDArrayFactory::create<TypeParam>('c', {bS, iD, iH, iW, iC}); auto input = NDArrayFactory::create<TypeParam>('c', {bS, iD, iH, iW, iC});
auto weights = NDArrayFactory::create<TypeParam>('c', {kD, kH, kW, iC, oC}); auto weights = NDArrayFactory::create<TypeParam>('c', {kD, kH, kW, iC, oC});
auto expected = NDArrayFactory::create<TypeParam>('c', {2, 3, 4, 3, 3}, {64.,64.,64.,64.,64.,64.,32.,32.,32.,96.,96.,96.,96.,96.,96.,48.,48.,48.,96.,96.,96.,96.,96.,96.,48.,48.,48., auto expected = NDArrayFactory::create<TypeParam>('c', {2, 3, 4, 3, 3}, {64.f, 64.f, 64.f, 64.f, 64.f, 64.f, 32.f, 32.f, 32.f, 96.f, 96.f, 96.f, 96.f, 96.f, 96.f, 48.f, 48.f, 48.f, 96.f, 96.f, 96.f, 96.f, 96.f, 96.f, 48.f, 48.f, 48.f,
64.,64.,64.,64.,64.,64.,32.,32.,32.,64.,64.,64.,64.,64.,64.,32.,32.,32.,96.,96.,96.,96.,96.,96.,48.,48.,48., 64.f, 64.f, 64.f, 64.f, 64.f, 64.f, 32.f, 32.f, 32.f, 64.f, 64.f, 64.f, 64.f, 64.f, 64.f, 32.f, 32.f, 32.f, 96.f, 96.f, 96.f, 96.f, 96.f, 96.f, 48.f, 48.f, 48.f,
96.,96.,96.,96.,96.,96.,48.,48.,48.,64.,64.,64.,64.,64.,64.,32.,32.,32.,32.,32.,32.,32.,32.,32.,16.,16.,16., 96.f, 96.f, 96.f, 96.f, 96.f, 96.f, 48.f, 48.f, 48.f, 64.f, 64.f, 64.f, 64.f, 64.f, 64.f, 32.f, 32.f, 32.f, 32.f, 32.f, 32.f, 32.f, 32.f, 32.f, 16.f, 16.f, 16.f,
48.,48.,48.,48.,48.,48.,24.,24.,24.,48.,48.,48.,48.,48.,48.,24.,24.,24.,32.,32.,32.,32.,32.,32.,16.,16.,16., 48.f, 48.f, 48.f, 48.f, 48.f, 48.f, 24.f, 24.f, 24.f, 48.f, 48.f, 48.f, 48.f, 48.f, 48.f, 24.f, 24.f, 24.f, 32.f, 32.f, 32.f, 32.f, 32.f, 32.f, 16.f, 16.f, 16.f,
64.,64.,64.,64.,64.,64.,32.,32.,32.,96.,96.,96.,96.,96.,96.,48.,48.,48.,96.,96.,96.,96.,96.,96.,48.,48.,48., 64.f, 64.f, 64.f, 64.f, 64.f, 64.f, 32.f, 32.f, 32.f, 96.f, 96.f, 96.f, 96.f, 96.f, 96.f, 48.f, 48.f, 48.f, 96.f, 96.f, 96.f, 96.f, 96.f, 96.f, 48.f, 48.f, 48.f,
64.,64.,64.,64.,64.,64.,32.,32.,32.,64.,64.,64.,64.,64.,64.,32.,32.,32.,96.,96.,96.,96.,96.,96.,48.,48.,48., 64.f, 64.f, 64.f, 64.f, 64.f, 64.f, 32.f, 32.f, 32.f, 64.f, 64.f, 64.f, 64.f, 64.f, 64.f, 32.f, 32.f, 32.f, 96.f, 96.f, 96.f, 96.f, 96.f, 96.f, 48.f, 48.f, 48.f,
96.,96.,96.,96.,96.,96.,48.,48.,48.,64.,64.,64.,64.,64.,64.,32.,32.,32.,32.,32.,32.,32.,32.,32.,16.,16.,16., 96.f, 96.f, 96.f, 96.f, 96.f, 96.f, 48.f, 48.f, 48.f, 64.f, 64.f, 64.f, 64.f, 64.f, 64.f, 32.f, 32.f, 32.f, 32.f, 32.f, 32.f, 32.f, 32.f, 32.f, 16.f, 16.f, 16.f,
48.,48.,48.,48.,48.,48.,24.,24.,24.,48.,48.,48.,48.,48.,48.,24.,24.,24.,32.,32.,32.,32.,32.,32.,16.,16.,16.}); 48.f, 48.f, 48.f, 48.f, 48.f, 48.f, 24.f, 24.f, 24.f, 48.f, 48.f, 48.f, 48.f, 48.f, 48.f, 24.f, 24.f, 24.f, 32.f, 32.f, 32.f, 32.f, 32.f, 32.f, 16.f, 16.f, 16.f});
input = 2.; input = 2.;
weights = 1.; weights = 1.;
@ -1730,14 +1729,14 @@ TYPED_TEST(TypedConvolutionTests1, conv3d_test2) {
auto input = NDArrayFactory::create<TypeParam>('c', {bS, iD, iH, iW, iC}); auto input = NDArrayFactory::create<TypeParam>('c', {bS, iD, iH, iW, iC});
auto weights = NDArrayFactory::create<TypeParam>('c', {kD, kH, kW, iC, oC}); auto weights = NDArrayFactory::create<TypeParam>('c', {kD, kH, kW, iC, oC});
auto expected = NDArrayFactory::create<TypeParam>('c', {2, 3, 4, 3, 3}, {534.4,540.8,547.2,534.4,540.8,547.2,248. ,251.2,254.4,686.4,696. ,705.6,686.4,696. ,705.6,314.4,319.2,324. ,686.4,696. ,705.6,686.4,696. ,705.6,314.4,319.2,324. , auto expected = NDArrayFactory::create<TypeParam>('c', {2, 3, 4, 3, 3}, {534.4f, 540.8f, 547.2f, 534.4f, 540.8f, 547.2f, 248.f, 251.2f, 254.4f, 686.4f, 696.f, 705.6f, 686.4f, 696.f, 705.6f, 314.4f, 319.2f, 324.f, 686.4f, 696.f, 705.6f, 686.4f, 696.f, 705.6f, 314.4f, 319.2f, 324.f,
380.8,387.2,393.6,380.8,387.2,393.6,171.2,174.4,177.6,534.4,540.8,547.2,534.4,540.8,547.2,248. ,251.2,254.4,686.4,696. ,705.6,686.4,696. ,705.6,314.4,319.2,324. , 380.8f, 387.2f, 393.6f, 380.8f, 387.2f, 393.6f, 171.2f, 174.4f, 177.6f, 534.4f, 540.8f, 547.2f, 534.4f, 540.8f, 547.2f, 248.f, 251.2f, 254.4f, 686.4f, 696.f, 705.6f, 686.4f, 696.f, 705.6f, 314.4f, 319.2f, 324.f,
686.4,696. ,705.6,686.4,696. ,705.6,314.4,319.2,324. ,380.8,387.2,393.6,380.8,387.2,393.6,171.2,174.4,177.6,152. ,155.2,158.4,152. ,155.2,158.4, 66.4, 68. , 69.6, 686.4f, 696.f, 705.6f, 686.4f, 696.f, 705.6f, 314.4f, 319.2f, 324.f, 380.8f, 387.2f, 393.6f, 380.8f, 387.2f, 393.6f, 171.2f, 174.4f, 177.6f, 152.f, 155.2f, 158.4f, 152.f, 155.2f, 158.4f, 66.4f, 68.f, 69.6f,
170.4,175.2,180. ,170.4,175.2,180. , 70.8, 73.2, 75.6,170.4,175.2,180. ,170.4,175.2,180. , 70.8, 73.2, 75.6, 75.2, 78.4, 81.6, 75.2, 78.4, 81.6, 28. , 29.6, 31.2, 170.4f, 175.2f, 180.f, 170.4f, 175.2f, 180.f, 70.8f, 73.2f, 75.6f, 170.4f, 175.2f, 180.f, 170.4f, 175.2f, 180.f, 70.8f, 73.2f, 75.6f, 75.2f, 78.4f, 81.6f, 75.2f, 78.4f, 81.6f, 28.f, 29.6f, 31.2f,
534.4,540.8,547.2,534.4,540.8,547.2,248. ,251.2,254.4,686.4,696. ,705.6,686.4,696. ,705.6,314.4,319.2,324. ,686.4,696. ,705.6,686.4,696. ,705.6,314.4,319.2,324. , 534.4f, 540.8f, 547.2f, 534.4f, 540.8f, 547.2f, 248.f, 251.2f, 254.4f, 686.4f, 696.f, 705.6f, 686.4f, 696.f, 705.6f, 314.4f, 319.2f, 324.f, 686.4f, 696.f, 705.6f, 686.4f, 696.f, 705.6f, 314.4f, 319.2f, 324.f,
380.8,387.2,393.6,380.8,387.2,393.6,171.2,174.4,177.6,534.4,540.8,547.2,534.4,540.8,547.2,248. ,251.2,254.4,686.4,696. ,705.6,686.4,696. ,705.6,314.4,319.2,324. , 380.8f, 387.2f, 393.6f, 380.8f, 387.2f, 393.6f, 171.2f, 174.4f, 177.6f, 534.4f, 540.8f, 547.2f, 534.4f, 540.8f, 547.2f, 248.f, 251.2f, 254.4f, 686.4f, 696.f, 705.6f, 686.4f, 696.f, 705.6f, 314.4f, 319.2f, 324.f,
686.4,696. ,705.6,686.4,696. ,705.6,314.4,319.2,324. ,380.8,387.2,393.6,380.8,387.2,393.6,171.2,174.4,177.6,152. ,155.2,158.4,152. ,155.2,158.4, 66.4, 68. , 69.6, 686.4f, 696.f, 705.6f, 686.4f, 696.f, 705.6f, 314.4f, 319.2f, 324.f, 380.8f, 387.2f, 393.6f, 380.8f, 387.2f, 393.6f, 171.2f, 174.4f, 177.6f, 152.f, 155.2f, 158.4f, 152.f, 155.2f, 158.4f, 66.4f, 68.f, 69.6f,
170.4,175.2,180. ,170.4,175.2,180. , 70.8, 73.2, 75.6,170.4,175.2,180. ,170.4,175.2,180. , 70.8, 73.2, 75.6, 75.2, 78.4, 81.6, 75.2, 78.4, 81.6, 28. , 29.6, 31.2}); 170.4f, 175.2f, 180.f, 170.4f, 175.2f, 180.f, 70.8f, 73.2f, 75.6f, 170.4f, 175.2f, 180.f, 170.4f, 175.2f, 180.f, 70.8f, 73.2f, 75.6f, 75.2f, 78.4f, 81.6f, 75.2f, 78.4f, 81.6f, 28.f, 29.6f, 31.2f});
input = 2.; input = 2.;
weights.linspace(0.1, 0.1); weights.linspace(0.1, 0.1);
@ -1761,10 +1760,10 @@ TYPED_TEST(TypedConvolutionTests1, conv3d_test3) {
auto input = NDArrayFactory::create<TypeParam>('c', {bS, iD, iH, iW, iC}); auto input = NDArrayFactory::create<TypeParam>('c', {bS, iD, iH, iW, iC});
auto weights = NDArrayFactory::create<TypeParam>('c', {kD, kH, kW, iC, oC}); auto weights = NDArrayFactory::create<TypeParam>('c', {kD, kH, kW, iC, oC});
auto expected = NDArrayFactory::create<TypeParam>('c', {2, 2, 2, 2, 3}, {686.4,696.,705.6,686.4,696.,705.6,686.4,696.,705.6,686.4,696.,705.6, auto expected = NDArrayFactory::create<TypeParam>('c', {2, 2, 2, 2, 3}, {686.4f, 696.f, 705.6f, 686.4f, 696.f, 705.6f, 686.4f, 696.f, 705.6f, 686.4f, 696.f, 705.6f,
686.4,696.,705.6,686.4,696.,705.6,686.4,696.,705.6,686.4,696.,705.6, 686.4f, 696.f, 705.6f, 686.4f, 696.f, 705.6f, 686.4f, 696.f, 705.6f, 686.4f, 696.f, 705.6f,
686.4,696.,705.6,686.4,696.,705.6,686.4,696.,705.6,686.4,696.,705.6, 686.4f, 696.f, 705.6f, 686.4f, 696.f, 705.6f, 686.4f, 696.f, 705.6f, 686.4f, 696.f, 705.6f,
686.4,696.,705.6,686.4,696.,705.6,686.4,696.,705.6,686.4,696.,705.6}); 686.4f, 696.f, 705.6f, 686.4f, 696.f, 705.6f, 686.4f, 696.f, 705.6f, 686.4f, 696.f, 705.6f});
input = 2.; input = 2.;
weights.linspace(0.1, 0.1); weights.linspace(0.1, 0.1);
@ -1844,10 +1843,10 @@ TYPED_TEST(TypedConvolutionTests1, conv3d_test6) {
auto input = NDArrayFactory::create<TypeParam>('c', {bS, iC, iD, iH, iW}); auto input = NDArrayFactory::create<TypeParam>('c', {bS, iC, iD, iH, iW});
auto weights = NDArrayFactory::create<TypeParam>('c', {kD, kH, kW, iC, oC}); auto weights = NDArrayFactory::create<TypeParam>('c', {kD, kH, kW, iC, oC});
auto bias = NDArrayFactory::create<TypeParam>('c', {oC},{1,2,3}); auto bias = NDArrayFactory::create<TypeParam>('c', {oC},{1.f, 2.f, 3.f});
auto expected = NDArrayFactory::create<TypeParam>('c', {2, 3, 2, 2, 2},{49., 49.,49., 49., 49., 49.,49., 49., 50., 50.,50., 50., 50., 50.,50., 50., auto expected = NDArrayFactory::create<TypeParam>('c', {2, 3, 2, 2, 2},{49.f, 49.f, 49.f, 49.f, 49.f, 49.f, 49.f, 49.f, 50.f, 50.f, 50.f, 50.f, 50.f, 50.f, 50.f, 50.f,
51., 51.,51., 51., 51., 51.,51., 51., 49., 49.,49., 49., 49., 49.,49., 49., 51.f, 51.f, 51.f, 51.f, 51.f, 51.f, 51.f, 51.f, 49.f, 49.f, 49.f, 49.f, 49.f, 49.f, 49.f, 49.f,
50., 50.,50., 50., 50., 50.,50., 50., 51., 51.,51., 51., 51., 51.,51., 51.}); 50.f, 50.f, 50.f, 50.f, 50.f, 50.f, 50.f, 50.f, 51.f, 51.f, 51.f, 51.f, 51.f, 51.f, 51.f, 51.f});
input = 2.; input = 2.;
weights = 0.5; weights = 0.5;
@ -1873,11 +1872,11 @@ TYPED_TEST(TypedConvolutionTests1, conv3d_test7) {
auto input = NDArrayFactory::create<TypeParam>('c', {bS, iC, iD, iH, iW}); auto input = NDArrayFactory::create<TypeParam>('c', {bS, iC, iD, iH, iW});
auto weights = NDArrayFactory::create<TypeParam>('c', {oC, iC, kD, kH, kW}); auto weights = NDArrayFactory::create<TypeParam>('c', {oC, iC, kD, kH, kW});
auto bias = NDArrayFactory::create<TypeParam>('c', {oC},{1,2,3}); auto bias = NDArrayFactory::create<TypeParam>('c', {oC},{1.f, 2.f, 3.f});
auto expected = NDArrayFactory::create<TypeParam>('c', {2, 3, 2, 2, 2},{236.2, 236.2, 236.2, 236.2, 236.2, 236.2, 236.2, 236.2, 698. , 698. , 698. , 698. , auto expected = NDArrayFactory::create<TypeParam>('c', {2, 3, 2, 2, 2},{236.2f, 236.2f, 236.2f, 236.2f, 236.2f, 236.2f, 236.2f, 236.2f, 698.f, 698.f, 698.f, 698.f,
698. , 698. , 698. , 698. ,1159.8,1159.8,1159.8,1159.8,1159.8,1159.8,1159.8,1159.8, 698.f, 698.f, 698.f, 698.f, 1159.8f, 1159.8f, 1159.8f, 1159.8f, 1159.8f, 1159.8f, 1159.8f, 1159.8f,
236.2, 236.2, 236.2, 236.2, 236.2, 236.2, 236.2, 236.2, 698. , 698. , 698. , 698. , 236.2f, 236.2f, 236.2f, 236.2f, 236.2f, 236.2f, 236.2f, 236.2f, 698.f, 698.f, 698.f, 698.f,
698. , 698. , 698. , 698. ,1159.8,1159.8,1159.8,1159.8,1159.8,1159.8,1159.8,1159.8}); 698.f, 698.f, 698.f, 698.f, 1159.8f, 1159.8f, 1159.8f, 1159.8f, 1159.8f, 1159.8f, 1159.8f, 1159.8f});
input = 2.; input = 2.;
weights.linspace(0.1, 0.1); weights.linspace(0.1, 0.1);
weights.permutei({2, 3, 4, 1, 0}); weights.permutei({2, 3, 4, 1, 0});
@ -1904,9 +1903,9 @@ TYPED_TEST(TypedConvolutionTests1, conv3d_test8) {
auto input = NDArrayFactory::create<TypeParam>('c', {bS, iC, iD, iH, iW}); auto input = NDArrayFactory::create<TypeParam>('c', {bS, iC, iD, iH, iW});
auto weights = NDArrayFactory::create<TypeParam>('c', {oC, iC, kD, kH, kW}); auto weights = NDArrayFactory::create<TypeParam>('c', {oC, iC, kD, kH, kW});
auto expected = NDArrayFactory::create<TypeParam>('c', {2, 3, 2, 2, 2},{235.2, 235.2, 235.2, 235.2, 235.2, 235.2, 235.2, 235.2, 696. , 696. , 696. , 696. , 696. , 696. , 696. , 696. , auto expected = NDArrayFactory::create<TypeParam>('c', {2, 3, 2, 2, 2},{235.2f, 235.2f, 235.2f, 235.2f, 235.2f, 235.2f, 235.2f, 235.2f, 696.f, 696.f, 696.f, 696.f, 696.f, 696.f, 696.f, 696.f,
1156.8,1156.8,1156.8,1156.8,1156.8,1156.8,1156.8,1156.8, 235.2, 235.2, 235.2, 235.2, 235.2, 235.2, 235.2, 235.2, 1156.8f, 1156.8f, 1156.8f, 1156.8f, 1156.8f, 1156.8f, 1156.8f, 1156.8f, 235.2f, 235.2f, 235.2f, 235.2f, 235.2f, 235.2f, 235.2f, 235.2f,
696. , 696. , 696. , 696. , 696. , 696. , 696. , 696. ,1156.8,1156.8,1156.8,1156.8,1156.8,1156.8,1156.8,1156.8}); 696.f, 696.f, 696.f, 696.f, 696.f, 696.f, 696.f, 696.f, 1156.8f, 1156.8f, 1156.8f, 1156.8f, 1156.8f, 1156.8f, 1156.8f, 1156.8f});
input = 2.; input = 2.;
weights.linspace(0.1, 0.1); weights.linspace(0.1, 0.1);
weights.permutei({2, 3, 4, 1, 0}); weights.permutei({2, 3, 4, 1, 0});
@ -1998,10 +1997,10 @@ TYPED_TEST(TypedConvolutionTests1, pointwise_conv2d_test1) {
auto bias = NDArrayFactory::create<TypeParam>('c', {oC}); auto bias = NDArrayFactory::create<TypeParam>('c', {oC});
auto expOutput = NDArrayFactory::create<TypeParam>('c', {bS, iH, iW, oC},{ 5.4, 6.2, 7.0, 5.4, 6.2, 7.0, 5.4, 6.2, 7.0, 5.4, 6.2, 7.0, 5.4, 6.2, 7.0, 5.4, 6.2, 7.0, 5.4, 6.2, auto expOutput = NDArrayFactory::create<TypeParam>('c', {bS, iH, iW, oC},{ 5.4f, 6.2f, 7.0f, 5.4f, 6.2f, 7.0f, 5.4f, 6.2f, 7.0f, 5.4f, 6.2f, 7.0f, 5.4f, 6.2f, 7.0f, 5.4f, 6.2f, 7.0f, 5.4f, 6.2f,
7.0, 5.4, 6.2, 7.0, 5.4, 6.2, 7.0, 5.4, 6.2, 7.0, 5.4, 6.2, 7.0, 5.4, 6.2, 7.0, 5.4, 6.2, 7.0, 5.4, 7.0f, 5.4f, 6.2f, 7.0f, 5.4f, 6.2f, 7.0f, 5.4f, 6.2f, 7.0f, 5.4f, 6.2f, 7.0f, 5.4f, 6.2f, 7.0f, 5.4f, 6.2f, 7.0f, 5.4f,
6.2, 7.0, 5.4, 6.2, 7.0, 5.4, 6.2, 7.0, 5.4, 6.2, 7.0, 5.4, 6.2, 7.0, 5.4, 6.2, 7.0, 5.4, 6.2, 7.0, 6.2f, 7.0f, 5.4f, 6.2f, 7.0f, 5.4f, 6.2f, 7.0f, 5.4f, 6.2f, 7.0f, 5.4f, 6.2f, 7.0f, 5.4f, 6.2f, 7.0f, 5.4f, 6.2f, 7.0f,
5.4, 6.2, 7.0, 5.4, 6.2, 7.0, 5.4, 6.2, 7.0, 5.4, 6.2, 7.0}); 5.4f, 6.2f, 7.0f, 5.4f, 6.2f, 7.0f, 5.4f, 6.2f, 7.0f, 5.4f, 6.2f, 7.0f});
input = 2.; input = 2.;
weights.linspace(0.1, 0.1); weights.linspace(0.1, 0.1);
bias = 1.; bias = 1.;
@ -2111,21 +2110,21 @@ TEST_F(ConvolutionTests1, vol2col_test2) {
auto columns = NDArrayFactory::create<float>('c', {kD, iC, kH, oW, kW, bS, oD, oH}); auto columns = NDArrayFactory::create<float>('c', {kD, iC, kH, oW, kW, bS, oD, oH});
columns.permutei({5, 1, 0, 2, 4, 6, 7, 3}); columns.permutei({5, 1, 0, 2, 4, 6, 7, 3});
columns = -1.; columns = -1.;
auto columnsExpected = NDArrayFactory::create<float>('c', {bS, iC, kD, kH, kW, oD, oH, oW}, {1., 2., 3., 4., 5., 6., 7., 8., 9., auto columnsExpected = NDArrayFactory::create<float>('c', {bS, iC, kD, kH, kW, oD, oH, oW}, {1.f, 2.f, 3.f, 4.f, 5.f, 6.f, 7.f, 8.f, 9.f,
10., 11., 12., 2., 0., 4., 0., 6., 0., 8., 0., 10., 0., 12., 0., 3., 4., 5., 6., 0., 0., 9., 10., 11., 12., 0., 0., 4., 0., 6., 0., 0.,0., 10., 0., 12., 0., 0., 0., 5., 6., 0., 0., 0., 0., 11., 12., 0., 0., 0., 0., 6., 0., 0., 0., 0., 0., 12., 0., 0., 0., 0., 0., 7., 8., 10.f, 11.f, 12.f, 2.f, 0.f, 4.f, 0.f, 6.f, 0.f, 8.f, 0.f, 10.f, 0.f, 12.f, 0.f, 3.f, 4.f, 5.f, 6.f, 0.f, 0.f, 9.f, 10.f, 11.f, 12.f, 0.f, 0.f, 4.f, 0.f, 6.f, 0.f, 0.f, 0.f, 10.f, 0.f, 12.f, 0.f, 0.f, 0.f, 5.f, 6.f, 0.f, 0.f, 0.f, 0.f, 11.f, 12.f, 0.f, 0.f, 0.f, 0.f, 6.f, 0.f, 0.f, 0.f, 0.f, 0.f, 12.f, 0.f, 0.f, 0.f, 0.f, 0.f, 7.f, 8.f,
9., 10., 11., 12., 0., 0., 0., 0., 0., 0., 8., 0., 10., 0., 12., 0., 0., 0., 0., 0., 0., 0., 9., 10., 11., 12., 0., 0., 0., 0., 0., 0., 0., 0., 10., 0., 12., 0., 0., 0., 0., 0., 0., 0., 0., 0., 11., 12., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 12., 0., 0., 0., 0., 0., 9.f, 10.f, 11.f, 12.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 8.f, 0.f, 10.f, 0.f, 12.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 9.f, 10.f, 11.f, 12.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 10.f, 0.f, 12.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 11.f, 12.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 12.f, 0.f, 0.f, 0.f, 0.f, 0.f,
0., 0., 0., 0., 0., 0., 13., 14., 15., 16., 17., 18., 19., 20., 21., 22., 23., 24., 14., 0., 16., 0., 18., 0., 20., 0., 22., 0., 24., 0., 15., 16., 17., 18., 0., 0., 21., 22., 23., 24., 0., 0., 16., 0., 18., 0., 0., 0., 22., 0., 24., 0., 0., 0., 17., 18., 0., 0., 0., 0., 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 13.f, 14.f, 15.f, 16.f, 17.f, 18.f, 19.f, 20.f, 21.f, 22.f, 23.f, 24.f, 14.f, 0.f, 16.f, 0.f, 18.f, 0.f, 20.f, 0.f, 22.f, 0.f, 24.f, 0.f, 15.f, 16.f, 17.f, 18.f, 0.f, 0.f, 21.f, 22.f, 23.f, 24.f, 0.f, 0.f, 16.f, 0.f, 18.f, 0.f, 0.f, 0.f, 22.f, 0.f, 24.f, 0.f, 0.f, 0.f, 17.f, 18.f, 0.f, 0.f, 0.f, 0.f,
23., 24., 0., 0., 0., 0., 18., 0., 0., 0., 0., 0., 24., 0., 0., 0., 0., 0., 19., 20., 21., 22., 23., 24., 0., 0., 0., 0., 0., 0., 20., 0., 22., 0., 24., 0., 0., 0., 0., 0., 0., 0., 21., 22., 23., 24., 0., 0., 0., 0., 0., 0., 0., 0., 22., 0., 24., 0., 0., 0., 0., 0., 0., 23.f, 24.f, 0.f, 0.f, 0.f, 0.f, 18.f, 0.f, 0.f, 0.f, 0.f, 0.f, 24.f, 0.f, 0.f, 0.f, 0.f, 0.f, 19.f, 20.f, 21.f, 22.f, 23.f, 24.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 20.f, 0.f, 22.f, 0.f, 24.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 21.f, 22.f, 23.f, 24.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 22.f, 0.f, 24.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f,
0., 0., 0., 23., 24., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 24., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 25., 26., 27., 28., 29., 30., 31., 32., 33., 34., 35., 36., 26., 0., 28., 0., 30., 0., 32., 0., 34., 0., 36., 0., 27., 28., 29., 30., 0., 0., 33., 34., 35., 36., 0.f, 0.f, 0.f, 23.f, 24.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 24.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 25.f, 26.f, 27.f, 28.f, 29.f, 30.f, 31.f, 32.f, 33.f, 34.f, 35.f, 36.f, 26.f, 0.f, 28.f, 0.f, 30.f, 0.f, 32.f, 0.f, 34.f, 0.f, 36.f, 0.f, 27.f, 28.f, 29.f, 30.f, 0.f, 0.f, 33.f, 34.f, 35.f, 36.f,
0., 0., 28., 0., 30., 0., 0., 0., 34., 0., 36., 0., 0., 0., 29., 30., 0., 0., 0., 0., 35., 36., 0., 0., 0., 0., 30., 0., 0., 0., 0., 0., 36., 0., 0., 0., 0., 0., 31., 32., 33., 34., 35., 36., 0., 0., 0., 0., 0., 0., 32., 0., 34., 0., 36., 0., 0., 0., 0., 0., 0., 0., 33., 0.f, 0.f, 28.f, 0.f, 30.f, 0.f, 0.f, 0.f, 34.f, 0.f, 36.f, 0.f, 0.f, 0.f, 29.f, 30.f, 0.f, 0.f, 0.f, 0.f, 35.f, 36.f, 0.f, 0.f, 0.f, 0.f, 30.f, 0.f, 0.f, 0.f, 0.f, 0.f, 36.f, 0.f, 0.f, 0.f, 0.f, 0.f, 31.f, 32.f, 33.f, 34.f, 35.f, 36.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 32.f, 0.f, 34.f, 0.f, 36.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 33.f,
34., 35., 36., 0., 0., 0., 0., 0., 0., 0., 0., 34., 0., 36., 0., 0., 0., 0., 0., 0., 0., 0., 0., 35., 36., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 36., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 37., 38., 39., 40., 41., 42., 43., 44., 45., 46., 47., 48., 38., 0., 40., 34.f, 35.f, 36.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 34.f, 0.f, 36.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 35.f, 36.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 36.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 37.f, 38.f, 39.f, 40.f, 41.f, 42.f, 43.f, 44.f, 45.f, 46.f, 47.f, 48.f, 38.f, 0.f, 40.f,
0., 42., 0., 44., 0., 46., 0., 48., 0., 39., 40., 41., 42., 0., 0., 45., 46., 47., 48., 0., 0., 40., 0., 42., 0., 0., 0., 46., 0., 48., 0., 0., 0., 41., 42., 0., 0., 0., 0., 47., 48., 0., 0., 0., 0., 42., 0., 0., 0., 0., 0., 48., 0., 0., 0., 0., 0., 43., 44., 45., 46., 47., 0.f, 42.f, 0.f, 44.f, 0.f, 46.f, 0.f, 48.f, 0.f, 39.f, 40.f, 41.f, 42.f, 0.f, 0.f, 45.f, 46.f, 47.f, 48.f, 0.f, 0.f, 40.f, 0.f, 42.f, 0.f, 0.f, 0.f, 46.f, 0.f, 48.f, 0.f, 0.f, 0.f, 41.f, 42.f, 0.f, 0.f, 0.f, 0.f, 47.f, 48.f, 0.f, 0.f, 0.f, 0.f, 42.f, 0.f, 0.f, 0.f, 0.f, 0.f, 48.f, 0.f, 0.f, 0.f, 0.f, 0.f, 43.f, 44.f, 45.f, 46.f, 47.f,
48., 0., 0., 0., 0., 0., 0., 44., 0., 46., 0., 48., 0., 0., 0., 0., 0., 0., 0., 45., 46., 47., 48., 0., 0., 0., 0., 0., 0., 0., 0., 46., 0., 48., 0., 0., 0., 0., 0., 0., 0., 0., 0., 47., 48., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 48., 0., 0., 0., 0., 0., 0., 0., 0., 0., 48.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 44.f, 0.f, 46.f, 0.f, 48.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 45.f, 46.f, 47.f, 48.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 46.f, 0.f, 48.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 47.f, 48.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 48.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f,
0., 0., 49., 50., 51., 52., 53., 54., 55., 56., 57., 58., 59., 60., 50., 0., 52., 0., 54., 0., 56., 0., 58., 0., 60., 0., 51., 52., 53., 54., 0., 0., 57., 58., 59., 60., 0., 0., 52., 0., 54., 0., 0., 0., 58., 0., 60., 0., 0., 0., 53., 54., 0., 0., 0., 0., 59., 60., 0., 0., 0.f, 0.f, 49.f, 50.f, 51.f, 52.f, 53.f, 54.f, 55.f, 56.f, 57.f, 58.f, 59.f, 60.f, 50.f, 0.f, 52.f, 0.f, 54.f, 0.f, 56.f, 0.f, 58.f, 0.f, 60.f, 0.f, 51.f, 52.f, 53.f, 54.f, 0.f, 0.f, 57.f, 58.f, 59.f, 60.f, 0.f, 0.f, 52.f, 0.f, 54.f, 0.f, 0.f, 0.f, 58.f, 0.f, 60.f, 0.f, 0.f, 0.f, 53.f, 54.f, 0.f, 0.f, 0.f, 0.f, 59.f, 60.f, 0.f, 0.f,
0., 0., 54., 0., 0., 0., 0., 0., 60., 0., 0., 0., 0., 0., 55., 56., 57., 58., 59., 60., 0., 0., 0., 0., 0., 0., 56., 0., 58., 0., 60., 0., 0., 0., 0., 0., 0., 0., 57., 58., 59., 60., 0., 0., 0., 0., 0., 0., 0., 0., 58., 0., 60., 0., 0., 0., 0., 0., 0., 0., 0., 0., 59., 60., 0.f, 0.f, 54.f, 0.f, 0.f, 0.f, 0.f, 0.f, 60.f, 0.f, 0.f, 0.f, 0.f, 0.f, 55.f, 56.f, 57.f, 58.f, 59.f, 60.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 56.f, 0.f, 58.f, 0.f, 60.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 57.f, 58.f, 59.f, 60.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 58.f, 0.f, 60.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 59.f, 60.f,
0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 60., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 61., 62., 63., 64., 65., 66., 67., 68., 69., 70., 71., 72., 62., 0., 64., 0., 66., 0., 68., 0., 70., 0., 72., 0., 63., 64., 65., 66., 0., 0., 69., 70., 71., 72., 0., 0., 64., 0., 66., 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 60.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 61.f, 62.f, 63.f, 64.f, 65.f, 66.f, 67.f, 68.f, 69.f, 70.f, 71.f, 72.f, 62.f, 0.f, 64.f, 0.f, 66.f, 0.f, 68.f, 0.f, 70.f, 0.f, 72.f, 0.f, 63.f, 64.f, 65.f, 66.f, 0.f, 0.f, 69.f, 70.f, 71.f, 72.f, 0.f, 0.f, 64.f, 0.f, 66.f,
0., 0., 0., 70., 0., 72., 0., 0., 0., 65., 66., 0., 0., 0., 0., 71., 72., 0., 0., 0., 0., 66., 0., 0., 0., 0., 0., 72., 0., 0., 0., 0., 0., 67., 68., 69., 70., 71., 72., 0., 0., 0., 0., 0., 0., 68., 0., 70., 0., 72., 0., 0., 0., 0., 0., 0., 0., 69., 70., 71., 72., 0., 0., 0.f, 0.f, 0.f, 70.f, 0.f, 72.f, 0.f, 0.f, 0.f, 65.f, 66.f, 0.f, 0.f, 0.f, 0.f, 71.f, 72.f, 0.f, 0.f, 0.f, 0.f, 66.f, 0.f, 0.f, 0.f, 0.f, 0.f, 72.f, 0.f, 0.f, 0.f, 0.f, 0.f, 67.f, 68.f, 69.f, 70.f, 71.f, 72.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 68.f, 0.f, 70.f, 0.f, 72.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 69.f, 70.f, 71.f, 72.f, 0.f, 0.f,
0., 0., 0., 0., 0., 0., 70., 0., 72., 0., 0., 0., 0., 0., 0., 0., 0., 0., 71., 72., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 72., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.}); 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 70.f, 0.f, 72.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 71.f, 72.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 72.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f});
graph::Context context(1); graph::Context context(1);
nd4j::ops::ConvolutionUtils::vol2col(context, volume, columns, sD, sH, sW, pD, pH, pW, dD, dH, dW); nd4j::ops::ConvolutionUtils::vol2col(context, volume, columns, sD, sH, sW, pD, pH, pW, dD, dH, dW);
@ -2146,7 +2145,7 @@ TEST_F(ConvolutionTests1, col2im_test1) {
auto columns = NDArrayFactory::create<float>('c', {bS, iC, kH, kW, oH, oW}); auto columns = NDArrayFactory::create<float>('c', {bS, iC, kH, kW, oH, oW});
columns.linspace(1); columns.linspace(1);
auto imageExpected = NDArrayFactory::create<float>('c', {bS, iC, iH, iW}, {1., 7., 12., 34., 17., 39., 44., 98., 33., 71., 76., 162., 49., 103., 108., 226.}); auto imageExpected = NDArrayFactory::create<float>('c', {bS, iC, iH, iW}, {1.f, 7.f, 12.f, 34.f, 17.f, 39.f, 44.f, 98.f, 33.f, 71.f, 76.f, 162.f, 49.f, 103.f, 108.f, 226.f});
LaunchContext ctx; LaunchContext ctx;
nd4j::ops::helpers::col2im(ctx, columns, image, sH, sW, pH, pW, iH, iW, dH, dW); nd4j::ops::helpers::col2im(ctx, columns, image, sH, sW, pH, pW, iH, iW, dH, dW);
@ -2165,12 +2164,12 @@ TEST_F(ConvolutionTests1, upsampling2d_test1) {
auto input = NDArrayFactory::create<float>('c', {bS, iH, iW, iC}); auto input = NDArrayFactory::create<float>('c', {bS, iH, iW, iC});
input.linspace(1); input.linspace(1);
auto expOutput = NDArrayFactory::create<float>('c', {bS, iH*factorH, iW*factorW, iC}, {1., 2., 3., 1., 2., 3., 1., 2., 3., 4., 5., 6., 4., 5., 6., 4., 5., 6., 1., 2., 3., 1., 2., 3., 1., 2., 3., 4., 5., 6., 4., 5., 6., 4., 5., 6., auto expOutput = NDArrayFactory::create<float>('c', {bS, iH*factorH, iW*factorW, iC}, {1.f, 2.f, 3.f, 1.f, 2.f, 3.f, 1.f, 2.f, 3.f, 4.f, 5.f, 6.f, 4.f, 5.f, 6.f, 4.f, 5.f, 6.f, 1.f, 2.f, 3.f, 1.f, 2.f, 3.f, 1.f, 2.f, 3.f, 4.f, 5.f, 6.f, 4.f, 5.f, 6.f, 4.f, 5.f, 6.f,
7., 8., 9., 7., 8., 9., 7., 8., 9.,10., 11., 12.,10., 11., 12.,10., 11., 12., 7., 8., 9., 7., 8., 9., 7., 8., 9.,10., 11., 12.,10., 11., 12.,10., 11., 12., 7.f, 8.f, 9.f, 7.f, 8.f, 9.f, 7.f, 8.f, 9.f, 10.f, 11.f, 12.f, 10.f, 11.f, 12.f, 10.f, 11.f, 12.f, 7.f, 8.f, 9.f, 7.f, 8.f, 9.f, 7.f, 8.f, 9.f, 10.f, 11.f, 12.f, 10.f, 11.f, 12.f, 10.f, 11.f, 12.f,
13., 14., 15.,13., 14., 15.,13., 14., 15.,16., 17., 18.,16., 17., 18.,16., 17., 18.,13., 14., 15.,13., 14., 15.,13., 14., 15.,16., 17., 18.,16., 17., 18.,16., 17., 18., 13.f, 14.f, 15.f, 13.f, 14.f, 15.f, 13.f, 14.f, 15.f, 16.f, 17.f, 18.f, 16.f, 17.f, 18.f, 16.f, 17.f, 18.f, 13.f, 14.f, 15.f, 13.f, 14.f, 15.f, 13.f, 14.f, 15.f, 16.f, 17.f, 18.f, 16.f, 17.f, 18.f, 16.f, 17.f, 18.f,
19., 20., 21.,19., 20., 21.,19., 20., 21.,22., 23., 24.,22., 23., 24.,22., 23., 24.,19., 20., 21.,19., 20., 21.,19., 20., 21.,22., 23., 24.,22., 23., 24.,22., 23., 24., 19.f, 20.f, 21.f, 19.f, 20.f, 21.f, 19.f, 20.f, 21.f, 22.f, 23.f, 24.f, 22.f, 23.f, 24.f, 22.f, 23.f, 24.f, 19.f, 20.f, 21.f, 19.f, 20.f, 21.f, 19.f, 20.f, 21.f, 22.f, 23.f, 24.f, 22.f, 23.f, 24.f, 22.f, 23.f, 24.f,
25., 26., 27.,25., 26., 27.,25., 26., 27.,28., 29., 30.,28., 29., 30.,28., 29., 30.,25., 26., 27.,25., 26., 27.,25., 26., 27.,28., 29., 30.,28., 29., 30.,28., 29., 30., 25.f, 26.f, 27.f, 25.f, 26.f, 27.f, 25.f, 26.f, 27.f, 28.f, 29.f, 30.f, 28.f, 29.f, 30.f, 28.f, 29.f, 30.f, 25.f, 26.f, 27.f, 25.f, 26.f, 27.f, 25.f, 26.f, 27.f, 28.f, 29.f, 30.f, 28.f, 29.f, 30.f, 28.f, 29.f, 30.f,
31., 32., 33.,31., 32., 33.,31., 32., 33.,34., 35., 36.,34., 35., 36.,34., 35., 36.,31., 32., 33.,31., 32., 33.,31., 32., 33.,34., 35., 36.,34., 35., 36.,34., 35., 36.}); 31.f, 32.f, 33.f, 31.f, 32.f, 33.f, 31.f, 32.f, 33.f, 34.f, 35.f, 36.f, 34.f, 35.f, 36.f, 34.f, 35.f, 36.f, 31.f, 32.f, 33.f, 31.f, 32.f, 33.f, 31.f, 32.f, 33.f, 34.f, 35.f, 36.f, 34.f, 35.f, 36.f, 34.f, 35.f, 36.f});
nd4j::ops::upsampling2d op; nd4j::ops::upsampling2d op;
auto results = op.execute({&input}, {}, {factorH, factorW, isNCHW}); auto results = op.execute({&input}, {}, {factorH, factorW, isNCHW});
@ -2193,12 +2192,12 @@ TEST_F(ConvolutionTests1, upsampling2d_test2) {
auto input = NDArrayFactory::create<float>('c', {bS, iC, iH, iW}); auto input = NDArrayFactory::create<float>('c', {bS, iC, iH, iW});
input.linspace(1); input.linspace(1);
auto expOutput = NDArrayFactory::create<float>('c', {bS, iC, iH*factorH, iW*factorW}, {1., 1., 1., 2., 2., 2., 1., 1., 1., 2., 2., 2., 3., 3., 3., 4., 4., 4., 3., 3., 3., 4., 4., 4., auto expOutput = NDArrayFactory::create<float>('c', {bS, iC, iH*factorH, iW*factorW}, {1.f, 1.f, 1.f, 2.f, 2.f, 2.f, 1.f, 1.f, 1.f, 2.f, 2.f, 2.f, 3.f, 3.f, 3.f, 4.f, 4.f, 4.f, 3.f, 3.f, 3.f, 4.f, 4.f, 4.f,
5., 5., 5., 6., 6., 6., 5., 5., 5., 6., 6., 6., 7., 7., 7., 8., 8., 8., 7., 7., 7., 8., 8., 8., 9., 9., 9., 10., 10., 10., 9., 9., 9., 10., 10., 10.,11., 11., 11., 12., 12., 12.,11., 11., 11., 12., 12., 12., 5.f, 5.f, 5.f, 6.f, 6.f, 6.f, 5.f, 5.f, 5.f, 6.f, 6.f, 6.f, 7.f, 7.f, 7.f, 8.f, 8.f, 8.f, 7.f, 7.f, 7.f, 8.f, 8.f, 8.f, 9.f, 9.f, 9.f, 10.f, 10.f, 10.f, 9.f, 9.f, 9.f, 10.f, 10.f, 10.f, 11.f, 11.f, 11.f, 12.f, 12.f, 12.f, 11.f, 11.f, 11.f, 12.f, 12.f, 12.f,
13., 13., 13., 14., 14., 14.,13., 13., 13., 14., 14., 14.,15., 15., 15., 16., 16., 16.,15., 15., 15., 16., 16., 16.,17., 17., 17., 18., 18., 18.,17., 17., 17., 18., 18., 18.,19., 19., 19., 20., 20., 20.,19., 19., 19., 20., 20., 20., 13.f, 13.f, 13.f, 14.f, 14.f, 14.f, 13.f, 13.f, 13.f, 14.f, 14.f, 14.f, 15.f, 15.f, 15.f, 16.f, 16.f, 16.f, 15.f, 15.f, 15.f, 16.f, 16.f, 16.f, 17.f, 17.f, 17.f, 18.f, 18.f, 18.f, 17.f, 17.f, 17.f, 18.f, 18.f, 18.f, 19.f, 19.f, 19.f, 20.f, 20.f, 20.f, 19.f, 19.f, 19.f, 20.f, 20.f, 20.f,
21., 21., 21., 22., 22., 22.,21., 21., 21., 22., 22., 22.,23., 23., 23., 24., 24., 24.,23., 23., 23., 24., 24., 24.,25., 25., 25., 26., 26., 26.,25., 25., 25., 26., 26., 26.,27., 27., 27., 28., 28., 28.,27., 27., 27., 28., 28., 28., 21.f, 21.f, 21.f, 22.f, 22.f, 22.f, 21.f, 21.f, 21.f, 22.f, 22.f, 22.f, 23.f, 23.f, 23.f, 24.f, 24.f, 24.f, 23.f, 23.f, 23.f, 24.f, 24.f, 24.f, 25.f, 25.f, 25.f, 26.f, 26.f, 26.f, 25.f, 25.f, 25.f, 26.f, 26.f, 26.f, 27.f, 27.f, 27.f, 28.f, 28.f, 28.f, 27.f, 27.f, 27.f, 28.f, 28.f, 28.f,
29., 29., 29., 30., 30., 30.,29., 29., 29., 30., 30., 30.,31., 31., 31., 32., 32., 32.,31., 31., 31., 32., 32., 32., 29.f, 29.f, 29.f, 30.f, 30.f, 30.f, 29.f, 29.f, 29.f, 30.f, 30.f, 30.f, 31.f, 31.f, 31.f, 32.f, 32.f, 32.f, 31.f, 31.f, 31.f, 32.f, 32.f, 32.f,
33., 33., 33., 34., 34., 34.,33., 33., 33., 34., 34., 34.,35., 35., 35., 36., 36., 36.,35., 35., 35., 36., 36., 36.}); 33.f, 33.f, 33.f, 34.f, 34.f, 34.f, 33.f, 33.f, 33.f, 34.f, 34.f, 34.f, 35.f, 35.f, 35.f, 36.f, 36.f, 36.f, 35.f, 35.f, 35.f, 36.f, 36.f, 36.f});
nd4j::ops::upsampling2d op; nd4j::ops::upsampling2d op;
auto results = op.execute({&input}, {}, {factorH, factorW, isNCHW}); auto results = op.execute({&input}, {}, {factorH, factorW, isNCHW});
@ -2222,21 +2221,21 @@ TEST_F(ConvolutionTests1, upsampling3d_test1) {
auto input = NDArrayFactory::create<float>('c', {bS, iD, iH, iW, iC}); auto input = NDArrayFactory::create<float>('c', {bS, iD, iH, iW, iC});
input.linspace(1); input.linspace(1);
auto expOutput = NDArrayFactory::create<float>('c', {bS, iD*factorD, iH*factorH, iW*factorW, iC}, {1., 2., 3., 1., 2., 3., 4., 5., 6., 4., 5., 6., 1., 2., 3., 1., 2., 3., 4., 5., 6., 4., 5., 6., 1., 2., 3., 1., 2., 3., 4., 5., 6., 4., 5., 6., 7., 8., 9., 7., 8., 9.,10., 11., 12.,10., 11., 12., 7., 8., 9., 7., 8., 9.,10., 11., 12.,10., 11., 12., auto expOutput = NDArrayFactory::create<float>('c', {bS, iD*factorD, iH*factorH, iW*factorW, iC}, {1.f, 2.f, 3.f, 1.f, 2.f, 3.f, 4.f, 5.f, 6.f, 4.f, 5.f, 6.f, 1.f, 2.f, 3.f, 1.f, 2.f, 3.f, 4.f, 5.f, 6.f, 4.f, 5.f, 6.f, 1.f, 2.f, 3.f, 1.f, 2.f, 3.f, 4.f, 5.f, 6.f, 4.f, 5.f, 6.f, 7.f, 8.f, 9.f, 7.f, 8.f, 9.f, 10.f, 11.f, 12.f, 10.f, 11.f, 12.f, 7.f, 8.f, 9.f, 7.f, 8.f, 9.f, 10.f, 11.f, 12.f, 10.f, 11.f, 12.f,
7., 8., 9., 7., 8., 9.,10., 11., 12.,10., 11., 12., 1., 2., 3., 1., 2., 3., 4., 5., 6., 4., 5., 6., 1., 2., 3., 1., 2., 3., 4., 5., 6., 4., 5., 6., 1., 2., 3., 1., 2., 3., 4., 5., 6., 4., 5., 6., 7., 8., 9., 7., 8., 9.,10., 11., 12.,10., 11., 12., 7.f, 8.f, 9.f, 7.f, 8.f, 9.f, 10.f, 11.f, 12.f, 10.f, 11.f, 12.f, 1.f, 2.f, 3.f, 1.f, 2.f, 3.f, 4.f, 5.f, 6.f, 4.f, 5.f, 6.f, 1.f, 2.f, 3.f, 1.f, 2.f, 3.f, 4.f, 5.f, 6.f, 4.f, 5.f, 6.f, 1.f, 2.f, 3.f, 1.f, 2.f, 3.f, 4.f, 5.f, 6.f, 4.f, 5.f, 6.f, 7.f, 8.f, 9.f, 7.f, 8.f, 9.f, 10.f, 11.f, 12.f, 10.f, 11.f, 12.f,
7., 8., 9., 7., 8., 9.,10., 11., 12.,10., 11., 12., 7., 8., 9., 7., 8., 9.,10., 11., 12.,10., 11., 12.,13., 14., 15.,13., 14., 15.,16., 17., 18.,16., 17., 18.,13., 14., 15.,13., 14., 15.,16., 17., 18.,16., 17., 18.,13., 14., 15.,13., 14., 15.,16., 17., 18.,16., 17., 18., 7.f, 8.f, 9.f, 7.f, 8.f, 9.f, 10.f, 11.f, 12.f, 10.f, 11.f, 12.f, 7.f, 8.f, 9.f, 7.f, 8.f, 9.f, 10.f, 11.f, 12.f, 10.f, 11.f, 12.f, 13.f, 14.f, 15.f, 13.f, 14.f, 15.f, 16.f, 17.f, 18.f, 16.f, 17.f, 18.f, 13.f, 14.f, 15.f, 13.f, 14.f, 15.f, 16.f, 17.f, 18.f, 16.f, 17.f, 18.f, 13.f, 14.f, 15.f, 13.f, 14.f, 15.f, 16.f, 17.f, 18.f, 16.f, 17.f, 18.f,
19., 20., 21.,19., 20., 21.,22., 23., 24.,22., 23., 24.,19., 20., 21.,19., 20., 21.,22., 23., 24.,22., 23., 24.,19., 20., 21.,19., 20., 21.,22., 23., 24.,22., 23., 24.,13., 14., 15.,13., 14., 15.,16., 17., 18.,16., 17., 18.,13., 14., 15.,13., 14., 15.,16., 17., 18.,16., 17., 18., 19.f, 20.f, 21.f, 19.f, 20.f, 21.f, 22.f, 23.f, 24.f, 22.f, 23.f, 24.f, 19.f, 20.f, 21.f, 19.f, 20.f, 21.f, 22.f, 23.f, 24.f, 22.f, 23.f, 24.f, 19.f, 20.f, 21.f, 19.f, 20.f, 21.f, 22.f, 23.f, 24.f, 22.f, 23.f, 24.f, 13.f, 14.f, 15.f, 13.f, 14.f, 15.f, 16.f, 17.f, 18.f, 16.f, 17.f, 18.f, 13.f, 14.f, 15.f, 13.f, 14.f, 15.f, 16.f, 17.f, 18.f, 16.f, 17.f, 18.f,
13., 14., 15.,13., 14., 15.,16., 17., 18.,16., 17., 18.,19., 20., 21.,19., 20., 21.,22., 23., 24.,22., 23., 24.,19., 20., 21.,19., 20., 21.,22., 23., 24.,22., 23., 24.,19., 20., 21.,19., 20., 21.,22., 23., 24.,22., 23., 24.,25., 26., 27.,25., 26., 27.,28., 29., 30.,28., 29., 30., 13.f, 14.f, 15.f, 13.f, 14.f, 15.f, 16.f, 17.f, 18.f, 16.f, 17.f, 18.f, 19.f, 20.f, 21.f, 19.f, 20.f, 21.f, 22.f, 23.f, 24.f, 22.f, 23.f, 24.f, 19.f, 20.f, 21.f, 19.f, 20.f, 21.f, 22.f, 23.f, 24.f, 22.f, 23.f, 24.f, 19.f, 20.f, 21.f, 19.f, 20.f, 21.f, 22.f, 23.f, 24.f, 22.f, 23.f, 24.f, 25.f, 26.f, 27.f, 25.f, 26.f, 27.f, 28.f, 29.f, 30.f, 28.f, 29.f, 30.f,
25., 26., 27.,25., 26., 27.,28., 29., 30.,28., 29., 30.,25., 26., 27.,25., 26., 27.,28., 29., 30.,28., 29., 30.,31., 32., 33.,31., 32., 33.,34., 35., 36.,34., 35., 36.,31., 32., 33.,31., 32., 33.,34., 35., 36.,34., 35., 36.,31., 32., 33.,31., 32., 33.,34., 35., 36.,34., 35., 36., 25.f, 26.f, 27.f, 25.f, 26.f, 27.f, 28.f, 29.f, 30.f, 28.f, 29.f, 30.f, 25.f, 26.f, 27.f, 25.f, 26.f, 27.f, 28.f, 29.f, 30.f, 28.f, 29.f, 30.f, 31.f, 32.f, 33.f, 31.f, 32.f, 33.f, 34.f, 35.f, 36.f, 34.f, 35.f, 36.f, 31.f, 32.f, 33.f, 31.f, 32.f, 33.f, 34.f, 35.f, 36.f, 34.f, 35.f, 36.f, 31.f, 32.f, 33.f, 31.f, 32.f, 33.f, 34.f, 35.f, 36.f, 34.f, 35.f, 36.f,
25., 26., 27.,25., 26., 27.,28., 29., 30.,28., 29., 30.,25., 26., 27.,25., 26., 27.,28., 29., 30.,28., 29., 30.,25., 26., 27.,25., 26., 27.,28., 29., 30.,28., 29., 30.,31., 32., 33.,31., 32., 33.,34., 35., 36.,34., 35., 36.,31., 32., 33.,31., 32., 33.,34., 35., 36.,34., 35., 36., 25.f, 26.f, 27.f, 25.f, 26.f, 27.f, 28.f, 29.f, 30.f, 28.f, 29.f, 30.f, 25.f, 26.f, 27.f, 25.f, 26.f, 27.f, 28.f, 29.f, 30.f, 28.f, 29.f, 30.f, 25.f, 26.f, 27.f, 25.f, 26.f, 27.f, 28.f, 29.f, 30.f, 28.f, 29.f, 30.f, 31.f, 32.f, 33.f, 31.f, 32.f, 33.f, 34.f, 35.f, 36.f, 34.f, 35.f, 36.f, 31.f, 32.f, 33.f, 31.f, 32.f, 33.f, 34.f, 35.f, 36.f, 34.f, 35.f, 36.f,
31., 32., 33.,31., 32., 33.,34., 35., 36.,34., 35., 36.,37., 38., 39.,37., 38., 39.,40., 41., 42.,40., 41., 42.,37., 38., 39.,37., 38., 39.,40., 41., 42.,40., 41., 42.,37., 38., 39.,37., 38., 39.,40., 41., 42.,40., 41., 42.,43., 44., 45.,43., 44., 45.,46., 47., 48.,46., 47., 48., 31.f, 32.f, 33.f, 31.f, 32.f, 33.f, 34.f, 35.f, 36.f, 34.f, 35.f, 36.f, 37.f, 38.f, 39.f, 37.f, 38.f, 39.f, 40.f, 41.f, 42.f, 40.f, 41.f, 42.f, 37.f, 38.f, 39.f, 37.f, 38.f, 39.f, 40.f, 41.f, 42.f, 40.f, 41.f, 42.f, 37.f, 38.f, 39.f, 37.f, 38.f, 39.f, 40.f, 41.f, 42.f, 40.f, 41.f, 42.f, 43.f, 44.f, 45.f, 43.f, 44.f, 45.f, 46.f, 47.f, 48.f, 46.f, 47.f, 48.f,
43., 44., 45.,43., 44., 45.,46., 47., 48.,46., 47., 48.,43., 44., 45.,43., 44., 45.,46., 47., 48.,46., 47., 48.,37., 38., 39.,37., 38., 39.,40., 41., 42.,40., 41., 42.,37., 38., 39.,37., 38., 39.,40., 41., 42.,40., 41., 42.,37., 38., 39.,37., 38., 39.,40., 41., 42.,40., 41., 42., 43.f, 44.f, 45.f, 43.f, 44.f, 45.f, 46.f, 47.f, 48.f, 46.f, 47.f, 48.f, 43.f, 44.f, 45.f, 43.f, 44.f, 45.f, 46.f, 47.f, 48.f, 46.f, 47.f, 48.f, 37.f, 38.f, 39.f, 37.f, 38.f, 39.f, 40.f, 41.f, 42.f, 40.f, 41.f, 42.f, 37.f, 38.f, 39.f, 37.f, 38.f, 39.f, 40.f, 41.f, 42.f, 40.f, 41.f, 42.f, 37.f, 38.f, 39.f, 37.f, 38.f, 39.f, 40.f, 41.f, 42.f, 40.f, 41.f, 42.f,
43., 44., 45.,43., 44., 45.,46., 47., 48.,46., 47., 48.,43., 44., 45.,43., 44., 45.,46., 47., 48.,46., 47., 48.,43., 44., 45.,43., 44., 45.,46., 47., 48.,46., 47., 48.,49., 50., 51.,49., 50., 51.,52., 53., 54.,52., 53., 54.,49., 50., 51.,49., 50., 51.,52., 53., 54.,52., 53., 54., 43.f, 44.f, 45.f, 43.f, 44.f, 45.f, 46.f, 47.f, 48.f, 46.f, 47.f, 48.f, 43.f, 44.f, 45.f, 43.f, 44.f, 45.f, 46.f, 47.f, 48.f, 46.f, 47.f, 48.f, 43.f, 44.f, 45.f, 43.f, 44.f, 45.f, 46.f, 47.f, 48.f, 46.f, 47.f, 48.f, 49.f, 50.f, 51.f, 49.f, 50.f, 51.f, 52.f, 53.f, 54.f, 52.f, 53.f, 54.f, 49.f, 50.f, 51.f, 49.f, 50.f, 51.f, 52.f, 53.f, 54.f, 52.f, 53.f, 54.f,
49., 50., 51.,49., 50., 51.,52., 53., 54.,52., 53., 54.,55., 56., 57.,55., 56., 57.,58., 59., 60.,58., 59., 60.,55., 56., 57.,55., 56., 57.,58., 59., 60.,58., 59., 60.,55., 56., 57.,55., 56., 57.,58., 59., 60.,58., 59., 60.,49., 50., 51.,49., 50., 51.,52., 53., 54.,52., 53., 54., 49.f, 50.f, 51.f, 49.f, 50.f, 51.f, 52.f, 53.f, 54.f, 52.f, 53.f, 54.f, 55.f, 56.f, 57.f, 55.f, 56.f, 57.f, 58.f, 59.f, 60.f, 58.f, 59.f, 60.f, 55.f, 56.f, 57.f, 55.f, 56.f, 57.f, 58.f, 59.f, 60.f, 58.f, 59.f, 60.f, 55.f, 56.f, 57.f, 55.f, 56.f, 57.f, 58.f, 59.f, 60.f, 58.f, 59.f, 60.f, 49.f, 50.f, 51.f, 49.f, 50.f, 51.f, 52.f, 53.f, 54.f, 52.f, 53.f, 54.f,
49., 50., 51.,49., 50., 51.,52., 53., 54.,52., 53., 54.,49., 50., 51.,49., 50., 51.,52., 53., 54.,52., 53., 54.,55., 56., 57.,55., 56., 57.,58., 59., 60.,58., 59., 60.,55., 56., 57.,55., 56., 57.,58., 59., 60.,58., 59., 60.,55., 56., 57.,55., 56., 57.,58., 59., 60.,58., 59., 60., 49.f, 50.f, 51.f, 49.f, 50.f, 51.f, 52.f, 53.f, 54.f, 52.f, 53.f, 54.f, 49.f, 50.f, 51.f, 49.f, 50.f, 51.f, 52.f, 53.f, 54.f, 52.f, 53.f, 54.f, 55.f, 56.f, 57.f, 55.f, 56.f, 57.f, 58.f, 59.f, 60.f, 58.f, 59.f, 60.f, 55.f, 56.f, 57.f, 55.f, 56.f, 57.f, 58.f, 59.f, 60.f, 58.f, 59.f, 60.f, 55.f, 56.f, 57.f, 55.f, 56.f, 57.f, 58.f, 59.f, 60.f, 58.f, 59.f, 60.f,
61., 62., 63.,61., 62., 63.,64., 65., 66.,64., 65., 66.,61., 62., 63.,61., 62., 63.,64., 65., 66.,64., 65., 66.,61., 62., 63.,61., 62., 63.,64., 65., 66.,64., 65., 66.,67., 68., 69.,67., 68., 69.,70., 71., 72.,70., 71., 72.,67., 68., 69.,67., 68., 69.,70., 71., 72.,70., 71., 72., 61.f, 62.f, 63.f, 61.f, 62.f, 63.f, 64.f, 65.f, 66.f, 64.f, 65.f, 66.f, 61.f, 62.f, 63.f, 61.f, 62.f, 63.f, 64.f, 65.f, 66.f, 64.f, 65.f, 66.f, 61.f, 62.f, 63.f, 61.f, 62.f, 63.f, 64.f, 65.f, 66.f, 64.f, 65.f, 66.f, 67.f, 68.f, 69.f, 67.f, 68.f, 69.f, 70.f, 71.f, 72.f, 70.f, 71.f, 72.f, 67.f, 68.f, 69.f, 67.f, 68.f, 69.f, 70.f, 71.f, 72.f, 70.f, 71.f, 72.f,
67., 68., 69.,67., 68., 69.,70., 71., 72.,70., 71., 72.,61., 62., 63.,61., 62., 63.,64., 65., 66.,64., 65., 66.,61., 62., 63.,61., 62., 63.,64., 65., 66.,64., 65., 66.,61., 62., 63.,61., 62., 63.,64., 65., 66.,64., 65., 66.,67., 68., 69.,67., 68., 69.,70., 71., 72.,70., 71., 72., 67.f, 68.f, 69.f, 67.f, 68.f, 69.f, 70.f, 71.f, 72.f, 70.f, 71.f, 72.f, 61.f, 62.f, 63.f, 61.f, 62.f, 63.f, 64.f, 65.f, 66.f, 64.f, 65.f, 66.f, 61.f, 62.f, 63.f, 61.f, 62.f, 63.f, 64.f, 65.f, 66.f, 64.f, 65.f, 66.f, 61.f, 62.f, 63.f, 61.f, 62.f, 63.f, 64.f, 65.f, 66.f, 64.f, 65.f, 66.f, 67.f, 68.f, 69.f, 67.f, 68.f, 69.f, 70.f, 71.f, 72.f, 70.f, 71.f, 72.f,
67., 68., 69.,67., 68., 69.,70., 71., 72.,70., 71., 72.,67., 68., 69.,67., 68., 69.,70., 71., 72.,70., 71., 72.}); 67.f, 68.f, 69.f, 67.f, 68.f, 69.f, 70.f, 71.f, 72.f, 70.f, 71.f, 72.f, 67.f, 68.f, 69.f, 67.f, 68.f, 69.f, 70.f, 71.f, 72.f, 70.f, 71.f, 72.f});
nd4j::ops::upsampling3d op; nd4j::ops::upsampling3d op;
auto results = op.execute({&input}, {}, {factorD, factorH, factorW, isNCDHW}); auto results = op.execute({&input}, {}, {factorD, factorH, factorW, isNCDHW});
@ -2259,18 +2258,18 @@ TEST_F(ConvolutionTests1, upsampling3d_test2) {
auto input = NDArrayFactory::create<float>('c', {bS, iC, iD, iH, iW}); auto input = NDArrayFactory::create<float>('c', {bS, iC, iD, iH, iW});
input.linspace(1); input.linspace(1);
auto expOutput = NDArrayFactory::create<float>('c', {bS, iC, iD*factorD, iH*factorH, iW*factorW}, { 1., 1., 2., 2., 1., 1., 2., 2., 1., 1., 2., 2., 3., 3., 4., 4., 3., 3., 4., 4., 3., 3., 4., 4., 1., 1., 2., 2., 1., 1., 2., 2., 1., 1., 2., 2., 3., 3., 4., 4., 3., 3., 4., 4., 3., 3., 4., 4., 5., 5., 6., 6., 5., 5., 6., 6., 5., 5., 6., 6., 7., 7., 8., 8., 7., 7., 8., 8., 7., 7., 8., 8., auto expOutput = NDArrayFactory::create<float>('c', {bS, iC, iD*factorD, iH*factorH, iW*factorW}, { 1.f, 1.f, 2.f, 2.f, 1.f, 1.f, 2.f, 2.f, 1.f, 1.f, 2.f, 2.f, 3.f, 3.f, 4.f, 4.f, 3.f, 3.f, 4.f, 4.f, 3.f, 3.f, 4.f, 4.f, 1.f, 1.f, 2.f, 2.f, 1.f, 1.f, 2.f, 2.f, 1.f, 1.f, 2.f, 2.f, 3.f, 3.f, 4.f, 4.f, 3.f, 3.f, 4.f, 4.f, 3.f, 3.f, 4.f, 4.f, 5.f, 5.f, 6.f, 6.f, 5.f, 5.f, 6.f, 6.f, 5.f, 5.f, 6.f, 6.f, 7.f, 7.f, 8.f, 8.f, 7.f, 7.f, 8.f, 8.f, 7.f, 7.f, 8.f, 8.f,
5., 5., 6., 6., 5., 5., 6., 6., 5., 5., 6., 6., 7., 7., 8., 8., 7., 7., 8., 8., 7., 7., 8., 8., 9., 9., 10., 10., 9., 9., 10., 10., 9., 9., 10., 10.,11., 11., 12., 12.,11., 11., 12., 12.,11., 11., 12., 12., 9., 9., 10., 10., 9., 9., 10., 10., 9., 9., 10., 10.,11., 11., 12., 12.,11., 11., 12., 12.,11., 11., 12., 12., 5.f, 5.f, 6.f, 6.f, 5.f, 5.f, 6.f, 6.f, 5.f, 5.f, 6.f, 6.f, 7.f, 7.f, 8.f, 8.f, 7.f, 7.f, 8.f, 8.f, 7.f, 7.f, 8.f, 8.f, 9.f, 9.f, 10.f, 10.f, 9.f, 9.f, 10.f, 10.f, 9.f, 9.f, 10.f, 10.f, 11.f, 11.f, 12.f, 12.f, 11.f, 11.f, 12.f, 12.f, 11.f, 11.f, 12.f, 12.f, 9.f, 9.f, 10.f, 10.f, 9.f, 9.f, 10.f, 10.f, 9.f, 9.f, 10.f, 10.f, 11.f, 11.f, 12.f, 12.f, 11.f, 11.f, 12.f, 12.f, 11.f, 11.f, 12.f, 12.f,
13., 13., 14., 14.,13., 13., 14., 14.,13., 13., 14., 14.,15., 15., 16., 16.,15., 15., 16., 16.,15., 15., 16., 16.,13., 13., 14., 14.,13., 13., 14., 14.,13., 13., 14., 14.,15., 15., 16., 16.,15., 15., 16., 16.,15., 15., 16., 16.,17., 17., 18., 18.,17., 17., 18., 18.,17., 17., 18., 18.,19., 19., 20., 20.,19., 19., 20., 20.,19., 19., 20., 20., 13.f, 13.f, 14.f, 14.f, 13.f, 13.f, 14.f, 14.f, 13.f, 13.f, 14.f, 14.f, 15.f, 15.f, 16.f, 16.f, 15.f, 15.f, 16.f, 16.f, 15.f, 15.f, 16.f, 16.f, 13.f, 13.f, 14.f, 14.f, 13.f, 13.f, 14.f, 14.f, 13.f, 13.f, 14.f, 14.f, 15.f, 15.f, 16.f, 16.f, 15.f, 15.f, 16.f, 16.f, 15.f, 15.f, 16.f, 16.f, 17.f, 17.f, 18.f, 18.f, 17.f, 17.f, 18.f, 18.f, 17.f, 17.f, 18.f, 18.f, 19.f, 19.f, 20.f, 20.f, 19.f, 19.f, 20.f, 20.f, 19.f, 19.f, 20.f, 20.f,
17., 17., 18., 18.,17., 17., 18., 18.,17., 17., 18., 18.,19., 19., 20., 20.,19., 19., 20., 20.,19., 19., 20., 20.,21., 21., 22., 22.,21., 21., 22., 22.,21., 21., 22., 22.,23., 23., 24., 24.,23., 23., 24., 24.,23., 23., 24., 24.,21., 21., 22., 22.,21., 21., 22., 22.,21., 21., 22., 22.,23., 23., 24., 24.,23., 23., 24., 24.,23., 23., 24., 24., 17.f, 17.f, 18.f, 18.f, 17.f, 17.f, 18.f, 18.f, 17.f, 17.f, 18.f, 18.f, 19.f, 19.f, 20.f, 20.f, 19.f, 19.f, 20.f, 20.f, 19.f, 19.f, 20.f, 20.f, 21.f, 21.f, 22.f, 22.f, 21.f, 21.f, 22.f, 22.f, 21.f, 21.f, 22.f, 22.f, 23.f, 23.f, 24.f, 24.f, 23.f, 23.f, 24.f, 24.f, 23.f, 23.f, 24.f, 24.f, 21.f, 21.f, 22.f, 22.f, 21.f, 21.f, 22.f, 22.f, 21.f, 21.f, 22.f, 22.f, 23.f, 23.f, 24.f, 24.f, 23.f, 23.f, 24.f, 24.f, 23.f, 23.f, 24.f, 24.f,
25., 25., 26., 26.,25., 25., 26., 26.,25., 25., 26., 26.,27., 27., 28., 28.,27., 27., 28., 28.,27., 27., 28., 28.,25., 25., 26., 26.,25., 25., 26., 26.,25., 25., 26., 26.,27., 27., 28., 28.,27., 27., 28., 28.,27., 27., 28., 28.,29., 29., 30., 30.,29., 29., 30., 30.,29., 29., 30., 30.,31., 31., 32., 32.,31., 31., 32., 32.,31., 31., 32., 32., 25.f, 25.f, 26.f, 26.f, 25.f, 25.f, 26.f, 26.f, 25.f, 25.f, 26.f, 26.f, 27.f, 27.f, 28.f, 28.f, 27.f, 27.f, 28.f, 28.f, 27.f, 27.f, 28.f, 28.f, 25.f, 25.f, 26.f, 26.f, 25.f, 25.f, 26.f, 26.f, 25.f, 25.f, 26.f, 26.f, 27.f, 27.f, 28.f, 28.f, 27.f, 27.f, 28.f, 28.f, 27.f, 27.f, 28.f, 28.f, 29.f, 29.f, 30.f, 30.f, 29.f, 29.f, 30.f, 30.f, 29.f, 29.f, 30.f, 30.f, 31.f, 31.f, 32.f, 32.f, 31.f, 31.f, 32.f, 32.f, 31.f, 31.f, 32.f, 32.f,
29., 29., 30., 30.,29., 29., 30., 30.,29., 29., 30., 30.,31., 31., 32., 32.,31., 31., 32., 32.,31., 31., 32., 32.,33., 33., 34., 34.,33., 33., 34., 34.,33., 33., 34., 34.,35., 35., 36., 36.,35., 35., 36., 36.,35., 35., 36., 36.,33., 33., 34., 34.,33., 33., 34., 34.,33., 33., 34., 34.,35., 35., 36., 36.,35., 35., 36., 36.,35., 35., 36., 36., 29.f, 29.f, 30.f, 30.f, 29.f, 29.f, 30.f, 30.f, 29.f, 29.f, 30.f, 30.f, 31.f, 31.f, 32.f, 32.f, 31.f, 31.f, 32.f, 32.f, 31.f, 31.f, 32.f, 32.f, 33.f, 33.f, 34.f, 34.f, 33.f, 33.f, 34.f, 34.f, 33.f, 33.f, 34.f, 34.f, 35.f, 35.f, 36.f, 36.f, 35.f, 35.f, 36.f, 36.f, 35.f, 35.f, 36.f, 36.f, 33.f, 33.f, 34.f, 34.f, 33.f, 33.f, 34.f, 34.f, 33.f, 33.f, 34.f, 34.f, 35.f, 35.f, 36.f, 36.f, 35.f, 35.f, 36.f, 36.f, 35.f, 35.f, 36.f, 36.f,
37., 37., 38., 38.,37., 37., 38., 38.,37., 37., 38., 38.,39., 39., 40., 40.,39., 39., 40., 40.,39., 39., 40., 40.,37., 37., 38., 38.,37., 37., 38., 38.,37., 37., 38., 38.,39., 39., 40., 40.,39., 39., 40., 40.,39., 39., 40., 40.,41., 41., 42., 42.,41., 41., 42., 42.,41., 41., 42., 42.,43., 43., 44., 44.,43., 43., 44., 44.,43., 43., 44., 44., 37.f, 37.f, 38.f, 38.f, 37.f, 37.f, 38.f, 38.f, 37.f, 37.f, 38.f, 38.f, 39.f, 39.f, 40.f, 40.f, 39.f, 39.f, 40.f, 40.f, 39.f, 39.f, 40.f, 40.f, 37.f, 37.f, 38.f, 38.f, 37.f, 37.f, 38.f, 38.f, 37.f, 37.f, 38.f, 38.f, 39.f, 39.f, 40.f, 40.f, 39.f, 39.f, 40.f, 40.f, 39.f, 39.f, 40.f, 40.f, 41.f, 41.f, 42.f, 42.f, 41.f, 41.f, 42.f, 42.f, 41.f, 41.f, 42.f, 42.f, 43.f, 43.f, 44.f, 44.f, 43.f, 43.f, 44.f, 44.f, 43.f, 43.f, 44.f, 44.f,
41., 41., 42., 42.,41., 41., 42., 42.,41., 41., 42., 42.,43., 43., 44., 44.,43., 43., 44., 44.,43., 43., 44., 44.,45., 45., 46., 46.,45., 45., 46., 46.,45., 45., 46., 46.,47., 47., 48., 48.,47., 47., 48., 48.,47., 47., 48., 48.,45., 45., 46., 46.,45., 45., 46., 46.,45., 45., 46., 46.,47., 47., 48., 48.,47., 47., 48., 48.,47., 47., 48., 48., 41.f, 41.f, 42.f, 42.f, 41.f, 41.f, 42.f, 42.f, 41.f, 41.f, 42.f, 42.f, 43.f, 43.f, 44.f, 44.f, 43.f, 43.f, 44.f, 44.f, 43.f, 43.f, 44.f, 44.f, 45.f, 45.f, 46.f, 46.f, 45.f, 45.f, 46.f, 46.f, 45.f, 45.f, 46.f, 46.f, 47.f, 47.f, 48.f, 48.f, 47.f, 47.f, 48.f, 48.f, 47.f, 47.f, 48.f, 48.f, 45.f, 45.f, 46.f, 46.f, 45.f, 45.f, 46.f, 46.f, 45.f, 45.f, 46.f, 46.f, 47.f, 47.f, 48.f, 48.f, 47.f, 47.f, 48.f, 48.f, 47.f, 47.f, 48.f, 48.f,
49., 49., 50., 50.,49., 49., 50., 50.,49., 49., 50., 50.,51., 51., 52., 52.,51., 51., 52., 52.,51., 51., 52., 52.,49., 49., 50., 50.,49., 49., 50., 50.,49., 49., 50., 50.,51., 51., 52., 52.,51., 51., 52., 52.,51., 51., 52., 52.,53., 53., 54., 54.,53., 53., 54., 54.,53., 53., 54., 54.,55., 55., 56., 56.,55., 55., 56., 56.,55., 55., 56., 56., 49.f, 49.f, 50.f, 50.f, 49.f, 49.f, 50.f, 50.f, 49.f, 49.f, 50.f, 50.f, 51.f, 51.f, 52.f, 52.f, 51.f, 51.f, 52.f, 52.f, 51.f, 51.f, 52.f, 52.f, 49.f, 49.f, 50.f, 50.f, 49.f, 49.f, 50.f, 50.f, 49.f, 49.f, 50.f, 50.f, 51.f, 51.f, 52.f, 52.f, 51.f, 51.f, 52.f, 52.f, 51.f, 51.f, 52.f, 52.f, 53.f, 53.f, 54.f, 54.f, 53.f, 53.f, 54.f, 54.f, 53.f, 53.f, 54.f, 54.f, 55.f, 55.f, 56.f, 56.f, 55.f, 55.f, 56.f, 56.f, 55.f, 55.f, 56.f, 56.f,
53., 53., 54., 54.,53., 53., 54., 54.,53., 53., 54., 54.,55., 55., 56., 56.,55., 55., 56., 56.,55., 55., 56., 56.,57., 57., 58., 58.,57., 57., 58., 58.,57., 57., 58., 58.,59., 59., 60., 60.,59., 59., 60., 60.,59., 59., 60., 60.,57., 57., 58., 58.,57., 57., 58., 58.,57., 57., 58., 58.,59., 59., 60., 60.,59., 59., 60., 60.,59., 59., 60., 60., 53.f, 53.f, 54.f, 54.f, 53.f, 53.f, 54.f, 54.f, 53.f, 53.f, 54.f, 54.f, 55.f, 55.f, 56.f, 56.f, 55.f, 55.f, 56.f, 56.f, 55.f, 55.f, 56.f, 56.f, 57.f, 57.f, 58.f, 58.f, 57.f, 57.f, 58.f, 58.f, 57.f, 57.f, 58.f, 58.f, 59.f, 59.f, 60.f, 60.f, 59.f, 59.f, 60.f, 60.f, 59.f, 59.f, 60.f, 60.f, 57.f, 57.f, 58.f, 58.f, 57.f, 57.f, 58.f, 58.f, 57.f, 57.f, 58.f, 58.f, 59.f, 59.f, 60.f, 60.f, 59.f, 59.f, 60.f, 60.f, 59.f, 59.f, 60.f, 60.f,
61., 61., 62., 62.,61., 61., 62., 62.,61., 61., 62., 62.,63., 63., 64., 64.,63., 63., 64., 64.,63., 63., 64., 64.,61., 61., 62., 62.,61., 61., 62., 62.,61., 61., 62., 62.,63., 63., 64., 64.,63., 63., 64., 64.,63., 63., 64., 64.,65., 65., 66., 66.,65., 65., 66., 66.,65., 65., 66., 66.,67., 67., 68., 68.,67., 67., 68., 68.,67., 67., 68., 68., 61.f, 61.f, 62.f, 62.f, 61.f, 61.f, 62.f, 62.f, 61.f, 61.f, 62.f, 62.f, 63.f, 63.f, 64.f, 64.f, 63.f, 63.f, 64.f, 64.f, 63.f, 63.f, 64.f, 64.f, 61.f, 61.f, 62.f, 62.f, 61.f, 61.f, 62.f, 62.f, 61.f, 61.f, 62.f, 62.f, 63.f, 63.f, 64.f, 64.f, 63.f, 63.f, 64.f, 64.f, 63.f, 63.f, 64.f, 64.f, 65.f, 65.f, 66.f, 66.f, 65.f, 65.f, 66.f, 66.f, 65.f, 65.f, 66.f, 66.f, 67.f, 67.f, 68.f, 68.f, 67.f, 67.f, 68.f, 68.f, 67.f, 67.f, 68.f, 68.f,
65., 65., 66., 66.,65., 65., 66., 66.,65., 65., 66., 66.,67., 67., 68., 68.,67., 67., 68., 68.,67., 67., 68., 68.,69., 69., 70., 70.,69., 69., 70., 70.,69., 69., 70., 70.,71., 71., 72., 72.,71., 71., 72., 72.,71., 71., 72., 72.,69., 69., 70., 70.,69., 69., 70., 70.,69., 69., 70., 70.,71., 71., 72., 72.,71., 71., 72., 72.,71., 71., 72., 72.}); 65.f, 65.f, 66.f, 66.f, 65.f, 65.f, 66.f, 66.f, 65.f, 65.f, 66.f, 66.f, 67.f, 67.f, 68.f, 68.f, 67.f, 67.f, 68.f, 68.f, 67.f, 67.f, 68.f, 68.f, 69.f, 69.f, 70.f, 70.f, 69.f, 69.f, 70.f, 70.f, 69.f, 69.f, 70.f, 70.f, 71.f, 71.f, 72.f, 72.f, 71.f, 71.f, 72.f, 72.f, 71.f, 71.f, 72.f, 72.f, 69.f, 69.f, 70.f, 70.f, 69.f, 69.f, 70.f, 70.f, 69.f, 69.f, 70.f, 70.f, 71.f, 71.f, 72.f, 72.f, 71.f, 71.f, 72.f, 72.f, 71.f, 71.f, 72.f, 72.f});
nd4j::ops::upsampling3d op; nd4j::ops::upsampling3d op;
auto results = op.execute({&input}, {}, {factorD, factorH, factorW, isNCDHW}); auto results = op.execute({&input}, {}, {factorD, factorH, factorW, isNCDHW});
@ -2413,14 +2412,14 @@ TEST_F(ConvolutionTests1, deconv2d_test1) {
auto input = NDArrayFactory::create<float>('c', {bS, iH, iW, iC}); auto input = NDArrayFactory::create<float>('c', {bS, iH, iW, iC});
auto weights = NDArrayFactory::create<float>('c', {kH, kW, oC, iC}); auto weights = NDArrayFactory::create<float>('c', {kH, kW, oC, iC});
auto exp = NDArrayFactory::create<float>('c', {bS, oH, oW, oC}, { 2.75, 7.75, 12.75, 17.75, 22.75, 30.5 , 40.5 , 50.5 , 60.5 , 70.5 , 30.5 , 40.5 , 50.5 , 60.5 , 70.5 , 27.75, 32.75, 37.75, 42.75, 47.75, auto exp = NDArrayFactory::create<float>('c', {bS, oH, oW, oC}, { 2.75f, 7.75f, 12.75f, 17.75f, 22.75f, 30.5f, 40.5f, 50.5f, 60.5f, 70.5f, 30.5f, 40.5f, 50.5f, 60.5f, 70.5f, 27.75f, 32.75f, 37.75f, 42.75f, 47.75f,
55.5 , 65.5 , 75.5 , 85.5 , 95.5 ,161. , 181. , 201. , 221. , 241. ,161. , 181. , 201. , 221. , 241. ,105.5 , 115.5 , 125.5 , 135.5 , 145.5 , 55.5f, 65.5f, 75.5f, 85.5f, 95.5f, 161.f, 181.f, 201.f, 221.f, 241.f, 161.f, 181.f, 201.f, 221.f, 241.f, 105.5f, 115.5f, 125.5f, 135.5f, 145.5f,
55.5 , 65.5 , 75.5 , 85.5 , 95.5 ,161. , 181. , 201. , 221. , 241. ,161. , 181. , 201. , 221. , 241. ,105.5 , 115.5 , 125.5 , 135.5 , 145.5 , 55.5f, 65.5f, 75.5f, 85.5f, 95.5f, 161.f, 181.f, 201.f, 221.f, 241.f, 161.f, 181.f, 201.f, 221.f, 241.f, 105.5f, 115.5f, 125.5f, 135.5f, 145.5f,
52.75, 57.75, 62.75, 67.75, 72.75,130.5 , 140.5 , 150.5 , 160.5 , 170.5 ,130.5 , 140.5 , 150.5 , 160.5 , 170.5 , 77.75, 82.75, 87.75, 92.75, 97.75, 52.75f, 57.75f, 62.75f, 67.75f, 72.75f, 130.5f, 140.5f, 150.5f, 160.5f, 170.5f, 130.5f, 140.5f, 150.5f, 160.5f, 170.5f, 77.75f, 82.75f, 87.75f, 92.75f, 97.75f,
2.75, 7.75, 12.75, 17.75, 22.75, 30.5 , 40.5 , 50.5 , 60.5 , 70.5 , 30.5 , 40.5 , 50.5 , 60.5 , 70.5 , 27.75, 32.75, 37.75, 42.75, 47.75, 2.75f, 7.75f, 12.75f, 17.75f, 22.75f, 30.5f, 40.5f, 50.5f, 60.5f, 70.5f, 30.5f, 40.5f, 50.5f, 60.5f, 70.5f, 27.75f, 32.75f, 37.75f, 42.75f, 47.75f,
55.5 , 65.5 , 75.5 , 85.5 , 95.5 ,161. , 181. , 201. , 221. , 241. ,161. , 181. , 201. , 221. , 241. ,105.5 , 115.5 , 125.5 , 135.5 , 145.5 , 55.5f, 65.5f, 75.5f, 85.5f, 95.5f, 161.f, 181.f, 201.f, 221.f, 241.f, 161.f, 181.f, 201.f, 221.f, 241.f, 105.5f, 115.5f, 125.5f, 135.5f, 145.5f,
55.5 , 65.5 , 75.5 , 85.5 , 95.5 ,161. , 181. , 201. , 221. , 241. ,161. , 181. , 201. , 221. , 241. ,105.5 , 115.5 , 125.5 , 135.5 , 145.5 , 55.5f, 65.5f, 75.5f, 85.5f, 95.5f, 161.f, 181.f, 201.f, 221.f, 241.f, 161.f, 181.f, 201.f, 221.f, 241.f, 105.5f, 115.5f, 125.5f, 135.5f, 145.5f,
52.75, 57.75, 62.75, 67.75, 72.75,130.5 , 140.5 , 150.5 , 160.5 , 170.5 ,130.5 , 140.5 , 150.5 , 160.5 , 170.5 , 77.75, 82.75, 87.75, 92.75, 97.75}); 52.75f, 57.75f, 62.75f, 67.75f, 72.75f, 130.5f, 140.5f, 150.5f, 160.5f, 170.5f, 130.5f, 140.5f, 150.5f, 160.5f, 170.5f, 77.75f, 82.75f, 87.75f, 92.75f, 97.75f});
input = 0.5; input = 0.5;
weights.linspace(0.1, 0.1); weights.linspace(0.1, 0.1);
@ -2446,14 +2445,14 @@ TEST_F(ConvolutionTests1, deconv2d_test2) {
auto input = NDArrayFactory::create<float>('c', {bS, oH, oW, oC}); auto input = NDArrayFactory::create<float>('c', {bS, oH, oW, oC});
auto weights = NDArrayFactory::create<float>('c', {kH, kW, iC, oC}); auto weights = NDArrayFactory::create<float>('c', {kH, kW, iC, oC});
auto exp = NDArrayFactory::create<float>('c', {bS, iH, iW, iC}, {2.75, 7.75, 12.75, 17.75, 22.75, 30.5 , 40.5 , 50.5 , 60.5 , 70.5 , 30.5 , 40.5 , 50.5 , 60.5 , 70.5 , 30.5 , 40.5 , 50.5 , 60.5 , 70.5 , auto exp = NDArrayFactory::create<float>('c', {bS, iH, iW, iC}, {2.75f, 7.75f, 12.75f, 17.75f, 22.75f, 30.5f, 40.5f, 50.5f, 60.5f, 70.5f, 30.5f, 40.5f, 50.5f, 60.5f, 70.5f, 30.5f, 40.5f, 50.5f, 60.5f, 70.5f,
55.5 , 65.5 , 75.5 , 85.5 , 95.5 ,161. , 181. , 201. , 221. , 241. ,161. , 181. , 201. , 221. , 241. ,161. , 181. , 201. , 221. , 241. , 55.5f, 65.5f, 75.5f, 85.5f, 95.5f, 161.f, 181.f, 201.f, 221.f, 241.f, 161.f, 181.f, 201.f, 221.f, 241.f, 161.f, 181.f, 201.f, 221.f, 241.f,
55.5 , 65.5 , 75.5 , 85.5 , 95.5 ,161. , 181. , 201. , 221. , 241. ,161. , 181. , 201. , 221. , 241. ,161. , 181. , 201. , 221. , 241. , 55.5f, 65.5f, 75.5f, 85.5f, 95.5f, 161.f, 181.f, 201.f, 221.f, 241.f, 161.f, 181.f, 201.f, 221.f, 241.f, 161.f, 181.f, 201.f, 221.f, 241.f,
55.5 , 65.5 , 75.5 , 85.5 , 95.5 ,161. , 181. , 201. , 221. , 241. ,161. , 181. , 201. , 221. , 241. ,161. , 181. , 201. , 221. , 241. , 55.5f, 65.5f, 75.5f, 85.5f, 95.5f, 161.f, 181.f, 201.f, 221.f, 241.f, 161.f, 181.f, 201.f, 221.f, 241.f, 161.f, 181.f, 201.f, 221.f, 241.f,
2.75, 7.75, 12.75, 17.75, 22.75, 30.5 , 40.5 , 50.5 , 60.5 , 70.5 , 30.5 , 40.5 , 50.5 , 60.5 , 70.5 , 30.5 , 40.5 , 50.5 , 60.5 , 70.5 , 2.75f, 7.75f, 12.75f, 17.75f, 22.75f, 30.5f, 40.5f, 50.5f, 60.5f, 70.5f, 30.5f, 40.5f, 50.5f, 60.5f, 70.5f, 30.5f, 40.5f, 50.5f, 60.5f, 70.5f,
55.5 , 65.5 , 75.5 , 85.5 , 95.5 ,161. , 181. , 201. , 221. , 241. ,161. , 181. , 201. , 221. , 241. ,161. , 181. , 201. , 221. , 241. , 55.5f, 65.5f, 75.5f, 85.5f, 95.5f, 161.f, 181.f, 201.f, 221.f, 241.f, 161.f, 181.f, 201.f, 221.f, 241.f, 161.f, 181.f, 201.f, 221.f, 241.f,
55.5 , 65.5 , 75.5 , 85.5 , 95.5 ,161. , 181. , 201. , 221. , 241. ,161. , 181. , 201. , 221. , 241. ,161. , 181. , 201. , 221. , 241. , 55.5f, 65.5f, 75.5f, 85.5f, 95.5f, 161.f, 181.f, 201.f, 221.f, 241.f, 161.f, 181.f, 201.f, 221.f, 241.f, 161.f, 181.f, 201.f, 221.f, 241.f,
55.5 , 65.5 , 75.5 , 85.5 , 95.5 ,161. , 181. , 201. , 221. , 241. ,161. , 181. , 201. , 221. , 241. ,161. , 181. , 201. , 221. , 241. }); 55.5f, 65.5f, 75.5f, 85.5f, 95.5f, 161.f, 181.f, 201.f, 221.f, 241.f, 161.f, 181.f, 201.f, 221.f, 241.f, 161.f, 181.f, 201.f, 221.f, 241.f });
input = 0.5; input = 0.5;
weights.linspace(0.1, 0.1); weights.linspace(0.1, 0.1);
@ -2480,10 +2479,10 @@ TEST_F(ConvolutionTests1, deconv2d_test3) {
auto weights = NDArrayFactory::create<float>('c', {kH, kW, oC, iC}); auto weights = NDArrayFactory::create<float>('c', {kH, kW, oC, iC});
auto bias = NDArrayFactory::create<float>('c', {oC}); auto bias = NDArrayFactory::create<float>('c', {oC});
auto exp = NDArrayFactory::create<float>('c', {bS, oH, oW, oC}, {-2.9, -6.8, -10.7, -2.6, -6.1, -9.6, -16.9, -23.9, -30.9, -13.1, -16.6, -20.1, -11.6, -14.7, -17.8, -2.0, -4.7, -7.4, -1.7, -4.0, -6.3, -11.5, -16.1, auto exp = NDArrayFactory::create<float>('c', {bS, oH, oW, oC}, {-2.9f, -6.8f, -10.7f, -2.6f, -6.1f, -9.6f, -16.9f, -23.9f, -30.9f, -13.1f, -16.6f, -20.1f, -11.6f, -14.7f, -17.8f, -2.0f, -4.7f, -7.4f, -1.7f, -4.0f, -6.3f, -11.5f, -16.1f,
-20.7, -8.6, -10.9, -13.2, -7.1, -9.0, -10.9, -27.4, -32.8, -38.2, -24.4, -29.0, -33.6, -65.0, -74.2, -83.4, -38.2, -42.8, -47.4, -20.7f, -8.6f, -10.9f, -13.2f, -7.1f, -9.0f, -10.9f, -27.4f, -32.8f, -38.2f, -24.4f, -29.0f, -33.6f, -65.0f, -74.2f, -83.4f, -38.2f, -42.8f, -47.4f,
-32.8, -36.6, -40.4, -18.2, -20.9, -23.6, -15.5, -17.8, -20.1, -39.1, -43.7, -48.3, -22.4, -24.7, -27.0, -18.5, -20.4, -22.3, -10.1, -11.6, -13.1, -32.8f, -36.6f, -40.4f, -18.2f, -20.9f, -23.6f, -15.5f, -17.8f, -20.1f, -39.1f, -43.7f, -48.3f, -22.4f, -24.7f, -27.0f, -18.5f, -20.4f, -22.3f, -10.1f, -11.6f, -13.1f,
-7.4, -8.5, -9.6, -19.3, -21.5, -23.7, -10.7, -11.8, -12.9, -6.8, -7.5, -8.2}); -7.4f, -8.5f, -9.6f, -19.3f, -21.5f, -23.7f, -10.7f, -11.8f, -12.9f, -6.8f, -7.5f, -8.2f});
input.linspace(-10, 0.5); input.linspace(-10, 0.5);
weights.linspace(0.1, 0.1); weights.linspace(0.1, 0.1);
@ -2568,17 +2567,17 @@ TYPED_TEST(TypedConvolutionTests1, deconv2d_test6) {
int dataFormat = 0; // 1-NHWC, 0-NCHW int dataFormat = 0; // 1-NHWC, 0-NCHW
auto input = NDArrayFactory::create<TypeParam>('c', {bS, iC, iH, iW}); auto input = NDArrayFactory::create<TypeParam>('c', {bS, iC, iH, iW});
auto weights = NDArrayFactory::create<TypeParam>('c', {kH, kW, oC, iC}, {1., 76., 151., 26., 101., 176., 51., 126., 201., 2., 77., 152., 27., 102., 177., 52., 127., 202., 3., 78., 153., 28., 103., 178., 53., 128., 203., auto weights = NDArrayFactory::create<TypeParam>('c', {kH, kW, oC, iC}, {1.f, 76.f, 151.f, 26.f, 101.f, 176.f, 51.f, 126.f, 201.f, 2.f, 77.f, 152.f, 27.f, 102.f, 177.f, 52.f, 127.f, 202.f, 3.f, 78.f, 153.f, 28.f, 103.f, 178.f, 53.f, 128.f, 203.f,
4., 79., 154., 29., 104., 179., 54., 129., 204., 5., 80., 155., 30., 105., 180., 55., 130., 205., 6., 81., 156., 31., 106., 181., 56., 131., 206., 4.f, 79.f, 154.f, 29.f, 104.f, 179.f, 54.f, 129.f, 204.f, 5.f, 80.f, 155.f, 30.f, 105.f, 180.f, 55.f, 130.f, 205.f, 6.f, 81.f, 156.f, 31.f, 106.f, 181.f, 56.f, 131.f, 206.f,
7., 82., 157., 32., 107., 182., 57., 132., 207., 8., 83., 158., 33., 108., 183., 58., 133., 208., 9., 84., 159., 34., 109., 184., 59., 134., 209., 7.f, 82.f, 157.f, 32.f, 107.f, 182.f, 57.f, 132.f, 207.f, 8.f, 83.f, 158.f, 33.f, 108.f, 183.f, 58.f, 133.f, 208.f, 9.f, 84.f, 159.f, 34.f, 109.f, 184.f, 59.f, 134.f, 209.f,
10., 85., 160., 35., 110., 185., 60., 135., 210., 11., 86., 161., 36., 111., 186., 61., 136., 211., 12., 87., 162., 37., 112., 187., 62., 137., 212., 10.f, 85.f, 160.f, 35.f, 110.f, 185.f, 60.f, 135.f, 210.f, 11.f, 86.f, 161.f, 36.f, 111.f, 186.f, 61.f, 136.f, 211.f, 12.f, 87.f, 162.f, 37.f, 112.f, 187.f, 62.f, 137.f, 212.f,
13., 88., 163., 38., 113., 188., 63., 138., 213., 14., 89., 164., 39., 114., 189., 64., 139., 214., 15., 90., 165., 40., 115., 190., 65., 140., 215., 13.f, 88.f, 163.f, 38.f, 113.f, 188.f, 63.f, 138.f, 213.f, 14.f, 89.f, 164.f, 39.f, 114.f, 189.f, 64.f, 139.f, 214.f, 15.f, 90.f, 165.f, 40.f, 115.f, 190.f, 65.f, 140.f, 215.f,
16., 91., 166., 41., 116., 191., 66., 141., 216., 17., 92., 167., 42., 117., 192., 67., 142., 217., 18., 93., 168., 43., 118., 193., 68., 143., 218., 16.f, 91.f, 166.f, 41.f, 116.f, 191.f, 66.f, 141.f, 216.f, 17.f, 92.f, 167.f, 42.f, 117.f, 192.f, 67.f, 142.f, 217.f, 18.f, 93.f, 168.f, 43.f, 118.f, 193.f, 68.f, 143.f, 218.f,
19., 94., 169., 44., 119., 194., 69., 144., 219., 20., 95., 170., 45., 120., 195., 70., 145., 220., 21., 96., 171., 46., 121., 196., 71., 146., 221., 19.f, 94.f, 169.f, 44.f, 119.f, 194.f, 69.f, 144.f, 219.f, 20.f, 95.f, 170.f, 45.f, 120.f, 195.f, 70.f, 145.f, 220.f, 21.f, 96.f, 171.f, 46.f, 121.f, 196.f, 71.f, 146.f, 221.f,
22., 97., 172., 47., 122., 197., 72., 147., 222., 23., 98., 173., 48., 123., 198., 73., 148., 223., 24., 99., 174., 49., 124., 199., 74., 149., 224., 22.f, 97.f, 172.f, 47.f, 122.f, 197.f, 72.f, 147.f, 222.f, 23.f, 98.f, 173.f, 48.f, 123.f, 198.f, 73.f, 148.f, 223.f, 24.f, 99.f, 174.f, 49.f, 124.f, 199.f, 74.f, 149.f, 224.f,
25., 100., 175.,50., 125., 200.,75., 150., 225.}); 25.f, 100.f, 175.f,50.f, 125.f, 200.f,75.f, 150.f, 225.f});
auto exp = NDArrayFactory::create<TypeParam>('c', {bS, oC, oH, oW}, {6276.0, 12831.0, 19668.0, 26790.0, 27012.0, 20703.0, 14100.0, 7200.0, 13719.0, 28023.0, 42918.0, 58410.0, 58902.0, 45105.0, 30693.0, 15660.0, 22389.0, 45696.0, 69930.0, 95100.0, 95910.0, 73386.0, 49899.0, 25440.0, 32346.0, 65970.0, 100884.0, 137100.0, 138276.0, 105726.0, 71838.0, 36600.0, 33726.0, 68790.0, 105204.0, 142980.0, 144156.0, 110226.0, 74898.0, 38160.0, 27555.0, 56154.0, 85806.0, 116520.0, 117474.0, 89748.0, 60933.0, 31020.0, 19917.0, 40557.0, 61926.0, 84030.0, 84714.0, 64671.0, 43875.0, 22320.0, 10752.0, 21879.0, 33384.0, 45270.0, 45636.0, 34815.0, 23604.0, 12000.0, 7551.0, 15456.0, 23718.0, 32340.0, 32562.0, 24978.0, 17025.0, 8700.0, 16569.0, 33873.0, 51918.0, 70710.0, 71202.0, 54555.0, 37143.0, 18960.0, 27114.0, 55371.0, 84780.0, 115350.0, 116160.0, 88911.0, 60474.0, 30840.0, 39246.0, 80070.0, 122484.0, 166500.0, 167676.0, 128226.0, 87138.0, 44400.0, 40626.0, 82890.0, 126804.0, 172380.0, 173556.0, 132726.0, 90198.0, 45960.0, 33180.0, 67629.0, 103356.0, 140370.0, 141324.0, 107973.0, 73308.0, 37320.0, 23967.0, 48807.0, 74526.0, 101130.0, 101814.0, 77721.0, 52725.0, 26820.0, 12927.0, 26304.0, 40134.0, 54420.0, 54786.0, 41790.0, 28329.0, 14400.0, 8826.0, 18081.0, 27768.0, 37890.0, 38112.0, 29253.0, 19950.0, 10200.0, 19419.0, 39723.0, 60918.0, 83010.0, 83502.0, 64005.0, 43593.0, 22260.0, 31839.0, 65046.0, 99630.0, 135600.0, 136410.0, 104436.0, 71049.0, 36240.0, 46146.0, 94170.0, 144084.0, 195900.0, 197076.0, 150726.0, 102438.0, 52200.0, 47526.0, 96990.0, 148404.0, 201780.0, 202956.0, 155226.0, 105498.0, 53760.0, 38805.0, 79104.0, 120906.0, 164220.0, 165174.0, 126198.0, 85683.0, 43620.0, 28017.0, 57057.0, 87126.0, 118230.0, 118914.0, 90771.0, 61575.0, 31320.0, 15102.0, 30729.0, 46884.0, 63570.0, 63936.0, 48765.0, 33054.0, 16800.0, 17220.0, 34863.0, 52932.0, 71430.0, 72228.0, 54831.0, 36996.0, 18720.0, 36327.0, 73527.0, 111606.0, 150570.0, 152214.0, 115521.0, 77925.0, 39420.0, 57381.0, 116112.0, 176202.0, 237660.0, 240198.0, 182250.0, 122907.0, 62160.0, 80442.0, 162738.0, 246900.0, 332940.0, 336420.0, 255198.0, 172062.0, 87000.0, 84702.0, 171318.0, 259860.0, 350340.0, 353820.0, 268338.0, 180882.0, 91440.0, 66867.0, 135210.0, 205038.0, 276360.0, 279042.0, 211572.0, 142581.0, 72060.0, 46845.0, 94701.0, 143574.0, 193470.0, 195306.0, 148047.0, 99747.0, 50400.0, 24576.0, 49671.0, 75288.0, 101430.0, 102372.0, 77583.0, 52260.0, 26400.0, 22095.0, 44688.0, 67782.0, 91380.0, 92178.0, 69906.0, 47121.0, 23820.0, 46377.0, 93777.0, 142206.0, 191670.0, 193314.0, 146571.0, 98775.0, 49920.0, 72906.0, 147387.0, 223452.0, 301110.0, 303648.0, 230175.0, 155082.0, 78360.0, 101742.0, 205638.0, 311700.0, 419940.0, 423420.0, 320898.0, 216162.0, 109200.0, 106002.0, 214218.0, 324660.0, 437340.0, 440820.0, 334038.0, 224982.0, 113640.0, 83292.0, 168285.0, 254988.0, 343410.0, 346092.0, 262197.0, 176556.0, 89160.0, 58095.0, 117351.0, 177774.0, 239370.0, 241206.0, 182697.0, 122997.0, 62100.0, 30351.0, 61296.0, 92838.0, 124980.0, 125922.0, 95358.0, 64185.0, 32400.0, 26970.0, 54513.0, 82632.0, 111330.0, 112128.0, 84981.0, 57246.0, 28920.0, 56427.0, 114027.0, 172806.0, 232770.0, 234414.0, 177621.0, 119625.0, 60420.0, 88431.0, 178662.0, 270702.0, 364560.0, 367098.0, 278100.0, 187257.0, 94560.0, 123042.0, 248538.0, 376500.0, 506940.0, 510420.0, 386598.0, 260262.0, 131400.0, 127302.0, 257118.0, 389460.0, 524340.0, 527820.0, 399738.0, 269082.0, 135840.0, 99717.0, 201360.0, 304938.0, 410460.0, 413142.0, 312822.0, 210531.0, 106260.0, 69345.0, 140001.0, 211974.0, 285270.0, 287106.0, 217347.0, 146247.0, 73800.0, 36126.0, 72921.0, 110388.0, 148530.0, 149472.0, 113133.0, 76110.0, 38400.0}); auto exp = NDArrayFactory::create<TypeParam>('c', {bS, oC, oH, oW}, {6276.0f, 12831.0f, 19668.0f, 26790.0f, 27012.0f, 20703.0f, 14100.0f, 7200.0f, 13719.0f, 28023.0f, 42918.0f, 58410.0f, 58902.0f, 45105.0f, 30693.0f, 15660.0f, 22389.0f, 45696.0f, 69930.0f, 95100.0f, 95910.0f, 73386.0f, 49899.0f, 25440.0f, 32346.0f, 65970.0f, 100884.0f, 137100.0f, 138276.0f, 105726.0f, 71838.0f, 36600.0f, 33726.0f, 68790.0f, 105204.0f, 142980.0f, 144156.0f, 110226.0f, 74898.0f, 38160.0f, 27555.0f, 56154.0f, 85806.0f, 116520.0f, 117474.0f, 89748.0f, 60933.0f, 31020.0f, 19917.0f, 40557.0f, 61926.0f, 84030.0f, 84714.0f, 64671.0f, 43875.0f, 22320.0f, 10752.0f, 21879.0f, 33384.0f, 45270.0f, 45636.0f, 34815.0f, 23604.0f, 12000.0f, 7551.0f, 15456.0f, 23718.0f, 32340.0f, 32562.0f, 24978.0f, 17025.0f, 8700.0f, 16569.0f, 33873.0f, 51918.0f, 70710.0f, 71202.0f, 54555.0f, 37143.0f, 18960.0f, 27114.0f, 55371.0f, 84780.0f, 115350.0f, 116160.0f, 88911.0f, 60474.0f, 30840.0f, 39246.0f, 80070.0f, 122484.0f, 166500.0f, 167676.0f, 128226.0f, 87138.0f, 44400.0f, 40626.0f, 82890.0f, 126804.0f, 172380.0f, 173556.0f, 132726.0f, 90198.0f, 45960.0f, 33180.0f, 67629.0f, 103356.0f, 140370.0f, 141324.0f, 107973.0f, 73308.0f, 37320.0f, 23967.0f, 48807.0f, 74526.0f, 101130.0f, 101814.0f, 77721.0f, 52725.0f, 26820.0f, 12927.0f, 26304.0f, 40134.0f, 54420.0f, 54786.0f, 41790.0f, 28329.0f, 14400.0f, 8826.0f, 18081.0f, 27768.0f, 37890.0f, 38112.0f, 29253.0f, 19950.0f, 10200.0f, 19419.0f, 39723.0f, 60918.0f, 83010.0f, 83502.0f, 64005.0f, 43593.0f, 22260.0f, 31839.0f, 65046.0f, 99630.0f, 135600.0f, 136410.0f, 104436.0f, 71049.0f, 36240.0f, 46146.0f, 94170.0f, 144084.0f, 195900.0f, 197076.0f, 150726.0f, 102438.0f, 52200.0f, 47526.0f, 96990.0f, 148404.0f, 201780.0f, 202956.0f, 155226.0f, 105498.0f, 53760.0f, 38805.0f, 79104.0f, 120906.0f, 164220.0f, 165174.0f, 126198.0f, 85683.0f, 43620.0f, 28017.0f, 57057.0f, 87126.0f, 118230.0f, 118914.0f, 90771.0f, 61575.0f, 31320.0f, 15102.0f, 30729.0f, 46884.0f, 63570.0f, 63936.0f, 48765.0f, 33054.0f, 16800.0f, 17220.0f, 34863.0f, 52932.0f, 71430.0f, 72228.0f, 54831.0f, 36996.0f, 18720.0f, 36327.0f, 73527.0f, 111606.0f, 150570.0f, 152214.0f, 115521.0f, 77925.0f, 39420.0f, 57381.0f, 116112.0f, 176202.0f, 237660.0f, 240198.0f, 182250.0f, 122907.0f, 62160.0f, 80442.0f, 162738.0f, 246900.0f, 332940.0f, 336420.0f, 255198.0f, 172062.0f, 87000.0f, 84702.0f, 171318.0f, 259860.0f, 350340.0f, 353820.0f, 268338.0f, 180882.0f, 91440.0f, 66867.0f, 135210.0f, 205038.0f, 276360.0f, 279042.0f, 211572.0f, 142581.0f, 72060.0f, 46845.0f, 94701.0f, 143574.0f, 193470.0f, 195306.0f, 148047.0f, 99747.0f, 50400.0f, 24576.0f, 49671.0f, 75288.0f, 101430.0f, 102372.0f, 77583.0f, 52260.0f, 26400.0f, 22095.0f, 44688.0f, 67782.0f, 91380.0f, 92178.0f, 69906.0f, 47121.0f, 23820.0f, 46377.0f, 93777.0f, 142206.0f, 191670.0f, 193314.0f, 146571.0f, 98775.0f, 49920.0f, 72906.0f, 147387.0f, 223452.0f, 301110.0f, 303648.0f, 230175.0f, 155082.0f, 78360.0f, 101742.0f, 205638.0f, 311700.0f, 419940.0f, 423420.0f, 320898.0f, 216162.0f, 109200.0f, 106002.0f, 214218.0f, 324660.0f, 437340.0f, 440820.0f, 334038.0f, 224982.0f, 113640.0f, 83292.0f, 168285.0f, 254988.0f, 343410.0f, 346092.0f, 262197.0f, 176556.0f, 89160.0f, 58095.0f, 117351.0f, 177774.0f, 239370.0f, 241206.0f, 182697.0f, 122997.0f, 62100.0f, 30351.0f, 61296.0f, 92838.0f, 124980.0f, 125922.0f, 95358.0f, 64185.0f, 32400.0f, 26970.0f, 54513.0f, 82632.0f, 111330.0f, 112128.0f, 84981.0f, 57246.0f, 28920.0f, 56427.0f, 114027.0f, 172806.0f, 232770.0f, 234414.0f, 177621.0f, 119625.0f, 60420.0f, 88431.0f, 178662.0f, 270702.0f, 364560.0f, 367098.0f, 278100.0f, 187257.0f, 94560.0f, 123042.0f, 248538.0f, 376500.0f, 506940.0f, 510420.0f, 386598.0f, 260262.0f, 131400.0f, 127302.0f, 257118.0f, 389460.0f, 524340.0f, 527820.0f, 399738.0f, 269082.0f, 135840.0f, 99717.0f, 201360.0f, 304938.0f, 410460.0f, 413142.0f, 312822.0f, 210531.0f, 106260.0f, 69345.0f, 140001.0f, 211974.0f, 285270.0f, 287106.0f, 217347.0f, 146247.0f, 73800.0f, 36126.0f, 72921.0f, 110388.0f, 148530.0f, 149472.0f, 113133.0f, 76110.0f, 38400.0f});
input.linspace(1); input.linspace(1);
@ -2674,14 +2673,14 @@ TYPED_TEST(TypedConvolutionTests1, deconv2d_tf_test1) {
auto input = NDArrayFactory::create<TypeParam>('c', {bS, oH, oW, oC}); auto input = NDArrayFactory::create<TypeParam>('c', {bS, oH, oW, oC});
auto weights = NDArrayFactory::create<TypeParam>('c', {kH, kW, iC, oC}); auto weights = NDArrayFactory::create<TypeParam>('c', {kH, kW, iC, oC});
auto outShape = NDArrayFactory::create<TypeParam>('c', {4}, {static_cast<TypeParam>(bS), static_cast<TypeParam>(iH), static_cast<TypeParam>(iW), static_cast<TypeParam>(iC)}); auto outShape = NDArrayFactory::create<TypeParam>('c', {4}, {static_cast<TypeParam>(bS), static_cast<TypeParam>(iH), static_cast<TypeParam>(iW), static_cast<TypeParam>(iC)});
auto exp = NDArrayFactory::create<TypeParam>('c', {bS, iH, iW, iC}, { 2.75, 7.75, 12.75, 17.75, 22.75, 30.5 , 40.5 , 50.5 , 60.5 , 70.5 , 30.5 , 40.5 , 50.5 , 60.5 , 70.5 , 27.75, 32.75, 37.75, 42.75, 47.75, auto exp = NDArrayFactory::create<TypeParam>('c', {bS, iH, iW, iC}, { 2.75f, 7.75f, 12.75f, 17.75f, 22.75f, 30.5f, 40.5f, 50.5f, 60.5f, 70.5f, 30.5f, 40.5f, 50.5f, 60.5f, 70.5f, 27.75f, 32.75f, 37.75f, 42.75f, 47.75f,
55.5 , 65.5 , 75.5 , 85.5 , 95.5 ,161. , 181. , 201. , 221. , 241. ,161. , 181. , 201. , 221. , 241. ,105.5 , 115.5 , 125.5 , 135.5 , 145.5 , 55.5f, 65.5f, 75.5f, 85.5f, 95.5f, 161.f, 181.f, 201.f, 221.f, 241.f, 161.f, 181.f, 201.f, 221.f, 241.f, 105.5f, 115.5f, 125.5f, 135.5f, 145.5f,
55.5 , 65.5 , 75.5 , 85.5 , 95.5 ,161. , 181. , 201. , 221. , 241. ,161. , 181. , 201. , 221. , 241. ,105.5 , 115.5 , 125.5 , 135.5 , 145.5 , 55.5f, 65.5f, 75.5f, 85.5f, 95.5f, 161.f, 181.f, 201.f, 221.f, 241.f, 161.f, 181.f, 201.f, 221.f, 241.f, 105.5f, 115.5f, 125.5f, 135.5f, 145.5f,
52.75, 57.75, 62.75, 67.75, 72.75,130.5 , 140.5 , 150.5 , 160.5 , 170.5 ,130.5 , 140.5 , 150.5 , 160.5 , 170.5 , 77.75, 82.75, 87.75, 92.75, 97.75, 52.75f, 57.75f, 62.75f, 67.75f, 72.75f, 130.5f, 140.5f, 150.5f, 160.5f, 170.5f, 130.5f, 140.5f, 150.5f, 160.5f, 170.5f, 77.75f, 82.75f, 87.75f, 92.75f, 97.75f,
2.75, 7.75, 12.75, 17.75, 22.75, 30.5 , 40.5 , 50.5 , 60.5 , 70.5 , 30.5 , 40.5 , 50.5 , 60.5 , 70.5 , 27.75, 32.75, 37.75, 42.75, 47.75, 2.75f, 7.75f, 12.75f, 17.75f, 22.75f, 30.5f, 40.5f, 50.5f, 60.5f, 70.5f, 30.5f, 40.5f, 50.5f, 60.5f, 70.5f, 27.75f, 32.75f, 37.75f, 42.75f, 47.75f,
55.5 , 65.5 , 75.5 , 85.5 , 95.5 ,161. , 181. , 201. , 221. , 241. ,161. , 181. , 201. , 221. , 241. ,105.5 , 115.5 , 125.5 , 135.5 , 145.5 , 55.5f, 65.5f, 75.5f, 85.5f, 95.5f, 161.f, 181.f, 201.f, 221.f, 241.f, 161.f, 181.f, 201.f, 221.f, 241.f, 105.5f, 115.5f, 125.5f, 135.5f, 145.5f,
55.5 , 65.5 , 75.5 , 85.5 , 95.5 ,161. , 181. , 201. , 221. , 241. ,161. , 181. , 201. , 221. , 241. ,105.5 , 115.5 , 125.5 , 135.5 , 145.5 , 55.5f, 65.5f, 75.5f, 85.5f, 95.5f, 161.f, 181.f, 201.f, 221.f, 241.f, 161.f, 181.f, 201.f, 221.f, 241.f, 105.5f, 115.5f, 125.5f, 135.5f, 145.5f,
52.75, 57.75, 62.75, 67.75, 72.75,130.5 , 140.5 , 150.5 , 160.5 , 170.5 ,130.5 , 140.5 , 150.5 , 160.5 , 170.5 , 77.75, 82.75, 87.75, 92.75, 97.75}); 52.75f, 57.75f, 62.75f, 67.75f, 72.75f, 130.5f, 140.5f, 150.5f, 160.5f, 170.5f, 130.5f, 140.5f, 150.5f, 160.5f, 170.5f, 77.75f, 82.75f, 87.75f, 92.75f, 97.75f});
input = 0.5; input = 0.5;
weights.linspace(0.1, 0.1); weights.linspace(0.1, 0.1);

View File

@ -110,14 +110,14 @@ TYPED_TEST(TypedConvolutionTests2, deconv2d_tf_test2) {
auto input = NDArrayFactory::create<TypeParam>('c', {bS, oH, oW, oC}); auto input = NDArrayFactory::create<TypeParam>('c', {bS, oH, oW, oC});
auto weights = NDArrayFactory::create<TypeParam>('c', {kH, kW, iC, oC}); auto weights = NDArrayFactory::create<TypeParam>('c', {kH, kW, iC, oC});
auto outShape = NDArrayFactory::create<TypeParam>('c', {4}, {static_cast<TypeParam>(bS), static_cast<TypeParam>(iH), static_cast<TypeParam>(iW), static_cast<TypeParam>(iC)}); auto outShape = NDArrayFactory::create<TypeParam>('c', {4}, {static_cast<TypeParam>(bS), static_cast<TypeParam>(iH), static_cast<TypeParam>(iW), static_cast<TypeParam>(iC)});
auto exp = NDArrayFactory::create<TypeParam>('c', {bS, iH, iW, iC}, {2.75, 7.75, 12.75, 17.75, 22.75, 30.5 , 40.5 , 50.5 , 60.5 , 70.5 , 30.5 , 40.5 , 50.5 , 60.5 , 70.5 , 30.5 , 40.5 , 50.5 , 60.5 , 70.5 , auto exp = NDArrayFactory::create<TypeParam>('c', {bS, iH, iW, iC}, {2.75f, 7.75f, 12.75f, 17.75f, 22.75f, 30.5f, 40.5f, 50.5f, 60.5f, 70.5f, 30.5f, 40.5f, 50.5f, 60.5f, 70.5f, 30.5f, 40.5f, 50.5f, 60.5f, 70.5f,
55.5 , 65.5 , 75.5 , 85.5 , 95.5 ,161. , 181. , 201. , 221. , 241. ,161. , 181. , 201. , 221. , 241. ,161. , 181. , 201. , 221. , 241. , 55.5f, 65.5f, 75.5f, 85.5f, 95.5f, 161.f, 181.f, 201.f, 221.f, 241.f, 161.f, 181.f, 201.f, 221.f, 241.f, 161.f, 181.f, 201.f, 221.f, 241.f,
55.5 , 65.5 , 75.5 , 85.5 , 95.5 ,161. , 181. , 201. , 221. , 241. ,161. , 181. , 201. , 221. , 241. ,161. , 181. , 201. , 221. , 241. , 55.5f, 65.5f, 75.5f, 85.5f, 95.5f, 161.f, 181.f, 201.f, 221.f, 241.f, 161.f, 181.f, 201.f, 221.f, 241.f, 161.f, 181.f, 201.f, 221.f, 241.f,
55.5 , 65.5 , 75.5 , 85.5 , 95.5 ,161. , 181. , 201. , 221. , 241. ,161. , 181. , 201. , 221. , 241. ,161. , 181. , 201. , 221. , 241. , 55.5f, 65.5f, 75.5f, 85.5f, 95.5f, 161.f, 181.f, 201.f, 221.f, 241.f, 161.f, 181.f, 201.f, 221.f, 241.f, 161.f, 181.f, 201.f, 221.f, 241.f,
2.75, 7.75, 12.75, 17.75, 22.75, 30.5 , 40.5 , 50.5 , 60.5 , 70.5 , 30.5 , 40.5 , 50.5 , 60.5 , 70.5 , 30.5 , 40.5 , 50.5 , 60.5 , 70.5 , 2.75f, 7.75f, 12.75f, 17.75f, 22.75f, 30.5f, 40.5f, 50.5f, 60.5f, 70.5f, 30.5f, 40.5f, 50.5f, 60.5f, 70.5f, 30.5f, 40.5f, 50.5f, 60.5f, 70.5f,
55.5 , 65.5 , 75.5 , 85.5 , 95.5 ,161. , 181. , 201. , 221. , 241. ,161. , 181. , 201. , 221. , 241. ,161. , 181. , 201. , 221. , 241. , 55.5f, 65.5f, 75.5f, 85.5f, 95.5f, 161.f, 181.f, 201.f, 221.f, 241.f, 161.f, 181.f, 201.f, 221.f, 241.f, 161.f, 181.f, 201.f, 221.f, 241.f,
55.5 , 65.5 , 75.5 , 85.5 , 95.5 ,161. , 181. , 201. , 221. , 241. ,161. , 181. , 201. , 221. , 241. ,161. , 181. , 201. , 221. , 241. , 55.5f, 65.5f, 75.5f, 85.5f, 95.5f, 161.f, 181.f, 201.f, 221.f, 241.f, 161.f, 181.f, 201.f, 221.f, 241.f, 161.f, 181.f, 201.f, 221.f, 241.f,
55.5 , 65.5 , 75.5 , 85.5 , 95.5 ,161. , 181. , 201. , 221. , 241. ,161. , 181. , 201. , 221. , 241. ,161. , 181. , 201. , 221. , 241. }); 55.5f, 65.5f, 75.5f, 85.5f, 95.5f, 161.f, 181.f, 201.f, 221.f, 241.f, 161.f, 181.f, 201.f, 221.f, 241.f, 161.f, 181.f, 201.f, 221.f, 241.f});
input = 0.5; input = 0.5;
weights.linspace(0.1, 0.1); weights.linspace(0.1, 0.1);
@ -150,7 +150,7 @@ TYPED_TEST(TypedConvolutionTests2, Test_DeConv2D_TF_1) {
////////////////////////////////////////////////////////////////////// //////////////////////////////////////////////////////////////////////
TYPED_TEST(TypedConvolutionTests2, Test_DeConv2D_TF_2) { TYPED_TEST(TypedConvolutionTests2, Test_DeConv2D_TF_2) {
auto input0 = NDArrayFactory::create<TypeParam>('c', {4}, {3, 8, 8, 16}); auto input0 = NDArrayFactory::create<TypeParam>('c', {4}, {3.f, 8.f, 8.f, 16.f});
auto input1 = NDArrayFactory::create<TypeParam>('c', {7, 7, 16, 5}, {1.05293429f, -0.89349967f, 0.31027254f, 1.22991478f, -0.62926656f, 0.56918693f, auto input1 = NDArrayFactory::create<TypeParam>('c', {7, 7, 16, 5}, {1.05293429f, -0.89349967f, 0.31027254f, 1.22991478f, -0.62926656f, 0.56918693f,
-1.60992694f, 1.10167944f, -0.80843484f, 0.07521993f, -1.15994942f, 0.76016301f, -0.40056285f, -1.16872537f, -0.91384381f, -0.36700436f, 1.82389200f, -1.18200207f, 0.51612782f, -0.92479187f, -0.09307563f, -0.55122334f, 1.23532486f, -1.11124146f, -0.05812126f, 0.68159896f, 0.69125599f, -0.77127314f, -0.10874277f, 0.86469102f, -1.60992694f, 1.10167944f, -0.80843484f, 0.07521993f, -1.15994942f, 0.76016301f, -0.40056285f, -1.16872537f, -0.91384381f, -0.36700436f, 1.82389200f, -1.18200207f, 0.51612782f, -0.92479187f, -0.09307563f, -0.55122334f, 1.23532486f, -1.11124146f, -0.05812126f, 0.68159896f, 0.69125599f, -0.77127314f, -0.10874277f, 0.86469102f,
@ -569,7 +569,6 @@ TEST_F(ConvolutionTests2, deconv3d_test4) {
////////////////////////////////////////////////////////////////////// //////////////////////////////////////////////////////////////////////
TEST_F(ConvolutionTests2, deconv3d_test5) { TEST_F(ConvolutionTests2, deconv3d_test5) {
int bS=1, oD=5,oH=5,oW=5, oC=3,iC=2, kD=2,kH=2,kW=2, sD=1,sH=1,sW=1, pD=0,pH=0,pW=0, dD=2,dH=2,dW=2; int bS=1, oD=5,oH=5,oW=5, oC=3,iC=2, kD=2,kH=2,kW=2, sD=1,sH=1,sW=1, pD=0,pH=0,pW=0, dD=2,dH=2,dW=2;
int iD=3,iH=3,iW=3; int iD=3,iH=3,iW=3;
int paddingMode = 0; // 1-SAME, 0-VALID; int paddingMode = 0; // 1-SAME, 0-VALID;
@ -579,22 +578,22 @@ TEST_F(ConvolutionTests2, deconv3d_test5) {
auto weights = NDArrayFactory::create<float>('c', {kD, kH, kW, oC, iC}); auto weights = NDArrayFactory::create<float>('c', {kD, kH, kW, oC, iC});
auto bias = NDArrayFactory::create<float>('c', {oC}); auto bias = NDArrayFactory::create<float>('c', {oC});
auto exp = NDArrayFactory::create<float>('c', {bS, oD, oH, oW, oC}, {-2.9, -6.8, -10.7, -2.6, -6.1, -9.6, -16.9, -23.9, -30.9, -13.1, -16.6, -20.1, -11.6, -14.7, -17.8, -2.0, -4.7, -7.4, -1.7, -4.0, -6.3, -11.5, auto exp = NDArrayFactory::create<float>('c', {bS, oD, oH, oW, oC}, {-2.9f, -6.8f, -10.7f, -2.6f, -6.1f, -9.6f, -16.9f, -23.9f, -30.9f, -13.1f, -16.6f, -20.1f, -11.6f, -14.7f, -17.8f, -2.0f, -4.7f, -7.4f, -1.7f, -4.0f, -6.3f, -11.5f,
-16.1, -20.7, -8.6, -10.9, -13.2, -7.1, -9.0, -10.9, -27.4, -32.8, -38.2, -24.4, -29.0, -33.6, -65.0, -74.2, -83.4, -38.2, -42.8, -47.4, -32.8, -16.1f, -20.7f, -8.6f, -10.9f, -13.2f, -7.1f, -9.0f, -10.9f, -27.4f, -32.8f, -38.2f, -24.4f, -29.0f, -33.6f, -65.0f, -74.2f, -83.4f, -38.2f, -42.8f, -47.4f, -32.8f,
-36.6, -40.4, -18.2, -20.9, -23.6, -15.5, -17.8, -20.1, -39.1, -43.7, -48.3, -22.4, -24.7, -27.0, -18.5, -20.4, -22.3, -10.1, -11.6, -13.1, -7.4, -36.6f, -40.4f, -18.2f, -20.9f, -23.6f, -15.5f, -17.8f, -20.1f, -39.1f, -43.7f, -48.3f, -22.4f, -24.7f, -27.0f, -18.5f, -20.4f, -22.3f, -10.1f, -11.6f, -13.1f, -7.4f,
-8.5, -9.6, -19.3, -21.5, -23.7, -10.7, -11.8, -12.9, -6.8, -7.5, -8.2, -0.2, -0.5, -0.8, 0.1, 0.2, 0.3, -0.7, -0.5, -0.3, 0.4, 0.5, 0.6, 1.9, 2.4, -8.5f, -9.6f, -19.3f, -21.5f, -23.7f, -10.7f, -11.8f, -12.9f, -6.8f, -7.5f, -8.2f, -0.2f, -0.5f, -0.8f, 0.1f, 0.2f, 0.3f, -0.7f, -0.5f, -0.3f, 0.4f, 0.5f, 0.6f, 1.9f, 2.4f,
2.9, 0.7, 1.6, 2.5, 1.0, 2.3, 3.6, 4.7, 7.3, 9.9, 4.9, 6.2, 7.5, 6.4, 8.1, 9.8, -0.4, 1.4, 3.2, 2.6, 5.2, 7.8, 10.6, 15.8, 21.0, 10.4, 13.0, 15.6, 2.9f, 0.7f, 1.6f, 2.5f, 1.0f, 2.3f, 3.6f, 4.7f, 7.3f, 9.9f, 4.9f, 6.2f, 7.5f, 6.4f, 8.1f, 9.8f, -0.4f, 1.4f, 3.2f, 2.6f, 5.2f, 7.8f, 10.6f, 15.8f, 21.0f, 10.4f, 13.0f, 15.6f,
15.8, 19.2, 22.6, 6.1, 7.0, 7.9, 8.8, 10.1, 11.4, 20.3, 22.9, 25.5, 12.7, 14.0, 15.3, 16.6, 18.3, 20.0, 14.2, 16.3, 18.4, 16.9, 19.4, 21.9, 40.1, 15.8f, 19.2f, 22.6f, 6.1f, 7.0f, 7.9f, 8.8f, 10.1f, 11.4f, 20.3f, 22.9f, 25.5f, 12.7f, 14.0f, 15.3f, 16.6f, 18.3f, 20.0f, 14.2f, 16.3f, 18.4f, 16.9f, 19.4f, 21.9f, 40.1f,
45.1, 50.1, 24.4, 26.9, 29.4, 28.3, 31.2, 34.1, -47.2, -47.8, -48.4, -41.8, -41.6, -41.4, -85.4, -85., -84.6, -41.2, -41.0, -40.8, -33.4, -32.4, -31.4, 45.1f, 50.1f, 24.4f, 26.9f, 29.4f, 28.3f, 31.2f, 34.1f, -47.2f, -47.8f, -48.4f, -41.8f, -41.6f, -41.4f, -85.4f, -85.f, -84.6f, -41.2f, -41.0f, -40.8f, -33.4f, -32.4f, -31.4f,
-31., -29.2, -27.4, -25.6, -23.0, -20.4, -45.8, -40.6, -35.4, -17.8, -15.2, -12.6, -10.0, -6.6, -3.2, -65.6, -62.0, -58.4, -50.0, -44.8, -39.6, -89.2, -31.f, -29.2f, -27.4f, -25.6f, -23.0f, -20.4f, -45.8f, -40.6f, -35.4f, -17.8f, -15.2f, -12.6f, -10.0f, -6.6f, -3.2f, -65.6f, -62.0f, -58.4f, -50.0f, -44.8f, -39.6f, -89.2f,
-78.8, -68.4, -34.4, -29.2, -24., -14.0, -7.2, -0.4, -20.2, -18.4, -16.6, -10., -7.4, -4.8, -14.6, -9.4, -4.2, -2.2, 0.4, 3.0, 10.4, 13.8, 17.2, 10.4, -78.8f, -68.4f, -34.4f, -29.2f, -24.f, -14.0f, -7.2f, -0.4f, -20.2f, -18.4f, -16.6f, -10.f, -7.4f, -4.8f, -14.6f, -9.4f, -4.2f, -2.2f, 0.4f, 3.0f, 10.4f, 13.8f, 17.2f, 10.4f,
14.6, 18.8, 20.6, 25.6, 30.6, 53.8, 63.8, 73.8, 35.6, 40.6, 45.6, 48.2, 54.0, 59.8, -3.8, -4.1, -4.4, 1.3, 1.4, 1.5, 1.7, 1.9, 2.1, 1.6, 1.7, 1.8, 7.9, 14.6f, 18.8f, 20.6f, 25.6f, 30.6f, 53.8f, 63.8f, 73.8f, 35.6f, 40.6f, 45.6f, 48.2f, 54.0f, 59.8f, -3.8f, -4.1f, -4.4f, 1.3f, 1.4f, 1.5f, 1.7f, 1.9f, 2.1f, 1.6f, 1.7f, 1.8f, 7.9f,
8.4, 8.9, 11.5, 12.4, 13.3, 16.6, 17.9, 19.2, 35.9, 38.5, 41.1, 20.5, 21.8, 23.1, 26.8, 28.5, 30.2, 21.2, 23.0, 24.8, 33.8, 36.4, 39.0, 73.0, 78.2, 8.4f, 8.9f, 11.5f, 12.4f, 13.3f, 16.6f, 17.9f, 19.2f, 35.9f, 38.5f, 41.1f, 20.5f, 21.8f, 23.1f, 26.8f, 28.5f, 30.2f, 21.2f, 23.0f, 24.8f, 33.8f, 36.4f, 39.0f, 73.0f, 78.2f,
83.4, 41.6, 44.2, 46.8, 56.6, 60.0, 63.4, 16.9, 17.8, 18.7, 24.4, 25.7, 27., 51.5, 54.1, 56.7, 28.3, 29.6, 30.9, 37.0, 38.7, 40.4, 39.4, 41.5, 83.4f, 41.6f, 44.2f, 46.8f, 56.6f, 60.0f, 63.4f, 16.9f, 17.8f, 18.7f, 24.4f, 25.7f, 27.f, 51.5f, 54.1f, 56.7f, 28.3f, 29.6f, 30.9f, 37.0f, 38.7f, 40.4f, 39.4f, 41.5f,
43.6, 46.9, 49.4, 51.9, 100.1, 105.1, 110.1, 54.4, 56.9, 59.4, 63.1, 66.0, 68.9, 42.1, 45.4, 48.7, 47.2, 50.9, 54.6, 104.3, 111.7, 43.6f, 46.9f, 49.4f, 51.9f, 100.1f, 105.1f, 110.1f, 54.4f, 56.9f, 59.4f, 63.1f, 66.0f, 68.9f, 42.1f, 45.4f, 48.7f, 47.2f, 50.9f, 54.6f, 104.3f, 111.7f,
119.1, 58.3, 62.0, 65.7, 64.6, 68.7, 72.8, 57.4, 61.9, 66.4, 62.5, 67.4, 72.3, 138.5, 148.3, 158.1, 77.2, 82.1, 87.0, 83.5, 88.8, 94.1, 119.1f, 58.3f, 62.0f, 65.7f, 64.6f, 68.7f, 72.8f, 57.4f, 61.9f, 66.4f, 62.5f, 67.4f, 72.3f, 138.5f, 148.3f, 158.1f, 77.2f, 82.1f, 87.0f, 83.5f, 88.8f, 94.1f,
134.6, 143.6, 152.6, 147.2, 157.0, 166.8, 321.4, 341.0, 360.6, 176.6, 186.4, 196.2, 191.6, 202.2, 212.8, 84.4, 88.9, 134.6f, 143.6f, 152.6f, 147.2f, 157.0f, 166.8f, 321.4f, 341.0f, 360.6f, 176.6f, 186.4f, 196.2f, 191.6f, 202.2f, 212.8f, 84.4f, 88.9f,
93.4, 91.9, 96.8, 101.7, 197.3, 207.1, 216.9, 106.6, 111.5, 116.4, 115.3, 120.6, 125.9, 106.9, 112.6, 118.3, 114.4, 120.5, 126.6, 245.9, 258.1, 270.3, 132.7, 138.8, 144.9, 141.4, 147.9, 154.4}); 93.4f, 91.9f, 96.8f, 101.7f, 197.3f, 207.1f, 216.9f, 106.6f, 111.5f, 116.4f, 115.3f, 120.6f, 125.9f, 106.9f, 112.6f, 118.3f, 114.4f, 120.5f, 126.6f, 245.9f, 258.1f, 270.3f, 132.7f, 138.8f, 144.9f, 141.4f, 147.9f, 154.4f});
input.linspace(-10, 0.5); input.linspace(-10, 0.5);
weights.linspace(0.1, 0.1); weights.linspace(0.1, 0.1);
@ -699,7 +698,7 @@ TEST_F(ConvolutionTests2, deconv3d_bp_test3) {
int dataFormat = 0; // 1-NDHWC, 0-NCDHW int dataFormat = 0; // 1-NDHWC, 0-NCDHW
auto input = NDArrayFactory::create<float>('c', {bS, oC, oD, oH, oW}); auto input = NDArrayFactory::create<float>('c', {bS, oC, oD, oH, oW});
auto weights = NDArrayFactory::create<float>('c', {kD, kH, kW, iC, oC}, {0.1,0.9,0.2,0.1,0.3,1.1,0.4,1.2,0.5,1.3,0.6,1.4,0.7,1.5,0.8,1.6}); auto weights = NDArrayFactory::create<float>('c', {kD, kH, kW, iC, oC}, {0.1f, 0.9f, 0.2f, 0.1f, 0.3f, 1.1f, 0.4f, 1.2f, 0.5f, 1.3f, 0.6f, 1.4f, 0.7f, 1.5f, 0.8f, 1.6f});
auto gradO = NDArrayFactory::create<float>('c', {bS, iC, iD, iH, iW}); auto gradO = NDArrayFactory::create<float>('c', {bS, iC, iD, iH, iW});
NDArray expGradI('c', {bS, oD, oH, oW, oC}, {33.8, 37.4, 44.6, 48.2, 66.2, 69.8, 77., 80.6, 77.25, 86.35, 104.55, 113.65, 159.15, 168.25, 186.45, 195.55}, nd4j::DataType::FLOAT32); NDArray expGradI('c', {bS, oD, oH, oW, oC}, {33.8, 37.4, 44.6, 48.2, 66.2, 69.8, 77., 80.6, 77.25, 86.35, 104.55, 113.65, 159.15, 168.25, 186.45, 195.55}, nd4j::DataType::FLOAT32);
@ -734,7 +733,7 @@ TEST_F(ConvolutionTests2, deconv3d_bp_test4) {
int dataFormat = 0; // 1-NDHWC, 0-NCDHW int dataFormat = 0; // 1-NDHWC, 0-NCDHW
auto input = NDArrayFactory::create<float>('c', {bS, oC, oD, oH, oW}); auto input = NDArrayFactory::create<float>('c', {bS, oC, oD, oH, oW});
auto weights = NDArrayFactory::create<float>('c', {kD, kH, kW, iC, oC}, {0.1,0.9,0.2,0.1,0.3,1.1,0.4,1.2,0.5,1.3,0.6,1.4,0.7,1.5,0.8,1.6}); auto weights = NDArrayFactory::create<float>('c', {kD, kH, kW, iC, oC}, {0.1f, 0.9f, 0.2f, 0.1f, 0.3f, 1.1f, 0.4f, 1.2f, 0.5f, 1.3f, 0.6f, 1.4f, 0.7f, 1.5f, 0.8f, 1.6f});
auto gradO = NDArrayFactory::create<float>('c', {bS, iC, iD, iH, iW}); auto gradO = NDArrayFactory::create<float>('c', {bS, iC, iD, iH, iW});
NDArray expGradI('c', {bS, oC, oD, oH, oW}, {0.4, 1.55, 1.05, 2.3, 5.7, 3.2, 1.5, 3.35, 1.75, 3.8, 8.3, 4.3, 9.0, 18.6, 9.2, 4.4, 8.7, 4.1, 1.8, 3.55, 1.65, 3.5, 6.5, 2.8, 1.3, 2.15, 0.75, 0.8, 3.15, 2.25, 4.7, 12.1, 7.2, 3.5, 8.15, 4.55, 7.8, 17.9, 9.9, 19.75, 42.85, 23.6, 9.35, 21.55, 12.9, 5.4, 11.55, 6.05, 8.25, 20.75, 13.2, 0.65, 6.6, 6.75}, nd4j::DataType::FLOAT32); NDArray expGradI('c', {bS, oC, oD, oH, oW}, {0.4, 1.55, 1.05, 2.3, 5.7, 3.2, 1.5, 3.35, 1.75, 3.8, 8.3, 4.3, 9.0, 18.6, 9.2, 4.4, 8.7, 4.1, 1.8, 3.55, 1.65, 3.5, 6.5, 2.8, 1.3, 2.15, 0.75, 0.8, 3.15, 2.25, 4.7, 12.1, 7.2, 3.5, 8.15, 4.55, 7.8, 17.9, 9.9, 19.75, 42.85, 23.6, 9.35, 21.55, 12.9, 5.4, 11.55, 6.05, 8.25, 20.75, 13.2, 0.65, 6.6, 6.75}, nd4j::DataType::FLOAT32);
@ -1062,14 +1061,14 @@ TYPED_TEST(TypedConvolutionTests2, maxpool2d_10) {
int paddingMode = 0; // 1-SAME, 0-VALID; int paddingMode = 0; // 1-SAME, 0-VALID;
int dataFormat = 0; // 1-NHWC, 0-NCHW int dataFormat = 0; // 1-NHWC, 0-NCHW
auto input = NDArrayFactory::create<TypeParam>('c', {bS, iC, iH, iW}, {0.27620894, 0.21801452, 0.062078513, 7.348895E-4, 0.24149609, 0.4948205, 0.93483436, 0.52035654, 0.30292067, 0.3289706, 0.7977864, auto input = NDArrayFactory::create<TypeParam>('c', {bS, iC, iH, iW}, {0.27620894f, 0.21801452f, 0.062078513f, 7.348895E-4f, 0.24149609f, 0.4948205f, 0.93483436f, 0.52035654f, 0.30292067f, 0.3289706f, 0.7977864f,
0.03180518, 0.1455722, 0.90352905, 0.9405744, 0.0048329555, 0.44062102, 0.111197524, 0.31742015, 0.1933705, 0.23825112, 0.35076278, 0.7135856, 0.28229436, 0.18310733, 0.03180518f, 0.1455722f, 0.90352905f, 0.9405744f, 0.0048329555f, 0.44062102f, 0.111197524f, 0.31742015f, 0.1933705f, 0.23825112f, 0.35076278f, 0.7135856f, 0.28229436f, 0.18310733f,
0.9613717, 0.56823575, 0.78289545, 0.62195826, 0.5244586, 0.5040889, 0.025349546, 0.41400263, 0.28420195, 0.8536445, 0.3044107, 0.7997134, 0.45762005, 0.7653578, 0.9613717f, 0.56823575f, 0.78289545f, 0.62195826f, 0.5244586f, 0.5040889f, 0.025349546f, 0.41400263f, 0.28420195f, 0.8536445f, 0.3044107f, 0.7997134f, 0.45762005f, 0.7653578f,
0.07198584, 0.5304998, 0.7334402, 0.85019743, 0.031957153, 0.37088063, 0.85722464, 0.06376881, 0.39791203}); 0.07198584f, 0.5304998f, 0.7334402f, 0.85019743f, 0.031957153f, 0.37088063f, 0.85722464f, 0.06376881f, 0.39791203f});
auto expOutput = NDArrayFactory::create<TypeParam>('c', {bS, iC, oH, oW}, {0.4948205, 0.93483436, 0.93483436, 0.4948205, 0.93483436, 0.93483436, 0.90352905, 0.9405744, 0.9405744, 0.44062102, 0.7135856, auto expOutput = NDArrayFactory::create<TypeParam>('c', {bS, iC, oH, oW}, {0.4948205f, 0.93483436f, 0.93483436f, 0.4948205f, 0.93483436f, 0.93483436f, 0.90352905f, 0.9405744f, 0.9405744f, 0.44062102f, 0.7135856f,
0.7135856, 0.9613717, 0.9613717, 0.78289545, 0.9613717, 0.9613717, 0.78289545, 0.7997134, 0.8536445, 0.8536445, 0.7997134, 0.85019743, 0.85019743, 0.7135856f, 0.9613717f, 0.9613717f, 0.78289545f, 0.9613717f, 0.9613717f, 0.78289545f, 0.7997134f, 0.8536445f, 0.8536445f, 0.7997134f, 0.85019743f, 0.85019743f,
0.85722464, 0.85722464, 0.85019743}); 0.85722464f, 0.85722464f, 0.85019743f});
nd4j::ops::maxpool2d op; nd4j::ops::maxpool2d op;
auto results = op.execute({&input}, {}, {kH,kW, sH,sW, pH,pW, dH,dW, paddingMode}); auto results = op.execute({&input}, {}, {kH,kW, sH,sW, pH,pW, dH,dW, paddingMode});
@ -1108,9 +1107,9 @@ TYPED_TEST(TypedConvolutionTests2, avgpool3d_test1) {
int dataFormat = 0; // 1-NDHWC, 0-NCDHW int dataFormat = 0; // 1-NDHWC, 0-NCDHW
auto input = NDArrayFactory::create<TypeParam>('c', {bS, iC, iD, iH, iW}); auto input = NDArrayFactory::create<TypeParam>('c', {bS, iC, iD, iH, iW});
auto expected = NDArrayFactory::create<TypeParam>('c', {bS, iC, oD, oH, oW}, {10.5, 11.5, 13.5, 14.5, 22.5, 23.5, 25.5, 26.5, 46.5, 47.5, 49.5, 50.5, 58.5, 59.5, 61.5, 62.5, auto expected = NDArrayFactory::create<TypeParam>('c', {bS, iC, oD, oH, oW}, {10.5f, 11.5f, 13.5f, 14.5f, 22.5f, 23.5f, 25.5f, 26.5f, 46.5f, 47.5f, 49.5f, 50.5f, 58.5f, 59.5f, 61.5f, 62.5f,
82.5, 83.5, 85.5, 86.5, 94.5, 95.5, 97.5, 98.5,118.5,119.5,121.5,122.5,130.5,131.5,133.5,134.5, 82.5f, 83.5f, 85.5f, 86.5f, 94.5f, 95.5f, 97.5f, 98.5f,118.5f,119.5f,121.5f,122.5f,130.5f,131.5f,133.5f,134.5f,
154.5,155.5,157.5,158.5,166.5,167.5,169.5,170.5,190.5,191.5,193.5,194.5,202.5,203.5,205.5,206.5}); 154.5f,155.5f,157.5f,158.5f,166.5f,167.5f,169.5f,170.5f,190.5f,191.5f,193.5f,194.5f,202.5f,203.5f,205.5f,206.5f});
input.linspace(1.); input.linspace(1.);
nd4j::ops::avgpool3dnew op; nd4j::ops::avgpool3dnew op;
@ -1133,12 +1132,12 @@ TYPED_TEST(TypedConvolutionTests2, avgpool3d_test2) {
int dataFormat = 1; // 1-NDHWC, 0-NCDHW int dataFormat = 1; // 1-NDHWC, 0-NCDHW
auto input = NDArrayFactory::create<TypeParam>('c', {bS, iD, iH, iW, iC}); auto input = NDArrayFactory::create<TypeParam>('c', {bS, iD, iH, iW, iC});
auto expected = NDArrayFactory::create<TypeParam>('c', {bS, oD, oH, oW, iC}, { 25. , 26. , 27. , 28. , 29. , 30. , 29.5, 30.5, 31.5, 29.5, 30.5, 31.5, 32.5, 33.5, 34.5, 34. , 35. , 36. , 38.5, 39.5, 40.5, 41.5, 42.5, 43.5, 43. , 44. , 45. , 43. , 44. , 45. , 46. , 47. , 48. , 47.5, 48.5, 49.5, auto expected = NDArrayFactory::create<TypeParam>('c', {bS, oD, oH, oW, iC}, { 25.f, 26.f, 27.f, 28.f, 29.f, 30.f, 29.5f, 30.5f, 31.5f, 29.5f, 30.5f, 31.5f, 32.5f, 33.5f, 34.5f, 34.f, 35.f, 36.f, 38.5f, 39.5f, 40.5f, 41.5f, 42.5f, 43.5f, 43.f, 44.f, 45.f, 43.f, 44.f, 45.f, 46.f, 47.f, 48.f, 47.5f, 48.5f, 49.5f,
61. , 62. , 63. , 64. , 65. , 66. , 65.5, 66.5, 67.5, 65.5, 66.5, 67.5, 68.5, 69.5, 70.5, 70. , 71. , 72. , 74.5, 75.5, 76.5, 77.5, 78.5, 79.5, 79. , 80. , 81. , 79. , 80. , 81. , 82. , 83. , 84. , 83.5, 84.5, 85.5, 61.f, 62.f, 63.f, 64.f, 65.f, 66.f, 65.5f, 66.5f, 67.5f, 65.5f, 66.5f, 67.5f, 68.5f, 69.5f, 70.5f, 70.f, 71.f, 72.f, 74.5f, 75.5f, 76.5f, 77.5f, 78.5f, 79.5f, 79.f, 80.f, 81.f, 79.f, 80.f, 81.f, 82.f, 83.f, 84.f, 83.5f, 84.5f, 85.5f,
79. , 80. , 81. , 82. , 83. , 84. , 83.5, 84.5, 85.5, 83.5, 84.5, 85.5, 86.5, 87.5, 88.5, 88. , 89. , 90. , 92.5, 93.5, 94.5, 95.5, 96.5, 97.5, 97. , 98. , 99. , 97. , 98. , 99. ,100. ,101. ,102. ,101.5,102.5,103.5, 79.f, 80.f, 81.f, 82.f, 83.f, 84.f, 83.5f, 84.5f, 85.5f, 83.5f, 84.5f, 85.5f, 86.5f, 87.5f, 88.5f, 88.f, 89.f, 90.f, 92.5f, 93.5f, 94.5f, 95.5f, 96.5f, 97.5f, 97.f, 98.f, 99.f, 97.f, 98.f, 99.f, 100.f, 101.f, 102.f, 101.5f, 102.5f, 103.5f,
133. ,134. ,135. ,136. ,137. ,138. ,137.5,138.5,139.5,137.5,138.5,139.5,140.5,141.5,142.5,142. ,143. ,144. ,146.5,147.5,148.5,149.5,150.5,151.5,151. ,152. ,153. ,151. ,152. ,153. ,154. ,155. ,156. ,155.5,156.5,157.5, 133.f, 134.f, 135.f, 136.f, 137.f, 138.f, 137.5f, 138.5f, 139.5f, 137.5f, 138.5f, 139.5f, 140.5f, 141.5f, 142.5f, 142.f, 143.f, 144.f, 146.5f, 147.5f, 148.5f, 149.5f, 150.5f, 151.5f, 151.f, 152.f, 153.f, 151.f, 152.f, 153.f, 154.f, 155.f, 156.f, 155.5f, 156.5f, 157.5f,
169. ,170. ,171. ,172. ,173. ,174. ,173.5,174.5,175.5,173.5,174.5,175.5,176.5,177.5,178.5,178. ,179. ,180. ,182.5,183.5,184.5,185.5,186.5,187.5,187. ,188. ,189. ,187. ,188. ,189. ,190. ,191. ,192. ,191.5,192.5,193.5, 169.f, 170.f, 171.f, 172.f, 173.f, 174.f, 173.5f, 174.5f, 175.5f, 173.5f, 174.5f, 175.5f, 176.5f, 177.5f, 178.5f, 178.f, 179.f, 180.f, 182.5f, 183.5f, 184.5f, 185.5f, 186.5f, 187.5f, 187.f, 188.f, 189.f, 187.f, 188.f, 189.f, 190.f, 191.f, 192.f, 191.5f, 192.5f, 193.5f,
187. ,188. ,189. ,190. ,191. ,192. ,191.5,192.5,193.5,191.5,192.5,193.5,194.5,195.5,196.5,196. ,197. ,198. ,200.5,201.5,202.5,203.5,204.5,205.5,205. ,206. ,207. ,205. ,206. ,207. ,208. ,209. ,210. ,209.5,210.5,211.5}); 187.f, 188.f, 189.f, 190.f, 191.f, 192.f, 191.5f, 192.5f, 193.5f, 191.5f, 192.5f, 193.5f, 194.5f, 195.5f, 196.5f, 196.f, 197.f, 198.f, 200.5f, 201.5f, 202.5f, 203.5f, 204.5f, 205.5f, 205.f, 206.f, 207.f, 205.f, 206.f, 207.f, 208.f, 209.f, 210.f, 209.5f, 210.5f, 211.5f});
input.linspace(1.); input.linspace(1.);
nd4j::ops::avgpool3dnew op; nd4j::ops::avgpool3dnew op;
@ -1161,9 +1160,9 @@ TYPED_TEST(TypedConvolutionTests2, avgpool3d_test3) {
int dataFormat = 1; // 1-NDHWC, 0-NCDHW int dataFormat = 1; // 1-NDHWC, 0-NCDHW
auto input = NDArrayFactory::create<TypeParam>('c', {bS, iD, iH, iW, iC}); auto input = NDArrayFactory::create<TypeParam>('c', {bS, iD, iH, iW, iC});
auto expected = NDArrayFactory::create<TypeParam>('c', {bS, oD, oH, oW, iC}, { 29.5, 30.5, 31.5, 32.5, 33.5, 34.5, 38.5, 39.5, 40.5, 41.5, 42.5, 43.5, 65.5, 66.5, 67.5, 68.5, 69.5, 70.5, auto expected = NDArrayFactory::create<TypeParam>('c', {bS, oD, oH, oW, iC}, { 29.5f, 30.5f, 31.5f, 32.5f, 33.5f, 34.5f, 38.5f, 39.5f, 40.5f, 41.5f, 42.5f, 43.5f, 65.5f, 66.5f, 67.5f, 68.5f, 69.5f, 70.5f,
74.5, 75.5, 76.5, 77.5, 78.5, 79.5,137.5,138.5,139.5,140.5,141.5,142.5,146.5,147.5,148.5,149.5,150.5,151.5, 74.5f, 75.5f, 76.5f, 77.5f, 78.5f, 79.5f, 137.5f, 138.5f, 139.5f, 140.5f, 141.5f, 142.5f, 146.5f, 147.5f, 148.5f, 149.5f, 150.5f, 151.5f,
173.5,174.5,175.5,176.5,177.5,178.5,182.5,183.5,184.5,185.5,186.5,187.5}); 173.5f, 174.5f, 175.5f, 176.5f, 177.5f, 178.5f, 182.5f, 183.5f, 184.5f, 185.5f, 186.5f, 187.5f});
input.linspace(1.); input.linspace(1.);
nd4j::ops::avgpool3dnew op; nd4j::ops::avgpool3dnew op;
@ -1186,24 +1185,24 @@ TYPED_TEST(TypedConvolutionTests2, avgpool3d_test4) {
int dataFormat = 0; // 1-NDHWC, 0-NCDHW int dataFormat = 0; // 1-NDHWC, 0-NCDHW
auto input = NDArrayFactory::create<TypeParam>('c', {bS, iC, iD, iH, iW}); auto input = NDArrayFactory::create<TypeParam>('c', {bS, iC, iD, iH, iW});
auto expected = NDArrayFactory::create<TypeParam>('c', {bS, iC, oD, oH, oW},{0.416667, 1.00, 1.333333, 0.75, 1.00, 2.25, 2.75, 1.50, 1.75, 3.75, 4.25, 2.25, 1.416667, 3.00, 3.333333, 1.75, 2.833333, 6.00, 6.666667, 3.50, 5.00, 10.50, 11.50, 6.00, 6.50, auto expected = NDArrayFactory::create<TypeParam>('c', {bS, iC, oD, oH, oW},{0.416667f, 1.00f, 1.333333f, 0.75f, 1.00f, 2.25f, 2.75f, 1.50f, 1.75f, 3.75f, 4.25f, 2.25f, 1.416667f, 3.00f, 3.333333f, 1.75f, 2.833333f, 6.00f, 6.666667f, 3.50f, 5.00f, 10.50f, 11.50f, 6.00f, 6.50f,
13.50, 14.50, 7.50, 4.833333, 10.00, 10.666667, 5.50, 6.833333, 14.00, 14.666667, 7.50, 11.00, 22.50, 23.50, 12.00, 12.50, 25.50, 26.50, 13.50, 8.833333, 18.00, 18.666666, 9.50, 13.50f, 14.50f, 7.50f, 4.833333f, 10.00f, 10.666667f, 5.50f, 6.833333f, 14.00f, 14.666667f, 7.50f, 11.00f, 22.50f, 23.50f, 12.00f, 12.50f, 25.50f, 26.50f, 13.50f, 8.833333f, 18.00f, 18.666666f, 9.50f,
4.416667, 9.00, 9.333333, 4.75, 7.00, 14.25, 14.75, 7.50, 7.75, 15.75, 16.25, 8.25, 5.416667, 11.00, 11.333333, 5.75, 6.416667, 13.00, 13.333333, 6.75, 10.00, 20.25, 20.75, 4.416667f, 9.00f, 9.333333f, 4.75f, 7.00f, 14.25f, 14.75f, 7.50f, 7.75f, 15.75f, 16.25f, 8.25f, 5.416667f, 11.00f, 11.333333f, 5.75f, 6.416667f, 13.00f, 13.333333f, 6.75f, 10.00f, 20.25f, 20.75f,
10.50, 10.75, 21.75, 22.25, 11.25, 7.416667, 15.00, 15.333333, 7.75, 14.833333, 30.00, 30.666666, 15.50, 23.00, 46.50, 47.50, 24.00, 24.50, 49.50, 50.50, 25.50, 16.833334, 10.50f, 10.75f, 21.75f, 22.25f, 11.25f, 7.416667f, 15.00f, 15.333333f, 7.75f, 14.833333f, 30.00f, 30.666666f, 15.50f, 23.00f, 46.50f, 47.50f, 24.00f, 24.50f, 49.50f, 50.50f, 25.50f, 16.833334f,
34.00, 34.666668, 17.50, 18.833334, 38.00, 38.666668, 19.50, 29.00, 58.50, 59.50, 30.00, 30.50, 61.50, 62.50, 31.50, 20.833334, 42.00, 42.666668, 21.50, 10.416667, 21.00, 34.00f, 34.666668f, 17.50f, 18.833334f, 38.00f, 38.666668f, 19.50f, 29.00f, 58.50f, 59.50f, 30.00f, 30.50f, 61.50f, 62.50f, 31.50f, 20.833334f, 42.00f, 42.666668f, 21.50f, 10.416667f, 21.00f,
21.333334, 10.75, 16.00, 32.25, 32.75, 16.50, 16.75, 33.75, 34.25, 17.25, 11.416667, 23.00, 23.333334, 11.75, 12.416667, 25.00, 25.333334, 12.75, 19.00, 38.25, 38.75, 19.50, 21.333334f, 10.75f, 16.00f, 32.25f, 32.75f, 16.50f, 16.75f, 33.75f, 34.25f, 17.25f, 11.416667f, 23.00f, 23.333334f, 11.75f, 12.416667f, 25.00f, 25.333334f, 12.75f, 19.00f, 38.25f, 38.75f, 19.50f,
19.75, 39.75, 40.25, 20.25, 13.416667, 27.00, 27.333334, 13.75, 26.833334, 54.00, 54.666668, 27.50, 41.00, 82.50, 83.50, 42.00, 42.50, 85.50, 86.50, 43.50, 28.833334, 58.00, 19.75f, 39.75f, 40.25f, 20.25f, 13.416667f, 27.00f, 27.333334f, 13.75f, 26.833334f, 54.00f, 54.666668f, 27.50f, 41.00f, 82.50f, 83.50f, 42.00f, 42.50f, 85.50f, 86.50f, 43.50f, 28.833334f, 58.00f,
58.666668, 29.50, 30.833334, 62.00, 62.666668, 31.50, 47.00, 94.50, 95.50, 48.00, 48.50, 97.50, 98.50, 49.50, 32.833332, 66.00, 66.666664, 33.50, 16.416666, 33.00, 33.333332, 58.666668f, 29.50f, 30.833334f, 62.00f, 62.666668f, 31.50f, 47.00f, 94.50f, 95.50f, 48.00f, 48.50f, 97.50f, 98.50f, 49.50f, 32.833332f, 66.00f, 66.666664f, 33.50f, 16.416666f, 33.00f, 33.333332f,
16.75, 25.00, 50.25, 50.75, 25.50, 25.75, 51.75, 52.25, 26.25, 17.416666, 35.00, 35.333332, 17.75, 18.416666, 37.00, 37.333332, 18.75, 28.00, 56.25, 56.75, 28.50, 28.75, 16.75f, 25.00f, 50.25f, 50.75f, 25.50f, 25.75f, 51.75f, 52.25f, 26.25f, 17.416666f, 35.00f, 35.333332f, 17.75f, 18.416666f, 37.00f, 37.333332f, 18.75f, 28.00f, 56.25f, 56.75f, 28.50f, 28.75f,
57.75, 58.25, 29.25, 19.416666, 39.00, 39.333332, 19.75, 38.833332, 78.00, 78.666664, 39.50, 59.00, 118.50, 119.50, 60.00, 60.50, 121.50, 122.50, 61.50, 40.833332, 82.00, 57.75f, 58.25f, 29.25f, 19.416666f, 39.00f, 39.333332f, 19.75f, 38.833332f, 78.00f, 78.666664f, 39.50f, 59.00f, 118.50f, 119.50f, 60.00f, 60.50f, 121.50f, 122.50f, 61.50f, 40.833332f, 82.00f,
82.666664, 41.50, 42.833332, 86.00, 86.666664, 43.50, 65.00, 130.50, 131.50, 66.00, 66.50, 133.50, 134.50, 67.50, 44.833332, 90.00, 90.666664, 45.50, 22.416666, 45.00, 82.666664f, 41.50f, 42.833332f, 86.00f, 86.666664f, 43.50f, 65.00f, 130.50f, 131.50f, 66.00f, 66.50f, 133.50f, 134.50f, 67.50f, 44.833332f, 90.00f, 90.666664f, 45.50f, 22.416666f, 45.00f,
45.333332, 22.75, 34.00, 68.25, 68.75, 34.50, 34.75, 69.75, 70.25, 35.25, 23.416666, 47.00, 47.333332, 23.75, 24.416666, 49.00, 49.333332, 24.75, 37.00, 74.25, 74.75, 45.333332f, 22.75f, 34.00f, 68.25f, 68.75f, 34.50f, 34.75f, 69.75f, 70.25f, 35.25f, 23.416666f, 47.00f, 47.333332f, 23.75f, 24.416666f, 49.00f, 49.333332f, 24.75f, 37.00f, 74.25f, 74.75f,
37.50, 37.75, 75.75, 76.25, 38.25, 25.416666, 51.00, 51.333332, 25.75, 50.833332, 102.00, 102.666664, 51.50, 77.00, 154.50, 155.50, 78.00, 78.50, 157.50, 158.50, 79.50, 37.50f, 37.75f, 75.75f, 76.25f, 38.25f, 25.416666f, 51.00f, 51.333332f, 25.75f, 50.833332f, 102.00f, 102.666664f, 51.50f, 77.00f, 154.50f, 155.50f, 78.00f, 78.50f, 157.50f, 158.50f, 79.50f,
52.833332, 106.00, 106.666664, 53.50, 54.833332, 110.00, 110.666664, 55.50, 83.00, 166.50, 167.50, 84.00, 84.50, 169.50, 170.50, 85.50, 56.833332, 114.00, 114.666664, 52.833332f, 106.00f, 106.666664f, 53.50f, 54.833332f, 110.00f, 110.666664f, 55.50f, 83.00f, 166.50f, 167.50f, 84.00f, 84.50f, 169.50f, 170.50f, 85.50f, 56.833332f, 114.00f, 114.666664f,
57.50, 28.416666, 57.00, 57.333332, 28.75, 43.00, 86.25, 86.75, 43.50, 43.75, 87.75, 88.25, 44.25, 29.416666, 59.00, 59.333332, 29.75, 30.416666, 61.00, 61.333332, 30.75, 57.50f, 28.416666f, 57.00f, 57.333332f, 28.75f, 43.00f, 86.25f, 86.75f, 43.50f, 43.75f, 87.75f, 88.25f, 44.25f, 29.416666f, 59.00f, 59.333332f, 29.75f, 30.416666f, 61.00f, 61.333332f, 30.75f,
46.00, 92.25, 92.75, 46.50, 46.75, 93.75, 94.25, 47.25, 31.416666, 63.00, 63.333332, 31.75, 62.833332, 126.00, 126.666664, 63.50, 95.00, 190.50, 191.50, 96.00, 96.50, 46.00f, 92.25f, 92.75f, 46.50f, 46.75f, 93.75f, 94.25f, 47.25f, 31.416666f, 63.00f, 63.333332f, 31.75f, 62.833332f, 126.00f, 126.666664f, 63.50f, 95.00f, 190.50f, 191.50f, 96.00f, 96.50f,
193.50, 194.50, 97.50, 64.833336, 130.00, 130.666672, 65.50, 66.833336, 134.00, 134.666672, 67.50, 101.00, 202.50, 203.50, 102.00, 102.50, 205.50, 206.50, 103.50, 193.50f, 194.50f, 97.50f, 64.833336f, 130.00f, 130.666672f, 65.50f, 66.833336f, 134.00f, 134.666672f, 67.50f, 101.00f, 202.50f, 203.50f, 102.00f, 102.50f, 205.50f, 206.50f, 103.50f,
68.833336, 138.00, 138.666672, 69.50, 34.416668, 69.00, 69.333336, 34.75, 52.00, 104.25, 104.75, 52.50, 52.75, 105.75, 106.25, 53.25, 35.416668, 71.00, 71.333336, 35.75}); 68.833336f, 138.00f, 138.666672f, 69.50f, 34.416668f, 69.00f, 69.333336f, 34.75f, 52.00f, 104.25f, 104.75f, 52.50f, 52.75f, 105.75f, 106.25f, 53.25f, 35.416668f, 71.00f, 71.333336f, 35.75f});
input.linspace(1.); input.linspace(1.);
nd4j::ops::avgpool3dnew op; nd4j::ops::avgpool3dnew op;
@ -1226,8 +1225,8 @@ TYPED_TEST(TypedConvolutionTests2, maxpool3d_test1) {
int dataFormat = 0; // 1-NDHWC, 0-NCDHW int dataFormat = 0; // 1-NDHWC, 0-NCDHW
auto input = NDArrayFactory::create<TypeParam>('c', {bS, iC, iD, iH, iW}); auto input = NDArrayFactory::create<TypeParam>('c', {bS, iC, iD, iH, iW});
auto expected = NDArrayFactory::create<TypeParam>('c', {bS, iC, oD, oH, oW}, {20., 21., 23., 24., 32., 33., 35., 36., 56., 57., 59., 60., 68., 69., 71., 72., 92., 93., 95., 96.,104.,105.,107.,108., auto expected = NDArrayFactory::create<TypeParam>('c', {bS, iC, oD, oH, oW}, {20.f, 21.f, 23.f, 24.f, 32.f, 33.f, 35.f, 36.f, 56.f, 57.f, 59.f, 60.f, 68.f, 69.f, 71.f, 72.f, 92.f, 93.f, 95.f, 96.f, 104.f, 105.f, 107.f, 108.f,
128.,129.,131.,132.,140.,141.,143.,144.,164.,165.,167.,168.,176.,177.,179.,180.,200.,201.,203.,204.,212.,213.,215.,216.}); 128.f, 129.f, 131.f, 132.f, 140.f, 141.f, 143.f, 144.f, 164.f, 165.f, 167.f, 168.f, 176.f, 177.f, 179.f, 180.f, 200.f, 201.f, 203.f, 204.f, 212.f, 213.f, 215.f, 216.f});
input.linspace(1.); input.linspace(1.);
nd4j::ops::maxpool3dnew op; nd4j::ops::maxpool3dnew op;
@ -1250,12 +1249,12 @@ TYPED_TEST(TypedConvolutionTests2, maxpool3d_test2) {
int dataFormat = 1; // 1-NDHWC, 0-NCDHW int dataFormat = 1; // 1-NDHWC, 0-NCDHW
auto input = NDArrayFactory::create<TypeParam>('c', {bS, iD, iH, iW, iC}); auto input = NDArrayFactory::create<TypeParam>('c', {bS, iD, iH, iW, iC});
auto expected = NDArrayFactory::create<TypeParam>('c', {bS, oD, oH, oW, iC}, { 49., 50., 51., 52., 53., 54., 52., 53., 54., 58., 59., 60., 61., 62., 63., 61., 62., 63., 67., 68., 69., 70., 71., 72., 70., 71., 72., 67., 68., 69., 70., 71., 72., 70., 71., 72., auto expected = NDArrayFactory::create<TypeParam>('c', {bS, oD, oH, oW, iC}, { 49.f, 50.f, 51.f, 52.f, 53.f, 54.f, 52.f, 53.f, 54.f, 58.f, 59.f, 60.f, 61.f, 62.f, 63.f, 61.f, 62.f, 63.f, 67.f, 68.f, 69.f, 70.f, 71.f, 72.f, 70.f, 71.f, 72.f, 67.f, 68.f, 69.f, 70.f, 71.f, 72.f, 70.f, 71.f, 72.f,
85., 86., 87., 88., 89., 90., 88., 89., 90., 94., 95., 96., 97., 98., 99., 97., 98., 99.,103., 104., 105.,106., 107., 108.,106., 107., 108.,103., 104., 105.,106., 107., 108.,106., 107., 108., 85.f, 86.f, 87.f, 88.f, 89.f, 90.f, 88.f, 89.f, 90.f, 94.f, 95.f, 96.f, 97.f, 98.f, 99.f, 97.f, 98.f, 99.f, 103.f, 104.f, 105.f, 106.f, 107.f, 108.f, 106.f, 107.f, 108.f, 103.f, 104.f, 105.f, 106.f, 107.f, 108.f, 106.f, 107.f, 108.f,
85., 86., 87., 88., 89., 90., 88., 89., 90., 94., 95., 96., 97., 98., 99., 97., 98., 99.,103., 104., 105.,106., 107., 108.,106., 107., 108.,103., 104., 105.,106., 107., 108.,106., 107., 108., 85.f, 86.f, 87.f, 88.f, 89.f, 90.f, 88.f, 89.f, 90.f, 94.f, 95.f, 96.f, 97.f, 98.f, 99.f, 97.f, 98.f, 99.f, 103.f, 104.f, 105.f, 106.f, 107.f, 108.f, 106.f, 107.f, 108.f, 103.f, 104.f, 105.f, 106.f, 107.f, 108.f, 106.f, 107.f, 108.f,
157., 158., 159.,160., 161., 162.,160., 161., 162.,166., 167., 168.,169., 170., 171.,169., 170., 171.,175., 176., 177.,178., 179., 180.,178., 179., 180.,175., 176., 177.,178., 179., 180.,178., 179., 180., 157.f, 158.f, 159.f, 160.f, 161.f, 162.f, 160.f, 161.f, 162.f, 166.f, 167.f, 168.f, 169.f, 170.f, 171.f, 169.f, 170.f, 171.f, 175.f, 176.f, 177.f, 178.f, 179.f, 180.f, 178.f, 179.f, 180.f, 175.f, 176.f, 177.f, 178.f, 179.f, 180.f, 178.f, 179.f, 180.f,
193., 194., 195.,196., 197., 198.,196., 197., 198.,202., 203., 204.,205., 206., 207.,205., 206., 207.,211., 212., 213.,214., 215., 216.,214., 215., 216.,211., 212., 213.,214., 215., 216.,214., 215., 216., 193.f, 194.f, 195.f, 196.f, 197.f, 198.f, 196.f, 197.f, 198.f, 202.f, 203.f, 204.f, 205.f, 206.f, 207.f, 205.f, 206.f, 207.f, 211.f, 212.f, 213.f, 214.f, 215.f, 216.f, 214.f, 215.f, 216.f, 211.f, 212.f, 213.f, 214.f, 215.f, 216.f, 214.f, 215.f, 216.f,
193., 194., 195.,196., 197., 198.,196., 197., 198.,202., 203., 204.,205., 206., 207.,205., 206., 207.,211., 212., 213.,214., 215., 216.,214., 215., 216.,211., 212., 213.,214., 215., 216.,214., 215., 216.}); 193.f, 194.f, 195.f, 196.f, 197.f, 198.f, 196.f, 197.f, 198.f, 202.f, 203.f, 204.f, 205.f, 206.f, 207.f, 205.f, 206.f, 207.f, 211.f, 212.f, 213.f, 214.f, 215.f, 216.f, 214.f, 215.f, 216.f, 211.f, 212.f, 213.f, 214.f, 215.f, 216.f, 214.f, 215.f, 216.f});
input.linspace(1.); input.linspace(1.);
nd4j::ops::maxpool3dnew op; nd4j::ops::maxpool3dnew op;
@ -1278,8 +1277,8 @@ TYPED_TEST(TypedConvolutionTests2, maxpool3d_test3) {
int dataFormat = 1; // 1-NDHWC, 0-NCDHW int dataFormat = 1; // 1-NDHWC, 0-NCDHW
auto input = NDArrayFactory::create<TypeParam>('c', {bS, iD, iH, iW, iC}); auto input = NDArrayFactory::create<TypeParam>('c', {bS, iD, iH, iW, iC});
auto expected = NDArrayFactory::create<TypeParam>('c', {bS, oD, oH, oW, iC}, {58., 59., 60., 61., 62., 63., 67., 68., 69., 70., 71., 72., 94., 95., 96., 97., 98., 99.,103., 104., 105.,106., 107., 108., auto expected = NDArrayFactory::create<TypeParam>('c', {bS, oD, oH, oW, iC}, {58.f, 59.f, 60.f, 61.f, 62.f, 63.f, 67.f, 68.f, 69.f, 70.f, 71.f, 72.f, 94.f, 95.f, 96.f, 97.f, 98.f, 99.f, 103.f, 104.f, 105.f, 106.f, 107.f, 108.f,
166., 167., 168.,169., 170., 171.,175., 176., 177.,178., 179., 180.,202., 203., 204.,205., 206., 207.,211., 212., 213.,214., 215., 216.}); 166.f, 167.f, 168.f, 169.f, 170.f, 171.f, 175.f, 176.f, 177.f, 178.f, 179.f, 180.f, 202.f, 203.f, 204.f, 205.f, 206.f, 207.f, 211.f, 212.f, 213.f, 214.f, 215.f, 216.f});
input.linspace(1.); input.linspace(1.);
nd4j::ops::maxpool3dnew op; nd4j::ops::maxpool3dnew op;
@ -1302,14 +1301,14 @@ TYPED_TEST(TypedConvolutionTests2, maxpool3d_test4) {
int dataFormat = 0; // -NDHWC, 0-NCDHW int dataFormat = 0; // -NDHWC, 0-NCDHW
auto input = NDArrayFactory::create<TypeParam>('c', {bS, iC, iD, iH, iW}); auto input = NDArrayFactory::create<TypeParam>('c', {bS, iC, iD, iH, iW});
auto expected = NDArrayFactory::create<TypeParam>('c', {bS, iC, oD, oH, oW},{ 4., 5., 6., 6., 7., 8., 9., 9., 10., 11., 12., 12., 10., 11., 12., 12., 16., 17., 18., 18., 19., 20., 21., 21., 22., 23., 24., 24., 22., 23., 24., 24., 28., 29., 30., 30., 31., 32., 33., 33., 34., 35., 36., 36., 34., 35., 36., 36., auto expected = NDArrayFactory::create<TypeParam>('c', {bS, iC, oD, oH, oW},{ 4.f, 5.f, 6.f, 6.f, 7.f, 8.f, 9.f, 9.f, 10.f, 11.f, 12.f, 12.f, 10.f, 11.f, 12.f, 12.f, 16.f, 17.f, 18.f, 18.f, 19.f, 20.f, 21.f, 21.f, 22.f, 23.f, 24.f, 24.f, 22.f, 23.f, 24.f, 24.f, 28.f, 29.f, 30.f, 30.f, 31.f, 32.f, 33.f, 33.f, 34.f, 35.f, 36.f, 36.f, 34.f, 35.f, 36.f, 36.f,
28., 29., 30., 30., 31., 32., 33., 33., 34., 35., 36., 36., 34., 35., 36., 36., 40., 41., 42., 42., 43., 44., 45., 45., 46., 47., 48., 48., 46., 47., 48., 48., 52., 53., 54., 54., 55., 56., 57., 57., 58., 59., 60., 60., 58., 59., 60., 60., 28.f, 29.f, 30.f, 30.f, 31.f, 32.f, 33.f, 33.f, 34.f, 35.f, 36.f, 36.f, 34.f, 35.f, 36.f, 36.f, 40.f, 41.f, 42.f, 42.f, 43.f, 44.f, 45.f, 45.f, 46.f, 47.f, 48.f, 48.f, 46.f, 47.f, 48.f, 48.f, 52.f, 53.f, 54.f, 54.f, 55.f, 56.f, 57.f, 57.f, 58.f, 59.f, 60.f, 60.f, 58.f, 59.f, 60.f, 60.f,
64., 65., 66., 66., 67., 68., 69., 69., 70., 71., 72., 72., 70., 71., 72., 72., 64., 65., 66., 66., 67., 68., 69., 69., 70., 71., 72., 72., 70., 71., 72., 72., 76., 77., 78., 78., 79., 80., 81., 81., 82., 83., 84., 84., 82., 83., 84., 84., 64.f, 65.f, 66.f, 66.f, 67.f, 68.f, 69.f, 69.f, 70.f, 71.f, 72.f, 72.f, 70.f, 71.f, 72.f, 72.f, 64.f, 65.f, 66.f, 66.f, 67.f, 68.f, 69.f, 69.f, 70.f, 71.f, 72.f, 72.f, 70.f, 71.f, 72.f, 72.f, 76.f, 77.f, 78.f, 78.f, 79.f, 80.f, 81.f, 81.f, 82.f, 83.f, 84.f, 84.f, 82.f, 83.f, 84.f, 84.f,
88., 89., 90., 90., 91., 92., 93., 93., 94., 95., 96., 96., 94., 95., 96., 96.,100., 101., 102., 102.,103., 104., 105., 105.,106., 107., 108., 108.,106., 107., 108., 108.,100., 101., 102., 102.,103., 104., 105., 105.,106., 107., 108., 108.,106., 107., 108., 108., 88.f, 89.f, 90.f, 90.f, 91.f, 92.f, 93.f, 93.f, 94.f, 95.f, 96.f, 96.f, 94.f, 95.f, 96.f, 96.f, 100.f, 101.f, 102.f, 102.f, 103.f, 104.f, 105.f, 105.f, 106.f, 107.f, 108.f, 108.f, 106.f, 107.f, 108.f, 108.f, 100.f, 101.f, 102.f, 102.f, 103.f, 104.f, 105.f, 105.f, 106.f, 107.f, 108.f, 108.f, 106.f, 107.f, 108.f, 108.f,
112., 113., 114., 114.,115., 116., 117., 117.,118., 119., 120., 120.,118., 119., 120., 120.,124., 125., 126., 126.,127., 128., 129., 129.,130., 131., 132., 132.,130., 131., 132., 132.,136., 137., 138., 138.,139., 140., 141., 141.,142., 143., 144., 144.,142., 143., 144., 144., 112.f, 113.f, 114.f, 114.f, 115.f, 116.f, 117.f, 117.f, 118.f, 119.f, 120.f, 120.f, 118.f, 119.f, 120.f, 120.f, 124.f, 125.f, 126.f, 126.f, 127.f, 128.f, 129.f, 129.f, 130.f, 131.f, 132.f, 132.f, 130.f, 131.f, 132.f, 132.f, 136.f, 137.f, 138.f, 138.f, 139.f, 140.f, 141.f, 141.f, 142.f, 143.f, 144.f, 144.f, 142.f, 143.f, 144.f, 144.f,
136., 137., 138., 138.,139., 140., 141., 141.,142., 143., 144., 144.,142., 143., 144., 144.,148., 149., 150., 150.,151., 152., 153., 153.,154., 155., 156., 156.,154., 155., 156., 156.,160., 161., 162., 162.,163., 164., 165., 165.,166., 167., 168., 168.,166., 167., 168., 168., 136.f, 137.f, 138.f, 138.f, 139.f, 140.f, 141.f, 141.f, 142.f, 143.f, 144.f, 144.f, 142.f, 143.f, 144.f, 144.f, 148.f, 149.f, 150.f, 150.f, 151.f, 152.f, 153.f, 153.f, 154.f, 155.f, 156.f, 156.f, 154.f, 155.f, 156.f, 156.f, 160.f, 161.f, 162.f, 162.f, 163.f, 164.f, 165.f, 165.f, 166.f, 167.f, 168.f, 168.f, 166.f, 167.f, 168.f, 168.f,
172., 173., 174., 174.,175., 176., 177., 177.,178., 179., 180., 180.,178., 179., 180., 180.,172., 173., 174., 174.,175., 176., 177., 177.,178., 179., 180., 180.,178., 179., 180., 180.,184., 185., 186., 186.,187., 188., 189., 189.,190., 191., 192., 192.,190., 191., 192., 192., 172.f, 173.f, 174.f, 174.f, 175.f, 176.f, 177.f, 177.f, 178.f, 179.f, 180.f, 180.f, 178.f, 179.f, 180.f, 180.f, 172.f, 173.f, 174.f, 174.f, 175.f, 176.f, 177.f, 177.f, 178.f, 179.f, 180.f, 180.f, 178.f, 179.f, 180.f, 180.f, 184.f, 185.f, 186.f, 186.f, 187.f, 188.f, 189.f, 189.f, 190.f, 191.f, 192.f, 192.f, 190.f, 191.f, 192.f, 192.f,
196., 197., 198., 198.,199., 200., 201., 201.,202., 203., 204., 204.,202., 203., 204., 204.,208., 209., 210., 210.,211., 212., 213., 213.,214., 215., 216., 216.,214., 215., 216., 216.,208., 209., 210., 210.,211., 212., 213., 213.,214., 215., 216., 216.,214., 215., 216., 216.}); 196.f, 197.f, 198.f, 198.f, 199.f, 200.f, 201.f, 201.f, 202.f, 203.f, 204.f, 204.f, 202.f, 203.f, 204.f, 204.f, 208.f, 209.f, 210.f, 210.f, 211.f, 212.f, 213.f, 213.f, 214.f, 215.f, 216.f, 216.f, 214.f, 215.f, 216.f, 216.f, 208.f, 209.f, 210.f, 210.f, 211.f, 212.f, 213.f, 213.f, 214.f, 215.f, 216.f, 216.f, 214.f, 215.f, 216.f, 216.f});
input.linspace(1.); input.linspace(1.);
nd4j::ops::maxpool3dnew op; nd4j::ops::maxpool3dnew op;
@ -1333,15 +1332,15 @@ TYPED_TEST(TypedConvolutionTests2, avgpool3d_bp_test1) {
auto input = NDArrayFactory::create<TypeParam>('c', {bS, iC, iD, iH, iW}); auto input = NDArrayFactory::create<TypeParam>('c', {bS, iC, iD, iH, iW});
auto gradO = NDArrayFactory::create<TypeParam>('c', {bS, iC, oD, oH, oW}); auto gradO = NDArrayFactory::create<TypeParam>('c', {bS, iC, oD, oH, oW});
auto expected = NDArrayFactory::create<TypeParam>('c', {bS, iC, iD, iH, iW}, {0.166667, 0.333333, 0.166667,0.333333, 0.666667, 0.333333,0.333333, 0.666667, 0.333333,0.166667, 0.333333, 0.166667,0.333333, 0.666667, 0.333333,0.666667, 1.333333, 0.666667,0.666667, 1.333333, 0.666667,0.333333, 0.666667, 0.333333, auto expected = NDArrayFactory::create<TypeParam>('c', {bS, iC, iD, iH, iW}, {0.166667f, 0.333333f, 0.166667f, 0.333333f, 0.666667f, 0.333333f, 0.333333f, 0.666667f, 0.333333f, 0.166667f, 0.333333f, 0.166667f, 0.333333f, 0.666667f, 0.333333f, 0.666667f, 1.333333f, 0.666667f, 0.666667f, 1.333333f, 0.666667f, 0.333333f, 0.666667f, 0.333333f,
0.166667, 0.333333, 0.166667,0.333333, 0.666667, 0.333333,0.333333, 0.666667, 0.333333,0.166667, 0.333333, 0.166667,0.166667, 0.333333, 0.166667,0.333333, 0.666667, 0.333333,0.333333, 0.666667, 0.333333,0.166667, 0.333333, 0.166667, 0.166667f, 0.333333f, 0.166667f, 0.333333f, 0.666667f, 0.333333f, 0.333333f, 0.666667f, 0.333333f, 0.166667f, 0.333333f, 0.166667f, 0.166667f, 0.333333f, 0.166667f, 0.333333f, 0.666667f, 0.333333f, 0.333333f, 0.666667f, 0.333333f, 0.166667f, 0.333333f, 0.166667f,
0.333333, 0.666667, 0.333333,0.666667, 1.333333, 0.666667,0.666667, 1.333333, 0.666667,0.333333, 0.666667, 0.333333,0.166667, 0.333333, 0.166667,0.333333, 0.666667, 0.333333,0.333333, 0.666667, 0.333333,0.166667, 0.333333, 0.166667, 0.333333f, 0.666667f, 0.333333f, 0.666667f, 1.333333f, 0.666667f, 0.666667f, 1.333333f, 0.666667f, 0.333333f, 0.666667f, 0.333333f, 0.166667f, 0.333333f, 0.166667f, 0.333333f, 0.666667f, 0.333333f, 0.333333f, 0.666667f, 0.333333f, 0.166667f, 0.333333f, 0.166667f,
0.166667, 0.333333, 0.166667,0.333333, 0.666667, 0.333333,0.333333, 0.666667, 0.333333,0.166667, 0.333333, 0.166667,0.333333, 0.666667, 0.333333,0.666667, 1.333333, 0.666667,0.666667, 1.333333, 0.666667,0.333333, 0.666667, 0.333333, 0.166667f, 0.333333f, 0.166667f, 0.333333f, 0.666667f, 0.333333f, 0.333333f, 0.666667f, 0.333333f, 0.166667f, 0.333333f, 0.166667f, 0.333333f, 0.666667f, 0.333333f, 0.666667f, 1.333333f, 0.666667f, 0.666667f, 1.333333f, 0.666667f, 0.333333f, 0.666667f, 0.333333f,
0.166667, 0.333333, 0.166667,0.333333, 0.666667, 0.333333,0.333333, 0.666667, 0.333333,0.166667, 0.333333, 0.166667,0.166667, 0.333333, 0.166667,0.333333, 0.666667, 0.333333,0.333333, 0.666667, 0.333333,0.166667, 0.333333, 0.166667, 0.166667f, 0.333333f, 0.166667f, 0.333333f, 0.666667f, 0.333333f, 0.333333f, 0.666667f, 0.333333f, 0.166667f, 0.333333f, 0.166667f, 0.166667f, 0.333333f, 0.166667f, 0.333333f, 0.666667f, 0.333333f, 0.333333f, 0.666667f, 0.333333f, 0.166667f, 0.333333f, 0.166667f,
0.333333, 0.666667, 0.333333,0.666667, 1.333333, 0.666667,0.666667, 1.333333, 0.666667,0.333333, 0.666667, 0.333333,0.166667, 0.333333, 0.166667,0.333333, 0.666667, 0.333333,0.333333, 0.666667, 0.333333,0.166667, 0.333333, 0.166667, 0.333333f, 0.666667f, 0.333333f, 0.666667f, 1.333333f, 0.666667f, 0.666667f, 1.333333f, 0.666667f, 0.333333f, 0.666667f, 0.333333f, 0.166667f, 0.333333f, 0.166667f, 0.333333f, 0.666667f, 0.333333f, 0.333333f, 0.666667f, 0.333333f, 0.166667f, 0.333333f, 0.166667f,
0.166667, 0.333333, 0.166667,0.333333, 0.666667, 0.333333,0.333333, 0.666667, 0.333333,0.166667, 0.333333, 0.166667,0.333333, 0.666667, 0.333333,0.666667, 1.333333, 0.666667,0.666667, 1.333333, 0.666667,0.333333, 0.666667, 0.333333, 0.166667f, 0.333333f, 0.166667f, 0.333333f, 0.666667f, 0.333333f, 0.333333f, 0.666667f, 0.333333f, 0.166667f, 0.333333f, 0.166667f, 0.333333f, 0.666667f, 0.333333f, 0.666667f, 1.333333f, 0.666667f, 0.666667f, 1.333333f, 0.666667f, 0.333333f, 0.666667f, 0.333333f,
0.166667, 0.333333, 0.166667,0.333333, 0.666667, 0.333333,0.333333, 0.666667, 0.333333,0.166667, 0.333333, 0.166667,0.166667, 0.333333, 0.166667,0.333333, 0.666667, 0.333333,0.333333, 0.666667, 0.333333,0.166667, 0.333333, 0.166667, 0.166667f, 0.333333f, 0.166667f, 0.333333f, 0.666667f, 0.333333f, 0.333333f, 0.666667f, 0.333333f, 0.166667f, 0.333333f, 0.166667f, 0.166667f, 0.333333f, 0.166667f, 0.333333f, 0.666667f, 0.333333f, 0.333333f, 0.666667f, 0.333333f, 0.166667f, 0.333333f, 0.166667f,
0.333333, 0.666667, 0.333333,0.666667, 1.333333, 0.666667,0.666667, 1.333333, 0.666667,0.333333, 0.666667, 0.333333,0.166667, 0.333333, 0.166667,0.333333, 0.666667, 0.333333,0.333333, 0.666667, 0.333333,0.166667, 0.333333, 0.166667}); 0.333333f, 0.666667f, 0.333333f, 0.666667f, 1.333333f, 0.666667f, 0.666667f, 1.333333f, 0.666667f, 0.333333f, 0.666667f, 0.333333f, 0.166667f, 0.333333f, 0.166667f, 0.333333f, 0.666667f, 0.333333f, 0.333333f, 0.666667f, 0.333333f, 0.166667f, 0.333333f, 0.166667f});
input.linspace(1.); input.linspace(1.);
gradO = 2.; gradO = 2.;
@ -1367,15 +1366,15 @@ TYPED_TEST(TypedConvolutionTests2, avgpool3d_bp_test2) {
auto input = NDArrayFactory::create<TypeParam>('c', {bS, iC, iD, iH, iW}); auto input = NDArrayFactory::create<TypeParam>('c', {bS, iC, iD, iH, iW});
auto gradO = NDArrayFactory::create<TypeParam>('c', {bS, iC, oD, oH, oW}); auto gradO = NDArrayFactory::create<TypeParam>('c', {bS, iC, oD, oH, oW});
auto expected = NDArrayFactory::create<TypeParam>('c', {bS, iC, iD, iH, iW}, {1.333333, 1.333333, 1.333333,2. , 2. , 2. ,2. , 2. , 2. ,1.333333, 1.333333, 1.333333,1.333333, 1.333333, 1.333333,2. , 2. , 2. ,2. , 2. , 2. ,1.333333, 1.333333, 1.333333, auto expected = NDArrayFactory::create<TypeParam>('c', {bS, iC, iD, iH, iW}, {1.333333f, 1.333333f, 1.333333f, 2.f, 2.f, 2.f, 2.f, 2.f, 2.f, 1.333333f, 1.333333f, 1.333333f, 1.333333f, 1.333333f, 1.333333f, 2.f, 2.f, 2.f, 2.f, 2.f, 2.f, 1.333333f, 1.333333f, 1.333333f,
1.333333, 1.333333, 1.333333,2. , 2. , 2. ,2. , 2. , 2. ,1.333333, 1.333333, 1.333333,1.333333, 1.333333, 1.333333,2. , 2. , 2. ,2. , 2. , 2. ,1.333333, 1.333333, 1.333333, 1.333333f, 1.333333f, 1.333333f, 2.f, 2.f, 2.f, 2.f, 2.f, 2.f, 1.333333f, 1.333333f, 1.333333f, 1.333333f, 1.333333f, 1.333333f, 2.f, 2.f, 2.f, 2.f, 2.f, 2.f, 1.333333f, 1.333333f, 1.333333f,
1.333333, 1.333333, 1.333333,2. , 2. , 2. ,2. , 2. , 2. ,1.333333, 1.333333, 1.333333,1.333333, 1.333333, 1.333333,2. , 2. , 2. ,2. , 2. , 2. ,1.333333, 1.333333, 1.333333, 1.333333f, 1.333333f, 1.333333f, 2.f, 2.f, 2.f, 2.f, 2.f, 2.f, 1.333333f, 1.333333f, 1.333333f, 1.333333f, 1.333333f, 1.333333f, 2.f, 2.f, 2.f, 2.f, 2.f, 2.f, 1.333333f, 1.333333f, 1.333333f,
1.333333, 1.333333, 1.333333,2. , 2. , 2. ,2. , 2. , 2. ,1.333333, 1.333333, 1.333333,1.333333, 1.333333, 1.333333,2. , 2. , 2. ,2. , 2. , 2. ,1.333333, 1.333333, 1.333333, 1.333333f, 1.333333f, 1.333333f, 2.f, 2.f, 2.f, 2.f, 2.f, 2.f, 1.333333f, 1.333333f, 1.333333f, 1.333333f, 1.333333f, 1.333333f, 2.f, 2.f, 2.f, 2.f, 2.f, 2.f, 1.333333f, 1.333333f, 1.333333f,
1.333333, 1.333333, 1.333333,2. , 2. , 2. ,2. , 2. , 2. ,1.333333, 1.333333, 1.333333,1.333333, 1.333333, 1.333333,2. , 2. , 2. ,2. , 2. , 2. ,1.333333, 1.333333, 1.333333, 1.333333f, 1.333333f, 1.333333f, 2.f, 2.f, 2.f, 2.f, 2.f, 2.f, 1.333333f, 1.333333f, 1.333333f, 1.333333f, 1.333333f, 1.333333f, 2.f, 2.f, 2.f, 2.f, 2.f, 2.f, 1.333333f, 1.333333f, 1.333333f,
1.333333, 1.333333, 1.333333,2. , 2. , 2. ,2. , 2. , 2. ,1.333333, 1.333333, 1.333333,1.333333, 1.333333, 1.333333,2. , 2. , 2. ,2. , 2. , 2. ,1.333333, 1.333333, 1.333333, 1.333333f, 1.333333f, 1.333333f, 2.f, 2.f, 2.f, 2.f, 2.f, 2.f, 1.333333f, 1.333333f, 1.333333f, 1.333333f, 1.333333f, 1.333333f, 2.f, 2.f, 2.f, 2.f, 2.f, 2.f, 1.333333f, 1.333333f, 1.333333f,
1.333333, 1.333333, 1.333333,2. , 2. , 2. ,2. , 2. , 2. ,1.333333, 1.333333, 1.333333,1.333333, 1.333333, 1.333333,2. , 2. , 2. ,2. , 2. , 2. ,1.333333, 1.333333, 1.333333, 1.333333f, 1.333333f, 1.333333f, 2.f, 2.f, 2.f, 2.f, 2.f, 2.f, 1.333333f, 1.333333f, 1.333333f, 1.333333f, 1.333333f, 1.333333f, 2.f, 2.f, 2.f, 2.f, 2.f, 2.f, 1.333333f, 1.333333f, 1.333333f,
1.333333, 1.333333, 1.333333,2. , 2. , 2. ,2. , 2. , 2. ,1.333333, 1.333333, 1.333333,1.333333, 1.333333, 1.333333,2. , 2. , 2. ,2. , 2. , 2. ,1.333333, 1.333333, 1.333333, 1.333333f, 1.333333f, 1.333333f, 2.f, 2.f, 2.f, 2.f, 2.f, 2.f, 1.333333f, 1.333333f, 1.333333f, 1.333333f, 1.333333f, 1.333333f, 2.f, 2.f, 2.f, 2.f, 2.f, 2.f, 1.333333f, 1.333333f, 1.333333f,
1.333333, 1.333333, 1.333333,2. , 2. , 2. ,2. , 2. , 2. ,1.333333, 1.333333, 1.333333,1.333333, 1.333333, 1.333333,2. , 2. , 2. ,2. , 2. , 2. ,1.333333, 1.333333, 1.333333}); 1.333333f, 1.333333f, 1.333333f, 2.f, 2.f, 2.f, 2.f, 2.f, 2.f, 1.333333f, 1.333333f, 1.333333f, 1.333333f, 1.333333f, 1.333333f, 2.f, 2.f, 2.f, 2.f, 2.f, 2.f, 1.333333f, 1.333333f, 1.333333f});
input.linspace(1.); input.linspace(1.);
gradO = 2.; gradO = 2.;
@ -1403,14 +1402,14 @@ TYPED_TEST(TypedConvolutionTests2, avgpool3d_bp_test3) {
auto input = NDArrayFactory::create<TypeParam>('c', {bS, iD, iH, iW, iC}); auto input = NDArrayFactory::create<TypeParam>('c', {bS, iD, iH, iW, iC});
auto gradO = NDArrayFactory::create<TypeParam>('c', {bS, oD, oH, oW, iC}); auto gradO = NDArrayFactory::create<TypeParam>('c', {bS, oD, oH, oW, iC});
auto expected = NDArrayFactory::create<TypeParam>('c', {bS, iD, iH, iW, iC}, {0.41667, 0.41667, 0.41667,0.83333, 0.83333, 0.83333,1.25, 1.25, 1.25 ,0.58333, 0.58333, 0.58333,1.16667, 1.16667, 1.16667,1.75, 1.75, 1.75 ,0.58333, 0.58333, 0.58333,1.16667, 1.16667, 1.16667,1.75, 1.75, 1.75 , auto expected = NDArrayFactory::create<TypeParam>('c', {bS, iD, iH, iW, iC}, {0.41667f, 0.41667f, 0.41667f, 0.83333f, 0.83333f, 0.83333f, 1.25f, 1.25f, 1.25f, 0.58333f, 0.58333f, 0.58333f, 1.16667f, 1.16667f, 1.16667f, 1.75f, 1.75f, 1.75f, 0.58333f, 0.58333f, 0.58333f, 1.16667f, 1.16667f, 1.16667f, 1.75f, 1.75f, 1.75f,
0.41667, 0.41667, 0.41667,0.83333, 0.83333, 0.83333,1.25, 1.25, 1.25 ,0.83333, 0.83333, 0.83333,1.66667, 1.66667, 1.66667,2.5 , 2.5 , 2.5 ,1.16667, 1.16667, 1.16667,2.33333, 2.33333, 2.33333,3.5 , 3.5 , 3.5 , 0.41667f, 0.41667f, 0.41667f, 0.83333f, 0.83333f, 0.83333f, 1.25f, 1.25f, 1.25f, 0.83333f, 0.83333f, 0.83333f, 1.66667f, 1.66667f, 1.66667f, 2.5f, 2.5f, 2.5f, 1.16667f, 1.16667f, 1.16667f, 2.33333f, 2.33333f, 2.33333f, 3.5f, 3.5f, 3.5f,
1.16667, 1.16667, 1.16667,2.33333, 2.33333, 2.33333,3.5 , 3.5 , 3.5 ,0.83333, 0.83333, 0.83333,1.66667, 1.66667, 1.66667,2.5 , 2.5 , 2.5 ,1.25 , 1.25 , 1.25 ,2.5 , 2.5 , 2.5 ,3.75, 3.75, 3.75 , 1.16667f, 1.16667f, 1.16667f, 2.33333f, 2.33333f, 2.33333f, 3.5f, 3.5f, 3.5f, 0.83333f, 0.83333f, 0.83333f, 1.66667f, 1.66667f, 1.66667f, 2.5f, 2.5f, 2.5f, 1.25f, 1.25f, 1.25f, 2.5f, 2.5f, 2.5f, 3.75f, 3.75f, 3.75f,
1.75 , 1.75 , 1.75 ,3.5 , 3.5 , 3.5 ,5.25, 5.25, 5.25 ,1.75 , 1.75 , 1.75 ,3.5 , 3.5 , 3.5 ,5.25, 5.25, 5.25 ,1.25 , 1.25 , 1.25 ,2.5 , 2.5 , 2.5 ,3.75, 3.75, 3.75 , 1.75f, 1.75f, 1.75f, 3.5f, 3.5f, 3.5f, 5.25f, 5.25f, 5.25f, 1.75f, 1.75f, 1.75f, 3.5f, 3.5f, 3.5f, 5.25f, 5.25f, 5.25f, 1.25f, 1.25f, 1.25f, 2.5f, 2.5f, 2.5f, 3.75f, 3.75f, 3.75f,
0.41667, 0.41667, 0.41667,0.83333, 0.83333, 0.83333,1.25, 1.25, 1.25 ,0.58333, 0.58333, 0.58333,1.16667, 1.16667, 1.16667,1.75, 1.75, 1.75 ,0.58333, 0.58333, 0.58333,1.16667, 1.16667, 1.16667,1.75, 1.75, 1.75 , 0.41667f, 0.41667f, 0.41667f, 0.83333f, 0.83333f, 0.83333f, 1.25f, 1.25f, 1.25f, 0.58333f, 0.58333f, 0.58333f, 1.16667f, 1.16667f, 1.16667f, 1.75f, 1.75f, 1.75f, 0.58333f, 0.58333f, 0.58333f, 1.16667f, 1.16667f, 1.16667f, 1.75f, 1.75f, 1.75f,
0.41667, 0.41667, 0.41667,0.83333, 0.83333, 0.83333,1.25, 1.25, 1.25 ,0.83333, 0.83333, 0.83333,1.66667, 1.66667, 1.66667,2.5 , 2.5 , 2.5 ,1.16667, 1.16667, 1.16667,2.33333, 2.33333, 2.33333,3.5 , 3.5 , 3.5 , 0.41667f, 0.41667f, 0.41667f, 0.83333f, 0.83333f, 0.83333f, 1.25f, 1.25f, 1.25f, 0.83333f, 0.83333f, 0.83333f, 1.66667f, 1.66667f, 1.66667f, 2.5f, 2.5f, 2.5f, 1.16667f, 1.16667f, 1.16667f, 2.33333f, 2.33333f, 2.33333f, 3.5f, 3.5f, 3.5f,
1.16667, 1.16667, 1.16667,2.33333, 2.33333, 2.33333,3.5 , 3.5 , 3.5 ,0.83333, 0.83333, 0.83333,1.66667, 1.66667, 1.66667,2.5 , 2.5 , 2.5 ,1.25 , 1.25 , 1.25 ,2.5 , 2.5 , 2.5 ,3.75, 3.75, 3.75 , 1.16667f, 1.16667f, 1.16667f, 2.33333f, 2.33333f, 2.33333f, 3.5f, 3.5f, 3.5f, 0.83333f, 0.83333f, 0.83333f, 1.66667f, 1.66667f, 1.66667f, 2.5f, 2.5f, 2.5f, 1.25f, 1.25f, 1.25f, 2.5f, 2.5f, 2.5f, 3.75f, 3.75f, 3.75f,
1.75 , 1.75 , 1.75 ,3.5 , 3.5 , 3.5 ,5.25, 5.25, 5.25 ,1.75 , 1.75 , 1.75 ,3.5 , 3.5 , 3.5 ,5.25, 5.25, 5.25 ,1.25 , 1.25 , 1.25 ,2.5 , 2.5 , 2.5 ,3.75, 3.75, 3.75 }); 1.75f, 1.75f, 1.75f, 3.5f, 3.5f, 3.5f, 5.25f, 5.25f, 5.25f, 1.75f, 1.75f, 1.75f, 3.5f, 3.5f, 3.5f, 5.25f, 5.25f, 5.25f, 1.25f, 1.25f, 1.25f, 2.5f, 2.5f, 2.5f, 3.75f, 3.75f, 3.75f});
input.linspace(1.); input.linspace(1.);
gradO = 2.; gradO = 2.;
@ -1435,14 +1434,14 @@ TYPED_TEST(TypedConvolutionTests2, avgpool3d_bp_test4) {
auto input = NDArrayFactory::create<TypeParam>('c', {bS, iD, iH, iW, iC}); auto input = NDArrayFactory::create<TypeParam>('c', {bS, iD, iH, iW, iC});
auto gradO = NDArrayFactory::create<TypeParam>('c', {bS, oD, oH, oW, iC}); auto gradO = NDArrayFactory::create<TypeParam>('c', {bS, oD, oH, oW, iC});
auto expected = NDArrayFactory::create<TypeParam>('c', {bS, iD, iH, iW, iC}, {0.16667, 0.16667, 0.16667,0.33333, 0.33333, 0.33333,0.5 , 0.5 , 0.5 ,0.33333, 0.33333, 0.33333,0.66667, 0.66667, 0.66667,1. , 1. , 1. ,0.58333, 0.58333, 0.58333,1.16667, 1.16667, 1.16667,1.75, 1.75 , 1.75 , auto expected = NDArrayFactory::create<TypeParam>('c', {bS, iD, iH, iW, iC}, {0.16667f, 0.16667f, 0.16667f, 0.33333f, 0.33333f, 0.33333f, 0.5f, 0.5f, 0.5f, 0.33333f, 0.33333f, 0.33333f, 0.66667f, 0.66667f, 0.66667f, 1.f, 1.f, 1.f, 0.58333f, 0.58333f, 0.58333f, 1.16667f, 1.16667f, 1.16667f, 1.75f, 1.75f, 1.75f,
0.91667, 0.91667, 0.91667,1.83333, 1.83333, 1.83333,2.75, 2.75 , 2.75 ,0.33333, 0.33333, 0.33333,0.66667, 0.66667, 0.66667,1. , 1. , 1. ,0.66667, 0.66667, 0.66667,1.33333, 1.33333, 1.33333,2. , 2. , 2. , 0.91667f, 0.91667f, 0.91667f, 1.83333f, 1.83333f, 1.83333f, 2.75f, 2.75f, 2.75f, 0.33333f, 0.33333f, 0.33333f, 0.66667f, 0.66667f, 0.66667f, 1.f, 1.f, 1.f, 0.66667f, 0.66667f, 0.66667f, 1.33333f, 1.33333f, 1.33333f, 2.f, 2.f, 2.f,
1.16667, 1.16667, 1.16667,2.33333, 2.33333, 2.33333,3.5 , 3.5 , 3.5 ,1.83333, 1.83333, 1.83333,3.66667, 3.66667, 3.66667,5.5 , 5.5 , 5.5 ,0.5 , 0.5 , 0.5 ,1. , 1. , 1. ,1.5 , 1.5 , 1.5 , 1.16667f, 1.16667f, 1.16667f, 2.33333f, 2.33333f, 2.33333f, 3.5f, 3.5f, 3.5f, 1.83333f, 1.83333f, 1.83333f, 3.66667f, 3.66667f, 3.66667f, 5.5f, 5.5f, 5.5f, 0.5f, 0.5f, 0.5f, 1.f, 1.f, 1.f, 1.5f, 1.5f, 1.5f,
1. , 1. , 1. ,2. , 2. , 2. ,3. , 3. , 3. ,1.75 , 1.75 , 1.75 ,3.5 , 3.5 , 3.5 ,5.25, 5.25 , 5.25 ,2.75 , 2.75 , 2.75 ,5.5 , 5.5 , 5.5 ,8.25, 8.25 , 8.25 , 1.f, 1.f, 1.f, 2.f, 2.f, 2.f, 3.f, 3.f, 3.f, 1.75f, 1.75f, 1.75f, 3.5f, 3.5f, 3.5f, 5.25f, 5.25f, 5.25f, 2.75f, 2.75f, 2.75f, 5.5f, 5.5f, 5.5f, 8.25f, 8.25f, 8.25f,
0.16667, 0.16667, 0.16667,0.33333, 0.33333, 0.33333,0.5 , 0.5 , 0.5 ,0.33333, 0.33333, 0.33333,0.66667, 0.66667, 0.66667,1. , 1. , 1. ,0.58333, 0.58333, 0.58333,1.16667, 1.16667, 1.16667,1.75, 1.75 , 1.75 , 0.16667f, 0.16667f, 0.16667f, 0.33333f, 0.33333f, 0.33333f, 0.5f, 0.5f, 0.5f, 0.33333f, 0.33333f, 0.33333f, 0.66667f, 0.66667f, 0.66667f, 1.f, 1.f, 1.f, 0.58333f, 0.58333f, 0.58333f, 1.16667f, 1.16667f, 1.16667f, 1.75f, 1.75f, 1.75f,
0.91667, 0.91667, 0.91667,1.83333, 1.83333, 1.83333,2.75, 2.75 , 2.75 ,0.33333, 0.33333, 0.33333,0.66667, 0.66667, 0.66667,1. , 1. , 1. ,0.66667, 0.66667, 0.66667,1.33333, 1.33333, 1.33333,2. , 2. , 2. , 0.91667f, 0.91667f, 0.91667f, 1.83333f, 1.83333f, 1.83333f, 2.75f, 2.75f, 2.75f, 0.33333f, 0.33333f, 0.33333f, 0.66667f, 0.66667f, 0.66667f, 1.f, 1.f, 1.f, 0.66667f, 0.66667f, 0.66667f, 1.33333f, 1.33333f, 1.33333f, 2.f, 2.f, 2.f,
1.16667, 1.16667, 1.16667,2.33333, 2.33333, 2.33333,3.5 , 3.5 , 3.5 ,1.83333, 1.83333, 1.83333,3.66667, 3.66667, 3.66667,5.5 , 5.5 , 5.5 ,0.5 , 0.5 , 0.5 ,1. , 1. , 1. ,1.5 , 1.5 , 1.5 , 1.16667f, 1.16667f, 1.16667f, 2.33333f, 2.33333f, 2.33333f, 3.5f, 3.5f, 3.5f, 1.83333f, 1.83333f, 1.83333f, 3.66667f, 3.66667f, 3.66667f, 5.5f, 5.5f, 5.5f, 0.5f, 0.5f, 0.5f, 1.f, 1.f, 1.f, 1.5f, 1.5f, 1.5f,
1. , 1. , 1. ,2. , 2. , 2. ,3. , 3. , 3. ,1.75 , 1.75 , 1.75 ,3.5 , 3.5 , 3.5 ,5.25, 5.25 , 5.25 ,2.75 , 2.75 , 2.75 ,5.5 , 5.5 , 5.5 ,8.25, 8.25 , 8.25 }); 1.f, 1.f, 1.f, 2.f, 2.f, 2.f, 3.f, 3.f, 3.f, 1.75f, 1.75f, 1.75f, 3.5f, 3.5f, 3.5f, 5.25f, 5.25f, 5.25f, 2.75f, 2.75f, 2.75f, 5.5f, 5.5f, 5.5f, 8.25f, 8.25f, 8.25f});
input.linspace(1.); input.linspace(1.);
gradO = 2.; gradO = 2.;
@ -1467,12 +1466,12 @@ TYPED_TEST(TypedConvolutionTests2, maxpool3d_bp_test1) {
auto input = NDArrayFactory::create<TypeParam>('c', {bS, iC, iD, iH, iW}); auto input = NDArrayFactory::create<TypeParam>('c', {bS, iC, iD, iH, iW});
auto gradO = NDArrayFactory::create<TypeParam>('c', {bS, iC, oD, oH, oW}); auto gradO = NDArrayFactory::create<TypeParam>('c', {bS, iC, oD, oH, oW});
auto expected = NDArrayFactory::create<TypeParam>('c', {bS, iC, iD, iH, iW}, {0. , 0. , 0. ,0. , 0. , 0. ,0. , 0. , 0. ,0. , 0. , 0. ,0. , 0. , 0. ,0. , 0. , 0. ,0. , 0.1, 0.2,0. , 0.3, 0.4,0. , 0. , 0. ,0. , 0. , 0. ,0. , 0.5, 0.6,0. , 0.7, 0.8, auto expected = NDArrayFactory::create<TypeParam>('c', {bS, iC, iD, iH, iW}, {0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.1f, 0.2f, 0.f, 0.3f, 0.4f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.5f, 0.6f, 0.f, 0.7f, 0.8f,
0. , 0. , 0. ,0. , 0. , 0. ,0. , 0. , 0. ,0. , 0. , 0. ,0. , 0. , 0. ,0. , 0. , 0. ,0. , 0.9, 1. ,0. , 1.1, 1.2,0. , 0. , 0. ,0. , 0. , 0. ,0. , 1.3, 1.4,0. , 1.5, 1.6, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.9f, 1.f, 0.f, 1.1f, 1.2f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 1.3f, 1.4f, 0.f, 1.5f, 1.6f,
0. , 0. , 0. ,0. , 0. , 0. ,0. , 0. , 0. ,0. , 0. , 0. ,0. , 0. , 0. ,0. , 0. , 0. ,0. , 1.7, 1.8,0. , 1.9, 2. ,0. , 0. , 0. ,0. , 0. , 0. ,0. , 2.1, 2.2,0. , 2.3, 2.4, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 1.7f, 1.8f, 0.f, 1.9f, 2.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 2.1f, 2.2f, 0.f, 2.3f, 2.4f,
0. , 0. , 0. ,0. , 0. , 0. ,0. , 0. , 0. ,0. , 0. , 0. ,0. , 0. , 0. ,0. , 0. , 0. ,0. , 2.5, 2.6,0. , 2.7, 2.8,0. , 0. , 0. ,0. , 0. , 0. ,0. , 2.9, 3. ,0. , 3.1, 3.2, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 2.5f, 2.6f, 0.f, 2.7f, 2.8f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 2.9f, 3.f, 0.f, 3.1f, 3.2f,
0. , 0. , 0. ,0. , 0. , 0. ,0. , 0. , 0. ,0. , 0. , 0. ,0. , 0. , 0. ,0. , 0. , 0. ,0. , 3.3, 3.4,0. , 3.5, 3.6,0. , 0. , 0. ,0. , 0. , 0. ,0. , 3.7, 3.8,0. , 3.9, 4. , 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 3.3f, 3.4f, 0.f, 3.5f, 3.6f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 3.7f, 3.8f, 0.f, 3.9f, 4.f,
0. , 0. , 0. ,0. , 0. , 0. ,0. , 0. , 0. ,0. , 0. , 0. ,0. , 0. , 0. ,0. , 0. , 0. ,0. , 4.1, 4.2,0. , 4.3, 4.4,0. , 0. , 0. ,0. , 0. , 0. ,0. , 4.5, 4.6,0. , 4.7, 4.8}); 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 4.1f, 4.2f, 0.f, 4.3f, 4.4f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 4.5f, 4.6f, 0.f, 4.7f, 4.8f});
input.linspace(1.); input.linspace(1.);
gradO.linspace(0.1, 0.1); gradO.linspace(0.1, 0.1);
@ -1497,15 +1496,15 @@ TYPED_TEST(TypedConvolutionTests2, maxpool3d_bp_test2) {
auto input = NDArrayFactory::create<TypeParam>('c', {bS, iC, iD, iH, iW}); auto input = NDArrayFactory::create<TypeParam>('c', {bS, iC, iD, iH, iW});
auto gradO = NDArrayFactory::create<TypeParam>('c', {bS, iC, oD, oH, oW}); auto gradO = NDArrayFactory::create<TypeParam>('c', {bS, iC, oD, oH, oW});
auto expected = NDArrayFactory::create<TypeParam>('c', {bS, iC, iD, iH, iW}, {0.000e+00, 0.000e+00, 0.000e+00,1.000e-01, 2.000e-01, 7.000e-01,5.000e-01, 6.000e-01, 1.500e+00,2.200e+00, 2.400e+00, 5.400e+00,0.000e+00, 0.000e+00, 0.000e+00,1.700e+00, 1.800e+00, 3.900e+00,2.100e+00, 2.200e+00, 4.700e+00,5.400e+00, 5.600e+00, 1.180e+01, auto expected = NDArrayFactory::create<TypeParam>('c', {bS, iC, iD, iH, iW}, {0.000e+00f, 0.000e+00f, 0.000e+00f, 1.000e-01f, 2.000e-01f, 7.000e-01f, 5.000e-01f, 6.000e-01f, 1.500e+00f, 2.200e+00f, 2.400e+00f, 5.400e+00f, 0.000e+00f, 0.000e+00f, 0.000e+00f, 1.700e+00f, 1.800e+00f, 3.900e+00f, 2.100e+00f, 2.200e+00f, 4.700e+00f, 5.400e+00f, 5.600e+00f, 1.180e+01f,
0.000e+00, 0.000e+00, 0.000e+00,8.200e+00, 8.400e+00, 1.740e+01,9.000e+00, 9.200e+00, 1.900e+01,2.040e+01, 2.080e+01, 4.280e+01,0.000e+00, 0.000e+00, 0.000e+00,6.500e+00, 6.600e+00, 1.350e+01,6.900e+00, 7.000e+00, 1.430e+01,1.500e+01, 1.520e+01, 3.100e+01, 0.000e+00f, 0.000e+00f, 0.000e+00f, 8.200e+00f, 8.400e+00f, 1.740e+01f, 9.000e+00f, 9.200e+00f, 1.900e+01f, 2.040e+01f, 2.080e+01f, 4.280e+01f, 0.000e+00f, 0.000e+00f, 0.000e+00f, 6.500e+00f, 6.600e+00f, 1.350e+01f, 6.900e+00f, 7.000e+00f, 1.430e+01f, 1.500e+01f, 1.520e+01f, 3.100e+01f,
0.000e+00, 0.000e+00, 0.000e+00,8.100e+00, 8.200e+00, 1.670e+01,8.500e+00, 8.600e+00, 1.750e+01,1.820e+01, 1.840e+01, 3.740e+01,0.000e+00, 0.000e+00, 0.000e+00,2.100e+01, 2.120e+01, 4.300e+01,2.180e+01, 2.200e+01, 4.460e+01,4.600e+01, 4.640e+01, 9.400e+01, 0.000e+00f, 0.000e+00f, 0.000e+00f, 8.100e+00f, 8.200e+00f, 1.670e+01f, 8.500e+00f, 8.600e+00f, 1.750e+01f, 1.820e+01f, 1.840e+01f, 3.740e+01f, 0.000e+00f, 0.000e+00f, 0.000e+00f, 2.100e+01f, 2.120e+01f, 4.300e+01f, 2.180e+01f, 2.200e+01f, 4.460e+01f, 4.600e+01f, 4.640e+01f, 9.400e+01f,
0.000e+00, 0.000e+00, 0.000e+00,1.290e+01, 1.300e+01, 2.630e+01,1.330e+01, 1.340e+01, 2.710e+01,2.780e+01, 2.800e+01, 5.660e+01,0.000e+00, 0.000e+00, 0.000e+00,1.450e+01, 1.460e+01, 2.950e+01,1.490e+01, 1.500e+01, 3.030e+01,3.100e+01, 3.120e+01, 6.300e+01, 0.000e+00f, 0.000e+00f, 0.000e+00f, 1.290e+01f, 1.300e+01f, 2.630e+01f, 1.330e+01f, 1.340e+01f, 2.710e+01f, 2.780e+01f, 2.800e+01f, 5.660e+01f, 0.000e+00f, 0.000e+00f, 0.000e+00f, 1.450e+01f, 1.460e+01f, 2.950e+01f, 1.490e+01f, 1.500e+01f, 3.030e+01f, 3.100e+01f, 3.120e+01f, 6.300e+01f,
0.000e+00, 0.000e+00, 0.000e+00,3.380e+01, 3.400e+01, 6.860e+01,3.460e+01, 3.480e+01, 7.020e+01,7.160e+01, 7.200e+01, 1.452e+02,0.000e+00, 0.000e+00, 0.000e+00,1.930e+01, 1.940e+01, 3.910e+01,1.970e+01, 1.980e+01, 3.990e+01,4.060e+01, 4.080e+01, 8.220e+01, 0.000e+00f, 0.000e+00f, 0.000e+00f, 3.380e+01f, 3.400e+01f, 6.860e+01f, 3.460e+01f, 3.480e+01f, 7.020e+01f, 7.160e+01f, 7.200e+01f, 1.452e+02f, 0.000e+00f, 0.000e+00f, 0.000e+00f, 1.930e+01f, 1.940e+01f, 3.910e+01f, 1.970e+01f, 1.980e+01f, 3.990e+01f, 4.060e+01f, 4.080e+01f, 8.220e+01f,
0.000e+00, 0.000e+00, 0.000e+00,2.090e+01, 2.100e+01, 4.230e+01,2.130e+01, 2.140e+01, 4.310e+01,4.380e+01, 4.400e+01, 8.860e+01,0.000e+00, 0.000e+00, 0.000e+00,4.660e+01, 4.680e+01, 9.420e+01,4.740e+01, 4.760e+01, 9.580e+01,9.720e+01, 9.760e+01, 1.964e+02, 0.000e+00f, 0.000e+00f, 0.000e+00f, 2.090e+01f, 2.100e+01f, 4.230e+01f, 2.130e+01f, 2.140e+01f, 4.310e+01f, 4.380e+01f, 4.400e+01f, 8.860e+01f, 0.000e+00f, 0.000e+00f, 0.000e+00f, 4.660e+01f, 4.680e+01f, 9.420e+01f, 4.740e+01f, 4.760e+01f, 9.580e+01f, 9.720e+01f, 9.760e+01f, 1.964e+02f,
0.000e+00, 0.000e+00, 0.000e+00,2.570e+01, 2.580e+01, 5.190e+01,2.610e+01, 2.620e+01, 5.270e+01,5.340e+01, 5.360e+01, 1.078e+02,0.000e+00, 0.000e+00, 0.000e+00,2.730e+01, 2.740e+01, 5.510e+01,2.770e+01, 2.780e+01, 5.590e+01,5.660e+01, 5.680e+01, 1.142e+02, 0.000e+00f, 0.000e+00f, 0.000e+00f, 2.570e+01f, 2.580e+01f, 5.190e+01f, 2.610e+01f, 2.620e+01f, 5.270e+01f, 5.340e+01f, 5.360e+01f, 1.078e+02f, 0.000e+00f, 0.000e+00f, 0.000e+00f, 2.730e+01f, 2.740e+01f, 5.510e+01f, 2.770e+01f, 2.780e+01f, 5.590e+01f, 5.660e+01f, 5.680e+01f, 1.142e+02f,
0.000e+00, 0.000e+00, 0.000e+00,5.940e+01, 5.960e+01, 1.198e+02,6.020e+01, 6.040e+01, 1.214e+02,1.228e+02, 1.232e+02, 2.476e+02,0.000e+00, 0.000e+00, 0.000e+00,3.210e+01, 3.220e+01, 6.470e+01,3.250e+01, 3.260e+01, 6.550e+01,6.620e+01, 6.640e+01, 1.334e+02, 0.000e+00f, 0.000e+00f, 0.000e+00f, 5.940e+01f, 5.960e+01f, 1.198e+02f, 6.020e+01f, 6.040e+01f, 1.214e+02f, 1.228e+02f, 1.232e+02f, 2.476e+02f, 0.000e+00f, 0.000e+00f, 0.000e+00f, 3.210e+01f, 3.220e+01f, 6.470e+01f, 3.250e+01f, 3.260e+01f, 6.550e+01f, 6.620e+01f, 6.640e+01f, 1.334e+02f,
0.000e+00, 0.000e+00, 0.000e+00,3.370e+01, 3.380e+01, 6.790e+01,3.410e+01, 3.420e+01, 6.870e+01,6.940e+01, 6.960e+01, 1.398e+02,0.000e+00, 0.000e+00, 0.000e+00,7.220e+01, 7.240e+01, 1.454e+02,7.300e+01, 7.320e+01, 1.470e+02,1.484e+02, 1.488e+02, 2.988e+02}); 0.000e+00f, 0.000e+00f, 0.000e+00f, 3.370e+01f, 3.380e+01f, 6.790e+01f, 3.410e+01f, 3.420e+01f, 6.870e+01f, 6.940e+01f, 6.960e+01f, 1.398e+02f, 0.000e+00f, 0.000e+00f, 0.000e+00f, 7.220e+01f, 7.240e+01f, 1.454e+02f, 7.300e+01f, 7.320e+01f, 1.470e+02f, 1.484e+02f, 1.488e+02f, 2.988e+02f});
input.linspace(1.); input.linspace(1.);
gradO.linspace(0.1, 0.1); gradO.linspace(0.1, 0.1);
@ -1530,14 +1529,14 @@ TYPED_TEST(TypedConvolutionTests2, maxpool3d_bp_test3) {
auto input = NDArrayFactory::create<TypeParam>('c', {bS, iD, iH, iW, iC}); auto input = NDArrayFactory::create<TypeParam>('c', {bS, iD, iH, iW, iC});
auto gradO = NDArrayFactory::create<TypeParam>('c', {bS, oD, oH, oW, iC}); auto gradO = NDArrayFactory::create<TypeParam>('c', {bS, oD, oH, oW, iC});
auto expected = NDArrayFactory::create<TypeParam>('c', {bS, iD, iH, iW, iC}, { 0., 0., 0., 0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. , auto expected = NDArrayFactory::create<TypeParam>('c', {bS, iD, iH, iW, iC}, { 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f,
0., 0., 0., 0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. , 0.1, 0.2 , 0.3, 1.1, 1.3 , 1.5, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.1f, 0.2f, 0.3f, 1.1f, 1.3f, 1.5f,
0., 0., 0., 1. , 1.1, 1.2, 2.9, 3.1 , 3.3, 0. , 0. , 0. , 4.7, 4.9 , 5.1, 11.2, 11.6 , 12. , 0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. , 0.f, 0.f, 0.f, 1.f, 1.1f, 1.2f, 2.9f, 3.1f, 3.3f, 0.f, 0.f, 0.f, 4.7f, 4.9f, 5.1f, 11.2f, 11.6f, 12.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f,
0., 0., 0., 11. , 11.2, 11.4, 23.8, 24.2 , 24.6, 0. , 0. , 0. , 12.8, 13. , 13.2, 27.4, 27.8 , 28.2, 0. , 0. , 0. , 31. , 31.4 , 31.8, 65.6, 66.39999, 67.2, 0.f, 0.f, 0.f, 11.f, 11.2f, 11.4f, 23.8f, 24.2f, 24.6f, 0.f, 0.f, 0.f, 12.8f, 13.f, 13.2f, 27.4f, 27.8f, 28.2f, 0.f, 0.f, 0.f, 31.f, 31.4f, 31.8f, 65.6f, 66.39999f, 67.2f,
0., 0., 0., 0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. , 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f,
0., 0., 0., 0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. , 10.9, 11. , 11.1, 22.7, 22.9 , 23.1, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 10.9f, 11.f, 11.1f, 22.7f, 22.9f, 23.1f,
0., 0., 0., 11.8, 11.9, 12. , 24.5, 24.7 , 24.9, 0. , 0. , 0. , 26.3, 26.5 , 26.7, 54.4, 54.8 , 55.2, 0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. , 0.f, 0.f, 0.f, 11.8f, 11.9f, 12.f, 24.5f, 24.7f, 24.9f, 0.f, 0.f, 0.f, 26.3f, 26.5f, 26.7f, 54.4f, 54.8f, 55.2f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f,
0., 0., 0., 32.6, 32.8, 33. , 67. , 67.4 , 67.8, 0. , 0. , 0. , 34.4, 34.6 , 34.8, 70.6, 71. , 71.4, 0. , 0. , 0. , 74.2, 74.6 , 75. ,152. , 152.8 ,153.6}); 0.f, 0.f, 0.f, 32.6f, 32.8f, 33.f, 67.f, 67.4f, 67.8f, 0.f, 0.f, 0.f, 34.4f, 34.6f, 34.8f, 70.6f, 71.f, 71.4f, 0.f, 0.f, 0.f, 74.2f, 74.6f, 75.f, 152.f, 152.8f, 153.6f});
input.linspace(1.); input.linspace(1.);
gradO.linspace(0.1, 0.1); gradO.linspace(0.1, 0.1);
@ -1563,13 +1562,13 @@ TYPED_TEST(TypedConvolutionTests2, maxpool3d_bp_test4) {
auto input = NDArrayFactory::create<TypeParam>('c', {bS, iD, iH, iW, iC}); auto input = NDArrayFactory::create<TypeParam>('c', {bS, iD, iH, iW, iC});
auto gradO = NDArrayFactory::create<TypeParam>('c', {bS, oD, oH, oW, iC}); auto gradO = NDArrayFactory::create<TypeParam>('c', {bS, oD, oH, oW, iC});
auto expected = NDArrayFactory::create<TypeParam>('c', {bS, iD, iH, iW, iC}, {0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, auto expected = NDArrayFactory::create<TypeParam>('c', {bS, iD, iH, iW, iC}, {0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0.1, 0.2, 0.3, 1.1, 1.3, 1.5, 0, 0, 0, 5.7, 6, 6.3, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.1f, 0.2f, 0.3f, 1.1f, 1.3f, 1.5f, 0.f, 0.f, 0.f, 5.7f, 6.f, 6.3f,
14.1, 14.7, 15.3, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 11, 11.2, 11.4, 23.8, 24.2, 14.1f, 14.7f, 15.3f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 11.f, 11.2f, 11.4f, 23.8f, 24.2f,
24.6, 0, 0, 0, 43.8, 44.4, 45, 93, 94.2, 95.4, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 24.6f, 0.f, 0.f, 0.f, 43.8f, 44.4f, 45.f, 93.f, 94.2f, 95.4f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f,
10.9, 11, 11.1, 22.7, 22.9, 23.1, 0, 0, 0, 38.1, 38.4, 38.7, 78.9, 79.5, 80.1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 10.9f, 11.f, 11.1f, 22.7f, 22.9f, 23.1f, 0.f, 0.f, 0.f, 38.1f, 38.4f, 38.7f, 78.9f, 79.5f, 80.1f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 32.6, 32.8, 33, 67, 67.4, 67.8, 0, 0, 0, 108.6, 109.2, 109.8, 222.6, 223.8, 225,}); 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 32.6f, 32.8f, 33.f, 67.f, 67.4f, 67.8f, 0.f, 0.f, 0.f, 108.6f, 109.2f, 109.8f, 222.6f, 223.8f, 225.f});
input.linspace(1.); input.linspace(1.);
gradO.linspace(0.1, 0.1); gradO.linspace(0.1, 0.1);
@ -1652,9 +1651,9 @@ TYPED_TEST(TypedConvolutionTests2, maxpool2d_bp_3) {
auto input = NDArrayFactory::create<TypeParam>('c', {bS, iC, iH, iW}); auto input = NDArrayFactory::create<TypeParam>('c', {bS, iC, iH, iW});
auto gradO = NDArrayFactory::create<TypeParam>('c', {bS, iC, oH, oW}); auto gradO = NDArrayFactory::create<TypeParam>('c', {bS, iC, oH, oW});
auto expected = NDArrayFactory::create<TypeParam>('c', {bS, iC, iH, iW}, {0. , 0. , 0. ,0. , 0. , 0. ,0. , 0.1, 0.2,0. , 0.3, 0.4,0. , 0. , 0. ,0. , 0. , 0. ,0. , 0.5, 0.6,0. , 0.7, 0.8, auto expected = NDArrayFactory::create<TypeParam>('c', {bS, iC, iH, iW}, {0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.1f, 0.2f, 0.f, 0.3f, 0.4f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.5f, 0.6f, 0.f, 0.7f, 0.8f,
0. , 0. , 0. ,0. , 0. , 0. ,0. , 0.9, 1. ,0. , 1.1, 1.2,0. , 0. , 0. ,0. , 0. , 0. ,0. , 1.3, 1.4,0. , 1.5, 1.6, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.9f, 1.f, 0.f, 1.1f, 1.2f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 1.3f, 1.4f, 0.f, 1.5f, 1.6f,
0. , 0. , 0. ,0. , 0. , 0. ,0. , 1.7, 1.8,0. , 1.9, 2. ,0. , 0. , 0. ,0. , 0. , 0. ,0. , 2.1, 2.2,0. , 2.3, 2.4}); 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 1.7f, 1.8f, 0.f, 1.9f, 2.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 2.1f, 2.2f, 0.f, 2.3f, 2.4f});
input.linspace(1.); input.linspace(1.);
gradO.linspace(0.1, 0.1); gradO.linspace(0.1, 0.1);
@ -1679,9 +1678,9 @@ TYPED_TEST(TypedConvolutionTests2, maxpool2d_bp_4) {
auto input = NDArrayFactory::create<TypeParam>('c', {bS, iC, iH, iW}); auto input = NDArrayFactory::create<TypeParam>('c', {bS, iC, iH, iW});
auto gradO = NDArrayFactory::create<TypeParam>('c', {bS, iC, oH, oW}); auto gradO = NDArrayFactory::create<TypeParam>('c', {bS, iC, oH, oW});
auto expected = NDArrayFactory::create<TypeParam>('c', {bS, iC, iH, iW}, {0. , 0. , 0. , 0.1, 0.2, 0.7, 0.5, 0.6, 1.5, 2.2, 2.4, 5.4, 0. , 0. , 0. , 1.7, 1.8, 3.9, 2.1, 2.2, 4.7, 5.4, 5.6, 11.8, auto expected = NDArrayFactory::create<TypeParam>('c', {bS, iC, iH, iW}, {0.f, 0.f, 0.f, 0.1f, 0.2f, 0.7f, 0.5f, 0.6f, 1.5f, 2.2f, 2.4f, 5.4f, 0.f, 0.f, 0.f, 1.7f, 1.8f, 3.9f, 2.1f, 2.2f, 4.7f, 5.4f, 5.6f, 11.8f,
0. , 0. , 0. , 3.3, 3.4, 7.1, 3.7, 3.8, 7.9, 8.6, 8.8, 18.2, 0. , 0. , 0. , 4.9, 5. , 10.3, 5.3, 5.4, 11.1,11.8, 12. , 24.6, 0.f, 0.f, 0.f, 3.3f, 3.4f, 7.1f, 3.7f, 3.8f, 7.9f, 8.6f, 8.8f, 18.2f, 0.f, 0.f, 0.f, 4.9f, 5.f, 10.3f, 5.3f, 5.4f, 11.1f, 11.8f, 12.f, 24.6f,
0. , 0. , 0. , 6.5, 6.6, 13.5, 6.9, 7. , 14.3,15. , 15.2, 31. , 0. , 0. , 0. , 8.1, 8.2, 16.7, 8.5, 8.6, 17.5,18.2, 18.4, 37.4}); 0.f, 0.f, 0.f, 6.5f, 6.6f, 13.5f, 6.9f, 7.f, 14.3f, 15.f, 15.2f, 31.f, 0.f, 0.f, 0.f, 8.1f, 8.2f, 16.7f, 8.5f, 8.6f, 17.5f, 18.2f, 18.4f, 37.4f});
input.linspace(1.); input.linspace(1.);
gradO.linspace(0.1, 0.1); gradO.linspace(0.1, 0.1);
@ -1706,9 +1705,9 @@ TYPED_TEST(TypedConvolutionTests2, maxpool2d_bp_5) {
auto input = NDArrayFactory::create<TypeParam>('c', {bS, iH, iW, iC}); auto input = NDArrayFactory::create<TypeParam>('c', {bS, iH, iW, iC});
auto gradO = NDArrayFactory::create<TypeParam>('c', {bS, oH, oW, iC}); auto gradO = NDArrayFactory::create<TypeParam>('c', {bS, oH, oW, iC});
auto expected = NDArrayFactory::create<TypeParam>('c', {bS, iH, iW, iC}, {0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. , 0.1, 0.2, 0.3, 1.1, 1.3, 1.5, 0. , 0. , 0. , 1. , 1.1, 1.2, 2.9, 3.1, 3.3, auto expected = NDArrayFactory::create<TypeParam>('c', {bS, iH, iW, iC}, {0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.1f, 0.2f, 0.3f, 1.1f, 1.3f, 1.5f, 0.f, 0.f, 0.f, 1.f, 1.1f, 1.2f, 2.9f, 3.1f, 3.3f,
0. , 0. , 0. , 4.7, 4.9, 5.1,11.2,11.6,12. , 0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. , 3.7, 3.8, 3.9, 8.3, 8.5, 8.7, 0.f, 0.f, 0.f, 4.7f, 4.9f, 5.1f, 11.2f, 11.6f, 12.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 3.7f, 3.8f, 3.9f, 8.3f, 8.5f, 8.7f,
0. , 0. , 0. , 4.6, 4.7, 4.8,10.1,10.3,10.5, 0. , 0. , 0. ,11.9,12.1,12.3,25.6,26. ,26.4}); 0.f, 0.f, 0.f, 4.6f, 4.7f, 4.8f, 10.1f, 10.3f, 10.5f, 0.f, 0.f, 0.f, 11.9f, 12.1f, 12.3f, 25.6f, 26.f, 26.4f});
input.linspace(1.); input.linspace(1.);
gradO.linspace(0.1, 0.1); gradO.linspace(0.1, 0.1);
@ -1733,9 +1732,9 @@ TYPED_TEST(TypedConvolutionTests2, maxpool2d_bp_6) {
auto input = NDArrayFactory::create<TypeParam>('c', {bS, iH, iW, iC}); auto input = NDArrayFactory::create<TypeParam>('c', {bS, iH, iW, iC});
auto gradO = NDArrayFactory::create<TypeParam>('c', {bS, oH, oW, iC}); auto gradO = NDArrayFactory::create<TypeParam>('c', {bS, oH, oW, iC});
auto expected = NDArrayFactory::create<TypeParam>('c', {bS, iH, iW, iC}, {0. , 0. , 0. ,0. , 0. , 0. ,0. , 0. , 0. ,0. , 0. , 0. ,0. , 0. , 0. ,0. , 0. , 0. ,0. , 0. , 0. ,0.1, 0.2, 0.3,0.4, 0.5, 0.6, auto expected = NDArrayFactory::create<TypeParam>('c', {bS, iH, iW, iC}, {0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.1f, 0.2f, 0.3f, 0.4f, 0.5f, 0.6f,
0. , 0. , 0. ,0.7, 0.8, 0.9,1. , 1.1, 1.2,0. , 0. , 0. ,0. , 0. , 0. ,0. , 0. , 0. ,0. , 0. , 0. ,0. , 0. , 0. ,0. , 0. , 0. , 0.f, 0.f, 0.f, 0.7f, 0.8f, 0.9f, 1.f, 1.1f, 1.2f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f,
0. , 0. , 0. ,1.3, 1.4, 1.5,1.6, 1.7, 1.8,0. , 0. , 0. ,1.9, 2. , 2.1,2.2, 2.3, 2.4}); 0.f, 0.f, 0.f, 1.3f, 1.4f, 1.5f, 1.6f, 1.7f, 1.8f, 0.f, 0.f, 0.f, 1.9f, 2.f, 2.1f, 2.2f, 2.3f, 2.4f});
input.linspace(1.); input.linspace(1.);
gradO.linspace(0.1, 0.1); gradO.linspace(0.1, 0.1);
@ -1815,8 +1814,8 @@ TYPED_TEST(TypedConvolutionTests2, avgpool2d_bp_2) {
// TypeParam expectedBuff[] = {0.875, 2., 2.5,1.375, 2.75 , 6., 7., 3.75, 4.75 ,10., 11., 5.75, 2.875, 6., 6.5, 3.375, 4.875, 10.,10.5, 5.375, 10.75, 22.,23., 11.75, 12.75, 26.,27., 13.75, 6.875, 14.,14.5, 7.375}; // TypeParam expectedBuff[] = {0.875, 2., 2.5,1.375, 2.75 , 6., 7., 3.75, 4.75 ,10., 11., 5.75, 2.875, 6., 6.5, 3.375, 4.875, 10.,10.5, 5.375, 10.75, 22.,23., 11.75, 12.75, 26.,27., 13.75, 6.875, 14.,14.5, 7.375};
auto input = NDArrayFactory::create<TypeParam>('c', {bS,iD,iH,iW}); auto input = NDArrayFactory::create<TypeParam>('c', {bS,iD,iH,iW});
auto epsilon = NDArrayFactory::create<TypeParam>('c', {bS,iD,oH,oW}, {3.5 , 4.5 , 5.5, 7.5 , 8.5 , 9.5, 11.5, 12.5, 13.5, 19.5, 20.5, 21.5, 23.5, 24.5, 25.5, 27.5, 28.5, 29.5}); auto epsilon = NDArrayFactory::create<TypeParam>('c', {bS,iD,oH,oW}, {3.5f, 4.5f, 5.5f, 7.5f, 8.5f, 9.5f, 11.5f, 12.5f, 13.5f, 19.5f, 20.5f, 21.5f, 23.5f, 24.5f, 25.5f, 27.5f, 28.5f, 29.5f});
auto expected = NDArrayFactory::create<TypeParam>('c', {bS,iD,iH,iW}, {0.875, 2., 2.5,1.375, 2.75 , 6., 7., 3.75, 4.75 ,10., 11., 5.75, 2.875, 6., 6.5, 3.375, 4.875, 10.,10.5, 5.375, 10.75, 22.,23., 11.75, 12.75, 26.,27., 13.75, 6.875, 14.,14.5, 7.375}); auto expected = NDArrayFactory::create<TypeParam>('c', {bS,iD,iH,iW}, {0.875f, 2.f, 2.5f, 1.375f, 2.75f, 6.f, 7.f, 3.75f, 4.75f, 10.f, 11.f, 5.75f, 2.875f, 6.f, 6.5f, 3.375f, 4.875f, 10.f, 10.5f, 5.375f, 10.75f, 22.f, 23.f, 11.75f, 12.75f, 26.f, 27.f, 13.75f, 6.875f, 14.f, 14.5f, 7.375f});
input.linspace(1.); input.linspace(1.);
@ -1842,12 +1841,12 @@ TYPED_TEST(TypedConvolutionTests2, avgpool2d_bp_3) {
auto input = NDArrayFactory::create<TypeParam>('c', {bS, iC, iH, iW}); auto input = NDArrayFactory::create<TypeParam>('c', {bS, iC, iH, iW});
auto gradO = NDArrayFactory::create<TypeParam>('c', {bS, iC, oH, oW}); auto gradO = NDArrayFactory::create<TypeParam>('c', {bS, iC, oH, oW});
auto expected = NDArrayFactory::create<TypeParam>('c', {bS, iC, iH, iW}, {0.016667,0.05 ,0.033333,0.066667,0.166667,0.1 ,0.066667,0.166667,0.1 ,0.05 ,0.116667,0.066667, auto expected = NDArrayFactory::create<TypeParam>('c', {bS, iC, iH, iW}, {0.016667f, 0.05f, 0.033333f, 0.066667f, 0.166667f, 0.1f, 0.066667f, 0.166667f, 0.1f, 0.05f, 0.116667f, 0.066667f,
0.083333,0.183333,0.1 ,0.2 ,0.433333,0.233333,0.2 ,0.433333,0.233333,0.116667,0.25 ,0.133333, 0.083333f, 0.183333f, 0.1f, 0.2f, 0.433333f, 0.233333f, 0.2f, 0.433333f, 0.233333f, 0.116667f, 0.25f, 0.133333f,
0.15 ,0.316667,0.166667,0.333333,0.7 ,0.366667,0.333333,0.7 ,0.366667,0.183333,0.383333,0.2 , 0.15f, 0.316667f, 0.166667f, 0.333333f, 0.7f, 0.366667f, 0.333333f, 0.7f, 0.366667f, 0.183333f, 0.383333f, 0.2f,
0.216667,0.45 ,0.233333,0.466667,0.966667,0.5 ,0.466667,0.966667,0.5 ,0.25 ,0.516667,0.266667, 0.216667f, 0.45f, 0.233333f, 0.466667f, 0.966667f, 0.5f, 0.466667f, 0.966667f, 0.5f, 0.25f, 0.516667f, 0.266667f,
0.283333,0.583333,0.3 ,0.6 ,1.233333,0.633333,0.6 ,1.233333,0.633333,0.316667,0.65 ,0.333333, 0.283333f, 0.583333f, 0.3f, 0.6f, 1.233333f, 0.633333f, 0.6f, 1.233333f, 0.633333f, 0.316667f, 0.65f, 0.333333f,
0.35 ,0.716667,0.366667,0.733333,1.5 ,0.766667,0.733333,1.5 ,0.766667,0.383333,0.783333,0.4 }); 0.35f, 0.716667f, 0.366667f, 0.733333f, 1.5f, 0.766667f, 0.733333f, 1.5f, 0.766667f, 0.383333f, 0.783333f, 0.4f });
input.linspace(1.); input.linspace(1.);
gradO.linspace(0.1, 0.1); gradO.linspace(0.1, 0.1);
@ -1873,12 +1872,12 @@ TYPED_TEST(TypedConvolutionTests2, avgpool2d_bp_4) {
auto input = NDArrayFactory::create<TypeParam>('c', {bS, iC, iH, iW}); auto input = NDArrayFactory::create<TypeParam>('c', {bS, iC, iH, iW});
auto gradO = NDArrayFactory::create<TypeParam>('c', {bS, iC, oH, oW}); auto gradO = NDArrayFactory::create<TypeParam>('c', {bS, iC, oH, oW});
auto expected = NDArrayFactory::create<TypeParam>('c', {bS, iC, iH, iW}, {0.233333,0.3 ,0.366667,0.55 ,0.65 ,0.75 ,0.95 ,1.05 ,1.15 ,0.766667,0.833333,0.9 , auto expected = NDArrayFactory::create<TypeParam>('c', {bS, iC, iH, iW}, {0.233333f, 0.3f, 0.366667f, 0.55f, 0.65f, 0.75f, 0.95f, 1.05f, 1.15f, 0.766667f, 0.833333f, 0.9f,
1.3 ,1.366667,1.433333,2.15 ,2.25 ,2.35 ,2.55 ,2.65 ,2.75 ,1.833333,1.9 ,1.966667, 1.3f, 1.366667f, 1.433333f, 2.15f, 2.25f, 2.35f, 2.55f, 2.65f, 2.75f, 1.833333f, 1.9f, 1.966667f,
2.366667,2.433333,2.5 ,3.75 ,3.85 ,3.95 ,4.15 ,4.25 ,4.35 ,2.9 ,2.966667,3.033333, 2.366667f, 2.433333f, 2.5f, 3.75f, 3.85f, 3.95f, 4.15f, 4.25f, 4.35f, 2.9f, 2.966667f, 3.033333f,
3.433333,3.5 ,3.566667,5.35 ,5.45 ,5.55 ,5.75 ,5.85 ,5.95 ,3.966667,4.033333,4.1 , 3.433333f, 3.5f, 3.566667f, 5.35f, 5.45f, 5.55f, 5.75f, 5.85f, 5.95f, 3.966667f, 4.033333f, 4.1f,
4.5 ,4.566667,4.633333,6.95 ,7.05 ,7.15 ,7.35 ,7.45 ,7.55 ,5.033333,5.1 ,5.166667, 4.5f, 4.566667f, 4.633333f, 6.95f, 7.05f, 7.15f, 7.35f, 7.45f, 7.55f, 5.033333f, 5.1f, 5.166667f,
5.566667,5.633333,5.7 ,8.549999,8.65 ,8.75 ,8.95 ,9.05 ,9.150001,6.1 ,6.166667,6.233334}); 5.566667f, 5.633333f, 5.7f, 8.549999f, 8.65f, 8.75f, 8.95f, 9.05f, 9.150001f, 6.1f, 6.166667f, 6.233334f});
input.linspace(1.); input.linspace(1.);
gradO.linspace(0.1, 0.1); gradO.linspace(0.1, 0.1);
@ -1904,10 +1903,10 @@ TYPED_TEST(TypedConvolutionTests2, avgpool2d_bp_5) {
auto input = NDArrayFactory::create<TypeParam>('c', {bS, iH, iW, iC}); auto input = NDArrayFactory::create<TypeParam>('c', {bS, iH, iW, iC});
auto gradO = NDArrayFactory::create<TypeParam>('c', {bS, oH, oW, iC}); auto gradO = NDArrayFactory::create<TypeParam>('c', {bS, oH, oW, iC});
auto expected = NDArrayFactory::create<TypeParam>('c', {bS, iH, iW, iC}, {0.19167, 0.23333, 0.275, 0.50833, 0.59167, 0.675, 1.2 , 1.325, 1.45 ,0.50833,0.56667, 0.625, 1.19167,1.30833, 1.425, 2.4 ,2.575, 2.75 , auto expected = NDArrayFactory::create<TypeParam>('c', {bS, iH, iW, iC}, {0.19167f, 0.23333f, 0.275f, 0.50833f, 0.59167f, 0.675f, 1.2f, 1.325f, 1.45f, 0.50833f, 0.56667f, 0.625f, 1.19167f, 1.30833f, 1.425f, 2.4f, 2.575f, 2.75f,
1.18333, 1.24167, 1.3 , 2.54167, 2.65833, 2.775, 4.425, 4.6 , 4.775,1.01667,1.05833, 1.1 , 2.15833,2.24167, 2.325, 3.675,3.8 , 3.925, 1.18333f, 1.24167f, 1.3f, 2.54167f, 2.65833f, 2.775f, 4.425f, 4.6f, 4.775f, 1.01667f, 1.05833f, 1.1f, 2.15833f, 2.24167f, 2.325f, 3.675f, 3.8f, 3.925f,
1.69167, 1.73333, 1.775, 3.50833, 3.59167, 3.675, 5.7 , 5.825, 5.95 ,2.60833,2.66667, 2.725, 5.39167,5.50833, 5.625, 8.7 ,8.875, 9.05 , 1.69167f, 1.73333f, 1.775f, 3.50833f, 3.59167f, 3.675f, 5.7f, 5.825f, 5.95f, 2.60833f, 2.66667f, 2.725f, 5.39167f, 5.50833f, 5.625f, 8.7f, 8.875f, 9.05f,
3.28333, 3.34167, 3.4 , 6.74167, 6.85833, 6.975,10.725,10.9 ,11.075,2.51667,2.55833, 2.6 , 5.15833,5.24167, 5.325, 8.175,8.3 , 8.425}); 3.28333f, 3.34167f, 3.4f, 6.74167f, 6.85833f, 6.975f, 10.725f, 10.9f, 11.075f, 2.51667f, 2.55833f, 2.6f, 5.15833f, 5.24167f, 5.325f, 8.175f, 8.3f, 8.425f});
input.linspace(1.); input.linspace(1.);
gradO.linspace(0.1, 0.1); gradO.linspace(0.1, 0.1);
@ -1933,10 +1932,10 @@ TYPED_TEST(TypedConvolutionTests2, avgpool2d_bp_6) {
auto input = NDArrayFactory::create<TypeParam>('c', {bS, iH, iW, iC}); auto input = NDArrayFactory::create<TypeParam>('c', {bS, iH, iW, iC});
auto gradO = NDArrayFactory::create<TypeParam>('c', {bS, oH, oW, iC}); auto gradO = NDArrayFactory::create<TypeParam>('c', {bS, oH, oW, iC});
auto expected = NDArrayFactory::create<TypeParam>('c', {bS, iH, iW, iC}, {0.01667,0.03333,0.05,0.08333,0.11667,0.15,0.06667,0.08333,0.1,0.13333,0.16667,0.2 ,0.36667,0.43333,0.5 ,0.23333,0.26667,0.3, auto expected = NDArrayFactory::create<TypeParam>('c', {bS, iH, iW, iC}, {0.01667f, 0.03333f, 0.05f, 0.08333f, 0.11667f, 0.15f, 0.06667f, 0.08333f, 0.1f, 0.13333f, 0.16667f, 0.2f, 0.36667f, 0.43333f, 0.5f, 0.23333f, 0.26667f, 0.3f,
0.13333,0.16667,0.2 ,0.36667,0.43333,0.5 ,0.23333,0.26667,0.3,0.11667,0.13333,0.15,0.28333,0.31667,0.35,0.16667,0.18333,0.2, 0.13333f, 0.16667f, 0.2f, 0.36667f, 0.43333f, 0.5f, 0.23333f, 0.26667f, 0.3f, 0.11667f, 0.13333f, 0.15f, 0.28333f, 0.31667f, 0.35f, 0.16667f, 0.18333f, 0.2f,
0.21667,0.23333,0.25,0.48333,0.51667,0.55,0.26667,0.28333,0.3,0.53333,0.56667,0.6 ,1.16667,1.23333,1.3 ,0.63333,0.66667,0.7, 0.21667f, 0.23333f, 0.25f, 0.48333f, 0.51667f, 0.55f, 0.26667f, 0.28333f, 0.3f, 0.53333f, 0.56667f, 0.6f, 1.16667f, 1.23333f, 1.3f, 0.63333f, 0.66667f, 0.7f,
0.53333,0.56667,0.6 ,1.16667,1.23333,1.3 ,0.63333,0.66667,0.7,0.31667,0.33333,0.35,0.68333,0.71667,0.75,0.36667,0.38333,0.4}); 0.53333f, 0.56667f, 0.6f, 1.16667f, 1.23333f, 1.3f, 0.63333f, 0.66667f, 0.7f, 0.31667f, 0.33333f, 0.35f, 0.68333f, 0.71667f, 0.75f, 0.36667f, 0.38333f, 0.4f});
input.linspace(1.); input.linspace(1.);
gradO.linspace(0.1, 0.1); gradO.linspace(0.1, 0.1);
@ -1995,12 +1994,12 @@ TYPED_TEST(TypedConvolutionTests2, pnormpool2d_bp_2) {
auto input = NDArrayFactory::create<TypeParam>('c', {bS, iC, iH, iW}); auto input = NDArrayFactory::create<TypeParam>('c', {bS, iC, iH, iW});
auto gradO = NDArrayFactory::create<TypeParam>('c', {bS, iC, oH, oW}); auto gradO = NDArrayFactory::create<TypeParam>('c', {bS, iC, oH, oW});
auto expected = NDArrayFactory::create<TypeParam>('c', {bS, iC, iH, iW}, {9.661570e-04,9.671602e-03,1.306569e-02,3.679184e-02,1.297220e-01,1.040181e-01,1.126750e-01,3.320884e-01,2.340406e-01,1.333333e-01,3.352886e-01,2.070211e-01, auto expected = NDArrayFactory::create<TypeParam>('c', {bS, iC, iH, iW}, {9.661570e-04f, 9.671602e-03f, 1.306569e-02f, 3.679184e-02f, 1.297220e-01f, 1.040181e-01f, 1.126750e-01f, 3.320884e-01f, 2.340406e-01f, 1.333333e-01f, 3.352886e-01f, 2.070211e-01f,
8.991618e-02,2.160601e-01,1.283173e-01,2.744226e-01,6.364498e-01,3.662123e-01,3.869788e-01,8.808994e-01,4.984556e-01,2.613189e-01,5.818475e-01,3.225517e-01, 8.991618e-02f, 2.160601e-01f, 1.283173e-01f, 2.744226e-01f, 6.364498e-01f, 3.662123e-01f, 3.869788e-01f, 8.808994e-01f, 4.984556e-01f, 2.613189e-01f, 5.818475e-01f, 3.225517e-01f,
2.065654e-01,4.553546e-01,2.501175e-01,5.190718e-01,1.131343e+00,6.148388e-01,6.362602e-01,1.377521e+00,7.439550e-01,3.833026e-01,8.227519e-01,4.407146e-01, 2.065654e-01f, 4.553546e-01f, 2.501175e-01f, 5.190718e-01f, 1.131343e+00f, 6.148388e-01f, 6.362602e-01f, 1.377521e+00f, 7.439550e-01f, 3.833026e-01f, 8.227519e-01f, 4.407146e-01f,
3.261206e-01,6.969233e-01,3.717564e-01,7.627507e-01,1.620991e+00,8.600952e-01,8.814538e-01,1.866888e+00,9.873542e-01,5.046682e-01,1.064004e+00,5.602558e-01, 3.261206e-01f, 6.969233e-01f, 3.717564e-01f, 7.627507e-01f, 1.620991e+00f, 8.600952e-01f, 8.814538e-01f, 1.866888e+00f, 9.873542e-01f, 5.046682e-01f, 1.064004e+00f, 5.602558e-01f,
4.464697e-01,9.389536e-01,4.932274e-01,1.005908e+00,2.108550e+00,1.104095e+00,1.125322e+00,2.354009e+00,1.230180e+00,6.258913e-01,1.305581e+00,6.804127e-01, 4.464697e-01f, 9.389536e-01f, 4.932274e-01f, 1.005908e+00f, 2.108550e+00f, 1.104095e+00f, 1.125322e+00f, 2.354009e+00f, 1.230180e+00f, 6.258913e-01f, 1.305581e+00f, 6.804127e-01f,
5.671396e-01,1.181128e+00,6.145977e-01,1.248783e+00,2.595083e+00,1.347494e+00,1.368600e+00,2.840157e+00,1.472778e+00,7.470673e-01,1.547362e+00,8.008900e-01}); 5.671396e-01f, 1.181128e+00f, 6.145977e-01f, 1.248783e+00f, 2.595083e+00f, 1.347494e+00f, 1.368600e+00f, 2.840157e+00f, 1.472778e+00f, 7.470673e-01f, 1.547362e+00f, 8.008900e-01f});
input.linspace(1.); input.linspace(1.);
gradO.linspace(0.1, 0.1); gradO.linspace(0.1, 0.1);
@ -2028,12 +2027,12 @@ TYPED_TEST(TypedConvolutionTests2, pnormpool2d_bp_3) {
auto input = NDArrayFactory::create<TypeParam>('c', {bS, iC, iH, iW}); auto input = NDArrayFactory::create<TypeParam>('c', {bS, iC, iH, iW});
auto gradO = NDArrayFactory::create<TypeParam>('c', {bS, iC, oH, oW}); auto gradO = NDArrayFactory::create<TypeParam>('c', {bS, iC, oH, oW});
auto expected = NDArrayFactory::create<TypeParam>('c', {bS, iC, iH, iW}, {0.007931,0.042891,0.040544,0.09369 ,0.276841,0.191675,0.163957,0.442946,0.287512,0.154919,0.373153,0.221172, auto expected = NDArrayFactory::create<TypeParam>('c', {bS, iC, iH, iW}, {0.007931f, 0.042891f, 0.040544f, 0.09369f, 0.276841f, 0.191675f, 0.163957f, 0.442946f, 0.287512f, 0.154919f, 0.373153f, 0.221172f,
0.15901 ,0.365232,0.207846,0.428282,0.959455,0.534076,0.508585,1.128771,0.623089,0.319794,0.698063,0.379547, 0.15901f, 0.365232f, 0.207846f, 0.428282f, 0.959455f, 0.534076f, 0.508585f, 1.128771f, 0.623089f, 0.319794f, 0.698063f, 0.379547f,
0.321068,0.692438,0.372316,0.757521,1.620323,0.864566,0.838684,1.787943,0.951023,0.483194,1.023434,0.541058, 0.321068f, 0.692438f, 0.372316f, 0.757521f, 1.620323f, 0.864566f, 0.838684f, 1.787943f, 0.951023f, 0.483194f, 1.023434f, 0.541058f,
0.483937,1.019414,0.536145,1.085348,2.276996,1.192917,1.166749,2.443606,1.278126,0.646499,1.349361,0.703463, 0.483937f, 1.019414f, 0.536145f, 1.085348f, 2.276996f, 1.192917f, 1.166749f, 2.443606f, 1.278126f, 0.646499f, 1.349361f, 0.703463f,
0.647021,1.346249,0.699745,1.412654,2.932174,1.520512,1.494153,3.098146,1.604985,0.809791,1.675544,0.866229, 0.647021f, 1.346249f, 0.699745f, 1.412654f, 2.932174f, 1.520512f, 1.494153f, 3.098146f, 1.604985f, 0.809791f, 1.675544f, 0.866229f,
0.810192,1.673009,0.863237,1.739711,3.58665 ,1.847753,1.82126 ,3.752188,1.931741,0.973081,2.001861,1.029173}); 0.810192f, 1.673009f, 0.863237f, 1.739711f, 3.58665f, 1.847753f, 1.82126f, 3.752188f, 1.931741f, 0.973081f, 2.001861f, 1.029173f});
input.linspace(1.); input.linspace(1.);
gradO.linspace(0.1, 0.1); gradO.linspace(0.1, 0.1);

View File

@ -2857,8 +2857,8 @@ TEST_F(DeclarableOpsTests10, FakeQuantWithMinMaxVars_Test_5) {
20.07843f, 21.019608f, 21.960783f, 23.058823f, 20.07843f, 21.019608f, 21.960783f, 23.058823f,
20.07843f, 21.019608f, 21.960783f, 23.058823f 20.07843f, 21.019608f, 21.960783f, 23.058823f
}); });
NDArray min = NDArrayFactory::create<float>({-20., -19., -18., -17}); NDArray min = NDArrayFactory::create<float>({-20.f, -19.f, -18.f, -17.f});
NDArray max = NDArrayFactory::create<float>({20., 21., 22., 23}); NDArray max = NDArrayFactory::create<float>({20.f, 21.f, 22.f, 23.f});
x.linspace(-60.); x.linspace(-60.);
nd4j::ops::fake_quant_with_min_max_vars_per_channel op; nd4j::ops::fake_quant_with_min_max_vars_per_channel op;
auto results = op.execute({&x, &min, &max}, {}, {}); auto results = op.execute({&x, &min, &max}, {}, {});
@ -3033,8 +3033,8 @@ TYPED_TEST(TypedDeclarableOpsTests10, batchnorm_test2) {
auto gamma = NDArrayFactory::create<TypeParam>('c', {4}); auto gamma = NDArrayFactory::create<TypeParam>('c', {4});
auto beta = NDArrayFactory::create<TypeParam>('c', {4}); auto beta = NDArrayFactory::create<TypeParam>('c', {4});
auto expected = NDArrayFactory::create<TypeParam>('c', {2,3,4}, {-0.52733537,-0.35763144,-0.18792751,-0.01822358, 0.15148035, 0.32118428, 0.49088821, 0.66059214, 0.83029607, 1. , 1.16970393, 1.33940786, auto expected = NDArrayFactory::create<TypeParam>('c', {2,3,4}, {-0.52733537f, -0.35763144f, -0.18792751f, -0.01822358f, 0.15148035f, 0.32118428f, 0.49088821f, 0.66059214f, 0.83029607f, 1.f, 1.16970393f, 1.33940786f,
1.50911179, 1.67881572, 1.84851965, 2.01822358, 2.18792751, 2.35763144, 2.52733537, 2.6970393 , 2.86674323, 3.03644717, 3.2061511 , 3.37585503}); 1.50911179f, 1.67881572f, 1.84851965f, 2.01822358f, 2.18792751f, 2.35763144f, 2.52733537f, 2.6970393f, 2.86674323f, 3.03644717f, 3.2061511f, 3.37585503f});
input.linspace(0.1, 0.1); input.linspace(0.1, 0.1);
mean.assign(1.); mean.assign(1.);
@ -3061,13 +3061,13 @@ TYPED_TEST(TypedDeclarableOpsTests10, batchnorm_test2) {
TYPED_TEST(TypedDeclarableOpsTests10, batchnorm_test3) { TYPED_TEST(TypedDeclarableOpsTests10, batchnorm_test3) {
auto input = NDArrayFactory::create<TypeParam>('c', {2,3,4}); auto input = NDArrayFactory::create<TypeParam>('c', {2,3,4});
auto mean = NDArrayFactory::create<TypeParam>('c', {3}, {1.05, 1.1, 1.15}); auto mean = NDArrayFactory::create<TypeParam>('c', {3}, {1.05f, 1.1f, 1.15f});
auto variance = NDArrayFactory::create<TypeParam>('c', {3}, {0.5, 0.6, 0.7}); auto variance = NDArrayFactory::create<TypeParam>('c', {3}, {0.5f, 0.6f, 0.7f});
auto gamma = NDArrayFactory::create<TypeParam>('c', {3}, {1.2, 1.3, 1.4}); auto gamma = NDArrayFactory::create<TypeParam>('c', {3}, {1.2f, 1.3f, 1.4f});
auto beta = NDArrayFactory::create<TypeParam>('c', {3}, {0.1, 0.2, 0.3}); auto beta = NDArrayFactory::create<TypeParam>('c', {3}, {0.1f, 0.2f, 0.3f});
auto expected = NDArrayFactory::create<TypeParam>('c', {2,3,4}, {-1.51218734,-1.34248341,-1.17277948,-1.00307555,-0.80696728,-0.6391394 ,-0.47131152,-0.30348364,-0.11832703, 0.04900378, 0.21633459, 0.38366541, auto expected = NDArrayFactory::create<TypeParam>('c', {2,3,4}, {-1.51218734f, -1.34248341f, -1.17277948f, -1.00307555f, -0.80696728f, -0.6391394f, -0.47131152f, -0.30348364f, -0.11832703f, 0.04900378f, 0.21633459f, 0.38366541f,
0.52425983, 0.69396376, 0.86366769, 1.03337162, 1.20696728, 1.37479516, 1.54262304, 1.71045092, 1.8896427 , 2.05697351, 2.22430432, 2.39163513,}); 0.52425983f, 0.69396376f, 0.86366769f, 1.03337162f, 1.20696728f, 1.37479516f, 1.54262304f, 1.71045092f, 1.8896427f, 2.05697351f, 2.22430432f, 2.39163513f});
input.linspace(0.1, 0.1); input.linspace(0.1, 0.1);
@ -3089,13 +3089,13 @@ TYPED_TEST(TypedDeclarableOpsTests10, batchnorm_test3) {
TYPED_TEST(TypedDeclarableOpsTests10, batchnorm_test4) { TYPED_TEST(TypedDeclarableOpsTests10, batchnorm_test4) {
auto input = NDArrayFactory::create<TypeParam>('c', {2,3,4}); auto input = NDArrayFactory::create<TypeParam>('c', {2,3,4});
auto mean = NDArrayFactory::create<TypeParam>('c', {2,1,4}, {1.05, 1.1, 1.15, 1.2, 1.25, 1.3, 1.35, 1.4}); auto mean = NDArrayFactory::create<TypeParam>('c', {2,1,4}, {1.05f, 1.1f, 1.15f, 1.2f, 1.25f, 1.3f, 1.35f, 1.4f});
auto variance = NDArrayFactory::create<TypeParam>('c', {2,1,4}, {0.5, 0.6, 0.7, 0.8, 0.9, 1., 1.1, 1.2}); auto variance = NDArrayFactory::create<TypeParam>('c', {2,1,4}, {0.5f, 0.6f, 0.7f, 0.8f, 0.9f, 1.f, 1.1f, 1.2f});
auto gamma = NDArrayFactory::create<TypeParam>('c', {2,1,4}, {1.2, 1.3, 1.4, 1.5, 1.6, 1.7, 1.8, 1.9}); auto gamma = NDArrayFactory::create<TypeParam>('c', {2,1,4}, {1.2f, 1.3f, 1.4f, 1.5f, 1.6f, 1.7f, 1.8f, 1.9f});
auto beta = NDArrayFactory::create<TypeParam>('c', {2,1,4}, {0.1, 0.2, 0.3, 0.4, 0.5, 0.66, 0.7, 0.8}); auto beta = NDArrayFactory::create<TypeParam>('c', {2,1,4}, {0.1f, 0.2f, 0.3f, 0.4f, 0.5f, 0.66f, 0.7f, 0.8f});
auto expected = NDArrayFactory::create<TypeParam>('c', {2,3,4}, {-1.51218734,-1.31045092,-1.12231189,-0.9416324 ,-0.83337162,-0.6391394 ,-0.45298865,-0.2708162 ,-0.1545559 , 0.03217212, 0.21633459, 0.4, auto expected = NDArrayFactory::create<TypeParam>('c', {2,3,4}, {-1.51218734f, -1.31045092f, -1.12231189f, -0.9416324f, -0.83337162f, -0.6391394f, -0.45298865f, -0.2708162f, -0.1545559f, 0.03217212f, 0.21633459f, 0.4f,
0.58432694, 0.82999915, 0.95743373, 1.14688951, 1.25894242, 1.50999575, 1.64392367, 1.84066852, 1.93355791, 2.18999235, 2.33041362, 2.53444754}); 0.58432694f, 0.82999915f, 0.95743373f, 1.14688951f, 1.25894242f, 1.50999575f, 1.64392367f, 1.84066852f, 1.93355791f, 2.18999235f, 2.33041362f, 2.53444754f});
input.linspace(0.1, 0.1); input.linspace(0.1, 0.1);

View File

@ -1406,9 +1406,9 @@ TEST_F(DeclarableOpsTests12, pad_tests1) {
// REFLECT mode 2D // REFLECT mode 2D
TEST_F(DeclarableOpsTests12, pad_tests2) { TEST_F(DeclarableOpsTests12, pad_tests2) {
float inBuff[] = {1,2,3,4,5,6}; float inBuff[] = {1.f, 2.f, 3.f, 4.f, 5.f, 6.f};
int padBuff[] = {1,1,2,2}; int padBuff[] = {1,1,2,2};
float expBuff[] = {6,5,4,5,6,5,4, 3,2,1,2,3,2,1, 6,5,4,5,6,5,4, 3,2,1,2,3,2,1}; float expBuff[] = {6.f, 5.f, 4.f, 5.f, 6.f, 5.f, 4.f, 3.f, 2.f, 1.f, 2.f, 3.f, 2.f, 1.f, 6.f, 5.f, 4.f, 5.f, 6.f, 5.f, 4.f, 3.f, 2.f, 1.f, 2.f, 3.f, 2.f, 1.f};
auto input = NDArrayFactory::create<float>(inBuff, 'c', {2,3}); auto input = NDArrayFactory::create<float>(inBuff, 'c', {2,3});
auto paddings = NDArrayFactory::create<int>(padBuff, 'c', {2,2}); auto paddings = NDArrayFactory::create<int>(padBuff, 'c', {2,2});
@ -1433,9 +1433,9 @@ TEST_F(DeclarableOpsTests12, pad_tests2) {
// SYMMETRIC mode 2D // SYMMETRIC mode 2D
TEST_F(DeclarableOpsTests12, pad_tests3) { TEST_F(DeclarableOpsTests12, pad_tests3) {
float inBuff[] = {1,2,3,4,5,6}; float inBuff[] = {1.f, 2.f, 3.f, 4.f, 5.f, 6.f};
int padBuff[] = {1,1,2,2}; int padBuff[] = {1,1,2,2};
float expBuff[] = {2,1,1,2,3,3,2, 2,1,1,2,3,3,2, 5,4,4,5,6,6,5, 5,4,4,5,6,6,5}; float expBuff[] = {2.f, 1.f, 1.f, 2.f, 3.f, 3.f, 2.f, 2.f,1.f,1.f,2.f,3.f,3.f,2.f, 5.f,4.f,4.f,5.f,6.f,6.f,5.f, 5.f,4.f,4.f,5.f,6.f,6.f,5.f};
auto input = NDArrayFactory::create<float>(inBuff, 'c', {2,3}); auto input = NDArrayFactory::create<float>(inBuff, 'c', {2,3});
auto paddings = NDArrayFactory::create<int>(padBuff, 'c', {2,2}); auto paddings = NDArrayFactory::create<int>(padBuff, 'c', {2,2});
@ -1460,13 +1460,13 @@ TEST_F(DeclarableOpsTests12, pad_tests3) {
// CONSTANT mode 3D // CONSTANT mode 3D
TEST_F(DeclarableOpsTests12, pad_tests4) { TEST_F(DeclarableOpsTests12, pad_tests4) {
float inBuff[] = {1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16,17,18}; float inBuff[] = {1.f,2.f,3.f,4.f,5.f,6.f,7.f,8.f,9.f,10.f,11.f,12.f,13.f,14.f,15.f,16.f,17.f,18.f};
int padBuff[] = {1,1,2,2,2,2}; int padBuff[] = {1,1,2,2,2,2};
float expBuff[] = {0,0, 0, 0, 0,0,0,0,0, 0, 0, 0,0,0,0,0, 0, 0, 0,0,0,0,0, 0, 0, 0,0,0,0,0, 0, 0, 0,0,0,0,0, 0, 0, float expBuff[] = {0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f,
0,0,0,0,0, 0, 0, 0,0,0,0,0, 0, 0, 0,0,0,0,0, 0, 0, 0,0,0,0,0, 1, 2, 3,0,0,0,0, 4, 5, 6,0,0,0,0, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 1.f, 2.f, 3.f, 0.f, 0.f, 0.f, 0.f, 4.f, 5.f, 6.f, 0.f, 0.f, 0.f, 0.f,
7, 8, 9,0,0,0,0, 0, 0, 0,0,0,0,0, 0, 0, 0,0,0,0,0, 0, 0, 0,0,0,0,0, 0, 0, 0,0,0,0,0,10,11,12,0, 7.f, 8.f, 9.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 10.f, 11.f, 12.f, 0.f,
0,0,0,13,14,15,0,0,0,0,16,17,18,0,0,0,0, 0, 0, 0,0,0,0,0, 0, 0, 0,0,0,0,0, 0, 0, 0,0,0,0,0, 0, 0.f, 0.f, 0.f, 13.f, 14.f, 15.f, 0.f, 0.f, 0.f, 0.f, 16.f, 17.f, 18.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f,
0, 0,0,0,0,0, 0, 0, 0,0,0,0,0, 0, 0, 0,0,0,0,0, 0, 0, 0,0,0,0,0, 0, 0, 0,0,0,0,0, 0, 0, 0,0,0}; 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f};
auto input = NDArrayFactory::create<float>(inBuff, 'c', {2,3,3}); auto input = NDArrayFactory::create<float>(inBuff, 'c', {2,3,3});
auto paddings = NDArrayFactory::create<int>(padBuff, 'c', {3,2}); auto paddings = NDArrayFactory::create<int>(padBuff, 'c', {3,2});
@ -1499,12 +1499,12 @@ TEST_F(DeclarableOpsTests12, pad_tests4) {
// REFLECT mode 3D // REFLECT mode 3D
TEST_F(DeclarableOpsTests12, pad_tests5) { TEST_F(DeclarableOpsTests12, pad_tests5) {
float inBuff[] = {1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16,17,18}; double inBuff[] = {1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16,17,18};
int padBuff[] = {1,1,2,2,2,2}; int padBuff[] = {1,1,2,2,2,2};
float expBuff[] = {18,17,16,17,18,17,16, 15,14,13,14,15,14,13, 12,11,10,11,12,11,10, 15,14,13,14,15,14,13, 18,17,16,17,18,17,16, 15,14,13,14,15,14,13, 12,11,10,11,12,11,10, 9, 8, 7, 8, 9, 8, 7, 6, 5, 4, 5, 6, 5, 4, 3, 2, 1, 2, 3, 2, 1, 6, 5, 4, 5, 6, 5, 4, 9, 8, 7, 8, 9, 8, 7, 6, 5, 4, 5, 6, 5, 4, 3, 2, 1, 2, 3, 2, 1, 18,17,16,17,18,17,16, 15,14,13,14,15,14,13, 12,11,10,11,12,11,10, 15,14,13,14,15,14,13, 18,17,16,17,18,17,16, 15,14,13,14,15,14,13, 12,11,10,11,12,11,10, 9, 8, 7, 8, 9, 8, 7, 6, 5, 4, 5, 6, 5, 4, 3, 2, 1, 2, 3, 2, 1, 6, 5, 4, 5, 6, 5, 4, 9, 8, 7, 8, 9, 8, 7, 6, 5, 4, 5, 6, 5, 4, 3, 2, 1, 2, 3, 2, 1}; double expBuff[] = {18,17,16,17,18,17,16, 15,14,13,14,15,14,13, 12,11,10,11,12,11,10, 15,14,13,14,15,14,13, 18,17,16,17,18,17,16, 15,14,13,14,15,14,13, 12,11,10,11,12,11,10, 9, 8, 7, 8, 9, 8, 7, 6, 5, 4, 5, 6, 5, 4, 3, 2, 1, 2, 3, 2, 1, 6, 5, 4, 5, 6, 5, 4, 9, 8, 7, 8, 9, 8, 7, 6, 5, 4, 5, 6, 5, 4, 3, 2, 1, 2, 3, 2, 1, 18,17,16,17,18,17,16, 15,14,13,14,15,14,13, 12,11,10,11,12,11,10, 15,14,13,14,15,14,13, 18,17,16,17,18,17,16, 15,14,13,14,15,14,13, 12,11,10,11,12,11,10, 9, 8, 7, 8, 9, 8, 7, 6, 5, 4, 5, 6, 5, 4, 3, 2, 1, 2, 3, 2, 1, 6, 5, 4, 5, 6, 5, 4, 9, 8, 7, 8, 9, 8, 7, 6, 5, 4, 5, 6, 5, 4, 3, 2, 1, 2, 3, 2, 1};
auto input = NDArrayFactory::create<float>(inBuff, 'c', {2,3,3}); auto input = NDArrayFactory::create<double>(inBuff, 'c', {2,3,3});
auto paddings = NDArrayFactory::create<int>(padBuff, 'c', {3,2}); auto paddings = NDArrayFactory::create<int>(padBuff, 'c', {3,2});
auto expected = NDArrayFactory::create<float>(expBuff, 'c', {4,7,7}); auto expected = NDArrayFactory::create<double>(expBuff, 'c', {4,7,7});
nd4j::ops::pad op; nd4j::ops::pad op;
auto results = op.execute({&input, &paddings}, {}, {1}); auto results = op.execute({&input, &paddings}, {}, {1});
@ -1525,13 +1525,13 @@ TEST_F(DeclarableOpsTests12, pad_tests5) {
// SYMMETRIC mode 3D // SYMMETRIC mode 3D
TEST_F(DeclarableOpsTests12, pad_tests6) { TEST_F(DeclarableOpsTests12, pad_tests6) {
float inBuff[] = {1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16,17,18}; double inBuff[] = {1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16,17,18};
int padBuff[] = {1,1,2,2,2,2}; int padBuff[] = {1,1,2,2,2,2};
float expBuff[] = {5, 4, 4, 5, 6, 6, 5, 2, 1, 1, 2, 3, 3, 2, 2, 1, 1, 2, 3, 3, 2, 5, 4, 4, 5, 6, 6, 5, 8, 7, 7, 8, 9, 9, 8, 8, 7, 7, 8, 9, 9, 8, 5, 4, 4, 5, 6, 6, 5, 5, 4, 4, 5, 6, 6, 5, 2, 1, 1, 2, 3, 3, 2, 2, 1, 1, 2, 3, 3, 2, 5, 4, 4, 5, 6, 6, 5, 8, 7, 7, 8, 9, 9, 8, 8, 7, 7, 8, 9, 9, 8, 5, 4, 4, 5, 6, 6, 5, 14,13,13,14,15,15,14, 11,10,10,11,12,12,11, 11,10,10,11,12,12,11, 14,13,13,14,15,15,14, 17,16,16,17,18,18,17, 17,16,16,17,18,18,17, 14,13,13,14,15,15,14, 14,13,13,14,15,15,14, 11,10,10,11,12,12,11, 11,10,10,11,12,12,11, 14,13,13,14,15,15,14, 17,16,16,17,18,18,17, 17,16,16,17,18,18,17, 14,13,13,14,15,15,14}; double expBuff[] = {5, 4, 4, 5, 6, 6, 5, 2, 1, 1, 2, 3, 3, 2, 2, 1, 1, 2, 3, 3, 2, 5, 4, 4, 5, 6, 6, 5, 8, 7, 7, 8, 9, 9, 8, 8, 7, 7, 8, 9, 9, 8, 5, 4, 4, 5, 6, 6, 5, 5, 4, 4, 5, 6, 6, 5, 2, 1, 1, 2, 3, 3, 2, 2, 1, 1, 2, 3, 3, 2, 5, 4, 4, 5, 6, 6, 5, 8, 7, 7, 8, 9, 9, 8, 8, 7, 7, 8, 9, 9, 8, 5, 4, 4, 5, 6, 6, 5, 14,13,13,14,15,15,14, 11,10,10,11,12,12,11, 11,10,10,11,12,12,11, 14,13,13,14,15,15,14, 17,16,16,17,18,18,17, 17,16,16,17,18,18,17, 14,13,13,14,15,15,14, 14,13,13,14,15,15,14, 11,10,10,11,12,12,11, 11,10,10,11,12,12,11, 14,13,13,14,15,15,14, 17,16,16,17,18,18,17, 17,16,16,17,18,18,17, 14,13,13,14,15,15,14};
auto input = NDArrayFactory::create<float>(inBuff, 'c', {2,3,3}); auto input = NDArrayFactory::create<double>(inBuff, 'c', {2,3,3});
auto paddings = NDArrayFactory::create<int>(padBuff, 'c', {3,2}); auto paddings = NDArrayFactory::create<int>(padBuff, 'c', {3,2});
auto expected = NDArrayFactory::create<float>(expBuff, 'c', {4,7,7}); auto expected = NDArrayFactory::create<double>(expBuff, 'c', {4,7,7});
nd4j::ops::pad op; nd4j::ops::pad op;
auto results = op.execute({&input, &paddings}, {}, {2}); auto results = op.execute({&input, &paddings}, {}, {2});
@ -1552,12 +1552,12 @@ TEST_F(DeclarableOpsTests12, pad_tests6) {
TEST_F(DeclarableOpsTests12, pad_tests7) TEST_F(DeclarableOpsTests12, pad_tests7)
{ {
float inBuff[] = {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16}; double inBuff[] = {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16};
int padBuff[] = {1, 1, 1, 1, 1, 1, 1, 1}; int padBuff[] = {1, 1, 1, 1, 1, 1, 1, 1};
float expBuff[] = {0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 2, 0, 0, 3, 4, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 5, 6, 0, 0, 7, 8, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 9, 10, 0, 0, 11, 12, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 13, 14, 0, 0, 15, 16, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0}; double expBuff[] = {0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 2, 0, 0, 3, 4, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 5, 6, 0, 0, 7, 8, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 9, 10, 0, 0, 11, 12, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 13, 14, 0, 0, 15, 16, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0};
auto input = NDArrayFactory::create<float>(inBuff, 'c', {2, 2, 2, 2}); auto input = NDArrayFactory::create<double>(inBuff, 'c', {2, 2, 2, 2});
auto paddings = NDArrayFactory::create<int>(padBuff, 'c', {4, 2}); auto paddings = NDArrayFactory::create<int>(padBuff, 'c', {4, 2});
auto expected = NDArrayFactory::create<float>(expBuff, 'c', {4, 4, 4, 4}); auto expected = NDArrayFactory::create<double>(expBuff, 'c', {4, 4, 4, 4});
nd4j::ops::pad op; nd4j::ops::pad op;
auto results = op.execute({&input, &paddings}, {}, {0}); auto results = op.execute({&input, &paddings}, {}, {0});
@ -1578,12 +1578,12 @@ TEST_F(DeclarableOpsTests12, pad_tests7)
TEST_F(DeclarableOpsTests12, pad_tests8) TEST_F(DeclarableOpsTests12, pad_tests8)
{ {
float inBuff[] = {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16}; double inBuff[] = {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16};
int padBuff[] = {1, 1, 1, 1, 1, 1, 1, 1}; int padBuff[] = {1, 1, 1, 1, 1, 1, 1, 1};
float expBuff[] = {16, 15, 16, 15, 14, 13, 14, 13, 16, 15, 16, 15, 14, 13, 14, 13, 12, 11, 12, 11, 10, 9, 10, 9, 12, 11, 12, 11, 10, 9, 10, 9, 16, 15, 16, 15, 14, 13, 14, 13, 16, 15, 16, 15, 14, 13, 14, 13, 12, 11, 12, 11, 10, 9, 10, 9, 12, 11, 12, 11, 10, 9, 10, 9, 8, 7, 8, 7, 6, 5, 6, 5, 8, 7, 8, 7, 6, 5, 6, 5, 4, 3, 4, 3, 2, 1, 2, 1, 4, 3, 4, 3, 2, 1, 2, 1, 8, 7, 8, 7, 6, 5, 6, 5, 8, 7, 8, 7, 6, 5, 6, 5, 4, 3, 4, 3, 2, 1, 2, 1, 4, 3, 4, 3, 2, 1, 2, 1, 16, 15, 16, 15, 14, 13, 14, 13, 16, 15, 16, 15, 14, 13, 14, 13, 12, 11, 12, 11, 10, 9, 10, 9, 12, 11, 12, 11, 10, 9, 10, 9, 16, 15, 16, 15, 14, 13, 14, 13, 16, 15, 16, 15, 14, 13, 14, 13, 12, 11, 12, 11, 10, 9, 10, 9, 12, 11, 12, 11, 10, 9, 10, 9, 8, 7, 8, 7, 6, 5, 6, 5, 8, 7, 8, 7, 6, 5, 6, 5, 4, 3, 4, 3, 2, 1, 2, 1, 4, 3, 4, 3, 2, 1, 2, 1, 8, 7, 8, 7, 6, 5, 6, 5, 8, 7, 8, 7, 6, 5, 6, 5, 4, 3, 4, 3, 2, 1, 2, 1, 4, 3, 4, 3, 2, 1, 2, 1}; double expBuff[] = {16, 15, 16, 15, 14, 13, 14, 13, 16, 15, 16, 15, 14, 13, 14, 13, 12, 11, 12, 11, 10, 9, 10, 9, 12, 11, 12, 11, 10, 9, 10, 9, 16, 15, 16, 15, 14, 13, 14, 13, 16, 15, 16, 15, 14, 13, 14, 13, 12, 11, 12, 11, 10, 9, 10, 9, 12, 11, 12, 11, 10, 9, 10, 9, 8, 7, 8, 7, 6, 5, 6, 5, 8, 7, 8, 7, 6, 5, 6, 5, 4, 3, 4, 3, 2, 1, 2, 1, 4, 3, 4, 3, 2, 1, 2, 1, 8, 7, 8, 7, 6, 5, 6, 5, 8, 7, 8, 7, 6, 5, 6, 5, 4, 3, 4, 3, 2, 1, 2, 1, 4, 3, 4, 3, 2, 1, 2, 1, 16, 15, 16, 15, 14, 13, 14, 13, 16, 15, 16, 15, 14, 13, 14, 13, 12, 11, 12, 11, 10, 9, 10, 9, 12, 11, 12, 11, 10, 9, 10, 9, 16, 15, 16, 15, 14, 13, 14, 13, 16, 15, 16, 15, 14, 13, 14, 13, 12, 11, 12, 11, 10, 9, 10, 9, 12, 11, 12, 11, 10, 9, 10, 9, 8, 7, 8, 7, 6, 5, 6, 5, 8, 7, 8, 7, 6, 5, 6, 5, 4, 3, 4, 3, 2, 1, 2, 1, 4, 3, 4, 3, 2, 1, 2, 1, 8, 7, 8, 7, 6, 5, 6, 5, 8, 7, 8, 7, 6, 5, 6, 5, 4, 3, 4, 3, 2, 1, 2, 1, 4, 3, 4, 3, 2, 1, 2, 1};
auto input = NDArrayFactory::create<float>(inBuff, 'c', {2, 2, 2, 2}); auto input = NDArrayFactory::create<double>(inBuff, 'c', {2, 2, 2, 2});
auto paddings = NDArrayFactory::create<int>(padBuff, 'c', {4, 2}); auto paddings = NDArrayFactory::create<int>(padBuff, 'c', {4, 2});
auto expected = NDArrayFactory::create<float>(expBuff, 'c', {4, 4, 4, 4}); auto expected = NDArrayFactory::create<double>(expBuff, 'c', {4, 4, 4, 4});
nd4j::ops::pad op; nd4j::ops::pad op;
auto results = op.execute({&input, &paddings}, {}, {1}); auto results = op.execute({&input, &paddings}, {}, {1});
@ -1604,12 +1604,12 @@ TEST_F(DeclarableOpsTests12, pad_tests8)
TEST_F(DeclarableOpsTests12, pad_tests9) TEST_F(DeclarableOpsTests12, pad_tests9)
{ {
float inBuff[] = {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16}; double inBuff[] = {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16};
int padBuff[] = {1, 1, 1, 1, 1, 1, 1, 1}; int padBuff[] = {1, 1, 1, 1, 1, 1, 1, 1};
float expBuff[] = {1, 1, 2, 2, 1, 1, 2, 2, 3, 3, 4, 4, 3, 3, 4, 4, 1, 1, 2, 2, 1, 1, 2, 2, 3, 3, 4, 4, 3, 3, 4, 4, 5, 5, 6, 6, 5, 5, 6, 6, 7, 7, 8, 8, 7, 7, 8, 8, 5, 5, 6, 6, 5, 5, 6, 6, 7, 7, 8, 8, 7, 7, 8, 8, 1, 1, 2, 2, 1, 1, 2, 2, 3, 3, 4, 4, 3, 3, 4, 4, 1, 1, 2, 2, 1, 1, 2, 2, 3, 3, 4, 4, 3, 3, 4, 4, 5, 5, 6, 6, 5, 5, 6, 6, 7, 7, 8, 8, 7, 7, 8, 8, 5, 5, 6, 6, 5, 5, 6, 6, 7, 7, 8, 8, 7, 7, 8, 8, 9, 9, 10, 10, 9, 9, 10, 10, 11, 11, 12, 12, 11, 11, 12, 12, 9, 9, 10, 10, 9, 9, 10, 10, 11, 11, 12, 12, 11, 11, 12, 12, 13, 13, 14, 14, 13, 13, 14, 14, 15, 15, 16, 16, 15, 15, 16, 16, 13, 13, 14, 14, 13, 13, 14, 14, 15, 15, 16, 16, 15, 15, 16, 16, 9, 9, 10, 10, 9, 9, 10, 10, 11, 11, 12, 12, 11, 11, 12, 12, 9, 9, 10, 10, 9, 9, 10, 10, 11, 11, 12, 12, 11, 11, 12, 12, 13, 13, 14, 14, 13, 13, 14, 14, 15, 15, 16, 16, 15, 15, 16, 16, 13, 13, 14, 14, 13, 13, 14, 14, 15, 15, 16, 16, 15, 15, 16, 16}; double expBuff[] = {1, 1, 2, 2, 1, 1, 2, 2, 3, 3, 4, 4, 3, 3, 4, 4, 1, 1, 2, 2, 1, 1, 2, 2, 3, 3, 4, 4, 3, 3, 4, 4, 5, 5, 6, 6, 5, 5, 6, 6, 7, 7, 8, 8, 7, 7, 8, 8, 5, 5, 6, 6, 5, 5, 6, 6, 7, 7, 8, 8, 7, 7, 8, 8, 1, 1, 2, 2, 1, 1, 2, 2, 3, 3, 4, 4, 3, 3, 4, 4, 1, 1, 2, 2, 1, 1, 2, 2, 3, 3, 4, 4, 3, 3, 4, 4, 5, 5, 6, 6, 5, 5, 6, 6, 7, 7, 8, 8, 7, 7, 8, 8, 5, 5, 6, 6, 5, 5, 6, 6, 7, 7, 8, 8, 7, 7, 8, 8, 9, 9, 10, 10, 9, 9, 10, 10, 11, 11, 12, 12, 11, 11, 12, 12, 9, 9, 10, 10, 9, 9, 10, 10, 11, 11, 12, 12, 11, 11, 12, 12, 13, 13, 14, 14, 13, 13, 14, 14, 15, 15, 16, 16, 15, 15, 16, 16, 13, 13, 14, 14, 13, 13, 14, 14, 15, 15, 16, 16, 15, 15, 16, 16, 9, 9, 10, 10, 9, 9, 10, 10, 11, 11, 12, 12, 11, 11, 12, 12, 9, 9, 10, 10, 9, 9, 10, 10, 11, 11, 12, 12, 11, 11, 12, 12, 13, 13, 14, 14, 13, 13, 14, 14, 15, 15, 16, 16, 15, 15, 16, 16, 13, 13, 14, 14, 13, 13, 14, 14, 15, 15, 16, 16, 15, 15, 16, 16};
auto input = NDArrayFactory::create<float>(inBuff, 'c', {2, 2, 2, 2}); auto input = NDArrayFactory::create<double>(inBuff, 'c', {2, 2, 2, 2});
auto paddings = NDArrayFactory::create<int>(padBuff, 'c', {4, 2}); auto paddings = NDArrayFactory::create<int>(padBuff, 'c', {4, 2});
auto expected = NDArrayFactory::create<float>(expBuff, 'c', {4, 4, 4, 4}); auto expected = NDArrayFactory::create<double>(expBuff, 'c', {4, 4, 4, 4});
nd4j::ops::pad op; nd4j::ops::pad op;
auto results = op.execute({&input, &paddings}, {}, {2}); auto results = op.execute({&input, &paddings}, {}, {2});
@ -2151,13 +2151,13 @@ TEST_F(DeclarableOpsTests12, pad_tests34) {
// CONSTANT mode 2D // CONSTANT mode 2D
TEST_F(DeclarableOpsTests12, Pad_1) { TEST_F(DeclarableOpsTests12, Pad_1) {
float inBuff[] = {1,2,3,4,5,6}; double inBuff[] = {1,2,3,4,5,6};
int padBuff[] = {1,1,2,2}; int padBuff[] = {1,1,2,2};
float expBuff[] = {0,0,0,0,0,0,0, 0,0,1,2,3,0,0, 0,0,4,5,6,0,0, 0,0,0,0,0,0,0}; double expBuff[] = {0,0,0,0,0,0,0, 0,0,1,2,3,0,0, 0,0,4,5,6,0,0, 0,0,0,0,0,0,0};
auto input = NDArrayFactory::create<float>(inBuff, 'c', {2,3}); auto input = NDArrayFactory::create<double>(inBuff, 'c', {2,3});
auto paddings = NDArrayFactory::create<int>(padBuff, 'c', {2,2}); auto paddings = NDArrayFactory::create<int>(padBuff, 'c', {2,2});
auto expected = NDArrayFactory::create<float>(expBuff, 'c', {4,7}); auto expected = NDArrayFactory::create<double>(expBuff, 'c', {4,7});
nd4j::ops::pad op; nd4j::ops::pad op;
auto results = op.execute({&input, &paddings}, {}, {0}); auto results = op.execute({&input, &paddings}, {}, {0});
@ -2178,13 +2178,13 @@ TEST_F(DeclarableOpsTests12, Pad_1) {
// REFLECT mode 2D // REFLECT mode 2D
TEST_F(DeclarableOpsTests12, Pad_2) { TEST_F(DeclarableOpsTests12, Pad_2) {
float inBuff[] = {1,2,3,4,5,6}; double inBuff[] = {1,2,3,4,5,6};
int padBuff[] = {1,1,2,2}; int padBuff[] = {1,1,2,2};
float expBuff[] = {6,5,4,5,6,5,4, 3,2,1,2,3,2,1, 6,5,4,5,6,5,4, 3,2,1,2,3,2,1}; double expBuff[] = {6,5,4,5,6,5,4, 3,2,1,2,3,2,1, 6,5,4,5,6,5,4, 3,2,1,2,3,2,1};
auto input = NDArrayFactory::create<float>(inBuff, 'c', {2,3}); auto input = NDArrayFactory::create<double>(inBuff, 'c', {2,3});
auto paddings = NDArrayFactory::create<int>(padBuff, 'c', {2,2}); auto paddings = NDArrayFactory::create<int>(padBuff, 'c', {2,2});
auto expected = NDArrayFactory::create<float>(expBuff, 'c', {4,7}); auto expected = NDArrayFactory::create<double>(expBuff, 'c', {4,7});
nd4j::ops::pad op; nd4j::ops::pad op;
auto results = op.execute({&input, &paddings}, {}, {1}); auto results = op.execute({&input, &paddings}, {}, {1});
@ -2205,13 +2205,13 @@ TEST_F(DeclarableOpsTests12, Pad_2) {
// SYMMETRIC mode 2D // SYMMETRIC mode 2D
TEST_F(DeclarableOpsTests12, Pad_3) { TEST_F(DeclarableOpsTests12, Pad_3) {
float inBuff[] = {1,2,3,4,5,6}; double inBuff[] = {1,2,3,4,5,6};
int padBuff[] = {1,1,2,2}; int padBuff[] = {1,1,2,2};
float expBuff[] = {2,1,1,2,3,3,2, 2,1,1,2,3,3,2, 5,4,4,5,6,6,5, 5,4,4,5,6,6,5}; double expBuff[] = {2,1,1,2,3,3,2, 2,1,1,2,3,3,2, 5,4,4,5,6,6,5, 5,4,4,5,6,6,5};
auto input = NDArrayFactory::create<float>(inBuff, 'c', {2,3}); auto input = NDArrayFactory::create<double>(inBuff, 'c', {2,3});
auto paddings = NDArrayFactory::create<int>(padBuff, 'c', {2,2}); auto paddings = NDArrayFactory::create<int>(padBuff, 'c', {2,2});
auto expected = NDArrayFactory::create<float>(expBuff, 'c', {4,7}); auto expected = NDArrayFactory::create<double>(expBuff, 'c', {4,7});
nd4j::ops::pad op; nd4j::ops::pad op;
auto results = op.execute({&input, &paddings}, {}, {2}); auto results = op.execute({&input, &paddings}, {}, {2});
@ -2232,13 +2232,13 @@ TEST_F(DeclarableOpsTests12, Pad_3) {
// CONSTANT mode 3D // CONSTANT mode 3D
TEST_F(DeclarableOpsTests12, Pad_4) { TEST_F(DeclarableOpsTests12, Pad_4) {
float inBuff[] = {1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16,17,18}; double inBuff[] = {1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16,17,18};
int padBuff[] = {1,1,2,2,2,2}; int padBuff[] = {1,1,2,2,2,2};
float expBuff[] = {0,0, 0, 0, 0,0,0,0,0, 0, 0, 0,0,0,0,0, 0, 0, 0,0,0,0,0, 0, 0, 0,0,0,0,0, 0, 0, 0,0,0,0,0, 0, 0, 0,0,0,0,0, 0, 0, 0,0,0,0,0, 0, 0, 0,0,0,0,0, 0, 0, 0,0,0,0,0, 1, 2, 3,0,0,0,0, 4, 5, 6,0,0,0,0, 7, 8, 9,0,0,0,0, 0, 0, 0,0,0,0,0, 0, 0, 0,0,0,0,0, 0, 0, 0,0,0,0,0, 0, 0, 0,0,0,0,0,10,11,12,0,0,0,0,13,14,15,0,0,0,0,16,17,18,0,0,0,0, 0, 0, 0,0,0,0,0, 0, 0, 0,0,0,0,0, 0, 0, 0,0,0,0,0, 0, 0, 0,0,0,0,0, 0, 0, 0,0,0,0,0, 0, 0, 0,0,0,0,0, 0, 0, 0,0,0,0,0, 0, 0, 0,0,0,0,0, 0, 0, 0,0,0}; double expBuff[] = {0,0, 0, 0, 0,0,0,0,0, 0, 0, 0,0,0,0,0, 0, 0, 0,0,0,0,0, 0, 0, 0,0,0,0,0, 0, 0, 0,0,0,0,0, 0, 0, 0,0,0,0,0, 0, 0, 0,0,0,0,0, 0, 0, 0,0,0,0,0, 0, 0, 0,0,0,0,0, 1, 2, 3,0,0,0,0, 4, 5, 6,0,0,0,0, 7, 8, 9,0,0,0,0, 0, 0, 0,0,0,0,0, 0, 0, 0,0,0,0,0, 0, 0, 0,0,0,0,0, 0, 0, 0,0,0,0,0,10,11,12,0,0,0,0,13,14,15,0,0,0,0,16,17,18,0,0,0,0, 0, 0, 0,0,0,0,0, 0, 0, 0,0,0,0,0, 0, 0, 0,0,0,0,0, 0, 0, 0,0,0,0,0, 0, 0, 0,0,0,0,0, 0, 0, 0,0,0,0,0, 0, 0, 0,0,0,0,0, 0, 0, 0,0,0,0,0, 0, 0, 0,0,0};
auto input = NDArrayFactory::create<float>(inBuff, 'c', {2,3,3}); auto input = NDArrayFactory::create<double>(inBuff, 'c', {2,3,3});
auto paddings = NDArrayFactory::create<int>(padBuff, 'c', {3,2}); auto paddings = NDArrayFactory::create<int>(padBuff, 'c', {3,2});
auto expected = NDArrayFactory::create<float>(expBuff, 'c', {4,7,7}); auto expected = NDArrayFactory::create<double>(expBuff, 'c', {4,7,7});
nd4j::ops::pad op; nd4j::ops::pad op;
auto results = op.execute({&input, &paddings}, {}, {0}); auto results = op.execute({&input, &paddings}, {}, {0});
@ -2260,12 +2260,12 @@ TEST_F(DeclarableOpsTests12, Pad_4) {
// REFLECT mode 3D // REFLECT mode 3D
TEST_F(DeclarableOpsTests12, Pad_5) { TEST_F(DeclarableOpsTests12, Pad_5) {
float inBuff[] = {1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16,17,18}; double inBuff[] = {1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16,17,18};
int padBuff[] = {1,1,2,2,2,2}; int padBuff[] = {1,1,2,2,2,2};
float expBuff[] = {18,17,16,17,18,17,16, 15,14,13,14,15,14,13, 12,11,10,11,12,11,10, 15,14,13,14,15,14,13, 18,17,16,17,18,17,16, 15,14,13,14,15,14,13, 12,11,10,11,12,11,10, 9, 8, 7, 8, 9, 8, 7, 6, 5, 4, 5, 6, 5, 4, 3, 2, 1, 2, 3, 2, 1, 6, 5, 4, 5, 6, 5, 4, 9, 8, 7, 8, 9, 8, 7, 6, 5, 4, 5, 6, 5, 4, 3, 2, 1, 2, 3, 2, 1, 18,17,16,17,18,17,16, 15,14,13,14,15,14,13, 12,11,10,11,12,11,10, 15,14,13,14,15,14,13, 18,17,16,17,18,17,16, 15,14,13,14,15,14,13, 12,11,10,11,12,11,10, 9, 8, 7, 8, 9, 8, 7, 6, 5, 4, 5, 6, 5, 4, 3, 2, 1, 2, 3, 2, 1, 6, 5, 4, 5, 6, 5, 4, 9, 8, 7, 8, 9, 8, 7, 6, 5, 4, 5, 6, 5, 4, 3, 2, 1, 2, 3, 2, 1}; double expBuff[] = {18,17,16,17,18,17,16, 15,14,13,14,15,14,13, 12,11,10,11,12,11,10, 15,14,13,14,15,14,13, 18,17,16,17,18,17,16, 15,14,13,14,15,14,13, 12,11,10,11,12,11,10, 9, 8, 7, 8, 9, 8, 7, 6, 5, 4, 5, 6, 5, 4, 3, 2, 1, 2, 3, 2, 1, 6, 5, 4, 5, 6, 5, 4, 9, 8, 7, 8, 9, 8, 7, 6, 5, 4, 5, 6, 5, 4, 3, 2, 1, 2, 3, 2, 1, 18,17,16,17,18,17,16, 15,14,13,14,15,14,13, 12,11,10,11,12,11,10, 15,14,13,14,15,14,13, 18,17,16,17,18,17,16, 15,14,13,14,15,14,13, 12,11,10,11,12,11,10, 9, 8, 7, 8, 9, 8, 7, 6, 5, 4, 5, 6, 5, 4, 3, 2, 1, 2, 3, 2, 1, 6, 5, 4, 5, 6, 5, 4, 9, 8, 7, 8, 9, 8, 7, 6, 5, 4, 5, 6, 5, 4, 3, 2, 1, 2, 3, 2, 1};
auto input = NDArrayFactory::create<float>(inBuff, 'c', {2,3,3}); auto input = NDArrayFactory::create<double>(inBuff, 'c', {2,3,3});
auto paddings = NDArrayFactory::create<int>(padBuff, 'c', {3,2}); auto paddings = NDArrayFactory::create<int>(padBuff, 'c', {3,2});
auto expected = NDArrayFactory::create<float>(expBuff, 'c', {4,7,7}); auto expected = NDArrayFactory::create<double>(expBuff, 'c', {4,7,7});
nd4j::ops::pad op; nd4j::ops::pad op;
auto results = op.execute({&input, &paddings}, {}, {1}); auto results = op.execute({&input, &paddings}, {}, {1});
@ -2286,13 +2286,13 @@ TEST_F(DeclarableOpsTests12, Pad_5) {
// SYMMETRIC mode 3D // SYMMETRIC mode 3D
TEST_F(DeclarableOpsTests12, Pad_6) { TEST_F(DeclarableOpsTests12, Pad_6) {
float inBuff[] = {1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16,17,18}; double inBuff[] = {1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16,17,18};
int padBuff[] = {1,1,2,2,2,2}; int padBuff[] = {1,1,2,2,2,2};
float expBuff[] = {5, 4, 4, 5, 6, 6, 5, 2, 1, 1, 2, 3, 3, 2, 2, 1, 1, 2, 3, 3, 2, 5, 4, 4, 5, 6, 6, 5, 8, 7, 7, 8, 9, 9, 8, 8, 7, 7, 8, 9, 9, 8, 5, 4, 4, 5, 6, 6, 5, 5, 4, 4, 5, 6, 6, 5, 2, 1, 1, 2, 3, 3, 2, 2, 1, 1, 2, 3, 3, 2, 5, 4, 4, 5, 6, 6, 5, 8, 7, 7, 8, 9, 9, 8, 8, 7, 7, 8, 9, 9, 8, 5, 4, 4, 5, 6, 6, 5, 14,13,13,14,15,15,14, 11,10,10,11,12,12,11, 11,10,10,11,12,12,11, 14,13,13,14,15,15,14, 17,16,16,17,18,18,17, 17,16,16,17,18,18,17, 14,13,13,14,15,15,14, 14,13,13,14,15,15,14, 11,10,10,11,12,12,11, 11,10,10,11,12,12,11, 14,13,13,14,15,15,14, 17,16,16,17,18,18,17, 17,16,16,17,18,18,17, 14,13,13,14,15,15,14}; double expBuff[] = {5, 4, 4, 5, 6, 6, 5, 2, 1, 1, 2, 3, 3, 2, 2, 1, 1, 2, 3, 3, 2, 5, 4, 4, 5, 6, 6, 5, 8, 7, 7, 8, 9, 9, 8, 8, 7, 7, 8, 9, 9, 8, 5, 4, 4, 5, 6, 6, 5, 5, 4, 4, 5, 6, 6, 5, 2, 1, 1, 2, 3, 3, 2, 2, 1, 1, 2, 3, 3, 2, 5, 4, 4, 5, 6, 6, 5, 8, 7, 7, 8, 9, 9, 8, 8, 7, 7, 8, 9, 9, 8, 5, 4, 4, 5, 6, 6, 5, 14,13,13,14,15,15,14, 11,10,10,11,12,12,11, 11,10,10,11,12,12,11, 14,13,13,14,15,15,14, 17,16,16,17,18,18,17, 17,16,16,17,18,18,17, 14,13,13,14,15,15,14, 14,13,13,14,15,15,14, 11,10,10,11,12,12,11, 11,10,10,11,12,12,11, 14,13,13,14,15,15,14, 17,16,16,17,18,18,17, 17,16,16,17,18,18,17, 14,13,13,14,15,15,14};
auto input = NDArrayFactory::create<float>(inBuff, 'c', {2,3,3}); auto input = NDArrayFactory::create<double>(inBuff, 'c', {2,3,3});
auto paddings = NDArrayFactory::create<int>(padBuff, 'c', {3,2}); auto paddings = NDArrayFactory::create<int>(padBuff, 'c', {3,2});
auto expected = NDArrayFactory::create<float>(expBuff, 'c', {4,7,7}); auto expected = NDArrayFactory::create<double>(expBuff, 'c', {4,7,7});
nd4j::ops::pad op; nd4j::ops::pad op;
auto results = op.execute({&input, &paddings}, {}, {2}); auto results = op.execute({&input, &paddings}, {}, {2});
@ -2313,12 +2313,12 @@ TEST_F(DeclarableOpsTests12, Pad_6) {
TEST_F(DeclarableOpsTests12, Pad_7) TEST_F(DeclarableOpsTests12, Pad_7)
{ {
float inBuff[] = {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16}; double inBuff[] = {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16};
int padBuff[] = {1, 1, 1, 1, 1, 1, 1, 1}; int padBuff[] = {1, 1, 1, 1, 1, 1, 1, 1};
float expBuff[] = {0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 2, 0, 0, 3, 4, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 5, 6, 0, 0, 7, 8, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 9, 10, 0, 0, 11, 12, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 13, 14, 0, 0, 15, 16, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0}; double expBuff[] = {0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 2, 0, 0, 3, 4, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 5, 6, 0, 0, 7, 8, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 9, 10, 0, 0, 11, 12, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 13, 14, 0, 0, 15, 16, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0};
auto input = NDArrayFactory::create<float>(inBuff, 'c', {2, 2, 2, 2}); auto input = NDArrayFactory::create<double>(inBuff, 'c', {2, 2, 2, 2});
auto paddings = NDArrayFactory::create<int>(padBuff, 'c', {4, 2}); auto paddings = NDArrayFactory::create<int>(padBuff, 'c', {4, 2});
auto expected = NDArrayFactory::create<float>(expBuff, 'c', {4, 4, 4, 4}); auto expected = NDArrayFactory::create<double>(expBuff, 'c', {4, 4, 4, 4});
nd4j::ops::pad op; nd4j::ops::pad op;
auto results = op.execute({&input, &paddings}, {}, {0}); auto results = op.execute({&input, &paddings}, {}, {0});
@ -2339,12 +2339,12 @@ TEST_F(DeclarableOpsTests12, Pad_7)
TEST_F(DeclarableOpsTests12, Pad_8) TEST_F(DeclarableOpsTests12, Pad_8)
{ {
float inBuff[] = {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16}; double inBuff[] = {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16};
int padBuff[] = {1, 1, 1, 1, 1, 1, 1, 1}; int padBuff[] = {1, 1, 1, 1, 1, 1, 1, 1};
float expBuff[] = {16, 15, 16, 15, 14, 13, 14, 13, 16, 15, 16, 15, 14, 13, 14, 13, 12, 11, 12, 11, 10, 9, 10, 9, 12, 11, 12, 11, 10, 9, 10, 9, 16, 15, 16, 15, 14, 13, 14, 13, 16, 15, 16, 15, 14, 13, 14, 13, 12, 11, 12, 11, 10, 9, 10, 9, 12, 11, 12, 11, 10, 9, 10, 9, 8, 7, 8, 7, 6, 5, 6, 5, 8, 7, 8, 7, 6, 5, 6, 5, 4, 3, 4, 3, 2, 1, 2, 1, 4, 3, 4, 3, 2, 1, 2, 1, 8, 7, 8, 7, 6, 5, 6, 5, 8, 7, 8, 7, 6, 5, 6, 5, 4, 3, 4, 3, 2, 1, 2, 1, 4, 3, 4, 3, 2, 1, 2, 1, 16, 15, 16, 15, 14, 13, 14, 13, 16, 15, 16, 15, 14, 13, 14, 13, 12, 11, 12, 11, 10, 9, 10, 9, 12, 11, 12, 11, 10, 9, 10, 9, 16, 15, 16, 15, 14, 13, 14, 13, 16, 15, 16, 15, 14, 13, 14, 13, 12, 11, 12, 11, 10, 9, 10, 9, 12, 11, 12, 11, 10, 9, 10, 9, 8, 7, 8, 7, 6, 5, 6, 5, 8, 7, 8, 7, 6, 5, 6, 5, 4, 3, 4, 3, 2, 1, 2, 1, 4, 3, 4, 3, 2, 1, 2, 1, 8, 7, 8, 7, 6, 5, 6, 5, 8, 7, 8, 7, 6, 5, 6, 5, 4, 3, 4, 3, 2, 1, 2, 1, 4, 3, 4, 3, 2, 1, 2, 1}; double expBuff[] = {16, 15, 16, 15, 14, 13, 14, 13, 16, 15, 16, 15, 14, 13, 14, 13, 12, 11, 12, 11, 10, 9, 10, 9, 12, 11, 12, 11, 10, 9, 10, 9, 16, 15, 16, 15, 14, 13, 14, 13, 16, 15, 16, 15, 14, 13, 14, 13, 12, 11, 12, 11, 10, 9, 10, 9, 12, 11, 12, 11, 10, 9, 10, 9, 8, 7, 8, 7, 6, 5, 6, 5, 8, 7, 8, 7, 6, 5, 6, 5, 4, 3, 4, 3, 2, 1, 2, 1, 4, 3, 4, 3, 2, 1, 2, 1, 8, 7, 8, 7, 6, 5, 6, 5, 8, 7, 8, 7, 6, 5, 6, 5, 4, 3, 4, 3, 2, 1, 2, 1, 4, 3, 4, 3, 2, 1, 2, 1, 16, 15, 16, 15, 14, 13, 14, 13, 16, 15, 16, 15, 14, 13, 14, 13, 12, 11, 12, 11, 10, 9, 10, 9, 12, 11, 12, 11, 10, 9, 10, 9, 16, 15, 16, 15, 14, 13, 14, 13, 16, 15, 16, 15, 14, 13, 14, 13, 12, 11, 12, 11, 10, 9, 10, 9, 12, 11, 12, 11, 10, 9, 10, 9, 8, 7, 8, 7, 6, 5, 6, 5, 8, 7, 8, 7, 6, 5, 6, 5, 4, 3, 4, 3, 2, 1, 2, 1, 4, 3, 4, 3, 2, 1, 2, 1, 8, 7, 8, 7, 6, 5, 6, 5, 8, 7, 8, 7, 6, 5, 6, 5, 4, 3, 4, 3, 2, 1, 2, 1, 4, 3, 4, 3, 2, 1, 2, 1};
auto input = NDArrayFactory::create<float>(inBuff, 'c', {2, 2, 2, 2}); auto input = NDArrayFactory::create<double>(inBuff, 'c', {2, 2, 2, 2});
auto paddings = NDArrayFactory::create<int>(padBuff, 'c', {4, 2}); auto paddings = NDArrayFactory::create<int>(padBuff, 'c', {4, 2});
auto expected = NDArrayFactory::create<float>(expBuff, 'c', {4, 4, 4, 4}); auto expected = NDArrayFactory::create<double>(expBuff, 'c', {4, 4, 4, 4});
nd4j::ops::pad op; nd4j::ops::pad op;
auto results = op.execute({&input, &paddings}, {}, {1}); auto results = op.execute({&input, &paddings}, {}, {1});
@ -2365,12 +2365,12 @@ TEST_F(DeclarableOpsTests12, Pad_8)
TEST_F(DeclarableOpsTests12, Pad_9) TEST_F(DeclarableOpsTests12, Pad_9)
{ {
float inBuff[] = {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16}; double inBuff[] = {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16};
int padBuff[] = {1, 1, 1, 1, 1, 1, 1, 1}; int padBuff[] = {1, 1, 1, 1, 1, 1, 1, 1};
float expBuff[] = {1, 1, 2, 2, 1, 1, 2, 2, 3, 3, 4, 4, 3, 3, 4, 4, 1, 1, 2, 2, 1, 1, 2, 2, 3, 3, 4, 4, 3, 3, 4, 4, 5, 5, 6, 6, 5, 5, 6, 6, 7, 7, 8, 8, 7, 7, 8, 8, 5, 5, 6, 6, 5, 5, 6, 6, 7, 7, 8, 8, 7, 7, 8, 8, 1, 1, 2, 2, 1, 1, 2, 2, 3, 3, 4, 4, 3, 3, 4, 4, 1, 1, 2, 2, 1, 1, 2, 2, 3, 3, 4, 4, 3, 3, 4, 4, 5, 5, 6, 6, 5, 5, 6, 6, 7, 7, 8, 8, 7, 7, 8, 8, 5, 5, 6, 6, 5, 5, 6, 6, 7, 7, 8, 8, 7, 7, 8, 8, 9, 9, 10, 10, 9, 9, 10, 10, 11, 11, 12, 12, 11, 11, 12, 12, 9, 9, 10, 10, 9, 9, 10, 10, 11, 11, 12, 12, 11, 11, 12, 12, 13, 13, 14, 14, 13, 13, 14, 14, 15, 15, 16, 16, 15, 15, 16, 16, 13, 13, 14, 14, 13, 13, 14, 14, 15, 15, 16, 16, 15, 15, 16, 16, 9, 9, 10, 10, 9, 9, 10, 10, 11, 11, 12, 12, 11, 11, 12, 12, 9, 9, 10, 10, 9, 9, 10, 10, 11, 11, 12, 12, 11, 11, 12, 12, 13, 13, 14, 14, 13, 13, 14, 14, 15, 15, 16, 16, 15, 15, 16, 16, 13, 13, 14, 14, 13, 13, 14, 14, 15, 15, 16, 16, 15, 15, 16, 16}; double expBuff[] = {1, 1, 2, 2, 1, 1, 2, 2, 3, 3, 4, 4, 3, 3, 4, 4, 1, 1, 2, 2, 1, 1, 2, 2, 3, 3, 4, 4, 3, 3, 4, 4, 5, 5, 6, 6, 5, 5, 6, 6, 7, 7, 8, 8, 7, 7, 8, 8, 5, 5, 6, 6, 5, 5, 6, 6, 7, 7, 8, 8, 7, 7, 8, 8, 1, 1, 2, 2, 1, 1, 2, 2, 3, 3, 4, 4, 3, 3, 4, 4, 1, 1, 2, 2, 1, 1, 2, 2, 3, 3, 4, 4, 3, 3, 4, 4, 5, 5, 6, 6, 5, 5, 6, 6, 7, 7, 8, 8, 7, 7, 8, 8, 5, 5, 6, 6, 5, 5, 6, 6, 7, 7, 8, 8, 7, 7, 8, 8, 9, 9, 10, 10, 9, 9, 10, 10, 11, 11, 12, 12, 11, 11, 12, 12, 9, 9, 10, 10, 9, 9, 10, 10, 11, 11, 12, 12, 11, 11, 12, 12, 13, 13, 14, 14, 13, 13, 14, 14, 15, 15, 16, 16, 15, 15, 16, 16, 13, 13, 14, 14, 13, 13, 14, 14, 15, 15, 16, 16, 15, 15, 16, 16, 9, 9, 10, 10, 9, 9, 10, 10, 11, 11, 12, 12, 11, 11, 12, 12, 9, 9, 10, 10, 9, 9, 10, 10, 11, 11, 12, 12, 11, 11, 12, 12, 13, 13, 14, 14, 13, 13, 14, 14, 15, 15, 16, 16, 15, 15, 16, 16, 13, 13, 14, 14, 13, 13, 14, 14, 15, 15, 16, 16, 15, 15, 16, 16};
auto input = NDArrayFactory::create<float>(inBuff, 'c', {2, 2, 2, 2}); auto input = NDArrayFactory::create<double>(inBuff, 'c', {2, 2, 2, 2});
auto paddings = NDArrayFactory::create<int>(padBuff, 'c', {4, 2}); auto paddings = NDArrayFactory::create<int>(padBuff, 'c', {4, 2});
auto expected = NDArrayFactory::create<float>(expBuff, 'c', {4, 4, 4, 4}); auto expected = NDArrayFactory::create<double>(expBuff, 'c', {4, 4, 4, 4});
nd4j::ops::pad op; nd4j::ops::pad op;
auto results = op.execute({&input, &paddings}, {}, {2}); auto results = op.execute({&input, &paddings}, {}, {2});
@ -2387,8 +2387,8 @@ TEST_F(DeclarableOpsTests12, Pad_9)
} }
TEST_F(DeclarableOpsTests12, Test_Expose_1) { TEST_F(DeclarableOpsTests12, Test_Expose_1) {
auto input0 = NDArrayFactory::create<float>('c', {2, 3}, {1, 2, 3, 6, 5, 4}); auto input0 = NDArrayFactory::create<double>('c', {2, 3}, {1, 2, 3, 6, 5, 4});
auto input1 = NDArrayFactory::create<float>('c', {2, 3}, {3, 2, 1, 4, 5, 6}); auto input1 = NDArrayFactory::create<double>('c', {2, 3}, {3, 2, 1, 4, 5, 6});
nd4j::ops::expose op; nd4j::ops::expose op;

View File

@ -1027,13 +1027,13 @@ TEST_F(DeclarableOpsTests13, lstmLayer_1) {
std::initializer_list<Nd4jLong> iArgs = {dataFormat, directionMode, gateAct, cellAct, outAct}; std::initializer_list<Nd4jLong> iArgs = {dataFormat, directionMode, gateAct, cellAct, outAct};
std::initializer_list<bool> bArgs = {hasBiases, hasSeqLen, hasInitH, hasInitC, hasPH, retFullSeq, retLastH, retLastC}; std::initializer_list<bool> bArgs = {hasBiases, hasSeqLen, hasInitH, hasInitC, hasPH, retFullSeq, retLastH, retLastC};
auto expH = NDArrayFactory::create<float>('c', {sL, bS, nOut}, {0.57574,0.57574,0.57574,0.58006,0.58006,0.58006,0.58434,0.58434,0.58434, auto expH = NDArrayFactory::create<float>('c', {sL, bS, nOut}, {0.57574f, 0.57574f, 0.57574f, 0.58006f, 0.58006f, 0.58006f, 0.58434f, 0.58434f, 0.58434f,
0.55114,0.55114,0.55114,0.55732,0.55732,0.55732,0.56338,0.56338,0.56338, 0.55114f, 0.55114f, 0.55114f, 0.55732f, 0.55732f, 0.55732f, 0.56338f, 0.56338f, 0.56338f,
0.53763,0.53763,0.53763,0.54534,0.54534,0.54534,0.55287,0.55287,0.55287, 0.53763f, 0.53763f, 0.53763f, 0.54534f, 0.54534f, 0.54534f, 0.55287f, 0.55287f, 0.55287f,
0.53626,0.53626,0.53626,0.54487,0.54487,0.54487,0.55327,0.55327,0.55327, 0.53626f, 0.53626f, 0.53626f, 0.54487f, 0.54487f, 0.54487f, 0.55327f, 0.55327f, 0.55327f,
0.54484,0.54484,0.54484,0.55379,0.55379,0.55379,0.5625 ,0.5625 ,0.5625}); 0.54484f, 0.54484f, 0.54484f, 0.55379f, 0.55379f, 0.55379f, 0.5625f, 0.5625f, 0.5625f});
auto expClast = NDArrayFactory::create<float>('c', {bS, nOut}, {1.1589154,1.1589154,1.1589154,1.1892855,1.1892855,1.1892855,1.219861 ,1.219861 ,1.219861}); auto expClast = NDArrayFactory::create<float>('c', {bS, nOut}, {1.1589154f, 1.1589154f, 1.1589154f, 1.1892855f, 1.1892855f, 1.1892855f, 1.219861f, 1.219861f, 1.219861f});
nd4j::ops::lstmLayer op; nd4j::ops::lstmLayer op;
auto results = op.execute({&x, &Wx, &Wr, &b, &hI, &cI}, tArgs, iArgs, bArgs); auto results = op.execute({&x, &Wx, &Wr, &b, &hI, &cI}, tArgs, iArgs, bArgs);
@ -1097,11 +1097,11 @@ TEST_F(DeclarableOpsTests13, lstmLayer_2) {
std::initializer_list<Nd4jLong> iArgs = {dataFormat, directionMode, gateAct, cellAct, outAct}; std::initializer_list<Nd4jLong> iArgs = {dataFormat, directionMode, gateAct, cellAct, outAct};
std::initializer_list<bool> bArgs = {hasBiases, hasSeqLen, hasInitH, hasInitC, hasPH, retFullSeq, retLastH, retLastC}; std::initializer_list<bool> bArgs = {hasBiases, hasSeqLen, hasInitH, hasInitC, hasPH, retFullSeq, retLastH, retLastC};
auto expH = NDArrayFactory::create<float>('c', {bS, sL, nOut}, {0.575735, 0.575735, 0.575735, 0.541562, 0.541562, 0.541562, 0.514003, 0.514003, 0.514003, 0.495597, 0.495597, 0.495597, 0.485999, 0.485999, 0.485999, auto expH = NDArrayFactory::create<float>('c', {bS, sL, nOut}, {0.575735f, 0.575735f, 0.575735f, 0.541562f, 0.541562f, 0.541562f, 0.514003f, 0.514003f, 0.514003f, 0.495597f, 0.495597f, 0.495597f, 0.485999f, 0.485999f, 0.485999f,
0.596965, 0.596965, 0.596965, 0.571978, 0.571978, 0.571978, 0.552888, 0.552888, 0.552888, 0.540606, 0.540606, 0.540606, 0.534764, 0.534764, 0.534764, 0.596965f, 0.596965f, 0.596965f, 0.571978f, 0.571978f, 0.571978f, 0.552888f, 0.552888f, 0.552888f, 0.540606f, 0.540606f, 0.540606f, 0.534764f, 0.534764f, 0.534764f,
0.61725 , 0.61725 , 0.61725 , 0.599828, 0.599828, 0.599828, 0.587627, 0.587627, 0.587627, 0.580408, 0.580408, 0.580408, 0.577735, 0.577735, 0.577735}); 0.61725f, 0.61725f, 0.61725f, 0.599828f, 0.599828f, 0.599828f, 0.587627f, 0.587627f, 0.587627f, 0.580408f, 0.580408f, 0.580408f, 0.577735f, 0.577735f, 0.577735f});
auto expClast = NDArrayFactory::create<float>('c', {bS, nOut}, {0.996965, 0.996965, 0.996965, 1.146756, 1.146756, 1.146756, 1.301922, 1.301922, 1.301922}); auto expClast = NDArrayFactory::create<float>('c', {bS, nOut}, {0.996965f, 0.996965f, 0.996965f, 1.146756f, 1.146756f, 1.146756f, 1.301922f, 1.301922f, 1.301922f});
nd4j::ops::lstmLayer op; nd4j::ops::lstmLayer op;
auto results = op.execute({&x, &Wx, &Wr, &b, &hI, &cI}, tArgs, iArgs, bArgs); auto results = op.execute({&x, &Wx, &Wr, &b, &hI, &cI}, tArgs, iArgs, bArgs);

View File

@ -178,10 +178,10 @@ TEST_F(DeclarableOpsTests15, Test_AdjustContrast_1) {
TEST_F(DeclarableOpsTests15, Test_AdjustContrast_2) { TEST_F(DeclarableOpsTests15, Test_AdjustContrast_2) {
auto x = NDArrayFactory::create<float>('c', {1, 4,4,3}); auto x = NDArrayFactory::create<float>('c', {1, 4,4,3});
auto e = NDArrayFactory::create<float>('c', {1, 4,4,3}, { auto e = NDArrayFactory::create<float>('c', {1, 4,4,3}, {
-21.5, -20.5, -19.5, -15.5, -14.5, -13.5, -9.5, -8.5, -7.5, -3.5, -2.5, -1.5, -21.5f, -20.5f, -19.5f, -15.5f, -14.5f, -13.5f, -9.5f, -8.5f, -7.5f, -3.5f, -2.5f, -1.5f,
2.5, 3.5, 4.5, 8.5, 9.5, 10.5, 14.5, 15.5, 16.5, 20.5, 21.5, 22.5, 2.5f, 3.5f, 4.5f, 8.5f, 9.5f, 10.5f, 14.5f, 15.5f, 16.5f, 20.5f, 21.5f, 22.5f,
26.5, 27.5, 28.5, 32.5, 33.5, 34.5, 38.5, 39.5, 40.5, 44.5, 45.5, 46.5, 26.5f, 27.5f, 28.5f, 32.5f, 33.5f, 34.5f, 38.5f, 39.5f, 40.5f, 44.5f, 45.5f, 46.5f,
50.5, 51.5, 52.5, 56.5, 57.5, 58.5, 62.5, 63.5, 64.5, 68.5, 69.5, 70.5 50.5f, 51.5f, 52.5f, 56.5f, 57.5f, 58.5f, 62.5f, 63.5f, 64.5f, 68.5f, 69.5f, 70.5f
}); });
x.linspace(1.); x.linspace(1.);
nd4j::ops::adjust_contrast op; nd4j::ops::adjust_contrast op;
@ -196,10 +196,10 @@ TEST_F(DeclarableOpsTests15, Test_AdjustContrast_2) {
TEST_F(DeclarableOpsTests15, Test_AdjustContrast_3) { TEST_F(DeclarableOpsTests15, Test_AdjustContrast_3) {
auto x = NDArrayFactory::create<float>('c', {1, 4,4,3}); auto x = NDArrayFactory::create<float>('c', {1, 4,4,3});
auto e = NDArrayFactory::create<float>('c', {1, 4,4,3}, { auto e = NDArrayFactory::create<float>('c', {1, 4,4,3}, {
-21.5, -20.5, -19.5, -15.5, -14.5, -13.5, -9.5, -8.5, -7.5, -3.5, -2.5, -1.5, -21.5f, -20.5f, -19.5f, -15.5f, -14.5f, -13.5f, -9.5f, -8.5f, -7.5f, -3.5f, -2.5f, -1.5f,
2.5, 3.5, 4.5, 8.5, 9.5, 10.5, 14.5, 15.5, 16.5, 20.5, 21.5, 22.5, 2.5f, 3.5f, 4.5f, 8.5f, 9.5f, 10.5f, 14.5f, 15.5f, 16.5f, 20.5f, 21.5f, 22.5f,
26.5, 27.5, 28.5, 32.5, 33.5, 34.5, 38.5, 39.5, 40.5, 44.5, 45.5, 46.5, 26.5f, 27.5f, 28.5f, 32.5f, 33.5f, 34.5f, 38.5f, 39.5f, 40.5f, 44.5f, 45.5f, 46.5f,
50.5, 51.5, 52.5, 56.5, 57.5, 58.5, 62.5, 63.5, 64.5, 68.5, 69.5, 70.5 50.5f, 51.5f, 52.5f, 56.5f, 57.5f, 58.5f, 62.5f, 63.5f, 64.5f, 68.5f, 69.5f, 70.5f
}); });
x.linspace(1.); x.linspace(1.);
nd4j::ops::adjust_contrast_v2 op; nd4j::ops::adjust_contrast_v2 op;
@ -243,8 +243,8 @@ TEST_F(DeclarableOpsTests15, Test_BitCast_1) {
TEST_F(DeclarableOpsTests15, Test_BitCast_2) { TEST_F(DeclarableOpsTests15, Test_BitCast_2) {
auto x = NDArrayFactory::create<float>('c', {2, 4}); auto x = NDArrayFactory::create<float>('c', {2, 4});
auto e = NDArrayFactory::create<float16>('c', {2, 4, 2}, {0, 1.875, 0, 2., 0, 2.125, 0, 2.25, auto e = NDArrayFactory::create<float16>('c', {2, 4, 2}, {0.f, 1.875f, 0.f, 2.f, 0.f, 2.125f, 0.f, 2.25f,
0, 2.312, 0, 2.375, 0, 2.438, 0., 2.5}); 0.f, 2.312f, 0.f, 2.375f, 0.f, 2.438f, 0.f, 2.5f});
x.linspace(1.); x.linspace(1.);
nd4j::ops::bitcast op; nd4j::ops::bitcast op;
auto result = op.execute({&x}, {}, {nd4j::DataType::HALF}, {}); auto result = op.execute({&x}, {}, {nd4j::DataType::HALF}, {});
@ -423,9 +423,9 @@ TEST_F(DeclarableOpsTests15, test_check_numeric_3) {
} }
TEST_F(DeclarableOpsTests15, Test_layer_norm_1) { TEST_F(DeclarableOpsTests15, Test_layer_norm_1) {
auto x = NDArrayFactory::create<float>('c', {1, 5}, {1., 2., 3., 4., 5.}); auto x = NDArrayFactory::create<float>('c', {1, 5}, {1.f, 2.f, 3.f, 4.f, 5.f});
auto g = NDArrayFactory::create<float>('c', {5}, {1., 2., 3., 4., 5.}); auto g = NDArrayFactory::create<float>('c', {5}, {1.f, 2.f, 3.f, 4.f, 5.f});
auto b = NDArrayFactory::create<float>('c', {5}, {1., 2., 3., 4., 5.}); auto b = NDArrayFactory::create<float>('c', {5}, {1.f, 2.f, 3.f, 4.f, 5.f});
nd4j::ops::layer_norm op; nd4j::ops::layer_norm op;
auto result = op.execute({&x, &g, &b}, {}, {0}, {false}); auto result = op.execute({&x, &g, &b}, {}, {0}, {false});
@ -434,10 +434,10 @@ TEST_F(DeclarableOpsTests15, Test_layer_norm_1) {
} }
TEST_F(DeclarableOpsTests15, Test_layer_norm_bp_1) { TEST_F(DeclarableOpsTests15, Test_layer_norm_bp_1) {
auto x = NDArrayFactory::create<float>('c', {1, 5}, {1., 2., 3., 4., 5.}); auto x = NDArrayFactory::create<float>('c', {1, 5}, {1.f, 2.f, 3.f, 4.f, 5.f});
auto g = NDArrayFactory::create<float>('c', {5}, {1., 2., 3., 4., 5.}); auto g = NDArrayFactory::create<float>('c', {5}, {1.f, 2.f, 3.f, 4.f, 5.f});
auto b = NDArrayFactory::create<float>('c', {5}, {1., 2., 3., 4., 5.}); auto b = NDArrayFactory::create<float>('c', {5}, {1.f, 2.f, 3.f, 4.f, 5.f});
auto eps = NDArrayFactory::create<float>('c', {1, 5}, {0., 0., 0., 0., 0.}); auto eps = NDArrayFactory::create<float>('c', {1, 5}, {0.f, 0.f, 0.f, 0.f, 0.f});
nd4j::ops::layer_norm_bp op; nd4j::ops::layer_norm_bp op;
auto result = op.execute({&x, &g, &b, &eps}, {}, {0}, {false}); auto result = op.execute({&x, &g, &b, &eps}, {}, {0}, {false});

View File

@ -40,7 +40,7 @@ public:
}; };
TEST_F(DeclarableOpsTests16, scatter_upd_1) { TEST_F(DeclarableOpsTests16, scatter_upd_1) {
auto x = NDArrayFactory::create<float>('c', {3}, {1, 1, 1}); auto x = NDArrayFactory::create<float>('c', {3}, {1.f, 1.f, 1.f});
auto y = NDArrayFactory::create<int>(0); auto y = NDArrayFactory::create<int>(0);
auto w = NDArrayFactory::create<float>(3.0f); auto w = NDArrayFactory::create<float>(3.0f);
auto e = NDArrayFactory::create<float>('c', {3}, {3.f, 1.f, 1.f}); auto e = NDArrayFactory::create<float>('c', {3}, {3.f, 1.f, 1.f});

View File

@ -400,7 +400,7 @@ TEST_F(DeclarableOpsTests2, NLP_Cbow_Test_1) {
TEST_F(DeclarableOpsTests2, YetAnotherMatmulTest_1) { TEST_F(DeclarableOpsTests2, YetAnotherMatmulTest_1) {
auto A = NDArrayFactory::create<float>('c', {3, 3}); auto A = NDArrayFactory::create<float>('c', {3, 3});
auto B = NDArrayFactory::create<float>('c', {3, 1}); auto B = NDArrayFactory::create<float>('c', {3, 1});
auto exp = NDArrayFactory::create<float>('c', {3, 1}, {14.00, 32.00, 50.00}); auto exp = NDArrayFactory::create<float>('c', {3, 1}, {14.00f, 32.00f, 50.00f});
A.linspace(1); A.linspace(1);
B.linspace(1); B.linspace(1);
@ -457,9 +457,9 @@ TEST_F(DeclarableOpsTests2, Test_Squeeze_2) {
} }
TEST_F(DeclarableOpsTests2, Test_FloorMod_1) { TEST_F(DeclarableOpsTests2, Test_FloorMod_1) {
auto x = NDArrayFactory::create<float>('c', {1, 3}, {2.0, 6.0, -3.0}); auto x = NDArrayFactory::create<float>('c', {1, 3}, {2.0f, 6.0f, -3.0f});
auto y = NDArrayFactory::create<float>('c', {1, 3}, {-3.0, 2.0, -2.0}); auto y = NDArrayFactory::create<float>('c', {1, 3}, {-3.0f, 2.0f, -2.0f});
auto exp = NDArrayFactory::create<float>('c', {1, 3}, {-1., 0., -1.,}); auto exp = NDArrayFactory::create<float>('c', {1, 3}, {-1.f, 0.f, -1.f});
nd4j::ops::floormod op; nd4j::ops::floormod op;
@ -475,9 +475,9 @@ TEST_F(DeclarableOpsTests2, Test_FloorMod_1) {
} }
TEST_F(DeclarableOpsTests2, Test_FloorDiv_1) { TEST_F(DeclarableOpsTests2, Test_FloorDiv_1) {
auto x = NDArrayFactory::create<float>('c', {1, 3}, {3.0, 6.0, -3.0}); auto x = NDArrayFactory::create<float>('c', {1, 3}, {3.0f, 6.0f, -3.0f});
auto y = NDArrayFactory::create<float>('c', {1, 3}, {-2.0, 2.0, -2.0}); auto y = NDArrayFactory::create<float>('c', {1, 3}, {-2.0f, 2.0f, -2.0f});
auto exp = NDArrayFactory::create<float>('c', {1, 3}, {-2., 3., 1.,}); auto exp = NDArrayFactory::create<float>('c', {1, 3}, {-2.f, 3.f, 1.f});
nd4j::ops::floordiv op; nd4j::ops::floordiv op;
@ -494,9 +494,9 @@ TEST_F(DeclarableOpsTests2, Test_FloorDiv_1) {
} }
TEST_F(DeclarableOpsTests2, Test_FloorDiv_2) { TEST_F(DeclarableOpsTests2, Test_FloorDiv_2) {
auto x = NDArrayFactory::create<float>('c', {1, 3}, {3.0, 6.0, -3.0}); auto x = NDArrayFactory::create<float>('c', {1, 3}, {3.0f, 6.0f, -3.0f});
auto y = NDArrayFactory::create<float>('c', {1, 3}, {-2.0, 2.0, -2.0}); auto y = NDArrayFactory::create<float>('c', {1, 3}, {-2.0f, 2.0f, -2.0f});
auto eps = NDArrayFactory::create<float>('c', {1, 3}, {1, 2, 3}); auto eps = NDArrayFactory::create<float>('c', {1, 3}, {1.f, 2.f, 3.f});
auto exp1 = NDArrayFactory::create<float>('c', {1, 3}, {0.f, 0.f, 0.f}); auto exp1 = NDArrayFactory::create<float>('c', {1, 3}, {0.f, 0.f, 0.f});
auto exp2 = NDArrayFactory::create<float>('c', {1, 3}, {0.f, 0.f, 0.f}); auto exp2 = NDArrayFactory::create<float>('c', {1, 3}, {0.f, 0.f, 0.f});
@ -518,8 +518,8 @@ TEST_F(DeclarableOpsTests2, Test_FloorDiv_2) {
} }
TEST_F(DeclarableOpsTests2, Test_CRelu_1) { TEST_F(DeclarableOpsTests2, Test_CRelu_1) {
auto x = NDArrayFactory::create<float>('c', {2, 2}, {1.0, 2.0, 3.0, 4.0}); auto x = NDArrayFactory::create<float>('c', {2, 2}, {1.0f, 2.0f, 3.0f, 4.0f});
auto exp = NDArrayFactory::create<float>('c', {2, 4}, {1.0, 2.0, 0, 0, 3.0, 4.0, 0, 0}); auto exp = NDArrayFactory::create<float>('c', {2, 4}, {1.0f, 2.0f, 0.f, 0.f, 3.0f, 4.0f, 0.f, 0.f});
nd4j::ops::crelu op; nd4j::ops::crelu op;
@ -536,9 +536,9 @@ TEST_F(DeclarableOpsTests2, Test_CRelu_1) {
} }
TEST_F(DeclarableOpsTests2, Test_CRelu_BP_2) { TEST_F(DeclarableOpsTests2, Test_CRelu_BP_2) {
auto x = NDArrayFactory::create<float>('c', {2, 2}, {1.0, 2.0, -3.0, 4.0}); auto x = NDArrayFactory::create<float>('c', {2, 2}, {1.0f, 2.0f, -3.0f, 4.0f});
auto eps = NDArrayFactory::create<float>('c', {2, 4}, {1.0, 2.0, 4, 3, 3.0, 4.0, 2, 1}); auto eps = NDArrayFactory::create<float>('c', {2, 4}, {1.0f, 2.0f, 4.f, 3.f, 3.0f, 4.0f, 2.f, 1.f});
auto exp = NDArrayFactory::create<float>('c', {2, 2}, {1, 2, -2, 4}); auto exp = NDArrayFactory::create<float>('c', {2, 2}, {1.f, 2.f, -2.f, 4.f});
nd4j::ops::crelu_bp op; nd4j::ops::crelu_bp op;
auto result = op.execute({&x, &eps}, {}, {}); auto result = op.execute({&x, &eps}, {}, {});
@ -556,9 +556,9 @@ TEST_F(DeclarableOpsTests2, Test_CRelu_BP_2) {
TEST_F(DeclarableOpsTests2, Test_Concat_BP_1) { TEST_F(DeclarableOpsTests2, Test_Concat_BP_1) {
auto x = NDArrayFactory::create<float>('c', {2, 2}); auto x = NDArrayFactory::create<float>('c', {2, 2});
auto y = NDArrayFactory::create<float>('c', {2, 2}); auto y = NDArrayFactory::create<float>('c', {2, 2});
auto eps = NDArrayFactory::create<float>('c', {2, 4}, {1.0, 2.0, 0, 1, 3.0, 4.0, 0, 1}); auto eps = NDArrayFactory::create<float>('c', {2, 4}, {1.0f, 2.0f, 0.f, 1.f, 3.0f, 4.0f, 0.f, 1.f});
auto expEX = NDArrayFactory::create<float>('c', {2, 2}, {1, 2, 3, 4}); auto expEX = NDArrayFactory::create<float>('c', {2, 2}, {1.f, 2.f, 3.f, 4.f});
auto expEY = NDArrayFactory::create<float>('c', {2, 2}, {0, 1, 0, 1}); auto expEY = NDArrayFactory::create<float>('c', {2, 2}, {0.f, 1.f, 0.f, 1.f});
nd4j::ops::concat_bp op; nd4j::ops::concat_bp op;
auto result = op.execute({&x, &y, &eps}, {}, {-1}); auto result = op.execute({&x, &y, &eps}, {}, {-1});
@ -581,9 +581,9 @@ TEST_F(DeclarableOpsTests2, Test_Concat_BP_1) {
//////////////////////////////////////////////////////////////////// ////////////////////////////////////////////////////////////////////
TEST_F(DeclarableOpsTests2, TestTensorDot5) { TEST_F(DeclarableOpsTests2, TestTensorDot5) {
auto x = NDArrayFactory::create<float>('c', {2,3,4}, {1,3,5,7,9,11,13,15, 1,3,5,7,9,11,13,15, 1,3,5,7,9,11,13,15}); auto x = NDArrayFactory::create<double>('c', {2,3,4}, {1,3,5,7,9,11,13,15, 1,3,5,7,9,11,13,15, 1,3,5,7,9,11,13,15});
auto y = NDArrayFactory::create<float>('c', {2,4,3}, {2,4,6,8,10,12,14,16, 2,4,6,8,10,12,14,16, 2,4,6,8,10,12,14,16}); auto y = NDArrayFactory::create<double>('c', {2,4,3}, {2,4,6,8,10,12,14,16, 2,4,6,8,10,12,14,16, 2,4,6,8,10,12,14,16});
auto expected = NDArrayFactory::create<float>('c', {2,4,2,4}, {44,110,160, 66,132, 38, 88,154, 68,170,224,102,204, 82,136,238, 92,230,288,138,276,126,184,322, 116,290,352,174,348,170,232,406, 76,190,160,114,228,182,152,266, 100,250,224,150,300,226,200,350, 124,310,288,186,372,270,248,434, 148,370,352,222,444,314,296,518}); auto expected = NDArrayFactory::create<double>('c', {2,4,2,4}, {44,110,160, 66,132, 38, 88,154, 68,170,224,102,204, 82,136,238, 92,230,288,138,276,126,184,322, 116,290,352,174,348,170,232,406, 76,190,160,114,228,182,152,266, 100,250,224,150,300,226,200,350, 124,310,288,186,372,270,248,434, 148,370,352,222,444,314,296,518});
nd4j::ops::tensormmul op; nd4j::ops::tensormmul op;
auto results = op.execute({&x, &y}, {}, {1,1,1,2}); auto results = op.execute({&x, &y}, {}, {1,1,1,2});
@ -603,9 +603,9 @@ TEST_F(DeclarableOpsTests2, TestTensorDot5) {
//////////////////////////////////////////////////////////////////// ////////////////////////////////////////////////////////////////////
TEST_F(DeclarableOpsTests2, TestTensorDot6) { TEST_F(DeclarableOpsTests2, TestTensorDot6) {
auto x = NDArrayFactory::create<float>('c', {2,3,4}, {1,3,5,7,9,11,13,15, 1,3,5,7,9,11,13,15, 1,3,5,7,9,11,13,15}); auto x = NDArrayFactory::create<double>('c', {2,3,4}, {1,3,5,7,9,11,13,15, 1,3,5,7,9,11,13,15, 1,3,5,7,9,11,13,15});
auto y = NDArrayFactory::create<float>('f', {2,4,3}, {2,4,6,8,10,12,14,16, 2,4,6,8,10,12,14,16, 2,4,6,8,10,12,14,16}); auto y = NDArrayFactory::create<double>('f', {2,4,3}, {2,4,6,8,10,12,14,16, 2,4,6,8,10,12,14,16, 2,4,6,8,10,12,14,16});
auto expected = NDArrayFactory::create<float>('c', {2,4,2,4}, {22, 66,110,154, 44, 88,132,176, 34,102,170,238, 68,136,204,272, 46,138,230,322, 92,184,276,368, 58,174,290,406,116,232,348,464, 38,114,190,266, 76,152,228,304, 50,150,250,350,100,200,300,400, 62,186,310,434,124,248,372,496, 74,222,370,518,148,296,444,592}); auto expected = NDArrayFactory::create<double>('c', {2,4,2,4}, {22, 66,110,154, 44, 88,132,176, 34,102,170,238, 68,136,204,272, 46,138,230,322, 92,184,276,368, 58,174,290,406,116,232,348,464, 38,114,190,266, 76,152,228,304, 50,150,250,350,100,200,300,400, 62,186,310,434,124,248,372,496, 74,222,370,518,148,296,444,592});
nd4j::ops::tensormmul op; nd4j::ops::tensormmul op;
auto results = op.execute({&x, &y}, {}, {1,1,1,2}); auto results = op.execute({&x, &y}, {}, {1,1,1,2});
@ -624,9 +624,9 @@ TEST_F(DeclarableOpsTests2, TestTensorDot6) {
//////////////////////////////////////////////////////////////////// ////////////////////////////////////////////////////////////////////
TEST_F(DeclarableOpsTests2, TestTensorDot7) { TEST_F(DeclarableOpsTests2, TestTensorDot7) {
auto x = NDArrayFactory::create<float>('f', {2,3,4}, {1,3,5,7,9,11,13,15, 1,3,5,7,9,11,13,15, 1,3,5,7,9,11,13,15}); auto x = NDArrayFactory::create<double>('f', {2,3,4}, {1,3,5,7,9,11,13,15, 1,3,5,7,9,11,13,15, 1,3,5,7,9,11,13,15});
auto y = NDArrayFactory::create<float>('c', {2,4,3}, {2,4,6,8,10,12,14,16, 2,4,6,8,10,12,14,16, 2,4,6,8,10,12,14,16}); auto y = NDArrayFactory::create<double>('c', {2,4,3}, {2,4,6,8,10,12,14,16, 2,4,6,8,10,12,14,16, 2,4,6,8,10,12,14,16});
auto expected = NDArrayFactory::create<float>('c', {2,4,2,4}, {76,166,112,106,196, 62,136,226, 60,174,208, 98,212,230,136,250, 76,214,336,122,260,174,168,306, 124,286,240,178,340,150,232,394, 100,226,176,142,268,106,184,310, 84,234,272,134,284,274,184,334, 100,274,400,158,332,218,216,390, 148,346,304,214,412,194,280,478}); auto expected = NDArrayFactory::create<double>('c', {2,4,2,4}, {76,166,112,106,196, 62,136,226, 60,174,208, 98,212,230,136,250, 76,214,336,122,260,174,168,306, 124,286,240,178,340,150,232,394, 100,226,176,142,268,106,184,310, 84,234,272,134,284,274,184,334, 100,274,400,158,332,218,216,390, 148,346,304,214,412,194,280,478});
nd4j::ops::tensormmul op; nd4j::ops::tensormmul op;
auto results = op.execute({&x, &y}, {}, {1,1,1,2}); auto results = op.execute({&x, &y}, {}, {1,1,1,2});
@ -645,9 +645,9 @@ TEST_F(DeclarableOpsTests2, TestTensorDot7) {
//////////////////////////////////////////////////////////////////// ////////////////////////////////////////////////////////////////////
TEST_F(DeclarableOpsTests2, TestTensorDot8) { TEST_F(DeclarableOpsTests2, TestTensorDot8) {
auto x = NDArrayFactory::create<float>('f', {2,3,4}, {1,3,5,7,9,11,13,15, 1,3,5,7,9,11,13,15, 1,3,5,7,9,11,13,15}); auto x = NDArrayFactory::create<double>('f', {2,3,4}, {1,3,5,7,9,11,13,15, 1,3,5,7,9,11,13,15, 1,3,5,7,9,11,13,15});
auto y = NDArrayFactory::create<float>('f', {2,4,3}, {2,4,6,8,10,12,14,16, 2,4,6,8,10,12,14,16, 2,4,6,8,10,12,14,16}); auto y = NDArrayFactory::create<double>('f', {2,4,3}, {2,4,6,8,10,12,14,16, 2,4,6,8,10,12,14,16, 2,4,6,8,10,12,14,16});
auto expected = NDArrayFactory::create<float>('c', {2,4,2,4}, {30, 90,150,210, 60,120,180,240, 38,114,190,266, 76,152,228,304, 46,138,230,322, 92,184,276,368, 54,162,270,378,108,216,324,432, 42,126,210,294, 84,168,252,336, 50,150,250,350,100,200,300,400, 58,174,290,406,116,232,348,464, 66,198,330,462,132,264,396,528}); auto expected = NDArrayFactory::create<double>('c', {2,4,2,4}, {30, 90,150,210, 60,120,180,240, 38,114,190,266, 76,152,228,304, 46,138,230,322, 92,184,276,368, 54,162,270,378,108,216,324,432, 42,126,210,294, 84,168,252,336, 50,150,250,350,100,200,300,400, 58,174,290,406,116,232,348,464, 66,198,330,462,132,264,396,528});
nd4j::ops::tensormmul op; nd4j::ops::tensormmul op;
auto results = op.execute({&x, &y}, {}, {1,1,1,2}); auto results = op.execute({&x, &y}, {}, {1,1,1,2});
@ -674,9 +674,9 @@ TEST_F(DeclarableOpsTests2, TestTensorDot9) {
// z.printShapeInfo(); // z.printShapeInfo();
// z.printIndexedBuffer(); // z.printIndexedBuffer();
auto x = NDArrayFactory::create<float>('f', {2,3,4}, {1,3,5,7,9,11,13,15, 1,3,5,7,9,11,13,15, 1,3,5,7,9,11,13,15}); auto x = NDArrayFactory::create<double>('f', {2,3,4}, {1,3,5,7,9,11,13,15, 1,3,5,7,9,11,13,15, 1,3,5,7,9,11,13,15});
auto y = NDArrayFactory::create<float>('f', {2,4,3}, {2,4,6,8,10,12,14,16, 2,4,6,8,10,12,14,16, 2,4,6,8,10,12,14,16}); auto y = NDArrayFactory::create<double>('f', {2,4,3}, {2,4,6,8,10,12,14,16, 2,4,6,8,10,12,14,16, 2,4,6,8,10,12,14,16});
auto expected = NDArrayFactory::create<float>('c', {3,4,4,3}, {14, 14, 14, 30, 30, 30, 46, 46, 46, 62, 62, 62, 86, 86, 86,198,198,198,310,310,310,422,422,422, 62, 62, 62,142,142,142,222,222,222,302,302,302, 38, 38, 38, 86, 86, 86,134,134,134,182,182,182, 38, 38, 38, 86, 86, 86,134,134,134,182,182,182, 14, 14, 14, 30, 30, 30, 46, 46, 46, 62, 62, 62, 86, 86, 86,198,198,198,310,310,310,422,422,422, 62, 62, 62,142,142,142,222,222,222,302,302,302, 62, 62, 62,142,142,142,222,222,222,302,302,302, 38, 38, 38, 86, 86, 86,134,134,134,182,182,182, 14, 14, 14, 30, 30, 30, 46, 46, 46, 62, 62, 62, 86, 86, 86,198,198,198,310,310,310,422,422,422}); auto expected = NDArrayFactory::create<double>('c', {3,4,4,3}, {14, 14, 14, 30, 30, 30, 46, 46, 46, 62, 62, 62, 86, 86, 86,198,198,198,310,310,310,422,422,422, 62, 62, 62,142,142,142,222,222,222,302,302,302, 38, 38, 38, 86, 86, 86,134,134,134,182,182,182, 38, 38, 38, 86, 86, 86,134,134,134,182,182,182, 14, 14, 14, 30, 30, 30, 46, 46, 46, 62, 62, 62, 86, 86, 86,198,198,198,310,310,310,422,422,422, 62, 62, 62,142,142,142,222,222,222,302,302,302, 62, 62, 62,142,142,142,222,222,222,302,302,302, 38, 38, 38, 86, 86, 86,134,134,134,182,182,182, 14, 14, 14, 30, 30, 30, 46, 46, 46, 62, 62, 62, 86, 86, 86,198,198,198,310,310,310,422,422,422});
nd4j::ops::tensormmul op; nd4j::ops::tensormmul op;
auto results = op.execute({&x, &y}, {}, {1,0,1,0}); auto results = op.execute({&x, &y}, {}, {1,0,1,0});
@ -695,9 +695,9 @@ TEST_F(DeclarableOpsTests2, TestTensorDot9) {
//////////////////////////////////////////////////////////////////// ////////////////////////////////////////////////////////////////////
TEST_F(DeclarableOpsTests2, TestTensorDot10) { TEST_F(DeclarableOpsTests2, TestTensorDot10) {
auto x = NDArrayFactory::create<float>('f', {2,3,4}, {1,3,5,7,9,11,13,15, 1,3,5,7,9,11,13,15, 1,3,5,7,9,11,13,15}); auto x = NDArrayFactory::create<double>('f', {2,3,4}, {1,3,5,7,9,11,13,15, 1,3,5,7,9,11,13,15, 1,3,5,7,9,11,13,15});
auto y = NDArrayFactory::create<float>('f', {2,4,3}, {2,4,6,8,10,12,14,16, 2,4,6,8,10,12,14,16, 2,4,6,8,10,12,14,16}); auto y = NDArrayFactory::create<double>('f', {2,4,3}, {2,4,6,8,10,12,14,16, 2,4,6,8,10,12,14,16, 2,4,6,8,10,12,14,16});
auto expected = NDArrayFactory::create<float>('c', {4,4}, {114,258,402,546, 138,314,490,666, 162,370,578,786, 186,426,666,906}); auto expected = NDArrayFactory::create<double>('c', {4,4}, {114,258,402,546, 138,314,490,666, 162,370,578,786, 186,426,666,906});
nd4j::ops::tensormmul op; nd4j::ops::tensormmul op;
auto results = op.execute({&x, &y}, {}, {2,0,1, 2,0,2}); auto results = op.execute({&x, &y}, {}, {2,0,1, 2,0,2});
@ -717,9 +717,9 @@ TEST_F(DeclarableOpsTests2, TestTensorDot10) {
//////////////////////////////////////////////////////////////////// ////////////////////////////////////////////////////////////////////
TEST_F(DeclarableOpsTests2, TestTensorDot11) { TEST_F(DeclarableOpsTests2, TestTensorDot11) {
auto x = NDArrayFactory::create<float>('c', {2,3,4}, {1,3,5,7,9,11,13,15, 1,3,5,7,9,11,13,15, 1,3,5,7,9,11,13,15}); auto x = NDArrayFactory::create<double>('c', {2,3,4}, {1,3,5,7,9,11,13,15, 1,3,5,7,9,11,13,15, 1,3,5,7,9,11,13,15});
auto y = NDArrayFactory::create<float>('f', {2,4,3}, {2,4,6,8,10,12,14,16, 2,4,6,8,10,12,14,16, 2,4,6,8,10,12,14,16}); auto y = NDArrayFactory::create<double>('f', {2,4,3}, {2,4,6,8,10,12,14,16, 2,4,6,8,10,12,14,16, 2,4,6,8,10,12,14,16});
auto expected = NDArrayFactory::create<float>('c', {4,4}, {98,218,338,458, 134,302,470,638, 170,386,602,818, 206,470,734,998}); auto expected = NDArrayFactory::create<double>('c', {4,4}, {98,218,338,458, 134,302,470,638, 170,386,602,818, 206,470,734,998});
nd4j::ops::tensormmul op; nd4j::ops::tensormmul op;
auto results = op.execute({&x, &y}, {}, {2,0,1, 2,0,2}); auto results = op.execute({&x, &y}, {}, {2,0,1, 2,0,2});
@ -738,9 +738,9 @@ TEST_F(DeclarableOpsTests2, TestTensorDot11) {
//////////////////////////////////////////////////////////////////// ////////////////////////////////////////////////////////////////////
TEST_F(DeclarableOpsTests2, TestTensorDot12) { TEST_F(DeclarableOpsTests2, TestTensorDot12) {
auto x = NDArrayFactory::create<float>('c', {2,3,4}, {1,3,5,7,9,11,13,15, 1,3,5,7,9,11,13,15, 1,3,5,7,9,11,13,15}); auto x = NDArrayFactory::create<double>('c', {2,3,4}, {1,3,5,7,9,11,13,15, 1,3,5,7,9,11,13,15, 1,3,5,7,9,11,13,15});
auto y = NDArrayFactory::create<float>('c', {2,4,3}, {2,4,6,8,10,12,14,16, 2,4,6,8,10,12,14,16, 2,4,6,8,10,12,14,16}); auto y = NDArrayFactory::create<double>('c', {2,4,3}, {2,4,6,8,10,12,14,16, 2,4,6,8,10,12,14,16, 2,4,6,8,10,12,14,16});
auto expected = NDArrayFactory::create<float>('c', {4,4}, {272,292,312,332, 368,396,424,452, 464,500,536,572, 560,604,648,692}); auto expected = NDArrayFactory::create<double>('c', {4,4}, {272,292,312,332, 368,396,424,452, 464,500,536,572, 560,604,648,692});
nd4j::ops::tensormmul op; nd4j::ops::tensormmul op;
auto results = op.execute({&x, &y}, {}, {2,0,1, 2,0,2}); auto results = op.execute({&x, &y}, {}, {2,0,1, 2,0,2});
@ -759,9 +759,9 @@ TEST_F(DeclarableOpsTests2, TestTensorDot12) {
//////////////////////////////////////////////////////////////////// ////////////////////////////////////////////////////////////////////
TEST_F(DeclarableOpsTests2, TestTensorDot13) { TEST_F(DeclarableOpsTests2, TestTensorDot13) {
auto x = NDArrayFactory::create<float>('c', {2,3,4}, {1,3,5,7,9,11,13,15, 1,3,5,7,9,11,13,15, 1,3,5,7,9,11,13,15}); auto x = NDArrayFactory::create<double>('c', {2,3,4}, {1,3,5,7,9,11,13,15, 1,3,5,7,9,11,13,15, 1,3,5,7,9,11,13,15});
auto y = NDArrayFactory::create<float>('c', {4,2,3}, {2,4,6,8,10,12,14,16, 2,4,6,8,10,12,14,16, 2,4,6,8,10,12,14,16}); auto y = NDArrayFactory::create<double>('c', {4,2,3}, {2,4,6,8,10,12,14,16, 2,4,6,8,10,12,14,16, 2,4,6,8,10,12,14,16});
auto expected = NDArrayFactory::create<float>('c', {3,3}, {640,560,640, 576,624,576, 640,560,640}); auto expected = NDArrayFactory::create<double>('c', {3,3}, {640,560,640, 576,624,576, 640,560,640});
nd4j::ops::tensormmul op; nd4j::ops::tensormmul op;
auto results = op.execute({&x, &y}, {}, {2,0,2, 2,1,0}); auto results = op.execute({&x, &y}, {}, {2,0,2, 2,1,0});
@ -780,9 +780,9 @@ TEST_F(DeclarableOpsTests2, TestTensorDot13) {
//////////////////////////////////////////////////////////////////// ////////////////////////////////////////////////////////////////////
TEST_F(DeclarableOpsTests2, TestTensorDot14) { TEST_F(DeclarableOpsTests2, TestTensorDot14) {
auto x = NDArrayFactory::create<float>('f', {2,3,4}, {1,3,5,7,9,11,13,15, 1,3,5,7,9,11,13,15, 1,3,5,7,9,11,13,15}); auto x = NDArrayFactory::create<double>('f', {2,3,4}, {1,3,5,7,9,11,13,15, 1,3,5,7,9,11,13,15, 1,3,5,7,9,11,13,15});
auto y = NDArrayFactory::create<float>('c', {4,2,3}, {2,4,6,8,10,12,14,16, 2,4,6,8,10,12,14,16, 2,4,6,8,10,12,14,16}); auto y = NDArrayFactory::create<double>('c', {4,2,3}, {2,4,6,8,10,12,14,16, 2,4,6,8,10,12,14,16, 2,4,6,8,10,12,14,16});
auto expected = NDArrayFactory::create<float>('c', {3,3}, {648,600,520, 648,536,648, 520,600,648}); auto expected = NDArrayFactory::create<double>('c', {3,3}, {648,600,520, 648,536,648, 520,600,648});
nd4j::ops::tensormmul op; nd4j::ops::tensormmul op;
auto results = op.execute({&x, &y}, {}, {2,0,2, 2,1,0}); auto results = op.execute({&x, &y}, {}, {2,0,2, 2,1,0});
@ -801,9 +801,9 @@ TEST_F(DeclarableOpsTests2, TestTensorDot14) {
//////////////////////////////////////////////////////////////////// ////////////////////////////////////////////////////////////////////
TEST_F(DeclarableOpsTests2, TestTensorDot15) { TEST_F(DeclarableOpsTests2, TestTensorDot15) {
auto x = NDArrayFactory::create<float>('f', {2,3,4}, {1,3,5,7,9,11,13,15, 1,3,5,7,9,11,13,15, 1,3,5,7,9,11,13,15}); auto x = NDArrayFactory::create<double>('f', {2,3,4}, {1,3,5,7,9,11,13,15, 1,3,5,7,9,11,13,15, 1,3,5,7,9,11,13,15});
auto y = NDArrayFactory::create<float>('f', {4,2,3}, {2,4,6,8,10,12,14,16, 2,4,6,8,10,12,14,16, 2,4,6,8,10,12,14,16}); auto y = NDArrayFactory::create<double>('f', {4,2,3}, {2,4,6,8,10,12,14,16, 2,4,6,8,10,12,14,16, 2,4,6,8,10,12,14,16});
auto expected = NDArrayFactory::create<float>('c', {3,3}, {624,624,624, 656,656,656, 624,624,624}); auto expected = NDArrayFactory::create<double>('c', {3,3}, {624,624,624, 656,656,656, 624,624,624});
nd4j::ops::tensormmul op; nd4j::ops::tensormmul op;
auto results = op.execute({&x, &y}, {}, {2,0,2, 2,1,0}); auto results = op.execute({&x, &y}, {}, {2,0,2, 2,1,0});
@ -1449,7 +1449,7 @@ TEST_F(DeclarableOpsTests2, cosine_distance_loss_test1) {
auto labels = NDArrayFactory::create<float>('c', {2,3,4}); auto labels = NDArrayFactory::create<float>('c', {2,3,4});
auto predictions = NDArrayFactory::create<float>('c', {2,3,4}); auto predictions = NDArrayFactory::create<float>('c', {2,3,4});
auto weights = NDArrayFactory::create<float>('c', {1,3,4}); auto weights = NDArrayFactory::create<float>('c', {1,3,4});
auto expected = NDArrayFactory::create<float>('c', {1,3,4}, {-91.5,-107.5,-125.5,-145.5, -167.5,-191.5,-217.5,-245.5, -275.5,-307.5,-341.5,-377.5}); auto expected = NDArrayFactory::create<float>('c', {1,3,4}, {-91.5f, -107.5f, -125.5f, -145.5f, -167.5f, -191.5f, -217.5f, -245.5f, -275.5f, -307.5f, -341.5f, -377.5f});
labels.linspace(1); labels.linspace(1);
predictions.linspace(2); predictions.linspace(2);
@ -1475,7 +1475,7 @@ TEST_F(DeclarableOpsTests2, cosine_distance_loss_test2) {
auto labels = NDArrayFactory::create<float>('c', {2,3,4}); auto labels = NDArrayFactory::create<float>('c', {2,3,4});
auto predictions = NDArrayFactory::create<float>('c', {2,3,4}); auto predictions = NDArrayFactory::create<float>('c', {2,3,4});
auto weights = NDArrayFactory::create<float>('c', {2,1,4}); auto weights = NDArrayFactory::create<float>('c', {2,1,4});
auto expected = NDArrayFactory::create<float>('c', {2,1,4}, {-3.25, -4., -4.75, -5.5,-12.25,-13.,-13.75,-14.5}); auto expected = NDArrayFactory::create<float>('c', {2,1,4}, {-3.25f, -4.f, -4.75f, -5.5f, -12.25f, -13.f, -13.75f, -14.5f});
labels.linspace(1); labels.linspace(1);
weights.assign(0.5); weights.assign(0.5);
@ -1502,7 +1502,7 @@ TEST_F(DeclarableOpsTests2, cosine_distance_loss_test3) {
auto labels = NDArrayFactory::create<float>('c', {2,3,4}); auto labels = NDArrayFactory::create<float>('c', {2,3,4});
auto predictions = NDArrayFactory::create<float>('c', {2,3,4}); auto predictions = NDArrayFactory::create<float>('c', {2,3,4});
auto weights = NDArrayFactory::create<float>('c', {2,3,1}); auto weights = NDArrayFactory::create<float>('c', {2,3,1});
auto expected = NDArrayFactory::create<float>('c', {2,3,1}, {-2., -6.,-10.,-14.,-18.,-22.}); auto expected = NDArrayFactory::create<float>('c', {2,3,1}, {-2.f, -6.f,-10.f,-14.f,-18.f,-22.f});
labels.linspace(1); labels.linspace(1);
weights.assign(0.5); weights.assign(0.5);
@ -1527,7 +1527,7 @@ TEST_F(DeclarableOpsTests2, cosine_distance_loss_test4) {
auto labels = NDArrayFactory::create<float>('c', {2,3,4}); auto labels = NDArrayFactory::create<float>('c', {2,3,4});
auto predictions = NDArrayFactory::create<float>('c', {2,3,4}); auto predictions = NDArrayFactory::create<float>('c', {2,3,4});
auto weights = NDArrayFactory::create<float>('c', {1,1}); auto weights = NDArrayFactory::create<float>('c', {1,1});
auto expected = NDArrayFactory::create<float>('c', {2,3,1}, {-2., -6.,-10.,-14.,-18.,-22.}); auto expected = NDArrayFactory::create<float>('c', {2,3,1}, {-2.f, -6.f,-10.f,-14.f,-18.f,-22.f});
labels.linspace(1); labels.linspace(1);
weights.assign(0.5); weights.assign(0.5);
@ -1702,10 +1702,10 @@ TEST_F(DeclarableOpsTests2, cosine_distance_loss_test10) {
/////////////////////////////////////////////////////////////////// ///////////////////////////////////////////////////////////////////
TEST_F(DeclarableOpsTests2, hinge_loss_test1) { TEST_F(DeclarableOpsTests2, hinge_loss_test1) {
auto labels = NDArrayFactory::create<float>('c', {2,3,4},{0,1,1,0,1,0,1,0,1,0,1,0,1,0,1,0,1,0,0,1,1,0,1,0}); auto labels = NDArrayFactory::create<double>('c', {2,3,4},{0,1,1,0,1,0,1,0,1,0,1,0,1,0,1,0,1,0,0,1,1,0,1,0});
auto logits = NDArrayFactory::create<float>('c', {2,3,4}); auto logits = NDArrayFactory::create<double>('c', {2,3,4});
auto weights = NDArrayFactory::create<float>('c', {2,3,4}); auto weights = NDArrayFactory::create<double>('c', {2,3,4});
auto expected = NDArrayFactory::create<float>('c', {2,3,4}, {1., 0. , 0., 2.5,0., 3.5, 0., 4.5,0., 5.5, 0., 6.5, 0., 7.5, 0., 8.5,0., 9.5,10., 0. ,0.,11.5, 0.,12.5}); auto expected = NDArrayFactory::create<double>('c', {2,3,4}, {1., 0. , 0., 2.5,0., 3.5, 0., 4.5,0., 5.5, 0., 6.5, 0., 7.5, 0., 8.5,0., 9.5,10., 0. ,0.,11.5, 0.,12.5});
logits.linspace(1); logits.linspace(1);
weights.assign(0.5); weights.assign(0.5);
@ -1727,10 +1727,10 @@ TEST_F(DeclarableOpsTests2, hinge_loss_test1) {
/////////////////////////////////////////////////////////////////// ///////////////////////////////////////////////////////////////////
TEST_F(DeclarableOpsTests2, hinge_loss_test2) { TEST_F(DeclarableOpsTests2, hinge_loss_test2) {
auto labels = NDArrayFactory::create<float>('c', {2,3,4},{0,1,1,0,1,0,1,0,1,0,1,0,1,0,1,0,1,0,0,1,1,0,1,0}); auto labels = NDArrayFactory::create<double>('c', {2,3,4},{0,1,1,0,1,0,1,0,1,0,1,0,1,0,1,0,1,0,0,1,1,0,1,0});
auto logits = NDArrayFactory::create<float>('c', {2,3,4}); auto logits = NDArrayFactory::create<double>('c', {2,3,4});
auto weights = NDArrayFactory::create<float>('c', {1,1}); auto weights = NDArrayFactory::create<double>('c', {1,1});
auto expected = NDArrayFactory::create<float>('c', {2,3,4}, {1., 0. , 0., 2.5,0., 3.5, 0., 4.5,0., 5.5, 0., 6.5, 0., 7.5, 0., 8.5,0., 9.5,10., 0. ,0.,11.5, 0.,12.5}); auto expected = NDArrayFactory::create<double>('c', {2,3,4}, {1., 0. , 0., 2.5,0., 3.5, 0., 4.5,0., 5.5, 0., 6.5, 0., 7.5, 0., 8.5,0., 9.5,10., 0. ,0.,11.5, 0.,12.5});
logits.linspace(1); logits.linspace(1);
weights.assign(0.5); weights.assign(0.5);
@ -1752,10 +1752,10 @@ TEST_F(DeclarableOpsTests2, hinge_loss_test2) {
/////////////////////////////////////////////////////////////////// ///////////////////////////////////////////////////////////////////
TEST_F(DeclarableOpsTests2, hinge_loss_test3) { TEST_F(DeclarableOpsTests2, hinge_loss_test3) {
auto labels = NDArrayFactory::create<float>('c', {2,3,4},{0,1,1,0,1,0,1,0,1,0,1,0,1,0,1,0,1,0,0,1,1,0,1,0}); auto labels = NDArrayFactory::create<double>('c', {2,3,4},{0,1,1,0,1,0,1,0,1,0,1,0,1,0,1,0,1,0,0,1,1,0,1,0});
auto logits = NDArrayFactory::create<float>('c', {2,3,4}); auto logits = NDArrayFactory::create<double>('c', {2,3,4});
auto weights = NDArrayFactory::create<float>('c', {1,3,1}); auto weights = NDArrayFactory::create<double>('c', {1,3,1});
auto expected = NDArrayFactory::create<float>('c', {2,3,4}, {1., 0. , 0., 2.5,0., 3.5, 0., 4.5,0., 5.5, 0., 6.5, 0., 7.5, 0., 8.5,0., 9.5,10., 0. ,0.,11.5, 0.,12.5}); auto expected = NDArrayFactory::create<double>('c', {2,3,4}, {1., 0. , 0., 2.5,0., 3.5, 0., 4.5,0., 5.5, 0., 6.5, 0., 7.5, 0., 8.5,0., 9.5,10., 0. ,0.,11.5, 0.,12.5});
logits.linspace(1); logits.linspace(1);
weights.assign(0.5); weights.assign(0.5);
@ -1777,9 +1777,9 @@ TEST_F(DeclarableOpsTests2, hinge_loss_test3) {
/////////////////////////////////////////////////////////////////// ///////////////////////////////////////////////////////////////////
TEST_F(DeclarableOpsTests2, hinge_loss_test4) { TEST_F(DeclarableOpsTests2, hinge_loss_test4) {
auto labels = NDArrayFactory::create<float>('c', {2,3,4},{0,1,1,0,1,0,1,0,1,0,1,0,1,0,1,0,1,0,0,1,1,0,1,0}); auto labels = NDArrayFactory::create<double>('c', {2,3,4},{0,1,1,0,1,0,1,0,1,0,1,0,1,0,1,0,1,0,0,1,1,0,1,0});
auto logits = NDArrayFactory::create<float>('c', {2,3,4}); auto logits = NDArrayFactory::create<double>('c', {2,3,4});
auto weights = NDArrayFactory::create<float>('c', {2,3,4}); auto weights = NDArrayFactory::create<double>('c', {2,3,4});
logits.linspace(1); logits.linspace(1);
weights.assign(0.5); weights.assign(0.5);
@ -1801,9 +1801,9 @@ TEST_F(DeclarableOpsTests2, hinge_loss_test4) {
/////////////////////////////////////////////////////////////////// ///////////////////////////////////////////////////////////////////
TEST_F(DeclarableOpsTests2, hinge_loss_test5) { TEST_F(DeclarableOpsTests2, hinge_loss_test5) {
auto labels = NDArrayFactory::create<float>('c', {2,3,4},{0,1,1,0,1,0,1,0,1,0,1,0,1,0,1,0,1,0,0,1,1,0,1,0}); auto labels = NDArrayFactory::create<double>('c', {2,3,4},{0,1,1,0,1,0,1,0,1,0,1,0,1,0,1,0,1,0,0,1,1,0,1,0});
auto logits = NDArrayFactory::create<float>('c', {2,3,4}); auto logits = NDArrayFactory::create<double>('c', {2,3,4});
auto weights = NDArrayFactory::create<float>('c', {1,1}); auto weights = NDArrayFactory::create<double>('c', {1,1});
logits.linspace(1); logits.linspace(1);
weights.assign(0.5); weights.assign(0.5);
@ -1825,9 +1825,9 @@ TEST_F(DeclarableOpsTests2, hinge_loss_test5) {
/////////////////////////////////////////////////////////////////// ///////////////////////////////////////////////////////////////////
TEST_F(DeclarableOpsTests2, hinge_loss_test6) { TEST_F(DeclarableOpsTests2, hinge_loss_test6) {
auto labels = NDArrayFactory::create<float>('c', {2,3,4},{0,1,1,0,1,0,1,0,1,0,1,0,1,0,1,0,1,0,0,1,1,0,1,0}); auto labels = NDArrayFactory::create<double>('c', {2,3,4},{0,1,1,0,1,0,1,0,1,0,1,0,1,0,1,0,1,0,0,1,1,0,1,0});
auto logits = NDArrayFactory::create<float>('c', {2,3,4}); auto logits = NDArrayFactory::create<double>('c', {2,3,4});
auto weights = NDArrayFactory::create<float>('c', {2,1,1}); auto weights = NDArrayFactory::create<double>('c', {2,1,1});
logits.linspace(1); logits.linspace(1);
weights.assign(0.5); weights.assign(0.5);
@ -1849,9 +1849,9 @@ TEST_F(DeclarableOpsTests2, hinge_loss_test6) {
/////////////////////////////////////////////////////////////////// ///////////////////////////////////////////////////////////////////
TEST_F(DeclarableOpsTests2, hinge_loss_test7) { TEST_F(DeclarableOpsTests2, hinge_loss_test7) {
auto labels = NDArrayFactory::create<float>('c', {2,3,4},{0,1,1,0,1,0,1,0,1,0,1,0,1,0,1,0,1,0,0,1,1,0,1,0}); auto labels = NDArrayFactory::create<double>('c', {2,3,4},{0,1,1,0,1,0,1,0,1,0,1,0,1,0,1,0,1,0,0,1,1,0,1,0});
auto logits = NDArrayFactory::create<float>('c', {2,3,4}); auto logits = NDArrayFactory::create<double>('c', {2,3,4});
auto weights = NDArrayFactory::create<float>('c', {2,3,4}); auto weights = NDArrayFactory::create<double>('c', {2,3,4});
logits.linspace(1); logits.linspace(1);
weights.assign(0.5); weights.assign(0.5);
@ -1873,9 +1873,9 @@ TEST_F(DeclarableOpsTests2, hinge_loss_test7) {
/////////////////////////////////////////////////////////////////// ///////////////////////////////////////////////////////////////////
TEST_F(DeclarableOpsTests2, hinge_loss_test8) { TEST_F(DeclarableOpsTests2, hinge_loss_test8) {
auto labels = NDArrayFactory::create<float>('c', {2,3,4},{0,1,1,0,1,0,1,0,1,0,1,0,1,0,1,0,1,0,0,1,1,0,1,0}); auto labels = NDArrayFactory::create<double>('c', {2,3,4},{0,1,1,0,1,0,1,0,1,0,1,0,1,0,1,0,1,0,0,1,1,0,1,0});
auto logits = NDArrayFactory::create<float>('c', {2,3,4}); auto logits = NDArrayFactory::create<double>('c', {2,3,4});
auto weights = NDArrayFactory::create<float>('c', {1,1}); auto weights = NDArrayFactory::create<double>('c', {1,1});
logits.linspace(1); logits.linspace(1);
weights.assign(0.5); weights.assign(0.5);
@ -1897,9 +1897,9 @@ TEST_F(DeclarableOpsTests2, hinge_loss_test8) {
/////////////////////////////////////////////////////////////////// ///////////////////////////////////////////////////////////////////
TEST_F(DeclarableOpsTests2, hinge_loss_test9) { TEST_F(DeclarableOpsTests2, hinge_loss_test9) {
auto labels = NDArrayFactory::create<float>('c', {2,3,4},{0,1,1,0,1,0,1,0,1,0,1,0,1,0,1,0,1,0,0,1,1,0,1,0}); auto labels = NDArrayFactory::create<double>('c', {2,3,4},{0,1,1,0,1,0,1,0,1,0,1,0,1,0,1,0,1,0,0,1,1,0,1,0});
auto logits = NDArrayFactory::create<float>('c', {2,3,4}); auto logits = NDArrayFactory::create<double>('c', {2,3,4});
auto weights = NDArrayFactory::create<float>('c', {1,1,4}); auto weights = NDArrayFactory::create<double>('c', {1,1,4});
logits.linspace(1); logits.linspace(1);
weights.assign(0.5); weights.assign(0.5);
@ -1921,9 +1921,9 @@ TEST_F(DeclarableOpsTests2, hinge_loss_test9) {
/////////////////////////////////////////////////////////////////// ///////////////////////////////////////////////////////////////////
TEST_F(DeclarableOpsTests2, hinge_loss_test10) { TEST_F(DeclarableOpsTests2, hinge_loss_test10) {
auto labels = NDArrayFactory::create<float>('c', {2,3,4},{0,1,1,0,1,0,1,0,1,0,1,0,1,0,1,0,1,0,0,1,1,0,1,0}); auto labels = NDArrayFactory::create<double>('c', {2,3,4},{0,1,1,0,1,0,1,0,1,0,1,0,1,0,1,0,1,0,0,1,1,0,1,0});
auto logits = NDArrayFactory::create<float>('c', {2,3,4}); auto logits = NDArrayFactory::create<double>('c', {2,3,4});
auto weights = NDArrayFactory::create<float>('c', {2,3,4}); auto weights = NDArrayFactory::create<double>('c', {2,3,4});
logits.linspace(1); logits.linspace(1);
weights.assign(0.5); weights.assign(0.5);
@ -1945,9 +1945,9 @@ TEST_F(DeclarableOpsTests2, hinge_loss_test10) {
/////////////////////////////////////////////////////////////////// ///////////////////////////////////////////////////////////////////
TEST_F(DeclarableOpsTests2, hinge_loss_test11) { TEST_F(DeclarableOpsTests2, hinge_loss_test11) {
auto labels = NDArrayFactory::create<float>('c', {2,3,4},{0,1,1,0,1,0,1,0,1,0,1,0,1,0,1,0,1,0,0,1,1,0,1,0}); auto labels = NDArrayFactory::create<double>('c', {2,3,4},{0,1,1,0,1,0,1,0,1,0,1,0,1,0,1,0,1,0,0,1,1,0,1,0});
auto logits = NDArrayFactory::create<float>('c', {2,3,4}); auto logits = NDArrayFactory::create<double>('c', {2,3,4});
auto weights = NDArrayFactory::create<float>('c', {2,1,4}); auto weights = NDArrayFactory::create<double>('c', {2,1,4});
logits.linspace(1); logits.linspace(1);
weights.assign(0.5); weights.assign(0.5);
@ -1969,9 +1969,9 @@ TEST_F(DeclarableOpsTests2, hinge_loss_test11) {
/////////////////////////////////////////////////////////////////// ///////////////////////////////////////////////////////////////////
TEST_F(DeclarableOpsTests2, hinge_loss_test12) { TEST_F(DeclarableOpsTests2, hinge_loss_test12) {
auto labels = NDArrayFactory::create<float>('c', {2,3,4},{0,1,1,0,1,0,1,0,1,0,1,0,1,0,1,0,1,0,0,1,1,0,1,0}); auto labels = NDArrayFactory::create<double>('c', {2,3,4},{0,1,1,0,1,0,1,0,1,0,1,0,1,0,1,0,1,0,0,1,1,0,1,0});
auto logits = NDArrayFactory::create<float>('c', {2,3,4}); auto logits = NDArrayFactory::create<double>('c', {2,3,4});
auto weights = NDArrayFactory::create<float>('c', {2,3,4}); auto weights = NDArrayFactory::create<double>('c', {2,3,4});
logits.linspace(1); logits.linspace(1);
weights.assign(0.5); weights.assign(0.5);
@ -1997,9 +1997,9 @@ TEST_F(DeclarableOpsTests2, hinge_loss_test12) {
/////////////////////////////////////////////////////////////////// ///////////////////////////////////////////////////////////////////
TEST_F(DeclarableOpsTests2, hinge_loss_test13) { TEST_F(DeclarableOpsTests2, hinge_loss_test13) {
auto labels = NDArrayFactory::create<float>('c', {2,3,4},{0,1,1,0,1,0,1,0,1,0,1,0,1,0,1,0,1,0,0,1,1,0,1,0}); auto labels = NDArrayFactory::create<double>('c', {2,3,4},{0,1,1,0,1,0,1,0,1,0,1,0,1,0,1,0,1,0,0,1,1,0,1,0});
auto logits = NDArrayFactory::create<float>('c', {2,3,4}); auto logits = NDArrayFactory::create<double>('c', {2,3,4});
auto weights = NDArrayFactory::create<float>('c', {1,1}); auto weights = NDArrayFactory::create<double>('c', {1,1});
logits.linspace(1); logits.linspace(1);
weights.assign(0.); weights.assign(0.);

File diff suppressed because it is too large Load Diff

View File

@ -201,7 +201,7 @@ TYPED_TEST(TypedDeclarableOpsTests4, Test_Pooling_Parity_10) {
TYPED_TEST(TypedDeclarableOpsTests4, Test_Pooling_Parity_11) { TYPED_TEST(TypedDeclarableOpsTests4, Test_Pooling_Parity_11) {
auto x = NDArrayFactory::create<TypeParam>('c', {1, 1, 3, 3}); auto x = NDArrayFactory::create<TypeParam>('c', {1, 1, 3, 3});
auto exp = NDArrayFactory::create<TypeParam>('c', {1, 1, 2, 2}, {3, 4, 6, 7}); auto exp = NDArrayFactory::create<TypeParam>('c', {1, 1, 2, 2}, {3.f, 4.f, 6.f, 7.f});
x.linspace(1); x.linspace(1);
@ -1582,17 +1582,17 @@ TEST_F(DeclarableOpsTests4, relu6_bp_test1) {
//////////////////////////////////////////////////////////////////////////////// ////////////////////////////////////////////////////////////////////////////////
TYPED_TEST(TypedDeclarableOpsTests4, LrnTest_1) { TYPED_TEST(TypedDeclarableOpsTests4, LrnTest_1) {
auto x = NDArrayFactory::create<TypeParam>('c', {2, 2, 2, 2}, { 5.5, 0., 0.3, 5.5, auto x = NDArrayFactory::create<TypeParam>('c', {2, 2, 2, 2}, { 5.5f, 0.f, 0.3f, 5.5f,
8.6, 0., 0., 0.4, 8.6f, 0.f, 0.f, 0.4f,
1.5, 1., 1.3, 1.5, 1.5f, 1.f, 1.3f, 1.5f,
2.6, 2., 3., 1.4} 2.6f, 2.f, 3.f, 1.4f}
); );
auto exp = NDArrayFactory::create<TypeParam>('c', {2, 2, 2, 2}, { auto exp = NDArrayFactory::create<TypeParam>('c', {2, 2, 2, 2}, {
0.98386997, 0., 0.05358852, 0.9824562, 0.98386997f, 0.f, 0.05358852f, 0.9824562f,
0.99330735, 0., 0., 0.37139067, 0.99330735f, 0.f, 0.f, 0.37139067f,
0.72760683, 0.4850712, 0.5848977, 0.67488194, 0.72760683f, 0.4850712f, 0.5848977f, 0.67488194f,
0.7581754, 0.58321184, 0.86747235, 0.4048204} 0.7581754f, 0.58321184f, 0.86747235f, 0.4048204f}
); );
nd4j::ops::lrn op; nd4j::ops::lrn op;
@ -1612,16 +1612,16 @@ TYPED_TEST(TypedDeclarableOpsTests4, LrnTest_1) {
//////////////////////////////////////////////////////////////////////////////// ////////////////////////////////////////////////////////////////////////////////
TYPED_TEST(TypedDeclarableOpsTests4, LrnTest_2) { TYPED_TEST(TypedDeclarableOpsTests4, LrnTest_2) {
auto x = NDArrayFactory::create<TypeParam>('c', {2, 2, 2, 2}, { 5.5, 0., 0.3, 5.5, auto x = NDArrayFactory::create<TypeParam>('c', {2, 2, 2, 2}, { 5.5f, 0.f, 0.3f, 5.5f,
8.6, 0., 0., 0.4, 8.6f, 0.f, 0.f, 0.4f,
1.5, 1., 1.3, 1.5, 1.5f, 1.f, 1.3f, 1.5f,
2.6, 2., 3., 1.4}); 2.6f, 2.f, 3.f, 1.4f});
auto exp = NDArrayFactory::create<TypeParam>('c', {2, 2, 2, 2}, { auto exp = NDArrayFactory::create<TypeParam>('c', {2, 2, 2, 2}, {
0.98386997, 0., 0.05358852, 0.9824562, 0.98386997f, 0.f, 0.05358852f, 0.9824562f,
0.99330735, 0., 0., 0.37139067, 0.99330735f, 0.f, 0.f, 0.37139067f,
0.72760683, 0.4850712, 0.5848977, 0.67488194, 0.72760683f, 0.4850712f, 0.5848977f, 0.67488194f,
0.7581754, 0.58321184, 0.86747235, 0.4048204}); 0.7581754f, 0.58321184f, 0.86747235f, 0.4048204f});
nd4j::ops::lrn op; nd4j::ops::lrn op;
auto results = op.execute({&x}, {1.0, 1.0, 0.5}, {2}, {}, false, nd4j::DataType::DOUBLE); auto results = op.execute({&x}, {1.0, 1.0, 0.5}, {2}, {}, false, nd4j::DataType::DOUBLE);
@ -1641,25 +1641,25 @@ TYPED_TEST(TypedDeclarableOpsTests4, LrnTest_3) {
auto x = NDArrayFactory::create<TypeParam>('c', {2, 2, 2, 4}, { auto x = NDArrayFactory::create<TypeParam>('c', {2, 2, 2, 4}, {
5.5, 0., 0.3, 5.5, 5.5f, 0.f, 0.3f, 5.5f,
1.5, 0., 1.3, 6.5, 1.5f, 0.f, 1.3f, 6.5f,
8.6, 0., 0., 0.4, 8.6f, 0.f, 0.f, 0.4f,
2.5, 1., 0.3, 4.5, 2.5f, 1.f, 0.3f, 4.5f,
1.5, 1., 1.3, 1.5, 1.5f, 1.f, 1.3f, 1.5f,
3.5, 0., 1.3, 2.5, 3.5f, 0.f, 1.3f, 2.5f,
2.6, 2., 3., 1.4, 2.6f, 2.f, 3.f, 1.4f,
4.5, 1., 0.3, 0.5} 4.5f, 1.f, 0.3f, 0.5f}
); );
auto exp = NDArrayFactory::create<TypeParam>('c', {2, 2, 2, 4}, { auto exp = NDArrayFactory::create<TypeParam>('c', {2, 2, 2, 4}, {
0.9824562, 0., 0.03822664, 0.9824562, 0.9824562f, 0.f, 0.03822664f, 0.9824562f,
0.67488194, 0., 0.18924236, 0.96960944, 0.67488194f, 0.f, 0.18924236f, 0.96960944f,
0.99330735, 0., 0., 0.37139067, 0.99330735f, 0.f, 0.f, 0.37139067f,
0.86567914, 0.18702209, 0.05610663, 0.9520745, 0.86567914f, 0.18702209f, 0.05610663f, 0.9520745f,
0.6154575, 0.34942827, 0.45425674, 0.6154575, 0.6154575f, 0.34942827f, 0.45425674f, 0.6154575f,
0.905509, 0. , 0.2824086, 0.8361251, 0.905509f, 0.f, 0.2824086f, 0.8361251f,
0.57063663, 0.41959068, 0.629386, 0.3504383, 0.57063663f, 0.41959068f, 0.629386f, 0.3504383f,
0.9520745, 0.21039814, 0.06311944, 0.3268602 } 0.9520745f, 0.21039814f, 0.06311944f, 0.3268602f }
); );
nd4j::ops::lrn op; nd4j::ops::lrn op;
@ -1680,25 +1680,25 @@ TYPED_TEST(TypedDeclarableOpsTests4, LrnTest_4) {
auto x = NDArrayFactory::create<TypeParam>('c', {2, 2, 2, 4}, { auto x = NDArrayFactory::create<TypeParam>('c', {2, 2, 2, 4}, {
5.5, 0., 0.3, 5.5, 5.5f, 0.f, 0.3f, 5.5f,
1.5, 0., 1.3, 6.5, 1.5f, 0.f, 1.3f, 6.5f,
8.6, 0., 0., 0.4, 8.6f, 0.f, 0.f, 0.4f,
2.5, 1., 0.3, 4.5, 2.5f, 1.f, 0.3f, 4.5f,
1.5, 1., 1.3, 1.5, 1.5f, 1.f, 1.3f, 1.5f,
3.5, 0., 1.3, 2.5, 3.5f, 0.f, 1.3f, 2.5f,
2.6, 2., 3., 1.4, 2.6f, 2.f, 3.f, 1.4f,
4.5, 1., 0.3, 0.5} 4.5f, 1.f, 0.3f, 0.5f}
); );
auto exp = NDArrayFactory::create<TypeParam>('c', {2, 2, 2, 4}, { auto exp = NDArrayFactory::create<TypeParam>('c', {2, 2, 2, 4}, {
0.70082176, 0., 0.03822664, 0.70082176, 0.70082176f, 0.f, 0.03822664f, 0.70082176f,
0.21835658, 0., 0.18924236, 0.9462118, 0.21835658f, 0.f, 0.18924236f, 0.9462118f,
0.9922489, 0., 0., 0.04615111, 0.9922489f, 0.f, 0.f, 0.04615111f,
0.46755522, 0.18702209, 0.05610663, 0.8415994, 0.46755522f, 0.18702209f, 0.05610663f, 0.8415994f,
0.5241424, 0.34942827, 0.45425674, 0.5241424, 0.5241424f, 0.34942827f, 0.45425674f, 0.5241424f,
0.76033086, 0., 0.2824086, 0.54309344, 0.76033086f, 0.f, 0.2824086f, 0.54309344f,
0.54546785, 0.41959068, 0.629386, 0.29371348, 0.54546785f, 0.41959068f, 0.629386f, 0.29371348f,
0.94679165, 0.21039814, 0.06311944, 0.10519907} 0.94679165f, 0.21039814f, 0.06311944f, 0.10519907f}
); );
nd4j::ops::lrn op; nd4j::ops::lrn op;
@ -1719,29 +1719,29 @@ TYPED_TEST(TypedDeclarableOpsTests4, LrnTest_5) {
auto x = NDArrayFactory::create<TypeParam>('c', {2, 2, 2, 4}, { auto x = NDArrayFactory::create<TypeParam>('c', {2, 2, 2, 4}, {
5.5,0., 0.3, 5.5, 5.5f, 0.f, 0.3f, 5.5f,
1.5,0., 1.3, 6.5, 1.5f, 0.f, 1.3f, 6.5f,
8.6,0., 0., 0.4, 8.6f, 0.f, 0.f, 0.4f,
2.5,1., 0.3, 4.5, 2.5f, 1.f, 0.3f, 4.5f,
1.5,1., 1.3, 1.5, 1.5f, 1.f, 1.3f, 1.5f,
3.5,0., 1.3, 2.5, 3.5f, 0.f, 1.3f, 2.5f,
2.6,2., 3., 1.4, 2.6f, 2.f, 3.f, 1.4f,
4.5,1., 0.3, 0.5} 4.5f, 1.f, 0.3f, 0.5f}
); );
auto eps = NDArrayFactory::create<TypeParam>('c', {2, 2, 2, 4}, { auto eps = NDArrayFactory::create<TypeParam>('c', {2, 2, 2, 4}, {
0.70082176, 0., 0.03822664, 0.70082176, 0.70082176f, 0.f, 0.03822664f, 0.70082176f,
0.21835658, 0., 0.18924236, 0.9462118, 0.21835658f, 0.f, 0.18924236f, 0.9462118f,
0.9922489, 0., 0. , 0.04615111, 0.9922489f, 0.f, 0.f, 0.04615111f,
0.46755522, 0.18702209, 0.05610663, 0.8415994, 0.46755522f, 0.18702209f, 0.05610663f, 0.8415994f,
0.5241424, 0.34942827, 0.45425674, 0.5241424, 0.5241424f, 0.34942827f, 0.45425674f, 0.5241424f,
0.76033086, 0., 0.2824086 , 0.54309344, 0.76033086f, 0.f, 0.2824086f, 0.54309344f,
0.54546785, 0.41959068, 0.629386 , 0.29371348, 0.54546785f, 0.41959068f, 0.629386f, 0.29371348f,
0.94679165, 0.21039814, 0.06311944, 0.10519907} 0.94679165f, 0.21039814f, 0.06311944f, 0.10519907f}
); );
auto exp = NDArrayFactory::create<TypeParam>('c', {2, 2, 2, 4}); auto exp = NDArrayFactory::create<TypeParam>('c', {2, 2, 2, 4});
@ -1766,7 +1766,7 @@ TEST_F(DeclarableOpsTests4, tri_test1) {
const int rows = 3; const int rows = 3;
const int cols = 5; const int cols = 5;
auto expected = NDArrayFactory::create<float>('c', {rows, cols}, {1, 0, 0, 0, 0, 1, 1, 0, 0, 0, 1, 1, 1, 0, 0}); auto expected = NDArrayFactory::create<float>('c', {rows, cols}, {1.f, 0.f, 0.f, 0.f, 0.f, 1.f, 1.f, 0.f, 0.f, 0.f, 1.f, 1.f, 1.f, 0.f, 0.f});
nd4j::ops::tri op; nd4j::ops::tri op;
auto results = op.execute({}, {}, {rows, cols}); auto results = op.execute({}, {}, {rows, cols});
@ -1789,7 +1789,7 @@ TEST_F(DeclarableOpsTests4, tri_test2) {
const int cols = 5; const int cols = 5;
const int diag = 2; const int diag = 2;
auto expected = NDArrayFactory::create<float>('c', {rows, cols}, {1, 1, 1, 0, 0, 1, 1, 1, 1, 0, 1, 1, 1, 1, 1}); auto expected = NDArrayFactory::create<float>('c', {rows, cols}, {1.f, 1.f, 1.f, 0.f, 0.f, 1.f, 1.f, 1.f, 1.f, 0.f, 1.f, 1.f, 1.f, 1.f, 1.f});
nd4j::ops::tri op; nd4j::ops::tri op;
auto results = op.execute({}, {}, {rows, cols, diag}); auto results = op.execute({}, {}, {rows, cols, diag});
@ -1810,7 +1810,7 @@ TEST_F(DeclarableOpsTests4, tri_test3) {
const int cols = 5; const int cols = 5;
const int diag = -1; const int diag = -1;
auto expected = NDArrayFactory::create<float>('c', {rows, cols}, {0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 1, 1, 0, 0, 0}); auto expected = NDArrayFactory::create<float>('c', {rows, cols}, {0.f, 0.f, 0.f, 0.f, 0.f, 1.f, 0.f, 0.f, 0.f, 0.f, 1.f, 1.f, 0.f, 0.f, 0.f});
nd4j::ops::tri op; nd4j::ops::tri op;
auto results = op.execute({}, {}, {rows, cols, diag}); auto results = op.execute({}, {}, {rows, cols, diag});
@ -1831,7 +1831,7 @@ TEST_F(DeclarableOpsTests4, tri_test4) {
const int cols = 5; const int cols = 5;
const int diag = -2; const int diag = -2;
auto expected = NDArrayFactory::create<float>('c', {rows, cols}, {0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0}); auto expected = NDArrayFactory::create<float>('c', {rows, cols}, {0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 1.f, 0.f, 0.f, 0.f, 0.f});
nd4j::ops::tri op; nd4j::ops::tri op;
auto results = op.execute({}, {}, {rows, cols, diag}); auto results = op.execute({}, {}, {rows, cols, diag});
@ -1850,7 +1850,7 @@ TEST_F(DeclarableOpsTests4, tri_test5) {
const int rows = 5; const int rows = 5;
auto expected = NDArrayFactory::create<float>('c', {rows, rows}, {1, 0, 0, 0, 0, 1, 1, 0, 0, 0, 1, 1, 1, 0, 0, 1, 1, 1, 1, 0, 1, 1, 1, 1, 1}); auto expected = NDArrayFactory::create<float>('c', {rows, rows}, {1.f, 0.f, 0.f, 0.f, 0.f, 1.f, 1.f, 0.f, 0.f, 0.f, 1.f, 1.f, 1.f, 0.f, 0.f, 1.f, 1.f, 1.f, 1.f, 0.f, 1.f, 1.f, 1.f, 1.f, 1.f});
nd4j::ops::tri op; nd4j::ops::tri op;
auto results = op.execute({}, {}, {rows}); auto results = op.execute({}, {}, {rows});
@ -1871,7 +1871,7 @@ TEST_F(DeclarableOpsTests4, tri_test6) {
const int cols = 5; const int cols = 5;
const int diag = -20; const int diag = -20;
auto expected = NDArrayFactory::create<float>('c', {rows, cols}, {0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0}); auto expected = NDArrayFactory::create<float>('c', {rows, cols}, {0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f});
nd4j::ops::tri op; nd4j::ops::tri op;
auto results = op.execute({}, {}, {rows, cols, diag}); auto results = op.execute({}, {}, {rows, cols, diag});
@ -1892,7 +1892,7 @@ TEST_F(DeclarableOpsTests4, tri_test7) {
const int cols = 5; const int cols = 5;
const int diag = 20; const int diag = 20;
auto expected = NDArrayFactory::create<float>('c', {rows, cols}, {1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1}); auto expected = NDArrayFactory::create<float>('c', {rows, cols}, {1.f, 1.f, 1.f, 1.f, 1.f, 1.f, 1.f, 1.f, 1.f, 1.f, 1.f, 1.f, 1.f, 1.f, 1.f});
nd4j::ops::tri op; nd4j::ops::tri op;
auto results = op.execute({}, {}, {rows, cols, diag}); auto results = op.execute({}, {}, {rows, cols, diag});

View File

@ -242,10 +242,10 @@ TEST_F(DeclarableOpsTests5, Test_SetSeed_1) {
} }
//////////////////////////////////////////////////////////////////// ////////////////////////////////////////////////////////////////////
TEST_F(DeclarableOpsTests5, scatterMul_test1) { TEST_F(DeclarableOpsTests5, scatterMul_test1) {
auto matrix = NDArrayFactory::create<float>('c', {2, 2}, {1, 2, 3, 4}); auto matrix = NDArrayFactory::create<float>('c', {2, 2}, {1.f, 2.f, 3.f, 4.f});
NDArray idc('c', {1}, {0LL}, nd4j::DataType::INT64); NDArray idc('c', {1}, {0LL}, nd4j::DataType::INT64);
auto updates = NDArrayFactory::create<float>('c', {1, 2}, {10, 1}); auto updates = NDArrayFactory::create<float>('c', {1, 2}, {10.f, 1.f});
auto exp = NDArrayFactory::create<float>('c', {2, 2}, {10, 2, 3, 4}); auto exp = NDArrayFactory::create<float>('c', {2, 2}, {10.f, 2.f, 3.f, 4.f});
nd4j::ops::scatter_mul op; nd4j::ops::scatter_mul op;
auto result = op.execute({&matrix, &idc, &updates}, {}, {}, {}); auto result = op.execute({&matrix, &idc, &updates}, {}, {}, {});
@ -260,10 +260,10 @@ TEST_F(DeclarableOpsTests5, scatterMul_test1) {
//////////////////////////////////////////////////////////////////// ////////////////////////////////////////////////////////////////////
TEST_F(DeclarableOpsTests5, scatterDiv_test1) { TEST_F(DeclarableOpsTests5, scatterDiv_test1) {
auto matrix = NDArrayFactory::create<float>('c', {2, 2}, {1, 2, 3, 4}); auto matrix = NDArrayFactory::create<float>('c', {2, 2}, {1.f, 2.f, 3.f, 4.f});
NDArray idc('c', {1}, {0LL}, nd4j::DataType::INT64); NDArray idc('c', {1}, {0LL}, nd4j::DataType::INT64);
auto updates = NDArrayFactory::create<float>('c', {1, 2}, {10, 1}); auto updates = NDArrayFactory::create<float>('c', {1, 2}, {10.f, 1.f});
auto exp = NDArrayFactory::create<float>('c', {2, 2}, {0.10, 2, 3, 4}); auto exp = NDArrayFactory::create<float>('c', {2, 2}, {0.10f, 2.f, 3.f, 4.f});
nd4j::ops::scatter_div op; nd4j::ops::scatter_div op;
auto result = op.execute({&matrix, &idc, &updates}, {}, {}, {}); auto result = op.execute({&matrix, &idc, &updates}, {}, {}, {});
@ -278,10 +278,10 @@ TEST_F(DeclarableOpsTests5, scatterDiv_test1) {
//////////////////////////////////////////////////////////////////// ////////////////////////////////////////////////////////////////////
TEST_F(DeclarableOpsTests5, scatterSub_test1) { TEST_F(DeclarableOpsTests5, scatterSub_test1) {
auto matrix = NDArrayFactory::create<float>('c', {2, 2}, {1, 2, 3, 4}); auto matrix = NDArrayFactory::create<float>('c', {2, 2}, {1.f, 2.f, 3.f, 4.f});
NDArray idc('c', {1}, {0LL}, nd4j::DataType::INT64); NDArray idc('c', {1}, {0LL}, nd4j::DataType::INT64);
auto updates = NDArrayFactory::create<float>('c', {1, 2}, {10, 1}); auto updates = NDArrayFactory::create<float>('c', {1, 2}, {10.f, 1.f});
auto exp = NDArrayFactory::create<float>('c', {2, 2}, {-9, 1, 3, 4}); auto exp = NDArrayFactory::create<float>('c', {2, 2}, {-9.f, 1.f, 3.f, 4.f});
nd4j::ops::scatter_sub op; nd4j::ops::scatter_sub op;
auto result = op.execute({&matrix, &idc, &updates}, {}, {}, {}); auto result = op.execute({&matrix, &idc, &updates}, {}, {}, {});
@ -296,8 +296,8 @@ TEST_F(DeclarableOpsTests5, scatterSub_test1) {
//////////////////////////////////////////////////////////////////// ////////////////////////////////////////////////////////////////////
TEST_F(DeclarableOpsTests5, hardsigmoid_test1) { TEST_F(DeclarableOpsTests5, hardsigmoid_test1) {
auto matrix = NDArrayFactory::create<float>('c', {2, 2}, {1, 2, 3, 4}); auto matrix = NDArrayFactory::create<float>('c', {2, 2}, {1.f, 2.f, 3.f, 4.f});
auto exp = NDArrayFactory::create<float>('c', {2, 2}, {0.7, 0.9, 1, 1}); auto exp = NDArrayFactory::create<float>('c', {2, 2}, {0.7f, 0.9f, 1.f, 1.f});
nd4j::ops::hardsigmoid op; nd4j::ops::hardsigmoid op;
auto result = op.execute({&matrix}, {}, {}, {}); auto result = op.execute({&matrix}, {}, {}, {});
@ -311,9 +311,9 @@ TEST_F(DeclarableOpsTests5, hardsigmoid_test1) {
//////////////////////////////////////////////////////////////////// ////////////////////////////////////////////////////////////////////
TEST_F(DeclarableOpsTests5, hardsigmoid_test2) { TEST_F(DeclarableOpsTests5, hardsigmoid_test2) {
auto matrix = NDArrayFactory::create<float>('c', {2, 2}, {1, 2, 3, 4}); auto matrix = NDArrayFactory::create<float>('c', {2, 2}, {1.f, 2.f, 3.f, 4.f});
auto eps = NDArrayFactory::create<float>('c', {2, 2}, {1, 2, 3, 4}); auto eps = NDArrayFactory::create<float>('c', {2, 2}, {1.f, 2.f, 3.f, 4.f});
auto exp = NDArrayFactory::create<float>('c', {2, 2}, {0.2, 0.4, 0, 0}); auto exp = NDArrayFactory::create<float>('c', {2, 2}, {0.2f, 0.4f, 0.f, 0.f});
nd4j::ops::hardsigmoid_bp op; nd4j::ops::hardsigmoid_bp op;
auto result = op.execute({&matrix, &eps}, {}, {}, {}); auto result = op.execute({&matrix, &eps}, {}, {}, {});
@ -327,8 +327,8 @@ TEST_F(DeclarableOpsTests5, hardsigmoid_test2) {
//////////////////////////////////////////////////////////////////// ////////////////////////////////////////////////////////////////////
TEST_F(DeclarableOpsTests5, hardtanh_test1) { TEST_F(DeclarableOpsTests5, hardtanh_test1) {
auto matrix = NDArrayFactory::create<float>('c', {3, 3}, {-4, -3, -2, -1, 0, 1, 2, 3, 4}); auto matrix = NDArrayFactory::create<double>('c', {3, 3}, {-4, -3, -2, -1, 0, 1, 2, 3, 4});
auto exp = NDArrayFactory::create<float>('c', {3, 3}, {-1, -1, -1, -1, 0, 1, 1, 1, 1}); auto exp = NDArrayFactory::create<double>('c', {3, 3}, {-1, -1, -1, -1, 0, 1, 1, 1, 1});
nd4j::ops::hardtanh op; nd4j::ops::hardtanh op;
auto result = op.execute({&matrix}, {}, {}, {}); auto result = op.execute({&matrix}, {}, {}, {});
@ -342,9 +342,9 @@ TEST_F(DeclarableOpsTests5, hardtanh_test1) {
} }
//////////////////////////////////////////////////////////////////// ////////////////////////////////////////////////////////////////////
TEST_F(DeclarableOpsTests5, hardtanh_test2) { TEST_F(DeclarableOpsTests5, hardtanh_test2) {
auto matrix = NDArrayFactory::create<float>('c', {3, 3}, {-4, -3, -2, -1, 0, 1, 2, 3, 4}); auto matrix = NDArrayFactory::create<double>('c', {3, 3}, {-4, -3, -2, -1, 0, 1, 2, 3, 4});
auto eps = NDArrayFactory::create<float>('c', {3, 3}, {1, 2, 3, 4, 5, 6, 7, 8, 9}); auto eps = NDArrayFactory::create<double>('c', {3, 3}, {1, 2, 3, 4, 5, 6, 7, 8, 9});
auto exp = NDArrayFactory::create<float>('c', {3, 3}, {0, 0, 0, 4, 5, 6, 0, 0, 0}); auto exp = NDArrayFactory::create<double>('c', {3, 3}, {0, 0, 0, 4, 5, 6, 0, 0, 0});
nd4j::ops::hardtanh_bp op; nd4j::ops::hardtanh_bp op;
auto result = op.execute({&matrix, &eps}, {}, {}, {}); auto result = op.execute({&matrix, &eps}, {}, {}, {});
@ -389,7 +389,7 @@ TEST_F(DeclarableOpsTests5, histogram_test2) {
//////////////////////////////////////////////////////////////////// ////////////////////////////////////////////////////////////////////
TEST_F(DeclarableOpsTests5, Identity_test1) { TEST_F(DeclarableOpsTests5, Identity_test1) {
auto matrix = NDArrayFactory::create<float>('c', {3, 3}, {-4, -3, -2, -1, 0, 1, 2, 3, 4}); auto matrix = NDArrayFactory::create<float>('c', {3, 3}, {-4.f, -3.f, -2.f, -1.f, 0.f, 1.f, 2.f, 3.f, 4.f});
// auto exp = NDArrayFactory::create<Nd4jLong>('c', {3, 3}, {3, 3, 3}); // auto exp = NDArrayFactory::create<Nd4jLong>('c', {3, 3}, {3, 3, 3});
nd4j::ops::identity op; nd4j::ops::identity op;
@ -404,8 +404,8 @@ TEST_F(DeclarableOpsTests5, Identity_test1) {
//////////////////////////////////////////////////////////////////// ////////////////////////////////////////////////////////////////////
TEST_F(DeclarableOpsTests5, Identity_test2) { TEST_F(DeclarableOpsTests5, Identity_test2) {
auto matrix = NDArrayFactory::create<float>('c', {3, 3}, {-4, -3, -2, -1, 0, 1, 2, 3, 4}); auto matrix = NDArrayFactory::create<double>('c', {3, 3}, {-4, -3, -2, -1, 0, 1, 2, 3, 4});
auto eps = NDArrayFactory::create<float>('c', {3, 3}, {1,2,3,4,5,6,7,8,9}); auto eps = NDArrayFactory::create<double>('c', {3, 3}, {1,2,3,4,5,6,7,8,9});
// auto exp = NDArrayFactory::create<float>('c', {3,3}); // auto exp = NDArrayFactory::create<float>('c', {3,3});
nd4j::ops::identity_bp op; nd4j::ops::identity_bp op;
auto result = op.execute({&matrix, &eps}, {}, {}, {}); auto result = op.execute({&matrix, &eps}, {}, {}, {});
@ -418,8 +418,8 @@ TEST_F(DeclarableOpsTests5, Identity_test2) {
} }
//////////////////////////////////////////////////////////////////// ////////////////////////////////////////////////////////////////////
TEST_F(DeclarableOpsTests5, Log1p_test1) { TEST_F(DeclarableOpsTests5, Log1p_test1) {
auto matrix = NDArrayFactory::create<float>('c', {3, 3}, {4, 3, 2, 1, 0, 1, 2, 3, 4}); auto matrix = NDArrayFactory::create<double>('c', {3, 3}, {4, 3, 2, 1, 0, 1, 2, 3, 4});
auto y = NDArrayFactory::create<float>('c', {3,3}, {5,4,3,2,1,2,3,4,5}); auto y = NDArrayFactory::create<double>('c', {3,3}, {5,4,3,2,1,2,3,4,5});
// auto eps = NDArrayFactory::create<float>('c', {3, 3}, {1,2,3,4,5,6,7,8,9}); // auto eps = NDArrayFactory::create<float>('c', {3, 3}, {1,2,3,4,5,6,7,8,9});
// auto exp = NDArrayFactory::create<float>('c', {3,3}); // auto exp = NDArrayFactory::create<float>('c', {3,3});
nd4j::ops::Log1p op; nd4j::ops::Log1p op;
@ -599,7 +599,7 @@ TEST_F(DeclarableOpsTests5, Test_BatchToSpace_4) {
////////////////////////////////////////////////////////////////////// //////////////////////////////////////////////////////////////////////
TEST_F(DeclarableOpsTests5, eye_test1) { TEST_F(DeclarableOpsTests5, eye_test1) {
auto expected = NDArrayFactory::create<float>('c', {3, 3}, {1, 0, 0, 0, 1, 0, 0, 0, 1}); auto expected = NDArrayFactory::create<float>('c', {3, 3}, {1.f, 0.f, 0.f, 0.f, 1.f, 0.f, 0.f, 0.f, 1.f});
nd4j::ops::eye op; nd4j::ops::eye op;
auto results = op.execute({}, {}, {-99, 3}); auto results = op.execute({}, {}, {-99, 3});
@ -616,7 +616,7 @@ TEST_F(DeclarableOpsTests5, eye_test1) {
////////////////////////////////////////////////////////////////////// //////////////////////////////////////////////////////////////////////
TEST_F(DeclarableOpsTests5, eye_test2) { TEST_F(DeclarableOpsTests5, eye_test2) {
auto expected = NDArrayFactory::create<float>('c', {3, 4}, {1, 0, 0, 0, 0, 1, 0, 0, 0, 0, 1, 0}); auto expected = NDArrayFactory::create<float>('c', {3, 4}, {1.f, 0.f, 0.f, 0.f, 0.f, 1.f, 0.f, 0.f, 0.f, 0.f, 1.f, 0.f});
nd4j::ops::eye op; nd4j::ops::eye op;
auto results = op.execute({}, {}, {-99, 3, 4}); auto results = op.execute({}, {}, {-99, 3, 4});

View File

@ -348,8 +348,8 @@ TEST_F(DeclarableOpsTests6, cumSum_1) {
} }
TEST_F(DeclarableOpsTests6, cumSum_2) { TEST_F(DeclarableOpsTests6, cumSum_2) {
auto x= NDArrayFactory::create<float>('c', {2, 4}, {1, 2, 3, 4, 1, 2, 3, 4}); auto x= NDArrayFactory::create<float>('c', {2, 4}, {1.f, 2.f, 3.f, 4.f, 1.f, 2.f, 3.f, 4.f});
auto exp= NDArrayFactory::create<float>('c', {2, 4}, {1, 3, 6, 10, 1, 3, 6, 10}); auto exp= NDArrayFactory::create<float>('c', {2, 4}, {1.f, 3.f, 6.f, 10.f, 1.f, 3.f, 6.f, 10.f});
nd4j::ops::cumsum op; nd4j::ops::cumsum op;
auto result = op.execute({&x}, {}, {0, 0, 1}); auto result = op.execute({&x}, {}, {0, 0, 1});
@ -365,8 +365,8 @@ TEST_F(DeclarableOpsTests6, cumSum_2) {
} }
TEST_F(DeclarableOpsTests6, cumSum_3) { TEST_F(DeclarableOpsTests6, cumSum_3) {
auto x= NDArrayFactory::create<float>('c', {2, 4}, {1, 2, 3, 4, 1, 2, 3, 4}); auto x= NDArrayFactory::create<float>('c', {2, 4}, {1.f, 2.f, 3.f, 4.f, 1.f, 2.f, 3.f, 4.f});
auto exp= NDArrayFactory::create<float>('c', {2, 4}, {1, 2, 3, 4, 2, 4, 6, 8}); auto exp= NDArrayFactory::create<float>('c', {2, 4}, {1.f, 2.f, 3.f, 4.f, 2.f, 4.f, 6.f, 8.f});
nd4j::ops::cumsum op; nd4j::ops::cumsum op;
auto result = op.execute({&x}, {}, {0, 0, 0}); auto result = op.execute({&x}, {}, {0, 0, 0});
@ -649,13 +649,13 @@ TEST_F(DeclarableOpsTests6, cumSum_17) {
NDArray exp0 = exp(0, {0}); NDArray exp0 = exp(0, {0});
NDArray exp1 = exp(1, {0}); NDArray exp1 = exp(1, {0});
exp0.p<float>(0, 1.); exp0.p(0, 1.);
exp1.p<float>(0, 1.); exp1.p(0, 1.);
for (int i = 1; i < 1500; ++i) { for (int i = 1; i < 1500; ++i) {
const auto prev = exp0.e<float>(i-1); const auto prev = exp0.e<float>(i-1);
exp0.p<float>(i, prev + i + 1); exp0.p(i, prev + i + 1);
exp1.p<float>(i, prev + i + 1); exp1.p(i, prev + i + 1);
} }
nd4j::ops::cumsum op; nd4j::ops::cumsum op;
@ -682,13 +682,13 @@ TEST_F(DeclarableOpsTests6, cumSum_18) {
NDArray exp0 = exp(0, {0}); NDArray exp0 = exp(0, {0});
NDArray exp1 = exp(1, {0}); NDArray exp1 = exp(1, {0});
exp0.p<float>(0, 0.); exp0.p(0, 0.);
exp1.p<float>(0, 0.); exp1.p(0, 0.);
for (int i = 1; i < 1500; ++i) { for (int i = 1; i < 1500; ++i) {
const auto prev = exp0.e<float>(i-1); const auto prev = exp0.e<float>(i-1);
exp0.p<float>(i, prev + i); exp0.p(i, prev + i);
exp1.p<float>(i, prev + i); exp1.p(i, prev + i);
} }
nd4j::ops::cumsum op; nd4j::ops::cumsum op;
@ -715,13 +715,13 @@ TEST_F(DeclarableOpsTests6, cumSum_19) {
NDArray exp0 = exp(0, {0}); NDArray exp0 = exp(0, {0});
NDArray exp1 = exp(1, {0}); NDArray exp1 = exp(1, {0});
exp0.p<float>(1499, 1500.); exp0.p(1499, 1500.f);
exp1.p<float>(1499, 1500.); exp1.p(1499, 1500.f);
for (int i = 1498; i >= 0; --i) { for (int i = 1498; i >= 0; --i) {
const auto prev = exp0.e<float>(i + 1); const auto prev = exp0.e<float>(i + 1);
exp0.p<float>(i, prev + i + 1); exp0.p(i, prev + i + 1);
exp1.p<float>(i, prev + i + 1); exp1.p(i, prev + i + 1);
} }
nd4j::ops::cumsum op; nd4j::ops::cumsum op;
@ -749,13 +749,13 @@ TEST_F(DeclarableOpsTests6, cumSum_20) {
NDArray exp0 = exp(0, {0}); NDArray exp0 = exp(0, {0});
NDArray exp1 = exp(1, {0}); NDArray exp1 = exp(1, {0});
exp0.p<float>(1499, 0.); exp0.p(1499, 0.);
exp1.p<float>(1499, 0.); exp1.p(1499, 0.);
for (int i = 1498; i >= 0; --i) { for (int i = 1498; i >= 0; --i) {
const auto prev = exp0.e<float>(i + 1); const auto prev = exp0.e<float>(i + 1);
exp0.p<float>(i, prev + i + 2); exp0.p(i, prev + i + 2);
exp1.p<float>(i, prev + i + 2); exp1.p(i, prev + i + 2);
} }
nd4j::ops::cumsum op; nd4j::ops::cumsum op;
@ -1576,7 +1576,7 @@ TEST_F(DeclarableOpsTests6, LogDet_3) {
//////////////////////////////////////////////////////////////////////////////// ////////////////////////////////////////////////////////////////////////////////
TEST_F(DeclarableOpsTests6, MatrixInverse_1) { TEST_F(DeclarableOpsTests6, MatrixInverse_1) {
auto x = NDArrayFactory::create<float>('c', {2, 5, 5}, { auto x = NDArrayFactory::create<double>('c', {2, 5, 5}, {
2., 4., 60., 8., 10., 2., 4., 60., 8., 10.,
0., 1., 2., 3., 4., 0., 1., 2., 3., 4.,
0., 0., 2., 4., 6., 0., 0., 2., 4., 6.,
@ -1590,7 +1590,7 @@ TEST_F(DeclarableOpsTests6, MatrixInverse_1) {
5., 4., 3., 2., 1., 5., 4., 3., 2., 1.,
}); });
auto exp = NDArrayFactory::create<float>('c', {2, 5, 5}, { auto exp = NDArrayFactory::create<double>('c', {2, 5, 5}, {
0.5, -2.0, -13.0, 54.0, -6.75, 0.5, -2.0, -13.0, 54.0, -6.75,
0.0, 1.0, -1.0, 1.0, 0.0, 0.0, 1.0, -1.0, 1.0, 0.0,
0, 0, 0.5, -2.0, 0.25, 0, 0, 0.5, -2.0, 0.25,
@ -1620,8 +1620,8 @@ TEST_F(DeclarableOpsTests6, MatrixInverse_1) {
//////////////////////////////////////////////////////////////////////////////// ////////////////////////////////////////////////////////////////////////////////
TEST_F(DeclarableOpsTests6, MatrixInverse_010) { TEST_F(DeclarableOpsTests6, MatrixInverse_010) {
auto x = NDArrayFactory::create<float>('c', {1, 5, 5}, {1., 0., 0., 0., 0.,2., 1., 0., 0., 0.,30., 2., 1., 0., 0.,4., 3., 2., 1., 0.,5., 4., 3., 2., 1.,}); auto x = NDArrayFactory::create<double>('c', {1, 5, 5}, {1., 0., 0., 0., 0.,2., 1., 0., 0., 0.,30., 2., 1., 0., 0.,4., 3., 2., 1., 0.,5., 4., 3., 2., 1.,});
auto exp = NDArrayFactory::create<float>('c', {1, 5, 5}, {1.0, 0.0, 0.0, 0.0, 0.,-2.0, 1.0, 0., 0., 0.,-26.0, -2.0, 1, 0, 0.,54.0, 1.0, -2.0, 1, 0.,-27.0, 0.0, 1.0, -2.0, 1.}); auto exp = NDArrayFactory::create<double>('c', {1, 5, 5}, {1.0, 0.0, 0.0, 0.0, 0.,-2.0, 1.0, 0., 0., 0.,-26.0, -2.0, 1, 0, 0.,54.0, 1.0, -2.0, 1, 0.,-27.0, 0.0, 1.0, -2.0, 1.});
nd4j::ops::matrix_inverse op; nd4j::ops::matrix_inverse op;
auto result = op.execute({&x}, {}, {}, {}, false, nd4j::DataType::FLOAT32); auto result = op.execute({&x}, {}, {}, {}, false, nd4j::DataType::FLOAT32);
@ -1639,9 +1639,9 @@ TEST_F(DeclarableOpsTests6, MatrixInverse_010) {
//////////////////////////////////////////////////////////////////////////////// ////////////////////////////////////////////////////////////////////////////////
TEST_F(DeclarableOpsTests6, MatrixInverse_01) { TEST_F(DeclarableOpsTests6, MatrixInverse_01) {
auto x = NDArrayFactory::create<float>('c', {1, 5, 5}, {2., 4., 60., 8., 10., 0., 1., 2., 3., 4., 0., 0., 2., 4., 6., 0., 0., 0., 1., 2., 0., 0., 0., 0., 4. }); auto x = NDArrayFactory::create<double>('c', {1, 5, 5}, {2., 4., 60., 8., 10., 0., 1., 2., 3., 4., 0., 0., 2., 4., 6., 0., 0., 0., 1., 2., 0., 0., 0., 0., 4. });
auto exp = NDArrayFactory::create<float>('c', {1, 5, 5}, {0.5, -2.0, -13.0, 54.0, -6.75, 0.0, 1.0, -1.0, 1.0, 0.0, 0, 0, 0.5, -2.0, 0.25, 0, 0, 0, 1.0, -0.5, 0, 0, 0, 0, 0.25 }); auto exp = NDArrayFactory::create<double>('c', {1, 5, 5}, {0.5, -2.0, -13.0, 54.0, -6.75, 0.0, 1.0, -1.0, 1.0, 0.0, 0, 0, 0.5, -2.0, 0.25, 0, 0, 0, 1.0, -0.5, 0, 0, 0, 0, 0.25 });
nd4j::ops::matrix_inverse op; nd4j::ops::matrix_inverse op;
auto result = op.execute({&x}, {}, {}, {}, false, nd4j::DataType::FLOAT32); auto result = op.execute({&x}, {}, {}, {}, false, nd4j::DataType::FLOAT32);
@ -1658,8 +1658,8 @@ TEST_F(DeclarableOpsTests6, MatrixInverse_01) {
//////////////////////////////////////////////////////////////////////////////// ////////////////////////////////////////////////////////////////////////////////
TEST_F(DeclarableOpsTests6, MatrixInverse_02) { TEST_F(DeclarableOpsTests6, MatrixInverse_02) {
auto x = NDArrayFactory::create<float>('c', {1, 5, 5}, {1., 0., 0., 0., 0., 2., 1., 0., 0., 0., 30., 2., 1., 0., 0., 4., 3., 2., 1., 0., 5., 4., 3., 2., 1. }); auto x = NDArrayFactory::create<double>('c', {1, 5, 5}, {1., 0., 0., 0., 0., 2., 1., 0., 0., 0., 30., 2., 1., 0., 0., 4., 3., 2., 1., 0., 5., 4., 3., 2., 1. });
auto exp = NDArrayFactory::create<float>('c', {1, 5, 5}, {1.0, 0.0, 0.0, 0.0, 0., -2.0, 1.0, 0., 0., 0., -26.0, -2.0, 1, 0, 0., 54.0, 1.0, -2.0, 1, 0., -27.0, 0.0, 1.0, -2.0, 1. }); auto exp = NDArrayFactory::create<double>('c', {1, 5, 5}, {1.0, 0.0, 0.0, 0.0, 0., -2.0, 1.0, 0., 0., 0., -26.0, -2.0, 1, 0, 0., 54.0, 1.0, -2.0, 1, 0., -27.0, 0.0, 1.0, -2.0, 1. });
nd4j::ops::matrix_inverse op; nd4j::ops::matrix_inverse op;
auto result = op.execute({&x}, {}, {}, {}, false, nd4j::DataType::FLOAT32); auto result = op.execute({&x}, {}, {}, {}, false, nd4j::DataType::FLOAT32);
@ -1724,19 +1724,19 @@ TEST_F(DeclarableOpsTests6, MatrixInverse_2) {
TEST_F(DeclarableOpsTests6, MatrixInverse_03) { TEST_F(DeclarableOpsTests6, MatrixInverse_03) {
auto x = NDArrayFactory::create<float>('c', {5, 5}, { auto x = NDArrayFactory::create<float>('c', {5, 5}, {
4., 0., 0., 0., 0., 4.f, 0.f, 0.f, 0.f, 0.f,
4., 2., 0., 0., 0., 4.f, 2.f, 0.f, 0.f, 0.f,
30., 2., 1., 0., 0., 30.f, 2.f, 1.f, 0.f, 0.f,
8., 6., 4., 2., 0., 8.f, 6.f, 4.f, 2.f, 0.f,
15., 12., 9., 6., 3., 15.f, 12.f, 9.f, 6.f, 3.f,
}); });
auto exp = NDArrayFactory::create<float>('c', {5, 5}, { auto exp = NDArrayFactory::create<float>('c', {5, 5}, {
0.25, 0.0, 0.0, 0.0, 0.0, 0.25f, 0.0f, 0.0f, 0.0f, 0.0f,
-0.50, 0.5, 0.0, 0.0, 0.0, -0.50f, 0.5f, 0.0f, 0.0f, 0.0f,
-6.50, -1.0, 1.0, 0.0, 0.0, -6.50f, -1.0f, 1.0f, 0.0f, 0.0f,
13.50, 0.5, -2.0, 0.5, 0.0, 13.50f, 0.5f, -2.0f, 0.5f, 0.0f,
-6.75, 0.0, 1.0, -1.0, 0.33333333 -6.75f, 0.0f, 1.0f, -1.0f, 0.33333333f
}); });
nd4j::ops::matrix_inverse op; nd4j::ops::matrix_inverse op;
@ -1758,19 +1758,19 @@ TEST_F(DeclarableOpsTests6, MatrixInverse_03) {
TEST_F(DeclarableOpsTests6, MatrixInverse_3) { TEST_F(DeclarableOpsTests6, MatrixInverse_3) {
auto x = NDArrayFactory::create<float>('c', {5, 5}, { auto x = NDArrayFactory::create<float>('c', {5, 5}, {
4., 0., 0., 0., 0., 4.f, 0.f, 0.f, 0.f, 0.f,
4., 2., 0., 0., 0., 4.f, 2.f, 0.f, 0.f, 0.f,
30., 2., 1., 0., 0., 30.f, 2.f, 1.f, 0.f, 0.f,
8., 6., 4., 2., 0., 8.f, 6.f, 4.f, 2.f, 0.f,
15., 12., 9., 6., 3., 15.f, 12.f, 9.f, 6.f, 3.f,
}); });
auto exp = NDArrayFactory::create<float>('c', {5, 5}, { auto exp = NDArrayFactory::create<float>('c', {5, 5}, {
0.25, 0.0, 0.0, 0.0, 0.0, 0.25f, 0.0f, 0.0f, 0.0f, 0.0f,
-0.50, 0.5, 0.0, 0.0, 0.0, -0.50f, 0.5f, 0.0f, 0.0f, 0.0f,
-6.50, -1.0, 1.0, 0.0, 0.0, -6.50f, -1.0f, 1.0f, 0.0f, 0.0f,
13.50, 0.5, -2.0, 0.5, 0.0, 13.50f, 0.5f, -2.0f, 0.5f, 0.0f,
-6.75, 0.0, 1.0, -1.0, 0.33333333 -6.75f, 0.0f, 1.0f, -1.0f, 0.33333333f
}); });
nd4j::ops::matrix_inverse op; nd4j::ops::matrix_inverse op;
@ -1792,19 +1792,19 @@ TEST_F(DeclarableOpsTests6, MatrixInverse_3) {
TEST_F(DeclarableOpsTests6, MatrixInverse_4) { TEST_F(DeclarableOpsTests6, MatrixInverse_4) {
auto x = NDArrayFactory::create<float>('c', {5, 5}, { auto x = NDArrayFactory::create<float>('c', {5, 5}, {
1., 2., 30., 4., 5., 1.f, 2.f, 30.f, 4.f, 5.f,
0., 1., 2., 3., 4., 0.f, 1.f, 2.f, 3.f, 4.f,
0., 0., 1., 2., 3., 0.f, 0.f, 1.f, 2.f, 3.f,
0., 0., 0., 1., 2., 0.f, 0.f, 0.f, 1.f, 2.f,
0., 0., 0., 0., 1. 0.f, 0.f, 0.f, 0.f, 1.f
}); });
auto exp = NDArrayFactory::create<float>('c', {5, 5}, { auto exp = NDArrayFactory::create<float>('c', {5, 5}, {
1.0, -2.0, -26.0, 54.0, -27.0, 1.0f, -2.0f, -26.0f, 54.0f, -27.0f,
0.0, 1.0, -2.0, 1.0, 0.0, 0.0f, 1.0f, -2.0f, 1.0f, 0.0f,
0.0, 0.0, 1.0, -2.0, 1.0, 0.0f, 0.0f, 1.0f, -2.0f, 1.0f,
0.0, 0.0, 0.0, 1.0, -2.0, 0.0f, 0.0f, 0.0f, 1.0f, -2.0f,
0.0, 0.0, 0.0, 0.0, 1.0 0.0f, 0.0f, 0.0f, 0.0f, 1.0f
}); });
nd4j::ops::matrix_inverse op; nd4j::ops::matrix_inverse op;
@ -1826,19 +1826,19 @@ TEST_F(DeclarableOpsTests6, MatrixInverse_4) {
TEST_F(DeclarableOpsTests6, MatrixInverse_04) { TEST_F(DeclarableOpsTests6, MatrixInverse_04) {
auto x = NDArrayFactory::create<float>('c', {5, 5}, { auto x = NDArrayFactory::create<float>('c', {5, 5}, {
1., 2., 30., 4., 5., 1.f, 2.f, 30.f, 4.f, 5.f,
0., 1., 2., 3., 4., 0.f, 1.f, 2.f, 3.f, 4.f,
0., 0., 1., 2., 3., 0.f, 0.f, 1.f, 2.f, 3.f,
0., 0., 0., 1., 2., 0.f, 0.f, 0.f, 1.f, 2.f,
0., 0., 0., 0., 1. 0.f, 0.f, 0.f, 0.f, 1.f
}); });
auto exp = NDArrayFactory::create<float>('c', {5, 5}, { auto exp = NDArrayFactory::create<float>('c', {5, 5}, {
1.0, -2.0, -26.0, 54.0, -27.0, 1.0f, -2.0f, -26.0f, 54.0f, -27.0f,
0.0, 1.0, -2.0, 1.0, 0.0, 0.0f, 1.0f, -2.0f, 1.0f, 0.0f,
0.0, 0.0, 1.0, -2.0, 1.0, 0.0f, 0.0f, 1.0f, -2.0f, 1.0f,
0.0, 0.0, 0.0, 1.0, -2.0, 0.0f, 0.0f, 0.0f, 1.0f, -2.0f,
0.0, 0.0, 0.0, 0.0, 1.0 0.0f, 0.0f, 0.0f, 0.0f, 1.0f
}); });
nd4j::ops::matrix_inverse op; nd4j::ops::matrix_inverse op;

View File

@ -1097,9 +1097,9 @@ TEST_F(DeclarableOpsTests7, TestSegmentMin_01) {
} }
//////////////////////////////////////////////////////////////////////////////// ////////////////////////////////////////////////////////////////////////////////
TEST_F(DeclarableOpsTests7, TestSegmentMin_02) { TEST_F(DeclarableOpsTests7, TestSegmentMin_02) {
auto x = NDArrayFactory::create<float>({1.8, -2.5,4., -9., 2.1, 2.4,-3.,-9., 2.1, 2.1,0.7, 0.1, 3., -4.2, 2.2, 1.}); auto x = NDArrayFactory::create<float>({1.8f, -2.5f, 4.f, -9.f, 2.1f, 2.4f, -3.f, -9.f, 2.1f, 2.1f,0.7f, 0.1f, 3.f, -4.2f, 2.2f, 1.f});
auto idx = NDArrayFactory::create<int>({0, 0, 1, 1, 1, 1, 2, 3, 3, 3, 4, 4, 4, 4, 4, 4}); auto idx = NDArrayFactory::create<int>({0, 0, 1, 1, 1, 1, 2, 3, 3, 3, 4, 4, 4, 4, 4, 4});
auto exp = NDArrayFactory::create<float>({-2.5, -9, -3., -9, -4.2}); auto exp = NDArrayFactory::create<float>({-2.5f, -9.f, -3.f, -9.f, -4.2f});
nd4j::ops::segment_min op; nd4j::ops::segment_min op;
@ -1432,7 +1432,7 @@ TEST_F(DeclarableOpsTests7, TestSegmentMean_02) {
TEST_F(DeclarableOpsTests7, TestSegmentMean_021) { TEST_F(DeclarableOpsTests7, TestSegmentMean_021) {
auto x = NDArrayFactory::create<float>('c', {6, 3});//, {1, 2, 3., 4., 5., 6., 7., 8., 9., 10., 11., 12., 13., 14., 15., 16., 17., 18.}); auto x = NDArrayFactory::create<float>('c', {6, 3});//, {1, 2, 3., 4., 5., 6., 7., 8., 9., 10., 11., 12., 13., 14., 15., 16., 17., 18.});
auto idx = NDArrayFactory::create<int>({0, 0, 1, 1, 2,2}); auto idx = NDArrayFactory::create<int>({0, 0, 1, 1, 2,2});
auto exp = NDArrayFactory::create<float>('c', {3, 3}, { 2.5, 3.5, 4.5, 8.5, 9.5, 10.5, 14.5, 15.5, 16.5}); auto exp = NDArrayFactory::create<float>('c', {3, 3}, { 2.5f, 3.5f, 4.5f, 8.5f, 9.5f, 10.5f, 14.5f, 15.5f, 16.5f});
nd4j::ops::segment_mean op; nd4j::ops::segment_mean op;
x.linspace(1.); x.linspace(1.);
@ -1448,7 +1448,7 @@ TEST_F(DeclarableOpsTests7, TestSegmentMean_022) {
auto x = NDArrayFactory::create<float>('c', {6, 3});//, {1, 2, 3., 4., 5., 6., 7., 8., 9., 10., 11., 12., 13., 14., 15., 16., 17., 18.}); auto x = NDArrayFactory::create<float>('c', {6, 3});//, {1, 2, 3., 4., 5., 6., 7., 8., 9., 10., 11., 12., 13., 14., 15., 16., 17., 18.});
auto idx = NDArrayFactory::create<int>({0, 0, 1, 1, 2,2}); auto idx = NDArrayFactory::create<int>({0, 0, 1, 1, 2,2});
auto z = NDArrayFactory::create<float>('c', {3, 3}); //, { 2.5, 3.5, 4.5, 8.5, 9.5, 10.5, 14.5, 15.5, 16.5}); auto z = NDArrayFactory::create<float>('c', {3, 3}); //, { 2.5, 3.5, 4.5, 8.5, 9.5, 10.5, 14.5, 15.5, 16.5});
auto exp = NDArrayFactory::create<float>('c', {3, 3}, { 2.5, 3.5, 4.5, 8.5, 9.5, 10.5, 14.5, 15.5, 16.5}); auto exp = NDArrayFactory::create<float>('c', {3, 3}, { 2.5f, 3.5f, 4.5f, 8.5f, 9.5f, 10.5f, 14.5f, 15.5f, 16.5f});
nd4j::ops::segment_mean op; nd4j::ops::segment_mean op;
x.linspace(1.); x.linspace(1.);
@ -3897,9 +3897,9 @@ TEST_F(DeclarableOpsTests7, rectifiedtanh_test2) {
TEST_F(DeclarableOpsTests7, RealDiv_1) { TEST_F(DeclarableOpsTests7, RealDiv_1) {
NDArray x = NDArrayFactory::create<float>('c', {1, 2, 1}, {2, 4}); NDArray x = NDArrayFactory::create<float>('c', {1, 2, 1}, {2.f, 4.f});
NDArray y = NDArrayFactory::create<float>('c', {1, 2}, {1,2}); NDArray y = NDArrayFactory::create<float>('c', {1, 2}, {1.f,2.f});
NDArray e = NDArrayFactory::create<float>('c', {1, 2, 2}, {2, 1, 4, 2}); NDArray e = NDArrayFactory::create<float>('c', {1, 2, 2}, {2.f, 1.f, 4.f, 2.f});
nd4j::ops::realdiv op; nd4j::ops::realdiv op;
auto result = op.execute({&x, &y}, {}, {}); auto result = op.execute({&x, &y}, {}, {});
@ -3917,11 +3917,11 @@ TEST_F(DeclarableOpsTests7, RealDiv_1) {
//////////////////////////////////////////////////////////////////////////////// ////////////////////////////////////////////////////////////////////////////////
TEST_F(DeclarableOpsTests7, RealDiv_BP_1) { TEST_F(DeclarableOpsTests7, RealDiv_BP_1) {
NDArray x = NDArrayFactory::create<float>('c', {1, 2, 1}, {2, 4}); NDArray x = NDArrayFactory::create<float>('c', {1, 2, 1}, {2.f, 4.f});
NDArray y = NDArrayFactory::create<float>('c', {1, 2}, {1,2}); NDArray y = NDArrayFactory::create<float>('c', {1, 2}, {1.f, 2.f});
NDArray e0 = NDArrayFactory::create<float>('c', {1, 2, 1}, {2, 5}); NDArray e0 = NDArrayFactory::create<float>('c', {1, 2, 1}, {2.f, 5.f});
NDArray e1 = NDArrayFactory::create<float>('c', {1, 2}, {-14, -5}); NDArray e1 = NDArrayFactory::create<float>('c', {1, 2}, {-14.f, -5.f});
NDArray eps = NDArrayFactory::create<float>('c', {1, 2, 2}, {1, 2, 3, 4}); NDArray eps = NDArrayFactory::create<float>('c', {1, 2, 2}, {1.f, 2.f, 3.f, 4.f});
nd4j::ops::realdiv_bp op; nd4j::ops::realdiv_bp op;
auto result = op.execute({&x, &y, &eps}, {}, {}); auto result = op.execute({&x, &y, &eps}, {}, {});
@ -3944,7 +3944,7 @@ TEST_F(DeclarableOpsTests7, RealDiv_BP_1) {
//////////////////////////////////////////////////////////////////////////////////////////////////////////////////////// ////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////
TEST_F(DeclarableOpsTests7, ShapesOf_1) { TEST_F(DeclarableOpsTests7, ShapesOf_1) {
NDArray x = NDArrayFactory::create<float>('c', {1, 2, 1}, {2, 4}); NDArray x = NDArrayFactory::create<float>('c', {1, 2, 1}, {2.f, 4.f});
// NDArray y = NDArrayFactory::create<float>('c', {1, 2}, {1,2}); // NDArray y = NDArrayFactory::create<float>('c', {1, 2}, {1,2});
NDArray e = NDArrayFactory::create<Nd4jLong>({1, 2, 1}); NDArray e = NDArrayFactory::create<Nd4jLong>({1, 2, 1});
@ -3964,8 +3964,8 @@ TEST_F(DeclarableOpsTests7, ShapesOf_1) {
//////////////////////////////////////////////////////////////////////////////////////////////////////////////////////// ////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////
TEST_F(DeclarableOpsTests7, ShapesOf_2) { TEST_F(DeclarableOpsTests7, ShapesOf_2) {
NDArray x = NDArrayFactory::create<float>('c', {1, 2, 1}, {2, 4}); NDArray x = NDArrayFactory::create<float>('c', {1, 2, 1}, {2.f, 4.f});
NDArray y = NDArrayFactory::create<float>('c', {1, 2}, {1,2}); NDArray y = NDArrayFactory::create<float>('c', {1, 2}, {1.f, 2.f});
NDArray e0 = NDArrayFactory::create<Nd4jLong>({1, 2, 1}); NDArray e0 = NDArrayFactory::create<Nd4jLong>({1, 2, 1});
NDArray e1 = NDArrayFactory::create<Nd4jLong>({1, 2}); NDArray e1 = NDArrayFactory::create<Nd4jLong>({1, 2});
@ -3987,8 +3987,8 @@ TEST_F(DeclarableOpsTests7, ShapesOf_2) {
TEST_F(DeclarableOpsTests7, Size_1) { TEST_F(DeclarableOpsTests7, Size_1) {
NDArray x = NDArrayFactory::create<float>('c', {1, 2, 1}, {2, 4}); NDArray x = NDArrayFactory::create<float>('c', {1, 2, 1}, {2.f, 4.f});
NDArray y = NDArrayFactory::create<float>('c', {5, 2}, {1,2,3,4,5,7,9,10, 10, 11}); NDArray y = NDArrayFactory::create<float>('c', {5, 2}, {1.f, 2.f, 3.f, 4.f, 5.f, 7.f, 9.f, 10.f, 10.f, 11.f});
NDArray e = NDArrayFactory::create<Nd4jLong>(2); NDArray e = NDArrayFactory::create<Nd4jLong>(2);
nd4j::ops::size op; nd4j::ops::size op;
@ -4006,8 +4006,8 @@ TEST_F(DeclarableOpsTests7, Size_1) {
TEST_F(DeclarableOpsTests7, Size_2) { TEST_F(DeclarableOpsTests7, Size_2) {
NDArray x = NDArrayFactory::create<float>('c', {1, 2, 1}, {2, 4}); NDArray x = NDArrayFactory::create<double>('c', {1, 2, 1}, {2, 4});
NDArray y = NDArrayFactory::create<float>('c', {5, 2}, {1,2,3,4,5,7,9,10, 10, 11}); NDArray y = NDArrayFactory::create<double>('c', {5, 2}, {1,2,3,4,5,7,9,10, 10, 11});
NDArray e = NDArrayFactory::create<Nd4jLong>(10); NDArray e = NDArrayFactory::create<Nd4jLong>(10);
nd4j::ops::size op; nd4j::ops::size op;
@ -4025,8 +4025,8 @@ TEST_F(DeclarableOpsTests7, Size_2) {
TEST_F(DeclarableOpsTests7, Softplus_1) { TEST_F(DeclarableOpsTests7, Softplus_1) {
NDArray x = NDArrayFactory::create<float>('c', {5, 2}, {1,2,3,4,5,7,9,10, 10, 11}); NDArray x = NDArrayFactory::create<double>('c', {5, 2}, {1,2,3,4,5,7,9,10, 10, 11});
NDArray e = NDArrayFactory::create<float>('c', {5, 2}, {1.3132616, 2.126928, 3.0485873, 4.01815, 5.0067153, 7.0009117, 9.000123, 10.000046, 10.000046, 11.000016}); NDArray e = NDArrayFactory::create<double>('c', {5, 2}, {1.3132616, 2.126928, 3.0485873, 4.01815, 5.0067153, 7.0009117, 9.000123, 10.000046, 10.000046, 11.000016});
nd4j::ops::softplus op; nd4j::ops::softplus op;
auto result = op.execute({&x}, {}, {}); auto result = op.execute({&x}, {}, {});
@ -4065,8 +4065,8 @@ TEST_F(DeclarableOpsTests7, Softplus_BP_1) {
TEST_F(DeclarableOpsTests7, Softsign_1) { TEST_F(DeclarableOpsTests7, Softsign_1) {
NDArray x = NDArrayFactory::create<float>('c', {5, 2}, {1,2,3,4,5,7,9,10, 10, 11}); NDArray x = NDArrayFactory::create<double>('c', {5, 2}, {1,2,3,4,5,7,9,10, 10, 11});
NDArray e = NDArrayFactory::create<float>('c', {5, 2}, {0.5, 0.6666667, 0.75, 0.8, 0.8333333, 0.875, 0.9, 0.90909094, 0.90909094, 0.9166667}); NDArray e = NDArrayFactory::create<double>('c', {5, 2}, {0.5, 0.6666667, 0.75, 0.8, 0.8333333, 0.875, 0.9, 0.90909094, 0.90909094, 0.9166667});
nd4j::ops::softsign op; nd4j::ops::softsign op;
auto result = op.execute({&x}, {}, {}); auto result = op.execute({&x}, {}, {});
@ -4213,7 +4213,7 @@ TEST_F(DeclarableOpsTests7, TypesConversion_test1) {
NDArray x = NDArrayFactory::create<double >('c', {5, 2}, {1,2,3,4,5,7,9,10, 10, 11}); NDArray x = NDArrayFactory::create<double >('c', {5, 2}, {1,2,3,4,5,7,9,10, 10, 11});
NDArray expI = NDArrayFactory::create<int>('c', {5, 2}, {1,2,3,4,5,7,9,10, 10, 11}); NDArray expI = NDArrayFactory::create<int>('c', {5, 2}, {1,2,3,4,5,7,9,10, 10, 11});
NDArray expL = NDArrayFactory::create<Nd4jLong>('c', {5, 2}, {1,2,3,4,5,7,9,10, 10, 11}); NDArray expL = NDArrayFactory::create<Nd4jLong>('c', {5, 2}, {1,2,3,4,5,7,9,10, 10, 11});
NDArray expF = NDArrayFactory::create<float>('c', {5, 2}, {1,2,3,4,5,7,9,10, 10, 11}); NDArray expF = NDArrayFactory::create<float>('c', {5, 2}, {1.f,2.f,3.f,4.f,5.f,7.f,9.f,10.f, 10.f, 11.f});
NDArray expF16 = NDArrayFactory::create<float16>('c', {5, 2}, {1.f,2.f,3.f,4.f,5.f,7.f,9.f,10.f, 10.f, 11.f}); NDArray expF16 = NDArrayFactory::create<float16>('c', {5, 2}, {1.f,2.f,3.f,4.f,5.f,7.f,9.f,10.f, 10.f, 11.f});
nd4j::ops::to_int32 op32; nd4j::ops::to_int32 op32;
@ -4239,7 +4239,7 @@ TEST_F(DeclarableOpsTests7, TypesConversion_test1) {
//////////////////////////////////////////////////////////////////////////////// ////////////////////////////////////////////////////////////////////////////////
TEST_F(DeclarableOpsTests7, TypesConversion_test2) { TEST_F(DeclarableOpsTests7, TypesConversion_test2) {
NDArray x = NDArrayFactory::create<double >('c', {5, 2}, {1,2,3,4,5,7,9,10, 10, 11}); NDArray x = NDArrayFactory::create<double >('c', {5, 2}, {1,2,3,4,5,7,9,10, 10, 11});
NDArray expF = NDArrayFactory::create<float>('c', {5, 2}, {1,2,3,4,5,7,9,10, 10, 11}); NDArray expF = NDArrayFactory::create<float>('c', {5, 2}, {1.f,2.f,3.f,4.f,5.f,7.f,9.f,10.f, 10.f, 11.f});
NDArray expH = NDArrayFactory::create<float16>('c', {5, 2}, {1.f,2.f,3.f,4.f,5.f,7.f,9.f,10.f, 10.f, 11.f}); NDArray expH = NDArrayFactory::create<float16>('c', {5, 2}, {1.f,2.f,3.f,4.f,5.f,7.f,9.f,10.f, 10.f, 11.f});
nd4j::ops::to_float32 op32; nd4j::ops::to_float32 op32;
@ -4291,7 +4291,7 @@ TEST_F(DeclarableOpsTests7, TypesConversion_test3) {
//////////////////////////////////////////////////////////////////////////////// ////////////////////////////////////////////////////////////////////////////////
TEST_F(DeclarableOpsTests7, TypesConversion_test4) { TEST_F(DeclarableOpsTests7, TypesConversion_test4) {
NDArray x = NDArrayFactory::create<Nd4jLong>('c', {5, 2}, {1,2,3,4,5,7,9,10, 10, 11}); NDArray x = NDArrayFactory::create<Nd4jLong>('c', {5, 2}, {1,2,3,4,5,7,9,10, 10, 11});
NDArray exp32 = NDArrayFactory::create<float>('c', {5, 2}, {1,2,3,4,5,7,9,10, 10, 11}); NDArray exp32 = NDArrayFactory::create<float>('c', {5, 2}, {1.f,2.f,3.f,4.f,5.f,7.f,9.f,10.f, 10.f, 11.f});
NDArray exp64 = NDArrayFactory::create<double>('c', {5, 2}, {1,2,3,4,5,7,9,10, 10, 11}); NDArray exp64 = NDArrayFactory::create<double>('c', {5, 2}, {1,2,3,4,5,7,9,10, 10, 11});
nd4j::ops::to_float32 op32; nd4j::ops::to_float32 op32;
@ -4968,8 +4968,8 @@ TEST_F(DeclarableOpsTests7, Test_Matmul_Once_Again) {
} }
TYPED_TEST(TypedDeclarableOpsTests7, Test_Pnorm_Once_Again) { TYPED_TEST(TypedDeclarableOpsTests7, Test_Pnorm_Once_Again) {
auto input = NDArrayFactory::create<TypeParam>('c', {1, 1, 5, 5}, {1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0, 11.0, 12.0, 13.0, 14.0, 15.0, 16.0, 17.0, 18.0, 19.0, 20.0, 21.0, 22.0, 23.0, 24.0, 25.0}); auto input = NDArrayFactory::create<TypeParam>('c', {1, 1, 5, 5}, {1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f, 7.0f, 8.0f, 9.0f, 10.0f, 11.0f, 12.0f, 13.0f, 14.0f, 15.0f, 16.0f, 17.0f, 18.0f, 19.0f, 20.0f, 21.0f, 22.0f, 23.0f, 24.0f, 25.0f});
auto exp = NDArrayFactory::create<TypeParam>('c', {1, 1, 5, 5}, {1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0, 11.0, 12.0, 13.0, 14.0, 15.0, 16.0, 17.0, 18.0, 19.0, 20.0, 21.0, 22.0, 23.0, 24.0, 25.0}); auto exp = NDArrayFactory::create<TypeParam>('c', {1, 1, 5, 5}, {1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f, 7.0f, 8.0f, 9.0f, 10.0f, 11.0f, 12.0f, 13.0f, 14.0f, 15.0f, 16.0f, 17.0f, 18.0f, 19.0f, 20.0f, 21.0f, 22.0f, 23.0f, 24.0f, 25.0f});
nd4j::ops::pnormpool2d op; nd4j::ops::pnormpool2d op;
auto result = op.execute({&input}, {}, {1,1, 1,1, 0,0, 1,1,1, 3, 0}); auto result = op.execute({&input}, {}, {1,1, 1,1, 0,0, 1,1,1, 3, 0});

View File

@ -3614,7 +3614,7 @@ TYPED_TEST(TypedDeclarableOpsTests8, LrnTest_BP_1) {
eps.linspace(1); eps.linspace(1);
// //
auto exp = NDArrayFactory::create<TypeParam>('c', {3,3,5,5}, { auto exp = NDArrayFactory::create<TypeParam>('c', {3,3,5,5}, {
0.238337, 0.309664, 0.334077, 0.376534, 0.342926, 0.370734, 0.362017, 0.354182, 0.379140, 0.376275, 0.380027, 0.368347, 0.356401, 0.378316, 0.381315, 0.382465, 0.370592, 0.357055, 0.377670, 0.382950, 0.383445, 0.371718, 0.357332, 0.377217, 0.383677, 0.383933, 0.372391, 0.357475, 0.376891, 0.384062, 0.384212, 0.372837, 0.357557, 0.376646, 0.384290, 0.384385, 0.373153, 0.357610, 0.376457, 0.384436, 0.384500, 0.373389, 0.357645, 0.376306, 0.384536, 0.384581, 0.373572, 0.357670, 0.376184, 0.384606, 0.384639, 0.373718, 0.357688, 0.376082, 0.384658, 0.384683, 0.373837, 0.357702, 0.375996, 0.384698, 0.384717, 0.373935, 0.357712, 0.375923, 0.384728, 0.384743, 0.374019, 0.357721, 0.375860, 0.384752, 0.384764, 0.374090, 0.357727, 0.375804, 0.384771, 0.384781, 0.374152, 0.357733, 0.375756, 0.384787, 0.384795, 0.374205, 0.357737, 0.375713, 0.384800, 0.384807, 0.374253, 0.357741, 0.375674, 0.384811, 0.384817, 0.374295, 0.357744, 0.375640, 0.384820, 0.384825, 0.374333, 0.357747, 0.375609, 0.384828, 0.384832, 0.374366, 0.357749, 0.375581, 0.384835, 0.384839, 0.374397, 0.357751, 0.375555, 0.384841, 0.384844, 0.374425, 0.357753, 0.375531, 0.384846, 0.384849, 0.374450, 0.357754, 0.375510, 0.384850, 0.384853, 0.374473, 0.357756, 0.375490, 0.384854, 0.384856, 0.374494, 0.357757, 0.375471, 0.384858, 0.384860, 0.374514, 0.357758, 0.375454, 0.384861, 0.384863, 0.374532, 0.357759, 0.375438, 0.384864, 0.384865, 0.374549, 0.357760, 0.375423, 0.384866, 0.384868, 0.374565, 0.357760, 0.375410, 0.384868, 0.384870, 0.374579, 0.357761, 0.375397, 0.384870, 0.384872, 0.374593, 0.357762, 0.375384, 0.384872, 0.384873, 0.374606, 0.357762, 0.375373, 0.384874, 0.384875, 0.374618, 0.357763, 0.375362, 0.384875, 0.384876, 0.374629, 0.357763, 0.375352, 0.384877, 0.384878, 0.374640, 0.357764, 0.375342, 0.384878, 0.384879, 0.374650, 0.357764, 0.375333, 0.384879, 0.384880, 0.374660, 0.357764, 0.375325, 0.384880, 0.384881, 0.374669, 0.357765, 0.375316, 0.384881, 0.384882, 0.374677, 0.357765, 0.375309, 0.384882, 0.384883, 0.374685, 0.357765, 0.375301, 0.384883, 0.384884, 0.374693, 0.357765, 0.375294, 0.384884, 0.384884, 0.374700, 0.357766, 0.375287, 0.384885, 0.384885, 0.374707, 0.357766, 0.375281, 0.384885, 0.384886, 0.374714, 0.357766, 0.375275, 0.384886} 0.238337f, 0.309664f, 0.334077f, 0.376534f, 0.342926f, 0.370734f, 0.362017f, 0.354182f, 0.379140f, 0.376275f, 0.380027f, 0.368347f, 0.356401f, 0.378316f, 0.381315f, 0.382465f, 0.370592f, 0.357055f, 0.377670f, 0.382950f, 0.383445f, 0.371718f, 0.357332f, 0.377217f, 0.383677f, 0.383933f, 0.372391f, 0.357475f, 0.376891f, 0.384062f, 0.384212f, 0.372837f, 0.357557f, 0.376646f, 0.384290f, 0.384385f, 0.373153f, 0.357610f, 0.376457f, 0.384436f, 0.384500f, 0.373389f, 0.357645f, 0.376306f, 0.384536f, 0.384581f, 0.373572f, 0.357670f, 0.376184f, 0.384606f, 0.384639f, 0.373718f, 0.357688f, 0.376082f, 0.384658f, 0.384683f, 0.373837f, 0.357702f, 0.375996f, 0.384698f, 0.384717f, 0.373935f, 0.357712f, 0.375923f, 0.384728f, 0.384743f, 0.374019f, 0.357721f, 0.375860f, 0.384752f, 0.384764f, 0.374090f, 0.357727f, 0.375804f, 0.384771f, 0.384781f, 0.374152f, 0.357733f, 0.375756f, 0.384787f, 0.384795f, 0.374205f, 0.357737f, 0.375713f, 0.384800f, 0.384807f, 0.374253f, 0.357741f, 0.375674f, 0.384811f, 0.384817f, 0.374295f, 0.357744f, 0.375640f, 0.384820f, 0.384825f, 0.374333f, 0.357747f, 0.375609f, 0.384828f, 0.384832f, 0.374366f, 0.357749f, 0.375581f, 0.384835f, 0.384839f, 0.374397f, 0.357751f, 0.375555f, 0.384841f, 0.384844f, 0.374425f, 0.357753f, 0.375531f, 0.384846f, 0.384849f, 0.374450f, 0.357754f, 0.375510f, 0.384850f, 0.384853f, 0.374473f, 0.357756f, 0.375490f, 0.384854f, 0.384856f, 0.374494f, 0.357757f, 0.375471f, 0.384858f, 0.384860f, 0.374514f, 0.357758f, 0.375454f, 0.384861f, 0.384863f, 0.374532f, 0.357759f, 0.375438f, 0.384864f, 0.384865f, 0.374549f, 0.357760f, 0.375423f, 0.384866f, 0.384868f, 0.374565f, 0.357760f, 0.375410f, 0.384868f, 0.384870f, 0.374579f, 0.357761f, 0.375397f, 0.384870f, 0.384872f, 0.374593f, 0.357762f, 0.375384f, 0.384872f, 0.384873f, 0.374606f, 0.357762f, 0.375373f, 0.384874f, 0.384875f, 0.374618f, 0.357763f, 0.375362f, 0.384875f, 0.384876f, 0.374629f, 0.357763f, 0.375352f, 0.384877f, 0.384878f, 0.374640f, 0.357764f, 0.375342f, 0.384878f, 0.384879f, 0.374650f, 0.357764f, 0.375333f, 0.384879f, 0.384880f, 0.374660f, 0.357764f, 0.375325f, 0.384880f, 0.384881f, 0.374669f, 0.357765f, 0.375316f, 0.384881f, 0.384882f, 0.374677f, 0.357765f, 0.375309f, 0.384882f, 0.384883f, 0.374685f, 0.357765f, 0.375301f, 0.384883f, 0.384884f, 0.374693f, 0.357765f, 0.375294f, 0.384884f, 0.384884f, 0.374700f, 0.357766f, 0.375287f, 0.384885f, 0.384885f, 0.374707f, 0.357766f, 0.375281f, 0.384885f, 0.384886f, 0.374714f, 0.357766f, 0.375275f, 0.384886f}
); );
/// ///
nd4j::ops::lrn_bp op; nd4j::ops::lrn_bp op;
@ -3636,65 +3636,65 @@ TYPED_TEST(TypedDeclarableOpsTests8, LrnTest_BP_2) {
auto x = NDArrayFactory::create<TypeParam>( 'c', {3, 3, 5, 5}); auto x = NDArrayFactory::create<TypeParam>( 'c', {3, 3, 5, 5});
x.linspace(1); x.linspace(1);
auto eps = NDArrayFactory::create<TypeParam>('c', {3, 3, 5, 5}, { 0.2581989 ,0.3592106 , 0.40089184, 0.53935987, 0.70014, auto eps = NDArrayFactory::create<TypeParam>('c', {3, 3, 5, 5}, { 0.2581989f, 0.3592106f, 0.40089184f, 0.53935987f, 0.70014f,
0.4898979 ,0.46056613, 0.43971977, 0.5240002 , 0.6375767, 0.4898979f, 0.46056613f, 0.43971977f, 0.5240002f, 0.6375767f,
0.5274096 ,0.47771242, 0.4443308 , 0.5163977 , 0.61701745, 0.5274096f, 0.47771242f, 0.4443308f, 0.5163977f, 0.61701745f,
0.5424508 ,0.48452914, 0.44570294, 0.5123918 , 0.6068971, 0.5424508f, 0.48452914f, 0.44570294f, 0.5123918f, 0.6068971f,
0.5505386 ,0.4881662 , 0.4462865 , 0.5099462 , 0.60088515, 0.5505386f, 0.4881662f, 0.4462865f, 0.5099462f, 0.60088515f,
0.5555859 , 0.49042296, 0.44658744, 0.5083028 , 0.59690416, 0.5555859f, 0.49042296f, 0.44658744f, 0.5083028f, 0.59690416f,
0.55903524, 0.4919585 , 0.44676256, 0.5071239 , 0.59407425, 0.55903524f, 0.4919585f, 0.44676256f, 0.5071239f, 0.59407425f,
0.5615412 , 0.49307042, 0.44687328, 0.50623745, 0.5919596 , 0.5615412f, 0.49307042f, 0.44687328f, 0.50623745f, 0.5919596f,
0.56344414, 0.49391258, 0.4469477 , 0.5055468 , 0.59031945, 0.56344414f, 0.49391258f, 0.4469477f, 0.5055468f, 0.59031945f,
0.56493837, 0.49457246, 0.4470002 , 0.5049936 , 0.5890103 , 0.56493837f, 0.49457246f, 0.4470002f, 0.5049936f, 0.5890103f,
0.56614274, 0.49510333, 0.44703856, 0.50454074, 0.5879411 , 0.56614274f, 0.49510333f, 0.44703856f, 0.50454074f, 0.5879411f,
0.567134 , 0.49553978, 0.4470674 , 0.504163 , 0.5870515 , 0.567134f, 0.49553978f, 0.4470674f, 0.504163f, 0.5870515f,
0.5679643 , 0.4959048 , 0.44708967, 0.5038433 , 0.5862998 , 0.5679643f, 0.4959048f, 0.44708967f, 0.5038433f, 0.5862998f,
0.56866974, 0.4962146 , 0.44710726, 0.5035692 , 0.58565617, 0.56866974f, 0.4962146f, 0.44710726f, 0.5035692f, 0.58565617f,
0.56927663, 0.49648085, 0.4471213 , 0.5033315 , 0.5850988 , 0.56927663f, 0.49648085f, 0.4471213f, 0.5033315f, 0.5850988f,
0.56980413, 0.49671215, 0.44713274, 0.50312346, 0.58461165, 0.56980413f, 0.49671215f, 0.44713274f, 0.50312346f, 0.58461165f,
0.57026696, 0.49691492, 0.4471422 , 0.50293994, 0.58418214, 0.57026696f, 0.49691492f, 0.4471422f, 0.50293994f, 0.58418214f,
0.5706764 , 0.49709415, 0.44715008, 0.5027767 , 0.5838005 , 0.5706764f, 0.49709415f, 0.44715008f, 0.5027767f, 0.5838005f,
0.571041 , 0.4972537 , 0.44715673, 0.50263065, 0.58345926, 0.571041f, 0.4972537f, 0.44715673f, 0.50263065f, 0.58345926f,
0.57136786, 0.49739665, 0.44716236, 0.5024992 , 0.58315235, 0.57136786f, 0.49739665f, 0.44716236f, 0.5024992f, 0.58315235f,
0.5716625 , 0.49752548, 0.4471672 , 0.5023803, 0.5828747 , 0.5716625f, 0.49752548f, 0.4471672f, 0.5023803f, 0.5828747f,
0.5719295 , 0.49764213, 0.44717142, 0.5022721, 0.5826225 , 0.5719295f, 0.49764213f, 0.44717142f, 0.5022721f, 0.5826225f,
0.57217246, 0.49774826, 0.44717506, 0.5021734, 0.58239233, 0.57217246f, 0.49774826f, 0.44717506f, 0.5021734f, 0.58239233f,
0.5723947 , 0.4978453 , 0.44717824, 0.5020829, 0.58218133, 0.5723947f, 0.4978453f, 0.44717824f, 0.5020829f, 0.58218133f,
0.57259864, 0.49793428, 0.44718108, 0.5019997, 0.5819874 , 0.57259864f, 0.49793428f, 0.44718108f, 0.5019997f, 0.5819874f,
0.5727864 , 0.49801624, 0.44718358, 0.5019227, 0.5818083 , 0.5727864f, 0.49801624f, 0.44718358f, 0.5019227f, 0.5818083f,
0.57296 , 0.49809194, 0.44718578, 0.5018515, 0.5816426 , 0.57296f, 0.49809194f, 0.44718578f, 0.5018515f, 0.5816426f,
0.5731208 , 0.49816203, 0.44718775, 0.5017854, 0.58148885, 0.5731208f, 0.49816203f, 0.44718775f, 0.5017854f, 0.58148885f,
0.57327026, 0.49822718, 0.4471895 , 0.5017239, 0.5813457 , 0.57327026f, 0.49822718f, 0.4471895f, 0.5017239f, 0.5813457f,
0.57340944, 0.49828786, 0.44719115, 0.5016664, 0.581212 , 0.57340944f, 0.49828786f, 0.44719115f, 0.5016664f, 0.581212f,
0.57353944, 0.4983446 , 0.44719255, 0.50161266, 0.58108705, 0.57353944f, 0.4983446f, 0.44719255f, 0.50161266f, 0.58108705f,
0.5736612 , 0.49839762, 0.4471939 , 0.50156236, 0.5809699 , 0.5736612f, 0.49839762f, 0.4471939f, 0.50156236f, 0.5809699f,
0.5737754 , 0.4984474 , 0.44719502, 0.501515 , 0.58085984, 0.5737754f, 0.4984474f, 0.44719502f, 0.501515f, 0.58085984f,
0.5738828 , 0.49849418, 0.4471962 , 0.50147045, 0.5807564 , 0.5738828f, 0.49849418f, 0.4471962f, 0.50147045f, 0.5807564f,
0.5739839 , 0.49853817, 0.44719717, 0.5014284 , 0.5806588 , 0.5739839f, 0.49853817f, 0.44719717f, 0.5014284f, 0.5806588f,
0.5740793 , 0.49857965, 0.4471981 , 0.5013887 , 0.5805666 , 0.5740793f, 0.49857965f, 0.4471981f, 0.5013887f, 0.5805666f,
0.5741694 , 0.49861887, 0.44719887, 0.50135124, 0.58047944, 0.5741694f, 0.49861887f, 0.44719887f, 0.50135124f, 0.58047944f,
0.57425463, 0.49865603, 0.44719967, 0.5013157 , 0.5803969 , 0.57425463f, 0.49865603f, 0.44719967f, 0.5013157f, 0.5803969f,
0.5743354 , 0.4986912 , 0.44720036, 0.5012819 , 0.5803186 , 0.5743354f, 0.4986912f, 0.44720036f, 0.5012819f, 0.5803186f,
0.57441217, 0.49872455, 0.44720104, 0.5012499 , 0.58024424, 0.57441217f, 0.49872455f, 0.44720104f, 0.5012499f, 0.58024424f,
0.57448506, 0.4987563 , 0.44720164, 0.5012194 , 0.58017343, 0.57448506f, 0.4987563f, 0.44720164f, 0.5012194f, 0.58017343f,
0.57455444, 0.4987865 , 0.4472022 , 0.5011904 , 0.5801061, 0.57455444f, 0.4987865f, 0.4472022f, 0.5011904f, 0.5801061f,
0.57462054, 0.49881527, 0.44720277, 0.5011627 , 0.5800419, 0.57462054f, 0.49881527f, 0.44720277f, 0.5011627f, 0.5800419f,
0.57468355, 0.49884263, 0.44720328, 0.50113624, 0.5799805, 0.57468355f, 0.49884263f, 0.44720328f, 0.50113624f, 0.5799805f,
0.57474375, 0.49886885, 0.44720373, 0.50111103, 0.5799219 }); 0.57474375f, 0.49886885f, 0.44720373f, 0.50111103f, 0.5799219f });
// //
auto exp = NDArrayFactory::create<TypeParam>('c', {3,3,5,5}, { auto exp = NDArrayFactory::create<TypeParam>('c', {3,3,5,5}, {
0.061538, 0.055617, 0.044643, 0.050772, 0.048019, 0.030270, 0.023819, 0.019468, 0.022074, 0.023990, 0.018221, 0.014664, 0.012182, 0.013954, 0.015685, 0.012967, 0.010563, 0.008841, 0.010185, 0.011621, 0.010052, 0.008248, 0.006934, 0.008015, 0.009222, 0.008204, 0.006764, 0.005702, 0.006606, 0.007642, 0.006929, 0.005732, 0.004841, 0.005618, 0.006523, 0.005996, 0.004973, 0.004205, 0.004887, 0.005689, 0.005284, 0.004391, 0.003717, 0.004324, 0.005044, 0.004723, 0.003931, 0.003331, 0.003877, 0.004531, 0.004270, 0.003558, 0.003017, 0.003514, 0.004112, 0.003896, 0.003250, 0.002757, 0.003213, 0.003764, 0.003582, 0.002991, 0.002539, 0.002959, 0.003470, 0.003315, 0.002770, 0.002352, 0.002743, 0.003219, 0.003085, 0.002580, 0.002191, 0.002556, 0.003002, 0.002885, 0.002414, 0.002051, 0.002393, 0.002812, 0.002709, 0.002268, 0.001927, 0.002250, 0.002645, 0.002553, 0.002138, 0.001818, 0.002122, 0.002496, 0.002415, 0.002023, 0.001720, 0.002009, 0.002363, 0.002290, 0.001920, 0.001632, 0.001906, 0.002244, 0.002178, 0.001826, 0.001553, 0.001814, 0.002136, 0.002076, 0.001741, 0.001481, 0.001731, 0.002038, 0.001984, 0.001664, 0.001416, 0.001654, 0.001949, 0.001899, 0.001593, 0.001356, 0.001584, 0.001867, 0.001821, 0.001528, 0.001301, 0.001520, 0.001792, 0.001750, 0.001469, 0.001250, 0.001461, 0.001722, 0.001683, 0.001413, 0.001203, 0.001406, 0.001658, 0.001622, 0.001362, 0.001159, 0.001355, 0.001599, 0.001565, 0.001314, 0.001119, 0.001308, 0.001543, 0.001512, 0.001270, 0.001081, 0.001264, 0.001491, 0.001462, 0.001228, 0.001046, 0.001223, 0.001443, 0.001415, 0.001189, 0.001013, 0.001184, 0.001397, 0.001372, 0.001153, 0.000982, 0.001148, 0.001355, 0.001331, 0.001118, 0.000952, 0.001114, 0.001315, 0.001292, 0.001086, 0.000925, 0.001082, 0.001277, 0.001255, 0.001055, 0.000899, 0.001051, 0.001241, 0.001221, 0.001026, 0.000874, 0.001023, 0.001208, 0.001188, 0.000999, 0.000851, 0.000996, 0.001176, 0.001157, 0.000973, 0.000829, 0.000970, 0.001145, 0.001128, 0.000949, 0.000808, 0.000945, 0.001117, 0.001100, 0.000925, 0.000788, 0.000922, 0.001089, 0.001073, 0.000903, 0.000769, 0.000900, 0.001063, 0.001048, 0.000882, 0.000751, 0.000879, 0.001038, 0.001024, 0.000861, 0.000734, 0.000859, 0.001015, 0.001001, 0.000842, 0.000717, 0.000840, 0.000992} 0.061538f, 0.055617f, 0.044643f, 0.050772f, 0.048019f, 0.030270f, 0.023819f, 0.019468f, 0.022074f, 0.023990f, 0.018221f, 0.014664f, 0.012182f, 0.013954f, 0.015685f, 0.012967f, 0.010563f, 0.008841f, 0.010185f, 0.011621f, 0.010052f, 0.008248f, 0.006934f, 0.008015f, 0.009222f, 0.008204f, 0.006764f, 0.005702f, 0.006606f, 0.007642f, 0.006929f, 0.005732f, 0.004841f, 0.005618f, 0.006523f, 0.005996f, 0.004973f, 0.004205f, 0.004887f, 0.005689f, 0.005284f, 0.004391f, 0.003717f, 0.004324f, 0.005044f, 0.004723f, 0.003931f, 0.003331f, 0.003877f, 0.004531f, 0.004270f, 0.003558f, 0.003017f, 0.003514f, 0.004112f, 0.003896f, 0.003250f, 0.002757f, 0.003213f, 0.003764f, 0.003582f, 0.002991f, 0.002539f, 0.002959f, 0.003470f, 0.003315f, 0.002770f, 0.002352f, 0.002743f, 0.003219f, 0.003085f, 0.002580f, 0.002191f, 0.002556f, 0.003002f, 0.002885f, 0.002414f, 0.002051f, 0.002393f, 0.002812f, 0.002709f, 0.002268f, 0.001927f, 0.002250f, 0.002645f, 0.002553f, 0.002138f, 0.001818f, 0.002122f, 0.002496f, 0.002415f, 0.002023f, 0.001720f, 0.002009f, 0.002363f, 0.002290f, 0.001920f, 0.001632f, 0.001906f, 0.002244f, 0.002178f, 0.001826f, 0.001553f, 0.001814f, 0.002136f, 0.002076f, 0.001741f, 0.001481f, 0.001731f, 0.002038f, 0.001984f, 0.001664f, 0.001416f, 0.001654f, 0.001949f, 0.001899f, 0.001593f, 0.001356f, 0.001584f, 0.001867f, 0.001821f, 0.001528f, 0.001301f, 0.001520f, 0.001792f, 0.001750f, 0.001469f, 0.001250f, 0.001461f, 0.001722f, 0.001683f, 0.001413f, 0.001203f, 0.001406f, 0.001658f, 0.001622f, 0.001362f, 0.001159f, 0.001355f, 0.001599f, 0.001565f, 0.001314f, 0.001119f, 0.001308f, 0.001543f, 0.001512f, 0.001270f, 0.001081f, 0.001264f, 0.001491f, 0.001462f, 0.001228f, 0.001046f, 0.001223f, 0.001443f, 0.001415f, 0.001189f, 0.001013f, 0.001184f, 0.001397f, 0.001372f, 0.001153f, 0.000982f, 0.001148f, 0.001355f, 0.001331f, 0.001118f, 0.000952f, 0.001114f, 0.001315f, 0.001292f, 0.001086f, 0.000925f, 0.001082f, 0.001277f, 0.001255f, 0.001055f, 0.000899f, 0.001051f, 0.001241f, 0.001221f, 0.001026f, 0.000874f, 0.001023f, 0.001208f, 0.001188f, 0.000999f, 0.000851f, 0.000996f, 0.001176f, 0.001157f, 0.000973f, 0.000829f, 0.000970f, 0.001145f, 0.001128f, 0.000949f, 0.000808f, 0.000945f, 0.001117f, 0.001100f, 0.000925f, 0.000788f, 0.000922f, 0.001089f, 0.001073f, 0.000903f, 0.000769f, 0.000900f, 0.001063f, 0.001048f, 0.000882f, 0.000751f, 0.000879f, 0.001038f, 0.001024f, 0.000861f, 0.000734f, 0.000859f, 0.001015f, 0.001001f, 0.000842f, 0.000717f, 0.000840f, 0.000992f}
// 0.009859, 0.013075, 0.013874, 0.017893, 0.022344, 0.014551, 0.012859, 0.011511, 0.013311, 0.015834, 0.012025, 0.010047, 0.008601, 0.009920, 0.011885, 0.009505, 0.007636, 0.006299, 0.007413, 0.009095, 0.007446, 0.005743, 0.004540, 0.005533, 0.007033, 0.005821, 0.004282, 0.003209, 0.004123, 0.005491, 0.004577, 0.003198, 0.002247, 0.003097, 0.004355, 0.003652, 0.002412, 0.001565, 0.002357, 0.003517, 0.002965, 0.001844, 0.001084, 0.001821, 0.002893, 0.002451, 0.001430, 0.000741, 0.001428, 0.002422, -0.111434, -0.105946, -0.100351, -0.091868, -0.083323, -0.078775, -0.076222, -0.073291, -0.067635, -0.061692, -0.058943, -0.057832, -0.056263, -0.052198, -0.047768, -0.046002, -0.045655, -0.044839, -0.041748, -0.038271, -0.037084, -0.037161, -0.036786, -0.034331, -0.031495, 0.000077, -0.000673, -0.001181, -0.000667, 0.000079, -0.000089, -0.000802, -0.001285, -0.000793, -0.000079, -0.000228, -0.000908, -0.001368, -0.000896, -0.000212, -0.000345, -0.000996, -0.001434, -0.000981, -0.000325, -0.000444, -0.001067, -0.001487, -0.001051, -0.000421, 0.000697, 0.000188, -0.000152, 0.000210, 0.000731, 0.000650, 0.000165, -0.000161, 0.000185, 0.000683, 0.000610, 0.000145, -0.000168, 0.000164, 0.000641, 0.000574, 0.000128, -0.000172, 0.000146, 0.000604, 0.000542, 0.000113, -0.000175, 0.000131, 0.000571, -0.009490, -0.010070, -0.010409, -0.009734, -0.008834, -0.008785, -0.009351, -0.009687, -0.009054, -0.008207, -0.008167, -0.008718, -0.009050, -0.008455, -0.007654, -0.007622, -0.008159, -0.008485, -0.007924, -0.007164, -0.007138, -0.007661, -0.007981, -0.007450, -0.006728, -0.000901, -0.001327, -0.001614, -0.001310, -0.000869, -0.000913, -0.001328, -0.001607, -0.001310, -0.000882, -0.000922, -0.001326, -0.001598, -0.001309, -0.000892, -0.000930, -0.001323, -0.001588, -0.001306, -0.000900, -0.000936, -0.001319, -0.001577, -0.001302, -0.000906, 0.000339, 0.000038, -0.000164, 0.000048, 0.000355, 0.000328, 0.000035, -0.000162, 0.000045, 0.000343, 0.000318, 0.000033, -0.000159, 0.000041, 0.000332, 0.000308, 0.000030, -0.000157, 0.000039, 0.000322, 0.000299, 0.000028, -0.000155, 0.000036, 0.000312, -0.004085, -0.004479, -0.004733, -0.004396, -0.003925, -0.003925, -0.004309, -0.004558, -0.004232, -0.003775, -0.003776, -0.004151, -0.004395, -0.004079, -0.003636, -0.003637, -0.004004, -0.004242, -0.003936, -0.003505, -0.003507, -0.003866, -0.004100, -0.003802, -0.003383} // 0.009859f, 0.013075f, 0.013874f, 0.017893f, 0.022344f, 0.014551f, 0.012859f, 0.011511f, 0.013311f, 0.015834f, 0.012025f, 0.010047f, 0.008601f, 0.009920f, 0.011885f, 0.009505f, 0.007636f, 0.006299f, 0.007413f, 0.009095f, 0.007446f, 0.005743f, 0.004540f, 0.005533f, 0.007033f, 0.005821f, 0.004282f, 0.003209f, 0.004123f, 0.005491f, 0.004577f, 0.003198f, 0.002247f, 0.003097f, 0.004355f, 0.003652f, 0.002412f, 0.001565f, 0.002357f, 0.003517f, 0.002965f, 0.001844f, 0.001084f, 0.001821f, 0.002893f, 0.002451f, 0.001430f, 0.000741f, 0.001428f, 0.002422f, -0.111434f, -0.105946f, -0.100351f, -0.091868f, -0.083323f, -0.078775f, -0.076222f, -0.073291f, -0.067635f, -0.061692f, -0.058943f, -0.057832f, -0.056263f, -0.052198f, -0.047768f, -0.046002f, -0.045655f, -0.044839f, -0.041748f, -0.038271f, -0.037084f, -0.037161f, -0.036786f, -0.034331f, -0.031495f, 0.000077f, -0.000673f, -0.001181f, -0.000667f, 0.000079f, -0.000089f, -0.000802f, -0.001285f, -0.000793f, -0.000079f, -0.000228f, -0.000908f, -0.001368f, -0.000896f, -0.000212f, -0.000345f, -0.000996f, -0.001434f, -0.000981f, -0.000325f, -0.000444f, -0.001067f, -0.001487f, -0.001051f, -0.000421f, 0.000697f, 0.000188f, -0.000152f, 0.000210f, 0.000731f, 0.000650f, 0.000165f, -0.000161f, 0.000185f, 0.000683f, 0.000610f, 0.000145f, -0.000168f, 0.000164f, 0.000641f, 0.000574f, 0.000128f, -0.000172f, 0.000146f, 0.000604f, 0.000542f, 0.000113f, -0.000175f, 0.000131f, 0.000571f, -0.009490f, -0.010070f, -0.010409f, -0.009734f, -0.008834f, -0.008785f, -0.009351f, -0.009687f, -0.009054f, -0.008207f, -0.008167f, -0.008718f, -0.009050f, -0.008455f, -0.007654f, -0.007622f, -0.008159f, -0.008485f, -0.007924f, -0.007164f, -0.007138f, -0.007661f, -0.007981f, -0.007450f, -0.006728f, -0.000901f, -0.001327f, -0.001614f, -0.001310f, -0.000869f, -0.000913f, -0.001328f, -0.001607f, -0.001310f, -0.000882f, -0.000922f, -0.001326f, -0.001598f, -0.001309f, -0.000892f, -0.000930f, -0.001323f, -0.001588f, -0.001306f, -0.000900f, -0.000936f, -0.001319f, -0.001577f, -0.001302f, -0.000906f, 0.000339f, 0.000038f, -0.000164f, 0.000048f, 0.000355f, 0.000328f, 0.000035f, -0.000162f, 0.000045f, 0.000343f, 0.000318f, 0.000033f, -0.000159f, 0.000041f, 0.000332f, 0.000308f, 0.000030f, -0.000157f, 0.000039f, 0.000322f, 0.000299f, 0.000028f, -0.000155f, 0.000036f, 0.000312f, -0.004085f, -0.004479f, -0.004733f, -0.004396f, -0.003925f, -0.003925f, -0.004309f, -0.004558f, -0.004232f, -0.003775f, -0.003776f, -0.004151f, -0.004395f, -0.004079f, -0.003636f, -0.003637f, -0.004004f, -0.004242f, -0.003936f, -0.003505f, -0.003507f, -0.003866f, -0.004100f, -0.003802f, -0.003383f}
); );
nd4j::ops::lrn_bp op; nd4j::ops::lrn_bp op;

View File

@ -1903,13 +1903,13 @@ TEST_F(DeclarableOpsTests9, cumprod_2) {
NDArray exp0 = exp(0, {0}); NDArray exp0 = exp(0, {0});
NDArray exp1 = exp(1, {0}); NDArray exp1 = exp(1, {0});
exp0.p<float>(0, 1.); exp0.p(0, 1.f);
exp1.p<float>(0, 1.); exp1.p(0, 1.f);
for (int i = 1; i < 1500; ++i) { for (int i = 1; i < 1500; ++i) {
const auto prev = exp0.e<float>(i-1); const auto prev = exp0.e<float>(i-1);
exp0.p<float>(i, prev * x0.e<float>(i)); exp0.p(i, prev * x0.e<float>(i));
exp1.p<float>(i, prev * x1.e<float>(i)); exp1.p(i, prev * x1.e<float>(i));
} }
nd4j::ops::cumprod op; nd4j::ops::cumprod op;
@ -3331,8 +3331,8 @@ TEST_F(DeclarableOpsTests9, Cholesky_Test_2) {
//////////////////////////////////////////////////////////////////// ////////////////////////////////////////////////////////////////////
TEST_F(DeclarableOpsTests9, Cholesky_Test_3) { TEST_F(DeclarableOpsTests9, Cholesky_Test_3) {
NDArray x = NDArrayFactory::create<float>('c', {2, 3, 3}, {4, 12,-16, 12 ,37,-43, -16, -43, 98, 1, 1, 1, 1, 2, 2, 1, 2., 6}); NDArray x = NDArrayFactory::create<float>('c', {2, 3, 3}, {4.f, 12.f, -16.f, 12.f, 37.f, -43.f, -16.f, -43.f, 98.f, 1.f, 1.f, 1.f, 1.f, 2.f, 2.f, 1.f, 2.f, 6.f});
NDArray exp = NDArrayFactory::create<float>('c', {2, 3, 3}, {2., 0., 0., 6., 1., 0., -8., 5., 3., 1., 0., 0., 1., 1., 0,1., 1., 2.}); NDArray exp = NDArrayFactory::create<float>('c', {2, 3, 3}, {2.f, 0.f, 0.f, 6.f, 1.f, 0.f, -8.f, 5.f, 3.f, 1.f, 0.f, 0.f, 1.f, 1.f, 0.f, 1.f, 1.f, 2.f});
nd4j::ops::cholesky op; nd4j::ops::cholesky op;

File diff suppressed because one or more lines are too long

View File

@ -94,7 +94,7 @@ TEST_F(NDArrayTest2, Test_Reshape_Scalar_2) {
} }
TEST_F(NDArrayTest2, Test_IndexReduce_1) { TEST_F(NDArrayTest2, Test_IndexReduce_1) {
auto x = NDArrayFactory::create<float>('c', {1, 5}, {1, 2, 3, 4, 5}); auto x = NDArrayFactory::create<double>('c', {1, 5}, {1, 2, 3, 4, 5});
ExtraArguments extras({3.0, 0.0, 10.0}); ExtraArguments extras({3.0, 0.0, 10.0});
int idx = x.indexReduceNumber(indexreduce::FirstIndex, &extras).e<int>(0); int idx = x.indexReduceNumber(indexreduce::FirstIndex, &extras).e<int>(0);
@ -160,7 +160,7 @@ TEST_F(NDArrayTest2, SetIdentity_test_5) {
TEST_F(NDArrayTest2, SetIdentity_test_6) { TEST_F(NDArrayTest2, SetIdentity_test_6) {
auto x = NDArrayFactory::create<float>('c', {3, 2}); auto x = NDArrayFactory::create<float>('c', {3, 2});
auto xExp = NDArrayFactory::create<float>('c', {3, 2}, {1,0,0,1,0,0}); auto xExp = NDArrayFactory::create<float>('c', {3, 2}, {1.f, 0.f, 0.f, 1.f, 0.f, 0.f});
x.setIdentity(); x.setIdentity();
@ -171,7 +171,7 @@ TEST_F(NDArrayTest2, SetIdentity_test_6) {
TEST_F(NDArrayTest2, SetIdentity_test_7) { TEST_F(NDArrayTest2, SetIdentity_test_7) {
auto x = NDArrayFactory::create<float>('c', {3, 4}); auto x = NDArrayFactory::create<float>('c', {3, 4});
auto xExp = NDArrayFactory::create<float>('c', {3, 4}, {1.,0.,0.,0.,0.,1.,0.,0.,0.,0.,1.,0.}); auto xExp = NDArrayFactory::create<float>('c', {3, 4}, {1.f, 0.f, 0.f, 0.f, 0.f, 1.f, 0.f, 0.f, 0.f, 0.f, 1.f, 0.f});
x.setIdentity(); x.setIdentity();
@ -192,9 +192,9 @@ TEST_F(NDArrayTest2, SetIdentity_test_8) {
//////////////////////////////////////////////////////////////////// ////////////////////////////////////////////////////////////////////
TEST_F(NDArrayTest2, Test_AllReduce3_1) { TEST_F(NDArrayTest2, Test_AllReduce3_1) {
auto x = NDArrayFactory::create<float>('c', {2, 3}, {1, 2, 3, 1, 2, 3}); auto x = NDArrayFactory::create<double>('c', {2, 3}, {1, 2, 3, 1, 2, 3});
auto y = NDArrayFactory::create<float>('c', {2, 3}, {2, 3, 4, 2, 3, 4}); auto y = NDArrayFactory::create<double>('c', {2, 3}, {2, 3, 4, 2, 3, 4});
auto exp = NDArrayFactory::create<float>('c', {2, 2}, {1.73205, 1.73205, 1.73205, 1.73205}); auto exp = NDArrayFactory::create<double>('c', {2, 2}, {1.73205, 1.73205, 1.73205, 1.73205});
auto z = x.applyAllReduce3(reduce3::EuclideanDistance, &y, {1}, nullptr); auto z = x.applyAllReduce3(reduce3::EuclideanDistance, &y, {1}, nullptr);
@ -206,9 +206,9 @@ TEST_F(NDArrayTest2, Test_AllReduce3_1) {
//////////////////////////////////////////////////////////////////// ////////////////////////////////////////////////////////////////////
TEST_F(NDArrayTest2, Test_AllReduce3_2) { TEST_F(NDArrayTest2, Test_AllReduce3_2) {
auto x = NDArrayFactory::create<float>('c', {2, 3}, {1, 2, 3, 2, 3, 4 }); auto x = NDArrayFactory::create<double>('c', {2, 3}, {1, 2, 3, 2, 3, 4 });
auto y = NDArrayFactory::create<float>('c', {2, 3}, {1, 2, 3, 2, 3, 4}); auto y = NDArrayFactory::create<double>('c', {2, 3}, {1, 2, 3, 2, 3, 4});
auto exp = NDArrayFactory::create<float>('c', {2, 2}, {0., 1.73205, 1.73205, 0.}); auto exp = NDArrayFactory::create<double>('c', {2, 2}, {0., 1.73205, 1.73205, 0.});
auto z = x.applyAllReduce3(reduce3::EuclideanDistance, &y, {1}, nullptr); auto z = x.applyAllReduce3(reduce3::EuclideanDistance, &y, {1}, nullptr);
@ -221,9 +221,9 @@ TEST_F(NDArrayTest2, Test_AllReduce3_2) {
//////////////////////////////////////////////////////////////////// ////////////////////////////////////////////////////////////////////
TEST_F(NDArrayTest2, mmul_test1) { TEST_F(NDArrayTest2, mmul_test1) {
auto x = NDArrayFactory::create<float>('c', {4, 1}, {1, 2, 3, 4}); auto x = NDArrayFactory::create<double>('c', {4, 1}, {1, 2, 3, 4});
auto y = NDArrayFactory::create<float>('c', {1, 4}, {1, 2, 3, 4}); auto y = NDArrayFactory::create<double>('c', {1, 4}, {1, 2, 3, 4});
auto exp = NDArrayFactory::create<float>('c', {4, 4}, {1,2, 3, 4,2,4, 6, 8,3,6, 9,12,4,8,12,16}); auto exp = NDArrayFactory::create<double>('c', {4, 4}, {1,2, 3, 4,2,4, 6, 8,3,6, 9,12,4,8,12,16});
auto result = mmul(x, y); auto result = mmul(x, y);
ASSERT_TRUE(exp.isSameShape(&result)); ASSERT_TRUE(exp.isSameShape(&result));
@ -234,9 +234,9 @@ TEST_F(NDArrayTest2, mmul_test1) {
//////////////////////////////////////////////////////////////////// ////////////////////////////////////////////////////////////////////
TEST_F(NDArrayTest2, mmul_test2) { TEST_F(NDArrayTest2, mmul_test2) {
auto x = NDArrayFactory::create<float>('c', {4, 1}, {1, 2, 3, 4}); auto x = NDArrayFactory::create<double>('c', {4, 1}, {1, 2, 3, 4});
auto y = NDArrayFactory::create<float>('c', {1, 4}, {1, 2, 3, 4}); auto y = NDArrayFactory::create<double>('c', {1, 4}, {1, 2, 3, 4});
auto exp = NDArrayFactory::create<float>('c', {1, 1}, {30}); auto exp = NDArrayFactory::create<double>('c', {1, 1}, {30});
auto result = mmul(y ,x); auto result = mmul(y ,x);
@ -248,10 +248,10 @@ TEST_F(NDArrayTest2, mmul_test2) {
//////////////////////////////////////////////////////////////////// ////////////////////////////////////////////////////////////////////
TEST_F(NDArrayTest2, mmul_test3) { TEST_F(NDArrayTest2, mmul_test3) {
auto x = NDArrayFactory::create<float>('c', {4, 1}, {1, 2, 3, 4}); auto x = NDArrayFactory::create<double>('c', {4, 1}, {1, 2, 3, 4});
auto exp = NDArrayFactory::create<float>('c', {4, 4}, {1. ,0.2 ,0.3 ,0.4 ,0.2,0.04,0.06,0.08,0.3,0.06,0.09,0.12,0.4,0.08,0.12,0.16}); auto exp = NDArrayFactory::create<double>('c', {4, 4}, {1. ,0.2 ,0.3 ,0.4 ,0.2,0.04,0.06,0.08,0.3,0.06,0.09,0.12,0.4,0.08,0.12,0.16});
auto w = NDArrayFactory::create<float>( x.ordering(), {(int)x.lengthOf(), 1}, x.getContext()); // column-vector auto w = NDArrayFactory::create<double>( x.ordering(), {(int)x.lengthOf(), 1}, x.getContext()); // column-vector
auto wT = NDArrayFactory::create<float>(x.ordering(), {1, (int)x.lengthOf()}, x.getContext()); // row-vector (transposed w) auto wT = NDArrayFactory::create<double>(x.ordering(), {1, (int)x.lengthOf()}, x.getContext()); // row-vector (transposed w)
w = x / (float)10.; w = x / (float)10.;
w.p(0, 1.); w.p(0, 1.);
@ -311,9 +311,9 @@ TEST_F(NDArrayTest2, Test_Enforce_1) {
} }
TEST_F(NDArrayTest2, TestVector_1) { TEST_F(NDArrayTest2, TestVector_1) {
auto x = NDArrayFactory::create<float>('c', {2, 3}); auto x = NDArrayFactory::create<double>('c', {2, 3});
auto row = NDArrayFactory::create<float>('c', {3}, {1, 2, 3}); auto row = NDArrayFactory::create<double>('c', {3}, {1, 2, 3});
auto exp = NDArrayFactory::create<float>('c', {2, 3}, {1, 2, 3, 1, 2, 3}); auto exp = NDArrayFactory::create<double>('c', {2, 3}, {1, 2, 3, 1, 2, 3});
x.addiRowVector(&row); x.addiRowVector(&row);
@ -341,9 +341,9 @@ TEST_F(NDArrayTest2, Operator_Plus_Test_5)
////////////////////////////////////////////////////////////////////// //////////////////////////////////////////////////////////////////////
TEST_F(NDArrayTest2, Operator_Plus_Test_6) { TEST_F(NDArrayTest2, Operator_Plus_Test_6) {
auto x = NDArrayFactory::create<float>('c', {3, 3, 3}); auto x = NDArrayFactory::create<double>('c', {3, 3, 3});
auto y = NDArrayFactory::create<float>('c', {3, 1, 3}); auto y = NDArrayFactory::create<double>('c', {3, 1, 3});
auto expected = NDArrayFactory::create<float>('c', {3, 3, 3}, {2., 4., 6., 5., 7., 9., 8.,10.,12., 14.,16.,18.,17.,19.,21.,20.,22.,24., 26.,28.,30.,29.,31.,33.,32.,34.,36.}); auto expected = NDArrayFactory::create<double>('c', {3, 3, 3}, {2., 4., 6., 5., 7., 9., 8.,10.,12., 14.,16.,18.,17.,19.,21.,20.,22.,24., 26.,28.,30.,29.,31.,33.,32.,34.,36.});
x.linspace(1); x.linspace(1);
y.linspace(1); y.linspace(1);
@ -356,8 +356,8 @@ TEST_F(NDArrayTest2, Operator_Plus_Test_6) {
////////////////////////////////////////////////////////////////////// //////////////////////////////////////////////////////////////////////
TEST_F(NDArrayTest2, tileToShape_test1) { TEST_F(NDArrayTest2, tileToShape_test1) {
auto x = NDArrayFactory::create<float>('c', {2, 2}, {1,2,3,4}); auto x = NDArrayFactory::create<double>('c', {2, 2}, {1,2,3,4});
auto exp = NDArrayFactory::create<float>('c', {2, 2, 2}, {1,2,3,4,1,2,3,4}); auto exp = NDArrayFactory::create<double>('c', {2, 2, 2}, {1,2,3,4,1,2,3,4});
x.tileToShape({2,2,2}); x.tileToShape({2,2,2});
@ -368,8 +368,8 @@ TEST_F(NDArrayTest2, tileToShape_test1) {
////////////////////////////////////////////////////////////////////// //////////////////////////////////////////////////////////////////////
TEST_F(NDArrayTest2, tileToShape_test2) { TEST_F(NDArrayTest2, tileToShape_test2) {
auto x = NDArrayFactory::create<float>('c', {2, 1, 2}, {1,2,3,4}); auto x = NDArrayFactory::create<double>('c', {2, 1, 2}, {1,2,3,4});
auto exp = NDArrayFactory::create<float>('c', {2, 3, 2}, {1,2,1,2,1,2,3,4,3,4,3,4}); auto exp = NDArrayFactory::create<double>('c', {2, 3, 2}, {1,2,1,2,1,2,3,4,3,4,3,4});
x.tileToShape({2,3,2}); x.tileToShape({2,3,2});
@ -380,9 +380,9 @@ TEST_F(NDArrayTest2, tileToShape_test2) {
////////////////////////////////////////////////////////////////////// //////////////////////////////////////////////////////////////////////
TEST_F(NDArrayTest2, tileToShape_test3) { TEST_F(NDArrayTest2, tileToShape_test3) {
auto x = NDArrayFactory::create<float>('c', {2, 2}, {1,2,3,4}); auto x = NDArrayFactory::create<double>('c', {2, 2}, {1,2,3,4});
auto result = NDArrayFactory::create<float>('c', {2, 2, 2}); auto result = NDArrayFactory::create<double>('c', {2, 2, 2});
auto exp = NDArrayFactory::create<float>('c', {2, 2, 2}, {1,2,3,4,1,2,3,4}); auto exp = NDArrayFactory::create<double>('c', {2, 2, 2}, {1,2,3,4,1,2,3,4});
x.tileToShape({2,2,2}, &result); x.tileToShape({2,2,2}, &result);
// result.printIndexedBuffer(); // result.printIndexedBuffer();
@ -394,9 +394,9 @@ TEST_F(NDArrayTest2, tileToShape_test3) {
////////////////////////////////////////////////////////////////////// //////////////////////////////////////////////////////////////////////
TEST_F(NDArrayTest2, tileToShape_test4) { TEST_F(NDArrayTest2, tileToShape_test4) {
auto x = NDArrayFactory::create<float>('c', {2, 1, 2}, {1,2,3,4}); auto x = NDArrayFactory::create<double>('c', {2, 1, 2}, {1,2,3,4});
auto result = NDArrayFactory::create<float>('c', {2, 3, 2}); auto result = NDArrayFactory::create<double>('c', {2, 3, 2});
auto exp = NDArrayFactory::create<float>('c', {2, 3, 2}, {1,2,1,2,1,2,3,4,3,4,3,4}); auto exp = NDArrayFactory::create<double>('c', {2, 3, 2}, {1,2,1,2,1,2,3,4,3,4,3,4});
x.tileToShape({2,3,2}, &result); x.tileToShape({2,3,2}, &result);
@ -407,50 +407,50 @@ TEST_F(NDArrayTest2, tileToShape_test4) {
#ifndef __CUDABLAS__ #ifndef __CUDABLAS__
TEST_F(NDArrayTest2, Test_TriplewiseLambda_1) { TEST_F(NDArrayTest2, Test_TriplewiseLambda_1) {
auto t = NDArrayFactory::create<float>('c', {3, 3}, {1, 1, 1, 1, 1, 1, 1, 1, 1}); auto t = NDArrayFactory::create<double>('c', {3, 3}, {1, 1, 1, 1, 1, 1, 1, 1, 1});
auto u = NDArrayFactory::create<float>('c', {3, 3}, {2, 2, 2, 2, 2, 2, 2, 2, 2}); auto u = NDArrayFactory::create<double>('c', {3, 3}, {2, 2, 2, 2, 2, 2, 2, 2, 2});
auto v = NDArrayFactory::create<float>('c', {3, 3}, {3, 3, 3, 3, 3, 3, 3, 3, 3}); auto v = NDArrayFactory::create<double>('c', {3, 3}, {3, 3, 3, 3, 3, 3, 3, 3, 3});
auto exp = NDArrayFactory::create<float>('c', {3, 3}, {7, 7, 7, 7, 7, 7, 7, 7, 7}); auto exp = NDArrayFactory::create<double>('c', {3, 3}, {7, 7, 7, 7, 7, 7, 7, 7, 7});
float extra = 1.0f; float extra = 1.0f;
auto la = LAMBDA_FFF(_t, _u, _v, extra) { auto la = LAMBDA_DDD(_t, _u, _v, extra) {
return _t + _u + _v + extra; return _t + _u + _v + extra;
}; };
t.applyTriplewiseLambda<float>(&u, &v, la); t.applyTriplewiseLambda<double>(&u, &v, la);
ASSERT_TRUE(t.equalsTo(&exp)); ASSERT_TRUE(t.equalsTo(&exp));
} }
TEST_F(NDArrayTest2, Test_TriplewiseLambda_2) { TEST_F(NDArrayTest2, Test_TriplewiseLambda_2) {
auto t = NDArrayFactory::create<float>('c', {3, 3}, {1, 1, 1, 1, 1, 1, 1, 1, 1}); auto t = NDArrayFactory::create<double>('c', {3, 3}, {1, 1, 1, 1, 1, 1, 1, 1, 1});
auto u = NDArrayFactory::create<float>('f', {3, 3}, {2, 2, 2, 2, 2, 2, 2, 2, 2}); auto u = NDArrayFactory::create<double>('f', {3, 3}, {2, 2, 2, 2, 2, 2, 2, 2, 2});
auto v = NDArrayFactory::create<float>('c', {3, 3}, {3, 3, 3, 3, 3, 3, 3, 3, 3}); auto v = NDArrayFactory::create<double>('c', {3, 3}, {3, 3, 3, 3, 3, 3, 3, 3, 3});
auto exp = NDArrayFactory::create<float>('c', {3, 3}, {7, 7, 7, 7, 7, 7, 7, 7, 7}); auto exp = NDArrayFactory::create<double>('c', {3, 3}, {7, 7, 7, 7, 7, 7, 7, 7, 7});
float extra = 1.0f; float extra = 1.0f;
auto la = LAMBDA_FFF(_t, _u, _v, extra) { auto la = LAMBDA_DDD(_t, _u, _v, extra) {
return _t + _u + _v + extra; return _t + _u + _v + extra;
}; };
t.applyTriplewiseLambda<float>(&u, &v, la); t.applyTriplewiseLambda<double>(&u, &v, la);
ASSERT_TRUE(t.equalsTo(&exp)); ASSERT_TRUE(t.equalsTo(&exp));
} }
//////////////////////////////////////////////////////////////////// ////////////////////////////////////////////////////////////////////
TEST_F(NDArrayTest2, Test_Indexed_Lambda) { TEST_F(NDArrayTest2, Test_Indexed_Lambda) {
auto x = NDArrayFactory::create<float>('c', {2, 2}); auto x = NDArrayFactory::create<double>('c', {2, 2});
auto exp = NDArrayFactory::create<float>('c', {2, 2}, {0, 1, 2, 3}); auto exp = NDArrayFactory::create<double>('c', {2, 2}, {0, 1, 2, 3});
auto lambda = ILAMBDA_F(_x) { auto lambda = ILAMBDA_D(_x) {
return (float) _idx; return (float) _idx;
}; };
x.applyIndexedLambda<float>(lambda); x.applyIndexedLambda<double>(lambda);
ASSERT_TRUE(exp.equalsTo(&x)); ASSERT_TRUE(exp.equalsTo(&x));
} }
@ -458,8 +458,8 @@ TEST_F(NDArrayTest2, Test_Indexed_Lambda) {
#endif #endif
TEST_F(NDArrayTest2, Test_PermuteEquality_1) { TEST_F(NDArrayTest2, Test_PermuteEquality_1) {
auto x = NDArrayFactory::create<float>('c', {1, 60}); auto x = NDArrayFactory::create<double>('c', {1, 60});
auto exp = NDArrayFactory::create<float>('c', {3, 5, 4}, {1.0, 6.0, 11.0, 16.0, 2.0, 7.0, 12.0, 17.0, 3.0, 8.0, 13.0, 18.0, 4.0, 9.0, 14.0, 19.0, 5.0, 10.0, 15.0, 20.0, 21.0, 26.0, 31.0, 36.0, 22.0, 27.0, 32.0, 37.0, 23.0, 28.0, 33.0, 38.0, 24.0, 29.0, 34.0, 39.0, 25.0, 30.0, 35.0, 40.0, 41.0, 46.0, 51.0, 56.0, 42.0, 47.0, 52.0, 57.0, 43.0, 48.0, 53.0, 58.0, 44.0, 49.0, 54.0, 59.0, 45.0, 50.0, 55.0, 60.0}); auto exp = NDArrayFactory::create<double>('c', {3, 5, 4}, {1.0, 6.0, 11.0, 16.0, 2.0, 7.0, 12.0, 17.0, 3.0, 8.0, 13.0, 18.0, 4.0, 9.0, 14.0, 19.0, 5.0, 10.0, 15.0, 20.0, 21.0, 26.0, 31.0, 36.0, 22.0, 27.0, 32.0, 37.0, 23.0, 28.0, 33.0, 38.0, 24.0, 29.0, 34.0, 39.0, 25.0, 30.0, 35.0, 40.0, 41.0, 46.0, 51.0, 56.0, 42.0, 47.0, 52.0, 57.0, 43.0, 48.0, 53.0, 58.0, 44.0, 49.0, 54.0, 59.0, 45.0, 50.0, 55.0, 60.0});
x.linspace(1); x.linspace(1);
x.reshapei('c', {3, 4, 5}); x.reshapei('c', {3, 4, 5});
@ -474,9 +474,9 @@ TEST_F(NDArrayTest2, Test_PermuteEquality_1) {
} }
TEST_F(NDArrayTest2, Test_PermuteEquality_0) { TEST_F(NDArrayTest2, Test_PermuteEquality_0) {
auto x = NDArrayFactory::create<float>('c', {1, 60}); auto x = NDArrayFactory::create<double>('c', {1, 60});
x.linspace(1); x.linspace(1);
auto exp = NDArrayFactory::create<float>('c', {3, 4, 5}, {1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0, 11.0, 12.0, 13.0, 14.0, 15.0, 16.0, 17.0, 18.0, 19.0, 20.0, 21.0, 22.0, 23.0, 24.0, 25.0, 26.0, 27.0, 28.0, 29.0, 30.0, 31.0, 32.0, 33.0, 34.0, 35.0, 36.0, 37.0, 38.0, 39.0, 40.0, 41.0, 42.0, 43.0, 44.0, 45.0, 46.0, 47.0, 48.0, 49.0, 50.0, 51.0, 52.0, 53.0, 54.0, 55.0, 56.0, 57.0, 58.0, 59.0, 60.0}); auto exp = NDArrayFactory::create<double>('c', {3, 4, 5}, {1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0, 11.0, 12.0, 13.0, 14.0, 15.0, 16.0, 17.0, 18.0, 19.0, 20.0, 21.0, 22.0, 23.0, 24.0, 25.0, 26.0, 27.0, 28.0, 29.0, 30.0, 31.0, 32.0, 33.0, 34.0, 35.0, 36.0, 37.0, 38.0, 39.0, 40.0, 41.0, 42.0, 43.0, 44.0, 45.0, 46.0, 47.0, 48.0, 49.0, 50.0, 51.0, 52.0, 53.0, 54.0, 55.0, 56.0, 57.0, 58.0, 59.0, 60.0});
x.reshapei('c', {3, 4, 5}); x.reshapei('c', {3, 4, 5});
x.permutei({0, 1, 2}); x.permutei({0, 1, 2});
@ -491,9 +491,9 @@ TEST_F(NDArrayTest2, Test_PermuteEquality_0) {
TEST_F(NDArrayTest2, Test_PermuteEquality_2) { TEST_F(NDArrayTest2, Test_PermuteEquality_2) {
auto x = NDArrayFactory::create<float>('c', {1, 60}); auto x = NDArrayFactory::create<double>('c', {1, 60});
x.linspace(1); x.linspace(1);
auto exp = NDArrayFactory::create<float>('c', {4, 3, 5}, {1.0, 2.0, 3.0, 4.0, 5.0, 21.0, 22.0, 23.0, 24.0, 25.0, 41.0, 42.0, 43.0, 44.0, 45.0, 6.0, 7.0, 8.0, 9.0, 10.0, 26.0, 27.0, 28.0, 29.0, 30.0, 46.0, 47.0, 48.0, 49.0, 50.0, 11.0, 12.0, 13.0, 14.0, 15.0, 31.0, 32.0, 33.0, 34.0, 35.0, 51.0, 52.0, 53.0, 54.0, 55.0, 16.0, 17.0, 18.0, 19.0, 20.0, 36.0, 37.0, 38.0, 39.0, 40.0, 56.0, 57.0, 58.0, 59.0, 60.0}); auto exp = NDArrayFactory::create<double>('c', {4, 3, 5}, {1.0, 2.0, 3.0, 4.0, 5.0, 21.0, 22.0, 23.0, 24.0, 25.0, 41.0, 42.0, 43.0, 44.0, 45.0, 6.0, 7.0, 8.0, 9.0, 10.0, 26.0, 27.0, 28.0, 29.0, 30.0, 46.0, 47.0, 48.0, 49.0, 50.0, 11.0, 12.0, 13.0, 14.0, 15.0, 31.0, 32.0, 33.0, 34.0, 35.0, 51.0, 52.0, 53.0, 54.0, 55.0, 16.0, 17.0, 18.0, 19.0, 20.0, 36.0, 37.0, 38.0, 39.0, 40.0, 56.0, 57.0, 58.0, 59.0, 60.0});
x.reshapei('c', {3, 4, 5}); x.reshapei('c', {3, 4, 5});
x.permutei({1, 0, 2}); x.permutei({1, 0, 2});
@ -507,9 +507,9 @@ TEST_F(NDArrayTest2, Test_PermuteEquality_2) {
} }
TEST_F(NDArrayTest2, Test_PermuteEquality_3) { TEST_F(NDArrayTest2, Test_PermuteEquality_3) {
auto x = NDArrayFactory::create<float>('c', {1, 60}); auto x = NDArrayFactory::create<double>('c', {1, 60});
x.linspace(1); x.linspace(1);
auto exp = NDArrayFactory::create<float>('c', {4, 5, 3}, {1.0, 21.0, 41.0, 2.0, 22.0, 42.0, 3.0, 23.0, 43.0, 4.0, 24.0, 44.0, 5.0, 25.0, 45.0, 6.0, 26.0, 46.0, 7.0, 27.0, 47.0, 8.0, 28.0, 48.0, 9.0, 29.0, 49.0, 10.0, 30.0, 50.0, 11.0, 31.0, 51.0, 12.0, 32.0, 52.0, 13.0, 33.0, 53.0, 14.0, 34.0, 54.0, 15.0, 35.0, 55.0, 16.0, 36.0, 56.0, 17.0, 37.0, 57.0, 18.0, 38.0, 58.0, 19.0, 39.0, 59.0, 20.0, 40.0, 60.0}); auto exp = NDArrayFactory::create<double>('c', {4, 5, 3}, {1.0, 21.0, 41.0, 2.0, 22.0, 42.0, 3.0, 23.0, 43.0, 4.0, 24.0, 44.0, 5.0, 25.0, 45.0, 6.0, 26.0, 46.0, 7.0, 27.0, 47.0, 8.0, 28.0, 48.0, 9.0, 29.0, 49.0, 10.0, 30.0, 50.0, 11.0, 31.0, 51.0, 12.0, 32.0, 52.0, 13.0, 33.0, 53.0, 14.0, 34.0, 54.0, 15.0, 35.0, 55.0, 16.0, 36.0, 56.0, 17.0, 37.0, 57.0, 18.0, 38.0, 58.0, 19.0, 39.0, 59.0, 20.0, 40.0, 60.0});
x.reshapei('c', {3, 4, 5}); x.reshapei('c', {3, 4, 5});
x.permutei({1, 2, 0}); x.permutei({1, 2, 0});
@ -523,9 +523,9 @@ TEST_F(NDArrayTest2, Test_PermuteEquality_3) {
} }
TEST_F(NDArrayTest2, Test_PermuteEquality_4) { TEST_F(NDArrayTest2, Test_PermuteEquality_4) {
auto x = NDArrayFactory::create<float>('c', {1, 60}); auto x = NDArrayFactory::create<double>('c', {1, 60});
x.linspace(1); x.linspace(1);
auto exp = NDArrayFactory::create<float>('c', {5, 3, 4}, {1.0, 6.0, 11.0, 16.0, 21.0, 26.0, 31.0, 36.0, 41.0, 46.0, 51.0, 56.0, 2.0, 7.0, 12.0, 17.0, 22.0, 27.0, 32.0, 37.0, 42.0, 47.0, 52.0, 57.0, 3.0, 8.0, 13.0, 18.0, 23.0, 28.0, 33.0, 38.0, 43.0, 48.0, 53.0, 58.0, 4.0, 9.0, 14.0, 19.0, 24.0, 29.0, 34.0, 39.0, 44.0, 49.0, 54.0, 59.0, 5.0, 10.0, 15.0, 20.0, 25.0, 30.0, 35.0, 40.0, 45.0, 50.0, 55.0, 60.0}); auto exp = NDArrayFactory::create<double>('c', {5, 3, 4}, {1.0, 6.0, 11.0, 16.0, 21.0, 26.0, 31.0, 36.0, 41.0, 46.0, 51.0, 56.0, 2.0, 7.0, 12.0, 17.0, 22.0, 27.0, 32.0, 37.0, 42.0, 47.0, 52.0, 57.0, 3.0, 8.0, 13.0, 18.0, 23.0, 28.0, 33.0, 38.0, 43.0, 48.0, 53.0, 58.0, 4.0, 9.0, 14.0, 19.0, 24.0, 29.0, 34.0, 39.0, 44.0, 49.0, 54.0, 59.0, 5.0, 10.0, 15.0, 20.0, 25.0, 30.0, 35.0, 40.0, 45.0, 50.0, 55.0, 60.0});
x.reshapei('c', {3, 4, 5}); x.reshapei('c', {3, 4, 5});
x.permutei({2, 0, 1}); x.permutei({2, 0, 1});
@ -539,9 +539,9 @@ TEST_F(NDArrayTest2, Test_PermuteEquality_4) {
} }
TEST_F(NDArrayTest2, Test_PermuteEquality_5) { TEST_F(NDArrayTest2, Test_PermuteEquality_5) {
auto x = NDArrayFactory::create<float>('c', {1, 60}); auto x = NDArrayFactory::create<double>('c', {1, 60});
x.linspace(1); x.linspace(1);
auto exp = NDArrayFactory::create<float>('c', {5, 4, 3}, auto exp = NDArrayFactory::create<double>('c', {5, 4, 3},
{1.0, 21.0, 41.0, 6.0, 26.0, 46.0, 11.0, 31.0, 51.0, 16.0, 36.0, 56.0, 2.0, 22.0, 42.0, 7.0, {1.0, 21.0, 41.0, 6.0, 26.0, 46.0, 11.0, 31.0, 51.0, 16.0, 36.0, 56.0, 2.0, 22.0, 42.0, 7.0,
27.0, 47.0, 12.0, 32.0, 52.0, 17.0, 37.0, 57.0, 3.0, 23.0, 43.0, 8.0, 28.0, 48.0, 13.0, 33.0, 27.0, 47.0, 12.0, 32.0, 52.0, 17.0, 37.0, 57.0, 3.0, 23.0, 43.0, 8.0, 28.0, 48.0, 13.0, 33.0,
53.0, 18.0, 38.0, 58.0, 4.0, 24.0, 44.0, 9.0, 29.0, 49.0, 14.0, 34.0, 54.0, 19.0, 39.0, 59.0, 53.0, 18.0, 38.0, 58.0, 4.0, 24.0, 44.0, 9.0, 29.0, 49.0, 14.0, 34.0, 54.0, 19.0, 39.0, 59.0,
@ -562,10 +562,10 @@ TEST_F(NDArrayTest2, Test_PermuteEquality_5) {
//////////////////////////////////////////////////////////////////// ////////////////////////////////////////////////////////////////////
TEST_F(NDArrayTest2, fillAsTriangular_test1) { TEST_F(NDArrayTest2, fillAsTriangular_test1) {
auto x = NDArrayFactory::create<float>('c', {4, 4}, {1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16}); auto x = NDArrayFactory::create<double>('c', {4, 4}, {1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16});
auto exp = NDArrayFactory::create<float>('c', {4, 4}, {1,0,0,0,5,6,0,0,9,10,11,0 ,13,14,15,16}); auto exp = NDArrayFactory::create<double>('c', {4, 4}, {1,0,0,0,5,6,0,0,9,10,11,0 ,13,14,15,16});
x.fillAsTriangular<float>(0., 0, 0, 'u'); x.fillAsTriangular<double>(0., 0, 0, 'u');
ASSERT_TRUE(exp.isSameShape(&x)); ASSERT_TRUE(exp.isSameShape(&x));
ASSERT_TRUE(exp.equalsTo(&x)); ASSERT_TRUE(exp.equalsTo(&x));
@ -575,10 +575,10 @@ TEST_F(NDArrayTest2, fillAsTriangular_test1) {
//////////////////////////////////////////////////////////////////// ////////////////////////////////////////////////////////////////////
TEST_F(NDArrayTest2, fillAsTriangular_test2) { TEST_F(NDArrayTest2, fillAsTriangular_test2) {
auto x = NDArrayFactory::create<float>('c', {4, 4}, {1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16}); auto x = NDArrayFactory::create<double>('c', {4, 4}, {1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16});
auto exp = NDArrayFactory::create<float>('c', {4, 4}, {0,0,0,0,5,0,0,0,9,10,0 ,0 ,13,14,15,0}); auto exp = NDArrayFactory::create<double>('c', {4, 4}, {0,0,0,0,5,0,0,0,9,10,0 ,0 ,13,14,15,0});
x.fillAsTriangular<float>(0., 0, -1, 'u'); x.fillAsTriangular<double>(0., 0, -1, 'u');
ASSERT_TRUE(exp.isSameShape(&x)); ASSERT_TRUE(exp.isSameShape(&x));
ASSERT_TRUE(exp.equalsTo(&x)); ASSERT_TRUE(exp.equalsTo(&x));
@ -588,10 +588,10 @@ TEST_F(NDArrayTest2, fillAsTriangular_test2) {
//////////////////////////////////////////////////////////////////// ////////////////////////////////////////////////////////////////////
TEST_F(NDArrayTest2, fillAsTriangular_test3) { TEST_F(NDArrayTest2, fillAsTriangular_test3) {
auto x = NDArrayFactory::create<float>('c', {4, 4}, {1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16}); auto x = NDArrayFactory::create<double>('c', {4, 4}, {1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16});
auto exp = NDArrayFactory::create<float>('c', {4, 4}, {1,2,3,4,0,6,7,8,0,0 ,11,12,0 ,0 , 0,16}); auto exp = NDArrayFactory::create<double>('c', {4, 4}, {1,2,3,4,0,6,7,8,0,0 ,11,12,0 ,0 , 0,16});
x.fillAsTriangular<float>(0., 0, 0, 'l'); x.fillAsTriangular<double>(0., 0, 0, 'l');
ASSERT_TRUE(exp.isSameShape(&x)); ASSERT_TRUE(exp.isSameShape(&x));
ASSERT_TRUE(exp.equalsTo(&x)); ASSERT_TRUE(exp.equalsTo(&x));
@ -601,10 +601,10 @@ TEST_F(NDArrayTest2, fillAsTriangular_test3) {
//////////////////////////////////////////////////////////////////// ////////////////////////////////////////////////////////////////////
TEST_F(NDArrayTest2, fillAsTriangular_test4) { TEST_F(NDArrayTest2, fillAsTriangular_test4) {
auto x = NDArrayFactory::create<float>('c', {4, 4}, {1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16}); auto x = NDArrayFactory::create<double>('c', {4, 4}, {1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16});
auto exp = NDArrayFactory::create<float>('c', {4, 4}, {0,2,3,4,0,0,7,8,0,0 , 0,12, 0, 0, 0, 0}); auto exp = NDArrayFactory::create<double>('c', {4, 4}, {0,2,3,4,0,0,7,8,0,0 , 0,12, 0, 0, 0, 0});
x.fillAsTriangular<float>(0., 1, 0, 'l'); x.fillAsTriangular<double>(0., 1, 0, 'l');
ASSERT_TRUE(exp.isSameShape(&x)); ASSERT_TRUE(exp.isSameShape(&x));
ASSERT_TRUE(exp.equalsTo(&x)); ASSERT_TRUE(exp.equalsTo(&x));
@ -612,11 +612,11 @@ TEST_F(NDArrayTest2, fillAsTriangular_test4) {
//////////////////////////////////////////////////////////////////// ////////////////////////////////////////////////////////////////////
TEST_F(NDArrayTest2, Test_DType_Conversion_1) { TEST_F(NDArrayTest2, Test_DType_Conversion_1) {
auto x = NDArrayFactory::create<float>('c', {2, 3}, {1, 2, 3, 4, 5, 6}); auto x = NDArrayFactory::create<double>('c', {2, 3}, {1, 2, 3, 4, 5, 6});
auto xd = x.template asT<double>(); auto xd = x.template asT<double>();
auto xf = xd->template asT<float>(); auto xf = xd->template asT<double>();
ASSERT_TRUE(x.isSameShape(xf)); ASSERT_TRUE(x.isSameShape(xf));
ASSERT_TRUE(x.equalsTo(xf)); ASSERT_TRUE(x.equalsTo(xf));
@ -766,8 +766,8 @@ TEST_F(NDArrayTest2, Test_Linspace_5) {
//////////////////////////////////////////////////////////////////// ////////////////////////////////////////////////////////////////////
TEST_F(NDArrayTest2, allTensorsAlongDimension_test1) { TEST_F(NDArrayTest2, allTensorsAlongDimension_test1) {
auto x = NDArrayFactory::create<float>('c', {4}, {1, 2, 3, 4}); auto x = NDArrayFactory::create<double>('c', {4}, {1, 2, 3, 4});
auto exp = NDArrayFactory::create<float>('c', {4}, {1, 2, 3, 4}); auto exp = NDArrayFactory::create<double>('c', {4}, {1, 2, 3, 4});
auto set = x.allTensorsAlongDimension({0}); auto set = x.allTensorsAlongDimension({0});
// set->at(0)->printShapeInfo(); // set->at(0)->printShapeInfo();
@ -836,8 +836,8 @@ TEST_F(NDArrayTest2, scalar_set_test2) {
} }
TEST_F(NDArrayTest2, big_dup_test) { TEST_F(NDArrayTest2, big_dup_test) {
// auto arr = NDArrayFactory::linspace<float>(1.0f, 10000000.0f, 100000000); // auto arr = NDArrayFactory::linspace<double>(1.0f, 10000000.0f, 100000000);
auto arr = NDArrayFactory::linspace<float>(1.0f, 1000.0f, 10000); auto arr = NDArrayFactory::linspace<double>(1.0f, 1000.0f, 10000);
auto dup = arr->dup('c'); auto dup = arr->dup('c');
ASSERT_EQ(*arr, *dup); ASSERT_EQ(*arr, *dup);

View File

@ -682,7 +682,7 @@ TEST_F(NativeOpsTests, ScalarTest_1) {
TEST_F(NativeOpsTests, ScalarTest_2) { TEST_F(NativeOpsTests, ScalarTest_2) {
auto x = NDArrayFactory::create<float>('c', {5, 5}); auto x = NDArrayFactory::create<float>('c', {5, 5});
auto y = NDArrayFactory::create<float>(10.); auto y = NDArrayFactory::create<float>(10.f);
auto exp = NDArrayFactory::create<bool>('c', {5,5}); auto exp = NDArrayFactory::create<bool>('c', {5,5});
auto z = NDArrayFactory::create<bool>('c', {5,5}); auto z = NDArrayFactory::create<bool>('c', {5,5});
@ -714,9 +714,9 @@ TEST_F(NativeOpsTests, ScalarTest_2) {
} }
TEST_F(NativeOpsTests, SummaryStatsScalarTest_1) { TEST_F(NativeOpsTests, SummaryStatsScalarTest_1) {
auto x = NDArrayFactory::create<float>('c', {5, 5}, {0.1, 0.2, 0.3, -0.3, -0.5, 0.5, 0.7, 0.9, 0.8, 0.1, 0.11, 0.12, 0.5, -0.8, -0.9, 0.4, 0.1, 0.2, 0.3, -0.3, -0.5, 0.2, 0.3, -0.3, -0.5}); auto x = NDArrayFactory::create<float>('c', {5, 5}, {0.1f, 0.2f, 0.3f, -0.3f, -0.5f, 0.5f, 0.7f, 0.9f, 0.8f, 0.1f, 0.11f, 0.12f, 0.5f, -0.8f, -0.9f, 0.4f, 0.1f, 0.2f, 0.3f, -0.3f, -0.5f, 0.2f, 0.3f, -0.3f, -0.5f});
auto exp = NDArrayFactory::create<float>(0.9); auto exp = NDArrayFactory::create<float>(0.9f);
auto z = NDArrayFactory::create<float>(0.21587136); auto z = NDArrayFactory::create<float>(0.21587136f);
Nd4jPointer extra[6]; Nd4jPointer extra[6];
#ifdef __CUDABLAS__ #ifdef __CUDABLAS__
@ -739,9 +739,9 @@ TEST_F(NativeOpsTests, SummaryStatsScalarTest_1) {
} }
TEST_F(NativeOpsTests, SummaryStatsScalarTest_2) { TEST_F(NativeOpsTests, SummaryStatsScalarTest_2) {
auto x = NDArrayFactory::create<float>('c', {5, 5}, {0.1, 0.2, 0.3, -0.3, -0.5, 0.5, 0.7, 0.9, 0.8, 0.1, 0.11, 0.12, 0.5, -0.8, -0.9, 0.4, 0.1, 0.2, 0.3, -0.3, -0.5, 0.2, 0.3, -0.3, -0.5}); auto x = NDArrayFactory::create<double>('c', {5, 5}, {0.1, 0.2, 0.3, -0.3, -0.5, 0.5, 0.7, 0.9, 0.8, 0.1, 0.11, 0.12, 0.5, -0.8, -0.9, 0.4, 0.1, 0.2, 0.3, -0.3, -0.5, 0.2, 0.3, -0.3, -0.5});
auto exp = NDArrayFactory::create<float>(0.9); auto exp = NDArrayFactory::create<double>(0.9);
auto z = NDArrayFactory::create<float>(0.21587136); auto z = NDArrayFactory::create<double>(0.21587136);
Nd4jPointer extra[6]; Nd4jPointer extra[6];
#ifdef __CUDABLAS__ #ifdef __CUDABLAS__
@ -764,9 +764,9 @@ TEST_F(NativeOpsTests, SummaryStatsScalarTest_2) {
} }
TEST_F(NativeOpsTests, SummaryStatsScalarTest_3) { TEST_F(NativeOpsTests, SummaryStatsScalarTest_3) {
auto x = NDArrayFactory::create<float>('c', {5, 5}, {0.1, 0.2, 0.3, -0.3, -0.5, 0.5, 0.7, 0.9, 0.8, 0.1, 0.11, 0.12, 0.5, -0.8, -0.9, 0.4, 0.1, 0.2, 0.3, -0.3, -0.5, 0.2, 0.3, -0.3, -0.5}); auto x = NDArrayFactory::create<double>('c', {5, 5}, {0.1, 0.2, 0.3, -0.3, -0.5, 0.5, 0.7, 0.9, 0.8, 0.1, 0.11, 0.12, 0.5, -0.8, -0.9, 0.4, 0.1, 0.2, 0.3, -0.3, -0.5, 0.2, 0.3, -0.3, -0.5});
auto exp = NDArrayFactory::create<float>(0.9); auto exp = NDArrayFactory::create<double>(0.9);
auto z = NDArrayFactory::create<float>(0.21587136); auto z = NDArrayFactory::create<double>(0.21587136);
Nd4jPointer extra[6]; Nd4jPointer extra[6];
#ifdef __CUDABLAS__ #ifdef __CUDABLAS__
@ -794,9 +794,9 @@ TEST_F(NativeOpsTests, SummaryStatsScalarTest_3) {
} }
TEST_F(NativeOpsTests, TransformTest_1) { TEST_F(NativeOpsTests, TransformTest_1) {
auto x = NDArrayFactory::create<float>('c', {5, 5}, {1, 4, 9, 16, 25, 36, 49, 64, 81, 100, 121, 144, 169, 196, 225, 256, 289, 324, 361, 400, 441, 484, 529, 576, 625}); auto x = NDArrayFactory::create<double>('c', {5, 5}, {1, 4, 9, 16, 25, 36, 49, 64, 81, 100, 121, 144, 169, 196, 225, 256, 289, 324, 361, 400, 441, 484, 529, 576, 625});
auto exp = NDArrayFactory::create<float>('c', {5, 5}); auto exp = NDArrayFactory::create<double>('c', {5, 5});
auto z = NDArrayFactory::create<float>('c', {5,5}); auto z = NDArrayFactory::create<double>('c', {5,5});
Nd4jPointer extra[6]; Nd4jPointer extra[6];
#ifdef __CUDABLAS__ #ifdef __CUDABLAS__
@ -821,7 +821,7 @@ TEST_F(NativeOpsTests, TransformTest_1) {
} }
TEST_F(NativeOpsTests, TransformTest_2) { TEST_F(NativeOpsTests, TransformTest_2) {
auto x = NDArrayFactory::create<float>('c', {5, 5}, {1, 4, 9, 16, 25, 36, 49, 64, 81, 100, 121, 144, 169, 196, 225, 256, 289, 324, 361, 400, 441, 484, 529, 576, 625}); auto x = NDArrayFactory::create<float>('c', {5, 5}, {1.f, 4.f, 9.f, 16.f, 25.f, 36.f, 49.f, 64.f, 81.f, 100.f, 121.f, 144.f, 169.f, 196.f, 225.f, 256.f, 289.f, 324.f, 361.f, 400.f, 441.f, 484.f, 529.f, 576.f, 625.f});
auto exp = NDArrayFactory::create<float>('c', {5, 5}); auto exp = NDArrayFactory::create<float>('c', {5, 5});
auto z = NDArrayFactory::create<float>('c', {5,5}); auto z = NDArrayFactory::create<float>('c', {5,5});
@ -878,10 +878,10 @@ TEST_F(NativeOpsTests, TransformTest_3) {
} }
TEST_F(NativeOpsTests, TransformTest_4) { TEST_F(NativeOpsTests, TransformTest_4) {
auto x = NDArrayFactory::create<float>('c', {5, 5}, {0, 1, 2, 3, 2, 1, 0, 1.57, 1.57, 1.57, 3.141592, 3.141592, auto x = NDArrayFactory::create<double>('c', {5, 5}, {0, 1, 2, 3, 2, 1, 0, 1.57, 1.57, 1.57, 3.141592, 3.141592,
3.141592, 0, 0, 0, 0, 1, 1, 2, 2, 2, 1, 0, 0}); 3.141592, 0, 0, 0, 0, 1, 1, 2, 2, 2, 1, 0, 0});
auto exp = NDArrayFactory::create<float>('c', {5, 5}); auto exp = NDArrayFactory::create<double>('c', {5, 5});
auto z = NDArrayFactory::create<float>('c', {5,5}, {1., 0.540302, -0.416147, -0.989992, -0.416147, 0.540302, 1.0, auto z = NDArrayFactory::create<double>('c', {5,5}, {1., 0.540302, -0.416147, -0.989992, -0.416147, 0.540302, 1.0,
0.000796, 0.000796, 0.000796, -1, -1, -1, 1., 1., 1.0, 1.0, 0.000796, 0.000796, 0.000796, -1, -1, -1, 1., 1., 1.0, 1.0,
0.540302, 0.540302, -0.416147, -0.416147, -0.416147, 0.540302, 1., 1.}); 0.540302, 0.540302, -0.416147, -0.416147, -0.416147, 0.540302, 1., 1.});
@ -909,7 +909,7 @@ TEST_F(NativeOpsTests, TransformTest_4) {
TEST_F(NativeOpsTests, ScalarTadTest_1) { TEST_F(NativeOpsTests, ScalarTadTest_1) {
auto x = NDArrayFactory::create<float>('c', {5, 5}); auto x = NDArrayFactory::create<float>('c', {5, 5});
auto y = NDArrayFactory::create<float>(10.); auto y = NDArrayFactory::create<float>(10.f);
auto exp = NDArrayFactory::create<float>('c', {5,5}); auto exp = NDArrayFactory::create<float>('c', {5,5});
auto z = NDArrayFactory::create<float>('c', {5,5}); auto z = NDArrayFactory::create<float>('c', {5,5});
@ -1433,9 +1433,9 @@ TEST_F(NativeOpsTests, MapTests_1) {
} }
TEST_F(NativeOpsTests, CustomOpTest_1) { TEST_F(NativeOpsTests, CustomOpTest_1) {
auto x = NDArrayFactory::create<float>('c', {1, 6}, {1, 2, 3, 4, 5, 6}); auto x = NDArrayFactory::create<float>('c', {1, 6}, {1.f, 2.f, 3.f, 4.f, 5.f, 6.f});
auto z = NDArrayFactory::create<float>('c', {6}); auto z = NDArrayFactory::create<float>('c', {6});
auto e = NDArrayFactory::create<float>('c', {6}, {1, 2, 3, 4, 5, 6}); auto e = NDArrayFactory::create<float>('c', {6}, {1.f, 2.f, 3.f, 4.f, 5.f, 6.f});
nd4j::ops::squeeze op; nd4j::ops::squeeze op;

View File

@ -23,11 +23,11 @@
class EqualsTest : public testing::Test { class EqualsTest : public testing::Test {
public: public:
Nd4jLong firstShapeBuffer[8] = {2,1,2,1,1,0,1,102}; Nd4jLong firstShapeBuffer[8] = {2,1,2,1,1,0,1,102};
float data[2] = {1.0,7.0}; float data[2] = {1.0f, 7.0f};
Nd4jLong secondShapeBuffer[8] = {2,2,1,6,1,0,6,99}; Nd4jLong secondShapeBuffer[8] = {2,2,1,6,1,0,6,99};
float dataSecond[12] = {1.0,2.0,3.0,4.0,5.0,6.0,7.0,8.0,9.0,10.0,11.0,12.0}; float dataSecond[12] = {1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f, 7.0f, 8.0f, 9.0f, 10.0f, 11.0f, 12.0f};
int opNum = 4; int opNum = 4;
float extraArgs[1] = {1e-6}; float extraArgs[1] = {1e-6f};
int dimension[1] = {2147483647}; int dimension[1] = {2147483647};
int dimensionLength = 1; int dimensionLength = 1;
}; };

File diff suppressed because one or more lines are too long