/******************************************************************************* * Copyright (c) 2015-2018 Skymind, Inc. * * This program and the accompanying materials are made available under the * terms of the Apache License, Version 2.0 which is available at * https://www.apache.org/licenses/LICENSE-2.0. * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the * License for the specific language governing permissions and limitations * under the License. * * SPDX-License-Identifier: Apache-2.0 ******************************************************************************/ // // Created by raver119 on 09.02.18. // #include "testlayers.h" #include #include #include #include #include using namespace nd4j; using namespace nd4j::graph; class DeclarableOpsTests7 : public testing::Test { public: DeclarableOpsTests7() { printf("\n"); fflush(stdout); } }; template class TypedDeclarableOpsTests7 : public testing::Test { public: TypedDeclarableOpsTests7() { printf("\n"); fflush(stdout); } }; typedef ::testing::Types TestingTypes; TYPED_TEST_CASE(TypedDeclarableOpsTests7, TestingTypes); TEST_F(DeclarableOpsTests7, Test_CHOOSE_SCALAR_LARGE) { double inputData[150] = { 0, 0.51, 0.68, 0.69, 0.86, 0.91, 0.96, 0.97, 0.97, 1.03, 1.13, 1.16, 1.16, 1.17, 1.19, 1.25, 1.25, 1.26, 1.27, 1.28, 1.29, 1.29, 1.29, 1.30, 1.31, 1.32, 1.33, 1.33, 1.35, 1.35, 1.36, 1.37, 1.38, 1.40, 1.41, 1.42, 1.43, 1.44, 1.44, 1.45, 1.45, 1.47, 1.47, 1.51, 1.51, 1.51, 1.52, 1.53, 1.56, 1.57, 1.58, 1.59, 1.61, 1.62, 1.63, 1.63, 1.64, 1.64, 1.66, 1.66, 1.67, 1.67, 1.70, 1.70, 1.70, 1.72, 1.72, 1.72, 1.72, 1.73, 1.74, 1.74, 1.76, 1.76, 1.77, 1.77, 1.80, 1.80, 1.81, 1.82, 1.83, 1.83, 1.84, 1.84, 1.84, 1.85, 1.85, 1.85, 1.86, 1.86, 1.87, 1.88, 1.89, 1.89, 1.89, 1.89, 1.89, 1.91, 1.91, 1.91, 1.92, 1.94, 1.95, 1.97, 1.98, 1.98, 1.98, 1.98, 1.98, 1.99, 2, 2, 2.01, 2.01, 2.02, 2.03, 2.03, 2.03, 2.04, 2.04, 2.05, 2.06, 2.07, 2.08, 2.08, 2.08, 2.08, 2.09, 2.09, 2.10, 2.10, 2.11, 2.11, 2.11, 2.12, 2.12, 2.13, 2.13, 2.14, 2.14, 2.14, 2.14, 2.15, 2.15, 2.16, 2.16, 2.16, 2.16, 2.16, 2.17 }; auto x = NDArrayFactory::create(inputData,'c',{1,149}); nd4j::ops::choose op; //greater than test auto result = op.execute({&x}, {0.0},{3}); ASSERT_EQ(Status::OK(), result->status()); auto z = result->at(1); z->printIndexedBuffer("CHOOSE test"); ASSERT_EQ(148,z->e(0)); //ASSERT_TRUE(exp.isSameShape(z)); delete result; } TEST_F(DeclarableOpsTests7, Test_CHOOSE_SCALAR_ZERO) { std::vector data; for(Nd4jLong i = 0; i < 4; i++) { data.push_back(i); } auto x = NDArrayFactory::create('c',{1,4},data); nd4j::ops::choose op; //greater than test auto result = op.execute({&x}, {0.0},{3}); ASSERT_EQ(Status::OK(), result->status()); auto z = result->at(1); auto array = *z; ASSERT_EQ(3,array.e(0)); //ASSERT_TRUE(exp.isSameShape(z)); delete result; } TEST_F(DeclarableOpsTests7, Test_CHOOSE_SCALAR) { std::vector data; for(Nd4jLong i = 0; i < 4; i++) { data.push_back(i); } auto x = NDArrayFactory::create('c',{1,4},data); auto scalar = NDArrayFactory::create('c',{1,1},{0.0}); nd4j::ops::choose op; //greater than test auto result = op.execute({&x,&scalar}, {1.0},{3}); ASSERT_EQ(Status::OK(), result->status()); auto z = result->at(0); ASSERT_EQ(3, z->lengthOf()); //ASSERT_TRUE(exp.isSameShape(z)); delete result; } TEST_F(DeclarableOpsTests7, Test_CHOOSE_SCALAR_LEFT) { std::vector data; for(Nd4jLong i = 0; i < 4; i++) { data.push_back(i); } auto x = NDArrayFactory::create('c',{1,4},data); auto scalar = NDArrayFactory::create('c',{1,1},{0.0}); nd4j::ops::choose op; //greater than test auto result = op.execute({&scalar,&x}, {1.0},{3}); ASSERT_EQ(Status::OK(), result->status()); auto z = result->at(0); ASSERT_EQ(3,z->lengthOf()); //ASSERT_TRUE(exp.isSameShape(z)); delete result; } TEST_F(DeclarableOpsTests7, Test_CHOOSE_ONLY_SCALAR) { std::vector data; for(Nd4jLong i = 0; i < 4; i++) { data.push_back(i); } auto x = NDArrayFactory::create('c',{1,4},data); nd4j::ops::choose op; //greater than test auto result = op.execute({&x}, {1.0},{3}); ASSERT_EQ(Status::OK(), result->status()); auto z = result->at(0); ASSERT_EQ(2,z->lengthOf()); //ASSERT_TRUE(exp.isSameShape(z)); delete result; } TEST_F(DeclarableOpsTests7, Test_CHOOSE_ONLY_SCALAR_GTE) { std::vector data; for(Nd4jLong i = 0; i < 4; i++) { data.push_back(i); } auto x = NDArrayFactory::create('c',{1,4},data); nd4j::ops::choose op; //greater than test auto result = op.execute({&x}, {1.0},{5}); ASSERT_EQ(Status::OK(), result->status()); auto z = result->at(0); ASSERT_EQ(3,z->lengthOf()); //ASSERT_TRUE(exp.isSameShape(z)); delete result; } TEST_F(DeclarableOpsTests7, TEST_WHERE) { std::vector data; std::vector mask; std::vector put; std::vector resultData; std::vector assertion; for(Nd4jLong i = 0; i < 4; i++) { data.push_back(i); if(i > 1) { assertion.push_back(5.0); mask.push_back(true); } else { assertion.push_back(i); mask.push_back(false); } put.push_back(5.0); resultData.push_back(0.0); } auto x = NDArrayFactory::create('c',{1,4},data); auto maskArr = NDArrayFactory::create('c',{1,4},mask); auto putArr = NDArrayFactory::create('c',{1,4},put); auto resultArr = NDArrayFactory::create('c',{1,4},resultData); nd4j::ops::where_np op; //greater than test // Nd4jStatus execute(std::initializer_list*> inputs, std::initializer_list*> outputs , std::initializer_list tArgs, std::initializer_list iArgs, bool isInplace = false); auto result = op.execute({&maskArr,&x,&putArr},{&resultArr}, {},{3}, {}, false); ASSERT_EQ(Status::OK(), result); for(int i = 0; i < 4; i++) ASSERT_EQ(assertion[i],resultArr.e(i)); // auto z = result->at(0); //ASSERT_EQ(4,z->lengthOf()); //ASSERT_TRUE(exp.isSameShape(z)); } TEST_F(DeclarableOpsTests7, TEST_WHERE_MASK) { double x[300] = {1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0}; double z[300] = {1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0}; bool mask[300] = {0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1}; double put[200] = {0.99666107,0.9867112,0.97686064,0.9671082,0.95745337,0.9478948,0.9384318,0.92906314,0.9197881,0.91060543,0.9015147,0.8925147,0.8836044,0.8747831,0.86605,0.85740393,0.8488442,0.84037,0.83198035,0.8236745,0.8154515,0.8073106,0.79925096,0.79127187,0.7833724,0.77555174,0.76780915,0.7601439,0.75255525,0.7450422,0.7376043,0.73024046,0.72295034,0.715733,0.7085876,0.7015135,0.69451016,0.68757665,0.6807124,0.6739167,0.66718876,0.66052806,0.6539338,0.6474054,0.6409421,0.6345435,0.6282087,0.6219371,0.6157281,0.60958105,0.6034956,0.59747064,0.5915059,0.5856007,0.57975453,0.5739667,0.5682366,0.5625637,0.5569475,0.5513874,0.54588276,0.540433,0.53503764,0.5296962,0.52440816,0.51917285,0.5139898,0.5088585,0.50377846,0.4987491,0.4937699,0.48884052,0.48396033,0.47912875,0.47434545,0.4696099,0.46492168,0.46028027,0.45568514,0.4511359,0.44663212,0.4421733,0.43775895,0.43338865,0.42906195,0.42477852,0.4205379,0.41633952,0.41218308,0.40806815,0.40399432,0.3999611,0.3959682,0.39201516,0.38810158,0.384227,0.38039115,0.37659356,0.37283397,0.3691119,0.36542687,0.36177874,0.35816705,0.3545914,0.35105142,0.34754673,0.34407702,0.34064204,0.33724132,0.3338745,0.33054137,0.3272415,0.32397458,0.32074028,0.3175382,0.31436813,0.31122974,0.3081226,0.30504647,0.30200112,0.2989862,0.29600134,0.29304633,0.2901207,0.28722438,0.28435695,0.2815181,0.27870762,0.27592525,0.27317056,0.27044344,0.26774356,0.26507056,0.2624243,0.25980446,0.25721073,0.25464293,0.25210077,0.249584,0.24709237,0.24462552,0.24218333,0.23976555,0.23737194,0.23500215,0.23265606,0.23033342,0.22803394,0.22575743,0.2235036,0.22127232,0.21906327,0.21687631,0.21471114,0.21256764,0.21044552,0.20834461,0.20626466,0.20420544,0.20216681,0.20014854,0.19815037,0.19617215,0.19421372,0.19227484,0.19035533,0.18845497,0.18657354,0.18471093,0.18286693,0.18104129,0.17923392,0.17744459,0.17567308,0.1739193,0.17218304,0.17046405,0.16876228,0.16707748,0.16540948,0.16375816,0.16212334,0.16050482,0.15890247,0.15731607,0.15574552,0.15419069,0.15265137,0.15112738,0.14961864,0.14812498,0.14664622,0.1451822,0.14373279,0.14229788,0.14087726,0.13947085,0.13807845,0.13669999,0.13533528}; double assertion[300] = {1.000000000000000000e+00,1.000000000000000000e+00,1.000000000000000000e+00,1.000000000000000000e+00,1.000000000000000000e+00,1.000000000000000000e+00,1.000000000000000000e+00,1.000000000000000000e+00,1.000000000000000000e+00,1.000000000000000000e+00,1.000000000000000000e+00,1.000000000000000000e+00,1.000000000000000000e+00,1.000000000000000000e+00,1.000000000000000000e+00,1.000000000000000000e+00,1.000000000000000000e+00,1.000000000000000000e+00,1.000000000000000000e+00,1.000000000000000000e+00,1.000000000000000000e+00,1.000000000000000000e+00,1.000000000000000000e+00,1.000000000000000000e+00,1.000000000000000000e+00,1.000000000000000000e+00,1.000000000000000000e+00,1.000000000000000000e+00,1.000000000000000000e+00,1.000000000000000000e+00,1.000000000000000000e+00,1.000000000000000000e+00,1.000000000000000000e+00,1.000000000000000000e+00,1.000000000000000000e+00,1.000000000000000000e+00,1.000000000000000000e+00,1.000000000000000000e+00,1.000000000000000000e+00,1.000000000000000000e+00,1.000000000000000000e+00,1.000000000000000000e+00,1.000000000000000000e+00,1.000000000000000000e+00,1.000000000000000000e+00,1.000000000000000000e+00,1.000000000000000000e+00,1.000000000000000000e+00,1.000000000000000000e+00,1.000000000000000000e+00,1.000000000000000000e+00,1.000000000000000000e+00,1.000000000000000000e+00,1.000000000000000000e+00,1.000000000000000000e+00,1.000000000000000000e+00,1.000000000000000000e+00,1.000000000000000000e+00,1.000000000000000000e+00,1.000000000000000000e+00,1.000000000000000000e+00,1.000000000000000000e+00,1.000000000000000000e+00,1.000000000000000000e+00,1.000000000000000000e+00,1.000000000000000000e+00,1.000000000000000000e+00,1.000000000000000000e+00,1.000000000000000000e+00,1.000000000000000000e+00,1.000000000000000000e+00,1.000000000000000000e+00,1.000000000000000000e+00,1.000000000000000000e+00,1.000000000000000000e+00,1.000000000000000000e+00,1.000000000000000000e+00,1.000000000000000000e+00,1.000000000000000000e+00,1.000000000000000000e+00,1.000000000000000000e+00,1.000000000000000000e+00,1.000000000000000000e+00,1.000000000000000000e+00,1.000000000000000000e+00,1.000000000000000000e+00,1.000000000000000000e+00,1.000000000000000000e+00,1.000000000000000000e+00,1.000000000000000000e+00,1.000000000000000000e+00,1.000000000000000000e+00,1.000000000000000000e+00,1.000000000000000000e+00,1.000000000000000000e+00,1.000000000000000000e+00,1.000000000000000000e+00,1.000000000000000000e+00,1.000000000000000000e+00,1.000000000000000000e+00,9.966611049434810354e-01,9.867111603284486332e-01,9.768605487739230320e-01,9.671082786103732953e-01,9.574533680683808834e-01,9.478948451798039354e-01,9.384317476799283186e-01,9.290631229105962285e-01,9.197880277243004610e-01,9.106055283892373620e-01,9.015147004953073528e-01,8.925146288610534828e-01,8.836044074415293492e-01,8.747831392370875037e-01,8.660499362030764647e-01,8.574039191604412302e-01,8.488442177072155204e-01,8.403699701308978698e-01,8.319803233217017979e-01,8.236744326866727306e-01,8.154514620646623468e-01,8.073105836421510251e-01,7.992509778699116163e-01,7.912718333805045523e-01,7.833723469065965173e-01,7.755517232000953554e-01,7.678091749520912224e-01,7.601439227135980969e-01,7.525551948170853267e-01,7.450422272987937689e-01,7.376042638218265335e-01,7.302405556000080011e-01,7.229503613225031211e-01,7.157329470791886639e-01,7.085875862867698771e-01,7.015135596156351072e-01,6.945101549174396149e-01,6.875766671534137009e-01,6.807123983233853703e-01,6.739166573955123196e-01,6.671887602367149173e-01,6.605280295438040739e-01,6.539337947752965619e-01,6.474053920839111242e-01,6.409421642497381555e-01,6.345434606140767375e-01,6.282086370139332576e-01,6.219370557171712832e-01,6.157280853583116942e-01,6.095811008749726367e-01,6.034954834449430816e-01,5.974706204238864338e-01,5.915059052836644238e-01,5.856007375512777280e-01,5.797545227484157682e-01,5.739666723316099173e-01,5.682366036329845604e-01,5.625637398015992385e-01,5.569475097453767676e-01,5.513873480736106725e-01,5.458826950400470501e-01,5.404329964865340896e-01,5.350377037872348085e-01,5.296962737933965659e-01,5.244081687786711354e-01,5.191728563849821176e-01,5.139898095689314772e-01,5.088585065487419845e-01,5.037784307517284565e-01,4.987490707622945774e-01,4.937699202704479151e-01,4.888404780208293054e-01,4.839602477622509946e-01,4.791287381977387683e-01,4.743454629350723484e-01,4.696099404378203390e-01,4.649216939768630041e-01,4.602802515824001017e-01,4.556851459964368911e-01,4.511359146257447605e-01,4.466320994952920342e-01,4.421732472021388527e-01,4.377589088697927955e-01,4.333886401030203062e-01,4.290620009431086457e-01,4.247785558235752101e-01,4.205378735263185508e-01,4.163395271382073215e-01,4.121830940081024908e-01,4.080681557043087104e-01,4.039942979724505667e-01,3.999611106937689398e-01,3.959681878438343627e-01,3.920151274516718853e-01,3.881015315592946102e-01,3.842270061816405180e-01,3.803911612669100828e-01,3.765936106572991271e-01,3.728339720501240850e-01,3.691118669593352886e-01,3.654269206774144463e-01,3.617787622376523182e-01,3.581670243768036999e-01,3.545913434981138868e-01,3.510513596347161203e-01,3.475467164133922426e-01,3.440770610186974499e-01,3.406420441574410929e-01,3.372413200235238606e-01,3.338745462631242389e-01,3.305413839402346898e-01,3.272414975025391692e-01,3.239745547476344245e-01,3.207402267895853032e-01,3.175381880258169032e-01,3.143681161043347383e-01,3.112296918912743071e-01,3.081225994387726264e-01,3.050465259531625062e-01,3.020011617634821843e-01,2.989862002903017069e-01,2.960013380148582840e-01,2.930462744485015647e-01,2.901207121024425017e-01,2.872243564578055852e-01,2.843569159359789489e-01,2.815181018692606840e-01,2.787076284717992514e-01,2.759252128108221624e-01,2.731705747781537075e-01,2.704434370620155681e-01,2.677435251191103149e-01,2.650705671469821278e-01,2.624242940566549609e-01,2.598044394455423789e-01,2.572107395706292876e-01,2.546429333219200064e-01,2.521007621961529055e-01,2.495839702707757235e-01,2.470923041781825646e-01,2.446255130802063582e-01,2.421833486428674187e-01,2.397655650113727777e-01,2.373719187853666479e-01,2.350021689944260528e-01,2.326560770738031469e-01,2.303334068404078172e-01,2.280339244690317291e-01,2.257573984688081292e-01,2.235035996599082919e-01,2.212723011504689752e-01,2.190632783137518302e-01,2.168763087655291855e-01,2.147111723416972873e-01,2.125676510761114746e-01,2.104455291786438698e-01,2.083445930134591173e-01,2.062646310775079761e-01,2.042054339792348794e-01,2.021667944174980747e-01,2.001485071607009836e-01,1.981503690261307848e-01,1.961721788595043592e-01,1.942137375147174327e-01,1.922748478337968081e-01,1.903553146270518526e-01,1.884549446534251604e-01,1.865735466010380594e-01,1.847109310679319050e-01,1.828669105430000552e-01,1.810412993871116094e-01,1.792339138144224131e-01,1.774445718738737465e-01,1.756730934308744496e-01,1.739193001491673995e-01,1.721830154728755669e-01,1.704640646087285105e-01,1.687622745084652875e-01,1.670774738514141378e-01,1.654094930272448083e-01,1.637581641188943782e-01,1.621233208856623365e-01,1.605047987464754966e-01,1.589024347633189727e-01,1.573160676248336609e-01,1.557455376300762306e-01,1.541906866724424563e-01,1.526513582237501165e-01,1.511273973184814046e-01,1.496186505381822129e-01,1.481249659960175158e-01,1.466461933214808777e-01,1.451821836452561187e-01,1.437327895842310799e-01,1.422978652266598532e-01,1.408772661174743090e-01,1.394708492437411185e-01,1.380784730202649913e-01,1.366999972753347725e-01,1.353352832366127023e-01}; Nd4jLong threeHundredShapePointer[8] = {2,1,300,1,1,0,1,99}; Nd4jLong twoHundredShapePointer[8] = {2,1,200,1,1,0,1,99}; nd4j::ops::where_np op; ArrayOptions::setDataType(threeHundredShapePointer, nd4j::DataType::DOUBLE); ArrayOptions::setDataType(twoHundredShapePointer, nd4j::DataType::DOUBLE); NDArray xArr(x,threeHundredShapePointer); NDArray putArr(put,twoHundredShapePointer); NDArray resultArr(z,threeHundredShapePointer); resultArr.assign(0.0); ArrayOptions::setDataType(threeHundredShapePointer, nd4j::DataType::BOOL); NDArray maskArr(mask,threeHundredShapePointer); ArrayOptions::setDataType(threeHundredShapePointer, nd4j::DataType::DOUBLE); NDArray assertArr(assertion, threeHundredShapePointer); Nd4jStatus result = op.execute({&maskArr, &xArr, &putArr},{&resultArr},{},{},{}); ASSERT_EQ(Status::OK(),result); ASSERT_TRUE(assertArr.isSameShape(resultArr)); ASSERT_TRUE (assertArr.equalsTo(resultArr)); } TEST_F(DeclarableOpsTests7, TEST_WHERE_SCALAR) { std::vector data; std::vector mask; std::vector put; std::vector resultData; std::vector assertion; for(Nd4jLong i = 0; i < 4; i++) { data.push_back(i); if(i > 1) { assertion.push_back(5.0); mask.push_back(true); } else { assertion.push_back(i); mask.push_back(false); } resultData.push_back(0.0); } put.push_back(5.0); auto x = NDArrayFactory::create('c',{1,4},data); auto maskArr = NDArrayFactory::create('c',{1,4},mask); auto putArr = NDArrayFactory::create('c',{1,1},put); auto resultArr = NDArrayFactory::create('c',{1,4},resultData); nd4j::ops::where_np op; //greater than test // Nd4jStatus execute(std::initializer_list*> inputs, std::initializer_list*> outputs , std::initializer_list tArgs, std::initializer_list iArgs, bool isInplace = false); auto result = op.execute({&maskArr,&x,&putArr},{&resultArr}, {},{3}, {}, false); // ASSERT_EQ(Status::OK(), result->status()); for(int i = 0; i < 4; i++) ASSERT_EQ(assertion[i],resultArr.e(i)); // auto z = result->at(0); //ASSERT_EQ(4,z->lengthOf()); //ASSERT_TRUE(exp.isSameShape(z)); } //////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests7, TestMatrixDiagPart_1) { auto x = NDArrayFactory::create('c', {2, 4, 4}, {1., 0., 0., 0., 0., 2., 0., 0., 0., 0., 3., 0., 0., 0., 0., 4.,5., 0., 0., 0., 0., 6., 0., 0., 0., 0., 7., 0., 0., 0., 0., 8.}); auto z = NDArrayFactory::create('c', {2, 4}, {1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0}); nd4j::ops::matrix_diag_part op; auto result = op.execute({&x}, {}, {}); ASSERT_EQ(result->status(), Status::OK()); ASSERT_TRUE(z.equalsTo(result->at(0))); delete result; } //////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests7, TestMatrixDiagPart_2) { auto x = NDArrayFactory::create('c', {2, 3, 4}, {1., 0., 0., 0., 0., 2., 0., 0., 0., 0., 3., 0.,5., 0., 0., 0., 0., 6., 0., 0., 0., 0., 7., 0.}); auto z = NDArrayFactory::create('c', {2, 3}, {1, 2, 3, 5, 6, 7}); nd4j::ops::matrix_diag_part op; auto result = op.execute({&x}, {}, {}); ASSERT_EQ(result->status(), Status::OK()); ASSERT_TRUE(z.equalsTo(result->at(0))); delete result; } //////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests7, TestMatrixDiag_1) { auto z = NDArrayFactory::create('c', {2, 4, 4}, {1., 0., 0., 0., 0., 2., 0., 0., 0., 0., 3., 0., 0., 0., 0., 4.,5., 0., 0., 0., 0., 6., 0., 0., 0., 0., 7., 0., 0., 0., 0., 8.}); auto x = NDArrayFactory::create('c', {2, 4}, {1, 2, 3, 4, 5, 6, 7, 8}); nd4j::ops::matrix_diag op; auto result = op.execute({&x}, {}, {}); ASSERT_EQ(result->status(), Status::OK()); ASSERT_TRUE(z.equalsTo(result->at(0))); delete result; } //////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests7, TestMatrixDiag_2) { auto z = NDArrayFactory::create('c', {2, 3, 3}, {1., 0., 0., 0., 2., 0., 0., 0., 3.,5., 0., 0., 0., 6., 0.,0., 0., 7.}); auto x = NDArrayFactory::create('c', {2, 3}, {1, 2, 3, 5, 6, 7}); nd4j::ops::matrix_diag op; auto result = op.execute({&x}, {}, {}); ASSERT_EQ(result->status(), Status::OK()); ASSERT_TRUE(z.equalsTo(result->at(0))); delete result; } //////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests7, TestRandomCrop_1) { auto x = NDArrayFactory::create('c', {2, 2, 4}, {1.8, 2.5, 4., 9., 2.1, 2.4, 3., 9.,2.1, 2.1, 0.7, 0.1,3., 4.2, 2.2, 1. }); auto shape = NDArrayFactory::create({1, 2, 3}); nd4j::ops::random_crop op; auto result = op.execute({&x, &shape}, {}, {}); ASSERT_EQ(result->status(), Status::OK()); //result->at(0)->printIndexedBuffer("Output"); // ASSERT_TRUE(z.equalsTo(result->at(0))); delete result; } //////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests7, TestRandomCrop_2) { auto x = NDArrayFactory::create('c', {2, 2, 4}, {1.8, 2.5, 4., 9., 2.1, 2.4, 3., 9.,2.1, 2.1, 0.7, 0.1,3., 4.2, 2.2, 1. }); auto shape = NDArrayFactory::create({2, 2, 2}); nd4j::ops::random_crop op; auto result = op.execute({&x, &shape}, {}, {}); ASSERT_EQ(result->status(), Status::OK()); //result->at(0)->printIndexedBuffer("Output"); // ASSERT_TRUE(z.equalsTo(result->at(0))); delete result; } //////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests7, Test_Dynamic_Stitch_119) { auto indices0 = NDArrayFactory::create('c', {2}, {1, 10}); auto indices1 = NDArrayFactory::create('c', {2, 3}, {0, 7, 9, 5, 8, 3}); auto indices2 = NDArrayFactory::create('c', {3, 1}, {6, 4, 2}); auto data0 = NDArrayFactory::create('c', {2,5,4}, {1.f, 2.f, 3.f, 4.f,5.f, 6.f, 7.f, 8.f,9.f, 10.f, 11.f, 12.f,13.f, 14.f, 15.f, 16.f,17.f, 18.f, 19.f, 20.f,21.f, 22.f, 23.f, 24.f, 25.f, 26.f, 27.f, 28.f,29.f, 30.f, 31.f, 32.f,33.f, 34.f, 35.f, 36.f,37.f, 38.f, 39.f, 40.f}); auto data1 = NDArrayFactory::create('c', {2,3,5,4},{1.f, 2.f, 3.f, 4.f,5.f, 6.f, 7.f, 8.f,9.f, 10.f, 11.f, 12.f,13.f, 14.f, 15.f, 16.f,17.f, 18.f, 19.f, 20.f,21.f, 22.f, 23.f, 24.f,25.f, 26.f, 27.f, 28.f, 29.f, 30.f, 31.f, 32.f,33.f, 34.f, 35.f, 36.f,37.f, 38.f, 39.f, 40.f,41.f, 42.f, 43.f, 44.f,45.f, 46.f, 47.f, 48.f,49.f, 50.f, 51.f, 52.f,53.f, 54.f, 55.f, 56.f, 57.f, 58.f, 59.f, 60.f,61.f, 62.f, 63.f, 64.f,65.f, 66.f, 67.f, 68.f,69.f, 70.f, 71.f, 72.f,73.f, 74.f, 75.f, 76.f,77.f, 78.f, 79.f, 80.f,81.f, 82.f, 83.f, 84.f, 85.f, 86.f, 87.f, 88.f,89.f, 90.f, 91.f, 92.f,93.f, 94.f, 95.f, 96.f,97.f, 98.f, 99.f, 100.f,101.f, 102.f, 103.f, 104.f,105.f, 106.f, 107.f, 108.f,109.f, 110.f, 111.f, 112.f, 113.f, 114.f, 115.f, 116.f,117.f, 118.f, 119.f, 120.f}); auto data2 = NDArrayFactory::create('c', {3,1,5,4}, {1.f, 2.f, 3.f, 4.f,5.f, 6.f, 7.f, 8.f,9.f, 10.f, 11.f, 12.f,13.f, 14.f, 15.f, 16.f,17.f, 18.f, 19.f, 20.f,21.f, 22.f, 23.f, 24.f, 25.f, 26.f, 27.f, 28.f,29.f, 30.f, 31.f, 32.f,33.f, 34.f, 35.f, 36.f,37.f, 38.f, 39.f, 40.f,41.f, 42.f, 43.f, 44.f,45.f, 46.f, 47.f, 48.f, 49.f, 50.f, 51.f, 52.f,53.f, 54.f, 55.f, 56.f,57.f, 58.f, 59.f, 60.f}); auto exp = NDArrayFactory::create('c', {11, 5, 4}, {1.f, 2.f, 3.f, 4.f,5.f, 6.f, 7.f, 8.f,9.f, 10.f, 11.f, 12.f,13.f, 14.f, 15.f, 16.f,17.f, 18.f, 19.f, 20.f,1.f, 2.f, 3.f, 4.f, 5.f, 6.f, 7.f, 8.f,9.f, 10.f, 11.f, 12.f,13.f, 14.f, 15.f, 16.f,17.f, 18.f, 19.f, 20.f,41.f, 42.f, 43.f, 44.f,45.f, 46.f, 47.f, 48.f, 49.f, 50.f, 51.f, 52.f,53.f, 54.f, 55.f, 56.f,57.f, 58.f, 59.f, 60.f,101.f, 102.f, 103.f, 104.f,105.f, 106.f, 107.f, 108.f,109.f, 110.f, 111.f, 112.f, 113.f, 114.f, 115.f, 116.f,117.f, 118.f, 119.f, 120.f,21.f, 22.f, 23.f, 24.f,25.f, 26.f, 27.f, 28.f,29.f, 30.f, 31.f, 32.f,33.f, 34.f, 35.f, 36.f, 37.f, 38.f, 39.f, 40.f,61.f, 62.f, 63.f, 64.f,65.f, 66.f, 67.f, 68.f,69.f, 70.f, 71.f, 72.f,73.f, 74.f, 75.f, 76.f,77.f, 78.f, 79.f, 80.f, 1.f, 2.f, 3.f, 4.f,5.f, 6.f, 7.f, 8.f,9.f, 10.f, 11.f, 12.f,13.f, 14.f, 15.f, 16.f,17.f, 18.f, 19.f, 20.f,21.f, 22.f, 23.f, 24.f, 25.f, 26.f, 27.f, 28.f,29.f, 30.f, 31.f, 32.f,33.f, 34.f, 35.f, 36.f,37.f, 38.f, 39.f, 40.f,81.f, 82.f, 83.f, 84.f,85.f, 86.f, 87.f, 88.f, 89.f, 90.f, 91.f, 92.f,93.f, 94.f, 95.f, 96.f,97.f, 98.f, 99.f, 100.f,41.f, 42.f, 43.f, 44.f,45.f, 46.f, 47.f, 48.f,49.f, 50.f, 51.f, 52.f, 53.f, 54.f, 55.f, 56.f,57.f, 58.f, 59.f, 60.f,21.f, 22.f, 23.f, 24.f,25.f, 26.f, 27.f, 28.f,29.f, 30.f, 31.f, 32.f,33.f, 34.f, 35.f, 36.f,37.f, 38.f, 39.f, 40.f}); nd4j::ops::dynamic_stitch op; auto result = op.execute({&indices0, &indices1, &indices2, &data0, &data1, &data2}, {}, {}); ASSERT_EQ(Status::OK(), result->status()); // result->at(0)->printIndexedBuffer("Output"); // exp.printIndexedBuffer("Expect"); // result->at(0)->printShapeInfo("Output shape"); ASSERT_TRUE(exp.isSameShape(result->at(0))); ASSERT_TRUE(exp.equalsTo(result->at(0))); delete result; } //////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests7, Test_Dynamic_Stitch_Prof_1) { auto indices0 = NDArrayFactory::create('c', {2}, {1, 10}); auto indices1 = NDArrayFactory::create('c', {2, 3}, {0, 7, 9, 5, 8, 3}); auto indices2 = NDArrayFactory::create('c', {3, 1}, {6, 4, 2}); auto data0 = NDArrayFactory::create('c', {2,5,4}, {1.f, 2.f, 3.f, 4.f,5.f, 6.f, 7.f, 8.f,9.f, 10.f, 11.f, 12.f,13.f, 14.f, 15.f, 16.f,17.f, 18.f, 19.f, 20.f,21.f, 22.f, 23.f, 24.f, 25.f, 26.f, 27.f, 28.f,29.f, 30.f, 31.f, 32.f,33.f, 34.f, 35.f, 36.f,37.f, 38.f, 39.f, 40.f}); auto data1 = NDArrayFactory::create('c', {2,3,5,4},{1.f, 2.f, 3.f, 4.f,5.f, 6.f, 7.f, 8.f,9.f, 10.f, 11.f, 12.f,13.f, 14.f, 15.f, 16.f,17.f, 18.f, 19.f, 20.f,21.f, 22.f, 23.f, 24.f,25.f, 26.f, 27.f, 28.f, 29.f, 30.f, 31.f, 32.f,33.f, 34.f, 35.f, 36.f,37.f, 38.f, 39.f, 40.f,41.f, 42.f, 43.f, 44.f,45.f, 46.f, 47.f, 48.f,49.f, 50.f, 51.f, 52.f,53.f, 54.f, 55.f, 56.f, 57.f, 58.f, 59.f, 60.f,61.f, 62.f, 63.f, 64.f,65.f, 66.f, 67.f, 68.f,69.f, 70.f, 71.f, 72.f,73.f, 74.f, 75.f, 76.f,77.f, 78.f, 79.f, 80.f,81.f, 82.f, 83.f, 84.f, 85.f, 86.f, 87.f, 88.f,89.f, 90.f, 91.f, 92.f,93.f, 94.f, 95.f, 96.f,97.f, 98.f, 99.f, 100.f,101.f, 102.f, 103.f, 104.f,105.f, 106.f, 107.f, 108.f,109.f, 110.f, 111.f, 112.f, 113.f, 114.f, 115.f, 116.f,117.f, 118.f, 119.f, 120.f}); auto data2 = NDArrayFactory::create('c', {3,1,5,4}, {1.f, 2.f, 3.f, 4.f,5.f, 6.f, 7.f, 8.f,9.f, 10.f, 11.f, 12.f,13.f, 14.f, 15.f, 16.f,17.f, 18.f, 19.f, 20.f,21.f, 22.f, 23.f, 24.f, 25.f, 26.f, 27.f, 28.f,29.f, 30.f, 31.f, 32.f,33.f, 34.f, 35.f, 36.f,37.f, 38.f, 39.f, 40.f,41.f, 42.f, 43.f, 44.f,45.f, 46.f, 47.f, 48.f, 49.f, 50.f, 51.f, 52.f,53.f, 54.f, 55.f, 56.f,57.f, 58.f, 59.f, 60.f}); auto exp = NDArrayFactory::create('c', {11, 5, 4}, {1.f, 2.f, 3.f, 4.f,5.f, 6.f, 7.f, 8.f,9.f, 10.f, 11.f, 12.f,13.f, 14.f, 15.f, 16.f,17.f, 18.f, 19.f, 20.f,1.f, 2.f, 3.f, 4.f, 5.f, 6.f, 7.f, 8.f,9.f, 10.f, 11.f, 12.f,13.f, 14.f, 15.f, 16.f,17.f, 18.f, 19.f, 20.f,41.f, 42.f, 43.f, 44.f,45.f, 46.f, 47.f, 48.f, 49.f, 50.f, 51.f, 52.f,53.f, 54.f, 55.f, 56.f,57.f, 58.f, 59.f, 60.f,101.f, 102.f, 103.f, 104.f,105.f, 106.f, 107.f, 108.f,109.f, 110.f, 111.f, 112.f, 113.f, 114.f, 115.f, 116.f,117.f, 118.f, 119.f, 120.f,21.f, 22.f, 23.f, 24.f,25.f, 26.f, 27.f, 28.f,29.f, 30.f, 31.f, 32.f,33.f, 34.f, 35.f, 36.f, 37.f, 38.f, 39.f, 40.f,61.f, 62.f, 63.f, 64.f,65.f, 66.f, 67.f, 68.f,69.f, 70.f, 71.f, 72.f,73.f, 74.f, 75.f, 76.f,77.f, 78.f, 79.f, 80.f, 1.f, 2.f, 3.f, 4.f,5.f, 6.f, 7.f, 8.f,9.f, 10.f, 11.f, 12.f,13.f, 14.f, 15.f, 16.f,17.f, 18.f, 19.f, 20.f,21.f, 22.f, 23.f, 24.f, 25.f, 26.f, 27.f, 28.f,29.f, 30.f, 31.f, 32.f,33.f, 34.f, 35.f, 36.f,37.f, 38.f, 39.f, 40.f,81.f, 82.f, 83.f, 84.f,85.f, 86.f, 87.f, 88.f, 89.f, 90.f, 91.f, 92.f,93.f, 94.f, 95.f, 96.f,97.f, 98.f, 99.f, 100.f,41.f, 42.f, 43.f, 44.f,45.f, 46.f, 47.f, 48.f,49.f, 50.f, 51.f, 52.f, 53.f, 54.f, 55.f, 56.f,57.f, 58.f, 59.f, 60.f,21.f, 22.f, 23.f, 24.f,25.f, 26.f, 27.f, 28.f,29.f, 30.f, 31.f, 32.f,33.f, 34.f, 35.f, 36.f,37.f, 38.f, 39.f, 40.f}); nd4j::ops::dynamic_stitch op; auto result = op.execute({&indices0, &indices1, &indices2, &data0, &data1, &data2}, {}, {}); ASSERT_EQ(Status::OK(), result->status()); // result->at(0)->printIndexedBuffer("Output"); // exp.printIndexedBuffer("Expect"); // result->at(0)->printShapeInfo("Output shape"); auto res = result->at(0); ASSERT_TRUE(exp.isSameShape(result->at(0))); ASSERT_TRUE(exp.equalsTo(result->at(0))); int numOfCases = 100; auto timeStart = std::chrono::system_clock::now(); for (int i = 0; i < numOfCases; i++) { op.execute({&indices0, &indices1, &indices2, &data0, &data1, &data2}, {res}, {}, {}, {}); } auto timeEnd = std::chrono::system_clock::now(); auto outerTime = std::chrono::duration_cast (timeEnd - timeStart).count(); //nd4j_printf("dynamic_stitch: Process with %i iterations was load: %lld us.\n", numOfCases, outerTime / numOfCases); delete result; } //////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests7, Test_Dynamic_Stitch_119_1) { auto indices0 = NDArrayFactory::create('c', {2}, {1, 10}); auto indices1 = NDArrayFactory::create('c', {2, 3}, {0, 7, 9, 5, 8, 3}); auto indices2 = NDArrayFactory::create('c', {3, 1}, {6, 4, 2}); auto data0 = NDArrayFactory::create('c', {2,5,4}); auto data1 = NDArrayFactory::create('c', {2,3,5,4}); auto data2 = NDArrayFactory::create('c', {3,1,5,4}); auto exp = NDArrayFactory::create('c', {11, 5, 4}, { 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 181, 182, 183, 184, 185, 186, 187, 188, 189, 190, 191, 192, 193, 194, 195, 196, 197, 198, 199, 200, 121, 122, 123, 124, 125, 126, 127, 128, 129, 130, 131, 132, 133, 134, 135, 136, 137, 138, 139, 140, 161, 162, 163, 164, 165, 166, 167, 168, 169, 170, 171, 172, 173, 174, 175, 176, 177, 178, 179, 180, 81, 82, 83, 84, 85, 86, 87, 88, 89, 90, 91, 92, 93, 94, 95, 96, 97, 98, 99, 100, 141, 142, 143, 144, 145, 146, 147, 148, 149, 150, 151, 152, 153, 154, 155, 156, 157, 158, 159, 160, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59, 60, 101, 102, 103, 104, 105, 106, 107, 108, 109, 110, 111, 112, 113, 114, 115, 116, 117, 118, 119, 120, 61, 62, 63, 64, 65, 66, 67, 68, 69, 70, 71, 72, 73, 74, 75, 76, 77, 78, 79, 80, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, }); data0.linspace(1); data1.linspace(21); data2.linspace(141); nd4j::ops::dynamic_stitch op; auto result = op.execute({&indices0, &indices1, &indices2, &data0, &data1, &data2}, {}, {}); ASSERT_EQ(Status::OK(), result->status()); auto z = result->at(0); z->printIndexedBuffer("Stitch"); z->printShapeInfo("Stitch Shape"); ASSERT_TRUE(z->isSameShape(exp)); ASSERT_TRUE(z->equalsTo(exp)); delete result; } TEST_F(DeclarableOpsTests7, Test_Dynamic_Stitch_119_2) { auto indices0 = NDArrayFactory::create('c', {2}, {1, 10}); auto indices1 = NDArrayFactory::create('c', {2, 3}, {0, 7, 9, 5, 8, 3}); auto indices2 = NDArrayFactory::create('c', {3, 1}, {6, 4, 2}); auto data0 = NDArrayFactory::create('c', {2,5,4}); auto data1 = NDArrayFactory::create('c', {2,3,5,4}); auto data2 = NDArrayFactory::create('c', {3,1,5,4}); auto exp = NDArrayFactory::create('c', {11, 5, 4}, { 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59, 60, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 201, 202, 203, 204, 205, 206, 207, 208, 209, 210, 211, 212, 213, 214, 215, 216, 217, 218, 219, 220, 141, 142, 143, 144, 145, 146, 147, 148, 149, 150, 151, 152, 153, 154, 155, 156, 157, 158, 159, 160, 181, 182, 183, 184, 185, 186, 187, 188, 189, 190, 191, 192, 193, 194, 195, 196, 197, 198, 199, 200, 101, 102, 103, 104, 105, 106, 107, 108, 109, 110, 111, 112, 113, 114, 115, 116, 117, 118, 119, 120, 161, 162, 163, 164, 165, 166, 167, 168, 169, 170, 171, 172, 173, 174, 175, 176, 177, 178, 179, 180, 61, 62, 63, 64, 65, 66, 67, 68, 69, 70, 71, 72, 73, 74, 75, 76, 77, 78, 79, 80, 121, 122, 123, 124, 125, 126, 127, 128, 129, 130, 131, 132, 133, 134, 135, 136, 137, 138, 139, 140, 81, 82, 83, 84, 85, 86, 87, 88, 89, 90, 91, 92, 93, 94, 95, 96, 97, 98, 99, 100, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, }); data0.linspace(1); data1.linspace(41); data2.linspace(161); nd4j::ops::dynamic_stitch op; auto result = op.execute({&indices0, &indices1, &indices2, &data0, &data1, &data2}, {}, {}); ASSERT_EQ(Status::OK(), result->status()); auto z = result->at(0); z->printIndexedBuffer("Stitch"); z->printShapeInfo("Stitch Shape"); ASSERT_TRUE(z->isSameShape(exp)); ASSERT_TRUE(z->equalsTo(exp)); delete result; } TEST_F(DeclarableOpsTests7, Test_Dynamic_Partition_119) { auto x = NDArrayFactory::create('c', {5, 4, 11}); auto y = NDArrayFactory::create('c', {5, 4}, {0,1,2,3, 1,0,2,3, 2,3,1,0, 2,1,0,3, 0,1,2,3}); auto e = NDArrayFactory::create('c', {5, 11}); x.assign(1.f); e.assign(1.f); nd4j::ops::dynamic_partition op; auto result = op.execute({&x, &y}, {}, {4}); ASSERT_EQ(Status::OK(), result->status()); ASSERT_EQ(4, result->size()); auto z = result->at(0); // z->printShapeInfo("Output shape info"); // z->printIndexedBuffer("Output1"); // result->at(1)->printIndexedBuffer("Output2"); // result->at(2)->printIndexedBuffer("Output3"); // result->at(3)->printIndexedBuffer("Output4"); ASSERT_TRUE(e.isSameShape(z)); delete result; } TEST_F(DeclarableOpsTests7, Test_Dynamic_Partition_119_1) { auto x = NDArrayFactory::create('c', {3, 4, 2}, {10, 20,11, 21,12, 22,13, 23,14, 24,15, 25,16, 26,17, 27,18, 28,19, 29,20, 30,21, 31}); auto y = NDArrayFactory::create('c', {3, 4}, {0,0,0,0, 2,2,2,2, 2,1,1,1}); auto e = NDArrayFactory::create('c', {4, 2}, {10, 20, 11, 21, 12, 22, 13, 23}); // x.assign(1.f); // e.assign(1.f); nd4j::ops::dynamic_partition op; auto result = op.execute({&x, &y}, {}, {3}); ASSERT_EQ(Status::OK(), result->status()); ASSERT_EQ(3, result->size()); auto z = result->at(0); // z->printShapeInfo("Output shape info"); // result->at(1)->printShapeInfo("Shape2"); // result->at(2)->printShapeInfo("Shape3"); // result->at(3)->printShapeInfo("Shape4"); // z->printIndexedBuffer("Output1"); // result->at(1)->printIndexedBuffer("Output2"); // result->at(2)->printIndexedBuffer("Output3"); // result->at(3)->printIndexedBuffer("Output4"); ASSERT_TRUE(e.isSameShape(z)); delete result; } TEST_F(DeclarableOpsTests7, Test_Dynamic_Partition_119_2) { auto x = NDArrayFactory::create('c', {5, 4, 11}); auto y = NDArrayFactory::create('c', {5, 4}, {0,1,2,3, 1,0,2,3, 2,3,1,0, 2,1,0,3, 0,1,2,3}); auto e1 = NDArrayFactory::create('c', {5, 11}, {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 56, 57, 58, 59, 60, 61, 62, 63, 64, 65, 66, 122, 123, 124, 125, 126, 127, 128, 129, 130, 131, 132, 155, 156, 157, 158, 159, 160, 161, 162, 163, 164, 165, 177, 178, 179, 180, 181, 182, 183, 184, 185, 186, 187}); auto e2 = NDArrayFactory::create('c', {5, 11}, { 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 45, 46, 47, 48, 49, 50, 51, 52, 53, 54, 55, 111, 112, 113, 114, 115, 116, 117, 118, 119, 120, 121, 144, 145, 146, 147, 148, 149, 150, 151, 152, 153, 154, 188, 189, 190, 191, 192, 193, 194, 195, 196, 197, 198}); auto e3 = NDArrayFactory::create('c', {5, 11}, {23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 67, 68, 69, 70, 71, 72, 73, 74, 75, 76, 77, 89, 90, 91, 92, 93, 94, 95, 96, 97, 98, 99, 133, 134, 135, 136, 137, 138, 139, 140, 141, 142, 143, 199, 200, 201, 202, 203, 204, 205, 206, 207, 208, 209}); auto e4 = NDArrayFactory::create('c', {5, 11}, { 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 78, 79, 80, 81, 82, 83, 84, 85, 86, 87, 88, 100, 101, 102, 103, 104, 105, 106, 107, 108, 109, 110, 166, 167, 168, 169, 170, 171, 172, 173, 174, 175, 176, 210, 211, 212, 213, 214, 215, 216, 217, 218, 219, 220}) ; std::vector e({&e1, &e2, &e3, &e4}); x.linspace(1.f); //.assign(1.f); nd4j::ops::dynamic_partition op; auto result = op.execute({&x, &y}, {}, {4}); ASSERT_EQ(Status::OK(), result->status()); ASSERT_EQ(4, result->size()); for (size_t i = 0; i < result->size(); i++) { auto z = result->at(i); // z->printShapeInfo("Output shape info"); // z->printIndexedBuffer("Output1"); // result->at(1)->printIndexedBuffer("Output2"); // result->at(2)->printIndexedBuffer("Output3"); // result->at(3)->printIndexedBuffer("Output4"); ASSERT_TRUE(e[i]->isSameShape(z)); ASSERT_TRUE(e[i]->equalsTo(z)); } delete result; } TEST_F(DeclarableOpsTests7, Test_SequenceMask_1) { auto input = NDArrayFactory::create('c', {4, 4}, {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16}); auto exp = NDArrayFactory::create('c', {4, 4, 16}, {1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0,1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0,1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0,1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0,1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0,1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1 }); nd4j::ops::sequence_mask op; auto result = op.execute({&input}, {}, {}); ASSERT_EQ(Status::OK(), result->status()); auto z = result->at(0); // z->printIndexedBuffer("Output"); // z->printShapeInfo("Shape"); ASSERT_TRUE(exp.isSameShape(z)); ASSERT_TRUE(exp.equalsTo(z)); delete result; } TEST_F(DeclarableOpsTests7, Test_SequenceMask_2) { auto input = NDArrayFactory::create('c', {2, 2, 2}, {10, 20, 30, 4, 0, 6, 7, 8}); auto exp = NDArrayFactory::create('c', {2, 2, 2, 30}, {1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0}); nd4j::ops::sequence_mask op; auto result = op.execute({&input}, {}, {}); ASSERT_EQ(Status::OK(), result->status()); auto z = result->at(0); // z->printBuffer("Output"); // z->printShapeInfo("Shape"); ASSERT_TRUE(exp.isSameShape(z)); ASSERT_TRUE(exp.equalsTo(z)); delete result; } //////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests7, TestSegmentMax_1) { auto x = NDArrayFactory::create({1.8, 2.5,4., 9., 2.1, 2.4,3.,9., 2.1, 2.1,0.7, 0.1, 3., 4.2, 2.2, 1.}); auto idx = NDArrayFactory::create({0, 0, 1, 1, 1, 1, 2, 3, 3, 3, 4, 4, 4, 4, 4, 4}); auto exp = NDArrayFactory::create({2.5, 9, 3, 9, 4.2}); nd4j::ops::segment_max op; auto result = op.execute({&x, &idx}, {}, {}); ASSERT_EQ(result->status(), Status::OK()); // result->at(0)->printBuffer("MaX1"); // exp.printBuffer("ExP1"); ASSERT_TRUE(exp.equalsTo(result->at(0))); delete result; } //////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests7, TestSegmentMax_01) { auto x = NDArrayFactory::create({1.8, 2.5,4., 9., 2.1, 2.4,3.,9., 2.1, 2.1,0.7, 0.1, 3., 4.2, 2.2, 1., 10, 40, 30}); auto idx = NDArrayFactory::create({0, 0, 1, 1, 1, 1, 2, 3, 3, 3, 4, 4, 4, 4, 4, 4, 5,5, 5}); auto exp = NDArrayFactory::create({2.5, 9, 3, 9, 4.2, 40}); nd4j::ops::segment_max op; auto result = op.execute({&x, &idx}, {}, {}); ASSERT_EQ(result->status(), Status::OK()); // result->at(0)->printBuffer("MaX01"); // exp.printBuffer("ExP01"); ASSERT_TRUE(exp.equalsTo(result->at(0))); delete result; } //////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests7, TestSegmentMaxBP_1) { auto x = NDArrayFactory::create({1.8, 2.5, 4., 9., 2.1, 2.4, 3., 9., 2.1, 2.1, 0.7, 0.1, 3., 4.2, 2.2, 1.}); auto idx = NDArrayFactory::create({0, 0, 1, 1, 1, 1, 2, 3, 3, 3, 4, 4, 4, 4, 4, 4}); auto exp = NDArrayFactory::create({0., 1., 0., 2., 0., 0., 3., 4., 0., 0.,0., 0., 0., 5., 0.,0.}); auto eps = NDArrayFactory::create('c', {5}); nd4j::ops::segment_max_bp op; eps.linspace(1); auto result = op.execute({&x, &idx, &eps}, {}, {}); ASSERT_EQ(result->status(), Status::OK()); // result->at(0)->printIndexedBuffer("OutputMaxBP"); ASSERT_TRUE(exp.equalsTo(result->at(0))); delete result; } //////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests7, TestSegmentMax_2) { auto x = NDArrayFactory::create('c', {5, 4}, { 0, 1.8, 2.5, 4., 1, 9., 2.1, 2.4, 0, 3., 9., 2.1, 2, 1, 2.1, 0.7, 3, 4.2, 2.2, 1. }); auto idx = NDArrayFactory::create({0, 0, 0, 1, 2}); auto exp = NDArrayFactory::create('c', {3, 4}, {1, 9, 9, 4, 2, 1, 2.1, 0.7, 3, 4.2, 2.2, 1.}); //{ 2.1, 2.5, 4., 9., 2.1, 2.1, 0.7, 0.1, 3., 4.2, 2.2, 1.} nd4j::ops::segment_max op; auto result = op.execute({&x, &idx}, {}, {}); ASSERT_EQ(result->status(), Status::OK()); ASSERT_EQ(result->size(), 1); auto out = result->at(0); // out->printIndexedBuffer("Output2Max"); // exp.printIndexedBuffer("Expect2Max"); // exp.printShapeInfo("Exp Shape"); ASSERT_TRUE(exp.equalsTo(result->at(0))); delete result; } //////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests7, TestSegmentMaxBP_2) { auto x = NDArrayFactory::create('c', {4, 4}, {1.8, 2.5, 4., 9., 2.1, 2.4, 3., 9., 2.1, 2.1, 0.7, 0.1,3., 4.2, 2.2, 1. }); auto idx = NDArrayFactory::create({0, 0, 1, 2}); auto eps = NDArrayFactory::create('c', {3, 4}, {1., 2., 3., 4., 5., 6., 7., 8., 9., 10., 11., 12.}); // NDArray exp('c', {3, 4}, {2.1, 2.5, 4, 9,2.1, 2.1, 0.7, 0.1,3., 4.2, 2.2, 1.}); auto exp = NDArrayFactory::create('c', {4, 4}, {0., 2., 3., 4., 1., 0., 0., 4., 5., 6., 7., 8., 9., 10., 11., 12.}); //{ 2.1, 2.5, 4., 9., 2.1, 2.1, 0.7, 0.1, 3., 4.2, 2.2, 1.} nd4j::ops::segment_max_bp op; auto result = op.execute({&x, &idx, &eps}, {}, {}); ASSERT_EQ(result->status(), Status::OK()); ASSERT_EQ(result->size(), 2); //exp.printIndexedBuffer("BP Max Expect"); //result->at(0)->printIndexedBuffer("BP Max Output"); // exp.printShapeInfo("Exp Shape"); ASSERT_TRUE(exp.equalsTo(result->at(0))); delete result; } //////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests7, TestSegmentMax_3) { auto x = NDArrayFactory::create('c', {4, 4, 4}, { 91. , 82. , 37. , 64. ,55.1, 46.4, 73. , 28. ,119.1, 12.1, 112.7, 13.1,14. , 114.2, 16.2, 117. ,51. , 42. , 67. , 24.,15.1, 56.4, 93. , 28., 109.1, 82.1, 12.7, 113.1,114. , 14.2, 116.2, 11. ,31. , 22. , 87., 44. ,55.1, 46.4, 73., 28. ,119.1, 12.1, 112.7, 13.1,14. , 114.2, 16.2, 117. , 91. , 82. , 37., 64. ,55.1, 46.4, 73., 28. ,119.1, 12.1, 112.7, 13.1,14. , 114.2, 16.2, 117. }); // ---------------------------------------------------------------- auto idx = NDArrayFactory::create({0, 1, 1, 2}); auto exp = NDArrayFactory::create('c', {3, 4, 4}, {91. , 82. , 37. , 64.,55.1, 46.4, 73. , 28.,119.1, 12.1,112.7, 13.1,14. ,114.2, 16.2,117.,51. , 42. , 87. , 44., 55.1, 56.4, 93. , 28.,119.1, 82.1,112.7,113.1,114. ,114.2,116.2,117.,91. , 82. , 37. , 64.,55.1, 46.4, 73. , 28., 119.1, 12.1,112.7, 13.1,14. ,114.2, 16.2,117. }); //{ 2.1, 2.5, 4., 9., 2.1, 2.1, 0.7, 0.1, 3., 4.2, 2.2, 1.} nd4j::ops::segment_max op; auto result = op.execute({&x, &idx}, {}, {}); ASSERT_EQ(result->status(), Status::OK()); // result->at(0)->printIndexedBuffer("Output3Max"); // result->at(0)->printShapeInfo("Out Shape 3 Max"); // exp.printIndexedBuffer("Expect3Max"); // exp.printShapeInfo("Exp Shape 3 Max"); ASSERT_TRUE(exp.equalsTo(result->at(0))); delete result; } //////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests7, TestSegmentMax_4) { auto x = NDArrayFactory::create('c', {4, 4, 4}, {91. , 82. , 37. , 64. ,55.1, 46.4, 73. , 28. ,119.1, 12.1, 112.7, 13.1,14. , 114.2, 16.2, 117. ,51. , 42. , 67. , 24., 15.1, 56.4, 93. , 28.,109.1, 82.1, 12.7, 113.1,114. , 14.2, 116.2, 11. ,31. , 22. , 87., 44. ,55.1, 46.4, 73., 28. , 119.1, 12.1, 112.7, 13.1,14. , 114.2, 16.2, 117. ,91. , 82. , 37., 64. ,55.1, 46.4, 73., 28. ,119.1, 12.1, 112.7, 13.1,14. , 114.2, 16.2, 117. }); auto idx = NDArrayFactory::create({0, 1, 3, 7}); auto exp = NDArrayFactory::create('c', {8, 4, 4}, { 91. , 82. , 37. , 64. ,55.1, 46.4, 73. , 28. ,119.1, 12.1, 112.7, 13.1,14. , 114.2, 16.2, 117. ,51. , 42. , 67. , 24. ,15.1, 56.4, 93. , 28. , 109.1, 82.1, 12.7, 113.1,114. , 14.2, 116.2, 11. ,0. , 0. , 0. , 0. ,0. , 0. , 0. , 0. ,0. , 0. , 0. , 0. ,0. , 0. , 0. , 0. , 31. , 22. , 87. , 44. ,55.1, 46.4, 73. , 28. ,119.1, 12.1, 112.7, 13.1,14. , 114.2, 16.2, 117. ,0. , 0. , 0. , 0. ,0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. ,0. , 0. , 0. , 0. ,0. , 0. , 0. , 0. ,0. , 0. , 0. , 0. ,0. , 0. , 0. , 0. ,0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. ,0. , 0. , 0. , 0. ,0. , 0. , 0. , 0. ,0. , 0. , 0. , 0. ,91. , 82. , 37. , 64. ,55.1, 46.4, 73. , 28. , 119.1, 12.1, 112.7, 13.1,14. , 114.2, 16.2, 117. }); nd4j::ops::segment_max op; auto result = op.execute({&x, &idx}, {}, {}); ASSERT_EQ(result->status(), Status::OK()); //result->at(0)->printIndexedBuffer("Output"); //result->at(0)->printShapeInfo("Out Shape"); //exp.printIndexedBuffer("Expect"); //exp.printShapeInfo("Exp Shape"); ASSERT_TRUE(exp.isSameShape(result->at(0))); ASSERT_TRUE(exp.equalsTo(result->at(0))); delete result; } //////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests7, TestUnsortedSegmentMax_1) { auto x = NDArrayFactory::create({1.8, 2.5, 4., 9., 2.1, 2.4, 3., 9., 2.1, 2.1, 0.7, 0.1, 3., 4.2, 2.2, 1.}); auto idx = NDArrayFactory::create({4, 4, 1, 1, 1, 1, 2, 3, 3, 3, 4, 4, 4, 4, 0, 0}); auto exp = NDArrayFactory::create({2.2, 9., 3., 9., 4.2}); nd4j::ops::unsorted_segment_max op; auto result = op.execute({&x, &idx}, {}, {5}); ASSERT_EQ(result->status(), Status::OK()); ASSERT_TRUE(exp.equalsTo(result->at(0))); delete result; } //////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests7, TestUnsortedSegmentMaxBP_1) { auto x = NDArrayFactory::create({1.8, 2.5, 4., 9., 2.1, 2.4, 3., 9., 2.1, 2.1, 0.7, 0.1, 3., 4.2, 2.2, 1.}); auto idx = NDArrayFactory::create({0, 0, 1, 1, 1, 1, 2, 3, 3, 3, 4, 4, 4, 4, 4, 4}); auto exp = NDArrayFactory::create({0., 1., 0., 2., 0., 0., 3., 4., 0., 0.,0., 0., 0., 5., 0.,0.}); auto eps = NDArrayFactory::create('c', {5}); nd4j::ops::segment_max_bp op; eps.linspace(1); auto result = op.execute({&x, &idx, &eps}, {}, {5}); ASSERT_EQ(result->status(), Status::OK()); ASSERT_TRUE(exp.equalsTo(result->at(0))); delete result; } //////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests7, TestUnsortedSegmentMaxBP_2) { auto x = NDArrayFactory::create({3., 1.8, 2.5, 4., 9., 2.1, 2.4, 9., 2.1, 2.1, 0.7, 0.1, 3., 4.2, 2.2, 1.}); auto idx = NDArrayFactory::create({2, 0, 0, 1, 1, 1, 1, 3, 3, 3, 4, 4, 4, 4, 4, 4}); auto exp = NDArrayFactory::create({3., 0., 1., 0., 2., 0., 0., 4., 0., 0.,0., 0., 0., 5., 0.,0.}); auto eps = NDArrayFactory::create('c', {5}); nd4j::ops::segment_max_bp op; eps.linspace(1); auto result = op.execute({&x, &idx, &eps}, {}, {5}); ASSERT_EQ(result->status(), Status::OK()); //result->at(0)->printIndexedBuffer("Output"); //exp.printIndexedBuffer("Expect"); ASSERT_TRUE(exp.equalsTo(result->at(0))); delete result; } //////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests7, TestUnsortedSegmentMax_2) { auto x = NDArrayFactory::create({1.8, 2.5, 4., 9., 2.1, 2.4, 3., 9., 2.1, 2.1, 0.7, 0.1, 3., 4.2, 2.2, 1.}); auto idx = NDArrayFactory::create({4, 4, 1, 1, 1, 1, 3, 3, 3, 3, 4, 4, 4, 4, 0, 0}); auto exp = NDArrayFactory::create({2.2, 9., -DataTypeUtils::max(), 9., 4.2}); nd4j::ops::unsorted_segment_max op; auto result = op.execute({&x, &idx}, {}, {5}); ASSERT_EQ(result->status(), Status::OK()); // result->at(0)->printIndexedBuffer("OutputUnsortedMax"); ASSERT_TRUE(exp.equalsTo(result->at(0))); delete result; } //////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests7, TestUnsortedSegmentMax_3) { auto x = NDArrayFactory::create('c', {4, 4}, {1.8, 2.5, 4., 9.,2.1, 2.4, 3., 9.,2.1, 2.1, 0.7, 0.1,3., 4.2, 2.2, 1. }); auto idx = NDArrayFactory::create({0, 0, 1, 2}); auto exp = NDArrayFactory::create('c', {3, 4}, {2.1, 2.5, 4, 9,2.1, 2.1, 0.7, 0.1,3., 4.2, 2.2, 1.}); //{ 2.1, 2.5, 4., 9., 2.1, 2.1, 0.7, 0.1, 3., 4.2, 2.2, 1.} nd4j::ops::unsorted_segment_max op; auto result = op.execute({&x, &idx}, {}, {3}); ASSERT_EQ(result->status(), Status::OK()); ASSERT_EQ(result->size(), 1); //exp.printIndexedBuffer("Expect"); //result->at(0)->printIndexedBuffer("Output"); ASSERT_TRUE(exp.equalsTo(result->at(0))); delete result; } //////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests7, TestUnsortedSegmentMax_4) { auto x = NDArrayFactory::create('c', {4, 4}, {1.8, 2.5, 4., 9., 2.1, 2.4, 3., 8., 2.1, 2.1, 11.7, 0.1, 3., 4.2, 2.2, 1. }); auto idx = NDArrayFactory::create({0, 0, 0, 2}); double principalMax = DataTypeUtils::max(); auto exp = NDArrayFactory::create('c', {3, 4}, {2.1, 2.5, 11.7, 9, -principalMax, -principalMax, -principalMax, -principalMax, 3., 4.2, 2.2, 1.}); //{ 2.1, 2.5, 4., 9., 2.1, 2.1, 0.7, 0.1, 3., 4.2, 2.2, 1.} nd4j::ops::unsorted_segment_max op; auto result = op.execute({&x, &idx}, {}, {3}); ASSERT_EQ(result->status(), Status::OK()); ASSERT_EQ(result->size(), 1); //exp.printIndexedBuffer("Expect"); //result->at(0)->printIndexedBuffer("Output"); ASSERT_TRUE(exp.equalsTo(result->at(0))); delete result; } //////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests7, TestSegmentMin_1) { auto x = NDArrayFactory::create({1.8, 2.5,4., 9., 2.1, 2.4, 3.,9., 2.1, 2.1,0.7, 0.1, 3., 4.2, 2.2, 1.}); auto idx = NDArrayFactory::create({0, 0, 1, 1, 1, 1, 2, 3, 3, 3, 4, 4, 4, 4, 4, 4}); auto exp = NDArrayFactory::create({1.8, 2.1, 3., 2.1, 0.1}); nd4j::ops::segment_min op; auto result = op.execute({&x, &idx}, {}, {}); ASSERT_EQ(result->status(), Status::OK()); auto out = result->at(0); out->printIndexedBuffer("Segment mIN1"); ASSERT_TRUE(exp.equalsTo(result->at(0))); delete result; } //////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests7, TestSegmentMin_01) { auto x = NDArrayFactory::create({1.8, -2.5,4., -9., 2.1, 2.4,-3.,-9., 2.1, 2.1,0.7, 0.1, 3., -4.2, 2.2, 1.}); auto idx = NDArrayFactory::create({0, 0, 1, 1, 1, 1, 2, 3, 3, 3, 4, 4, 4, 4, 4, 4}); auto exp = NDArrayFactory::create({-2.5, -9, -3., -9, -4.2}); nd4j::ops::segment_min op; auto result = op.execute({&x, &idx}, {}, {}); ASSERT_EQ(result->status(), Status::OK()); auto out = result->at(0); out->printIndexedBuffer("Segment mIN01"); ASSERT_TRUE(exp.equalsTo(result->at(0))); delete result; } //////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests7, TestSegmentMin_02) { auto x = NDArrayFactory::create({1.8, -2.5,4., -9., 2.1, 2.4,-3.,-9., 2.1, 2.1,0.7, 0.1, 3., -4.2, 2.2, 1.}); auto idx = NDArrayFactory::create({0, 0, 1, 1, 1, 1, 2, 3, 3, 3, 4, 4, 4, 4, 4, 4}); auto exp = NDArrayFactory::create({-2.5, -9, -3., -9, -4.2}); nd4j::ops::segment_min op; auto result = op.execute({&x, &idx}, {}, {}); ASSERT_EQ(result->status(), Status::OK()); auto out = result->at(0); out->printIndexedBuffer("Segment mIN02"); ASSERT_TRUE(exp.equalsTo(result->at(0))); delete result; } //////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests7, TestSegmentMinBP_1) { auto x = NDArrayFactory::create({1.8, 2.5, 4., 9., 2.1, 2.4, 3., 9., 2.1, 2.1, 0.7, 0.1, 3., 4.2, 2.2, 1.}); auto idx = NDArrayFactory::create({0, 0, 1, 1, 1, 1, 2, 3, 3, 3, 4, 4, 4, 4, 4, 4}); auto exp = NDArrayFactory::create({ 1., 0., 0., 0., 2., 0., 3., 0., 4., 4., 0., 5., 0., 0., 0., 0.}); auto eps = NDArrayFactory::create('c', {5}); eps.linspace(1); nd4j::ops::segment_min_bp op; auto result = op.execute({&x, &idx, &eps}, {}, {}); ASSERT_EQ(result->status(), Status::OK()); //result->at(0)->printIndexedBuffer("Output1"); //exp.printIndexedBuffer("Expecte"); ASSERT_TRUE(exp.equalsTo(result->at(0))); delete result; } //////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests7, TestUnsortedSegmentMinBP_1) { auto x = NDArrayFactory::create({1.8, 2.5, 4., 9., 2.1, 2.4, 3., 9., 2.1, 2.1, 0.7, 0.1, 3., 4.2, 2.2, 1.}); auto idx = NDArrayFactory::create({0, 0, 1, 1, 1, 1, 2, 3, 3, 3, 4, 4, 4, 4, 4, 4}); auto exp = NDArrayFactory::create({ 1., 0., 0., 0., 2., 0., 3., 0., 4., 4., 0., 5., 0., 0., 0., 0.}); auto eps = NDArrayFactory::create('c', {5}); eps.linspace(1); nd4j::ops::unsorted_segment_min_bp op; auto result = op.execute({&x, &idx, &eps}, {}, {5}); ASSERT_EQ(result->status(), Status::OK()); //result->at(0)->printIndexedBuffer("Output1"); //exp.printIndexedBuffer("Expecte"); ASSERT_TRUE(exp.equalsTo(result->at(0))); delete result; } //////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests7, TestUnsortedSegmentMinBP_2) { auto x = NDArrayFactory::create({3., 1.8, 2.5, 4., 9., 2.1, 2.4, 9., 2.1, 2.1, 0.7, 0.1, 3., 4.2, 2.2, 1.}); auto idx = NDArrayFactory::create({2, 0, 0, 1, 1, 1, 1, 3, 3, 3, 4, 4, 4, 4, 4, 4}); auto exp = NDArrayFactory::create({3., 1., 0., 0., 0., 2., 0., 0., 4., 4., 0., 5., 0., 0., 0., 0.}); auto eps = NDArrayFactory::create('c', {5}); eps.linspace(1); nd4j::ops::unsorted_segment_min_bp op; auto result = op.execute({&x, &idx, &eps}, {}, {5}); ASSERT_EQ(result->status(), Status::OK()); //result->at(0)->printIndexedBuffer("Output1"); //exp.printIndexedBuffer("Expecte"); ASSERT_TRUE(exp.equalsTo(result->at(0))); delete result; } //////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests7, TestSegmentMin_2) { auto x = NDArrayFactory::create('c', {4, 4}, {1.8, 2.5, 4., 9.,2.1, 2.4, 3., 9.,2.1, 2.1, 0.7, 0.1,3., 4.2, 2.2, 1.}); auto idx = NDArrayFactory::create({0, 0, 1, 2}); auto exp = NDArrayFactory::create('c', {3, 4}, {1.8, 2.4, 3. , 9.,2.1, 2.1, 0.7, 0.1,3. , 4.2, 2.2, 1.}); //{ 2.1, 2.5, 4., 9., 2.1, 2.1, 0.7, 0.1, 3., 4.2, 2.2, 1.} nd4j::ops::segment_min op; auto result = op.execute({&x, &idx}, {}, {}); ASSERT_EQ(result->status(), Status::OK()); ASSERT_EQ(result->size(), 1); // exp.printIndexedBuffer("Expect"); // exp.printShapeInfo("Exp Shape"); ASSERT_TRUE(exp.equalsTo(result->at(0))); delete result; } //////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests7, TestSegmentMinBP_2) { auto x = NDArrayFactory::create('c', {4, 4}, {1.8, 2.5, 4., 9.,2.1, 2.4, 3., 9.,2.1, 2.1, 0.7, 0.1,3., 4.2, 2.2, 1.}); auto idx = NDArrayFactory::create({0, 0, 1, 2}); auto eps = NDArrayFactory::create('c', {3, 4}, {1., 2., 3. , 4., 5., 6., 7., 8., 9., 10., 11., 12.}); auto exp = NDArrayFactory::create('c', {4, 4}, {1., 0., 0., 4., 0., 2., 3., 4., 5., 6., 7., 8., 9., 10., 11., 12.}); //{ 2.1, 2.5, 4., 9., 2.1, 2.1, 0.7, 0.1, 3., 4.2, 2.2, 1.} nd4j::ops::segment_min_bp op; auto result = op.execute({&x, &idx, &eps}, {}, {}); ASSERT_EQ(result->status(), Status::OK()); ASSERT_EQ(result->size(), 2); // exp.printIndexedBuffer("Expect"); // result->at(0)->printIndexedBuffer("Output"); ASSERT_TRUE(exp.equalsTo(result->at(0))); delete result; } //////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests7, TestSegmentMin_3) { auto x = NDArrayFactory::create('c', {4, 4, 4}, { 91. , 82. , 37. , 64. ,55.1, 46.4, 73. , 28. ,119.1, 12.1, 112.7, 13.1,14. , 114.2, 16.2, 117. ,51. , 42. , 67. , 24.,15.1, 56.4, 93. , 28.,109.1, 82.1, 12.7, 113.1, 114. , 14.2, 116.2, 11. ,31. , 22. , 87., 44. ,55.1, 46.4, 73., 28. ,119.1, 12.1, 112.7, 13.1,14. , 114.2, 16.2, 117. ,91. , 82. , 37., 64. ,55.1, 46.4, 73., 28. , 119.1, 12.1, 112.7, 13.1, 14. , 114.2, 16.2, 117. }); // ---------------------------------------------------------------- auto idx = NDArrayFactory::create({0, 1, 1, 2}); auto exp = NDArrayFactory::create('c', {3, 4, 4}, {91. , 82. , 37. , 64. ,55.1, 46.4, 73. , 28. ,119.1, 12.1, 112.7, 13.1,14. , 114.2, 16.2, 117. ,31. , 22. , 67. , 24. , 15.1, 46.4, 73. , 28. ,109.1, 12.1, 12.7, 13.1,14. , 14.2, 16.2, 11. ,91. , 82. , 37. , 64. ,55.1, 46.4, 73. , 28. ,119.1, 12.1, 112.7, 13.1,14. , 114.2, 16.2, 117. }); nd4j::ops::segment_min op; auto result = op.execute({&x, &idx}, {}, {}); ASSERT_EQ(result->status(), Status::OK()); // result->at(0)->printIndexedBuffer("Output"); // result->at(0)->printShapeInfo("Out Shape"); // exp.printIndexedBuffer("Expect"); // exp.printShapeInfo("Exp Shape"); ASSERT_TRUE(exp.equalsTo(result->at(0))); delete result; } //////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests7, TestSegmentMin_4) { auto x = NDArrayFactory::create('c', {4, 4, 4}, { 91. , 82. , 37. , 64. ,55.1, 46.4, 73. , 28. ,119.1, 12.1, 112.7, 13.1,14. , 114.2, 16.2, 117. ,51. , 42. , 67. , 24.,15.1, 56.4, 93. , 28., 109.1, 82.1, 12.7, 113.1,114. , 14.2, 116.2, 11. ,31. , 22. , 87., 44. ,55.1, 46.4, 73., 28. ,119.1, 12.1, 112.7, 13.1,14. , 114.2, 16.2, 117. , 91. , 82. , 37., 64. ,55.1, 46.4, 73., 28. ,119.1, 12.1, 112.7, 13.1,14. , 114.2, 16.2, 117. }); // ---------------------------------------------------------------- auto idx = NDArrayFactory::create({0, 1, 3, 7}); auto exp = NDArrayFactory::create('c', {8, 4, 4}, { 91. , 82. , 37. , 64. ,55.1, 46.4, 73. , 28. ,119.1, 12.1, 112.7, 13.1,14. , 114.2, 16.2, 117. ,51. , 42. , 67. , 24. ,15.1, 56.4, 93. , 28. , 109.1, 82.1, 12.7, 113.1,114. , 14.2, 116.2, 11. ,0. , 0. , 0. , 0. ,0. , 0. , 0. , 0. ,0. , 0. , 0. , 0. ,0. , 0. , 0. , 0. , 31. , 22. , 87. , 44. ,55.1, 46.4, 73. , 28. ,119.1, 12.1, 112.7, 13.1,14. , 114.2, 16.2, 117. ,0. , 0. , 0. , 0. ,0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. ,0. , 0. , 0. , 0. ,0. , 0. , 0. , 0. ,0. , 0. , 0. , 0. ,0. , 0. , 0. , 0. ,0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. ,0. , 0. , 0. , 0. ,0. , 0. , 0. , 0. ,0. , 0. , 0. , 0. ,91. , 82. , 37. , 64. ,55.1, 46.4, 73. , 28. ,119.1, 12.1, 112.7, 13.1,14. , 114.2, 16.2, 117. }); nd4j::ops::segment_min op; auto result = op.execute({&x, &idx}, {}, {}); ASSERT_EQ(result->status(), Status::OK()); //result->at(0)->printIndexedBuffer("Output"); //result->at(0)->printShapeInfo("Out Shape"); //exp.printIndexedBuffer("Expect"); //exp.printShapeInfo("Exp Shape"); ASSERT_TRUE(exp.isSameShape(result->at(0))); ASSERT_TRUE(exp.equalsTo(result->at(0))); delete result; } //////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests7, TestUnsortedSegmentMin_1) { auto x = NDArrayFactory::create({1.8, 2.5,4., 9., 2.1, 2.4,3.,9., 2.1, 2.1,0.7, 0.1, 3., 4.2, 2.2, 1.}); auto idx = NDArrayFactory::create({0, 0, 1, 1, 1, 1, 2, 3, 3, 3, 4, 4, 4, 4, 4, 4}); auto exp = NDArrayFactory::create({1.8, 2.1, 3., 2.1, 0.1}); nd4j::ops::unsorted_segment_min op; auto result = op.execute({&x, &idx}, {}, {5}); ASSERT_EQ(result->status(), Status::OK()); ASSERT_TRUE(exp.equalsTo(result->at(0))); delete result; } //////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests7, TestUnsortedSegmentMin_01) { auto x = NDArrayFactory::create({1.8, 2.5,4., 9., 2.1, 2.4,3.,9., 2.1, 2.1,0.7, 0.1, 3., 4.2, 2.2, 1.}); auto idx = NDArrayFactory::create({0, 0, 1, 1, 1, 1, 2, 3, 3, 3, 4, 4, 4, 4, 4, 4}); auto exp = NDArrayFactory::create({1.8, 2.1, 3., 2.1, 0.1}); nd4j::ops::unsorted_segment_min op; auto result = op.execute({&x, &idx}, {}, {5}); ASSERT_EQ(result->status(), Status::OK()); ASSERT_TRUE(exp.equalsTo(result->at(0))); delete result; } //////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests7, TestUnsortedSegmentMin_2) { auto x = NDArrayFactory::create('c', {4, 4}, {1.8, 2.5, 4., 9.,2.1, 2.4, 3., 9.,2.1, 2.1, 0.7, 0.1,3., 4.2, 2.2, 1.}); auto idx = NDArrayFactory::create({0, 0, 1, 2}); auto exp = NDArrayFactory::create('c', {3, 4}, {1.8, 2.4, 3. , 9.,2.1, 2.1, 0.7, 0.1,3. , 4.2, 2.2, 1.}); //{ 2.1, 2.5, 4., 9., 2.1, 2.1, 0.7, 0.1, 3., 4.2, 2.2, 1.} nd4j::ops::unsorted_segment_min op; auto result = op.execute({&x, &idx}, {}, {3}); ASSERT_EQ(result->status(), Status::OK()); ASSERT_EQ(result->size(), 1); // exp.printIndexedBuffer("Expect"); // exp.printShapeInfo("Exp Shape"); ASSERT_TRUE(exp.equalsTo(result->at(0))); delete result; } //////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests7, TestUnsortedSegmentMin_3) { auto x = NDArrayFactory::create('c', {4, 4, 4}, { 91. , 82. , 37. , 64. ,55.1, 46.4, 73. , 28. ,119.1, 12.1, 112.7, 13.1,14. , 114.2, 16.2, 117. ,51. , 42. , 67. , 24.,15.1, 56.4, 93. , 28.,109.1, 82.1, 12.7, 113.1, 114. , 14.2, 116.2, 11. ,31. , 22. , 87., 44. ,55.1, 46.4, 73., 28. ,119.1, 12.1, 112.7, 13.1,14. , 114.2, 16.2, 117. ,91. , 82. , 37., 64. ,55.1, 46.4, 73., 28. , 119.1, 12.1, 112.7, 13.1, 14. , 114.2, 16.2, 117. }); // ---------------------------------------------------------------- auto idx = NDArrayFactory::create({0, 1, 1, 2}); auto exp = NDArrayFactory::create('c', {3, 4, 4}, {91. , 82. , 37. , 64. ,55.1, 46.4, 73. , 28. ,119.1, 12.1, 112.7, 13.1,14. , 114.2, 16.2, 117. ,31. , 22. , 67. , 24. , 15.1, 46.4, 73. , 28. ,109.1, 12.1, 12.7, 13.1,14. , 14.2, 16.2, 11. ,91. , 82. , 37. , 64. ,55.1, 46.4, 73. , 28. ,119.1, 12.1, 112.7, 13.1,14. , 114.2, 16.2, 117. }); nd4j::ops::unsorted_segment_min op; auto result = op.execute({&x, &idx}, {}, {3}); ASSERT_EQ(result->status(), Status::OK()); // result->at(0)->printIndexedBuffer("Output"); // result->at(0)->printShapeInfo("Out Shape"); // exp.printIndexedBuffer("Expect"); // exp.printShapeInfo("Exp Shape"); ASSERT_TRUE(exp.equalsTo(result->at(0))); delete result; } //////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests7, TestUnsortedSegmentMin_4) { auto x = NDArrayFactory::create('c', {4, 4, 4}, {91., 82., 37., 64., 55.1, 46.4, 73., 28., 119.1, 12.1, 112.7, 13.1, 14., 114.2, 16.2, 117., 51., 42., 67., 24., 15.1, 56.4, 93., 28., 109.1, 82.1, 12.7, 113.1, 114., 14.2, 116.2, 11., 31., 22., 87., 44., 55.1, 46.4, 73., 28., 119.1, 12.1, 112.7, 13.1, 14., 114.2, 16.2, 117., 91., 82., 37., 64., 55.1, 46.4, 73., 28., 119.1, 12.1, 112.7, 13.1, 14., 114.2, 16.2, 117. }); // ---------------------------------------------------------------- auto idx = NDArrayFactory::create({0, 1, 3, 7}); double principalMax = DataTypeUtils::max(); auto exp = NDArrayFactory::create('c', {8, 4, 4}, { 91., 82., 37., 64., 55.1, 46.4, 73., 28., 119.1, 12.1, 112.7, 13.1, 14., 114.2, 16.2, 117., 51., 42., 67., 24., 15.1, 56.4, 93., 28., 109.1, 82.1, 12.7, 113.1, 114., 14.2, 116.2, 11., principalMax, principalMax, principalMax, principalMax, principalMax, principalMax, principalMax, principalMax, principalMax, principalMax, principalMax, principalMax, principalMax, principalMax, principalMax, principalMax, 31., 22., 87., 44., 55.1, 46.4, 73., 28., 119.1, 12.1, 112.7, 13.1, 14., 114.2, 16.2, 117., principalMax, principalMax, principalMax, principalMax, principalMax, principalMax, principalMax, principalMax, principalMax, principalMax, principalMax, principalMax, principalMax, principalMax, principalMax, principalMax, principalMax, principalMax, principalMax, principalMax, principalMax, principalMax, principalMax, principalMax, principalMax, principalMax, principalMax, principalMax, principalMax, principalMax, principalMax, principalMax, principalMax, principalMax, principalMax, principalMax, principalMax, principalMax, principalMax, principalMax, principalMax, principalMax, principalMax, principalMax, principalMax, principalMax, principalMax, principalMax, 91., 82. , 37. , 64. ,55.1, 46.4, 73. , 28. ,119.1, 12.1, 112.7, 13.1,14. , 114.2, 16.2, 117. }); nd4j::ops::unsorted_segment_min op; auto result = op.execute({&x, &idx}, {}, {8}); ASSERT_EQ(result->status(), Status::OK()); //result->at(0)->printIndexedBuffer("Output"); //result->at(0)->printShapeInfo("Out Shape"); // exp.printIndexedBuffer("Expect"); //exp.printShapeInfo("Exp Shape"); ASSERT_TRUE(exp.isSameShape(result->at(0))); ASSERT_TRUE(exp.equalsTo(result->at(0))); delete result; } //////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests7, TestSegmentMean_1) { auto x = NDArrayFactory::create({1.8, 2.5,4., 9., 2.1, 2.4,3.,9., 2.1, 2.1,0.7, 0.1, 3., 4.2, 2.2, 1.}); auto idx = NDArrayFactory::create({0, 0, 1, 1, 1, 1, 2, 3, 3, 3, 4, 4, 4, 4, 4, 4}); auto exp = NDArrayFactory::create({2.15, 4.375, 3., 4.4, 1.8666667}); nd4j::ops::segment_mean op; auto result = op.execute({&x, &idx}, {}, {}); ASSERT_EQ(result->status(), Status::OK()); ASSERT_TRUE(exp.equalsTo(result->at(0))); delete result; } TEST_F(DeclarableOpsTests7, TestSegmentMean_2) { auto x = NDArrayFactory::create('c', {4, 4}, {1.8, 2.5, 4., 9.,2.1, 2.4, 3., 9.,2.1, 2.1, 0.7, 0.1,3., 4.2, 2.2, 1.}); auto idx = NDArrayFactory::create({0, 0, 1, 2}); auto exp = NDArrayFactory::create('c', {3, 4}, { 1.95, 2.45, 3.5, 9., 2.1, 2.1, 0.7, 0.1, 3. , 4.2, 2.2, 1.}); nd4j::ops::segment_mean op; auto result = op.execute({&x, &idx}, {}, {}); ASSERT_EQ(result->status(), Status::OK()); ASSERT_EQ(result->size(), 1); // exp.printIndexedBuffer("Expect"); // result->at(0)->printIndexedBuffer("Output"); // exp.printShapeInfo("Exp Shape"); ASSERT_TRUE(exp.equalsTo(result->at(0))); delete result; } //////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests7, TestSegmentMeanBP_2) { auto x = NDArrayFactory::create('c', {4, 4}, {1.8, 2.5, 4., 9.,2.1, 2.4, 3., 9.,2.1, 2.1, 0.7, 0.1,3., 4.2, 2.2, 1.}); auto idx = NDArrayFactory::create({0, 0, 1, 2}); auto eps = NDArrayFactory::create('c', {3, 4}); auto exp = NDArrayFactory::create('c', {4, 4}, { 0.5, 1., 1.5, 2., 0.5, 1., 1.5, 2., 5., 6., 7., 8., 9., 10., 11., 12.}); eps.linspace(1); nd4j::ops::segment_mean_bp op; auto result = op.execute({&x, &idx, &eps}, {}, {}); ASSERT_EQ(result->status(), Status::OK()); ASSERT_EQ(result->size(), 2); // exp.printIndexedBuffer("Expect"); // result->at(0)->printIndexedBuffer("Output"); // exp.printShapeInfo("Exp Shape"); ASSERT_TRUE(exp.equalsTo(result->at(0))); delete result; } //////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests7, TestSegmentMean_3) { auto x = NDArrayFactory::create('c', {4, 4, 4}, { 91. , 82. , 37. , 64. ,55.1, 46.4, 73. , 28. ,119.1, 12.1, 112.7, 13.1,14. , 114.2, 16.2, 117. ,51. , 42. , 67. , 24.,15.1, 56.4, 93. , 28., 109.1, 82.1, 12.7, 113.1,114. , 14.2, 116.2, 11. ,31. , 22. , 87., 44. ,55.1, 46.4, 73., 28. ,119.1, 12.1, 112.7, 13.1,14. , 114.2, 16.2, 117. , 91. , 82. , 37., 64. ,55.1, 46.4, 73., 28. ,119.1, 12.1, 112.7, 13.1,14. , 114.2, 16.2, 117. }); // ---------------------------------------------------------------- auto idx = NDArrayFactory::create({0, 1, 1, 2}); auto exp = NDArrayFactory::create('c', {3, 4, 4}, { 91. , 82. , 37. , 64. ,55.1 , 46.4 , 73. , 28. ,119.1 , 12.1 , 112.7 , 13.1,14. , 114.2 , 16.2 , 117. , 41. , 32. , 77. , 34. ,35.1 , 51.4 , 83. , 28. ,114.1 , 47.1 , 62.7, 63.1,64. , 64.2 , 66.2 , 64. , 91. , 82. , 37. , 64. ,55.1 , 46.4 , 73. , 28. ,119.1 , 12.1 , 112.7 , 13.1,14. , 114.2 , 16.2 , 117. }); nd4j::ops::segment_mean op; auto result = op.execute({&x, &idx}, {}, {}); ASSERT_EQ(result->status(), Status::OK()); // result->at(0)->printIndexedBuffer("Output"); // result->at(0)->printShapeInfo("Out Shape"); // exp.printIndexedBuffer("Expect"); // exp.printShapeInfo("Exp Shape"); ASSERT_TRUE(exp.equalsTo(result->at(0))); delete result; } //////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests7, TestSegmentMean_4) { auto x = NDArrayFactory::create('c', {4, 4, 4}, { 91. , 82. , 37. , 64. ,55.1, 46.4, 73. , 28. ,119.1, 12.1, 112.7, 13.1,14. , 114.2, 16.2, 117. ,51. , 42. , 67. , 24.,15.1, 56.4, 93. , 28., 109.1, 82.1, 12.7, 113.1,114. , 14.2, 116.2, 11. ,31. , 22. , 87., 44. ,55.1, 46.4, 73., 28. ,119.1, 12.1, 112.7, 13.1,14. , 114.2, 16.2, 117. , 91. , 82. , 37., 64. ,55.1, 46.4, 73., 28. ,119.1, 12.1, 112.7, 13.1,14. , 114.2, 16.2, 117. }); // ---------------------------------------------------------------- auto idx = NDArrayFactory::create({0, 1, 3, 7}); auto exp = NDArrayFactory::create('c', {8, 4, 4}, { 91. , 82. , 37. , 64. ,55.1, 46.4, 73. , 28. ,119.1, 12.1, 112.7, 13.1,14. , 114.2, 16.2, 117. ,51. , 42. , 67. , 24. ,15.1, 56.4, 93. , 28. , 109.1, 82.1, 12.7, 113.1,114. , 14.2, 116.2, 11. ,0. , 0. , 0. , 0. ,0. , 0. , 0. , 0. ,0. , 0. , 0. , 0. ,0. , 0. , 0. , 0. , 31. , 22. , 87. , 44. ,55.1, 46.4, 73. , 28. ,119.1, 12.1, 112.7, 13.1,14. , 114.2, 16.2, 117. ,0. , 0. , 0. , 0. ,0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. ,0. , 0. , 0. , 0. ,0. , 0. , 0. , 0. ,0. , 0. , 0. , 0. ,0. , 0. , 0. , 0. ,0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. ,0. , 0. , 0. , 0. ,0. , 0. , 0. , 0. ,0. , 0. , 0. , 0. ,91. , 82. , 37. , 64. ,55.1, 46.4, 73. , 28. , 119.1, 12.1, 112.7, 13.1,14. , 114.2, 16.2, 117. }); nd4j::ops::segment_mean op; auto result = op.execute({&x, &idx}, {}, {}); ASSERT_EQ(result->status(), Status::OK()); //result->at(0)->printIndexedBuffer("Output"); //result->at(0)->printShapeInfo("Out Shape"); //exp.printIndexedBuffer("Expect"); //exp.printShapeInfo("Exp Shape"); ASSERT_TRUE(exp.isSameShape(result->at(0))); ASSERT_TRUE(exp.equalsTo(result->at(0))); delete result; } //////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests7, TestUnsortedSegmentMean_1) { auto x = NDArrayFactory::create({1.8, 2.5,4., 9., 2.1, 2.4,3.,9., 2.1, 2.1,0.7, 0.1, 3., 4.2, 2.2, 1.}); auto idx = NDArrayFactory::create({0, 0, 1, 1, 1, 1, 2, 3, 3, 3, 4, 4, 4, 4, 4, 4}); auto exp = NDArrayFactory::create({2.15, 4.375, 3., 4.4, 1.8666667}); nd4j::ops::unsorted_segment_mean op; auto result = op.execute({&x, &idx}, {}, {5}); ASSERT_EQ(result->status(), Status::OK()); ASSERT_TRUE(exp.equalsTo(result->at(0))); delete result; } //////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests7, TestSegmentMeanBP_1) { auto x = NDArrayFactory::create({1.8, 2.5,4., 9., 2.1, 2.4,3.,9., 2.1, 2.1,0.7, 0.1, 3., 4.2, 2.2, 1.}); auto idx = NDArrayFactory::create({0, 0, 1, 1, 1, 1, 2, 3, 3, 3, 4, 4, 4, 4, 4, 4}); auto eps = NDArrayFactory::create({1., 2., 3., 4., 5.}); auto exp = NDArrayFactory::create({1./2., 1./2., 2./4., 2./4., 2./4., 2./4, 3., 4./3., 4./3., 4./3., 5./6., 5./6., 5./6., 5./6., 5./6., 5./6.}); nd4j::ops::segment_mean_bp op; auto result = op.execute({&x, &idx, &eps}, {}, {}); ASSERT_EQ(result->status(), Status::OK()); ASSERT_TRUE(exp.equalsTo(result->at(0))); delete result; } //////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests7, TestUnsortedSegmentMeanBP_1) { auto x = NDArrayFactory::create({1.8, 2.5,4., 9., 2.1, 2.4,3.,9., 2.1, 2.1,0.7, 0.1, 3., 4.2, 2.2, 1.}); auto idx = NDArrayFactory::create({0, 0, 1, 1, 1, 1, 2, 3, 3, 3, 4, 4, 4, 4, 4, 4}); auto eps = NDArrayFactory::create({1., 2., 3., 4., 5.}); auto exp = NDArrayFactory::create({1./2., 1./2., 2./4., 2./4., 2./4., 2./4, 3., 4./3., 4./3., 4./3., 5./6., 5./6., 5./6., 5./6., 5./6., 5./6.}); nd4j::ops::unsorted_segment_mean_bp op; auto result = op.execute({&x, &idx, &eps}, {}, {5}); ASSERT_EQ(result->status(), Status::OK()); ASSERT_TRUE(exp.equalsTo(result->at(0))); delete result; } //////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests7, TestUnsortedSegmentMeanBP_2) { auto x = NDArrayFactory::create({3.,1.8, 2.5,4., 9., 2.1, 2.4,9., 2.1, 2.1,0.7, 0.1, 3., 4.2, 2.2, 1.}); auto idx = NDArrayFactory::create({2, 0, 0, 1, 1, 1, 1, 3, 3, 3, 4, 4, 4, 4, 4, 4}); auto eps = NDArrayFactory::create({1., 2., 3., 4., 5.}); auto exp = NDArrayFactory::create({3., 1./2., 1./2., 2./4., 2./4., 2./4., 2./4, 4./3., 4./3., 4./3., 5./6., 5./6., 5./6., 5./6., 5./6., 5./6.}); nd4j::ops::unsorted_segment_mean_bp op; auto result = op.execute({&x, &idx, &eps}, {}, {5}); ASSERT_EQ(result->status(), Status::OK()); ASSERT_TRUE(exp.equalsTo(result->at(0))); delete result; } //////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests7, TestUnsortedSegmentMean_2) { auto x = NDArrayFactory::create('c', {4, 4}, {1.8, 2.5, 4., 9.,2.1, 2.4, 3., 9.,2.1, 2.1, 0.7, 0.1,3., 4.2, 2.2, 1.}); auto idx = NDArrayFactory::create({0, 0, 1, 2}); auto exp = NDArrayFactory::create('c', {3, 4}, { 1.95, 2.45, 3.5, 9., 2.1, 2.1, 0.7, 0.1, 3. , 4.2, 2.2, 1.}); nd4j::ops::unsorted_segment_mean op; auto result = op.execute({&x, &idx}, {}, {3}); ASSERT_EQ(result->status(), Status::OK()); ASSERT_EQ(result->size(), 1); // exp.printIndexedBuffer("Expect"); // result->at(0)->printIndexedBuffer("Output"); // exp.printShapeInfo("Exp Shape"); ASSERT_TRUE(exp.equalsTo(result->at(0))); delete result; } //////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests7, TestUnsortedSegmentMean_3) { auto x = NDArrayFactory::create('c', {4, 4, 4}, { 91. , 82. , 37. , 64. ,55.1, 46.4, 73. , 28. ,119.1, 12.1, 112.7, 13.1,14. , 114.2, 16.2, 117. ,51. , 42. , 67. , 24.,15.1, 56.4, 93. , 28., 109.1, 82.1, 12.7, 113.1,114. , 14.2, 116.2, 11. ,31. , 22. , 87., 44. ,55.1, 46.4, 73., 28. ,119.1, 12.1, 112.7, 13.1,14. , 114.2, 16.2, 117. , 91. , 82. , 37., 64. ,55.1, 46.4, 73., 28. ,119.1, 12.1, 112.7, 13.1,14. , 114.2, 16.2, 117. }); // ---------------------------------------------------------------- auto idx = NDArrayFactory::create({0, 1, 1, 2}); auto exp = NDArrayFactory::create('c', {3, 4, 4}, { 91. , 82. , 37. , 64. ,55.1 , 46.4 , 73. , 28. ,119.1 , 12.1 , 112.7 , 13.1,14. , 114.2 , 16.2 , 117. , 41. , 32. , 77. , 34. ,35.1 , 51.4 , 83. , 28. ,114.1 , 47.1 , 62.7, 63.1,64. , 64.2 , 66.2 , 64. , 91. , 82. , 37. , 64. ,55.1 , 46.4 , 73. , 28. ,119.1 , 12.1 , 112.7 , 13.1,14. , 114.2 , 16.2 , 117. }); nd4j::ops::unsorted_segment_mean op; auto result = op.execute({&x, &idx}, {}, {3}); ASSERT_EQ(result->status(), Status::OK()); // result->at(0)->printIndexedBuffer("Output"); // result->at(0)->printShapeInfo("Out Shape"); // exp.printIndexedBuffer("Expect"); // exp.printShapeInfo("Exp Shape"); ASSERT_TRUE(exp.equalsTo(result->at(0))); delete result; } //////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests7, TestUnsortedSegmentMean_4) { auto x = NDArrayFactory::create('c', {4, 4, 4}, { 91. , 82. , 37. , 64. ,55.1, 46.4, 73. , 28. ,119.1, 12.1, 112.7, 13.1,14. , 114.2, 16.2, 117. ,51. , 42. , 67. , 24.,15.1, 56.4, 93. , 28., 109.1, 82.1, 12.7, 113.1,114. , 14.2, 116.2, 11. ,31. , 22. , 87., 44. ,55.1, 46.4, 73., 28. ,119.1, 12.1, 112.7, 13.1,14. , 114.2, 16.2, 117. , 91. , 82. , 37., 64. ,55.1, 46.4, 73., 28. ,119.1, 12.1, 112.7, 13.1,14. , 114.2, 16.2, 117. }); // ---------------------------------------------------------------- auto idx = NDArrayFactory::create({0, 1, 3, 7}); auto exp = NDArrayFactory::create('c', {8, 4, 4}, { 91. , 82. , 37. , 64. ,55.1, 46.4, 73. , 28. ,119.1, 12.1, 112.7, 13.1,14. , 114.2, 16.2, 117. ,51. , 42. , 67. , 24. ,15.1, 56.4, 93. , 28. , 109.1, 82.1, 12.7, 113.1,114. , 14.2, 116.2, 11. ,0. , 0. , 0. , 0. ,0. , 0. , 0. , 0. ,0. , 0. , 0. , 0. ,0. , 0. , 0. , 0. , 31. , 22. , 87. , 44. ,55.1, 46.4, 73. , 28. ,119.1, 12.1, 112.7, 13.1,14. , 114.2, 16.2, 117. ,0. , 0. , 0. , 0. ,0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. ,0. , 0. , 0. , 0. ,0. , 0. , 0. , 0. ,0. , 0. , 0. , 0. ,0. , 0. , 0. , 0. ,0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. ,0. , 0. , 0. , 0. ,0. , 0. , 0. , 0. ,0. , 0. , 0. , 0. ,91. , 82. , 37. , 64. ,55.1, 46.4, 73. , 28. , 119.1, 12.1, 112.7, 13.1,14. , 114.2, 16.2, 117. }); nd4j::ops::unsorted_segment_mean op; auto result = op.execute({&x, &idx}, {}, {8}); ASSERT_EQ(result->status(), Status::OK()); //result->at(0)->printIndexedBuffer("Output"); //result->at(0)->printShapeInfo("Out Shape"); //exp.printIndexedBuffer("Expect"); //exp.printShapeInfo("Exp Shape"); ASSERT_TRUE(exp.isSameShape(result->at(0))); ASSERT_TRUE(exp.equalsTo(result->at(0))); delete result; } //////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests7, TestUnsortedSegmentSqrtN_1) { auto x = NDArrayFactory::create({1.8, 2.5,4., 9., 2.1, 2.4,3.,9., 2.1, 2.1,0.7, 0.1, 3., 4.2, 2.2, 1.}); auto idx = NDArrayFactory::create({0, 0, 1, 1, 1, 1, 2, 3, 3, 3, 4, 4, 4, 4, 4, 4}); auto exp = NDArrayFactory::create({3.0405593, 8.75, 3., 7.621024, 4.5723805}); nd4j::ops::unsorted_segment_sqrt_n op; auto result = op.execute({&x, &idx}, {}, {5}); ASSERT_EQ(result->status(), Status::OK()); ASSERT_TRUE(exp.equalsTo(result->at(0))); delete result; } //////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests7, TestUnsortedSegmentSqrtN_BP_1) { auto x = NDArrayFactory::create({3.,1.8, 2.5,4., 9., 2.1, 2.4,9., 2.1, 2.1,0.7, 0.1, 3., 4.2, 2.2, 1.}); auto idx = NDArrayFactory::create({2, 0, 0, 1, 1, 1, 1, 3, 3, 3, 4, 4, 4, 4, 4, 4}); auto eps = NDArrayFactory::create({1., 2., 3., 4., 5.}); // NDArray exp({3.0405593, 8.75, 3., 7.621024, 4.5723805}); auto exp = NDArrayFactory::create({3., 0.707107, 0.707107, 1., 1., 1., 1., 2.309401, 2.309401, 2.309401, 2.041241, 2.041241, 2.041241, 2.041241, 2.041241, 2.041241}); nd4j::ops::unsorted_segment_sqrt_n_bp op; auto result = op.execute({&x, &idx, &eps}, {}, {5}); ASSERT_EQ(result->status(), Status::OK()); // result->at(0)->printIndexedBuffer("Hello Out:"); ASSERT_TRUE(exp.equalsTo(result->at(0))); delete result; } //////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests7, TestUnsortedSegmentSqrtN_2) { auto x = NDArrayFactory::create('c', {4, 4}, {1.8, 2.5, 4., 9.,2.1, 2.4, 3., 9.,2.1, 2.1, 0.7, 0.1,3., 4.2, 2.2, 1.}); auto idx = NDArrayFactory::create({0, 0, 1, 2}); auto exp = NDArrayFactory::create('c', {3, 4}, { 2.7577164, 3.4648232, 4.9497476, 12.727922, 2.1, 2.1, 0.7, 0.1, 3. , 4.2, 2.2, 1. }); nd4j::ops::unsorted_segment_sqrt_n op; auto result = op.execute({&x, &idx}, {}, {3}); ASSERT_EQ(result->status(), Status::OK()); ASSERT_EQ(result->size(), 1); // exp.printIndexedBuffer("Expect"); // result->at(0)->printIndexedBuffer("Output"); // exp.printShapeInfo("Exp Shape"); ASSERT_TRUE(exp.equalsTo(result->at(0))); delete result; } //////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests7, TestUnsortedSegmentSqrtN_3) { auto x = NDArrayFactory::create('c', {4, 4, 4}, { 91. , 82. , 37. , 64. ,55.1, 46.4, 73. , 28. ,119.1, 12.1, 112.7, 13.1,14. , 114.2, 16.2, 117. ,51. , 42. , 67. , 24.,15.1, 56.4, 93. , 28., 109.1, 82.1, 12.7, 113.1,114. , 14.2, 116.2, 11. ,31. , 22. , 87., 44. ,55.1, 46.4, 73., 28. ,119.1, 12.1, 112.7, 13.1,14. , 114.2, 16.2, 117. , 91. , 82. , 37., 64. ,55.1, 46.4, 73., 28. ,119.1, 12.1, 112.7, 13.1,14. , 114.2, 16.2, 117. }); // ---------------------------------------------------------------- auto idx = NDArrayFactory::create({0, 1, 1, 2}); auto exp = NDArrayFactory::create('c', {3, 4, 4}, { 91. , 82. , 37. , 64. ,55.1 , 46.4 , 73. , 28. ,119.1 , 12.1 , 112.7 , 13.1,14. , 114.2 , 16.2 , 117. , 57.982758, 45.254833, 108.89445, 48.083263, 49.638893, 72.69058, 117.37973, 39.59798, 161.36177, 66.60946, 88.67119, 89.23688, 90.50967, 90.79251, 93.62093, 90.50967, 91. , 82. , 37. , 64. ,55.1 , 46.4 , 73. , 28. ,119.1 , 12.1 , 112.7 , 13.1,14. , 114.2 , 16.2 , 117. }); nd4j::ops::unsorted_segment_sqrt_n op; auto result = op.execute({&x, &idx}, {}, {3}); ASSERT_EQ(result->status(), Status::OK()); // result->at(0)->printIndexedBuffer("Output"); // result->at(0)->printShapeInfo("Out Shape"); // exp.printIndexedBuffer("Expect"); // exp.printShapeInfo("Exp Shape"); ASSERT_TRUE(exp.equalsTo(result->at(0))); delete result; } //////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests7, TestUnsortedSegmentSqrtN_4) { auto x = NDArrayFactory::create('c', {4, 4, 4}, { 91. , 82. , 37. , 64. ,55.1, 46.4, 73. , 28. ,119.1, 12.1, 112.7, 13.1,14. , 114.2, 16.2, 117. ,51. , 42. , 67. , 24.,15.1, 56.4, 93. , 28., 109.1, 82.1, 12.7, 113.1,114. , 14.2, 116.2, 11. ,31. , 22. , 87., 44. ,55.1, 46.4, 73., 28. ,119.1, 12.1, 112.7, 13.1,14. , 114.2, 16.2, 117. , 91. , 82. , 37., 64. ,55.1, 46.4, 73., 28. ,119.1, 12.1, 112.7, 13.1,14. , 114.2, 16.2, 117. }); // ---------------------------------------------------------------- auto idx = NDArrayFactory::create({0, 1, 3, 7}); auto exp = NDArrayFactory::create('c', {8, 4, 4}, { 91. , 82. , 37. , 64. ,55.1, 46.4, 73. , 28. ,119.1, 12.1, 112.7, 13.1,14. , 114.2, 16.2, 117. ,51. , 42. , 67. , 24. ,15.1, 56.4, 93. , 28. , 109.1, 82.1, 12.7, 113.1,114. , 14.2, 116.2, 11. ,0. , 0. , 0. , 0. ,0. , 0. , 0. , 0. ,0. , 0. , 0. , 0. ,0. , 0. , 0. , 0. , 31. , 22. , 87. , 44. ,55.1, 46.4, 73. , 28. ,119.1, 12.1, 112.7, 13.1,14. , 114.2, 16.2, 117. ,0. , 0. , 0. , 0. ,0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. ,0. , 0. , 0. , 0. ,0. , 0. , 0. , 0. ,0. , 0. , 0. , 0. ,0. , 0. , 0. , 0. ,0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. ,0. , 0. , 0. , 0. ,0. , 0. , 0. , 0. ,0. , 0. , 0. , 0. ,91. , 82. , 37. , 64. ,55.1, 46.4, 73. , 28. , 119.1, 12.1, 112.7, 13.1,14. , 114.2, 16.2, 117. }); nd4j::ops::unsorted_segment_sqrt_n op; auto result = op.execute({&x, &idx}, {}, {8}); ASSERT_EQ(result->status(), Status::OK()); //result->at(0)->printIndexedBuffer("Output"); //result->at(0)->printShapeInfo("Out Shape"); //exp.printIndexedBuffer("Expect"); //exp.printShapeInfo("Exp Shape"); ASSERT_TRUE(exp.isSameShape(result->at(0))); ASSERT_TRUE(exp.equalsTo(result->at(0))); delete result; } //////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests7, TestUnsortedSegmentSqrtN_5) { auto x = NDArrayFactory::create({1.,2.,5.,7.,3.,1.,3.,4.}); auto idx = NDArrayFactory::create({3, 1, 0, 0, 2, 0, 3, 2}); //NDArray exp({1.7320508075688772, 1., 1.4142135623730951, 1.4142135623730951}); auto exp = NDArrayFactory::create({7.5055537, 2., 4.9497476, 2.828427}); nd4j::ops::unsorted_segment_sqrt_n op; auto result = op.execute({&x, &idx}, {}, {4}); ASSERT_EQ(result->status(), Status::OK()); // result->at(0)->printIndexedBuffer("Output"); // exp.printIndexedBuffer("Expect"); ASSERT_TRUE(exp.equalsTo(result->at(0))); delete result; } //////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests7, TestSegmentSum_1) { auto x = NDArrayFactory::create({1.8, 2.5,4., 9., 2.1, 2.4,3.,9., 2.1, 2.1,0.7, 0.1, 3., 4.2, 2.2, 1. }); auto idx = NDArrayFactory::create({0, 0, 1, 1, 1, 1, 2, 3, 3, 3, 4, 4, 4, 4, 4, 4}); auto exp = NDArrayFactory::create({4.3, 17.5, 3., 13.2, 11.2}); nd4j::ops::segment_sum op; auto result = op.execute({&x, &idx}, {}, {}); ASSERT_EQ(result->status(), Status::OK()); result->at(0)->printIndexedBuffer("Output Sum"); exp.printIndexedBuffer("Expect Sum"); ASSERT_TRUE(exp.equalsTo(result->at(0))); delete result; } //////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests7, TestSegmentSumBP_1) { auto x = NDArrayFactory::create({1.8, 2.5,4., 9., 2.1, 2.4,3.,9., 2.1, 2.1,0.7, 0.1, 3., 4.2, 2.2, 1. }); auto idx = NDArrayFactory::create({0, 0, 1, 1, 1, 1, 2, 3, 3, 3, 4, 4, 4, 4, 4, 4}); auto eps = NDArrayFactory::create({1., 2., 3., 4., 5.}); auto exp = NDArrayFactory::create({ 1., 1., 2., 2., 2., 2., 3., 4., 4., 4., 5., 5., 5., 5., 5., 5.}); nd4j::ops::segment_sum_bp op; auto result = op.execute({&x, &idx, &eps}, {}, {}); ASSERT_EQ(result->status(), Status::OK()); ASSERT_TRUE(exp.equalsTo(result->at(0))); delete result; } TEST_F(DeclarableOpsTests7, TestUnsortedSegmentSumBP_1) { auto x = NDArrayFactory::create({1.8, 2.5,4., 9., 2.1, 2.4,3.,9., 2.1, 2.1,0.7, 0.1, 3., 4.2, 2.2, 1. }); auto idx = NDArrayFactory::create({0, 0, 1, 1, 1, 1, 2, 3, 3, 3, 4, 4, 4, 4, 4, 4}); auto eps = NDArrayFactory::create({1, 2, 3, 4, 5}); auto exp = NDArrayFactory::create({ 1., 1., 2., 2., 2., 2., 3., 4., 4., 4., 5., 5., 5., 5., 5., 5.}); nd4j::ops::unsorted_segment_sum_bp op; auto result = op.execute({&x, &idx, &eps}, {}, {5}); ASSERT_EQ(result->status(), Status::OK()); ASSERT_TRUE(exp.equalsTo(result->at(0))); delete result; } TEST_F(DeclarableOpsTests7, TestUnsortedSegmentSumBP_2) { auto x = NDArrayFactory::create({3.,1.8, 2.5,4., 9., 2.1, 2.4,9., 2.1, 2.1,0.7, 0.1, 3., 4.2, 2.2, 1. }); auto idx = NDArrayFactory::create({2, 0, 0, 1, 1, 1, 1, 3, 3, 3, 4, 4, 4, 4, 4, 4}); auto eps = NDArrayFactory::create({1., 2., 3., 4., 5.}); auto exp = NDArrayFactory::create({ 3., 1., 1., 2., 2., 2., 2., 4., 4., 4., 5., 5., 5., 5., 5., 5.}); nd4j::ops::unsorted_segment_sum_bp op; auto result = op.execute({&x, &idx, &eps}, {}, {5}); ASSERT_EQ(result->status(), Status::OK()); ASSERT_TRUE(exp.equalsTo(result->at(0))); delete result; } //////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests7, TestSegmentSum_2) { auto x = NDArrayFactory::create('c', {4, 4}, {1.8, 2.5, 4., 9.,2.1, 2.4, 3., 9.,2.1, 2.1, 0.7, 0.1,3., 4.2, 2.2, 1. }); auto idx = NDArrayFactory::create({0, 0, 1, 2}); auto exp = NDArrayFactory::create('c', {3, 4}, {3.9 , 4.9, 7. , 18.,2.1 , 2.1, 0.7, 0.1,3. , 4.2, 2.2, 1.}); nd4j::ops::segment_sum op; auto result = op.execute({&x, &idx}, {}, {}); ASSERT_EQ(result->status(), Status::OK()); ASSERT_EQ(result->size(), 1); // exp.printIndexedBuffer("Expect"); // exp.printShapeInfo("Exp Shape"); ASSERT_TRUE(exp.equalsTo(result->at(0))); delete result; } //////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests7, TestSegmentSumBP_2) { auto x = NDArrayFactory::create('c', {4, 4}, {1.8, 2.5, 4., 9.,2.1, 2.4, 3., 9.,2.1, 2.1, 0.7, 0.1,3., 4.2, 2.2, 1. }); auto idx = NDArrayFactory::create({0, 0, 1, 2}); auto exp = NDArrayFactory::create('c', {4, 4}, {1. , 2., 3., 4., 1. , 2., 3., 4., 5., 6., 7., 8., 9., 10., 11., 12.}); auto eps = NDArrayFactory::create('c', {3, 4}); eps.linspace(1); nd4j::ops::segment_sum_bp op; auto result = op.execute({&x, &idx, &eps}, {}, {}); ASSERT_EQ(result->status(), Status::OK()); ASSERT_EQ(result->size(), 2); // exp.printIndexedBuffer("Expect"); // exp.printShapeInfo("Exp Shape"); ASSERT_TRUE(exp.equalsTo(result->at(0))); delete result; } //////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests7, TestSegmentSum_3) { auto x = NDArrayFactory::create('c', {4, 4, 4}, { 91. , 82. , 37. , 64. ,55.1, 46.4, 73. , 28. ,119.1, 12.1, 112.7, 13.1,14. , 114.2, 16.2, 117. ,51. , 42. , 67. , 24.,15.1, 56.4, 93. , 28., 109.1, 82.1, 12.7, 113.1, 114. , 14.2, 116.2, 11. ,31. , 22. , 87., 44. ,55.1, 46.4, 73., 28. ,119.1, 12.1, 112.7, 13.1,14. , 114.2, 16.2, 117. , 91. , 82. , 37., 64. ,55.1, 46.4, 73., 28. ,119.1, 12.1, 112.7, 13.1, 14. , 114.2, 16.2, 117. }); // ---------------------------------------------------------------- auto idx = NDArrayFactory::create({0, 1, 1, 2}); auto exp = NDArrayFactory::create('c', {3, 4, 4}, { 91. , 82. , 37. , 64. ,55.1, 46.4, 73. , 28. ,119.1, 12.1, 112.7, 13.1,14. , 114.2, 16.2, 117. ,82. , 64. , 154. , 68. , 70.2, 102.8, 166. , 56. ,228.2, 94.2, 125.4, 126.2 ,128. , 128.4, 132.4, 128. ,91. , 82. , 37. , 64. , 55.1, 46.4, 73. , 28. ,119.1, 12.1, 112.7, 13.1,14. , 114.2, 16.2, 117. }); nd4j::ops::segment_sum op; auto result = op.execute({&x, &idx}, {}, {}); ASSERT_EQ(result->status(), Status::OK()); // result->at(0)->printIndexedBuffer("Output"); // result->at(0)->printShapeInfo("Out Shape"); // exp.printIndexedBuffer("Expect"); // exp.printShapeInfo("Exp Shape"); ASSERT_TRUE(exp.equalsTo(result->at(0))); delete result; } //////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests7, TestSegmentSum_4) { auto x = NDArrayFactory::create('c', {4, 4, 4}, { 91. , 82. , 37. , 64. ,55.1, 46.4, 73. , 28. ,119.1, 12.1, 112.7, 13.1,14. , 114.2, 16.2, 117. ,51. , 42. , 67. , 24.,15.1, 56.4, 93. , 28., 109.1, 82.1, 12.7, 113.1,114. , 14.2, 116.2, 11. ,31. , 22. , 87., 44. ,55.1, 46.4, 73., 28. ,119.1, 12.1, 112.7, 13.1,14. , 114.2, 16.2, 117. , 91. , 82. , 37., 64. ,55.1, 46.4, 73., 28. ,119.1, 12.1, 112.7, 13.1,14. , 114.2, 16.2, 117. }); // ---------------------------------------------------------------- auto idx = NDArrayFactory::create({0, 1, 3, 7}); auto exp = NDArrayFactory::create('c', {8, 4, 4}, { 91. , 82. , 37. , 64. ,55.1, 46.4, 73. , 28. ,119.1, 12.1, 112.7, 13.1,14. , 114.2, 16.2, 117. ,51. , 42. , 67. , 24. ,15.1, 56.4, 93. , 28. , 109.1, 82.1, 12.7, 113.1,114. , 14.2, 116.2, 11. ,0. , 0. , 0. , 0. ,0. , 0. , 0. , 0. ,0. , 0. , 0. , 0. ,0. , 0. , 0. , 0. , 31. , 22. , 87. , 44. ,55.1, 46.4, 73. , 28. ,119.1, 12.1, 112.7, 13.1,14. , 114.2, 16.2, 117. ,0. , 0. , 0. , 0. ,0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. ,0. , 0. , 0. , 0. ,0. , 0. , 0. , 0. ,0. , 0. , 0. , 0. ,0. , 0. , 0. , 0. ,0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. ,0. , 0. , 0. , 0. ,0. , 0. , 0. , 0. ,0. , 0. , 0. , 0. ,91. , 82. , 37. , 64. ,55.1, 46.4, 73. , 28. , 119.1, 12.1, 112.7, 13.1,14. , 114.2, 16.2, 117. }); nd4j::ops::segment_sum op; auto result = op.execute({&x, &idx}, {}, {}); ASSERT_EQ(result->status(), Status::OK()); //result->at(0)->printIndexedBuffer("Output"); //result->at(0)->printShapeInfo("Out Shape"); //exp.printIndexedBuffer("Expect"); //exp.printShapeInfo("Exp Shape"); ASSERT_TRUE(exp.isSameShape(result->at(0))); ASSERT_TRUE(exp.equalsTo(result->at(0))); delete result; } //////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests7, TestUnsortedSegmentSum_1) { auto x = NDArrayFactory::create({1.8, 2.5,4., 9., 2.1, 2.4,3.,9., 2.1, 2.1,0.7, 0.1, 3., 4.2, 2.2, 1. }); auto idx = NDArrayFactory::create({0, 0, 1, 1, 1, 1, 2, 3, 3, 3, 4, 4, 4, 4, 4, 4}); auto exp = NDArrayFactory::create({4.3, 17.5, 3., 13.2, 11.2}); nd4j::ops::unsorted_segment_sum op; auto result = op.execute({&x, &idx}, {}, {5}); ASSERT_EQ(result->status(), Status::OK()); result->at(0)->printIndexedBuffer("UnsortedSum1"); exp.printIndexedBuffer("Unsorted Sum1 Exp"); ASSERT_TRUE(exp.equalsTo(result->at(0))); delete result; } //////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests7, TestUnsortedSegmentSum_2) { auto x = NDArrayFactory::create('c', {4, 4}, {1.8, 2.5, 4., 9.,2.1, 2.4, 3., 9.,2.1, 2.1, 0.7, 0.1,3., 4.2, 2.2, 1. }); auto idx = NDArrayFactory::create({0, 0, 1, 2}); auto exp = NDArrayFactory::create('c', {3, 4}, {3.9 , 4.9, 7. , 18.,2.1 , 2.1, 0.7, 0.1,3. , 4.2, 2.2, 1.}); nd4j::ops::unsorted_segment_sum op; auto result = op.execute({&x, &idx}, {}, {3}); ASSERT_EQ(result->status(), Status::OK()); ASSERT_EQ(result->size(), 1); // exp.printIndexedBuffer("Expect"); // exp.printShapeInfo("Exp Shape"); ASSERT_TRUE(exp.equalsTo(result->at(0))); delete result; } //////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests7, TestUnsortedSegmentSum_3) { auto x = NDArrayFactory::create('c', {4, 4, 4}, { 91. , 82. , 37. , 64. ,55.1, 46.4, 73. , 28. ,119.1, 12.1, 112.7, 13.1,14. , 114.2, 16.2, 117. ,51. , 42. , 67. , 24.,15.1, 56.4, 93. , 28., 109.1, 82.1, 12.7, 113.1, 114. , 14.2, 116.2, 11. ,31. , 22. , 87., 44. ,55.1, 46.4, 73., 28. ,119.1, 12.1, 112.7, 13.1,14. , 114.2, 16.2, 117. , 91. , 82. , 37., 64. ,55.1, 46.4, 73., 28. ,119.1, 12.1, 112.7, 13.1, 14. , 114.2, 16.2, 117. }); // ---------------------------------------------------------------- auto idx = NDArrayFactory::create({0, 1, 1, 2}); auto exp = NDArrayFactory::create('c', {3, 4, 4}, { 91. , 82. , 37. , 64. ,55.1, 46.4, 73. , 28. ,119.1, 12.1, 112.7, 13.1,14. , 114.2, 16.2, 117. ,82. , 64. , 154. , 68. , 70.2, 102.8, 166. , 56. ,228.2, 94.2, 125.4, 126.2 ,128. , 128.4, 132.4, 128. ,91. , 82. , 37. , 64. , 55.1, 46.4, 73. , 28. ,119.1, 12.1, 112.7, 13.1,14. , 114.2, 16.2, 117. }); nd4j::ops::unsorted_segment_sum op; auto result = op.execute({&x, &idx}, {}, {3}); ASSERT_EQ(result->status(), Status::OK()); // result->at(0)->printIndexedBuffer("Output"); // result->at(0)->printShapeInfo("Out Shape"); // exp.printIndexedBuffer("Expect"); // exp.printShapeInfo("Exp Shape"); ASSERT_TRUE(exp.equalsTo(result->at(0))); delete result; } //////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests7, TestUnsortedSegmentSum_4) { auto x = NDArrayFactory::create('c', {4, 4, 4}, {91. , 82. , 37. , 64. ,55.1, 46.4, 73. , 28. ,119.1, 12.1, 112.7, 13.1,14. , 114.2, 16.2, 117. ,51. , 42. , 67. , 24.,15.1, 56.4, 93. , 28., 109.1, 82.1, 12.7, 113.1,114. , 14.2, 116.2, 11. ,31. , 22. , 87., 44. ,55.1, 46.4, 73., 28. ,119.1, 12.1, 112.7, 13.1,14. , 114.2, 16.2, 117. , 91. , 82. , 37., 64. ,55.1, 46.4, 73., 28. ,119.1, 12.1, 112.7, 13.1,14. , 114.2, 16.2, 117. }); // ---------------------------------------------------------------- auto idx = NDArrayFactory::create({0, 1, 3, 7}); auto exp = NDArrayFactory::create('c', {8, 4, 4}, { 91. , 82. , 37. , 64. ,55.1, 46.4, 73. , 28. ,119.1, 12.1, 112.7, 13.1,14. , 114.2, 16.2, 117. ,51. , 42. , 67. , 24. ,15.1, 56.4, 93. , 28. , 109.1, 82.1, 12.7, 113.1,114. , 14.2, 116.2, 11. ,0. , 0. , 0. , 0. ,0. , 0. , 0. , 0. ,0. , 0. , 0. , 0. ,0. , 0. , 0. , 0. , 31. , 22. , 87. , 44. ,55.1, 46.4, 73. , 28. ,119.1, 12.1, 112.7, 13.1,14. , 114.2, 16.2, 117. ,0. , 0. , 0. , 0. ,0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. ,0. , 0. , 0. , 0. ,0. , 0. , 0. , 0. ,0. , 0. , 0. , 0. ,0. , 0. , 0. , 0. ,0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. ,0. , 0. , 0. , 0. ,0. , 0. , 0. , 0. ,0. , 0. , 0. , 0. ,91. , 82. , 37. , 64. ,55.1, 46.4, 73. , 28. , 119.1, 12.1, 112.7, 13.1,14. , 114.2, 16.2, 117. }); nd4j::ops::unsorted_segment_sum op; auto result = op.execute({&x, &idx}, {}, {8}); ASSERT_EQ(result->status(), Status::OK()); // result->at(0)->printIndexedBuffer("Output"); // result->at(0)->printShapeInfo("Out Shape"); // exp.printIndexedBuffer("Expect"); //exp.printShapeInfo("Exp Shape"); ASSERT_TRUE(exp.isSameShape(result->at(0))); ASSERT_TRUE(exp.equalsTo(result->at(0))); delete result; } //////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests7, TestSegmentProd_1) { auto x = NDArrayFactory::create({1.8, 2.5, 4., 9., 2.1, 2.4,3.,9., 2.1, 2.1,0.7, 0.1, 3., 4.2, 2.2, 1.}); auto idx = NDArrayFactory::create({0, 0, 1, 1, 1, 1, 2, 3, 3, 3, 4, 4, 4, 4, 4, 4}); auto exp = NDArrayFactory::create({4.5, 181.44, 3., 39.69, 1.9404}); nd4j::ops::segment_prod op; auto result = op.execute({&x, &idx}, {}, {}); ASSERT_EQ(result->status(), Status::OK()); ASSERT_TRUE(exp.equalsTo(result->at(0))); delete result; } //////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests7, TestSegmentProdBP_1) { auto x = NDArrayFactory::create({ 1.8, 2.5, 4., 9., 2.1, 2.4, 3., 9., 2.1, 2.1, 0.7, 0.1, 3., 4.2, 2.2, 1.}); auto idx = NDArrayFactory::create({0, 0, 1, 1, 1, 1, 2, 3, 3, 3, 4, 4, 4, 4, 4, 4}); auto eps = NDArrayFactory::create({1., 2., 3., 4., 5.}); auto exp = NDArrayFactory::create({2.5, 1.8, 90.72, 40.32, 172.8, 151.2, 3., 17.64, 75.6, 75.6, 13.86, 97.02, 3.234, 2.31, 4.41, 9.702}); nd4j::ops::segment_prod_bp op; auto result = op.execute({&x, &idx, &eps}, {}, {}); ASSERT_EQ(result->status(), Status::OK()); // result->at(0)->printIndexedBuffer("ProdBP Output"); // exp.printIndexedBuffer("ProdBP Expect"); ASSERT_TRUE(exp.equalsTo(result->at(0))); delete result; } //////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests7, TestUnsortedSegmentProdBP_1) { auto x = NDArrayFactory::create({ 1.8, 2.5, 4., 9., 2.1, 2.4, 3., 9., 2.1, 2.1, 0.7, 0.1, 3., 4.2, 2.2, 1.}); auto idx = NDArrayFactory::create({0, 0, 1, 1, 1, 1, 2, 3, 3, 3, 4, 4, 4, 4, 4, 4}); auto eps = NDArrayFactory::create({1., 2., 3., 4., 5.}); auto exp = NDArrayFactory::create({2.5, 1.8, 90.72, 40.32, 172.8, 151.2, 3., 17.64, 75.6, 75.6, 13.86, 97.02, 3.234, 2.31, 4.41, 9.702}); nd4j::ops::segment_prod_bp op; auto result = op.execute({&x, &idx, &eps}, {}, {}); ASSERT_EQ(result->status(), Status::OK()); //result->at(0)->printIndexedBuffer("ProdBP Output"); //exp.printIndexedBuffer("ProdBP Expect"); ASSERT_TRUE(exp.equalsTo(result->at(0))); delete result; } //////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests7, TestUnsortedSegmentProdBP_2) { auto x = NDArrayFactory::create({ 3., 1.8, 2.5, 4., 9., 2.1, 2.4, 9., 2.1, 2.1, 0.7, 0.1, 3., 4.2, 2.2, 1.}); auto idx = NDArrayFactory::create({2, 0, 0, 1, 1, 1, 1, 3, 3, 3, 4, 4, 4, 4, 4, 4}); auto eps = NDArrayFactory::create({1., 2., 3., 4., 5.}); auto exp = NDArrayFactory::create({3., 2.5, 1.8, 90.72, 40.32, 172.8, 151.2, 17.64, 75.6, 75.6, 13.86, 97.02, 3.234, 2.31, 4.41, 9.702}); auto n = NDArrayFactory::create(5LL); nd4j::ops::unsorted_segment_prod_bp op; auto result = op.execute({&x, &idx, &eps, &n}, {}, {5}); ASSERT_EQ(result->status(), Status::OK()); //result->at(0)->printIndexedBuffer("Unsorted ProdBP Output"); //exp.printIndexedBuffer("Unsorted ProdBP Expect"); ASSERT_TRUE(exp.equalsTo(result->at(0))); delete result; } //////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests7, TestSegmentProd_2) { auto x = NDArrayFactory::create('c', {4, 4}, { 1.8, 2.5, 4., 9., 2.1, 2.4, 3., 9., 2.1, 2.1, 0.7, 0.1, 3., 4.2, 2.2, 1. }); auto idx = NDArrayFactory::create({0, 0, 1, 2}); auto exp = NDArrayFactory::create('c', {3, 4}, { 3.78, 6. , 12. , 81., 2.1 , 2.1, 0.7 , 0.1, 3. , 4.2, 2.2 , 1.}); //{ 2.1, 2.5, 4., 9., 2.1, 2.1, 0.7, 0.1, 3., 4.2, 2.2, 1.} nd4j::ops::segment_prod op; auto result = op.execute({&x, &idx}, {}, {}); ASSERT_EQ(result->status(), Status::OK()); ASSERT_EQ(result->size(), 1); // exp.printIndexedBuffer("Expect"); // exp.printShapeInfo("Exp Shape"); ASSERT_TRUE(exp.equalsTo(result->at(0))); delete result; } //////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests7, TestSegmentProdBP_2) { auto x = NDArrayFactory::create('c', {4, 4}, {1.8, 2.5, 4., 9., 2.1, 2.4, 3., 9., 2.1, 2.1, 0.7, 0.1, 3., 4.2, 2.2, 1. }); auto idx = NDArrayFactory::create({0, 0, 1, 2}); auto eps = NDArrayFactory::create('c', {3, 4}); auto exp = NDArrayFactory::create('c', {4, 4}, {2.1, 4.8, 9., 36., 1.8, 5., 12., 36., 5., 6., 7., 8., 9., 10., 11., 12.}); //{ 2.1, 2.5, 4., 9., 2.1, 2.1, 0.7, 0.1, 3., 4.2, 2.2, 1.} eps.linspace(1); nd4j::ops::segment_prod_bp op; auto result = op.execute({&x, &idx, &eps}, {}, {}); ASSERT_EQ(result->status(), Status::OK()); ASSERT_EQ(result->size(), 2); // exp.printIndexedBuffer("Expect"); // exp.printShapeInfo("Exp Shape"); ASSERT_TRUE(exp.equalsTo(result->at(0))); delete result; } //////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests7, TestSegmentProd_3) { auto x = NDArrayFactory::create('c', {4, 4, 4}, { 91. , 82. , 37. , 64. , 55.1, 46.4, 73. , 28. , 119.1, 12.1, 112.7, 13.1, 14. , 114.2, 16.2, 117. , 51. , 42. , 67. , 24., 15.1, 56.4, 93. , 28., 109.1, 82.1, 12.7, 113.1, 114. , 14.2, 116.2, 11. , 31. , 22. , 87., 44. , 55.1, 46.4, 73., 28. , 119.1, 12.1, 112.7, 13.1, 14. , 114.2, 16.2, 117. , 91. , 82. , 37., 64. , 55.1, 46.4, 73., 28. , 119.1, 12.1, 112.7, 13.1, 14. , 114.2, 16.2, 117. }); // ---------------------------------------------------------------- auto idx = NDArrayFactory::create({0, 1, 1, 2}); auto exp = NDArrayFactory::create('c', {3, 4, 4}, { 91. , 82. , 37. , 64. , 55.1, 46.4, 73. , 28. , 119.1, 12.1, 112.7, 13.1, 14. , 114.2, 16.2, 117. , 1581, 924, 5829, 1056,832.01001, 2616.9602, 6789, 784, 12993.810, 993.41003, 1431.2899, 1481.61, 1596, 1621.64, 1882.4401, 1287, 91. , 82. , 37. , 64. , 55.1, 46.4, 73. , 28. ,119.1, 12.1, 112.7, 13.1, 14. , 114.2, 16.2, 117. }); nd4j::ops::segment_prod op; auto result = op.execute({&x, &idx}, {}, {}); ASSERT_EQ(result->status(), Status::OK()); // result->at(0)->printIndexedBuffer("Output"); // result->at(0)->printShapeInfo("Out Shape"); // exp.printIndexedBuffer("Expect"); // exp.printShapeInfo("Exp Shape"); ASSERT_TRUE(exp.equalsTo(result->at(0))); delete result; } //////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests7, TestSegmentProd_04) { auto x = NDArrayFactory::create({1,2,3,4,5,6,7,8 }); // ---------------------------------------------------------------- auto idx = NDArrayFactory::create({0,0,1,2,2,2,3,3}); auto exp = NDArrayFactory::create({ 2, 3, 120, 56}); nd4j::ops::segment_prod op; auto result = op.execute({&x, &idx}, {}, {}); ASSERT_EQ(result->status(), Status::OK()); result->at(0)->printIndexedBuffer("Output"); // result->at(0)->printShapeInfo("Out Shape"); exp.printIndexedBuffer("Expect"); // exp.printShapeInfo("Exp Shape"); ASSERT_TRUE(exp.equalsTo(result->at(0))); delete result; } //////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests7, TestSegmentProd_05) { auto x = NDArrayFactory::create({1,2,3,4,5,6,7,8 }); // ---------------------------------------------------------------- auto idx = NDArrayFactory::create({0,0,1,2,2,2,3,3}); auto exp = NDArrayFactory::create({ 2, 3, 120, 56}); nd4j::ops::segment_prod op; auto result = op.execute({&x, &idx}, {}, {}); ASSERT_EQ(result->status(), Status::OK()); result->at(0)->printIndexedBuffer("Output"); // result->at(0)->printShapeInfo("Out Shape"); exp.printIndexedBuffer("Expect"); // exp.printShapeInfo("Exp Shape"); ASSERT_TRUE(exp.equalsTo(result->at(0))); delete result; } //////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests7, TestSegmentProd_06) { auto x = NDArrayFactory::create({'\x1','\x2','\x3','\x4','\x5','\x6','\x7','\x8' }); // ---------------------------------------------------------------- auto idx = NDArrayFactory::create({0,0,1,2,2,2,3,3}); auto exp = NDArrayFactory::create({ 2, 3, 120, 56}); x.printIndexedBuffer("INPUT INT8"); nd4j::ops::segment_prod op; auto result = op.execute({&x, &idx}, {}, {}); ASSERT_EQ(result->status(), Status::OK()); result->at(0)->printIndexedBuffer("Output"); // result->at(0)->printShapeInfo("Out Shape"); exp.printIndexedBuffer("Expect"); // exp.printShapeInfo("Exp Shape"); ASSERT_TRUE(exp.equalsTo(result->at(0))); delete result; } //////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests7, TestSegmentProd_07) { auto x = NDArrayFactory::create({'\x1','\x2','\x3','\x4','\x5','\x6','\x7','\x8' }); // ---------------------------------------------------------------- auto idx = NDArrayFactory::create({0,0,1,2,2,2,3,3}); auto exp = NDArrayFactory::create({ 2, 3, 120, 56}); x.printIndexedBuffer("INPUT INT8"); nd4j::ops::segment_prod op; auto result = op.execute({&x, &idx}, {}, {}); ASSERT_EQ(result->status(), Status::OK()); result->at(0)->printIndexedBuffer("Output"); // result->at(0)->printShapeInfo("Out Shape"); exp.printIndexedBuffer("Expect"); // exp.printShapeInfo("Exp Shape"); ASSERT_TRUE(exp.equalsTo(result->at(0))); delete result; } //////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests7, TestUnsortedSegmentProd_1) { auto x = NDArrayFactory::create({1.8, 2.5,4., 9., 2.1, 2.4,3.,9., 2.1, 2.1,0.7, 0.1, 3., 4.2, 2.2, 1.}); auto idx = NDArrayFactory::create({0, 0, 1, 1, 1, 1, 2, 3, 3, 3, 4, 4, 4, 4, 4, 4}); auto exp = NDArrayFactory::create({4.5, 181.44, 3., 39.69, 1.9404}); nd4j::ops::unsorted_segment_prod op; auto result = op.execute({&x, &idx}, {}, {5}); ASSERT_EQ(result->status(), Status::OK()); ASSERT_TRUE(exp.equalsTo(result->at(0))); delete result; } //////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests7, TestUnsortedSegmentProd_11) { auto x = NDArrayFactory::create({3.,1.8, 2.5,4., 9., 2.1, 2.4,9., 2.1, 2.1,0.7, 0.1, 3., 4.2, 2.2, 1.}); auto idx = NDArrayFactory::create({2, 0, 0, 1, 1, 1, 1, 3, 3, 3, 4, 4, 4, 4, 4, 4}); auto exp = NDArrayFactory::create({4.5, 181.44, 3., 39.69, 1.9404}); nd4j::ops::unsorted_segment_prod op; auto result = op.execute({&x, &idx}, {}, {5}); ASSERT_EQ(result->status(), Status::OK()); ASSERT_TRUE(exp.equalsTo(result->at(0))); delete result; } //////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests7, TestUnsortedSegmentProd_2) { auto x = NDArrayFactory::create('c', {4, 4}, { 1.8, 2.5, 4., 9., 2.1, 2.4, 3., 9., 2.1, 2.1, 0.7, 0.1, 3., 4.2, 2.2, 1. }); auto idx = NDArrayFactory::create({0, 0, 1, 2}); auto exp = NDArrayFactory::create('c', {3, 4}, { 3.78, 6. , 12. , 81., 2.1 , 2.1, 0.7 , 0.1, 3. , 4.2, 2.2 , 1.}); //{ 2.1, 2.5, 4., 9., 2.1, 2.1, 0.7, 0.1, 3., 4.2, 2.2, 1.} nd4j::ops::unsorted_segment_prod op; auto result = op.execute({&x, &idx}, {}, {3}); ASSERT_EQ(result->status(), Status::OK()); ASSERT_EQ(result->size(), 1); // exp.printIndexedBuffer("Expect"); // exp.printShapeInfo("Exp Shape"); ASSERT_TRUE(exp.equalsTo(result->at(0))); delete result; } //////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests7, TestUnsortedSegmentProd_12) { auto x = NDArrayFactory::create('c', {4, 4}, { 3., 4.2, 2.2, 1., 1.8, 2.5, 4., 9., 2.1, 2.4, 3., 9., 2.1, 2.1, 0.7, 0.1 }); auto idx = NDArrayFactory::create({2, 0, 0, 1}); auto exp = NDArrayFactory::create('c', {3, 4}, { 3.78, 6. , 12. , 81., 2.1 , 2.1, 0.7 , 0.1, 3. , 4.2, 2.2 , 1.}); //{ 2.1, 2.5, 4., 9., 2.1, 2.1, 0.7, 0.1, 3., 4.2, 2.2, 1.} nd4j::ops::unsorted_segment_prod op; auto result = op.execute({&x, &idx}, {}, {3}); ASSERT_EQ(result->status(), Status::OK()); ASSERT_EQ(result->size(), 1); // exp.printIndexedBuffer("Expect"); // exp.printShapeInfo("Exp Shape"); ASSERT_TRUE(exp.equalsTo(result->at(0))); delete result; } //////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests7, TestUnsortedSegmentProd_3) { auto x = NDArrayFactory::create('c', {4, 4, 4}, { 91. , 82. , 37. , 64. , 55.1, 46.4, 73. , 28. , 119.1, 12.1, 112.7, 13.1, 14. , 114.2, 16.2, 117. , 51. , 42. , 67. , 24., 15.1, 56.4, 93. , 28., 109.1, 82.1, 12.7, 113.1, 114. , 14.2, 116.2, 11. , 31. , 22. , 87., 44. , 55.1, 46.4, 73., 28. , 119.1, 12.1, 112.7, 13.1, 14. , 114.2, 16.2, 117. , 91. , 82. , 37., 64. , 55.1, 46.4, 73., 28. , 119.1, 12.1, 112.7, 13.1, 14. , 114.2, 16.2, 117. }); // ---------------------------------------------------------------- auto idx = NDArrayFactory::create({0, 1, 1, 2}); auto exp = NDArrayFactory::create('c', {3, 4, 4}, { 91. , 82. , 37. , 64. , 55.1, 46.4, 73. , 28. , 119.1, 12.1, 112.7, 13.1, 14. , 114.2, 16.2, 117. , 1581, 924, 5829, 1056,832.01001, 2616.9602, 6789, 784, 12993.810, 993.41003, 1431.2899, 1481.61, 1596.0000, 1621.6399, 1882.4401, 1287, 91. , 82. , 37. , 64. , 55.1, 46.4, 73. , 28. ,119.1, 12.1, 112.7, 13.1, 14. , 114.2, 16.2, 117. }); nd4j::ops::unsorted_segment_prod op; auto result = op.execute({&x, &idx}, {}, {3}); ASSERT_EQ(result->status(), Status::OK()); // result->at(0)->printIndexedBuffer("Output"); // result->at(0)->printShapeInfo("Out Shape"); // exp.printIndexedBuffer("Expect"); // exp.printShapeInfo("Exp Shape"); ASSERT_TRUE(exp.equalsTo(result->at(0))); delete result; } //////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests7, TestUnsortedSegmentProd_4) { auto x = NDArrayFactory::create('c', {4, 4, 4}, { 91. , 82. , 37. , 64. , 55.1, 46.4, 73. , 28. , 119.1, 12.1, 112.7, 13.1, 14. , 114.2, 16.2, 117. , 51. , 42. , 67. , 24., 15.1, 56.4, 93. , 28., 109.1, 82.1, 12.7, 113.1, 114. , 14.2, 116.2, 11. , 31. , 22. , 87., 44. , 55.1, 46.4, 73., 28. , 119.1, 12.1, 112.7, 13.1, 14. , 114.2, 16.2, 117. , 91. , 82. , 37., 64. , 55.1, 46.4, 73., 28. , 119.1, 12.1, 112.7, 13.1, 14. , 114.2, 16.2, 117. }); // ---------------------------------------------------------------- auto idx = NDArrayFactory::create({1, 1, 1, 2}); auto exp = NDArrayFactory::create('c', {3, 4, 4}, { 1., 1., 1., 1., 1., 1.,1.,1., 1.,1.,1.,1., 1.,1.,1.,1., 143871, 75768, 215673, 67584., 45843.75, 121426.96, 495597, 21952, 1547562.8, 12020.262, 161306.38, 19409.092, 22344, 185191.27, 30495.531, 150579, 91., 82., 37., 64, 55.1, 46.400002, 73, 28, 119.1, 12.1, 112.7, 13.1, 14, 114.2, 16.2, 117}); nd4j::ops::unsorted_segment_prod op; auto result = op.execute({&x, &idx}, {}, {3}); ASSERT_EQ(result->status(), Status::OK()); //result->at(0)->printIndexedBuffer("Output"); // result->at(0)->printShapeInfo("Out Shape"); //exp.printIndexedBuffer("Expect"); // exp.printShapeInfo("Exp Shape"); ASSERT_TRUE(exp.equalsTo(result->at(0))); delete result; } //////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests7, TestUnsortedSegmentProdBP_4) { auto x = NDArrayFactory::create('c', {8}, { 5,1,7,2,3,4,1,3}); auto gradO = NDArrayFactory::create('c', {4}, {1,2,3,4}); // ---------------------------------------------------------------- auto idx = NDArrayFactory::create({0,0,0,1,2,2,3,3}); auto exp = NDArrayFactory::create('c', {8}, { 7.000000, 35.000000, 5.000000, 2.000000, 12.000000, 9.000000, 12.000000, 4.000000 }); // 1., 1., 1., 1., 1., 1.,1.,1., 1.,1.,1.,1., 1.,1.,1.,1., // // 143871, 75768, 215673, 67584., 45843.75, 121426.96, 495597, 21952, // 1547562.8, 12020.262, 161306.38, 19409.092, 22344, 185191.27, 30495.531, 150579, // // 91., 82., 37., 64, 55.1, 46.400002, 73, 28, 119.1, 12.1, 112.7, 13.1, 14, 114.2, 16.2, 117}); nd4j::ops::unsorted_segment_prod_bp op; auto result = op.execute({&x, &idx, &gradO}, {}, {4}); ASSERT_EQ(result->status(), Status::OK()); //result->at(0)->printIndexedBuffer("Output"); //result->at(0)->printShapeInfo("Out Shape"); //exp.printIndexedBuffer("Expect"); // exp.printShapeInfo("Exp Shape"); ASSERT_TRUE(exp.equalsTo(result->at(0))); delete result; } //////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests7, TestExtractImagePatches_1) { auto x = NDArrayFactory::create('c', {2,4, 4, 4}, { 91., 82., 37., 64., 55., 46., 73., 28., 119., 12., 112., 13., 14., 114., 16., 117., 51., 42., 67., 24., 15., 56., 93., 28., 109., 82., 12., 113., 114., 14., 116., 11., 31., 22., 87., 44., 55., 46., 73., 28., 119., 12., 112., 13., 14., 114., 16., 117., 91., 82., 37., 64., 55.1, 46.4, 73., 28., 119., 12., 112., 13., 14., 114., 16.2, 117., 91., 82., 37., 64., 55., 46., 73., 28., 119., 12., 112., 13., 14., 114., 16., 117., 51., 42., 67., 24., 15., 56., 93., 28., 109., 82., 12., 113., 114., 14., 116., 11., 31., 22., 87., 44., 55., 46., 73., 28., 119., 12., 112., 13., 14., 114., 16., 117., 91., 82., 37., 64., 55.1, 46.4, 73., 28., 119., 12., 112., 13., 140., 110., 160., 107.}); // ---------------------------------------------------------------- auto exp = NDArrayFactory::create('c', {2, 4, 4, 4}, { 91., 82., 37., 64., 55., 46., 73., 28., 119., 12., 112., 13., 14., 114., 16., 117., 51., 42., 67., 24., 15., 56., 93., 28., 109., 82., 12., 113., 114., 14., 116., 11., 31., 22., 87., 44., 55., 46., 73., 28., 119., 12., 112., 13., 14., 114., 16., 117., 91., 82., 37., 64., 55.1, 46.4, 73., 28., 119., 12., 112., 13., 14., 114., 16.2, 117., 91., 82., 37., 64., 55., 46., 73., 28., 119., 12., 112., 13., 14., 114., 16., 117., 51., 42., 67., 24., 15., 56., 93., 28., 109., 82., 12., 113., 114., 14., 116., 11., 31., 22., 87., 44., 55., 46., 73., 28., 119., 12., 112., 13., 14., 114., 16., 117., 91., 82., 37., 64., 55.1, 46.4, 73., 28., 119., 12., 112., 13., 140., 110., 160., 107.}); nd4j::ops::extract_image_patches op; auto result = op.execute({&x}, {}, {1,1,1,1,1,1,0}); ASSERT_EQ(result->status(), Status::OK()); // result->at(0)->printIndexedBuffer("Output"); //result->at(0)->printShapeInfo("Out Shape"); // exp.printIndexedBuffer("Expect"); //exp.printShapeInfo("Exp Shape"); ASSERT_TRUE(exp.isSameShape(result->at(0))); ASSERT_TRUE(exp.equalsTo(result->at(0))); delete result; } //////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests7, TestExtractImagePatches_2) { auto x = NDArrayFactory::create('c', {3, 3, 4, 3}, { 11., 12., 13., 12., 13., 14., 15., 16., 17., 18., 19., 10., 1., 2., 3., 2., 3., 4., 21., 22., 23., 22., 23., 24., 5., 6., 7., 8., 9., 0., 35., 36., 37., 38., 39., 40., 9., 8., 7., 6., 5., 4., 49., 48., 47., 46., 45., 44., 3., 2., 1., 0., 1., 2., 53., 52., 51., 50., 51., 52., 15., 16., 17., 18., 19., 10., 135., 136., 137., 138., 139., 140., 211., 12., 13., 12., 213., 14., 15., 216., 17., 128., 19., 10., 21., 2., 3., 2., 3., 24., 21., 22., 223., 22., 223., 24., 25., 6., 7., 8., 9., 20., 35., 36., 327., 38., 239., 40.}); //Images shape is (3, 3, 4, 3) //[1, 1, 1, 1] //[1, 3, 2, 1] auto exp = NDArrayFactory::create('c', {3, 1, 1, 12}, { 11., 12., 13., 12., 13., 14., 1., 2., 3., 2., 3., 4., 9., 8., 7., 6., 5., 4., 3., 2., 1., 0., 1., 2., 211., 12., 13., 12., 213., 14., 21., 2., 3., 2., 3., 24. }); // ---------------------------------------------------------------- nd4j::ops::extract_image_patches op; auto result = op.execute({&x}, {}, {2,2, 3,3, 1,1,0}); ASSERT_EQ(result->status(), Status::OK()); ASSERT_TRUE(exp.isSameShape(result->at(0))); ASSERT_TRUE(exp.equalsTo(result->at(0))); delete result; } //////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests7, TestExtractImagePatches_3) { auto x = NDArrayFactory::create('c', {3, 3, 4, 3}, { 11., 12., 13., 12., 13., 14., 15., 16., 17., 18., 19., 10., 1., 2., 3., 2., 3., 4., 21., 22., 23., 22., 23., 24., 5., 6., 7., 8., 9., 0., 35., 36., 37., 38., 39., 40., 9., 8., 7., 6., 5., 4., 49., 48., 47., 46., 45., 44., 3., 2., 1., 0., 1., 2., 53., 52., 51., 50., 51., 52., 15., 16., 17., 18., 19., 10., 135., 136., 137., 138., 139., 140., 211., 12., 13., 12., 213., 14., 15., 216., 17., 128., 19., 10., 21., 2., 3., 2., 3., 24., 21., 22., 223., 22., 223., 24., 25., 6., 7., 8., 9., 20., 35., 36., 327., 38., 239., 40.}); //Images shape is (3, 3, 4, 3) //[1, 1, 1, 1] //[1, 3, 2, 1] auto exp = NDArrayFactory::create('c', {3, 1, 2, 6}, { 11., 12., 13., 5., 6., 7., 15., 16., 17., 35., 36., 37., 9., 8., 7., 15., 16., 17., 49., 48., 47., 135., 136., 137., 211., 12., 13., 25., 6., 7., 15., 216., 17., 35., 36., 327. }); // ---------------------------------------------------------------- nd4j::ops::extract_image_patches op; auto result = op.execute({&x}, {}, {2,1,3,2,2,2,0}); ASSERT_EQ(result->status(), Status::OK()); // x.printIndexedBuffer("images"); // nd4j_printf("input params: ksize = [1, 2, 1, 1], strides = [1, 3, 2, 1], rates = [1, 2, 2, 1]\n", ""); result->at(0)->printBuffer("Output"); //result->at(0)->printShapeInfo("Out Shape"); exp.printBuffer("Expect"); //exp.printShapeInfo("Exp Shape"); ASSERT_TRUE(exp.isSameShape(result->at(0))); ASSERT_TRUE(exp.equalsTo(result->at(0))); delete result; } //////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests7, TestExtractImagePatches_4) { auto x = NDArrayFactory::create('c', {3, 3, 4, 3}, { 11., 12., 13., 12., 13., 14., 15., 16., 17., 18., 19., 10., 1., 2., 3., 2., 3., 4., 21., 22., 23., 22., 23., 24., 5., 6., 7., 8., 9., 0., 35., 36., 37., 38., 39., 40., 9., 8., 7., 6., 5., 4., 49., 48., 47., 46., 45., 44., 3., 2., 1., 0., 1., 2., 53., 52., 51., 50., 51., 52., 15., 16., 17., 18., 19., 10., 135., 136., 137., 138., 139., 140., 211., 12., 13., 12., 213., 14., 15., 216., 17., 128., 19., 10., 21., 2., 3., 2., 3., 24., 21., 22., 223., 22., 223., 24., 25., 6., 7., 8., 9., 20., 35., 36., 327., 38., 239., 40.}); //Images shape is (3, 3, 4, 3) //[1, 1, 1, 1] //[1, 3, 2, 1] auto exp = NDArrayFactory::create('c', {3, 3, 4, 3}, { 11., 12., 13., 12., 13., 14., 15., 16., 17., 18., 19., 10., 1., 2., 3., 2., 3., 4., 21., 22., 23., 22., 23., 24., 5., 6., 7., 8., 9., 0., 35., 36., 37., 38., 39., 40., 9., 8., 7., 6., 5., 4., 49., 48., 47., 46., 45., 44., 3., 2., 1., 0., 1., 2., 53., 52., 51., 50., 51., 52., 15., 16., 17., 18., 19., 10., 135., 136., 137., 138., 139., 140., 211., 12., 13., 12., 213., 14., 15., 216., 17., 128., 19., 10., 21., 2., 3., 2., 3., 24., 21., 22., 223., 22., 223., 24., 25., 6., 7., 8., 9., 20., 35., 36., 327., 38., 239., 40.}); // ---------------------------------------------------------------- nd4j::ops::extract_image_patches op; auto result = op.execute({&x}, {}, {1,1,1,1,1,1,0}); ASSERT_EQ(result->status(), Status::OK()); // result->at(0)->printIndexedBuffer("Output"); //result->at(0)->printShapeInfo("Out Shape"); // exp.printIndexedBuffer("Expect"); //exp.printShapeInfo("Exp Shape"); ASSERT_TRUE(exp.isSameShape(result->at(0))); ASSERT_TRUE(exp.equalsTo(result->at(0))); delete result; } //////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests7, TestExtractImagePatches_5) { auto x = NDArrayFactory::create('c', {3, 3, 4, 3}, { 11., 12., 13., 12., 13., 14., 15., 16., 17., 18., 19., 10., 1., 2., 3., 2., 3., 4., 21., 22., 23., 22., 23., 24., 5., 6., 7., 8., 9., 0., 35., 36., 37., 38., 39., 40., 9., 8., 7., 6., 5., 4., 49., 48., 47., 46., 45., 44., 3., 2., 1., 0., 1., 2., 53., 52., 51., 50., 51., 52., 15., 16., 17., 18., 19., 10., 135., 136., 137., 138., 139., 140., 211., 12., 13., 12., 213., 14., 15., 216., 17., 128., 19., 10., 21., 2., 3., 2., 3., 24., 21., 22., 223., 22., 223., 24., 25., 6., 7., 8., 9., 20., 35., 36., 327., 38., 239., 40.}); //Images shape is (3, 3, 4, 3) //[1, 1, 1, 1] //[1, 3, 2, 1] auto exp = NDArrayFactory::create('c', {3, 1, 1, 18}, { 11., 12., 13., 15., 16., 17., 1., 2., 3., 21., 22., 23., 5., 6., 7., 35., 36., 37., 9., 8., 7., 49., 48., 47., 3., 2., 1., 53., 52., 51., 15., 16., 17., 135., 136., 137., 211., 12., 13., 15., 216., 17., 21., 2., 3., 21., 22., 223., 25., 6., 7., 35., 36., 327. //Patch shape is (3, 1, 2, 18) }); // ---------------------------------------------------------------- nd4j::ops::extract_image_patches op; auto result = op.execute({&x}, {}, {3,2,3,2,1,2,0}); ASSERT_EQ(result->status(), Status::OK()); // result->at(0)->printIndexedBuffer("Output"); //result->at(0)->printShapeInfo("Out Shape"); // exp.printIndexedBuffer("Expect"); //exp.printShapeInfo("Exp Shape"); ASSERT_TRUE(exp.isSameShape(result->at(0))); ASSERT_TRUE(exp.equalsTo(result->at(0))); delete result; } //////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests7, TestExtractImagePatches_6) { auto x = NDArrayFactory::create('c', {2, 2, 4, 2}, { 11.11, 11.12, 11.21, 11.22, 11.31, 11.32, 11.41, 11.42, 12.11, 12.12, 12.21, 12.22, 12.31, 12.32, 12.41, 12.42, 21.11, 21.12, 21.21, 21.22, 21.31, 21.32, 21.41, 21.42, 22.11, 22.12, 22.21, 22.22, 22.31, 22.32, 22.41, 22.42 }); //Images shape is (3, 3, 4, 3) //[1, 1, 1, 1] //[1, 3, 2, 1] auto exp = NDArrayFactory::create('c', {2, 1, 4, 4}, { 11.11, 11.12, 12.11, 12.12, 11.21, 11.22, 12.21, 12.22, 11.31, 11.32, 12.31, 12.32, 11.41, 11.42, 12.41, 12.42, 21.11, 21.12, 22.11, 22.12, 21.21, 21.22, 22.21, 22.22, 21.31, 21.32, 22.31, 22.32, 21.41, 21.42, 22.41, 22.42 }); // ---------------------------------------------------------------- nd4j::ops::extract_image_patches op; auto result = op.execute({&x}, {}, {2,1, 1,1, 1,1,0}); ASSERT_EQ(result->status(), Status::OK()); ASSERT_TRUE(exp.isSameShape(result->at(0))); ASSERT_TRUE(exp.equalsTo(result->at(0))); delete result; } //////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests7, TestExtractImagePatches_SGO_7) { auto x = NDArrayFactory::create('c', {1, 3, 3, 1}); x.linspace(1); //Images shape is (1, 3, 3, 4) //[1, 1, 1, 1] //[1, 3, 2, 1] auto exp = NDArrayFactory::create('c', {1, 3, 3, 4}, { 1., 2., 4., 5., 2., 3., 5., 6., 3., 0., 6., 0., 4., 5., 7., 8., 5., 6., 8., 9., 6., 0., 9., 0., 7., 8., 0., 0., 8., 9., 0., 0., 9., 0., 0., 0. }); // ---------------------------------------------------------------- nd4j::ops::extract_image_patches op; auto result = op.execute({&x}, {}, {2,2, 1,1, 1,1, 1}); // equiv TF ksizes=[1,2,2,1], strides=[1,1,1,1], rates=[1,1,1,1], padding="SAME" ASSERT_EQ(result->status(), Status::OK()); auto output = result->at(0); // output->printBuffer("Output"); // exp.printBuffer("Expect"); // for (Nd4jLong e = 0; e < exp.lengthOf(); e++) // if (exp.e(e) != output->e(e)) // printf("%lld ", e); // printf("\n"); //result->at(1)->printBuffer("OUtput2"); ASSERT_TRUE(exp.isSameShape(output)); ASSERT_TRUE(exp.equalsTo(output)); delete result; } //////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests7, TestExtractImagePatches_SGO_8) { auto x = NDArrayFactory::create('c', {1, 3, 3, 2}); x.linspace(1); //Images shape is (1, 3, 3, 4) //[1, 1, 1, 1] //[1, 3, 2, 1] auto exp = NDArrayFactory::create('c', {1, 3, 3, 8}, { 1, 2, 3, 4, 7, 8, 9, 10, 3, 4, 5, 6, 9, 10, 11, 12, 5, 6, 0, 0, 11, 12, 0, 0, 7, 8, 9, 10, 13, 14, 15, 16, 9, 10, 11, 12, 15, 16, 17, 18, 11, 12, 0, 0, 17, 18, 0, 0, 13, 14, 15, 16, 0, 0, 0, 0, 15, 16, 17, 18, 0, 0, 0, 0, 17, 18, 0, 0, 0, 0, 0, 0 }); // ---------------------------------------------------------------- nd4j::ops::extract_image_patches op; auto result = op.execute({&x}, {}, {2,2, 1,1, 1,1, 1}); // equiv TF ksizes=[1,2,2,1], strides=[1,1,1,1], rates=[1,1,1,1], padding="SAME" ASSERT_EQ(result->status(), Status::OK()); auto output = result->at(0); // output->printBuffer("Output"); // exp.printBuffer("Expect"); // for (Nd4jLong e = 0; e < exp.lengthOf(); e++) // if (exp.e(e) != output->e(e)) // printf("%lld ", e); // printf("\n"); //result->at(1)->printBuffer("OUtput2"); ASSERT_TRUE(exp.isSameShape(output)); ASSERT_TRUE(exp.equalsTo(output)); delete result; } //////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests7, TestExtractImagePatches_SGO_9) { auto x = NDArrayFactory::create('c', {1, 6, 6, 2}); x.linspace(1); //Images shape is (1, 3, 3, 4) //[1, 1, 1, 1] //[1, 3, 2, 1] auto exp = NDArrayFactory::create('c', {1, 6, 6, 18}, { 0., 0., 0., 0., 0., 0., 0., 0., 1., 2., 3., 4., 0., 0., 13., 14., 15., 16., 0., 0., 0., 0., 0., 0., 1., 2., 3., 4., 5., 6., 13., 14., 15., 16., 17., 18., 0., 0., 0., 0., 0., 0., 3., 4., 5., 6., 7., 8., 15., 16., 17., 18., 19., 20., 0., 0., 0., 0., 0., 0., 5., 6., 7., 8., 9., 10., 17., 18., 19., 20., 21., 22., 0., 0., 0., 0., 0., 0., 7., 8., 9., 10., 11., 12., 19., 20., 21., 22., 23., 24., 0., 0., 0., 0., 0., 0., 9., 10., 11., 12., 0., 0., 21., 22., 23., 24., 0., 0., 0., 0., 1., 2., 3., 4., 0., 0., 13., 14., 15., 16., 0., 0., 25., 26., 27., 28., 1., 2., 3., 4., 5., 6., 13., 14., 15., 16., 17., 18., 25., 26., 27., 28., 29., 30., 3., 4., 5., 6., 7., 8., 15., 16., 17., 18., 19., 20., 27., 28., 29., 30., 31., 32., 5., 6., 7., 8., 9., 10., 17., 18., 19., 20., 21., 22., 29., 30., 31., 32., 33., 34., 7., 8., 9., 10., 11., 12., 19., 20., 21., 22., 23., 24., 31., 32., 33., 34., 35., 36., 9., 10., 11., 12., 0., 0., 21., 22., 23., 24., 0., 0., 33., 34., 35., 36., 0., 0., 0., 0., 13., 14., 15., 16., 0., 0., 25., 26., 27., 28., 0., 0., 37., 38., 39., 40., 13., 14., 15., 16., 17., 18., 25., 26., 27., 28., 29., 30., 37., 38., 39., 40., 41., 42., 15., 16., 17., 18., 19., 20., 27., 28., 29., 30., 31., 32., 39., 40., 41., 42., 43., 44., 17., 18., 19., 20., 21., 22., 29., 30., 31., 32., 33., 34., 41., 42., 43., 44., 45., 46., 19., 20., 21., 22., 23., 24., 31., 32., 33., 34., 35., 36., 43., 44., 45., 46., 47., 48., 21., 22., 23., 24., 0., 0., 33., 34., 35., 36., 0., 0., 45., 46., 47., 48., 0., 0., 0., 0., 25., 26., 27., 28., 0., 0., 37., 38., 39., 40., 0., 0., 49., 50., 51., 52., 25., 26., 27., 28., 29., 30., 37., 38., 39., 40., 41., 42., 49., 50., 51., 52., 53., 54., 27., 28., 29., 30., 31., 32., 39., 40., 41., 42., 43., 44., 51., 52., 53., 54., 55., 56., 29., 30., 31., 32., 33., 34., 41., 42., 43., 44., 45., 46., 53., 54., 55., 56., 57., 58., 31., 32., 33., 34., 35., 36., 43., 44., 45., 46., 47., 48., 55., 56., 57., 58., 59., 60., 33., 34., 35., 36., 0., 0., 45., 46., 47., 48., 0., 0., 57., 58., 59., 60., 0., 0., 0., 0., 37., 38., 39., 40., 0., 0., 49., 50., 51., 52., 0., 0., 61., 62., 63., 64., 37., 38., 39., 40., 41., 42., 49., 50., 51., 52., 53., 54., 61., 62., 63., 64., 65., 66., 39., 40., 41., 42., 43., 44., 51., 52., 53., 54., 55., 56., 63., 64., 65., 66., 67., 68., 41., 42., 43., 44., 45., 46., 53., 54., 55., 56., 57., 58., 65., 66., 67., 68., 69., 70., 43., 44., 45., 46., 47., 48., 55., 56., 57., 58., 59., 60., 67., 68., 69., 70., 71., 72., 45., 46., 47., 48., 0., 0., 57., 58., 59., 60., 0., 0., 69., 70., 71., 72., 0., 0., 0., 0., 49., 50., 51., 52., 0., 0., 61., 62., 63., 64., 0., 0., 0., 0., 0., 0., 49., 50., 51., 52., 53., 54., 61., 62., 63., 64., 65., 66., 0., 0., 0., 0., 0., 0., 51., 52., 53., 54., 55., 56., 63., 64., 65., 66., 67., 68., 0., 0., 0., 0., 0., 0., 53., 54., 55., 56., 57., 58., 65., 66., 67., 68., 69., 70., 0., 0., 0., 0., 0., 0., 55., 56., 57., 58., 59., 60., 67., 68., 69., 70., 71., 72., 0., 0., 0., 0., 0., 0., 57., 58., 59., 60., 0., 0., 69., 70., 71., 72., 0., 0., 0., 0., 0., 0., 0., 0.}); // ---------------------------------------------------------------- nd4j::ops::extract_image_patches op; auto result = op.execute({&x}, {}, {3,3, 1,1, 1,1, 1}); // equiv TF ksizes=[1,2,2,1], strides=[1,1,1,1], rates=[1,1,1,1], padding="SAME" ASSERT_EQ(result->status(), Status::OK()); auto output = result->at(0); // output->printBuffer("OutputSame"); // exp.printBuffer("ExpectSame"); // for (Nd4jLong e = 0; e < exp.lengthOf(); e++) // if (exp.e(e) != output->e(e)) // printf("%lld ", e); // printf("\n"); //result->at(1)->printBuffer("OUtput2"); ASSERT_TRUE(exp.isSameShape(output)); ASSERT_TRUE(exp.equalsTo(output)); delete result; } //////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests7, TestExtractImagePatches_SGO_9_1) { auto x = NDArrayFactory::create('c', {1, 4, 4, 2}, {1, 116, 2, 116, 3, 116, 4, 116, 5, 117, 6, 117, 7, 117, 8, 117, 9, 118, 10, 118, 11, 118, 12, 118, 13, 119, 14, 119, 15, 119, 16, 119}); //x.linspace(1); //Images shape is (1, 3, 3, 4) //[1, 1, 1, 1] //[1, 3, 2, 1] auto exp = NDArrayFactory::create('c', {1, 4, 4, 8}, { 1, 116, 2, 116, 5, 117, 6, 117, 2, 116, 3, 116, 6, 117, 7, 117, 3, 116, 4, 116, 7, 117, 8, 117, 4, 116, 0, 0, 8, 117, 0, 0, 5, 117, 6, 117, 9, 118, 10, 118, 6, 117, 7, 117, 10, 118, 11, 118, 7, 117, 8, 117, 11, 118, 12, 118, 8, 117, 0, 0, 12, 118, 0, 0, 9, 118, 10, 118, 13, 119, 14, 119, 10, 118, 11, 118, 14, 119, 15, 119, 11, 118, 12, 118, 15, 119, 16, 119, 12, 118, 0, 0, 16, 119, 0, 0, 13, 119, 14, 119, 0, 0, 0, 0, 14, 119, 15, 119, 0, 0, 0, 0, 15, 119, 16, 119, 0, 0, 0, 0, 16, 119, 0, 0, 0, 0, 0, 0 }); // ---------------------------------------------------------------- nd4j::ops::extract_image_patches op; auto result = op.execute({&x}, {}, {2,2, 1,1, 1,1, 1}); // equiv TF ksizes=[1,2,2,1], strides=[1,1,1,1], rates=[1,1,1,1], padding="SAME" ASSERT_EQ(result->status(), Status::OK()); auto output = result->at(0); // output->printBuffer("OutputSame"); // exp.printBuffer("ExpectSame"); // for (Nd4jLong e = 0; e < exp.lengthOf(); e++) // if (exp.e(e) != output->e(e)) // printf("%lld ", e); // printf("\n"); //result->at(1)->printBuffer("OUtput2"); ASSERT_TRUE(exp.isSameShape(output)); ASSERT_TRUE(exp.equalsTo(output)); delete result; } // // //////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests7, TestExtractImagePatches_SGO_10) { auto x = NDArrayFactory::create('c', {1, 6, 6, 2}); x.linspace(1); //Images shape is (1, 3, 3, 4) //[1, 1, 1, 1] //[1, 3, 2, 1] auto exp = NDArrayFactory::create('c', {1, 4, 4, 18}, { 1., 2., 3., 4., 5., 6., 13., 14., 15., 16., 17., 18., 25., 26., 27., 28., 29., 30., 3., 4., 5., 6., 7., 8., 15., 16., 17., 18., 19., 20., 27., 28., 29., 30., 31., 32., 5., 6., 7., 8., 9., 10., 17., 18., 19., 20., 21., 22., 29., 30., 31., 32., 33., 34., 7., 8., 9., 10., 11., 12., 19., 20., 21., 22., 23., 24., 31., 32., 33., 34., 35., 36., 13., 14., 15., 16., 17., 18., 25., 26., 27., 28., 29., 30., 37., 38., 39., 40., 41., 42., 15., 16., 17., 18., 19., 20., 27., 28., 29., 30., 31., 32., 39., 40., 41., 42., 43., 44., 17., 18., 19., 20., 21., 22., 29., 30., 31., 32., 33., 34., 41., 42., 43., 44., 45., 46., 19., 20., 21., 22., 23., 24., 31., 32., 33., 34., 35., 36., 43., 44., 45., 46., 47., 48., 25., 26., 27., 28., 29., 30., 37., 38., 39., 40., 41., 42., 49., 50., 51., 52., 53., 54., 27., 28., 29., 30., 31., 32., 39., 40., 41., 42., 43., 44., 51., 52., 53., 54., 55., 56., 29., 30., 31., 32., 33., 34., 41., 42., 43., 44., 45., 46., 53., 54., 55., 56., 57., 58., 31., 32., 33., 34., 35., 36., 43., 44., 45., 46., 47., 48., 55., 56., 57., 58., 59., 60., 37., 38., 39., 40., 41., 42., 49., 50., 51., 52., 53., 54., 61., 62., 63., 64., 65., 66., 39., 40., 41., 42., 43., 44., 51., 52., 53., 54., 55., 56., 63., 64., 65., 66., 67., 68., 41., 42., 43., 44., 45., 46., 53., 54., 55., 56., 57., 58., 65., 66., 67., 68., 69., 70., 43., 44., 45., 46., 47., 48., 55., 56., 57., 58., 59., 60., 67., 68., 69., 70., 71., 72.}); // ---------------------------------------------------------------- nd4j::ops::extract_image_patches op; //x.printIndexedBuffer("Images"); //x.printBuffer("Images linear"); auto result = op.execute({&x}, {}, {3,3, 1,1, 1,1, 0}); // equiv TF ksizes=[1,2,2,1], strides=[1,1,1,1], rates=[1,1,1,1], padding="VALID" ASSERT_EQ(result->status(), Status::OK()); auto output = result->at(0); // output->printBuffer("OutputValid"); // exp.printBuffer("ExpectValid"); // for (Nd4jLong e = 0; e < exp.lengthOf(); e++) // if (exp.e(e) != output->e(e)) // printf("%lld ", e); // printf("\n"); //result->at(1)->printBuffer("OUtput2"); ASSERT_TRUE(exp.isSameShape(output)); ASSERT_TRUE(exp.equalsTo(output)); delete result; } TEST_F(DeclarableOpsTests7, TestExtractImagePatches_SGO_010) { auto x = NDArrayFactory::create('c', {1, 4, 4, 1}); x.linspace(1); //Images shape is (1, 3, 3, 4) //[1, 1, 1, 1] //[1, 3, 2, 1] auto exp = NDArrayFactory::create('c', {1, 3, 3, 4}, { 1, 2, 5, 6, 2, 3, 6, 7, 3, 4, 7, 8, 5, 6, 9, 10, 6, 7, 10, 11, 7, 8, 11, 12, 9, 10, 13, 14, 10, 11, 14, 15, 11, 12, 15, 16}); // ---------------------------------------------------------------- nd4j::ops::extract_image_patches op; //x.printIndexedBuffer("Images"); //x.printBuffer("Images linear"); auto result = op.execute({&x}, {}, {2,2, 1,1, 1,1, 0}); // equiv TF ksizes=[1,2,2,1], strides=[1,1,1,1], rates=[1,1,1,1], padding="VALID" ASSERT_EQ(result->status(), Status::OK()); auto output = result->at(0); // output->printBuffer("OutputValid"); // exp.printBuffer("ExpectValid"); // for (Nd4jLong e = 0; e < exp.lengthOf(); e++) // if (exp.e(e) != output->e(e)) // printf("%lld ", e); // printf("\n"); //result->at(1)->printBuffer("OUtput2"); ASSERT_TRUE(exp.isSameShape(output)); ASSERT_TRUE(exp.equalsTo(output)); delete result; } TEST_F(DeclarableOpsTests7, TestExtractImagePatches_SGO_010_1) { auto x = NDArrayFactory::create('c', {1, 4, 4, 1}); x.linspace(1); //Images shape is (1, 3, 3, 4) //[1, 1, 1, 1] //[1, 3, 2, 1] auto exp = NDArrayFactory::create('c', {1, 4, 4, 4}, { 1, 2, 5, 6, 2, 3, 6, 7, 3, 4, 7, 8, 4, 0, 8, 0, 5, 6, 9, 10, 6, 7, 10, 11, 7, 8, 11, 12, 8, 0, 12, 0, 9, 10, 13, 14, 10, 11, 14, 15, 11, 12, 15, 16, 12, 0, 16, 0, 13, 14, 0, 0, 14, 15, 0, 0, 15, 16, 0, 0, 16, 0, 0, 0}); // ---------------------------------------------------------------- nd4j::ops::extract_image_patches op; //x.printIndexedBuffer("Images"); //x.printBuffer("Images linear"); auto result = op.execute({&x}, {}, {2,2, 1,1, 1,1, 1}); // equiv TF ksizes=[1,2,2,1], strides=[1,1,1,1], rates=[1,1,1,1], padding="VALID" ASSERT_EQ(result->status(), Status::OK()); auto output = result->at(0); // output->printBuffer("OutputSame"); // exp.printBuffer("ExpectSame"); // exp.printIndexedBuffer("Expect Same Formatted"); // output->printIndexedBuffer("Output Same Formatted"); // for (Nd4jLong e = 0; e < exp.lengthOf(); e++) // if (exp.e(e) != output->e(e)) // printf("%lld ", e); // printf("\n"); //result->at(1)->printBuffer("OUtput2"); ASSERT_TRUE(exp.isSameShape(output)); ASSERT_TRUE(exp.equalsTo(output)); delete result; } TEST_F(DeclarableOpsTests7, TestExtractImagePatches_SGO_011) { auto x = NDArrayFactory::create('c', {1, 4, 4, 1}); x.linspace(1); //Images shape is (1, 3, 3, 4) //[1, 1, 1, 1] //[1, 3, 2, 1] auto exp = NDArrayFactory::create('c', {1, 2, 2, 4}, { 1, 3, 9, 11, 2, 4, 10, 12, 5, 7, 13, 15, 6, 8, 14, 16, }); // ---------------------------------------------------------------- nd4j::ops::extract_image_patches op; //x.printIndexedBuffer("Images"); //x.printBuffer("Images linear"); auto result = op.execute({&x}, {}, {2,2, 1,1, 2,2, 0}); // equiv TF ksizes=[1,2,2,1], strides=[1,1,1,1], rates=[1,1,1,1], padding="VALID" ASSERT_EQ(result->status(), Status::OK()); auto output = result->at(0); // output->printBuffer("OutputValid"); // exp.printBuffer("ExpectValid"); // for (Nd4jLong e = 0; e < exp.lengthOf(); e++) // if (exp.e(e) != output->e(e)) // printf("%lld ", e); // printf("\n"); //result->at(1)->printBuffer("OUtput2"); ASSERT_TRUE(exp.isSameShape(output)); ASSERT_TRUE(exp.equalsTo(output)); delete result; } //////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests7, TestExtractImagePatches_SGO_11) { auto x = NDArrayFactory::create('c', {1, 8, 8, 2}); x.linspace(1); //Images shape is (1, 3, 3, 4) //[1, 1, 1, 1] //[1, 3, 2, 1] auto exp = NDArrayFactory::create('c', {1, 4, 4, 8}, { 1, 2, 3, 4, 17, 18, 19, 20, 5, 6, 7, 8, 21, 22, 23, 24, 9, 10, 11, 12, 25, 26, 27, 28, 13, 14, 15, 16, 29, 30, 31, 32, 33, 34, 35, 36, 49, 50, 51, 52, 37, 38, 39, 40, 53, 54, 55, 56, 41, 42, 43, 44, 57, 58, 59, 60, 45, 46, 47, 48, 61, 62, 63, 64, 65, 66, 67, 68, 81, 82, 83, 84, 69, 70, 71, 72, 85, 86, 87, 88, 73, 74, 75, 76, 89, 90, 91, 92, 77, 78, 79, 80, 93, 94, 95, 96, 97, 98, 99, 100, 113, 114, 115, 116, 101, 102, 103, 104, 117, 118, 119, 120, 105, 106, 107, 108, 121, 122, 123, 124, 109, 110, 111, 112, 125, 126, 127, 128}); // ---------------------------------------------------------------- nd4j::ops::extract_image_patches op; auto result = op.execute({&x}, {}, {2,2, 2,2, 1,1, 1}); // equiv TF ksizes=[1,2,2,1], strides=[1,1,1,1], rates=[1,1,1,1], padding="SAME" ASSERT_EQ(result->status(), Status::OK()); auto output = result->at(0); // output->printBuffer("Output"); // exp.printBuffer("Expect"); // for (Nd4jLong e = 0; e < exp.lengthOf(); e++) // if (exp.e(e) != output->e(e)) // printf("%lld ", e); // printf("\n"); //result->at(1)->printBuffer("OUtput2"); ASSERT_TRUE(exp.isSameShape(output)); ASSERT_TRUE(exp.equalsTo(output)); delete result; } //////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests7, TestExtractImagePatches_SGO_12) { auto x = NDArrayFactory::create('c', {1, 8, 8, 2}); x.linspace(1); //Images shape is (1, 3, 3, 4) //[1, 1, 1, 1] //[1, 3, 2, 1] auto exp = NDArrayFactory::create('c', {1, 8, 8, 8}, { 0, 0, 0, 0, 0, 0, 19, 20, 0, 0, 0, 0, 17, 18, 21, 22, 0, 0, 0, 0, 19, 20, 23, 24, 0, 0, 0, 0, 21, 22, 25, 26, 0, 0, 0, 0, 23, 24, 27, 28, 0, 0, 0, 0, 25, 26, 29, 30, 0, 0, 0, 0, 27, 28, 31, 32, 0, 0, 0, 0, 29, 30, 0, 0, 0, 0, 3, 4, 0, 0, 35, 36, 1, 2, 5, 6, 33, 34, 37, 38, 3, 4, 7, 8, 35, 36, 39, 40, 5, 6, 9, 10, 37, 38, 41, 42, 7, 8, 11, 12, 39, 40, 43, 44, 9, 10, 13, 14, 41, 42, 45, 46, 11, 12, 15, 16, 43, 44, 47, 48, 13, 14, 0, 0, 45, 46, 0, 0, 0, 0, 19, 20, 0, 0, 51, 52, 17, 18, 21, 22, 49, 50, 53, 54, 19, 20, 23, 24, 51, 52, 55, 56, 21, 22, 25, 26, 53, 54, 57, 58, 23, 24, 27, 28, 55, 56, 59, 60, 25, 26, 29, 30, 57, 58, 61, 62, 27, 28, 31, 32, 59, 60, 63, 64, 29, 30, 0, 0, 61, 62, 0, 0, 0, 0, 35, 36, 0, 0, 67, 68, 33, 34, 37, 38, 65, 66, 69, 70, 35, 36, 39, 40, 67, 68, 71, 72, 37, 38, 41, 42, 69, 70, 73, 74, 39, 40, 43, 44, 71, 72, 75, 76, 41, 42, 45, 46, 73, 74, 77, 78, 43, 44, 47, 48, 75, 76, 79, 80, 45, 46, 0, 0, 77, 78, 0, 0, 0, 0, 51, 52, 0, 0, 83, 84, 49, 50, 53, 54, 81, 82, 85, 86, 51, 52, 55, 56, 83, 84, 87, 88, 53, 54, 57, 58, 85, 86, 89, 90, 55, 56, 59, 60, 87, 88, 91, 92, 57, 58, 61, 62, 89, 90, 93, 94, 59, 60, 63, 64, 91, 92, 95, 96, 61, 62, 0, 0, 93, 94, 0, 0, 0, 0, 67, 68, 0, 0, 99, 100, 65, 66, 69, 70, 97, 98, 101, 102, 67, 68, 71, 72, 99, 100, 103, 104, 69, 70, 73, 74, 101, 102, 105, 106, 71, 72, 75, 76, 103, 104, 107, 108, 73, 74, 77, 78, 105, 106, 109, 110, 75, 76, 79, 80, 107, 108, 111, 112, 77, 78, 0, 0, 109, 110, 0, 0, 0, 0, 83, 84, 0, 0, 115, 116, 81, 82, 85, 86, 113, 114, 117, 118, 83, 84, 87, 88, 115, 116, 119, 120, 85, 86, 89, 90, 117, 118, 121, 122, 87, 88, 91, 92, 119, 120, 123, 124, 89, 90, 93, 94, 121, 122, 125, 126, 91, 92, 95, 96, 123, 124, 127, 128, 93, 94, 0, 0, 125, 126, 0, 0, 0, 0, 99, 100, 0, 0, 0, 0, 97, 98, 101, 102, 0, 0, 0, 0, 99, 100, 103, 104, 0, 0, 0, 0, 101, 102, 105, 106, 0, 0, 0, 0, 103, 104, 107, 108, 0, 0, 0, 0, 105, 106, 109, 110, 0, 0, 0, 0, 107, 108, 111, 112, 0, 0, 0, 0, 109, 110, 0, 0, 0, 0, 0, 0}); // ---------------------------------------------------------------- nd4j::ops::extract_image_patches op; auto result = op.execute({&x}, {}, {2,2, 1,1, 2,2, 1}); // equiv TF ksizes=[1,2,2,1], strides=[1,1,1,1], rates=[1,2,2,1], padding="SAME" ASSERT_EQ(result->status(), Status::OK()); auto output = result->at(0); //output->printShapeInfo("Output shape"); // output->printIndexedBuffer("Output"); // exp.printBuffer("Expect"); // for (Nd4jLong e = 0; e < exp.lengthOf(); e++) // if (exp.e(e) != output->e(e)) // printf("%lld ", e); // printf("\n"); //result->at(1)->printBuffer("OUtput2"); ASSERT_TRUE(exp.isSameShape(output)); ASSERT_TRUE(exp.equalsTo(output)); delete result; } //////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests7, TestExtractImagePatches_SGO_13) { auto x = NDArrayFactory::create('c', {1, 3, 3, 2}); x.linspace(1); auto exp = NDArrayFactory::create('c', {1, 3, 3, 8}, { 1., 2., 3., 4., 7., 8., 9., 10., 3., 4., 5., 6., 9., 10., 11., 12., 5., 6., 0., 0., 11., 12., 0., 0., 7., 8., 9., 10., 13., 14., 15., 16., 9., 10., 11., 12., 15., 16., 17., 18., 11., 12., 0., 0., 17., 18., 0., 0., 13., 14., 15., 16., 0., 0., 0., 0., 15., 16., 17., 18., 0., 0., 0., 0., 17., 18., 0., 0., 0., 0., 0., 0. }); // ---------------------------------------------------------------- nd4j::ops::extract_image_patches op; auto result = op.execute({&x}, {}, {2,2, 1,1, 1,1, 1}); // equiv TF ksizes=[1,2,2,1], strides=[1,1,1,1], rates=[1,1,1,1], padding="SAME" ASSERT_EQ(result->status(), Status::OK()); auto output = result->at(0); // output->printShapeInfo("Output shape"); // output->printBuffer("Output"); // exp.printBuffer("Expect"); // for (Nd4jLong e = 0; e < exp.lengthOf(); e++) // if (exp.e(e) != output->e(e)) // printf("%lld ", e); // printf("\n"); //result->at(1)->printBuffer("OUtput2"); ASSERT_TRUE(exp.isSameShape(output)); ASSERT_TRUE(exp.equalsTo(output)); delete result; } //////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests7, TestRoll_1) { auto x = NDArrayFactory::create('c', {2, 2, 4, 2}, { 11.11, 11.12, 11.21, 11.22, 11.31, 11.32, 11.41, 11.42, 12.11, 12.12, 12.21, 12.22, 12.31, 12.32, 12.41, 12.42, 21.11, 21.12, 21.21, 21.22, 21.31, 21.32, 21.41, 21.42, 22.11, 22.12, 22.21, 22.22, 22.31, 22.32, 22.41, 22.42 }); auto exp = NDArrayFactory::create('c', {2, 2, 4, 2}, { 22.21, 22.22, 22.31, 22.32, 22.41, 22.42, 11.11, 11.12, 11.21, 11.22, 11.31, 11.32, 11.41, 11.42, 12.11, 12.12, 12.21, 12.22, 12.31, 12.32, 12.41, 12.42, 21.11, 21.12, 21.21, 21.22, 21.31, 21.32, 21.41, 21.42, 22.11, 22.12 }); // ---------------------------------------------------------------- nd4j::ops::roll op; auto result = op.execute({&x}, {}, {6}, {}, false, nd4j::DataType::DOUBLE); ASSERT_EQ(result->status(), Status::OK()); result->at(0)->printIndexedBuffer("z"); ASSERT_TRUE(exp.equalsTo(result->at(0))); delete result; } //////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests7, TestRoll_2) { auto x = NDArrayFactory::create('c', {2, 2, 4, 2}, { 11.11, 11.12, 11.21, 11.22, 11.31, 11.32, 11.41, 11.42, 12.11, 12.12, 12.21, 12.22, 12.31, 12.32, 12.41, 12.42, 21.11, 21.12, 21.21, 21.22, 21.31, 21.32, 21.41, 21.42, 22.11, 22.12, 22.21, 22.22, 22.31, 22.32, 22.41, 22.42 }); auto exp = NDArrayFactory::create('c', {2, 2, 4, 2}, { 12.11, 12.12, 12.21, 12.22, 12.31, 12.32, 12.41, 12.42, 21.11, 21.12, 21.21, 21.22, 21.31, 21.32, 21.41, 21.42, 22.11, 22.12, 22.21, 22.22, 22.31, 22.32, 22.41, 22.42, 11.11, 11.12, 11.21, 11.22, 11.31, 11.32, 11.41, 11.42 }); // ---------------------------------------------------------------- nd4j::ops::roll op; auto result = op.execute({&x}, {}, {-8}, {}, false, nd4j::DataType::DOUBLE); ASSERT_EQ(result->status(), Status::OK()); ASSERT_TRUE(exp.equalsTo(result->at(0))); delete result; } //////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests7, TestRoll_3) { auto x = NDArrayFactory::create('c', {2, 2, 4, 2}, { 11.11, 11.12, 11.21, 11.22, 11.31, 11.32, 11.41, 11.42, 12.11, 12.12, 12.21, 12.22, 12.31, 12.32, 12.41, 12.42, 21.11, 21.12, 21.21, 21.22, 21.31, 21.32, 21.41, 21.42, 22.11, 22.12, 22.21, 22.22, 22.31, 22.32, 22.41, 22.42 }); auto exp = NDArrayFactory::create('c', {2, 2, 4, 2}, { 12.11, 12.12, 12.21, 12.22, 12.31, 12.32, 12.41, 12.42, 21.11, 21.12, 21.21, 21.22, 21.31, 21.32, 21.41, 21.42, 22.11, 22.12, 22.21, 22.22, 22.31, 22.32, 22.41, 22.42, 11.11, 11.12, 11.21, 11.22, 11.31, 11.32, 11.41, 11.42 }); // ---------------------------------------------------------------- nd4j::ops::roll op; auto result = op.execute({&x}, {}, {-40}, {}, false, nd4j::DataType::DOUBLE); ASSERT_EQ(result->status(), Status::OK()); ASSERT_TRUE(exp.equalsTo(result->at(0))); delete result; } //////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests7, TestRoll_4) { auto x = NDArrayFactory::create('c', {2, 2, 4, 2}, { 11.11, 11.12, 11.21, 11.22, 11.31, 11.32, 11.41, 11.42, 12.11, 12.12, 12.21, 12.22, 12.31, 12.32, 12.41, 12.42, 21.11, 21.12, 21.21, 21.22, 21.31, 21.32, 21.41, 21.42, 22.11, 22.12, 22.21, 22.22, 22.31, 22.32, 22.41, 22.42 }); auto exp = NDArrayFactory::create('c', {2, 2, 4, 2}, { 22.21, 22.22, 22.31, 22.32, 22.41, 22.42, 11.11, 11.12, 11.21, 11.22, 11.31, 11.32, 11.41, 11.42, 12.11, 12.12, 12.21, 12.22, 12.31, 12.32, 12.41, 12.42, 21.11, 21.12, 21.21, 21.22, 21.31, 21.32, 21.41, 21.42, 22.11, 22.12 }); // 22.21, 22.22, 22.31, 22.32, 22.41, 22.42, 11.11, 11.12, 11.21, 11.22, 11.31, 11.32, 11.41, 11.42, // 12.11, 12.12, 12.21, 12.22, 12.31, 12.32, 12.41, 12.42, 21.11, 21.12, 21.21, 21.22, 21.31, 21.32 // 21.41, 21.42, 22.11, 22.12 // ---------------------------------------------------------------- nd4j::ops::roll op; auto result = op.execute({&x}, {}, {38}, {}, false, nd4j::DataType::DOUBLE); ASSERT_EQ(result->status(), Status::OK()); //result->at(0)->printIndexedBuffer("Output 4"); //exp.printIndexedBuffer("Expect 4"); ASSERT_TRUE(exp.equalsTo(result->at(0))); delete result; } //////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests7, TestRoll_4_inplace) { auto x = NDArrayFactory::create('c', {2, 2, 4, 2}, { 11.11, 11.12, 11.21, 11.22, 11.31, 11.32, 11.41, 11.42, 12.11, 12.12, 12.21, 12.22, 12.31, 12.32, 12.41, 12.42, 21.11, 21.12, 21.21, 21.22, 21.31, 21.32, 21.41, 21.42, 22.11, 22.12, 22.21, 22.22, 22.31, 22.32, 22.41, 22.42 }); auto exp = NDArrayFactory::create('c', {2, 2, 4, 2}, { 22.21, 22.22, 22.31, 22.32, 22.41, 22.42, 11.11, 11.12, 11.21, 11.22, 11.31, 11.32, 11.41, 11.42, 12.11, 12.12, 12.21, 12.22, 12.31, 12.32, 12.41, 12.42, 21.11, 21.12, 21.21, 21.22, 21.31, 21.32, 21.41, 21.42, 22.11, 22.12 }); // 22.21, 22.22, 22.31, 22.32, 22.41, 22.42, 11.11, 11.12, 11.21, 11.22, 11.31, 11.32, 11.41, 11.42, // 12.11, 12.12, 12.21, 12.22, 12.31, 12.32, 12.41, 12.42, 21.11, 21.12, 21.21, 21.22, 21.31, 21.32 // 21.41, 21.42, 22.11, 22.12 // ---------------------------------------------------------------- nd4j::ops::roll op; NDArray* y = nullptr; auto result = op.execute({&x}, {y}, {}, {38}, {}, true, nd4j::DataType::DOUBLE); ASSERT_EQ(result, Status::OK()); //x.printIndexedBuffer("Output 4 inplace"); //exp.printIndexedBuffer("Expect 4 inplace"); ASSERT_TRUE(exp.equalsTo(&x)); // delete result; } //////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests7, TestRoll_5) { auto x = NDArrayFactory::create('c', {3, 4}, { 0., 1., 2., 3., 4, 5., 6., 7., 8., 9., 10., 11. }); auto exp = NDArrayFactory::create('c', {3, 4}, { 2., 3., 0., 1., 6., 7., 4., 5., 10., 11., 8., 9. // 4, 5, 6, 7, 8, 9, 10, 11, 0, 1, 2, 3 }); // ---------------------------------------------------------------- nd4j::ops::roll op; auto result = op.execute({&x}, {}, {2, 1}, {}, false, nd4j::DataType::DOUBLE); ASSERT_EQ(result->status(), Status::OK()); //result->at(0)->printIndexedBuffer("Output"); //exp.printIndexedBuffer("Expect"); ASSERT_TRUE(exp.equalsTo(result->at(0))); delete result; } //////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests7, TestRoll_6) { auto x = NDArrayFactory::create('c', {2, 3, 2}, { 0., 1., 2., 3., 4, 5., 6., 7., 8., 9., 10., 11. }); auto exp = NDArrayFactory::create('c', {2, 3, 2}, { 1., 0., 3., 2., 5., 4., 7., 6., 9., 8., 11., 10. }); // ---------------------------------------------------------------- nd4j::ops::roll op; auto result = op.execute({&x}, {}, {1, 2}, {}, false, nd4j::DataType::DOUBLE); ASSERT_EQ(result->status(), Status::OK()); //result->at(0)->printIndexedBuffer("Output"); //exp.printIndexedBuffer("Expect"); ASSERT_TRUE(exp.equalsTo(result->at(0))); delete result; } //////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests7, TestRoll_7) { auto x = NDArrayFactory::create('c', {2, 3, 2}, { 0., 1., 2., 3., 4., 5., 6., 7., 8., 9., 10., 11. }); auto exp = NDArrayFactory::create('c', {2, 3, 2}, { 11., 10., 7., 6., 9., 8., 5., 4., 1., 0., 3., 2. }); // ---------------------------------------------------------------- nd4j::ops::roll op; auto result = op.execute({&x}, {}, {1, 2, 1, 0}, {}, false, nd4j::DataType::DOUBLE); ASSERT_EQ(result->status(), Status::OK()); //result->at(0)->printIndexedBuffer("Output"); //exp.printIndexedBuffer("Expect"); ASSERT_TRUE(exp.equalsTo(result->at(0))); delete result; } //////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests7, TestRoll_8) { auto x = NDArrayFactory::create('c', {2, 3, 2}, { 0., 1., 2., 3., 4., 5., 6., 7., 8., 9., 10., 11. }); auto exp = NDArrayFactory::create('c', {2, 3, 2}, { 11., 10., 7., 6., 9., 8., 5., 4., 1., 0., 3., 2. }); // ---------------------------------------------------------------- nd4j::ops::roll op; NDArray* y = nullptr; auto result = op.execute({&x}, {y}, {}, {1, 2, 1, 0}, {}, true, nd4j::DataType::DOUBLE); ASSERT_EQ(result, Status::OK()); //x.printIndexedBuffer("Output"); //exp.printIndexedBuffer("Expect"); ASSERT_TRUE(exp.equalsTo(&x)); // delete result; } //////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests7, TestRoll_9) { auto x = NDArrayFactory::create('c', {2, 3, 3}, { 0., 1., 2., 3., 4., 5., 6., 7., 8., 9., 10., 11., 12., 13., 14., 15., 16., 17. }); auto exp = NDArrayFactory::create('c', {2, 3, 3}, { 6., 7., 8., 0., 1., 2., 3., 4., 5., 15., 16., 17., 9., 10., 11., 12., 13., 14. }); // ---------------------------------------------------------------- nd4j::ops::roll op; NDArray* y = nullptr; auto result = op.execute({&x}, {y}, {}, {1, 1}, {}, true, nd4j::DataType::DOUBLE); ASSERT_EQ(result, Status::OK()); x.printIndexedBuffer("Output"); //exp.printIndexedBuffer("Expect"); ASSERT_TRUE(exp.equalsTo(&x)); // delete result; } //////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests7, TestRoll_10) { auto x = NDArrayFactory::create('c', {2, 3, 4}, { 1., 2., 3., 4., 5., 6., 7., 8., 9., 10., 11., 12., 13., 14., 15., 16., 17., 18., 19., 20., 21., 22., 23., 24. }); auto exp = NDArrayFactory::create('c', {2, 3, 4}, { 1., 2., 3., 4., 5., 6., 7., 8., 9., 10., 11., 12., 13., 14., 15., 16., 17., 18., 19., 20., 21., 22., 23., 24. }); // ---------------------------------------------------------------- nd4j::ops::roll op; auto result = op.execute({&x}, {}, {3, 1}, {}, false, nd4j::DataType::DOUBLE); ASSERT_EQ(result->status(), Status::OK()); auto out = result->at(0); // out->printIndexedBuffer("Output"); //exp.printIndexedBuffer("Expect"); ASSERT_TRUE(exp.equalsTo(out)); delete result; } //////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests7, TestRoll_11) { auto x = NDArrayFactory::create('c', {2, 3, 4}, { 1., 2., 3., 4., 5., 6., 7., 8., 9., 10., 11., 12., 13., 14., 15., 16., 17., 18., 19., 20., 21., 22., 23., 24. }); auto shift = NDArrayFactory::create({1,2}); auto axis = NDArrayFactory::create({0, 1}); auto exp = NDArrayFactory::create('c', {2, 3, 4}, { 17., 18., 19., 20., 21., 22., 23., 24., 13., 14., 15., 16., 5., 6., 7, 8, 9, 10, 11, 12, 1, 2, 3, 4 }); // ---------------------------------------------------------------- nd4j::ops::roll op; NDArray* y = nullptr; auto result = op.execute({&x, &shift, &axis}, {}, {}, {}, false, nd4j::DataType::DOUBLE); ASSERT_EQ(result->status(), Status::OK()); auto out = result->at(0); // out->printIndexedBuffer("Output"); //exp.printIndexedBuffer("Expect"); ASSERT_TRUE(exp.equalsTo(out)); delete result; } //////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests7, TestRoll_12) { auto x = NDArrayFactory::create('c', {2, 3, 4}, { 1., 2., 3., 4., 5., 6., 7., 8., 9., 10., 11., 12., 13., 14., 15., 16., 17., 18., 19., 20., 21., 22., 23., 24. }); auto shift = NDArrayFactory::create({1,1,1}); auto axis = NDArrayFactory::create({0, 1, 2}); auto exp = NDArrayFactory::create('c', {2, 3, 4}, { 24, 21, 22, 23, 16, 13, 14, 15, 20, 17, 18, 19, 12, 9, 10, 11, 4, 1, 2, 3, 8, 5, 6, 7 }); // ---------------------------------------------------------------- nd4j::ops::roll op; NDArray* y = nullptr; auto result = op.execute({&x, &shift, &axis}, {}, {}, {}, false, nd4j::DataType::DOUBLE); ASSERT_EQ(result->status(), Status::OK()); auto out = result->at(0); out->printIndexedBuffer("Output"); //exp.printIndexedBuffer("Expect"); ASSERT_TRUE(exp.equalsTo(out)); delete result; } //////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests7, TestRoll_13) { auto x = NDArrayFactory::create('c', {2, 3, 4}, { 1., 2., 3., 4., 5., 6., 7., 8., 9., 10., 11., 12., 13., 14., 15., 16., 17., 18., 19., 20., 21., 22., 23., 24. }); auto shift = NDArrayFactory::create(3); auto axis = NDArrayFactory::create(2); auto exp = NDArrayFactory::create('c', {2, 3, 4}, { 2,3,4,1,6,7,8,5,10,11,12,9,14, 15, 16, 13, 18, 19, 20, 17, 22, 23, 24, 21 }); // ---------------------------------------------------------------- nd4j::ops::roll op; NDArray* y = nullptr; auto result = op.execute({&x}, {}, {3,2}, {}, false, nd4j::DataType::DOUBLE); ASSERT_EQ(result->status(), Status::OK()); auto out = result->at(0); // out->printIndexedBuffer("Output"); //exp.printIndexedBuffer("Expect"); ASSERT_TRUE(exp.equalsTo(out)); delete result; } //////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests7, TestRoll_14) { auto x = NDArrayFactory::create('c', {2, 3, 4}, { 1., 2., 3., 4., 5., 6., 7., 8., 9., 10., 11., 12., 13., 14., 15., 16., 17., 18., 19., 20., 21., 22., 23., 24. }); auto shift = NDArrayFactory::create({1,1,1}); auto axis = NDArrayFactory::create({0, 1, 2}); auto exp = NDArrayFactory::create('c', {2, 3, 4}, { 24, 21, 22, 23, 16, 13, 14, 15, 20, 17, 18, 19, 12, 9, 10, 11, 4, 1, 2, 3, 8, 5, 6, 7 }); // ---------------------------------------------------------------- nd4j::ops::roll op; auto result = op.execute({&x, &shift, &axis}, {}, {}, {}, false, nd4j::DataType::DOUBLE); ASSERT_EQ(result->status(), Status::OK()); auto out = result->at(0); // out->printIndexedBuffer("Output"); //exp.printIndexedBuffer("Expect"); ASSERT_TRUE(exp.equalsTo(out)); delete result; } //////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests7, percentile_test1) { const int dim0=5, dim1=5, dim2=4; auto input = NDArrayFactory::create('c', {dim0, dim1, dim2}, {6., 7., 83., 81., 84., 86., 87., 85., 88., 5., 8., 78., 79., 77., 80., 10., 16., 18., 19., 17., 20., 22., 23., 21., 24., 26., 27., 25., 28., 30., 31., 29., 32., 38., 11., 9., 12., 14., 15., 13., 39., 37., 40., 42., 43., 41., 44., 46., 47., 45., 48., 50., 51., 49., 52., 54., 55., 53., 56., 58., 59., 57., 60., 98., 99., 97.,100., 62., 63., 61., 64., 66., 67., 65., 68., 70., 71., 69., 72., 74., 75., 73., 76., 2., 3., 1., 4., 94., 95., 93., 96., 82., 90., 91., 89., 92., 34., 35., 33., 36.}); auto expected = NDArrayFactory::create(50.); nd4j::ops::percentile op; auto result = op.execute({&input}, {50.}, {}); auto output = result->at(0); ASSERT_TRUE(expected.isSameShape(output)); ASSERT_TRUE(expected.equalsTo(output)); delete result; } //////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests7, percentile_test2) { const int dim0=5, dim1=5, dim2=4; auto input = NDArrayFactory::create('c', {dim0, dim1, dim2}, {6., 7., 83., 81., 84., 86., 87., 85., 88., 5., 8., 78., 79., 77., 80., 10., 16., 18., 19., 17., 20., 22., 23., 21., 24., 26., 27., 25., 28., 30., 31., 29., 32., 38., 11., 9., 12., 14., 15., 13., 39., 37., 40., 42., 43., 41., 44., 46., 47., 45., 48., 50., 51., 49., 52., 54., 55., 53., 56., 58., 59., 57., 60., 98., 99., 97.,100., 62., 63., 61., 64., 66., 67., 65., 68., 70., 71., 69., 72., 74., 75., 73., 76., 2., 3., 1., 4., 94., 95., 93., 96., 82., 90., 91., 89., 92., 34., 35., 33., 36.}); auto expected = NDArrayFactory::create('c', {1,1,1}, {11.}); nd4j::ops::percentile op; //q, interpolation, keepDims auto result = op.execute({&input}, {10, 2, 1}, {}); auto output = result->at(0); ASSERT_TRUE(expected.isSameShape(output)); ASSERT_TRUE(expected.equalsTo(output)); delete result; } //////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests7, percentile_test3) { const int dim0=5, dim1=5, dim2=4; auto input = NDArrayFactory::create('c', {dim0, dim1, dim2}, {6., 7., 83., 81., 84., 86., 87., 85., 88., 5., 8., 78., 79., 77., 80., 10., 16., 18., 19., 17., 20., 22., 23., 21., 24., 26., 27., 25., 28., 30., 31., 29., 32., 38., 11., 9., 12., 14., 15., 13., 39., 37., 40., 42., 43., 41., 44., 46., 47., 45., 48., 50., 51., 49., 52., 54., 55., 53., 56., 58., 59., 57., 60., 98., 99., 97.,100., 62., 63., 61., 64., 66., 67., 65., 68., 70., 71., 69., 72., 74., 75., 73., 76., 2., 3., 1., 4., 94., 95., 93., 96., 82., 90., 91., 89., 92., 34., 35., 33., 36.}); auto expected = NDArrayFactory::create('c', {1,1,1}, {10.}); nd4j::ops::percentile op; //q, interpolation, keepDims auto result = op.execute({&input}, {10, 0, 1}, {}); auto output = result->at(0); ASSERT_TRUE(expected.isSameShape(output)); ASSERT_TRUE(expected.equalsTo(output)); delete result; } //////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests7, percentile_test4) { const int dim0=5, dim1=5, dim2=4; auto input = NDArrayFactory::create('c', {dim0, dim1, dim2}, {6., 7., 83., 81., 84., 86., 87., 85., 88., 5., 8., 78., 79., 77., 80., 10., 16., 18., 19., 17., 20., 22., 23., 21., 24., 26., 27., 25., 28., 30., 31., 29., 32., 38., 11., 9., 12., 14., 15., 13., 39., 37., 40., 42., 43., 41., 44., 46., 47., 45., 48., 50., 51., 49., 52., 54., 55., 53., 56., 58., 59., 57., 60., 98., 99., 97.,100., 62., 63., 61., 64., 66., 67., 65., 68., 70., 71., 69., 72., 74., 75., 73., 76., 2., 3., 1., 4., 94., 95., 93., 96., 82., 90., 91., 89., 92., 34., 35., 33., 36.}); auto expected = NDArrayFactory::create('c', {1,1,1}, {11.}); nd4j::ops::percentile op; //q, interpolation, keepDims auto result = op.execute({&input}, {10, 1, 1}, {}); auto output = result->at(0); ASSERT_TRUE(expected.isSameShape(output)); ASSERT_TRUE(expected.equalsTo(output)); delete result; } //////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests7, percentile_test5) { const int dim0=5, dim1=5, dim2=4; auto input = NDArrayFactory::create('c', {dim0, dim1, dim2}, {6., 7., 83., 81., 84., 86., 87., 85., 88., 5., 8., 78., 79., 77., 80., 10., 16., 18., 19., 17., 20., 22., 23., 21., 24., 26., 27., 25., 28., 30., 31., 29., 32., 38., 11., 9., 12., 14., 15., 13., 39., 37., 40., 42., 43., 41., 44., 46., 47., 45., 48., 50., 51., 49., 52., 54., 55., 53., 56., 58., 59., 57., 60., 98., 99., 97.,100., 62., 63., 61., 64., 66., 67., 65., 68., 70., 71., 69., 72., 74., 75., 73., 76., 2., 3., 1., 4., 94., 95., 93., 96., 82., 90., 91., 89., 92., 34., 35., 33., 36.}); auto expected = NDArrayFactory::create('c', {1,1,4}, {12., 7., 11., 10.}); nd4j::ops::percentile op; //q, interpolation, keepDims auto result = op.execute({&input}, {10, 0, 1}, {0,1}); auto output = result->at(0); ASSERT_TRUE(expected.isSameShape(output)); ASSERT_TRUE(expected.equalsTo(output)); delete result; } //////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests7, percentile_test6) { const int dim0=5, dim1=5, dim2=4; auto input = NDArrayFactory::create('c', {dim0, dim1, dim2}, {6., 7., 83., 81., 84., 86., 87., 85., 88., 5., 8., 78., 79., 77., 80., 10., 16., 18., 19., 17., 20., 22., 23., 21., 24., 26., 27., 25., 28., 30., 31., 29., 32., 38., 11., 9., 12., 14., 15., 13., 39., 37., 40., 42., 43., 41., 44., 46., 47., 45., 48., 50., 51., 49., 52., 54., 55., 53., 56., 58., 59., 57., 60., 98., 99., 97.,100., 62., 63., 61., 64., 66., 67., 65., 68., 70., 71., 69., 72., 74., 75., 73., 76., 2., 3., 1., 4., 94., 95., 93., 96., 82., 90., 91., 89., 92., 34., 35., 33., 36.}); auto expected = NDArrayFactory::create('c', {1,1,4}, {16., 14., 15., 13.}); nd4j::ops::percentile op; //q, interpolation, keepDims auto result = op.execute({&input}, {10, 1, 1}, {0,1}); auto output = result->at(0); ASSERT_TRUE(expected.isSameShape(output)); ASSERT_TRUE(expected.equalsTo(output)); delete result; } //////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests7, percentile_test7) { const int dim0=5, dim1=5, dim2=4; auto input = NDArrayFactory::create('c', {dim0, dim1, dim2}, {6., 7., 83., 81., 84., 86., 87., 85., 88., 5., 8., 78., 79., 77., 80., 10., 16., 18., 19., 17., 20., 22., 23., 21., 24., 26., 27., 25., 28., 30., 31., 29., 32., 38., 11., 9., 12., 14., 15., 13., 39., 37., 40., 42., 43., 41., 44., 46., 47., 45., 48., 50., 51., 49., 52., 54., 55., 53., 56., 58., 59., 57., 60., 98., 99., 97.,100., 62., 63., 61., 64., 66., 67., 65., 68., 70., 71., 69., 72., 74., 75., 73., 76., 2., 3., 1., 4., 94., 95., 93., 96., 82., 90., 91., 89., 92., 34., 35., 33., 36.}); auto expected = NDArrayFactory::create('c', {1,1,4}, {12., 7., 11., 10.}); nd4j::ops::percentile op; //q, interpolation, keepDims auto result = op.execute({&input}, {10, 2, 1}, {0,1}); auto output = result->at(0); ASSERT_TRUE(expected.isSameShape(output)); ASSERT_TRUE(expected.equalsTo(output)); delete result; } //////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests7, percentile_test8) { const int dim0=5, dim1=5, dim2=4; auto input = NDArrayFactory::create('c', {dim0, dim1, dim2}, {6., 7., 83., 81., 84., 86., 87., 85., 88., 5., 8., 78., 79., 77., 80., 10., 16., 18., 19., 17., 20., 22., 23., 21., 24., 26., 27., 25., 28., 30., 31., 29., 32., 38., 11., 9., 12., 14., 15., 13., 39., 37., 40., 42., 43., 41., 44., 46., 47., 45., 48., 50., 51., 49., 52., 54., 55., 53., 56., 58., 59., 57., 60., 98., 99., 97.,100., 62., 63., 61., 64., 66., 67., 65., 68., 70., 71., 69., 72., 74., 75., 73., 76., 2., 3., 1., 4., 94., 95., 93., 96., 82., 90., 91., 89., 92., 34., 35., 33., 36.}); auto expected = NDArrayFactory::create('c', {4}, {12., 7., 11., 10.}); nd4j::ops::percentile op; //q, interpolation, keepDims auto result = op.execute({&input}, {10, 2, 0}, {0,1}); auto output = result->at(0); ASSERT_TRUE(expected.isSameShape(output)); ASSERT_TRUE(expected.equalsTo(output)); delete result; } //////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests7, percentile_test9) { const int dim0=100; auto input = NDArrayFactory::create('c', {dim0}, {6., 7., 83., 81., 84., 86., 87., 85., 88., 5., 8., 78., 79., 77., 80., 10., 16., 18., 19., 17., 20., 22., 23., 21., 24., 26., 27., 25., 28., 30., 31., 29., 32., 38., 11., 9., 12., 14., 15., 13., 39., 37., 40., 42., 43., 41., 44., 46., 47., 45., 48., 50., 51., 49., 52., 54., 55., 53., 56., 58., 59., 57., 60., 98., 99., 97.,100., 62., 63., 61., 64., 66., 67., 65., 68., 70., 71., 69., 72., 74., 75., 73., 76., 2., 3., 1., 4., 94., 95., 93., 96., 82., 90., 91., 89., 92., 34., 35., 33., 36.}); auto expected = NDArrayFactory::create(11.); nd4j::ops::percentile op; //q, interpolation, keepDims auto result = op.execute({&input}, {10, 2, 0}, {0}); auto output = result->at(0); ASSERT_TRUE(expected.isSameShape(output)); ASSERT_TRUE(expected.equalsTo(output)); delete result; } //////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests7, percentile_test10) { const int dim0=100; auto input = NDArrayFactory::create('c', {dim0}, {6., 7., 83., 81., 84., 86., 87., 85., 88., 5., 8., 78., 79., 77., 80., 10., 16., 18., 19., 17., 20., 22., 23., 21., 24., 26., 27., 25., 28., 30., 31., 29., 32., 38., 11., 9., 12., 14., 15., 13., 39., 37., 40., 42., 43., 41., 44., 46., 47., 45., 48., 50., 51., 49., 52., 54., 55., 53., 56., 58., 59., 57., 60., 98., 99., 97.,100., 62., 63., 61., 64., 66., 67., 65., 68., 70., 71., 69., 72., 74., 75., 73., 76., 2., 3., 1., 4., 94., 95., 93., 96., 82., 90., 91., 89., 92., 34., 35., 33., 36.}); auto expected = NDArrayFactory::create('c', {1}, {11.}); nd4j::ops::percentile op; //q, interpolation, keepDims auto result = op.execute({&input}, {10, 2, 1}, {0}); auto output = result->at(0); ASSERT_TRUE(expected.isSameShape(output)); ASSERT_TRUE(expected.equalsTo(output)); delete result; } //////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests7, percentile_test11) { const int dim0=1; auto input = NDArrayFactory::create('c', {dim0}, {100.}); auto expected = NDArrayFactory::create('c', {1}, {100.}); nd4j::ops::percentile op; //q, interpolation, keepDims auto result = op.execute({&input}, {10, 2, 1}, {0}); auto output = result->at(0); ASSERT_TRUE(expected.isSameShape(output)); ASSERT_TRUE(expected.equalsTo(output)); delete result; } //////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests7, percentile_test12) { const int dim0=1; auto input = NDArrayFactory::create('c', {dim0}, {100.}); auto expected = NDArrayFactory::create(100.); nd4j::ops::percentile op; //q, interpolation, keepDims auto result = op.execute({&input}, {10, 2, 0}, {}); auto output = result->at(0); ASSERT_TRUE(expected.isSameShape(output)); ASSERT_TRUE(expected.equalsTo(output)); delete result; } //////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests7, transpose_test3) { auto input = NDArrayFactory::create('c', {5, 3}, {1.f, 2.f, 3.f, 4.f, 5.f, 6.f, 7.f, 8.f, 9.f, 10.f, 11.f, 12.f, 13.f, 14.f, 15.f}); auto exp = NDArrayFactory::create('c', {3, 5}, {1.f, 4.f, 7.f, 10.f, 13.f, 2.f, 5.f, 8.f, 11.f, 14.f, 3.f, 6.f, 9.f, 12.f, 15.f}); nd4j::ops::transpose op; auto result = op.execute({&input}, {}, {}); auto output = result->at(0); ASSERT_TRUE(exp.isSameShape(output)); ASSERT_TRUE(exp.equalsTo(output)); delete result; } //////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests7, rationaltanh_test1) { auto input = NDArrayFactory::create('c', {8}, {0, 1, 2, 3, 4, 5, 6, 7}); NDArray exp = NDArrayFactory::create({0.000000, 0.998222, 1.516093, 1.658054, 1.695077, 1.706884, 1.711427, 1.713446}); nd4j::ops::rationaltanh op; auto result = op.execute({&input}, {}, {}); auto output = result->at(0); // output->printIndexedBuffer("Output rationaltanh"); ASSERT_TRUE(exp.isSameShape(output)); ASSERT_TRUE(exp.equalsTo(output)); delete result; } //////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests7, rationaltanh_test2) { auto input = NDArrayFactory::create('c', {2,2,2}, {0, 1, 2, 3, 4, 5, 6, 7}); NDArray exp = NDArrayFactory::create('c', {2,2,2}, {0.000000, 0.998222, 1.516093, 1.658054, 1.695077, 1.706884, 1.711427, 1.713446}); nd4j::ops::rationaltanh op; auto result = op.execute({&input}, {}, {}); auto output = result->at(0); // output->printIndexedBuffer("Output rationaltanh"); ASSERT_TRUE(exp.isSameShape(output)); ASSERT_TRUE(exp.equalsTo(output)); delete result; } //////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests7, rationaltanh_test3) { auto input = NDArrayFactory::create('c', {2,2,2}, {0, 1, 2, 3, 4, 5, 6, 7}); auto eps = NDArrayFactory::create('c', {2,2,2}, {1, 2, 3, 4, 5, 6, 7, 8}); NDArray exp = NDArrayFactory::create('c', {2,2,2}, {1.143933, 1.605747, 0.795557, 0.261710, 0.095832, 0.041218, 0.020221, 0.010971}); nd4j::ops::rationaltanh_bp op; auto result = op.execute({&input, &eps}, {}, {}); auto output = result->at(0); // output->printBuffer("Output rationaltanh BP"); ASSERT_TRUE(exp.isSameShape(output)); ASSERT_TRUE(exp.equalsTo(output)); delete result; } //////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests7, rectifiedtanh_test1) { auto input = NDArrayFactory::create('c', {2,2,2}, {0, 1, 2, 3, 4, 5, 6, 7}); NDArray exp = NDArrayFactory::create('c', {2,2,2}, {0.000000, 0.761594, 0.964028, 0.995055, 0.999329, 0.999909, 0.999988, 0.999998}); nd4j::ops::rectifiedtanh op; auto result = op.execute({&input}, {}, {}); auto output = result->at(0); // output->printIndexedBuffer("Output rectifiedtanh"); ASSERT_TRUE(exp.isSameShape(output)); ASSERT_TRUE(exp.equalsTo(output)); delete result; } //////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests7, rectifiedtanh_test2) { auto input = NDArrayFactory::create('c', {2,2,2}, {0, 1, 2, 3, 4, 5, 6, 7}); auto eps = NDArrayFactory::create('c', {2,2,2}, {1, 2, 3, 4, 5, 6, 7, 8}); NDArray exp = NDArrayFactory::create('c', {2,2,2}, {0.000000, 0.839949, 0.211952, 0.039464, 0.006705, 0.001089, 0.000172, 0.000027}); nd4j::ops::rectifiedtanh_bp op; auto result = op.execute({&input, &eps}, {}, {}); auto output = result->at(0); // output->printBuffer("Output rectifiedtanh BP"); ASSERT_TRUE(exp.isSameShape(output)); ASSERT_TRUE(exp.equalsTo(output)); delete result; } TEST_F(DeclarableOpsTests7, RealDiv_1) { NDArray x = NDArrayFactory::create('c', {1, 2, 1}, {2, 4}); NDArray y = NDArrayFactory::create('c', {1, 2}, {1,2}); NDArray e = NDArrayFactory::create('c', {1, 2, 2}, {2, 1, 4, 2}); nd4j::ops::realdiv op; auto result = op.execute({&x, &y}, {}, {}); ASSERT_EQ(Status::OK(), result->status()); auto z = result->at(0); // z->printIndexedBuffer("OUtput RealDiv"); ASSERT_TRUE(e.isSameShape(z)); ASSERT_TRUE(e.equalsTo(*z)); delete result; } //////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests7, RealDiv_BP_1) { NDArray x = NDArrayFactory::create('c', {1, 2, 1}, {2, 4}); NDArray y = NDArrayFactory::create('c', {1, 2}, {1,2}); NDArray e0 = NDArrayFactory::create('c', {1, 2, 1}, {2, 5}); NDArray e1 = NDArrayFactory::create('c', {1, 2}, {-14, -5}); NDArray eps = NDArrayFactory::create('c', {1, 2, 2}, {1, 2, 3, 4}); nd4j::ops::realdiv_bp op; auto result = op.execute({&x, &y, &eps}, {}, {}); ASSERT_EQ(Status::OK(), result->status()); auto z0 = result->at(0); auto z1 = result->at(1); // z0->printShapeInfo("OUtput RealDiv BP0 shape"); // z1->printShapeInfo("OUtput RealDiv BP1 shape"); // z0->printIndexedBuffer("OUtput RealDiv BP0"); // z1->printIndexedBuffer("OUtput RealDiv BP1"); // ASSERT_TRUE(e.isSameShape(z)); ASSERT_TRUE(e0.equalsTo(z0)); ASSERT_TRUE(e1.equalsTo(z1)); delete result; } //////////////////////////////////////////////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests7, ShapesOf_1) { NDArray x = NDArrayFactory::create('c', {1, 2, 1}, {2, 4}); // NDArray y = NDArrayFactory::create('c', {1, 2}, {1,2}); NDArray e = NDArrayFactory::create({1, 2, 1}); nd4j::ops::shapes_of op; auto result = op.execute({&x}, {}, {}); ASSERT_EQ(Status::OK(), result->status()); auto z = result->at(0); // z->printIndexedBuffer("OUtput RealDiv"); // ASSERT_TRUE(e.isSameShape(z)); ASSERT_TRUE(e.equalsTo(*z)); delete result; } //////////////////////////////////////////////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests7, ShapesOf_2) { NDArray x = NDArrayFactory::create('c', {1, 2, 1}, {2, 4}); NDArray y = NDArrayFactory::create('c', {1, 2}, {1,2}); NDArray e0 = NDArrayFactory::create({1, 2, 1}); NDArray e1 = NDArrayFactory::create({1, 2}); nd4j::ops::shapes_of op; auto result = op.execute({&x, &y}, {}, {}); ASSERT_EQ(Status::OK(), result->status()); auto z0 = result->at(0); auto z1 = result->at(1); // z0->printIndexedBuffer("OUtput shapes2"); // z1->printIndexedBuffer("OUtput shapes2"); // ASSERT_TRUE(e.isSameShape(z)); ASSERT_TRUE(e0.equalsTo(z0)); ASSERT_TRUE(e1.equalsTo(z1)); delete result; } TEST_F(DeclarableOpsTests7, Size_1) { NDArray x = NDArrayFactory::create('c', {1, 2, 1}, {2, 4}); NDArray y = NDArrayFactory::create('c', {5, 2}, {1,2,3,4,5,7,9,10, 10, 11}); NDArray e = NDArrayFactory::create(2); nd4j::ops::size op; auto result = op.execute({&x}, {}, {}); ASSERT_EQ(Status::OK(), result->status()); auto z = result->at(0); // z->printIndexedBuffer("OUtput SIZE"); /// ASSERT_TRUE(e.isSameShape(z)); ASSERT_TRUE(e.equalsTo(*z)); delete result; } TEST_F(DeclarableOpsTests7, Size_2) { NDArray x = NDArrayFactory::create('c', {1, 2, 1}, {2, 4}); NDArray y = NDArrayFactory::create('c', {5, 2}, {1,2,3,4,5,7,9,10, 10, 11}); NDArray e = NDArrayFactory::create(10); nd4j::ops::size op; auto result = op.execute({&y}, {}, {}); ASSERT_EQ(Status::OK(), result->status()); auto z = result->at(0); // z->printIndexedBuffer("OUtput SIZE"); /// ASSERT_TRUE(e.isSameShape(z)); ASSERT_TRUE(e.equalsTo(*z)); delete result; } TEST_F(DeclarableOpsTests7, Softplus_1) { NDArray x = NDArrayFactory::create('c', {5, 2}, {1,2,3,4,5,7,9,10, 10, 11}); NDArray e = NDArrayFactory::create('c', {5, 2}, {1.3132616, 2.126928, 3.0485873, 4.01815, 5.0067153, 7.0009117, 9.000123, 10.000046, 10.000046, 11.000016}); nd4j::ops::softplus op; auto result = op.execute({&x}, {}, {}); ASSERT_EQ(Status::OK(), result->status()); auto z = result->at(0); // z->printIndexedBuffer("OUtput Softplus"); /// ASSERT_TRUE(e.isSameShape(z)); ASSERT_TRUE(e.equalsTo(*z)); delete result; } TEST_F(DeclarableOpsTests7, Softplus_BP_1) { NDArray x = NDArrayFactory::create('c', {5, 2}, {1,2,3,4,5,7,9,10, 10, 11}); // NDArray e = NDArrayFactory::create('c', {5, 2}, {1.3132616, 2.126928, 3.0485873, 4.01815, 5.0067153, 7.0009117, 9.000123, 10.000046, 10.000046, 11.000016}); NDArray eps = NDArrayFactory::create('c', {5, 2}, {1,2,3,4,5,6,7,8, 9, 10}); nd4j::ops::softplus ffOP; nd4j::ops::softplus_bp bpOp; const OpArgsHolder argsHolderFF({&x}, {}, {}); const OpArgsHolder argsHolderBP({&x, &eps}, {}, {}); bool gradOK = GradCheck::checkGrad(ffOP, bpOp, argsHolderFF, argsHolderBP); ASSERT_TRUE(gradOK); // // auto z = result->at(0); // z->printIndexedBuffer("OUtput Softplus"); ///// ASSERT_TRUE(e.isSameShape(z)); // ASSERT_TRUE(e.equalsTo(*z)); // // delete result; } TEST_F(DeclarableOpsTests7, Softsign_1) { NDArray x = NDArrayFactory::create('c', {5, 2}, {1,2,3,4,5,7,9,10, 10, 11}); NDArray e = NDArrayFactory::create('c', {5, 2}, {0.5, 0.6666667, 0.75, 0.8, 0.8333333, 0.875, 0.9, 0.90909094, 0.90909094, 0.9166667}); nd4j::ops::softsign op; auto result = op.execute({&x}, {}, {}); ASSERT_EQ(Status::OK(), result->status()); auto z = result->at(0); // z->printIndexedBuffer("OUtput Softsign"); /// ASSERT_TRUE(e.isSameShape(z)); ASSERT_TRUE(e.equalsTo(*z)); delete result; } TEST_F(DeclarableOpsTests7, Softsign_BP_1) { NDArray x = NDArrayFactory::create('c', {5, 2}, {1,2,3,4,5,7,9,10, 10, 11}); // NDArray e = NDArrayFactory::create('c', {5, 2}, {1.3132616, 2.126928, 3.0485873, 4.01815, 5.0067153, 7.0009117, 9.000123, 10.000046, 10.000046, 11.000016}); NDArray eps = NDArrayFactory::create('c', {5, 2}, {1,2,3,4,5,6,7,8, 9, 10}); nd4j::ops::softsign ffOP; nd4j::ops::softsign_bp bpOp; const OpArgsHolder argsHolderFF({&x}, {}, {}); const OpArgsHolder argsHolderBP({&x, &eps}, {}, {}); bool gradOK = GradCheck::checkGrad(ffOP, bpOp, argsHolderFF, argsHolderBP); ASSERT_TRUE(gradOK); } //////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests7, fill_test2) { auto x = NDArrayFactory::create('c', {1,2}, {2, 2}); auto v = NDArrayFactory::create(42.); auto exp = NDArrayFactory::create('c', {2, 2},{42.f, 42.f, 42.f, 42.f}); nd4j::ops::fill op; auto result = op.execute({&x, &v}, {}, {}); ASSERT_EQ(ND4J_STATUS_OK, result->status()); auto z = result->at(0); ASSERT_TRUE(exp.isSameShape(z)); ASSERT_TRUE(exp.equalsTo(z)); delete result; } //////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests7, fill_test3) { auto x = NDArrayFactory::create('c', {2}, {2, 2}); auto v = NDArrayFactory::create(42.); auto exp = NDArrayFactory::create('c', {2, 2}, {42.f, 42.f, 42.f, 42.f}); nd4j::ops::fill op; auto result = op.execute({&x, &v}, {}, {}); auto output = result->at(0); ASSERT_EQ(ND4J_STATUS_OK, result->status()); ASSERT_TRUE(exp.isSameShape(output)); ASSERT_TRUE(exp.equalsTo(output)); delete result; } //////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests7, ToggleBits_test1) { auto x = NDArrayFactory::create('c', {2}, {2, 2}); auto exp = NDArrayFactory::create('c', {2}, {-3, -3}); nd4j::ops::toggle_bits op; auto result = op.execute({&x}, {}, {}, {}, false, nd4j::DataType::INT32); auto output = result->at(0); ASSERT_EQ(ND4J_STATUS_OK, result->status()); // output->printIndexedBuffer("Toggled"); ASSERT_TRUE(exp.isSameShape(output)); ASSERT_TRUE(exp.equalsTo(output)); delete result; } //////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests7, ToggleBits_test2) { auto x = NDArrayFactory::create('c', {2}, {2, 2}); auto y = NDArrayFactory::create('c', {2}, {1, 1}); auto exp0 = NDArrayFactory::create('c', {2}, {-3, -3}); auto exp1 = NDArrayFactory::create('c', {2}, {-2, -2}); nd4j::ops::toggle_bits op; auto result = op.execute({&x, &y}, {}, {}, {}, false, nd4j::DataType::INT32); auto output = result->at(0); auto z = result->at(1); ASSERT_EQ(ND4J_STATUS_OK, result->status()); // output->printIndexedBuffer("Toggled"); ASSERT_TRUE(exp0.isSameShape(output)); ASSERT_TRUE(exp0.equalsTo(output)); ASSERT_TRUE(exp1.isSameShape(z)); ASSERT_TRUE(exp1.equalsTo(z)); delete result; } //////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests7, Truncatediv_test1) { NDArray x = NDArrayFactory::create('c', {5, 2}, {1,2,3,4,5,7,9,10, 10, 11}); NDArray y = NDArrayFactory::create('c', {5, 2}, {2,2,2,2,2,2,2,2, 2, 2}); NDArray exp = NDArrayFactory::create('c', {5, 2}, {0.5, 1., 1.5, 2., 2.5, 3.5, 4.5, 5., 5., 5.5}); nd4j::ops::truncatediv op; auto result = op.execute({&x, &y}, {}, {}); ASSERT_EQ(ND4J_STATUS_OK, result->status()); auto output = result->at(0); // output->printIndexedBuffer("Toggled"); ASSERT_TRUE(exp.isSameShape(output)); delete result; } //////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests7, Truncatediv_test2) { NDArray x = NDArrayFactory::create('c', {5, 2}, {1,2,3,4,5,7,9,10, 10, 11}); NDArray y = NDArrayFactory::create('c', {1, 2}, {2,2}); NDArray exp = NDArrayFactory::create('c', {5, 2}, {0.5, 1., 1.5, 2., 2.5, 3.5, 4.5, 5., 5., 5.5}); nd4j::ops::truncatediv op; auto result = op.execute({&x, &y}, {}, {}); ASSERT_EQ(ND4J_STATUS_OK, result->status()); auto output = result->at(0); // output->printIndexedBuffer("Toggled"); ASSERT_TRUE(exp.isSameShape(output)); delete result; } //////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests7, TypesConversion_test1) { NDArray x = NDArrayFactory::create('c', {5, 2}, {1,2,3,4,5,7,9,10, 10, 11}); NDArray expI = NDArrayFactory::create('c', {5, 2}, {1,2,3,4,5,7,9,10, 10, 11}); NDArray expL = NDArrayFactory::create('c', {5, 2}, {1,2,3,4,5,7,9,10, 10, 11}); NDArray expF = NDArrayFactory::create('c', {5, 2}, {1,2,3,4,5,7,9,10, 10, 11}); NDArray expF16 = NDArrayFactory::create('c', {5, 2}, {1.f,2.f,3.f,4.f,5.f,7.f,9.f,10.f, 10.f, 11.f}); nd4j::ops::to_int32 op32; nd4j::ops::to_int64 op64; auto result32 = op32.execute({&x}, {}, {}); auto result64 = op64.execute({&x}, {}, {}); ASSERT_EQ(ND4J_STATUS_OK, result32->status()); ASSERT_EQ(ND4J_STATUS_OK, result64->status()); auto out1 = result32->at(0); // out1->printIndexedBuffer("OUT_I"); auto out2 = result64->at(0); // out2->printIndexedBuffer("OUT_L"); // output->printIndexedBuffer("Toggled"); ASSERT_TRUE(expI.equalsTo(out1)); ASSERT_TRUE(expL.equalsTo(out2)); delete result32; delete result64; } //////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests7, TypesConversion_test2) { NDArray x = NDArrayFactory::create('c', {5, 2}, {1,2,3,4,5,7,9,10, 10, 11}); NDArray expF = NDArrayFactory::create('c', {5, 2}, {1,2,3,4,5,7,9,10, 10, 11}); NDArray expH = NDArrayFactory::create('c', {5, 2}, {1.f,2.f,3.f,4.f,5.f,7.f,9.f,10.f, 10.f, 11.f}); nd4j::ops::to_float32 op32; nd4j::ops::to_float16 op16; auto result32 = op32.execute({&x}, {}, {}); auto result16 = op16.execute({&x}, {}, {}); ASSERT_EQ(ND4J_STATUS_OK, result32->status()); ASSERT_EQ(ND4J_STATUS_OK, result16->status()); auto out1 = result32->at(0); // out1->printIndexedBuffer("OUT_F"); auto out2 = result16->at(0); // out2->printIndexedBuffer("OUT_H"); // output->printIndexedBuffer("Toggled"); ASSERT_TRUE(expF.equalsTo(out1)); ASSERT_TRUE(expH.equalsTo(out2)); delete result32; delete result16; } //////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests7, TypesConversion_test3) { NDArray x = NDArrayFactory::create('c', {5, 2}, {1,2,3,4,5,7,9,10, 10, 11}); NDArray exp32 = NDArrayFactory::create('c', {5, 2}, {1,2,3,4,5,7,9,10, 10, 11}); NDArray exp64 = NDArrayFactory::create('c', {5, 2}, {1,2,3,4,5,7,9,10, 10, 11}); nd4j::ops::to_uint32 op32; nd4j::ops::to_uint64 op64; auto result32 = op32.execute({&x}, {}, {}); auto result64 = op64.execute({&x}, {}, {}); ASSERT_EQ(ND4J_STATUS_OK, result32->status()); ASSERT_EQ(ND4J_STATUS_OK, result64->status()); auto out1 = result32->at(0); // out1->printIndexedBuffer("OUT_U32"); auto out2 = result64->at(0); // out2->printIndexedBuffer("OUT_U64"); // output->printIndexedBuffer("Toggled"); ASSERT_TRUE(exp32.equalsTo(out1)); ASSERT_TRUE(exp64.equalsTo(out2)); delete result32; delete result64; } //////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests7, TypesConversion_test4) { NDArray x = NDArrayFactory::create('c', {5, 2}, {1,2,3,4,5,7,9,10, 10, 11}); NDArray exp32 = NDArrayFactory::create('c', {5, 2}, {1,2,3,4,5,7,9,10, 10, 11}); NDArray exp64 = NDArrayFactory::create('c', {5, 2}, {1,2,3,4,5,7,9,10, 10, 11}); nd4j::ops::to_float32 op32; nd4j::ops::to_double op64; auto result32 = op32.execute({&x}, {}, {}); auto result64 = op64.execute({&x}, {}, {}); ASSERT_EQ(ND4J_STATUS_OK, result32->status()); ASSERT_EQ(ND4J_STATUS_OK, result64->status()); auto out1 = result32->at(0); out1->printIndexedBuffer("OUT_F"); auto out2 = result64->at(0); out2->printIndexedBuffer("OUT_D"); // output->printIndexedBuffer("Toggled"); ASSERT_TRUE(exp32.equalsTo(out1)); ASSERT_TRUE(exp64.equalsTo(out2)); delete result32; delete result64; } //////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests7, mirrorPad_test1) { auto input = NDArrayFactory::create('c', {2, 3}, {1., 2., 3., 4., 5., 6.}); auto paddings = NDArrayFactory::create('c', {2, 2}, {1, 1, 2, 2}); auto exp = NDArrayFactory::create('c', {4, 7}, {2, 1, 1, 2, 3, 3, 2, 2, 1, 1, 2, 3, 3, 2, 5, 4, 4, 5, 6, 6, 5, 5, 4, 4, 5, 6, 6, 5}); nd4j::ops::mirror_pad op; auto result = op.execute({&input, &paddings}, {}, {1}); auto output = result->at(0); ASSERT_TRUE(exp.isSameShape(output)); ASSERT_TRUE(exp.equalsTo(output)); delete result; } //////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests7, mirrorPad_test2) { auto input = NDArrayFactory::create('c', {2, 3}, {1., 2., 3., 4., 5., 6.}); auto paddings = NDArrayFactory::create('c', {2, 2}, {1, 1, 2, 2}); auto exp = NDArrayFactory::create('c', {4, 7}, {6, 5, 4, 5, 6, 5, 4, 3, 2, 1, 2, 3, 2, 1, 6, 5, 4, 5, 6, 5, 4, 3, 2, 1, 2, 3, 2, 1}); nd4j::ops::mirror_pad op; auto result = op.execute({&input, &paddings}, {}, {0}); auto output = result->at(0); ASSERT_TRUE(exp.isSameShape(output)); ASSERT_TRUE(exp.equalsTo(output)); delete result; } //////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests7, mirrorPad_test3) { auto input = NDArrayFactory::create('c', {3}, {1., 2., 3.}); auto paddings = NDArrayFactory::create('c', {1,2}, {2, 2}); auto exp = NDArrayFactory::create('c', {7}, {2, 1, 1, 2, 3, 3, 2}); nd4j::ops::mirror_pad op; auto result = op.execute({&input, &paddings}, {}, {1}); auto output = result->at(0); ASSERT_TRUE(exp.isSameShape(output)); ASSERT_TRUE(exp.equalsTo(output)); delete result; } //////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests7, mirrorPad_test4) { auto input = NDArrayFactory::create('c', {3}, {1., 2., 3.}); auto paddings = NDArrayFactory::create('c', {2}, {2, 3}); auto exp = NDArrayFactory::create('c', {8}, {2, 1, 1, 2, 3, 3, 2, 1}); nd4j::ops::mirror_pad op; auto result = op.execute({&input, &paddings}, {}, {1}); auto output = result->at(0); ASSERT_TRUE(exp.isSameShape(output)); ASSERT_TRUE(exp.equalsTo(output)); delete result; } //////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests7, mirrorPad_test5) { auto input = NDArrayFactory::create('c', {3}, {1., 2., 3.}); auto paddings = NDArrayFactory::create('c', {2}, {2, 2}); auto exp = NDArrayFactory::create('c', {7}, {3, 2, 1, 2, 3, 2, 1}); nd4j::ops::mirror_pad op; auto result = op.execute({&input, &paddings}, {}, {0}); auto output = result->at(0); output->printBuffer("Output"); exp.printBuffer("Expected"); ASSERT_TRUE(exp.isSameShape(output)); ASSERT_TRUE(exp.equalsTo(output)); delete result; } //////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests7, mirrorPad_test6) { auto input = NDArrayFactory::create(1.); auto paddings = NDArrayFactory::create('c', {1,2,1,1}, {1, 1}); auto exp = NDArrayFactory::create('c', {3}, {1,1,1}); nd4j::ops::mirror_pad op; auto result = op.execute({&input, &paddings}, {}, {1}); auto output = result->at(0); ASSERT_TRUE(exp.isSameShape(output)); ASSERT_TRUE(exp.equalsTo(output)); delete result; } //////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests7, mirrorPad_test7) { auto input = NDArrayFactory::create(1.); auto paddings = NDArrayFactory::create('c', {2}, {1, 1}); auto exp = NDArrayFactory::create('c', {3}, {1,1,1}); nd4j::ops::mirror_pad op; auto result = op.execute({&input, &paddings}, {}, {1}); auto output = result->at(0); ASSERT_TRUE(exp.isSameShape(output)); ASSERT_TRUE(exp.equalsTo(output)); delete result; } //////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests7, mirrorPad_test8) { auto input = NDArrayFactory::create('c', {1,3}, {1., 2., 3.}); auto paddings = NDArrayFactory::create('c', {2, 2}, {1, 1, 3, 3}); auto exp = NDArrayFactory::create('c', {3,9}, {3, 2, 1, 1, 2, 3, 3, 2, 1, 3, 2, 1, 1, 2, 3, 3, 2, 1, 3, 2, 1, 1, 2, 3, 3, 2, 1}); nd4j::ops::mirror_pad op; auto result = op.execute({&input, &paddings}, {}, {1}); ASSERT_EQ(result->status(), Status::OK()); auto output = result->at(0); ASSERT_TRUE(exp.isSameShape(output)); ASSERT_TRUE(exp.equalsTo(output)); delete result; } //////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests7, mirrorPad_test9) { auto input = NDArrayFactory::create('c', {2, 3}, {1., 2., 3., 4., 5., 6.}); auto paddings = NDArrayFactory::create('c', {2, 2}, {2, 2, 3, 3}); auto exp = NDArrayFactory::create('c', {6, 9}, {6, 5, 4, 4, 5, 6, 6, 5, 4, 3, 2, 1, 1, 2, 3, 3, 2, 1, 3, 2, 1, 1, 2, 3, 3, 2, 1, 6, 5, 4, 4, 5, 6, 6, 5, 4, 6, 5, 4, 4, 5, 6, 6, 5, 4, 3, 2, 1, 1, 2, 3, 3, 2, 1}); nd4j::ops::mirror_pad op; auto result = op.execute({&input, &paddings}, {}, {1}); auto output = result->at(0); ASSERT_TRUE(exp.isSameShape(output)); ASSERT_TRUE(exp.equalsTo(output)); delete result; } //////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests7, mirrorPad_test10) { auto input = NDArrayFactory::create('c', {1,3}, {1., 2., 3.}); auto paddings = NDArrayFactory::create('c', {2, 2}, {0, 0, 0, 0}); auto exp = NDArrayFactory::create('c', {1,3}, {1., 2., 3.}); nd4j::ops::mirror_pad op; auto result = op.execute({&input, &paddings}, {}, {1}); auto output = result->at(0); ASSERT_TRUE(exp.isSameShape(output)); ASSERT_TRUE(exp.equalsTo(output)); delete result; } //////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests7, mirrorPad_test11) { auto input = NDArrayFactory::create('c', {1,3}, {1., 2., 3.}); auto paddings = NDArrayFactory::create('c', {2, 2}, {0, 0, 0, 0}); auto exp = NDArrayFactory::create('c', {1,3}, {1., 2., 3.}); nd4j::ops::mirror_pad op; auto result = op.execute({&input, &paddings}, {}, {0}); auto output = result->at(0); ASSERT_TRUE(exp.isSameShape(output)); ASSERT_TRUE(exp.equalsTo(output)); delete result; } //////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests7, mirrorPad_test12) { auto input = NDArrayFactory::create('c', {3}, {1., 2., 3.}); auto paddings = NDArrayFactory::create('c', {2,1}, {0, 0}); auto exp = NDArrayFactory::create('c', {3}, {1., 2., 3.}); nd4j::ops::mirror_pad op; auto result = op.execute({&input, &paddings}, {}, {0}); auto output = result->at(0); ASSERT_TRUE(exp.isSameShape(output)); ASSERT_TRUE(exp.equalsTo(output)); delete result; } //////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests7, mirrorPad_test13) { auto input = NDArrayFactory::create('c', {2, 3}, {1., 2., 3., 4., 5., 6.}); auto paddings = NDArrayFactory::create('c', {2, 2}, {0, 0, 0, 0}); auto exp = NDArrayFactory::create('c', {2, 3}, {1., 2., 3., 4., 5., 6.}); nd4j::ops::mirror_pad op; auto result = op.execute({&input, &paddings}, {}, {0}); auto output = result->at(0); ASSERT_TRUE(exp.isSameShape(output)); ASSERT_TRUE(exp.equalsTo(output)); delete result; } //////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests7, mirrorPad_test14) { auto input = NDArrayFactory::create('c', {2, 3}, {1., 2., 3., 4., 5., 6.}); auto paddings = NDArrayFactory::create('c', {2, 2}, {1, 0, 0, 1}); auto exp = NDArrayFactory::create('c', {3, 4}, {4, 5, 6, 5, 1, 2, 3, 2, 4, 5, 6, 5}); nd4j::ops::mirror_pad op; auto result = op.execute({&input, &paddings}, {}, {0}); auto output = result->at(0); ASSERT_TRUE(exp.isSameShape(output)); ASSERT_TRUE(exp.equalsTo(output)); delete result; } //////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests7, mirrorPad_test15) { auto input = NDArrayFactory::create('c', {2, 3}, {1., 2., 3., 4., 5., 6.}); auto paddings = NDArrayFactory::create('c', {2, 2}, {1, 1, 0, 0}); auto exp = NDArrayFactory::create('c', {4, 3}, {1, 2, 3, 1, 2, 3, 4, 5, 6, 4, 5, 6}); nd4j::ops::mirror_pad op; auto result = op.execute({&input, &paddings}, {}, {1}); auto output = result->at(0); ASSERT_TRUE(exp.isSameShape(output)); ASSERT_TRUE(exp.equalsTo(output)); delete result; } //////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests7, mirrorPad_test16) { auto input = NDArrayFactory::create('c', {4,3,2}); auto paddings = NDArrayFactory::create('c', {3,2}, {3,3,2,2,1,1}); auto exp = NDArrayFactory::create('c', {10,7,4}, {24., 23., 24., 23.,22., 21., 22., 21.,20., 19., 20., 19.,22., 21., 22., 21.,24., 23., 24., 23.,22., 21., 22., 21.,20., 19., 20., 19.,18., 17., 18., 17.,16., 15., 16., 15.,14., 13., 14., 13.,16., 15., 16., 15.,18., 17., 18., 17.,16., 15., 16., 15.,14., 13., 14., 13., 12., 11., 12., 11.,10., 9., 10., 9., 8., 7., 8., 7.,10., 9., 10., 9.,12., 11., 12., 11.,10., 9., 10., 9., 8., 7., 8., 7., 6., 5., 6., 5., 4., 3., 4., 3., 2., 1., 2., 1., 4., 3., 4., 3., 6., 5., 6., 5., 4., 3., 4., 3., 2., 1., 2., 1., 12., 11., 12., 11.,10., 9., 10., 9., 8., 7., 8., 7.,10., 9., 10., 9.,12., 11., 12., 11.,10., 9., 10., 9., 8., 7., 8., 7.,18., 17., 18., 17.,16., 15., 16., 15.,14., 13., 14., 13.,16., 15., 16., 15.,18., 17., 18., 17.,16., 15., 16., 15.,14., 13., 14., 13., 24., 23., 24., 23.,22., 21., 22., 21.,20., 19., 20., 19.,22., 21., 22., 21.,24., 23., 24., 23.,22., 21., 22., 21.,20., 19., 20., 19.,18., 17., 18., 17.,16., 15., 16., 15.,14., 13., 14., 13.,16., 15., 16., 15.,18., 17., 18., 17.,16., 15., 16., 15.,14., 13., 14., 13., 12., 11., 12., 11.,10., 9., 10., 9., 8., 7., 8., 7.,10., 9., 10., 9.,12., 11., 12., 11.,10., 9., 10., 9., 8., 7., 8., 7., 6., 5., 6., 5., 4., 3., 4., 3., 2., 1., 2., 1., 4., 3., 4., 3., 6., 5., 6., 5., 4., 3., 4., 3., 2., 1., 2., 1.}); input.linspace(1.); nd4j::ops::mirror_pad op; auto result = op.execute({&input, &paddings}, {}, {0}); ASSERT_EQ(result->status(), Status::OK()); auto output = result->at(0); //output->printBuffer("VVV"); //exp.printBuffer("EXP"); ASSERT_TRUE(exp.isSameShape(output)); ASSERT_TRUE(exp.equalsTo(output)); delete result; } //////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests7, Test_Reduce_Sum_1) { auto input = NDArrayFactory::create('c', {3, 5}, {1., 2., 3., 4., 5., 6., 7., 8., 9., 10., 11., 12., 13., 14., 15.}); auto exp = NDArrayFactory::create(120.f); //************************************// nd4j::ops::reduce_sum op; auto result = op.execute({&input}, {}, {}); ASSERT_EQ(Status::OK(), result->status()); auto z = result->at(0); //z->printIndexedBuffer("Result is "); ASSERT_TRUE(exp.equalsTo(z)); delete result; } //////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests7, Test_Reduce_Sum_2) { auto input = NDArrayFactory::create('c', {3, 5}, {1., 2., 3., 4., 5., 6., 7., 8., 9., 10., 11., 12., 13., 14., 15.}); auto exp = NDArrayFactory::create({15.f, 40.f, 65.f}); //************************************// nd4j::ops::reduce_sum op; auto result = op.execute({&input}, {}, {1}); ASSERT_EQ(Status::OK(), result->status()); auto z = result->at(0); // z->printIndexedBuffer("Result is "); ASSERT_TRUE(exp.equalsTo(z)); delete result; } //////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests7, Test_Reduce_Prod_1) { auto input = NDArrayFactory::create('c', {3, 5}, {1., 2., 3., 4., 5., 6., 7., 8., 9., 10., 11., 12., 13., 14., 15.}); auto exp = NDArrayFactory::create(1307674368000.f); //************************************// nd4j::ops::reduce_prod op; auto result = op.execute({&input}, {}, {}); ASSERT_EQ(Status::OK(), result->status()); auto z = result->at(0); //z->printIndexedBuffer("Result is "); ASSERT_TRUE(exp.equalsTo(z)); delete result; } //////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests7, Test_Reduce_Prod_2) { auto input = NDArrayFactory::create('c', {3, 5}, {1., 2., 3., 4., 5., 6., 7., 8., 9., 10., 11., 12., 13., 14., 15.}); auto exp = NDArrayFactory::create({120.f, 30240.f, 360360.f}); //************************************// nd4j::ops::reduce_prod op; auto result = op.execute({&input}, {}, {1}); ASSERT_EQ(Status::OK(), result->status()); auto z = result->at(0); // z->printIndexedBuffer("Result is "); ASSERT_TRUE(exp.equalsTo(z)); delete result; } //////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests7, Test_Reduce_Sum_01) { auto x = NDArrayFactory::create('c', {2,3,4}); auto exp = NDArrayFactory::create('c', {4}, {66.f, 72.f, 78.f, 84.f}); x.linspace(1); nd4j::ops::reduce_sum op; auto result = op.execute({&x}, {}, {0,1}); auto output = result->at(0); // output->printIndexedBuffer("Result is"); ASSERT_EQ(ND4J_STATUS_OK, result->status()); ASSERT_TRUE(exp.isSameShape(output)); ASSERT_TRUE(exp.equalsTo(output)); delete result; } //////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests7, Test_Reduce_Sum_02) { auto x = NDArrayFactory::create('c', {2,3,4}); auto exp = NDArrayFactory::create('c', {1,1,4}, {66.f, 72.f, 78.f, 84.f}); x.linspace(1); nd4j::ops::reduce_sum op; auto result = op.execute({&x}, {1.}, {0, 1}); auto output = result->at(0); // output->printIndexedBuffer("Result is"); ASSERT_EQ(ND4J_STATUS_OK, result->status()); ASSERT_TRUE(exp.isSameShape(output)); ASSERT_TRUE(exp.equalsTo(output)); delete result; } //////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests7, Test_Reduce_Sum_3) { auto x = NDArrayFactory::create('c', {2,3,4}); auto exp = NDArrayFactory::create('c', {3}, {68.f, 100.f, 132.f}); x.linspace(1); nd4j::ops::reduce_sum op; auto result = op.execute({&x}, {}, {0, 2}); auto output = result->at(0); // output->printIndexedBuffer("Result is"); ASSERT_EQ(ND4J_STATUS_OK, result->status()); ASSERT_TRUE(exp.isSameShape(output)); ASSERT_TRUE(exp.equalsTo(output)); delete result; } //////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests7, Test_Reduce_Sum_4) { auto x = NDArrayFactory::create('c', {2,3,4}); auto exp = NDArrayFactory::create('c', {1,3,1}, {68.f, 100.f, 132.f}); x.linspace(1); nd4j::ops::reduce_sum op; auto result = op.execute({&x}, {1.}, {0, 2}); auto output = result->at(0); // output->printIndexedBuffer("Result is"); ASSERT_EQ(ND4J_STATUS_OK, result->status()); ASSERT_TRUE(exp.isSameShape(output)); ASSERT_TRUE(exp.equalsTo(output)); delete result; } //////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests7, Test_Reduce_Sum_5) { auto x = NDArrayFactory::create('c', {2,3,4}); auto exp = NDArrayFactory::create(300.f); x.linspace(1); nd4j::ops::reduce_sum op; auto result = op.execute({&x}, {}, {}); auto output = result->at(0); // output->printIndexedBuffer("Result is"); ASSERT_EQ(ND4J_STATUS_OK, result->status()); ASSERT_TRUE(exp.isSameShape(output)); ASSERT_TRUE(exp.equalsTo(output)); delete result; } //////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests7, Test_Reduce_Sum_6) { auto x = NDArrayFactory::create('c', {2,3,4}); auto exp = NDArrayFactory::create(300.f); x.linspace(1); nd4j::ops::reduce_sum op; auto result = op.execute({&x}, {}, {0,1,2}); auto output = result->at(0); // output->printIndexedBuffer("Result is"); ASSERT_EQ(ND4J_STATUS_OK, result->status()); ASSERT_TRUE(exp.isSameShape(output)); ASSERT_TRUE(exp.equalsTo(output)); delete result; } //////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests7, Test_Reduce_Sum_7) { auto x = NDArrayFactory::create('c', {2,3,4}); auto exp = NDArrayFactory::create('c', {1,1,1}, {300.f}); x.linspace(1); // x.printIndexedBuffer("Input with shape (2, 3, 4) is"); nd4j::ops::reduce_sum op; auto result = op.execute({&x}, {1.}, {0,1,2}); auto output = result->at(0); // output->printIndexedBuffer("Result is"); ASSERT_EQ(ND4J_STATUS_OK, result->status()); ASSERT_TRUE(exp.isSameShape(output)); ASSERT_TRUE(exp.equalsTo(output)); delete result; } //////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests7, Test_Reduce_Prod_01) { auto x = NDArrayFactory::create('c', {2,3,2}); auto exp = NDArrayFactory::create('c', {2}, {10395.f, 46080.f}); x.linspace(1); nd4j::ops::reduce_prod op; auto result = op.execute({&x}, {}, {0,1}); auto output = result->at(0); // output->printIndexedBuffer("Result is"); ASSERT_EQ(ND4J_STATUS_OK, result->status()); ASSERT_TRUE(exp.isSameShape(output)); ASSERT_TRUE(exp.equalsTo(output)); delete result; } //////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests7, Test_Reduce_Prod_02) { auto x = NDArrayFactory::create('c', {2,3,2}); auto exp = NDArrayFactory::create('c', {1,1,2}, {10395.f, 46080.f}); x.linspace(1); nd4j::ops::reduce_prod op; auto result = op.execute({&x}, {1.}, {0, 1}); auto output = result->at(0); // output->printIndexedBuffer("Result is"); ASSERT_EQ(ND4J_STATUS_OK, result->status()); ASSERT_TRUE(exp.isSameShape(output)); ASSERT_TRUE(exp.equalsTo(output)); delete result; } //////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests7, Test_Reduce_Prod_3) { auto x = NDArrayFactory::create('c', {2,3,2}); auto exp = NDArrayFactory::create('c', {3}, {112.f, 1080.f, 3960.f}); x.linspace(1); nd4j::ops::reduce_prod op; auto result = op.execute({&x}, {}, {0, 2}); auto output = result->at(0); // output->printIndexedBuffer("Result is"); ASSERT_EQ(ND4J_STATUS_OK, result->status()); ASSERT_TRUE(exp.isSameShape(output)); ASSERT_TRUE(exp.equalsTo(output)); delete result; } //////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests7, Test_Reduce_Prod_4) { auto x = NDArrayFactory::create('c', {2,3,2}); auto exp = NDArrayFactory::create('c', {1,3,1}, {112.f, 1080.f, 3960.f}); x.linspace(1); nd4j::ops::reduce_prod op; auto result = op.execute({&x}, {1.}, {0, 2}); auto output = result->at(0); // output->printIndexedBuffer("Result is"); ASSERT_EQ(ND4J_STATUS_OK, result->status()); ASSERT_TRUE(exp.isSameShape(output)); ASSERT_TRUE(exp.equalsTo(output)); delete result; } //////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests7, Test_Reduce_Prod_5) { auto x = NDArrayFactory::create('c', {2,3,2}); auto exp = NDArrayFactory::create(479001600.f); x.linspace(1); nd4j::ops::reduce_prod op; auto result = op.execute({&x}, {}, {}); auto output = result->at(0); // output->printIndexedBuffer("Result is"); ASSERT_EQ(ND4J_STATUS_OK, result->status()); ASSERT_TRUE(exp.isSameShape(output)); ASSERT_TRUE(exp.equalsTo(output)); delete result; } //////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests7, Test_Reduce_Prod_6) { auto x = NDArrayFactory::create('c', {2,3,2}); auto exp = NDArrayFactory::create(479001600.f); x.linspace(1); nd4j::ops::reduce_prod op; auto result = op.execute({&x}, {}, {0,1,2}); auto output = result->at(0); // output->printIndexedBuffer("Result is"); ASSERT_EQ(ND4J_STATUS_OK, result->status()); ASSERT_TRUE(exp.isSameShape(output)); ASSERT_TRUE(exp.equalsTo(output)); delete result; } //////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests7, Test_Reduce_Prod_7) { auto x = NDArrayFactory::create('c', {2,3,2}); auto exp = NDArrayFactory::create('c', {1, 1, 1}, {479001600.f}); x.linspace(1); // x.printIndexedBuffer("Input with shape (2, 3, 4) is"); nd4j::ops::reduce_prod op; auto result = op.execute({&x}, {1.}, {0,1,2}); auto output = result->at(0); // output->printIndexedBuffer("Result is"); ASSERT_EQ(ND4J_STATUS_OK, result->status()); ASSERT_TRUE(exp.isSameShape(output)); ASSERT_TRUE(exp.equalsTo(output)); delete result; } TEST_F(DeclarableOpsTests7, Test_Matmul_Once_Again) { auto x = NDArrayFactory::create('c', {1, 2}, {2.0f, 2.0f}); auto y = NDArrayFactory::create('c', {2, 1}, {2.0f, 2.0f}); auto exp = NDArrayFactory::create('c', {1, 1}, {8.0f}); nd4j::ops::matmul op; auto result = op.execute({&x, &y}, {}, {}); ASSERT_EQ(Status::OK(), result->status()); ASSERT_EQ(exp, *result->at(0)); delete result; } TYPED_TEST(TypedDeclarableOpsTests7, Test_Pnorm_Once_Again) { auto input = NDArrayFactory::create('c', {1, 1, 5, 5}, {1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0, 11.0, 12.0, 13.0, 14.0, 15.0, 16.0, 17.0, 18.0, 19.0, 20.0, 21.0, 22.0, 23.0, 24.0, 25.0}); auto exp = NDArrayFactory::create('c', {1, 1, 5, 5}, {1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0, 11.0, 12.0, 13.0, 14.0, 15.0, 16.0, 17.0, 18.0, 19.0, 20.0, 21.0, 22.0, 23.0, 24.0, 25.0}); nd4j::ops::pnormpool2d op; auto result = op.execute({&input}, {}, {1,1, 1,1, 0,0, 1,1,1, 3, 0}); ASSERT_EQ(Status::OK(), result->status()); ASSERT_EQ(exp, *result->at(0)); delete result; } //////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests7, Test_Reduce_Min_1) { auto x = NDArrayFactory::create('c', {2, 3, 4}); auto exp = NDArrayFactory::create('c', {4}, {1.f, 2.f, 3.f, 4.f}); x.linspace(1); nd4j::ops::reduce_min op; auto result = op.execute({&x}, {}, {0, 1}); auto output = result->at(0); // output->printIndexedBuffer("Result is"); ASSERT_EQ(ND4J_STATUS_OK, result->status()); ASSERT_TRUE(exp.isSameShape(output)); ASSERT_TRUE(exp.equalsTo(output)); delete result; } //////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests7, Test_Reduce_Min_2) { auto x = NDArrayFactory::create('c', {2, 3, 4}); auto exp = NDArrayFactory::create('c', {1,1,4}, {1.f, 2.f, 3.f, 4.f}); x.linspace(1); nd4j::ops::reduce_min op; auto result = op.execute({&x}, {1.}, {0, 1}); auto output = result->at(0); // output->printIndexedBuffer("Result is"); ASSERT_EQ(ND4J_STATUS_OK, result->status()); ASSERT_TRUE(exp.isSameShape(output)); ASSERT_TRUE(exp.equalsTo(output)); delete result; } //////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests7, Test_Reduce_Min_3) { auto x = NDArrayFactory::create('c', {2, 3, 4}); auto exp = NDArrayFactory::create('c', {3}, {1.f, 5.f, 9.f}); x.linspace(1); nd4j::ops::reduce_min op; auto result = op.execute({&x}, {}, {0, 2}); auto output = result->at(0); // output->printIndexedBuffer("Result is"); ASSERT_EQ(ND4J_STATUS_OK, result->status()); ASSERT_TRUE(exp.isSameShape(output)); ASSERT_TRUE(exp.equalsTo(output)); delete result; } //////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests7, Test_Reduce_Min_4) { auto x = NDArrayFactory::create('c', {2, 3, 4}); auto exp = NDArrayFactory::create('c', {1,3,1}, {1.f, 5.f, 9.f}); x.linspace(1); nd4j::ops::reduce_min op; auto result = op.execute({&x}, {1.}, {0, 2}); auto output = result->at(0); // output->printIndexedBuffer("Result is"); ASSERT_EQ(ND4J_STATUS_OK, result->status()); ASSERT_TRUE(exp.isSameShape(output)); ASSERT_TRUE(exp.equalsTo(output)); delete result; } //////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests7, Test_Reduce_Min_5) { auto x = NDArrayFactory::create('c', {2, 3, 4}); auto exp = NDArrayFactory::create(1.f); x.linspace(1); nd4j::ops::reduce_min op; auto result = op.execute({&x}, {}, {}); auto output = result->at(0); // output->printIndexedBuffer("Result is"); ASSERT_EQ(ND4J_STATUS_OK, result->status()); ASSERT_TRUE(exp.isSameShape(output)); ASSERT_TRUE(exp.equalsTo(output)); delete result; } //////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests7, Test_Reduce_Min_6) { auto x = NDArrayFactory::create('c', {2, 3, 4}); auto exp = NDArrayFactory::create(1.f); x.linspace(1); nd4j::ops::reduce_min op; auto result = op.execute({&x}, {}, {0,1,2}); auto output = result->at(0); // output->printIndexedBuffer("Result is"); ASSERT_EQ(ND4J_STATUS_OK, result->status()); ASSERT_TRUE(exp.isSameShape(output)); ASSERT_TRUE(exp.equalsTo(output)); delete result; } //////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests7, Test_Reduce_Min_7) { auto x = NDArrayFactory::create('c', {2, 3, 4}); auto exp = NDArrayFactory::create('c', {1, 1, 1}, {1.f}); x.linspace(1); // x.printIndexedBuffer("Input with shape (2, 3, 4) is"); nd4j::ops::reduce_min op; auto result = op.execute({&x}, {1.}, {0,1,2}); auto output = result->at(0); // output->printIndexedBuffer("Result is"); ASSERT_EQ(ND4J_STATUS_OK, result->status()); ASSERT_TRUE(exp.isSameShape(output)); ASSERT_TRUE(exp.equalsTo(output)); delete result; } //////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests7, Test_Reduce_Max_1) { auto x = NDArrayFactory::create('c', {2, 3, 4}); auto exp = NDArrayFactory::create('c', {4}, {21.f, 22.f, 23.f, 24.f}); x.linspace(1); nd4j::ops::reduce_max op; auto result = op.execute({&x}, {}, {0,1}); auto output = result->at(0); // output->printIndexedBuffer("Result is"); // output->printShapeInfo("Output shape"); ASSERT_EQ(ND4J_STATUS_OK, result->status()); ASSERT_TRUE(exp.isSameShape(output)); ASSERT_TRUE(exp.equalsTo(output)); delete result; } //////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests7, Test_Reduce_Max_2) { auto x = NDArrayFactory::create('c', {2,3,4}); auto exp = NDArrayFactory::create('c', {1,1,4}, {21.f, 22.f, 23.f, 24.f}); x.linspace(1); nd4j::ops::reduce_max op; auto result = op.execute({&x}, {1.}, {0, 1}); auto output = result->at(0); // output->printIndexedBuffer("Result is"); ASSERT_EQ(ND4J_STATUS_OK, result->status()); ASSERT_TRUE(exp.isSameShape(output)); ASSERT_TRUE(exp.equalsTo(output)); delete result; } //////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests7, Test_Reduce_Max_3) { auto x = NDArrayFactory::create('c', {2, 3, 4}); auto exp = NDArrayFactory::create('c', {3}, {16.f, 20.f, 24.f}); x.linspace(1); nd4j::ops::reduce_max op; auto result = op.execute({&x}, {}, {0, 2}); auto output = result->at(0); // output->printIndexedBuffer("Result is"); ASSERT_EQ(ND4J_STATUS_OK, result->status()); ASSERT_TRUE(exp.isSameShape(output)); ASSERT_TRUE(exp.equalsTo(output)); delete result; } //////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests7, Test_Reduce_Max_4) { auto x = NDArrayFactory::create('c', {2, 3, 4}); auto exp = NDArrayFactory::create('c', {1,3,1}, {16.f, 20.f, 24.f}); x.linspace(1); nd4j::ops::reduce_max op; auto result = op.execute({&x}, {1.}, {0, 2}); auto output = result->at(0); // output->printIndexedBuffer("Result is"); ASSERT_EQ(ND4J_STATUS_OK, result->status()); ASSERT_TRUE(exp.isSameShape(output)); ASSERT_TRUE(exp.equalsTo(output)); delete result; } //////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests7, Test_Reduce_Max_5) { auto x = NDArrayFactory::create('c', {2, 3, 4}); auto exp = NDArrayFactory::create(24.f); x.linspace(1); nd4j::ops::reduce_max op; auto result = op.execute({&x}, {}, {}); auto output = result->at(0); // output->printIndexedBuffer("Result is"); ASSERT_EQ(ND4J_STATUS_OK, result->status()); ASSERT_TRUE(exp.isSameShape(output)); ASSERT_TRUE(exp.equalsTo(output)); delete result; } //////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests7, Test_Reduce_Max_6) { auto x = NDArrayFactory::create('c', {2, 3, 4}); auto exp = NDArrayFactory::create(24.f); x.linspace(1); nd4j::ops::reduce_max op; auto result = op.execute({&x}, {}, {0,1,2}); auto output = result->at(0); // output->printIndexedBuffer("Result is"); ASSERT_EQ(ND4J_STATUS_OK, result->status()); ASSERT_TRUE(exp.isSameShape(output)); ASSERT_TRUE(exp.equalsTo(output)); delete result; } //////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests7, Test_Reduce_Max_7) { auto x = NDArrayFactory::create('c', {2, 3, 4}); auto exp = NDArrayFactory::create('c', {1, 1, 1}, {24.f}); x.linspace(1); // x.printIndexedBuffer("Input with shape (2, 3, 4) is"); nd4j::ops::reduce_max op; auto result = op.execute({&x}, {1.}, {0,1,2}); auto output = result->at(0); // output->printIndexedBuffer("Result is"); ASSERT_EQ(ND4J_STATUS_OK, result->status()); ASSERT_TRUE(exp.isSameShape(output)); ASSERT_TRUE(exp.equalsTo(output)); delete result; } //////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests7, Test_Reduce_Norm1_1) { auto x = NDArrayFactory::create('c', {2, 3, 4}); auto exp = NDArrayFactory::create('c', {4}, {66.f, 72.f, 78.f, 84.f}); x.linspace(1); nd4j::ops::reduce_norm1 op; auto result = op.execute({&x}, {}, {0,1}); auto output = result->at(0); // output->printIndexedBuffer("Result is"); ASSERT_EQ(ND4J_STATUS_OK, result->status()); ASSERT_TRUE(exp.isSameShape(output)); ASSERT_TRUE(exp.equalsTo(output)); delete result; } //////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests7, Test_Reduce_Norm1_2) { auto x = NDArrayFactory::create('c', {2,3,4}); auto exp = NDArrayFactory::create('c', {1,1,4}, {66.f, 72.f, 78.f, 84.f}); x.linspace(1); nd4j::ops::reduce_norm1 op; auto result = op.execute({&x}, {1.}, {0, 1}); auto output = result->at(0); // output->printIndexedBuffer("Result is"); ASSERT_EQ(ND4J_STATUS_OK, result->status()); ASSERT_TRUE(exp.isSameShape(output)); ASSERT_TRUE(exp.equalsTo(output)); delete result; } //////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests7, Test_Reduce_Norm1_3) { auto x = NDArrayFactory::create('c', {2, 3, 4}); auto exp = NDArrayFactory::create('c', {3}, {68.f, 100.f, 132.f}); x.linspace(1); nd4j::ops::reduce_norm1 op; auto result = op.execute({&x}, {}, {0, 2}); auto output = result->at(0); // output->printIndexedBuffer("Result is"); ASSERT_EQ(ND4J_STATUS_OK, result->status()); ASSERT_TRUE(exp.isSameShape(output)); ASSERT_TRUE(exp.equalsTo(output)); delete result; } //////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests7, Test_Reduce_Norm1_4) { auto x = NDArrayFactory::create('c', {2, 3, 4}); auto exp = NDArrayFactory::create('c', {1,3,1}, {68.f, 100.f, 132.f}); x.linspace(1); nd4j::ops::reduce_norm1 op; auto result = op.execute({&x}, {1.}, {0, 2}); auto output = result->at(0); // output->printIndexedBuffer("Result is"); ASSERT_EQ(ND4J_STATUS_OK, result->status()); ASSERT_TRUE(exp.isSameShape(output)); ASSERT_TRUE(exp.equalsTo(output)); delete result; } //////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests7, Test_Reduce_Norm1_5) { auto x = NDArrayFactory::create('c', {2, 3, 4}); auto exp = NDArrayFactory::create(300.f); x.linspace(1); nd4j::ops::reduce_norm1 op; auto result = op.execute({&x}, {}, {}); auto output = result->at(0); // output->printIndexedBuffer("Result is"); ASSERT_EQ(ND4J_STATUS_OK, result->status()); ASSERT_TRUE(exp.isSameShape(output)); ASSERT_TRUE(exp.equalsTo(output)); delete result; } //////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests7, Test_Reduce_Norm1_6) { auto x = NDArrayFactory::create('c', {2, 3, 4}); auto exp = NDArrayFactory::create(300.f); x.linspace(1); nd4j::ops::reduce_norm1 op; auto result = op.execute({&x}, {}, {0,1,2}); auto output = result->at(0); // output->printIndexedBuffer("Result is"); ASSERT_EQ(ND4J_STATUS_OK, result->status()); ASSERT_TRUE(exp.isSameShape(output)); ASSERT_TRUE(exp.equalsTo(output)); delete result; } //////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests7, Test_Reduce_Norm1_7) { auto x = NDArrayFactory::create('c', {2, 3, 4}); auto exp = NDArrayFactory::create('c', {1, 1, 1}, {300.f}); x.linspace(1); // x.printIndexedBuffer("Input with shape (2, 3, 4) is"); nd4j::ops::reduce_norm1 op; auto result = op.execute({&x}, {1.}, {0,1,2}); auto output = result->at(0); // output->printIndexedBuffer("Result is"); ASSERT_EQ(ND4J_STATUS_OK, result->status()); ASSERT_TRUE(exp.isSameShape(output)); ASSERT_TRUE(exp.equalsTo(output)); delete result; } TEST_F(DeclarableOpsTests7, Test_Reduce_Norm2_1) { auto x = NDArrayFactory::create('c', {2, 3, 4}); auto exp = NDArrayFactory::create('c', {4}, {31.7175f, 33.823071f, 35.97221f, 38.15757f}); x.linspace(1); nd4j::ops::reduce_norm2 op; auto result = op.execute({&x}, {}, {0,1}); auto output = result->at(0); // output->printIndexedBuffer("Result is"); ASSERT_EQ(ND4J_STATUS_OK, result->status()); ASSERT_TRUE(exp.isSameShape(output)); ASSERT_TRUE(exp.equalsTo(output)); delete result; } //////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests7, Test_Reduce_Norm2_2) { auto x = NDArrayFactory::create('c', {2,3,4}); auto exp = NDArrayFactory::create('c', {1,1,4}, {31.7175f, 33.823071f, 35.97221f, 38.15757f}); x.linspace(1); nd4j::ops::reduce_norm2 op; auto result = op.execute({&x}, {1.}, {0, 1}); auto output = result->at(0); // output->printIndexedBuffer("Result is"); ASSERT_EQ(ND4J_STATUS_OK, result->status()); ASSERT_TRUE(exp.isSameShape(output)); ASSERT_TRUE(exp.equalsTo(output)); delete result; } //////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests7, Test_Reduce_Norm2_3) { auto x = NDArrayFactory::create('c', {2, 3, 4}); auto exp = NDArrayFactory::create('c', {3}, {29.597298f, 39.344631f, 49.759422f}); x.linspace(1); nd4j::ops::reduce_norm2 op; auto result = op.execute({&x}, {}, {0, 2}); auto output = result->at(0); // output->printIndexedBuffer("Result is"); ASSERT_EQ(ND4J_STATUS_OK, result->status()); ASSERT_TRUE(exp.isSameShape(output)); ASSERT_TRUE(exp.equalsTo(output)); delete result; } //////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests7, Test_Reduce_Norm2_4) { auto x = NDArrayFactory::create('c', {2, 3, 4}); auto exp = NDArrayFactory::create('c', {1,3,1}, {29.597298f, 39.344631f, 49.759422f}); x.linspace(1); nd4j::ops::reduce_norm2 op; auto result = op.execute({&x}, {1.}, {0, 2}); auto output = result->at(0); // output->printIndexedBuffer("Result is"); ASSERT_EQ(ND4J_STATUS_OK, result->status()); ASSERT_TRUE(exp.isSameShape(output)); ASSERT_TRUE(exp.equalsTo(output)); delete result; } //////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests7, Test_Reduce_Norm2_5) { auto x = NDArrayFactory::create('c', {2, 3, 4}); auto exp = NDArrayFactory::create(70.f); x.linspace(1); nd4j::ops::reduce_norm2 op; auto result = op.execute({&x}, {}, {}); auto output = result->at(0); // output->printIndexedBuffer("Result is"); ASSERT_EQ(ND4J_STATUS_OK, result->status()); ASSERT_TRUE(exp.isSameShape(output)); ASSERT_TRUE(exp.equalsTo(output)); delete result; } //////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests7, Test_Reduce_Norm2_6) { auto x = NDArrayFactory::create('c', {2, 3, 4}); auto exp = NDArrayFactory::create(70.f); x.linspace(1); nd4j::ops::reduce_norm2 op; auto result = op.execute({&x}, {}, {0,1,2}); auto output = result->at(0); // output->printIndexedBuffer("Result is"); ASSERT_EQ(ND4J_STATUS_OK, result->status()); ASSERT_TRUE(exp.isSameShape(output)); ASSERT_TRUE(exp.equalsTo(output)); delete result; } //////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests7, Test_Reduce_Norm2_7) { auto x = NDArrayFactory::create('c', {2, 3, 4}); auto exp = NDArrayFactory::create('c', {1, 1, 1}, {70.f}); x.linspace(1); // x.printIndexedBuffer("Input with shape (2, 3, 4) is"); nd4j::ops::reduce_norm2 op; auto result = op.execute({&x}, {1.}, {0,1,2}); auto output = result->at(0); // output->printIndexedBuffer("Result is"); ASSERT_EQ(ND4J_STATUS_OK, result->status()); ASSERT_TRUE(exp.isSameShape(output)); ASSERT_TRUE(exp.equalsTo(output)); delete result; } //////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests7, Test_Reduce_NormMax_1) { auto x = NDArrayFactory::create('c', {2, 3, 4}); auto exp = NDArrayFactory::create('c', {4}, {21.f, 22.f, 23.f, 24.f}); x.linspace(1); nd4j::ops::reduce_norm_max op; auto result = op.execute({&x}, {}, {0,1}); auto output = result->at(0); // output->printIndexedBuffer("Result is"); ASSERT_EQ(ND4J_STATUS_OK, result->status()); ASSERT_TRUE(exp.isSameShape(output)); ASSERT_TRUE(exp.equalsTo(output)); delete result; } //////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests7, Test_Reduce_NormMax_2) { auto x = NDArrayFactory::create('c', {2, 3, 4}); auto exp = NDArrayFactory::create('c', {1,1,4}, {21.f, 22.f, 23.f, 24.f}); x.linspace(1); nd4j::ops::reduce_norm_max op; auto result = op.execute({&x}, {1.f}, {0,1}); auto output = result->at(0); // output->printIndexedBuffer("Result is"); ASSERT_EQ(ND4J_STATUS_OK, result->status()); ASSERT_TRUE(exp.isSameShape(output)); ASSERT_TRUE(exp.equalsTo(output)); delete result; } //////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests7, Test_Reduce_NormMax_3) { auto x = NDArrayFactory::create('c', {2, 3, 4}); auto exp = NDArrayFactory::create('c', {3}, {16.f, 20.f, 24.f}); x.linspace(1); nd4j::ops::reduce_norm_max op; auto result = op.execute({&x}, {}, {0,2}); auto output = result->at(0); // output->printIndexedBuffer("Result is"); ASSERT_EQ(ND4J_STATUS_OK, result->status()); ASSERT_TRUE(exp.isSameShape(output)); ASSERT_TRUE(exp.equalsTo(output)); delete result; } //////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests7, Test_Reduce_NormMax_4) { auto x = NDArrayFactory::create('c', {2, 3, 4}); auto exp = NDArrayFactory::create('c', {1, 3, 1}, {16.f, 20.f, 24.f}); x.linspace(1); nd4j::ops::reduce_norm_max op; auto result = op.execute({&x}, {1.f}, {0,2}); auto output = result->at(0); // output->printIndexedBuffer("Result is"); ASSERT_EQ(ND4J_STATUS_OK, result->status()); ASSERT_TRUE(exp.isSameShape(output)); ASSERT_TRUE(exp.equalsTo(output)); delete result; } //////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests7, Test_Reduce_NormMax_5) { auto x = NDArrayFactory::create('c', {2, 3, 4}); auto exp = NDArrayFactory::create(24.f); x.linspace(1); nd4j::ops::reduce_norm_max op; auto result = op.execute({&x}, {}, {}); auto output = result->at(0); // output->printIndexedBuffer("Result is"); ASSERT_EQ(ND4J_STATUS_OK, result->status()); ASSERT_TRUE(exp.isSameShape(output)); ASSERT_TRUE(exp.equalsTo(output)); delete result; } //////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests7, Test_Reduce_NormMax_6) { auto x = NDArrayFactory::create('c', {2, 3, 4}); auto exp = NDArrayFactory::create(24.f); x.linspace(1); nd4j::ops::reduce_norm_max op; auto result = op.execute({&x}, {}, {0, 1, 2}); auto output = result->at(0); // output->printIndexedBuffer("Result is"); ASSERT_EQ(ND4J_STATUS_OK, result->status()); ASSERT_TRUE(exp.isSameShape(output)); ASSERT_TRUE(exp.equalsTo(output)); delete result; } //////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests7, Test_Reduce_NormMax_7) { auto x = NDArrayFactory::create('c', {2, 3, 4}); auto exp = NDArrayFactory::create('c', {1, 1, 1}, {24.f}); x.linspace(1); nd4j::ops::reduce_norm_max op; auto result = op.execute({&x}, {1.f}, {}); auto output = result->at(0); // output->printIndexedBuffer("Result is"); ASSERT_EQ(ND4J_STATUS_OK, result->status()); ASSERT_TRUE(exp.isSameShape(output)); ASSERT_TRUE(exp.equalsTo(output)); delete result; } //////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests7, Test_Reduce_SquaredNorm_1) { auto x = NDArrayFactory::create('c', {2, 3, 4}); auto exp = NDArrayFactory::create('c', {4}, {1006.f, 1144.f, 1294.f, 1456.f}); x.linspace(1); nd4j::ops::reduce_sqnorm op; auto result = op.execute({&x}, {}, {0,1}); auto output = result->at(0); // output->printIndexedBuffer("Result is"); ASSERT_EQ(ND4J_STATUS_OK, result->status()); ASSERT_TRUE(exp.isSameShape(output)); ASSERT_TRUE(exp.equalsTo(output)); delete result; } //////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests7, Test_Reduce_SquaredNorm_2) { auto x = NDArrayFactory::create('c', {2, 3, 4}); auto exp = NDArrayFactory::create('c', {1,1,4}, {1006.f, 1144.f, 1294.f, 1456.f}); x.linspace(1); nd4j::ops::reduce_sqnorm op; auto result = op.execute({&x}, {1.f}, {0,1}); auto output = result->at(0); // output->printIndexedBuffer("Result is"); ASSERT_EQ(ND4J_STATUS_OK, result->status()); ASSERT_TRUE(exp.isSameShape(output)); ASSERT_TRUE(exp.equalsTo(output)); delete result; } //////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests7, Test_Reduce_SquaredNorm_3) { auto x = NDArrayFactory::create('c', {2, 3, 4}); auto exp = NDArrayFactory::create('c', {3}, {876.f, 1548.f, 2476.f}); x.linspace(1); nd4j::ops::reduce_sqnorm op; auto result = op.execute({&x}, {}, {0,2}); auto output = result->at(0); // output->printIndexedBuffer("Result is"); ASSERT_EQ(ND4J_STATUS_OK, result->status()); ASSERT_TRUE(exp.isSameShape(output)); ASSERT_TRUE(exp.equalsTo(output)); delete result; } //////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests7, Test_Reduce_SquaredNorm_4) { auto x = NDArrayFactory::create('c', {2, 3, 4}); auto exp = NDArrayFactory::create('c', {1, 3, 1}, {876.f, 1548.f, 2476.f}); x.linspace(1); nd4j::ops::reduce_sqnorm op; auto result = op.execute({&x}, {1.f}, {0,2}); auto output = result->at(0); // output->printIndexedBuffer("Result is"); ASSERT_EQ(ND4J_STATUS_OK, result->status()); ASSERT_TRUE(exp.isSameShape(output)); ASSERT_TRUE(exp.equalsTo(output)); delete result; } //////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests7, Test_Reduce_SquaredNorm_5) { auto x = NDArrayFactory::create('c', {2, 3, 4}); auto exp = NDArrayFactory::create(4900.f); x.linspace(1); nd4j::ops::reduce_sqnorm op; auto result = op.execute({&x}, {}, {}); auto output = result->at(0); // output->printIndexedBuffer("Result is"); ASSERT_EQ(ND4J_STATUS_OK, result->status()); ASSERT_TRUE(exp.isSameShape(output)); ASSERT_TRUE(exp.equalsTo(output)); delete result; } //////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests7, Test_Reduce_SquaredNorm_6) { auto x = NDArrayFactory::create('c', {2, 3, 4}); auto exp = NDArrayFactory::create(4900.f); x.linspace(1); nd4j::ops::reduce_sqnorm op; auto result = op.execute({&x}, {}, {0, 1, 2}); auto output = result->at(0); // output->printIndexedBuffer("Result is"); ASSERT_EQ(ND4J_STATUS_OK, result->status()); ASSERT_TRUE(exp.isSameShape(output)); ASSERT_TRUE(exp.equalsTo(output)); delete result; } //////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests7, Test_Reduce_SquaredNorm_7) { auto x = NDArrayFactory::create('c', {2, 3, 4}); auto exp = NDArrayFactory::create('c', {1, 1, 1}, {4900.f}); x.linspace(1); nd4j::ops::reduce_sqnorm op; auto result = op.execute({&x}, {1.f}, {}); auto output = result->at(0); // output->printIndexedBuffer("Result is"); ASSERT_EQ(ND4J_STATUS_OK, result->status()); ASSERT_TRUE(exp.isSameShape(output)); ASSERT_TRUE(exp.equalsTo(output)); delete result; } //////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests7, Test_Reduce_Sum_BP_1) { auto input = NDArrayFactory::create('c', {3, 4}, {1., 2., 3., 4., 5., 6., 7., 8., 9., 10., 11., 12.}); auto eps = NDArrayFactory::create(0.5f); auto exp = NDArrayFactory::create('c', {3, 4}, {0.5f, 0.5f, 0.5f, 0.5f, 0.5f, 0.5f, 0.5f, 0.5f, 0.5f, 0.5f, 0.5f,0.5f}); //************************************// nd4j::ops::reduce_sum_bp op; auto result = op.execute({&input, &eps}, {}, {}); ASSERT_EQ(Status::OK(), result->status()); auto z = result->at(0); // z->printIndexedBuffer("Result is "); // z->printShapeInfo(); ASSERT_TRUE(exp.equalsTo(z)); delete result; } //////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests7, Test_Reduce_Sum_BP_2) { auto input = NDArrayFactory::create('c', {3, 4}, {1., 2., 3., 4., 5., 6., 7., 8., 9., 10., 11., 12.}); auto eps = NDArrayFactory::create('c', {1, 1}, {0.5f}); auto exp = NDArrayFactory::create('c', {3, 4}, {0.5f, 0.5f, 0.5f, 0.5f, 0.5f, 0.5f, 0.5f, 0.5f, 0.5f, 0.5f, 0.5f,0.5f}); //************************************// nd4j::ops::reduce_sum_bp op; auto result = op.execute({&input, &eps}, {1.f}, {}); ASSERT_EQ(Status::OK(), result->status()); auto z = result->at(0); // z->printIndexedBuffer("Result is "); // z->printShapeInfo(); ASSERT_TRUE(exp.equalsTo(z)); delete result; } //////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests7, Test_Reduce_Sum_BP_3) { auto input = NDArrayFactory::create('c', {3, 4}, {1., 2., 3., 4., 5., 6., 7., 8., 9., 10., 11., 12.}); auto eps = NDArrayFactory::create('c', {4}, {1.f, 2.f, 3.f, 4.f}); auto exp = NDArrayFactory::create('c', {3, 4}, {1.f, 2.f, 3.f, 4.f, 1.f, 2.f, 3.f, 4.f, 1.f, 2.f, 3.f, 4.f}); //************************************// nd4j::ops::reduce_sum_bp op; auto result = op.execute({&input, &eps}, {}, {0}); ASSERT_EQ(Status::OK(), result->status()); auto z = result->at(0); // z->printIndexedBuffer("Result is "); // z->printShapeInfo(); ASSERT_TRUE(exp.equalsTo(z)); delete result; } //////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests7, Test_Reduce_Sum_BP_4) { auto input = NDArrayFactory::create('c', {3, 4}, {1., 2., 3., 4., 5., 6., 7., 8., 9., 10., 11., 12.}); auto eps = NDArrayFactory::create('c', {1, 4}, {1.f, 2.f, 3.f, 4.f}); auto exp = NDArrayFactory::create('c', {3, 4}, {1.f, 2.f, 3.f, 4.f, 1.f, 2.f, 3.f, 4.f, 1.f, 2.f, 3.f, 4.f}); //************************************// nd4j::ops::reduce_sum_bp op; auto result = op.execute({&input, &eps}, {1.f}, {0}); ASSERT_EQ(Status::OK(), result->status()); auto z = result->at(0); // z->printIndexedBuffer("Result is "); // z->printShapeInfo(); ASSERT_TRUE(exp.equalsTo(z)); delete result; } //////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests7, Test_Reduce_Prod_BP_1) { auto input = NDArrayFactory::create('c', {3, 5}, {1.f, 2.f, 3.f, 4.f, 5.f, 6.f, 7.f, 8.f, 9.f, 10.f, 11.f, 12.f, 13.f, 14.f, 15.f}); auto eps = NDArrayFactory::create(1307674368000.f); //************************************// // auto exp = NDArrayFactory::create('c', {3, 4}, {0.5f, 0.5f, 0.5f, 0.5f, 0.5f, 0.5f, 0.5f, 0.5f, 0.5f, 0.5f, 0.5f,0.5f}); //************************************// auto exp = NDArrayFactory::create('c', {3, 5}, {1710012166826558903812096.f, 855006083413279451906048.f, 570004067618451974258688.f, 427503041706639725953024.f, 342002454982589992140800.f, 285002033809225987129344.f, 244287457550765131825152.f, 213751520853319862976512.f, 190001355872817324752896.f, 171001227491294996070400.f, 155455648254341989531648.f, 142501016904612993564672.f, 131539399526781282156544.f, 122143728775382565912576.f, 114000815325130245799936.f}); nd4j::ops::reduce_prod_bp op; auto result = op.execute({&input, &eps}, {}, {}); ASSERT_EQ(Status::OK(), result->status()); auto z = result->at(0); // z->printIndexedBuffer("Result is "); // z->printShapeInfo(); ASSERT_TRUE(exp.equalsTo(z)); delete result; } //////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests7, Test_Reduce_Prod_BP_2) { auto input = NDArrayFactory::create('c', {3, 4}, {1.f, 2.f, 3.f, 4.f, 5.f, 6.f, 7.f, 8.f, 9.f, 10.f, 11.f, 12.f}); auto eps = NDArrayFactory::create(0.5f); //************************************// // auto exp = NDArrayFactory::create('c', {3, 4}, {0.5f, 0.5f, 0.5f, 0.5f, 0.5f, 0.5f, 0.5f, 0.5f, 0.5f, 0.5f, 0.5f,0.5f}); //************************************// auto exp = NDArrayFactory::create('c', {3, 4}); nd4j::ops::reduce_prod_bp op; nd4j::ops::reduce_prod op_exp; auto res = op_exp.execute({&input}, {}, {}); auto result = op.execute({&input, &eps}, {}, {}); exp.assign(res->at(0)->e(0)); exp /= input; exp *= eps.e(0); ASSERT_EQ(Status::OK(), result->status()); auto z = result->at(0); //z->printIndexedBuffer("Result is "); //exp.printIndexedBuffer("Expected"); // z->printShapeInfo(); ASSERT_TRUE(exp.equalsTo(z)); delete result; delete res; } //////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests7, Test_Reduce_Prod_BP_3) { auto input = NDArrayFactory::create('c', {3, 4}, {1.f, 2.f, 3.f, 4.f, 5.f, 6.f, 7.f, 8.f, 9.f, 10.f, 11.f, 12.f}); auto eps = NDArrayFactory::create('c', {1, 4}, {1.f, 2.f, 3.f, 4.f}); //************************************// auto exp = NDArrayFactory::create('c', {3, 4}, {45.f, 120.f, 231.f, 384.f, 9.f, 40.f, 99.f, 192.f, 5.f, 24.f, 63.f, 128.f}); nd4j::ops::reduce_prod_bp op; //nd4j::ops::reduce_prod op_exp; auto result = op.execute({&input, &eps}, {1.f}, {0}); ASSERT_EQ(Status::OK(), result->status()); auto z = result->at(0); // z->printIndexedBuffer("Result is "); // exp.printIndexedBuffer("Expected"); // z->printShapeInfo(); ASSERT_TRUE(exp.equalsTo(z)); delete result; } //////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests7, Test_Reduce_Prod_BP_03) { int ax = 0; auto input = NDArrayFactory::create('c', {3, 4}, {1.f, 2.f, 3.f, 4.f, 5.f, 6.f, 7.f, 8.f, 9.f, 10.f, 11.f, 12.f}); auto eps = NDArrayFactory::create('c', {1, 4}, {1.f, 2.f, 3.f, 4.f}); //************************************// auto exp = NDArrayFactory::create('c', {3, 4}, {45.f, 120.f, 231.f, 384.f, 9.f, 40.f, 99.f, 192.f, 5.f, 24.f, 63.f, 128.f}); auto axis = NDArrayFactory::create('c', {1}, {ax}); nd4j::ops::reduce_prod_bp op; //nd4j::ops::reduce_prod op_exp; auto result = op.execute({&input, &eps, &axis}, {}, {}, {true}); ASSERT_EQ(Status::OK(), result->status()); auto z = result->at(0); // z->printIndexedBuffer("Result is "); // exp.printIndexedBuffer("Expected"); // z->printShapeInfo(); ASSERT_TRUE(exp.equalsTo(z)); delete result; } //////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests7, Test_Reduce_Prod_BP_4) { auto input = NDArrayFactory::create('c', {3, 4}, {1.f, 2.f, 3.f, 4.f, 5.f, 6.f, 7.f, 8.f, 9.f, 10.f, 11.f, 12.f}); auto eps = NDArrayFactory::create('c', {4}, {1.f, 2.f, 3.f, 4.f}); //************************************// auto exp = NDArrayFactory::create('c', {3, 4}, {45.f, 120.f, 231.f, 384.f, 9.f, 40.f, 99.f, 192.f, 5.f, 24.f, 63.f, 128.f}); nd4j::ops::reduce_prod_bp op; nd4j::ops::reduce_prod op_exp; // auto res = op_exp.execute({&input}, {}, {}); auto result = op.execute({&input, &eps}, {0.f}, {0}); ASSERT_EQ(Status::OK(), result->status()); auto z = result->at(0); // z->printIndexedBuffer("Result is "); // exp.printIndexedBuffer("Expected"); // z->printShapeInfo(); ASSERT_TRUE(exp.equalsTo(z)); delete result; // delete res; } //////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests7, Test_Reduce_Prod_BP_5) { auto input = NDArrayFactory::create('c', {3, 4}, {1.f, 2.f, 3.f, 4.f, 5.f, 6.f, 7.f, 8.f, 9.f, 10.f, 11.f, 12.f}); auto eps = NDArrayFactory::create('c', {3}, {1.f, 2.f, 3.f}); //************************************// auto exp = NDArrayFactory::create('c', {3, 4}, {24.f, 12.f, 8.f, 6.f, 672.f, 560.f, 480.f, 420.f, 3960.f, 3564.f, 3240.f, 2970.f}); nd4j::ops::reduce_prod_bp op; nd4j::ops::reduce_prod op_exp; // auto res = op_exp.execute({&input}, {}, {}); auto result = op.execute({&input, &eps}, {0.f}, {1}); ASSERT_EQ(Status::OK(), result->status()); auto z = result->at(0); // z->printIndexedBuffer("Result is "); // exp.printIndexedBuffer("Expected"); // z->printShapeInfo(); ASSERT_TRUE(exp.equalsTo(z)); delete result; // delete res; } //////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests7, Test_Reduce_Min_BP_1) { auto x = NDArrayFactory::create('c', {2, 3, 4}); auto eps = NDArrayFactory::create('c', {4}, {1.f, 2.f, 3.f, 4.f}); auto exp = NDArrayFactory::create('c', {2, 3, 4}); exp.p(0, eps.e(0)); exp.p(1, eps.e(1)); exp.p(2, eps.e(2)); exp.p(3, eps.e(3)); x.linspace(1); // x.printIndexedBuffer("Input is"); // exp.printIndexedBuffer("Expected "); nd4j::ops::reduce_min_bp op; auto result = op.execute({&x, &eps}, {}, {0, 1}); auto output = result->at(0); // output->printIndexedBuffer("Result is"); ASSERT_EQ(ND4J_STATUS_OK, result->status()); ASSERT_TRUE(exp.isSameShape(output)); ASSERT_TRUE(exp.equalsTo(output)); delete result; } //////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests7, Test_Reduce_Min_BP_2) { auto x = NDArrayFactory::create('c', {2, 3, 4}); auto eps = NDArrayFactory::create('c', {1, 1, 4}, {1.f, 2.f, 3.f, 4.f}); auto exp = NDArrayFactory::create('c', {2, 3, 4}); exp.p(0, eps.e(0)); exp.p(1, eps.e(1)); exp.p(2, eps.e(2)); exp.p(3, eps.e(3)); x.linspace(1); // x.printIndexedBuffer("Input is"); // exp.printIndexedBuffer("Expected "); nd4j::ops::reduce_min_bp op; auto result = op.execute({&x, &eps}, {1.f}, {0, 1}); auto output = result->at(0); // output->printIndexedBuffer("Result is"); ASSERT_EQ(ND4J_STATUS_OK, result->status()); ASSERT_TRUE(exp.isSameShape(output)); ASSERT_TRUE(exp.equalsTo(output)); delete result; } //////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests7, Test_Reduce_Min_BP_02) { auto x = NDArrayFactory::create('c', {2, 3, 4}); auto eps = NDArrayFactory::create('c', {1, 1, 4}, {1.f, 2.f, 3.f, 4.f}); auto exp = NDArrayFactory::create('c', {2, 3, 4}); exp.p(0, eps.e(0)); exp.p(1, eps.e(1)); exp.p(2, eps.e(2)); exp.p(3, eps.e(3)); auto axes = NDArrayFactory::create({0,1}); x.linspace(1); // x.printIndexedBuffer("Input is"); // exp.printIndexedBuffer("Expected "); nd4j::ops::reduce_min_bp op; auto result = op.execute({&x, &eps, &axes}, {}, {}, {true}); auto output = result->at(0); // output->printIndexedBuffer("Result is"); ASSERT_EQ(ND4J_STATUS_OK, result->status()); ASSERT_TRUE(exp.isSameShape(output)); ASSERT_TRUE(exp.equalsTo(output)); delete result; } //////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests7, Test_Reduce_Min_BP_3) { auto x = NDArrayFactory::create('c', {3, 4}); auto eps = NDArrayFactory::create('c', {1, 1}, {0.5f}); auto exp = NDArrayFactory::create('c', {3, 4}); x.linspace(1); x.p(2,2, -1.f); exp.p(2,2, 0.5f); //x.printIndexedBuffer("Input is"); // exp.printIndexedBuffer("Expected "); nd4j::ops::reduce_min_bp op; auto result = op.execute({&x, &eps}, {1.f}, {}); auto output = result->at(0); // output->printIndexedBuffer("Result is"); ASSERT_EQ(ND4J_STATUS_OK, result->status()); ASSERT_TRUE(exp.isSameShape(output)); ASSERT_TRUE(exp.equalsTo(output)); delete result; } //////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests7, Test_Reduce_Min_BP_4) { auto x = NDArrayFactory::create('c', {3, 4}); auto eps = NDArrayFactory::create(0.5f); auto exp = NDArrayFactory::create('c', {3, 4}); x.linspace(1); x.p(2,2, -1.f); exp.p(2,2, 0.5f); // x.printIndexedBuffer("Input is"); // exp.printIndexedBuffer("Expected "); nd4j::ops::reduce_min_bp op; auto result = op.execute({&x, &eps}, {}, {}); auto output = result->at(0); // output->printIndexedBuffer("Result is"); ASSERT_EQ(ND4J_STATUS_OK, result->status()); ASSERT_TRUE(exp.isSameShape(output)); ASSERT_TRUE(exp.equalsTo(output)); delete result; } //////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests7, Test_Reduce_Min_BP_5) { auto x = NDArrayFactory::create('c', {4, 4}); auto eps = NDArrayFactory::create('c', {4}, {1.f, 2.f, 3.f, 4.f}); auto exp = NDArrayFactory::create('c', {4, 4}); x.linspace(1); x.p(0,0, -1.f); x.p(1,1, -2.f); x.p(2,2, -3.f); x.p(3,3, -4.f); exp.p(0,0, 1.f); exp.p(1,1, 2.f); exp.p(2,2, 3.f); exp.p(3,3, 4.f); // exp(2,2) = 0.5f; // x.printIndexedBuffer("Input is"); // exp.printIndexedBuffer("Expected "); nd4j::ops::reduce_min_bp op; auto result = op.execute({&x, &eps}, {}, {0}); auto output = result->at(0); // output->printIndexedBuffer("Result is"); ASSERT_EQ(ND4J_STATUS_OK, result->status()); ASSERT_TRUE(exp.isSameShape(output)); ASSERT_TRUE(exp.equalsTo(output)); delete result; } //////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests7, Test_Reduce_Min_BP_6) { auto x = NDArrayFactory::create('c', {4, 4}); auto eps = NDArrayFactory::create('c', {1,4}, {1.f, 2.f, 3.f, 4.f}); auto exp = NDArrayFactory::create('c', {4, 4}); x.linspace(1); x.p(0,0, -1.f); x.p(1,1, -2.f); x.p(2,2, -3.f); x.p(3,3, -4.f); exp.p(0,0, 1.f); exp.p(1,1, 2.f); exp.p(2,2, 3.f); exp.p(3,3, 4.f); // exp(2,2) = 0.5f; // x.printIndexedBuffer("Input is"); // exp.printIndexedBuffer("Expected "); nd4j::ops::reduce_min_bp op; auto result = op.execute({&x, &eps}, {1.f}, {0}); auto output = result->at(0); // output->printIndexedBuffer("Result is"); ASSERT_EQ(ND4J_STATUS_OK, result->status()); ASSERT_TRUE(exp.isSameShape(output)); ASSERT_TRUE(exp.equalsTo(output)); delete result; } //////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests7, Test_Reduce_Max_BP_1) { auto x = NDArrayFactory::create('c', {2, 3, 4}); auto eps = NDArrayFactory::create('c', {4}, {21.f, 22.f, 23.f, 24.f}); auto exp = NDArrayFactory::create('c', {2, 3, 4}); exp.p(20, eps.e(0)); exp.p(21, eps.e(1)); exp.p(22, eps.e(2)); exp.p(23, eps.e(3)); x.linspace(1); // x.printIndexedBuffer("Input is"); // exp.printIndexedBuffer("Expected "); nd4j::ops::reduce_max_bp op; auto result = op.execute({&x, &eps}, {}, {0, 1}); auto output = result->at(0); exp.printIndexedBuffer("E"); output->printIndexedBuffer("O"); ASSERT_EQ(ND4J_STATUS_OK, result->status()); ASSERT_TRUE(exp.isSameShape(output)); ASSERT_TRUE(exp.equalsTo(output)); delete result; } //////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests7, Test_Reduce_Max_BP_2) { auto x = NDArrayFactory::create('c', {2, 3, 4}); auto eps = NDArrayFactory::create('c', {1, 1, 4}, {21.f, 22.f, 23.f, 24.f}); auto exp = NDArrayFactory::create('c', {2, 3, 4}); exp.p(20, eps.e(0)); exp.p(21, eps.e(1)); exp.p(22, eps.e(2)); exp.p(23, eps.e(3)); x.linspace(1); // x.printIndexedBuffer("Input is"); // exp.printIndexedBuffer("Expected "); nd4j::ops::reduce_max_bp op; auto result = op.execute({&x, &eps}, {1.f}, {0, 1}); auto output = result->at(0); // output->printIndexedBuffer("Result is"); ASSERT_EQ(ND4J_STATUS_OK, result->status()); ASSERT_TRUE(exp.isSameShape(output)); ASSERT_TRUE(exp.equalsTo(output)); delete result; } //////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests7, Test_Reduce_Max_BP_02) { auto x = NDArrayFactory::create('c', {2, 3, 4}); auto eps = NDArrayFactory::create('c', {1, 1, 4}, {21.f, 22.f, 23.f, 24.f}); auto exp = NDArrayFactory::create('c', {2, 3, 4}); exp.p(20, eps.e(0)); exp.p(21, eps.e(1)); exp.p(22, eps.e(2)); exp.p(23, eps.e(3)); auto axes = NDArrayFactory::create({0, 1}); x.linspace(1); // x.printIndexedBuffer("Input is"); // exp.printIndexedBuffer("Expected "); nd4j::ops::reduce_max_bp op; auto result = op.execute({&x, &eps, &axes}, {}, {}, {true}); auto output = result->at(0); // output->printIndexedBuffer("Result is"); ASSERT_EQ(ND4J_STATUS_OK, result->status()); ASSERT_TRUE(exp.isSameShape(output)); ASSERT_TRUE(exp.equalsTo(output)); delete result; } //////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests7, Test_Reduce_Max_BP_3) { auto x = NDArrayFactory::create('c', {4, 4}); auto eps = NDArrayFactory::create('c', {4}, {1.f, 2.f, 3.f, 4.f}); auto exp = NDArrayFactory::create('c', {4, 4}); x.linspace(1); x.p(0,0, 21.f); x.p(1,1, 22.f); x.p(2,2, 23.f); x.p(3,3, 24.f); exp.p(0,0, 1.f); exp.p(1,1, 2.f); exp.p(2,2, 3.f); exp.p(3,3, 4.f); // x.printIndexedBuffer("Input is"); // exp.printIndexedBuffer("Expected "); nd4j::ops::reduce_max_bp op; auto result = op.execute({&x, &eps}, {}, {0}); auto output = result->at(0); // output->printIndexedBuffer("Result is"); ASSERT_EQ(ND4J_STATUS_OK, result->status()); ASSERT_TRUE(exp.isSameShape(output)); ASSERT_TRUE(exp.equalsTo(output)); delete result; } //////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests7, Test_Reduce_Max_BP_4) { auto x = NDArrayFactory::create('c', {4, 4}); auto eps = NDArrayFactory::create('c', {1,4}, {1.f, 2.f, 3.f, 4.f}); auto exp = NDArrayFactory::create('c', {4, 4}); x.linspace(1); x.p(0,0, 21.f); x.p(1,1, 22.f); x.p(2,2, 23.f); x.p(3,3, 24.f); exp.p(0,0, 1.f); exp.p(1,1, 2.f); exp.p(2,2, 3.f); exp.p(3,3, 4.f); // x.printIndexedBuffer("Input is"); // exp.printIndexedBuffer("Expected "); nd4j::ops::reduce_max_bp op; auto result = op.execute({&x, &eps}, {1.f}, {0}); auto output = result->at(0); // output->printIndexedBuffer("Result is"); ASSERT_EQ(ND4J_STATUS_OK, result->status()); ASSERT_TRUE(exp.isSameShape(output)); ASSERT_TRUE(exp.equalsTo(output)); delete result; } //////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests7, Test_Reduce_Norm1_BP_1) { auto x = NDArrayFactory::create('c', {2, 3, 4}); auto eps = NDArrayFactory::create(5.f); x.linspace(1); auto exp = NDArrayFactory::create('c', {2, 3, 4}); x.p(12, -2.f); x.p(20, -3.f); exp.assign(5.f); exp.p(12, -exp.e(12)); exp.p(20, -exp.e(20)); nd4j::ops::reduce_norm1_bp op; auto result = op.execute({&x, &eps}, {}, {}); auto output = result->at(0); // output->printIndexedBuffer("Result is"); ASSERT_EQ(ND4J_STATUS_OK, result->status()); ASSERT_TRUE(exp.isSameShape(output)); ASSERT_TRUE(exp.equalsTo(output)); delete result; } //////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests7, Test_Reduce_Norm1_BP_2) { auto x = NDArrayFactory::create('c', {2, 3, 4}); auto eps = NDArrayFactory::create({1.f, 2.f, 3.f, 4.f}); x.linspace(1); auto exp = NDArrayFactory::create('c', {2, 3, 4}, {1.f, 2.f, 3.f, 4.f, 1.f, 2.f, 3.f, 4.f, 1.f, 2.f, 3.f, 4.f, 1.f, 2.f, 3.f, 4.f,1.f, 2.f, 3.f, 4.f,1.f, 2.f, 3.f, 4.f}); nd4j::ops::reduce_norm1_bp op; auto result = op.execute({&x, &eps}, {}, {0,1}); ASSERT_EQ(ND4J_STATUS_OK, result->status()); auto output = result->at(0); // output->printIndexedBuffer("Result is"); // exp.printIndexedBuffer("Expect is"); ASSERT_TRUE(exp.isSameShape(output)); ASSERT_TRUE(exp.equalsTo(output)); delete result; } //////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests7, Test_Reduce_Norm1_BP_02) { auto x = NDArrayFactory::create('c', {2, 3, 4}); auto eps = NDArrayFactory::create({1.f, 2.f, 3.f, 4.f}); x.linspace(1); auto exp = NDArrayFactory::create('c', {2, 3, 4}, {1.f, 2.f, 3.f, 4.f, 1.f, 2.f, 3.f, 4.f, 1.f, 2.f, 3.f, 4.f, 1.f, 2.f, 3.f, 4.f,1.f, 2.f, 3.f, 4.f,1.f, 2.f, 3.f, 4.f}); auto axes = NDArrayFactory::create({0,1}); nd4j::ops::reduce_norm1_bp op; auto result = op.execute({&x, &eps, &axes}, {}, {}, {false}); ASSERT_EQ(ND4J_STATUS_OK, result->status()); auto output = result->at(0); output->printIndexedBuffer("Result is"); exp.printIndexedBuffer("Expect is"); ASSERT_TRUE(exp.isSameShape(output)); ASSERT_TRUE(exp.equalsTo(output)); delete result; } //////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests7, Test_Reduce_Norm1_BP_3) { auto x = NDArrayFactory::create('c', {2, 3, 4}); auto eps = NDArrayFactory::create('c', {1, 1, 4}, {1.f, 2.f, 3.f, 4.f}); x.linspace(1); auto exp = NDArrayFactory::create('c', {2, 3, 4}, {1.f, 2.f, 3.f, 4.f, 1.f, 2.f, 3.f, 4.f, 1.f, 2.f, 3.f, 4.f, 1.f, 2.f, 3.f, 4.f,1.f, 2.f, 3.f, 4.f,1.f, 2.f, 3.f, 4.f}); nd4j::ops::reduce_norm1_bp op; auto result = op.execute({&x, &eps}, {1.f}, {0,1}); auto output = result->at(0); // output->printIndexedBuffer("Result is"); ASSERT_EQ(ND4J_STATUS_OK, result->status()); ASSERT_TRUE(exp.isSameShape(output)); ASSERT_TRUE(exp.equalsTo(output)); delete result; } //////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests7, Test_Reduce_Norm2_BP_1) { auto x = NDArrayFactory::create('c', {2, 3, 4}); auto eps = NDArrayFactory::create('c', {4}, {31.7175f, 33.823071f, 35.97221f, 38.15757f}); x.linspace(1); nd4j::ops::reduce_norm2_bp op; auto result = op.execute({&x, &eps}, {}, {0,1}); auto output = result->at(0); // output->printIndexedBuffer("Result is"); ASSERT_EQ(ND4J_STATUS_OK, result->status()); ASSERT_TRUE(x.isSameShape(output)); ASSERT_TRUE(x.equalsTo(output)); delete result; } //////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests7, Test_Reduce_Norm2_BP_2) { auto x = NDArrayFactory::create('c', {2, 3, 4}); auto eps = NDArrayFactory::create('c', {1, 1, 4}, {31.7175f, 33.823071f, 35.97221f, 38.15757f}); x.linspace(1); nd4j::ops::reduce_norm2_bp op; auto result = op.execute({&x, &eps}, {1.f}, {0,1}); auto output = result->at(0); // output->printIndexedBuffer("Result is"); ASSERT_EQ(ND4J_STATUS_OK, result->status()); ASSERT_TRUE(x.isSameShape(output)); ASSERT_TRUE(x.equalsTo(output)); delete result; } //////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests7, Test_Reduce_Norm2_BP_02) { auto x = NDArrayFactory::create('c', {2, 3, 4}); auto eps = NDArrayFactory::create('c', {1, 1, 4}, {31.7175f, 33.823071f, 35.97221f, 38.15757f}); auto axes = NDArrayFactory::create({0, 1}); x.linspace(1); nd4j::ops::reduce_norm2_bp op; auto result = op.execute({&x, &eps, &axes}, {}, {}, {true}); auto output = result->at(0); // output->printIndexedBuffer("Result is"); ASSERT_EQ(ND4J_STATUS_OK, result->status()); ASSERT_TRUE(x.isSameShape(output)); ASSERT_TRUE(x.equalsTo(output)); delete result; } //////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests7, Test_Reduce_Norm2_BP_3) { auto x = NDArrayFactory::create('c', {2, 3, 4}); auto eps = NDArrayFactory::create('c', {3}, {29.597298f, 39.344631f, 49.759422f}); x.linspace(1); nd4j::ops::reduce_norm2_bp op; auto result = op.execute({&x, &eps}, {}, {0, 2}); auto output = result->at(0); // output->printIndexedBuffer("Result is"); ASSERT_EQ(ND4J_STATUS_OK, result->status()); ASSERT_TRUE(x.isSameShape(output)); ASSERT_TRUE(x.equalsTo(output)); delete result; } //////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests7, Test_Reduce_Norm2_BP_4) { auto x = NDArrayFactory::create('c', {2, 3, 4}); auto eps = NDArrayFactory::create('c', {1,3,1}, {29.597298f, 39.344631f, 49.759422f}); x.linspace(1); nd4j::ops::reduce_norm2_bp op; auto result = op.execute({&x, &eps}, {1.f}, {0, 2}); auto output = result->at(0); // output->printIndexedBuffer("Result is"); ASSERT_EQ(ND4J_STATUS_OK, result->status()); ASSERT_TRUE(x.isSameShape(output)); ASSERT_TRUE(x.equalsTo(output)); delete result; } //////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests7, Test_Reduce_SquaredNorm_BP_1) { auto x = NDArrayFactory::create('c', {2, 3, 4}); auto eps = NDArrayFactory::create('c', {4}, {1.f, 2.f, 3.f, 4.f}); auto exp = NDArrayFactory::create('c', {2, 3, 4}, { 2.f, 8.f, 18.f, 32.f, 10.f, 24.f, 42.f, 64.f, 18.f, 40.f, 66.f, 96.f, 26.f, 56.f, 90.f, 128.f, 34.f, 72.f, 114.f, 160.f, 42.f, 88.f, 138.f, 192.f}); x.linspace(1); nd4j::ops::reduce_sqnorm_bp op; auto result = op.execute({&x, &eps}, {}, {0,1}); ASSERT_EQ(ND4J_STATUS_OK, result->status()); auto output = result->at(0); ASSERT_TRUE(exp.isSameShape(output)); ASSERT_TRUE(exp.equalsTo(output)); delete result; } //////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests7, Test_Reduce_SquaredNorm_BP_01) { auto x = NDArrayFactory::create('c', {2, 3, 4}); auto eps = NDArrayFactory::create('c', {4}, {1.f, 2.f, 3.f, 4.f}); auto exp = NDArrayFactory::create('c', {2, 3, 4}, { 2.f, 8.f, 18.f, 32.f, 10.f, 24.f, 42.f, 64.f, 18.f, 40.f, 66.f, 96.f, 26.f, 56.f, 90.f, 128.f, 34.f, 72.f, 114.f, 160.f, 42.f, 88.f, 138.f, 192.f}); auto axes = NDArrayFactory::create({0, 1}); x.linspace(1); nd4j::ops::reduce_sqnorm_bp op; auto result = op.execute({&x, &eps, &axes}, {}, {}, {false}); ASSERT_EQ(ND4J_STATUS_OK, result->status()); auto output = result->at(0); ASSERT_TRUE(exp.isSameShape(output)); ASSERT_TRUE(exp.equalsTo(output)); delete result; } //////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests7, Test_Reduce_NormMax_BP_1) { auto x = NDArrayFactory::create('c', {2, 3, 4}); auto eps = NDArrayFactory::create('c', {4}, {1.f, 2.f, 3.f, 4.f}); auto exp = NDArrayFactory::create('c', {2, 3, 4}); x.linspace(1); exp.p(20, 1.f); exp.p(21, 2.f); exp.p(22, 3.f); exp.p(23, 4.f); nd4j::ops::reduce_norm_max_bp op; auto result = op.execute({&x, &eps}, {}, {0,1}); auto output = result->at(0); // output->printIndexedBuffer("Result is"); ASSERT_EQ(ND4J_STATUS_OK, result->status()); ASSERT_TRUE(exp.isSameShape(output)); ASSERT_TRUE(exp.equalsTo(output)); delete result; } //////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests7, Test_Reduce_NormMax_BP_2) { auto x = NDArrayFactory::create('c', {2, 3, 4}); auto eps = NDArrayFactory::create('c', {1,1,4}, {1.f, 2.f, 3.f, 4.f}); auto exp = NDArrayFactory::create('c', {2, 3, 4}); x.linspace(1); exp.p(20, 1.f); exp.p(21, 2.f); exp.p(22, 3.f); exp.p(23, 4.f); nd4j::ops::reduce_norm_max_bp op; auto result = op.execute({&x, &eps}, {1.f}, {0,1}); auto output = result->at(0); // output->printIndexedBuffer("Result is"); ASSERT_EQ(ND4J_STATUS_OK, result->status()); ASSERT_TRUE(exp.isSameShape(output)); ASSERT_TRUE(exp.equalsTo(output)); delete result; } //////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests7, Test_Reduce_NormMax_BP_02) { auto x = NDArrayFactory::create('c', {2, 3, 4}); auto eps = NDArrayFactory::create('c', {1,1,4}, {1.f, 2.f, 3.f, 4.f}); auto exp = NDArrayFactory::create('c', {2, 3, 4}); auto axes = NDArrayFactory::create({0,1}); x.linspace(1); exp.p(20, 1.f); exp.p(21, 2.f); exp.p(22, 3.f); exp.p(23, 4.f); nd4j::ops::reduce_norm_max_bp op; auto result = op.execute({&x, &eps, &axes}, {}, {}, {true}); auto output = result->at(0); // output->printIndexedBuffer("Result is"); ASSERT_EQ(ND4J_STATUS_OK, result->status()); ASSERT_TRUE(exp.isSameShape(output)); ASSERT_TRUE(exp.equalsTo(output)); delete result; } //////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests7, Test_Reduce_NormMax_BP_3) { auto x = NDArrayFactory::create('c', {2, 3, 4}); auto eps = NDArrayFactory::create('c', {3}, {1.f, 2.f, 3.f}); auto exp = NDArrayFactory::create('c', {2, 3, 4}); x.linspace(1); exp.p(15, 1.f); exp.p(19, 2.f); exp.p(23, 3.f); nd4j::ops::reduce_norm_max_bp op; auto result = op.execute({&x, &eps}, {}, {0,2}); auto output = result->at(0); // output->printIndexedBuffer("Result is"); ASSERT_EQ(ND4J_STATUS_OK, result->status()); ASSERT_TRUE(exp.isSameShape(output)); ASSERT_TRUE(exp.equalsTo(output)); delete result; } //////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests7, Test_Reduce_NormMax_BP_4) { auto x = NDArrayFactory::create('c', {2, 3, 4}); auto eps = NDArrayFactory::create('c', {1, 3, 1}, {1.f, 2.f, 3.f}); auto exp = NDArrayFactory::create('c', {2, 3, 4}); x.linspace(1); exp.p(15, 1.f); exp.p(19, 2.f); exp.p(23, 3.f); nd4j::ops::reduce_norm_max_bp op; auto result = op.execute({&x, &eps}, {1.f}, {0,2}); auto output = result->at(0); // output->printIndexedBuffer("Result is"); ASSERT_EQ(ND4J_STATUS_OK, result->status()); ASSERT_TRUE(exp.isSameShape(output)); ASSERT_TRUE(exp.equalsTo(output)); delete result; } //////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests7, Test_Reduce_NormMax_BP_5) { auto x = NDArrayFactory::create('c', {2, 3, 4}); auto eps = NDArrayFactory::create(1.f); auto exp = NDArrayFactory::create('c', {2, 3, 4}); x.linspace(1); exp.p(23, 1.f); nd4j::ops::reduce_norm_max_bp op; auto result = op.execute({&x, &eps}, {}, {}); auto output = result->at(0); // output->printIndexedBuffer("Result is"); ASSERT_EQ(ND4J_STATUS_OK, result->status()); ASSERT_TRUE(exp.isSameShape(output)); ASSERT_TRUE(exp.equalsTo(output)); delete result; } //////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests7, Test_Reduce_NormMax_BP_6) { auto x = NDArrayFactory::create('c', {2, 3, 4}); auto eps = NDArrayFactory::create(1.f); auto exp = NDArrayFactory::create('c', {2, 3, 4}); x.linspace(1); exp.p(23, 1.f); nd4j::ops::reduce_norm_max_bp op; auto result = op.execute({&x, &eps}, {}, {0, 1, 2}); auto output = result->at(0); // output->printIndexedBuffer("Result is"); ASSERT_EQ(ND4J_STATUS_OK, result->status()); ASSERT_TRUE(exp.isSameShape(output)); ASSERT_TRUE(exp.equalsTo(output)); delete result; } //////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests7, Test_Reduce_NormMax_BP_7) { auto x = NDArrayFactory::create('c', {2, 3, 4}); auto eps = NDArrayFactory::create('c', {1, 1, 1}, {1.f}); auto exp = NDArrayFactory::create('c', {2, 3, 4}); x.linspace(1); exp.p(23, 1.f); nd4j::ops::reduce_norm_max_bp op; auto result = op.execute({&x, &eps}, {1.f}, {}); auto output = result->at(0); // output->printIndexedBuffer("Result is"); ASSERT_EQ(ND4J_STATUS_OK, result->status()); ASSERT_TRUE(exp.isSameShape(output)); ASSERT_TRUE(exp.equalsTo(output)); delete result; } //////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests7, Test_Reduce_Dot_BP_1) { auto x = NDArrayFactory::create('c', {2, 3, 4}); auto y = NDArrayFactory::create('c', {2, 3, 4}); NDArray* z; // = NDArrayFactory::create('c', {4}); auto eps = NDArrayFactory::create(1.f); // auto exp = NDArrayFactory::create('c', {2, 3, 4}); x.linspace(1); y.linspace(2); nd4j::ops::reduce_dot_bp op; auto result = op.execute({&x, &y, &eps}, {}, {}); auto output = result->at(0); auto outputX = result->at(1); //tput->printIndexedBuffer("Result is"); // ASSERT_EQ(ND4J_STATUS_OK, result->status()); ASSERT_TRUE(x.equalsTo(outputX)); ASSERT_TRUE(y.equalsTo(output)); delete result; // delete z; } //////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests7, Test_Reduce_Dot_BP_2) { auto x = NDArrayFactory::create('c', {2, 3, 4}); auto y = NDArrayFactory::create('c', {2, 3, 4}); // auto z; // = NDArrayFactory::create('c', {4}); auto eps = NDArrayFactory::create('c', {2, 4}); auto expX = NDArrayFactory::create('c', {2, 3, 4}, {2.f, 4.f, 6.f, 8.f, 2.f, 4.f, 6.f, 8.f, 2.f, 4.f, 6.f, 8.f, 10.f, 12.f, 14.f, 16.f, 10.f, 12.f, 14.f, 16.f, 10.f, 12.f, 14.f, 16.f }); auto expY = NDArrayFactory::create('c', {2, 3, 4}, {1.f, 2.f, 3.f, 4.f, 1.f, 2.f, 3.f, 4.f, 1.f, 2.f, 3.f, 4.f, 5.f, 6.f, 7.f, 8.f, 5.f, 6.f, 7.f, 8.f, 5.f, 6.f, 7.f, 8.f }); x.assign(1.f); eps.linspace(1); y.assign(2.f); nd4j::ops::reduce_dot_bp op; auto result = op.execute({&x, &y, &eps}, {}, {1}); ASSERT_EQ(result->status(), ND4J_STATUS_OK); ASSERT_EQ(result->size(), 2); auto outputX = result->at(0); auto outputY = result->at(1); ASSERT_EQ(ND4J_STATUS_OK, result->status()); ASSERT_TRUE(expX.equalsTo(outputX)); ASSERT_TRUE(expY.equalsTo(outputY)); delete result; // delete z; } //////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests7, Test_Reduce_Dot_BP_02) { auto x = NDArrayFactory::create('c', {2, 3, 4}); auto y = NDArrayFactory::create('c', {2, 3, 4}); // auto z; // = NDArrayFactory::create('c', {4}); auto eps = NDArrayFactory::create('c', {2, 4}); auto expX = NDArrayFactory::create('c', {2, 3, 4}, {2.f, 4.f, 6.f, 8.f, 2.f, 4.f, 6.f, 8.f, 2.f, 4.f, 6.f, 8.f, 10.f, 12.f, 14.f, 16.f, 10.f, 12.f, 14.f, 16.f, 10.f, 12.f, 14.f, 16.f }); auto expY = NDArrayFactory::create('c', {2, 3, 4}, {1.f, 2.f, 3.f, 4.f, 1.f, 2.f, 3.f, 4.f, 1.f, 2.f, 3.f, 4.f, 5.f, 6.f, 7.f, 8.f, 5.f, 6.f, 7.f, 8.f, 5.f, 6.f, 7.f, 8.f }); auto axis = NDArrayFactory::create('c', {1}, {1}); x.assign(1.f); eps.linspace(1); y.assign(2.f); nd4j::ops::reduce_dot_bp op; auto result = op.execute({&x, &y, &eps, &axis}, {}, {}, {false}); ASSERT_EQ(result->status(), ND4J_STATUS_OK); ASSERT_EQ(result->size(), 2); auto outputX = result->at(0); auto outputY = result->at(1); ASSERT_EQ(ND4J_STATUS_OK, result->status()); ASSERT_TRUE(expX.equalsTo(outputX)); ASSERT_TRUE(expY.equalsTo(outputY)); delete result; // delete z; } //////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests7, Test_Reduce_Dot_BP_3) { auto x = NDArrayFactory::create('c', {3, 4}); auto y = NDArrayFactory::create('c', {3, 4}); auto eps = NDArrayFactory::create('c', {3}); auto expX = NDArrayFactory::create('c', {3, 4}, {2.f, 2.f, 2.f, 2.f, 4.f, 4.f, 4.f, 4.f, 6.f, 6.f, 6.f, 6.f}); auto expY = NDArrayFactory::create('c', {3, 4}, {1.f, 2.f, 3.f, 4.f, 10.f, 12.f, 14.f, 16.f, 27.f, 30.f, 33.f, 36.f}); x.linspace(1); eps.linspace(1); y.assign(2.f); nd4j::ops::reduce_dot_bp op; auto result = op.execute({&x,&y, &eps}, {}, {1}); auto outputX = result->at(0); auto outputY = result->at(1); ASSERT_EQ(ND4J_STATUS_OK, result->status()); ASSERT_TRUE(expX.equalsTo(outputX)); ASSERT_TRUE(expY.equalsTo(outputY)); delete result; } //////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests7, cumsum_bp_1) { auto x = NDArrayFactory::create('c', {3, 4}); auto eps = NDArrayFactory::create('c', {3, 4}); auto exp = NDArrayFactory::create('c', {3, 4}, {12.f, 11.f, 10.f, 9.f, 8.f, 7.f, 6.f, 5.f, 4.f, 3.f, 2.f, 1.f}); x.linspace(1); eps.assign(1.f); nd4j::ops::cumsum_bp op; auto result = op.execute({&x, &eps}, {}, {0,0}); auto output = result->at(0); ASSERT_EQ(ND4J_STATUS_OK, result->status()); ASSERT_TRUE(exp.isSameShape(output)); ASSERT_TRUE(exp.equalsTo(output)); delete result; } //////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests7, cumsum_bp_2) { auto x = NDArrayFactory::create('c', {3, 4}); auto eps = NDArrayFactory::create('c', {3, 4}); auto exp = NDArrayFactory::create('c', {3, 4}, { 11.f, 10.f, 9.f, 8.f, 7.f, 6.f, 5.f, 4.f, 3.f, 2.f, 1.f, 0.f}); x.linspace(1); eps.assign(1.f); nd4j::ops::cumsum_bp op; auto result = op.execute({&x, &eps}, {}, {1,0}); auto output = result->at(0); ASSERT_EQ(ND4J_STATUS_OK, result->status()); ASSERT_TRUE(exp.isSameShape(output)); ASSERT_TRUE(exp.equalsTo(output)); delete result; } //////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests7, cumsum_bp_3) { auto x = NDArrayFactory::create('c', {3, 4}); auto eps = NDArrayFactory::create('c', {3, 4}); auto exp = NDArrayFactory::create('c', {3, 4}); x.linspace(1); exp.linspace(0); eps.assign(1.f); nd4j::ops::cumsum_bp op; auto result = op.execute({&x, &eps}, {}, {1,1}); auto output = result->at(0); ASSERT_EQ(ND4J_STATUS_OK, result->status()); ASSERT_TRUE(exp.equalsTo(output)); delete result; }